├── .gitattributes ├── LICENSE ├── README.md ├── __init__.py ├── conf ├── glow_tts_baseline.yaml ├── glow_tts_std.yaml ├── glow_tts_stdp.yaml └── hifigan16k_ft.yaml ├── dataset ├── dev_tts_common_voice.json └── train_tts_common_voice.json ├── glow_tts_with_pitch.py ├── media └── glow_tts_stdp.png ├── models ├── __init__.py └── glow_tts_with_pitch.py ├── modules ├── glow_tts_modules │ ├── __init__.py │ ├── glow_tts_submodules.py │ ├── glow_tts_submodules_with_pitch.py │ ├── stocpred_modules.py │ └── transforms.py └── glow_tts_with_pitch.py ├── notebook └── inference_glowtts_stdp.ipynb ├── train_glowtts_baseline.sh ├── train_glowtts_std.sh ├── train_glowtts_stdp.sh └── utils ├── __init__.py ├── data.py ├── glow_tts_loss.py ├── helpers.py ├── tts_data_types.py └── tts_tokenizers.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Stochastic Pitch Prediction Improves the Diversity and Naturalness of Speech in Glow-TTS 3 | 4 | ### Sewade Ogun, Vincent Colotte, Emmanuel Vincent 5 | 6 | In our recent [paper](https://arxiv.org/abs/2305.17724), we proposed GlowTTS-STDP, a flow-based TTS model that improves the naturalness and diversity of generated utterances. 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
Glow-TTS-STDP at Inference
Glow-TTS-STDP
17 | 18 | 19 | ## 2. Requirements 20 | 21 | Our model is based on the GlowTTS architecture which we implemented in the NeMo toolkit. You have to install the same version of NeMo used in our experiments in order to ensure all dependencies work correctly. 22 | 23 | ```bash 24 | apt-get update && apt-get install -y libsndfile1 ffmpeg 25 | git clone https://github.com/NVIDIA/NeMo 26 | cd NeMo 27 | git checkout v1.5.0 28 | ./reinstall.sh 29 | ``` 30 | 31 | After installation, you should have; 32 | 1) NeMo toolkit (version 1.5.0), [https://github.com/NVIDIA/NeMo](https://github.com/NVIDIA/NeMo) 33 | 2) Pytorch 1.10.0 or above 34 | 3) Pytorch Lightning 35 | 36 | 37 | GPUs are required for model training. Kindly note that we used mixed-precision training for all our experiments. 38 | 39 | PS: Checkout the NeMo github page if you have problems with the library installations. 40 | 41 | ## 2. Model setup 42 | 43 | Clone this github repo after installing NeMo and changing to the correct branch successfully. This repo contains; 44 | i. the model, 45 | ii. the dataset (without the audio files), 46 | iii. the training scripts, 47 | iv. configuration files for all the experiments 48 | 49 | ```bash 50 | git clone https://github.com/ogunlao/glowtts_stdp 51 | ``` 52 | 53 | ## 2. Pre-requisites 54 | 55 | a) Download and extract [the English subset of Common Voice Version 7.0](https://commonvoice.mozilla.org/en/datasets) into the `dataset` directory. Convert the files from mp3 to wav, and resample the files to 16 kHz for faster data loading. The training and validation json files, which contains [CommonVoice WV-MOS-4.0-all](https://arxiv.org/abs/2210.06370) has been provided. 56 | 57 | b) A HiFi-GAN vocoder trained with 16 kHz multi-speaker speech utterances is required. We trained a Hifi-GAN v1 on [LibriTTS](http://www.openslr.org/60). HiFI-GAN can be trained using the NeMO toolkit. 58 | The config file for hifi-GAN is provided in `glowtts_stdp/conf/hifigan16k_ft.yaml` 59 | 60 | c) A speaker embedding file is required either in the form of a pickle or json file. We extract embedding vectors using the open source library, [resemblyzer](https://github.com/resemble-ai/Resemblyzer). 61 | 62 | Embeddings should be saved as a lookup table (dictionary) using the structure: 63 | 64 | ``` 65 | { 66 | audio1: [[embedding vector1]], 67 | audio2: [[embedding vector1]], 68 | } 69 | ``` 70 | 71 | Notice that audio files are without extension. The lookup table can either be saved on disk as a pickle or json file. 72 | 73 | 74 | ## 3. Training Example 75 | 76 | To train the baseline GlowTTS model 77 | ```sh 78 | cd glowtts_stdp 79 | sh train_glowtts_baseline.sh 80 | ``` 81 | 82 | To train the GlowTTS-STD model (model with stochastic duration prediction) 83 | ```sh 84 | cd glowtts_stdp 85 | sh train_glowtts_std.sh 86 | ``` 87 | 88 | To train the GlowTTS-STDP model (model with stochastic duration prediction and stochastic pitch prediction) 89 | ```sh 90 | cd glowtts_stdp 91 | sh train_glowtts_stdp.sh 92 | ``` 93 | 94 | NeMo uses [Hydra](https://hydra.cc/) for hyperparameter configuration, therefore hyperparameters can be changed either in their respective config file or in their train scripts. 95 | 96 | ## 4. Inference Example 97 | 98 | See [inference_glowtts_stdp.ipynb](notebook/inference_glowtts_stdp.ipynb) 99 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/__init__.py -------------------------------------------------------------------------------- /conf/glow_tts_baseline.yaml: -------------------------------------------------------------------------------- 1 | name: "GlowTTS_baseline" 2 | gin_channels: 256 3 | use_stoch_dur_pred: false 4 | use_stoch_pitch_pred: false 5 | use_log_pitch: false 6 | sup_data_path: ??? 7 | sup_data_types: ["speaker_emb"] 8 | 9 | train_dataset: ??? 10 | validation_datasets: ??? 11 | test_datasets: null 12 | 13 | phoneme_dict_path: "../NeMo/scripts/tts_dataset_files/cmudict-0.7b_nv22.10" 14 | heteronyms_path: "../NeMo/scripts/tts_dataset_files/heteronyms-052722" 15 | whitelist_path: "../NeMo/nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv" 16 | 17 | speaker_emb_path: ??? 18 | 19 | # Default values from librosa.pyin 20 | pitch_fmin: null 21 | pitch_fmax: null 22 | pitch_mean: null 23 | pitch_std: null 24 | 25 | # Default values for dataset with sample_rate=22050 26 | sample_rate: 16000 27 | n_mel_channels: 80 28 | n_window_size: 1024 29 | n_window_stride: 256 30 | n_fft: 1024 31 | lowfreq: 0 32 | highfreq: 8000 33 | window: hann 34 | pad_value: 0.0 35 | 36 | model: 37 | n_speakers: 4469 38 | gin_channels: ${gin_channels} 39 | use_external_speaker_emb: true 40 | speaker_emb_path: ${speaker_emb_path} 41 | use_stoch_dur_pred: ${use_stoch_dur_pred} 42 | use_stoch_pitch_pred: ${use_stoch_pitch_pred} 43 | pitch_loss_scale: 0.0 44 | 45 | max_token_duration: 75 46 | symbols_embedding_dim: 256 #384 47 | pitch_embedding_kernel_size: 3 48 | 49 | pitch_fmin: ${pitch_fmin} 50 | pitch_fmax: ${pitch_fmax} 51 | 52 | pitch_mean: ${pitch_mean} 53 | pitch_std: ${pitch_std} 54 | 55 | sample_rate: ${sample_rate} 56 | n_mel_channels: ${n_mel_channels} 57 | n_window_size: ${n_window_size} 58 | n_window_stride: ${n_window_stride} 59 | n_fft: ${n_fft} 60 | lowfreq: ${lowfreq} 61 | highfreq: ${highfreq} 62 | window: ${window} 63 | pad_value: ${pad_value} 64 | 65 | # use_log_pitch: ${use_log_pitch} 66 | 67 | # text_normalizer: 68 | # _target_: nemo_text_processing.text_normalization.normalize.Normalizer 69 | # lang: en 70 | # input_case: cased 71 | # whitelist: ${whitelist_path} 72 | 73 | # text_normalizer_call_kwargs: 74 | # verbose: false 75 | # punct_pre_process: true 76 | # punct_post_process: true 77 | 78 | text_tokenizer: 79 | _target_: utils.tts_tokenizers.EnglishPhonemesTokenizer 80 | punct: true 81 | stresses: false # true 82 | chars: true 83 | apostrophe: true 84 | pad_with_space: false 85 | add_blank_at: "last" 86 | add_blank_to_text: true 87 | g2p: 88 | _target_: nemo_text_processing.g2p.modules.EnglishG2p 89 | phoneme_dict: ${phoneme_dict_path} 90 | heteronyms: ${heteronyms_path} 91 | phoneme_probability: 0.8 92 | 93 | train_ds: 94 | dataset: 95 | _target_: utils.data.TTSDataset 96 | manifest_filepath: ${train_dataset} 97 | sample_rate: ${model.sample_rate} 98 | sup_data_path: ${sup_data_path} 99 | sup_data_types: ${sup_data_types} 100 | n_fft: ${n_fft} 101 | win_length: ${model.n_fft} 102 | hop_length: ${model.n_window_stride} 103 | window: ${model.window} 104 | n_mels: ${model.n_mel_channels} 105 | lowfreq: ${model.lowfreq} 106 | highfreq: ${model.highfreq} 107 | max_duration: 16.7 108 | min_duration: 0.1 109 | ignore_file: null 110 | trim: false 111 | # pitch_fmin: ${model.pitch_fmin} 112 | # pitch_fmax: ${model.pitch_fmax} 113 | # pitch_norm: true 114 | # pitch_mean: ${model.pitch_mean} 115 | # pitch_std: ${model.pitch_std} 116 | # use_log_pitch: ${model.use_log_pitch} 117 | speaker_emb_path: ${speaker_emb_path} 118 | 119 | dataloader_params: 120 | drop_last: false 121 | shuffle: true 122 | batch_size: 32 123 | num_workers: 32 124 | pin_memory: true 125 | 126 | validation_ds: 127 | dataset: 128 | _target_: utils.data.TTSDataset 129 | manifest_filepath: ${validation_datasets} 130 | sample_rate: ${model.sample_rate} 131 | sup_data_path: ${sup_data_path} 132 | sup_data_types: ${sup_data_types} 133 | n_fft: ${model.n_fft} 134 | win_length: ${model.n_window_size} 135 | hop_length: ${model.n_window_stride} 136 | window: ${model.window} 137 | n_mels: ${model.n_mel_channels} 138 | lowfreq: ${model.lowfreq} 139 | highfreq: ${model.highfreq} 140 | max_duration: null 141 | min_duration: null 142 | ignore_file: null 143 | trim: false 144 | # pitch_fmin: ${model.pitch_fmin} 145 | # pitch_fmax: ${model.pitch_fmax} 146 | # pitch_norm: true 147 | # pitch_mean: ${model.pitch_mean} 148 | # pitch_std: ${model.pitch_std} 149 | # use_log_pitch: ${model.use_log_pitch} 150 | speaker_emb_path: ${speaker_emb_path} 151 | 152 | dataloader_params: 153 | drop_last: false 154 | shuffle: false 155 | batch_size: 2 156 | num_workers: 2 157 | pin_memory: true 158 | 159 | preprocessor: 160 | _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor 161 | features: ${model.n_mel_channels} 162 | lowfreq: ${model.lowfreq} 163 | highfreq: ${model.highfreq} 164 | n_fft: ${model.n_fft} 165 | n_window_size: ${model.n_window_size} 166 | window_size: false 167 | n_window_stride: ${model.n_window_stride} 168 | window_stride: false 169 | pad_to: 1 170 | pad_value: ${model.pad_value} 171 | sample_rate: ${model.sample_rate} 172 | window: ${model.window} 173 | normalize: null 174 | preemph: null 175 | dither: 0.0 176 | frame_splicing: 1 177 | log: true 178 | log_zero_guard_type: add 179 | log_zero_guard_value: 1e-05 180 | mag_power: 1.0 181 | 182 | encoder: 183 | _target_: modules.glow_tts_with_pitch.TextEncoder 184 | n_vocab: 148 185 | out_channels: ${model.n_mel_channels} 186 | hidden_channels: 192 187 | filter_channels: 768 188 | filter_channels_dp: 256 189 | kernel_size: 3 190 | p_dropout: 0.1 191 | n_layers: 6 192 | n_heads: 2 193 | window_size: 4 194 | prenet: true 195 | mean_only: true 196 | gin_channels: ${gin_channels} 197 | use_stoch_dur_pred: ${use_stoch_dur_pred} 198 | 199 | decoder: 200 | _target_: modules.glow_tts_with_pitch.FlowSpecDecoder 201 | in_channels: ${model.n_mel_channels} 202 | hidden_channels: 192 203 | kernel_size: 5 204 | n_blocks: 12 205 | n_layers: 4 206 | n_sqz: 2 207 | n_split: 4 208 | sigmoid_scale: false 209 | p_dropout: 0.05 210 | dilation_rate: 1 211 | gin_channels: ${gin_channels} 212 | 213 | optim: 214 | name: radam 215 | lr: 2e-4 216 | # optimizer arguments 217 | betas: [0.9, 0.98] 218 | weight_decay: 0.0 219 | 220 | # scheduler setup 221 | sched: 222 | name: CosineAnnealing 223 | 224 | # Scheduler params 225 | warmup_steps: 6000 226 | min_lr: 1e-5 227 | last_epoch: -1 228 | 229 | 230 | trainer: 231 | accelerator: auto 232 | devices: -1 #-1 # number of gpus 233 | strategy: ddp 234 | num_nodes: 1 235 | enable_checkpointing: false # Provided by exp_manager 236 | logger: false # Provided by exp_manager 237 | max_epochs: 1000 238 | max_steps: -1 # computed at runtime if not set 239 | accumulate_grad_batches: 2 240 | log_every_n_steps: 100 # Interval of logging. 241 | check_val_every_n_epoch: 2 242 | amp_backend: native 243 | precision: 16 # mixed-precision training 244 | gradient_clip_val: 5.0 245 | 246 | 247 | exp_manager: 248 | exp_dir: null 249 | name: ${name} 250 | resume_if_exists: False 251 | resume_ignore_no_checkpoint: True 252 | create_tensorboard_logger: True 253 | create_checkpoint_callback: True 254 | checkpoint_callback_params: 255 | always_save_nemo: True 256 | save_top_k: 1 257 | monitor: "val_loss" 258 | mode: "min" 259 | create_early_stopping_callback: False 260 | early_stopping_params: 261 | monitor: "val_loss" 262 | patience: 10 263 | verbose: True 264 | mode: "min" 265 | create_wandb_logger: False 266 | wandb_logger_kwargs: 267 | name: null 268 | project: null 269 | 270 | hydra: 271 | run: 272 | dir: . 273 | job_logging: 274 | root: 275 | handlers: null -------------------------------------------------------------------------------- /conf/glow_tts_std.yaml: -------------------------------------------------------------------------------- 1 | name: "GlowTTS_std" 2 | gin_channels: 256 3 | use_stoch_dur_pred: true 4 | use_stoch_pitch_pred: false 5 | use_log_pitch: false 6 | sup_data_path: ??? 7 | sup_data_types: ["speaker_emb"] 8 | 9 | train_dataset: ??? 10 | validation_datasets: ??? 11 | test_datasets: null 12 | 13 | phoneme_dict_path: "../NeMo/scripts/tts_dataset_files/cmudict-0.7b_nv22.10" 14 | heteronyms_path: "../NeMo/scripts/tts_dataset_files/heteronyms-052722" 15 | whitelist_path: "../NeMo/nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv" 16 | 17 | speaker_emb_path: ??? 18 | 19 | # Default values from librosa.pyin 20 | pitch_fmin: null 21 | pitch_fmax: null 22 | pitch_mean: null 23 | pitch_std: null 24 | 25 | # Default values for dataset with sample_rate=22050 26 | sample_rate: 16000 27 | n_mel_channels: 80 28 | n_window_size: 1024 29 | n_window_stride: 256 30 | n_fft: 1024 31 | lowfreq: 0 32 | highfreq: 8000 33 | window: hann 34 | pad_value: 0.0 35 | 36 | model: 37 | n_speakers: 4469 38 | gin_channels: ${gin_channels} 39 | use_external_speaker_emb: true 40 | speaker_emb_path: ${speaker_emb_path} 41 | use_stoch_dur_pred: ${use_stoch_dur_pred} 42 | use_stoch_pitch_pred: ${use_stoch_pitch_pred} 43 | pitch_loss_scale: 0.0 44 | 45 | max_token_duration: 75 46 | symbols_embedding_dim: 256 #384 47 | pitch_embedding_kernel_size: 3 48 | 49 | pitch_fmin: ${pitch_fmin} 50 | pitch_fmax: ${pitch_fmax} 51 | 52 | pitch_mean: ${pitch_mean} 53 | pitch_std: ${pitch_std} 54 | 55 | sample_rate: ${sample_rate} 56 | n_mel_channels: ${n_mel_channels} 57 | n_window_size: ${n_window_size} 58 | n_window_stride: ${n_window_stride} 59 | n_fft: ${n_fft} 60 | lowfreq: ${lowfreq} 61 | highfreq: ${highfreq} 62 | window: ${window} 63 | pad_value: ${pad_value} 64 | 65 | # use_log_pitch: ${use_log_pitch} 66 | 67 | # text_normalizer: 68 | # _target_: nemo_text_processing.text_normalization.normalize.Normalizer 69 | # lang: en 70 | # input_case: cased 71 | # whitelist: ${whitelist_path} 72 | 73 | # text_normalizer_call_kwargs: 74 | # verbose: false 75 | # punct_pre_process: true 76 | # punct_post_process: true 77 | 78 | text_tokenizer: 79 | _target_: utils.tts_tokenizers.EnglishPhonemesTokenizer 80 | punct: true 81 | stresses: false # true 82 | chars: true 83 | apostrophe: true 84 | pad_with_space: false 85 | add_blank_at: "last" 86 | add_blank_to_text: true 87 | g2p: 88 | _target_: nemo_text_processing.g2p.modules.EnglishG2p 89 | phoneme_dict: ${phoneme_dict_path} 90 | heteronyms: ${heteronyms_path} 91 | phoneme_probability: 0.8 92 | 93 | train_ds: 94 | dataset: 95 | _target_: utils.data.TTSDataset 96 | manifest_filepath: ${train_dataset} 97 | sample_rate: ${model.sample_rate} 98 | sup_data_path: ${sup_data_path} 99 | sup_data_types: ${sup_data_types} 100 | n_fft: ${n_fft} 101 | win_length: ${model.n_fft} 102 | hop_length: ${model.n_window_stride} 103 | window: ${model.window} 104 | n_mels: ${model.n_mel_channels} 105 | lowfreq: ${model.lowfreq} 106 | highfreq: ${model.highfreq} 107 | max_duration: 16.7 108 | min_duration: 0.1 109 | ignore_file: null 110 | trim: false 111 | # pitch_fmin: ${model.pitch_fmin} 112 | # pitch_fmax: ${model.pitch_fmax} 113 | # pitch_norm: true 114 | # pitch_mean: ${model.pitch_mean} 115 | # pitch_std: ${model.pitch_std} 116 | # use_log_pitch: ${model.use_log_pitch} 117 | speaker_emb_path: ${speaker_emb_path} 118 | 119 | dataloader_params: 120 | drop_last: false 121 | shuffle: true 122 | batch_size: 32 123 | num_workers: 32 124 | pin_memory: true 125 | 126 | validation_ds: 127 | dataset: 128 | _target_: utils.data.TTSDataset 129 | manifest_filepath: ${validation_datasets} 130 | sample_rate: ${model.sample_rate} 131 | sup_data_path: ${sup_data_path} 132 | sup_data_types: ${sup_data_types} 133 | n_fft: ${model.n_fft} 134 | win_length: ${model.n_window_size} 135 | hop_length: ${model.n_window_stride} 136 | window: ${model.window} 137 | n_mels: ${model.n_mel_channels} 138 | lowfreq: ${model.lowfreq} 139 | highfreq: ${model.highfreq} 140 | max_duration: null 141 | min_duration: null 142 | ignore_file: null 143 | trim: false 144 | # pitch_fmin: ${model.pitch_fmin} 145 | # pitch_fmax: ${model.pitch_fmax} 146 | # pitch_norm: true 147 | # pitch_mean: ${model.pitch_mean} 148 | # pitch_std: ${model.pitch_std} 149 | # use_log_pitch: ${model.use_log_pitch} 150 | speaker_emb_path: ${speaker_emb_path} 151 | 152 | dataloader_params: 153 | drop_last: false 154 | shuffle: false 155 | batch_size: 2 156 | num_workers: 2 157 | pin_memory: true 158 | 159 | preprocessor: 160 | _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor 161 | features: ${model.n_mel_channels} 162 | lowfreq: ${model.lowfreq} 163 | highfreq: ${model.highfreq} 164 | n_fft: ${model.n_fft} 165 | n_window_size: ${model.n_window_size} 166 | window_size: false 167 | n_window_stride: ${model.n_window_stride} 168 | window_stride: false 169 | pad_to: 1 170 | pad_value: ${model.pad_value} 171 | sample_rate: ${model.sample_rate} 172 | window: ${model.window} 173 | normalize: null 174 | preemph: null 175 | dither: 0.0 176 | frame_splicing: 1 177 | log: true 178 | log_zero_guard_type: add 179 | log_zero_guard_value: 1e-05 180 | mag_power: 1.0 181 | 182 | encoder: 183 | _target_: modules.glow_tts_with_pitch.TextEncoder 184 | n_vocab: 148 185 | out_channels: ${model.n_mel_channels} 186 | hidden_channels: 192 187 | filter_channels: 768 188 | filter_channels_dp: 256 189 | kernel_size: 3 190 | p_dropout: 0.1 191 | n_layers: 6 192 | n_heads: 2 193 | window_size: 4 194 | prenet: true 195 | mean_only: true 196 | gin_channels: ${gin_channels} 197 | use_stoch_dur_pred: ${use_stoch_dur_pred} 198 | 199 | decoder: 200 | _target_: modules.glow_tts_with_pitch.FlowSpecDecoder 201 | in_channels: ${model.n_mel_channels} 202 | hidden_channels: 192 203 | kernel_size: 5 204 | n_blocks: 12 205 | n_layers: 4 206 | n_sqz: 2 207 | n_split: 4 208 | sigmoid_scale: false 209 | p_dropout: 0.05 210 | dilation_rate: 1 211 | gin_channels: ${gin_channels} 212 | 213 | optim: 214 | name: radam 215 | lr: 2e-4 216 | # optimizer arguments 217 | betas: [0.9, 0.98] 218 | weight_decay: 0.0 219 | 220 | # scheduler setup 221 | sched: 222 | name: CosineAnnealing 223 | 224 | # Scheduler params 225 | warmup_steps: 6000 226 | min_lr: 1e-5 227 | last_epoch: -1 228 | 229 | 230 | trainer: 231 | accelerator: auto 232 | devices: -1 #-1 # number of gpus 233 | strategy: ddp 234 | num_nodes: 1 235 | enable_checkpointing: false # Provided by exp_manager 236 | logger: false # Provided by exp_manager 237 | max_epochs: 1000 238 | max_steps: -1 # computed at runtime if not set 239 | accumulate_grad_batches: 2 240 | log_every_n_steps: 100 # Interval of logging. 241 | check_val_every_n_epoch: 4 242 | amp_backend: native 243 | precision: 16 # mixed-precision training 244 | gradient_clip_val: 5.0 245 | 246 | exp_manager: 247 | exp_dir: null 248 | name: ${name} 249 | resume_if_exists: False 250 | resume_ignore_no_checkpoint: True 251 | create_tensorboard_logger: True 252 | create_checkpoint_callback: True 253 | checkpoint_callback_params: 254 | always_save_nemo: True 255 | save_top_k: 1 256 | monitor: "val_loss" 257 | mode: "min" 258 | create_early_stopping_callback: False 259 | early_stopping_params: 260 | monitor: "val_loss" 261 | patience: 10 262 | verbose: True 263 | mode: "min" 264 | create_wandb_logger: False 265 | wandb_logger_kwargs: 266 | name: null 267 | project: null 268 | 269 | hydra: 270 | run: 271 | dir: . 272 | job_logging: 273 | root: 274 | handlers: null -------------------------------------------------------------------------------- /conf/glow_tts_stdp.yaml: -------------------------------------------------------------------------------- 1 | name: "GlowTTS_stdp" 2 | gin_channels: 256 3 | use_stoch_dur_pred: true 4 | use_stoch_pitch_pred: true 5 | sup_data_path: ??? 6 | sup_data_types: ["pitch", "speaker_emb"] 7 | 8 | use_log_pitch: true 9 | use_normalized_pitch: false 10 | unvoiced_value: null 11 | use_frame_emb_for_pitch: true 12 | 13 | 14 | train_dataset: ??? 15 | validation_datasets: ??? 16 | test_datasets: null 17 | 18 | phoneme_dict_path: "../NeMo/scripts/tts_dataset_files/cmudict-0.7b_nv22.10" 19 | heteronyms_path: "../NeMo/scripts/tts_dataset_files/heteronyms-052722" 20 | whitelist_path: "../NeMo/nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv" 21 | 22 | speaker_emb_path: ??? 23 | 24 | # Default values from librosa.pyin 25 | pitch_fmin: null 26 | pitch_fmax: null 27 | pitch_mean: null 28 | pitch_std: null 29 | pitch_norm: false 30 | 31 | # Default values for dataset with sample_rate=22050 32 | sample_rate: 16000 33 | n_mel_channels: 80 34 | n_window_size: 1024 35 | n_window_stride: 256 36 | n_fft: 1024 37 | lowfreq: 0 38 | highfreq: 8000 39 | window: hann 40 | pad_value: 0.0 41 | 42 | model: 43 | n_speakers: 4469 44 | gin_channels: ${gin_channels} 45 | use_external_speaker_emb: true 46 | speaker_emb_path: ${speaker_emb_path} 47 | use_stoch_dur_pred: ${use_stoch_dur_pred} 48 | use_stoch_pitch_pred: ${use_stoch_pitch_pred} 49 | pitch_loss_scale: 0.5 50 | 51 | unvoiced_value: ${unvoiced_value} 52 | use_log_pitch: ${use_log_pitch} 53 | use_normalized_pitch: ${use_normalized_pitch} 54 | use_frame_emb_for_pitch: ${use_frame_emb_for_pitch} 55 | 56 | max_token_duration: 75 57 | symbols_embedding_dim: 256 #384 58 | pitch_embedding_kernel_size: 3 59 | 60 | pitch_fmin: ${pitch_fmin} 61 | pitch_fmax: ${pitch_fmax} 62 | 63 | pitch_mean: ${pitch_mean} 64 | pitch_std: ${pitch_std} 65 | pitch_norm: ${pitch_norm} 66 | 67 | sample_rate: ${sample_rate} 68 | n_mel_channels: ${n_mel_channels} 69 | n_window_size: ${n_window_size} 70 | n_window_stride: ${n_window_stride} 71 | n_fft: ${n_fft} 72 | lowfreq: ${lowfreq} 73 | highfreq: ${highfreq} 74 | window: ${window} 75 | pad_value: ${pad_value} 76 | 77 | # text_normalizer: 78 | # _target_: nemo_text_processing.text_normalization.normalize.Normalizer 79 | # lang: en 80 | # input_case: cased 81 | # whitelist: ${whitelist_path} 82 | 83 | # text_normalizer_call_kwargs: 84 | # verbose: false 85 | # punct_pre_process: true 86 | # punct_post_process: true 87 | 88 | text_tokenizer: 89 | _target_: utils.tts_tokenizers.EnglishPhonemesTokenizer 90 | punct: true 91 | stresses: false #true 92 | chars: true 93 | apostrophe: true 94 | pad_with_space: false 95 | add_blank_at: "last" 96 | add_blank_to_text: true 97 | g2p: 98 | _target_: nemo_text_processing.g2p.modules.EnglishG2p 99 | phoneme_dict: ${phoneme_dict_path} 100 | heteronyms: ${heteronyms_path} 101 | phoneme_probability: 0.8 102 | 103 | train_ds: 104 | dataset: 105 | _target_: utils.data.TTSDataset 106 | manifest_filepath: ${train_dataset} 107 | sample_rate: ${model.sample_rate} 108 | sup_data_path: ${sup_data_path} 109 | sup_data_types: ${sup_data_types} 110 | n_fft: ${n_fft} 111 | win_length: ${model.n_fft} 112 | hop_length: ${model.n_window_stride} 113 | window: ${model.window} 114 | n_mels: ${model.n_mel_channels} 115 | lowfreq: ${model.lowfreq} 116 | highfreq: ${model.highfreq} 117 | max_duration: 16.7 118 | min_duration: 0.1 119 | ignore_file: null 120 | trim: false 121 | pitch_fmin: ${model.pitch_fmin} 122 | pitch_fmax: ${model.pitch_fmax} 123 | pitch_norm: ${model.pitch_norm} 124 | pitch_mean: ${model.pitch_mean} 125 | pitch_std: ${model.pitch_std} 126 | use_log_pitch: ${model.use_log_pitch} 127 | use_beta_binomial_interpolator: true 128 | speaker_emb_path: ${speaker_emb_path} 129 | 130 | dataloader_params: 131 | drop_last: false 132 | shuffle: true 133 | batch_size: 24 134 | num_workers: 24 135 | pin_memory: true 136 | 137 | validation_ds: 138 | dataset: 139 | _target_: utils.data.TTSDataset 140 | manifest_filepath: ${validation_datasets} 141 | sample_rate: ${model.sample_rate} 142 | sup_data_path: ${sup_data_path} 143 | sup_data_types: ${sup_data_types} 144 | n_fft: ${model.n_fft} 145 | win_length: ${model.n_window_size} 146 | hop_length: ${model.n_window_stride} 147 | window: ${model.window} 148 | n_mels: ${model.n_mel_channels} 149 | lowfreq: ${model.lowfreq} 150 | highfreq: ${model.highfreq} 151 | max_duration: null 152 | min_duration: null 153 | ignore_file: null 154 | trim: false 155 | pitch_fmin: ${model.pitch_fmin} 156 | pitch_fmax: ${model.pitch_fmax} 157 | pitch_norm: ${model.pitch_norm} 158 | pitch_mean: ${model.pitch_mean} 159 | pitch_std: ${model.pitch_std} 160 | use_log_pitch: ${model.use_log_pitch} 161 | use_beta_binomial_interpolator: true 162 | speaker_emb_path: ${speaker_emb_path} 163 | 164 | dataloader_params: 165 | drop_last: false 166 | shuffle: false 167 | batch_size: 4 168 | num_workers: 4 169 | pin_memory: true 170 | 171 | preprocessor: 172 | _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor 173 | features: ${model.n_mel_channels} 174 | lowfreq: ${model.lowfreq} 175 | highfreq: ${model.highfreq} 176 | n_fft: ${model.n_fft} 177 | n_window_size: ${model.n_window_size} 178 | window_size: false 179 | n_window_stride: ${model.n_window_stride} 180 | window_stride: false 181 | pad_to: 1 #16 182 | pad_value: ${model.pad_value} 183 | sample_rate: ${model.sample_rate} 184 | window: ${model.window} 185 | normalize: null 186 | preemph: null 187 | dither: 0.0 188 | frame_splicing: 1 189 | log: true 190 | log_zero_guard_type: add 191 | log_zero_guard_value: 1e-05 192 | mag_power: 1.0 193 | 194 | encoder: 195 | _target_: modules.glow_tts_with_pitch.TextEncoder 196 | n_vocab: 148 197 | out_channels: ${model.n_mel_channels} 198 | hidden_channels: 192 #192 #256 if using stoch_dur_pred 199 | filter_channels: 768 200 | filter_channels_dp: 256 #256 if not using stoch_dur_pred 201 | kernel_size: 3 202 | p_dropout: 0.1 203 | n_layers: 6 204 | n_heads: 2 205 | window_size: 4 206 | prenet: true 207 | mean_only: true 208 | gin_channels: ${gin_channels} 209 | use_stoch_dur_pred: ${use_stoch_dur_pred} 210 | 211 | decoder: 212 | _target_: modules.glow_tts_with_pitch.FlowSpecDecoder 213 | in_channels: ${model.n_mel_channels} 214 | hidden_channels: 192 #256 if using stoch_dur_pred 215 | kernel_size: 5 216 | n_blocks: 12 217 | n_layers: 4 218 | n_sqz: 2 219 | n_split: 4 220 | sigmoid_scale: false 221 | p_dropout: 0.05 222 | dilation_rate: 1 223 | gin_channels: ${gin_channels} 224 | 225 | pitch_predictor: 226 | _target_: modules.glow_tts_modules.StochasticPitchPredictor 227 | in_channels: 192 #80 #192 228 | filter_channels: 256 229 | kernel_size: 3 230 | p_dropout: 0.1 231 | n_flows: 4 232 | gin_channels: ${gin_channels} 233 | 234 | optim: 235 | name: radam 236 | lr: 2e-4 237 | # optimizer arguments 238 | betas: [0.9, 0.98] 239 | weight_decay: 0.0 240 | 241 | # scheduler setup 242 | sched: 243 | name: CosineAnnealing 244 | 245 | # Scheduler params 246 | warmup_steps: 6000 247 | min_lr: 1e-5 248 | last_epoch: -1 249 | 250 | 251 | trainer: 252 | accelerator: auto 253 | devices: -1 # number of gpus 254 | strategy: ddp 255 | num_nodes: 1 256 | enable_checkpointing: false # Provided by exp_manager 257 | logger: false # Provided by exp_manager 258 | max_epochs: 1000 259 | max_steps: -1 # computed at runtime if not set 260 | accumulate_grad_batches: 2 261 | log_every_n_steps: 100 # Interval of logging. 262 | check_val_every_n_epoch: 2 263 | amp_backend: native 264 | precision: 16 # mixed-precision training 265 | gradient_clip_val: 5.0 266 | 267 | exp_manager: 268 | exp_dir: null 269 | name: ${name} 270 | resume_if_exists: False 271 | resume_ignore_no_checkpoint: True 272 | create_tensorboard_logger: True 273 | create_checkpoint_callback: True 274 | checkpoint_callback_params: 275 | always_save_nemo: True 276 | save_top_k: 1 277 | monitor: "val_loss" 278 | mode: "min" 279 | create_early_stopping_callback: False 280 | early_stopping_params: 281 | monitor: "val_loss" 282 | patience: 10 283 | verbose: True 284 | mode: "min" 285 | create_wandb_logger: False 286 | wandb_logger_kwargs: 287 | name: null 288 | project: null 289 | 290 | hydra: 291 | run: 292 | dir: . 293 | job_logging: 294 | root: 295 | handlers: null -------------------------------------------------------------------------------- /conf/hifigan16k_ft.yaml: -------------------------------------------------------------------------------- 1 | name: "HifiGAN_CV_16k" 2 | train_dataset: ??? 3 | validation_datasets: ??? 4 | 5 | defaults: 6 | - model/generator: v1 7 | - model/train_ds: train_ds 8 | - model/validation_ds: val_ds 9 | 10 | model: 11 | preprocessor: 12 | _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures 13 | dither: 0.0 14 | frame_splicing: 1 15 | nfilt: 80 16 | highfreq: 8000 17 | log: true 18 | log_zero_guard_type: clamp 19 | log_zero_guard_value: 1e-05 20 | lowfreq: 0 21 | mag_power: 1.0 22 | n_fft: 1024 23 | n_window_size: 1024 24 | n_window_stride: 256 25 | normalize: null 26 | pad_to: 0 27 | pad_value: -11.52 28 | preemph: null 29 | sample_rate: 16000 30 | window: hann 31 | use_grads: false 32 | exact_pad: true 33 | 34 | optim: 35 | _target_: torch.optim.AdamW 36 | lr: 0.0002 37 | betas: [0.8, 0.99] 38 | 39 | sched: 40 | name: CosineAnnealing 41 | min_lr: 1e-8 # 1e-5 42 | warmup_ratio: 0.02 43 | 44 | max_steps: 2500000 45 | l1_loss_factor: 45 46 | denoise_strength: 0.0025 47 | 48 | trainer: 49 | accelerator: auto 50 | devices: -1 # number of gpus 51 | strategy: ddp 52 | num_nodes: 1 53 | enable_checkpointing: false # Provided by exp_manager 54 | logger: false # Provided by exp_manager 55 | max_steps: ${model.max_steps} 56 | accumulate_grad_batches: 1 57 | log_every_n_steps: 100 58 | check_val_every_n_epoch: 1 59 | 60 | exp_manager: 61 | exp_dir: null 62 | name: ${name} 63 | create_tensorboard_logger: True 64 | create_checkpoint_callback: True 65 | checkpoint_callback_params: 66 | save_top_k: 1 67 | monitor: "val_loss" 68 | mode: "min" 69 | early_stopping_params: 70 | monitor: "val_loss" 71 | patience: 10 72 | verbose: True 73 | mode: "min" 74 | 75 | hydra: 76 | run: 77 | dir: . 78 | job_logging: 79 | root: 80 | handlers: null 81 | 82 | -------------------------------------------------------------------------------- /glow_tts_with_pitch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytorch_lightning as pl 16 | 17 | from nemo.collections.common.callbacks import LogEpochTimeCallback 18 | from nemo.core.config import hydra_runner 19 | from nemo.utils.exp_manager import exp_manager 20 | 21 | from models.glow_tts_with_pitch import GlowTTSModel 22 | 23 | @hydra_runner(config_path="conf", config_name="glow_tts") 24 | def main(cfg): 25 | trainer = pl.Trainer(**cfg.trainer) 26 | exp_manager(trainer, cfg.get("exp_manager", None)) 27 | model = GlowTTSModel(cfg=cfg.model, trainer=trainer) 28 | model.maybe_init_from_pretrained_checkpoint(cfg=cfg) 29 | lr_logger = pl.callbacks.LearningRateMonitor() 30 | epoch_time_logger = LogEpochTimeCallback() 31 | trainer.callbacks.extend([lr_logger, epoch_time_logger]) 32 | trainer.fit(model) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() # noqa pylint: disable=no-value-for-parameter -------------------------------------------------------------------------------- /media/glow_tts_stdp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/media/glow_tts_stdp.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/models/__init__.py -------------------------------------------------------------------------------- /models/glow_tts_with_pitch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import contextlib 15 | import math 16 | from copy import deepcopy 17 | from dataclasses import dataclass 18 | from typing import Any, Dict, Optional 19 | 20 | import torch 21 | import torch.utils.data 22 | from hydra.utils import instantiate 23 | from omegaconf import MISSING, DictConfig, OmegaConf 24 | from pytorch_lightning import Trainer 25 | from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger 26 | 27 | from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations 28 | from utils.helpers import ( 29 | log_audio_to_tb, 30 | plot_alignment_to_numpy, 31 | plot_spectrogram_to_numpy, 32 | process_batch, 33 | ) 34 | from utils.glow_tts_loss import GlowTTSLoss 35 | from utils.data import load_speaker_emb 36 | from nemo.collections.tts.losses.fastpitchloss import PitchLoss 37 | from nemo.collections.tts.models.base import SpectrogramGenerator 38 | 39 | from modules.glow_tts_with_pitch import GlowTTSModule 40 | from modules.glow_tts_modules.glow_tts_submodules import sequence_mask 41 | 42 | from nemo.core.classes.common import PretrainedModelInfo, typecheck 43 | from nemo.core.neural_types.elements import ( 44 | AcousticEncodedRepresentation, 45 | LengthsType, 46 | MelSpectrogramType, 47 | TokenIndex, 48 | RegressionValuesType, 49 | ) 50 | from nemo.core.neural_types.neural_type import NeuralType 51 | from nemo.utils import logging 52 | 53 | 54 | @dataclass 55 | class GlowTTSConfig: 56 | encoder: Dict[Any, Any] = MISSING 57 | decoder: Dict[Any, Any] = MISSING 58 | parser: Dict[Any, Any] = MISSING 59 | preprocessor: Dict[Any, Any] = MISSING 60 | train_ds: Optional[Dict[Any, Any]] = None 61 | validation_ds: Optional[Dict[Any, Any]] = None 62 | test_ds: Optional[Dict[Any, Any]] = None 63 | 64 | 65 | class GlowTTSModel(SpectrogramGenerator): 66 | """ 67 | GlowTTS model used to generate spectrograms from text 68 | Consists of a text encoder and an invertible spectrogram decoder 69 | """ 70 | 71 | def __init__(self, cfg: DictConfig, trainer: Trainer = None): 72 | if isinstance(cfg, dict): 73 | cfg = OmegaConf.create(cfg) 74 | 75 | # Setup normalizer 76 | self.normalizer = None 77 | self.text_normalizer_call = None 78 | self.text_normalizer_call_kwargs = {} 79 | self._setup_normalizer(cfg) 80 | 81 | # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True) 82 | input_fft_kwargs = {} 83 | 84 | self.vocab = None 85 | 86 | self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1] 87 | 88 | if self.ds_class_name == "TTSDataset": 89 | self._setup_tokenizer(cfg) 90 | assert self.vocab is not None 91 | input_fft_kwargs["n_embed"] = len(self.vocab.tokens) 92 | input_fft_kwargs["padding_idx"] = self.vocab.pad 93 | 94 | self._parser = None 95 | 96 | super().__init__(cfg=cfg, trainer=trainer) 97 | 98 | schema = OmegaConf.structured(GlowTTSConfig) 99 | # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes 100 | if isinstance(cfg, dict): 101 | cfg = OmegaConf.create(cfg) 102 | elif not isinstance(cfg, DictConfig): 103 | raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig") 104 | # Ensure passed cfg is compliant with schema 105 | OmegaConf.merge(cfg, schema) 106 | 107 | self.preprocessor = instantiate(self._cfg.preprocessor) 108 | 109 | # if self._cfg.parser.get("add_blank"): 110 | self._cfg.encoder.n_vocab = len(self.vocab.tokens) 111 | encoder = instantiate(self._cfg.encoder) 112 | decoder = instantiate(self._cfg.decoder) 113 | pitch_stats = None 114 | unvoiced_value = 0.0 115 | if "pitch_predictor" in self._cfg: 116 | pitch_predictor = instantiate(self._cfg.pitch_predictor) 117 | # inject into encoder 118 | if self._cfg.get("unvoiced_value") is not None: 119 | unvoiced_value = self._cfg.get("unvoiced_value") 120 | elif self._cfg.get("use_normalized_pitch"): 121 | # compute unoviced value 122 | 123 | pitch_mean, pitch_std = self._cfg.get("pitch_mean", 0.0), self._cfg.get("pitch_std", 1.0) 124 | pitch_mean = pitch_mean or 0.0 125 | pitch_std = pitch_std or 1.0 126 | unvoiced_value = -pitch_mean/pitch_std 127 | pitch_stats = {"pitch_mean": pitch_mean, "pitch_std": pitch_std, 128 | # "pitch_fmin": self._cfg.get("pitch_fmin"), 129 | } 130 | else: 131 | pitch_predictor = None 132 | self.glow_tts = GlowTTSModule( 133 | encoder, 134 | decoder, 135 | pitch_predictor, 136 | n_speakers=cfg.n_speakers, 137 | gin_channels=cfg.gin_channels, 138 | use_external_speaker_emb=cfg.get("use_external_speaker_emb", False), 139 | use_stoch_dur_pred=cfg.get("use_stoch_dur_pred", False), 140 | use_stoch_pitch_pred=cfg.get("use_stoch_pitch_pred", False), 141 | unvoiced_value = unvoiced_value, 142 | use_log_pitch = cfg.get("use_log_pitch", False), 143 | use_normalized_pitch = cfg.get("use_normalized_pitch", False), 144 | use_frame_emb_for_pitch = cfg.get("use_frame_emb_for_pitch", False), 145 | pitch_stats=pitch_stats, 146 | ) 147 | self.loss = GlowTTSLoss() 148 | if pitch_predictor is not None: 149 | self.pitch_loss_scale = cfg.get("pitch_loss_scale", 0.1) 150 | self.pitch_loss_fn = PitchLoss(loss_scale=self.pitch_loss_scale) 151 | 152 | else: 153 | self.pitch_loss_scale = 0.0 154 | self.pitch_loss_fn = None 155 | 156 | 157 | def parse(self, str_input: str, normalize=True) -> torch.tensor: 158 | if str_input[-1] not in [".", "!", "?"]: 159 | str_input = str_input + "." 160 | 161 | if self.training: 162 | logging.warning("parse() is meant to be called in eval mode.") 163 | 164 | if normalize and self.text_normalizer_call is not None: 165 | str_input = self.text_normalizer_call(str_input, **self.text_normalizer_call_kwargs) 166 | 167 | eval_phon_mode = contextlib.nullcontext() 168 | if hasattr(self.vocab, "set_phone_prob"): 169 | eval_phon_mode = self.vocab.set_phone_prob(prob=1.0) 170 | 171 | # Disable mixed g2p representation if necessary 172 | with eval_phon_mode: 173 | tokens = self.parser(str_input) 174 | 175 | x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device) 176 | return x 177 | 178 | @property 179 | def parser(self): 180 | if self._parser is not None: 181 | return self._parser 182 | 183 | ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1] 184 | 185 | if ds_class_name == "TTSDataset": 186 | self._parser = self.vocab.encode 187 | else: 188 | raise ValueError(f"Unknown dataset class: {ds_class_name}") 189 | 190 | return self._parser 191 | 192 | @typecheck( 193 | input_types={ 194 | "x": NeuralType(("B", "T"), TokenIndex()), 195 | "x_lengths": NeuralType(("B"), LengthsType()), 196 | "y": NeuralType(("B", "D", "T"), MelSpectrogramType(), optional=True), 197 | "y_lengths": NeuralType(("B"), LengthsType(), optional=True), 198 | "gen": NeuralType(optional=True), 199 | "noise_scale": NeuralType(optional=True), 200 | "length_scale": NeuralType(optional=True), 201 | "speaker": NeuralType(("B"), TokenIndex(), optional=True), 202 | "speaker_embeddings": NeuralType( 203 | ("B", "D"), AcousticEncodedRepresentation(), optional=True 204 | ), 205 | "stoch_dur_noise_scale": NeuralType(optional=True), 206 | "stoch_pitch_noise_scale": NeuralType(optional=True), 207 | "pitch_scale": NeuralType(optional=True), 208 | "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()), 209 | } 210 | ) 211 | 212 | def forward( 213 | self, 214 | *, 215 | x, 216 | x_lengths, 217 | y=None, 218 | y_lengths=None, 219 | speaker=None, 220 | gen=False, 221 | noise_scale=0.0, 222 | length_scale=1.0, 223 | speaker_embeddings=None, 224 | stoch_dur_noise_scale=1.0, 225 | stoch_pitch_noise_scale=1.0, 226 | pitch_scale=0.0, 227 | pitch=None, 228 | ): 229 | if gen: 230 | return self.glow_tts.generate_spect( 231 | text=x, 232 | text_lengths=x_lengths, 233 | noise_scale=noise_scale, 234 | length_scale=length_scale, 235 | speaker=speaker, 236 | speaker_embeddings=speaker_embeddings, 237 | stoch_dur_noise_scale=stoch_dur_noise_scale, 238 | stoch_pitch_noise_scale=stoch_pitch_noise_scale, 239 | pitch_scale=pitch_scale, 240 | ) 241 | else: 242 | return self.glow_tts( 243 | text=x, 244 | text_lengths=x_lengths, 245 | spect=y, 246 | spect_lengths=y_lengths, 247 | speaker=speaker, 248 | speaker_embeddings=speaker_embeddings, 249 | pitch=pitch, 250 | ) 251 | 252 | def step( 253 | self, 254 | y, 255 | y_lengths, 256 | x, 257 | x_lengths, 258 | speaker, 259 | speaker_embeddings, 260 | pitch, 261 | ): 262 | 263 | z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn, stoch_dur_loss, pitch, pitch_pred, pitch_loss = self( 264 | x=x, 265 | x_lengths=x_lengths, 266 | y=y, 267 | y_lengths=y_lengths, 268 | speaker=speaker, 269 | speaker_embeddings=speaker_embeddings, 270 | pitch=pitch, 271 | ) 272 | 273 | l_mle, l_length, logdet = self.loss( 274 | z=z, 275 | y_m=y_m, 276 | y_logs=y_logs, 277 | logdet=logdet, 278 | logw=logw, 279 | logw_=logw_, 280 | x_lengths=x_lengths, 281 | y_lengths=y_lengths, 282 | stoch_dur_loss=stoch_dur_loss, 283 | ) 284 | 285 | if self.pitch_loss_fn is not None: 286 | if pitch_loss is None: 287 | pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=x_lengths) 288 | else: 289 | pitch_loss = self.pitch_loss_scale*pitch_loss 290 | 291 | if pitch_loss is None: 292 | loss = sum([l_mle, l_length]) 293 | pitch_loss = torch.tensor([0.0]).to(device=l_mle.device) 294 | else: 295 | loss = sum([l_mle, l_length, pitch_loss]) 296 | 297 | 298 | return l_mle, l_length, logdet, loss, attn, pitch_loss 299 | 300 | def training_step(self, batch, batch_idx): 301 | 302 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) 303 | y = batch_dict.get("audio") 304 | y_lengths = batch_dict.get("audio_lens") 305 | x = batch_dict.get("text") 306 | x_lengths = batch_dict.get("text_lens") 307 | attn_prior = batch_dict.get("align_prior_matrix", None) 308 | pitch = batch_dict.get("pitch", None) 309 | energy = batch_dict.get("energy", None) 310 | speaker = batch_dict.get("speaker_id", None) 311 | speaker_embeddings = batch_dict.get("speaker_emb", None) 312 | 313 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) 314 | 315 | l_mle, l_length, logdet, loss, _, pitch_loss = self.step( 316 | y, y_lengths, x, x_lengths, speaker, speaker_embeddings, pitch 317 | ) 318 | 319 | output = { 320 | "loss": loss, # required 321 | "progress_bar": {"l_mle": l_mle, "l_length": l_length, "logdet": logdet}, 322 | "log": {"loss": loss, "l_mle": l_mle, "l_length": l_length, "logdet": logdet, "pitch_loss": pitch_loss}, 323 | } 324 | 325 | return output 326 | 327 | @torch.no_grad() 328 | def compute_likelihood(self, batch_dict): 329 | y = batch_dict.get("audio") 330 | y_lengths = batch_dict.get("audio_lens") 331 | x = batch_dict.get("text") 332 | x_lengths = batch_dict.get("text_lens") 333 | attn_prior = batch_dict.get("align_prior_matrix", None) 334 | pitch = batch_dict.get("pitch", None) 335 | energy = batch_dict.get("energy", None) 336 | speaker = batch_dict.get("speaker_id", None) 337 | speaker_embeddings = batch_dict.get("speaker_emb", None) 338 | 339 | y, y_lengths, x, x_lengths, speaker_embeddings = ( 340 | y.to(self.device), 341 | y_lengths.to(self.device), 342 | x.to(self.device), 343 | x_lengths.to(self.device), 344 | speaker_embeddings.to(self.device), 345 | ) 346 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) 347 | 348 | z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn, stoch_dur_loss, pitch, pitch_pred, pitch_loss = self( 349 | x=x, 350 | x_lengths=x_lengths, 351 | y=y, 352 | y_lengths=y_lengths, 353 | speaker=speaker, 354 | speaker_embeddings=speaker_embeddings, 355 | pitch=pitch, 356 | ) 357 | 358 | l_mle_normal = 0.5 * math.log(2 * math.pi) + ( 359 | torch.sum(y_logs) + 0.5 * torch.sum(torch.exp(-2 * y_logs) * (z - y_m) ** 2) 360 | ) / (torch.sum(y_lengths) * z.shape[1]) 361 | 362 | return l_mle_normal 363 | 364 | def validation_step(self, batch, batch_idx): 365 | 366 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) 367 | y = batch_dict.get("audio") 368 | y_lengths = batch_dict.get("audio_lens") 369 | x = batch_dict.get("text") 370 | x_lengths = batch_dict.get("text_lens") 371 | attn_prior = batch_dict.get("align_prior_matrix", None) 372 | pitch = batch_dict.get("pitch", None) 373 | energy = batch_dict.get("energy", None) 374 | speaker = batch_dict.get("speaker_id", None) 375 | speaker_embeddings = batch_dict.get("speaker_emb", None) 376 | 377 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) 378 | 379 | l_mle, l_length, logdet, loss, attn, pitch_loss = self.step( 380 | y, 381 | y_lengths, 382 | x, 383 | x_lengths, 384 | speaker, 385 | speaker_embeddings, 386 | pitch, 387 | ) 388 | 389 | y_gen, attn_gen = self( 390 | x=x, 391 | x_lengths=x_lengths, 392 | gen=True, 393 | speaker=speaker, 394 | speaker_embeddings=speaker_embeddings, 395 | pitch=None, # use predicted pitch 396 | noise_scale=0.667, 397 | ) 398 | 399 | return { 400 | "loss": loss, 401 | "l_mle": l_mle, 402 | "l_length": l_length, 403 | "logdet": logdet, 404 | "y": y, 405 | "y_gen": y_gen, 406 | "x": x, 407 | "attn": attn, 408 | "attn_gen": attn_gen, 409 | "progress_bar": {"l_mle": l_mle, "l_length": l_length, "logdet": logdet}, 410 | "pitch_loss": pitch_loss, 411 | } 412 | 413 | def validation_epoch_end(self, outputs): 414 | avg_loss = torch.stack([x["loss"] for x in outputs]).mean() 415 | avg_mle = torch.stack([x["l_mle"] for x in outputs]).mean() 416 | avg_length_loss = torch.stack([x["l_length"] for x in outputs]).mean() 417 | avg_logdet = torch.stack([x["logdet"] for x in outputs]).mean() 418 | avg_pitch_loss = torch.stack([x["pitch_loss"] for x in outputs]).mean() 419 | tensorboard_logs = { 420 | "val_loss": avg_loss, 421 | "val_mle": avg_mle, 422 | "val_length_loss": avg_length_loss, 423 | "val_logdet": avg_logdet, 424 | "val_pitch_loss": avg_pitch_loss, 425 | } 426 | if self.logger is not None and self.logger.experiment is not None: 427 | tb_logger = self.logger.experiment 428 | if isinstance(self.logger, LoggerCollection): 429 | for logger in self.logger: 430 | if isinstance(logger, TensorBoardLogger): 431 | tb_logger = logger.experiment 432 | break 433 | 434 | separated_tokens = self.vocab.decode(outputs[0]["x"][0]) 435 | 436 | tb_logger.add_text("separated tokens", separated_tokens, self.global_step) 437 | tb_logger.add_image( 438 | "real_spectrogram", 439 | plot_spectrogram_to_numpy(outputs[0]["y"][0].data.cpu().numpy()), 440 | self.global_step, 441 | dataformats="HWC", 442 | ) 443 | tb_logger.add_image( 444 | "generated_spectrogram", 445 | plot_spectrogram_to_numpy(outputs[0]["y_gen"][0].data.cpu().numpy()), 446 | self.global_step, 447 | dataformats="HWC", 448 | ) 449 | tb_logger.add_image( 450 | "alignment_for_real_sp", 451 | plot_alignment_to_numpy(outputs[0]["attn"][0].data.cpu().numpy()), 452 | self.global_step, 453 | dataformats="HWC", 454 | ) 455 | tb_logger.add_image( 456 | "alignment_for_generated_sp", 457 | plot_alignment_to_numpy(outputs[0]["attn_gen"][0].data.cpu().numpy()), 458 | self.global_step, 459 | dataformats="HWC", 460 | ) 461 | log_audio_to_tb(tb_logger, outputs[0]["y"][0], "true_audio_gf", self.global_step) 462 | log_audio_to_tb( 463 | tb_logger, outputs[0]["y_gen"][0], "generated_audio_gf", self.global_step 464 | ) 465 | self.log("val_loss", avg_loss) 466 | return {"val_loss": avg_loss, "log": tensorboard_logs} 467 | 468 | def _setup_normalizer(self, cfg): 469 | if "text_normalizer" in cfg: 470 | normalizer_kwargs = {} 471 | 472 | if "whitelist" in cfg.text_normalizer: 473 | normalizer_kwargs["whitelist"] = self.register_artifact( 474 | 'text_normalizer.whitelist', cfg.text_normalizer.whitelist 475 | ) 476 | 477 | try: 478 | self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs) 479 | except Exception as e: 480 | logging.error(e) 481 | raise ImportError( 482 | "`pynini` not installed, please install via NeMo/nemo_text_processing/pynini_install.sh" 483 | ) 484 | 485 | self.text_normalizer_call = self.normalizer.normalize 486 | if "text_normalizer_call_kwargs" in cfg: 487 | self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs 488 | 489 | def _setup_tokenizer(self, cfg): 490 | text_tokenizer_kwargs = {} 491 | 492 | if "phoneme_dict" in cfg.text_tokenizer: 493 | text_tokenizer_kwargs["phoneme_dict"] = self.register_artifact( 494 | "text_tokenizer.phoneme_dict", cfg.text_tokenizer.phoneme_dict, 495 | ) 496 | if "heteronyms" in cfg.text_tokenizer: 497 | text_tokenizer_kwargs["heteronyms"] = self.register_artifact( 498 | "text_tokenizer.heteronyms", cfg.text_tokenizer.heteronyms, 499 | ) 500 | 501 | if "g2p" in cfg.text_tokenizer: 502 | g2p_kwargs = {} 503 | 504 | if "phoneme_dict" in cfg.text_tokenizer.g2p: 505 | g2p_kwargs["phoneme_dict"] = self.register_artifact( 506 | 'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict, 507 | ) 508 | 509 | if "heteronyms" in cfg.text_tokenizer.g2p: 510 | g2p_kwargs["heteronyms"] = self.register_artifact( 511 | 'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms, 512 | ) 513 | 514 | text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs) 515 | 516 | self.vocab = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs) 517 | 518 | def _setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): 519 | if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): 520 | raise ValueError(f"No dataset for {name}") 521 | if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig): 522 | raise ValueError(f"No dataloder_params for {name}") 523 | if shuffle_should_be: 524 | if 'shuffle' not in cfg.dataloader_params: 525 | logging.warning( 526 | f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " 527 | "config. Manually setting to True" 528 | ) 529 | with open_dict(cfg.dataloader_params): 530 | cfg.dataloader_params.shuffle = True 531 | elif not cfg.dataloader_params.shuffle: 532 | logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!") 533 | elif cfg.dataloader_params.shuffle: 534 | logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!") 535 | 536 | if cfg.dataset._target_ == "utils.data.TTSDataset": 537 | phon_mode = contextlib.nullcontext() 538 | if hasattr(self.vocab, "set_phone_prob"): 539 | phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability) 540 | 541 | with phon_mode: 542 | print("I got here!!!!!!!", self.vocab) 543 | dataset = instantiate( 544 | cfg.dataset, 545 | text_normalizer=self.normalizer, 546 | text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, 547 | text_tokenizer=self.vocab, 548 | ) 549 | else: 550 | dataset = instantiate(cfg.dataset) 551 | 552 | return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) 553 | 554 | def setup_training_data(self, train_data_config: Optional[DictConfig]): 555 | self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, name="train") 556 | 557 | def setup_validation_data(self, val_data_config: Optional[DictConfig]): 558 | self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, shuffle_should_be=False, name="val") 559 | 560 | def setup_test_data(self, test_data_config: Optional[DictConfig]): 561 | self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config, shuffle_should_be=False, name="test") 562 | 563 | 564 | def generate_spectrogram( 565 | self, 566 | tokens: "torch.tensor", 567 | noise_scale: float = 0.0, 568 | length_scale: float = 1.0, 569 | speaker: int = None, 570 | speaker_embeddings: "torch.tensor" = None, 571 | stoch_dur_noise_scale: float = 1.0, 572 | stoch_pitch_noise_scale: float = 1.0, 573 | pitch_scale: float = 0.0, 574 | ) -> torch.tensor: 575 | 576 | self.eval() 577 | 578 | token_len = torch.tensor([tokens.shape[1]]).to(self.device) 579 | 580 | if isinstance(speaker, int): 581 | speaker = torch.tensor([speaker]).to(self.device) 582 | else: 583 | speaker = None 584 | 585 | if speaker_embeddings is not None: 586 | speaker_embeddings = speaker_embeddings.to(self.device) 587 | 588 | spect, _ = self( 589 | x=tokens, 590 | x_lengths=token_len, 591 | speaker=speaker, 592 | gen=True, 593 | noise_scale=noise_scale, 594 | length_scale=length_scale, 595 | speaker_embeddings=speaker_embeddings, 596 | stoch_dur_noise_scale=stoch_dur_noise_scale, 597 | stoch_pitch_noise_scale=stoch_pitch_noise_scale, 598 | pitch_scale=pitch_scale, 599 | ) 600 | 601 | return spect 602 | 603 | @torch.no_grad() 604 | def generate_spectrogram_with_mas( 605 | self, 606 | batch, 607 | noise_scale: float = 0.667, 608 | length_scale: float = 1.0, 609 | external_speaker: int = None, 610 | external_speaker_embeddings: "torch.tensor" = None, 611 | gen: bool = False, 612 | randomize_speaker=False, 613 | ) -> torch.tensor: 614 | """Forced aligned generation of synthetic mel-spectrogram of mel-spectrograms.""" 615 | # y is audio or melspec, x is text token ids 616 | self.eval() 617 | 618 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) 619 | y = batch_dict.get("audio") 620 | y_lengths = batch_dict.get("audio_lens") 621 | x = batch_dict.get("text") 622 | x_lengths = batch_dict.get("text_lens") 623 | attn_prior = batch_dict.get("align_prior_matrix", None) 624 | pitch = batch_dict.get("pitch", None) 625 | energy = batch_dict.get("energy", None) 626 | speaker = batch_dict.get("speaker_id", None) 627 | speaker_embeddings = batch_dict.get("speaker_emb", None) 628 | 629 | 630 | y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) 631 | 632 | if len(speaker_embeddings.shape) == 1: 633 | speaker_embeddings = speaker_embeddings.unsqueeze(0) 634 | speaker = None 635 | if speaker_embeddings is not None: 636 | speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(-1).to(self.device) 637 | 638 | y, y_lengths, x, x_lengths, speaker_embeddings = ( 639 | y.to(self.device), 640 | y_lengths.to(self.device), 641 | x.to(self.device), 642 | x_lengths.to(self.device), 643 | speaker_embeddings, 644 | ) 645 | (z, y_m, y_logs, logdet, log_durs_predicted, log_durs_extracted, \ 646 | spect_lengths, attn, stoch_dur_loss, pitch, pitch_pred, pitch_pred_loss 647 | ) = self( 648 | x=x, 649 | x_lengths=x_lengths, 650 | y=y, 651 | y_lengths=y_lengths, 652 | speaker=speaker, 653 | gen=gen, 654 | speaker_embeddings=speaker_embeddings.squeeze(-1), 655 | pitch=pitch, 656 | ) 657 | 658 | y_max_length = z.size(2) 659 | 660 | y_mask = torch.unsqueeze(sequence_mask(spect_lengths, y_max_length), 1) 661 | 662 | # predicted aligned feature 663 | z = (y_m + torch.exp(y_logs) * torch.randn_like(y_m) * noise_scale) * y_mask 664 | 665 | if external_speaker_embeddings is not None: 666 | external_speaker_embeddings = torch.nn.functional.normalize( 667 | external_speaker_embeddings 668 | ).unsqueeze(-1) 669 | 670 | if randomize_speaker: 671 | # use a random speaker embedding to decode utterance 672 | idx = torch.randperm(speaker_embeddings.shape[0]) 673 | speaker_embeddings = speaker_embeddings[idx].view(speaker_embeddings.size()) 674 | if len(speaker_embeddings.shape) == 2: 675 | speaker_embeddings = speaker_embeddings.unsqueeze(-1).to(self.device) 676 | 677 | # invert with same or different speaker through the decoder 678 | y_pred, _ = self.glow_tts.decoder( 679 | spect=z.to(self.device), 680 | spect_mask=y_mask.to(self.device), 681 | speaker_embeddings=external_speaker_embeddings or speaker_embeddings, 682 | reverse=True, pitch=pitch, 683 | ) 684 | return y_pred, spect_lengths 685 | 686 | 687 | @classmethod 688 | def list_available_models(cls) -> "List[PretrainedModelInfo]": 689 | """ 690 | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. 691 | Returns: 692 | List of available pre-trained models. 693 | """ 694 | list_of_models = [] 695 | model = PretrainedModelInfo( 696 | pretrained_model_name="tts_en_glowtts", 697 | location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_glowtts/versions/1.0.0rc1/files/tts_en_glowtts.nemo", 698 | description="This model is trained on LJSpeech sampled at 22050Hz, and can be used to generate female English voices with an American accent.", 699 | class_=cls, 700 | aliases=["GlowTTS-22050Hz"], 701 | ) 702 | list_of_models.append(model) 703 | return list_of_models -------------------------------------------------------------------------------- /modules/glow_tts_modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.glow_tts_modules.glow_tts_submodules_with_pitch import StochasticPitchPredictor -------------------------------------------------------------------------------- /modules/glow_tts_modules/glow_tts_submodules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # MIT License 16 | # 17 | # Copyright (c) 2020 Jaehyeon Kim 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | 37 | import math 38 | 39 | import numpy as np 40 | import torch 41 | from torch import nn 42 | from torch.nn import functional as F 43 | 44 | from modules.glow_tts_modules import stocpred_modules 45 | 46 | 47 | def convert_pad_shape(pad_shape): 48 | """ 49 | Used to get arguments for F.pad 50 | """ 51 | l = pad_shape[::-1] 52 | pad_shape = [item for sublist in l for item in sublist] 53 | return pad_shape 54 | 55 | 56 | def sequence_mask(length, max_length=None): 57 | """ 58 | Get masks for given lengths 59 | """ 60 | if max_length is None: 61 | max_length = length.max() 62 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 63 | 64 | return x.unsqueeze(0) < length.unsqueeze(1) 65 | 66 | 67 | def maximum_path(value, mask, max_neg_val=None): 68 | """ 69 | Monotonic alignment search algorithm 70 | Numpy-friendly version. It's about 4 times faster than torch version. 71 | value: [b, t_x, t_y] 72 | mask: [b, t_x, t_y] 73 | """ 74 | if max_neg_val is None: 75 | max_neg_val = -np.inf # Patch for Sphinx complaint 76 | value = value * mask 77 | 78 | device = value.device 79 | dtype = value.dtype 80 | value = value.cpu().detach().numpy() 81 | mask = mask.cpu().detach().numpy().astype(np.bool) 82 | 83 | b, t_x, t_y = value.shape 84 | direction = np.zeros(value.shape, dtype=np.int64) 85 | v = np.zeros((b, t_x), dtype=np.float32) 86 | x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) 87 | for j in range(t_y): 88 | v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] 89 | v1 = v 90 | max_mask = v1 >= v0 91 | v_max = np.where(max_mask, v1, v0) 92 | direction[:, :, j] = max_mask 93 | 94 | index_mask = x_range <= j 95 | v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) 96 | direction = np.where(mask, direction, 1) 97 | 98 | path = np.zeros(value.shape, dtype=np.float32) 99 | index = mask[:, :, 0].sum(1).astype(np.int64) - 1 100 | index_range = np.arange(b) 101 | try: 102 | for j in reversed(range(t_y)): 103 | path[index_range, index, j] = 1 104 | index = index + direction[index_range, index, j] - 1 105 | except Exception as e: 106 | print("index range", index_range) 107 | print(e) 108 | path = path * mask.astype(np.float32) 109 | path = torch.from_numpy(path).to(device=device, dtype=dtype) 110 | return path 111 | 112 | 113 | def generate_path(duration, mask): 114 | """ 115 | Generate alignment based on predicted durations 116 | duration: [b, t_x] 117 | mask: [b, t_x, t_y] 118 | """ 119 | 120 | b, t_x, t_y = mask.shape 121 | cum_duration = torch.cumsum(duration, 1) 122 | 123 | cum_duration_flat = cum_duration.view(b * t_x) 124 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 125 | path = path.view(b, t_x, t_y) 126 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 127 | path = path * mask 128 | return path 129 | 130 | 131 | @torch.jit.script 132 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 133 | n_channels_int = n_channels[0] 134 | in_act = input_a + input_b 135 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 136 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 137 | acts = t_act * s_act 138 | return acts 139 | 140 | 141 | class LayerNorm(nn.Module): 142 | def __init__(self, channels, eps=1e-4): 143 | super().__init__() 144 | self.channels = channels 145 | self.eps = eps 146 | 147 | self.gamma = nn.Parameter(torch.ones(channels)) 148 | self.beta = nn.Parameter(torch.zeros(channels)) 149 | 150 | def forward(self, x): 151 | n_dims = len(x.shape) 152 | mean = torch.mean(x, 1, keepdim=True) 153 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True) 154 | 155 | x = (x - mean) * torch.rsqrt(variance + self.eps) 156 | 157 | shape = [1, -1] + [1] * (n_dims - 2) 158 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 159 | return x 160 | 161 | 162 | class ConvReluNorm(nn.Module): 163 | def __init__( 164 | self, 165 | in_channels, 166 | hidden_channels, 167 | out_channels, 168 | kernel_size, 169 | n_layers, 170 | p_dropout, 171 | ): 172 | super().__init__() 173 | self.in_channels = in_channels 174 | self.hidden_channels = hidden_channels 175 | self.out_channels = out_channels 176 | self.kernel_size = kernel_size 177 | self.n_layers = n_layers 178 | self.p_dropout = p_dropout 179 | assert n_layers > 1, "Number of layers should be larger than 0." 180 | 181 | self.conv_layers = nn.ModuleList() 182 | self.norm_layers = nn.ModuleList() 183 | self.conv_layers.append( 184 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) 185 | ) 186 | self.norm_layers.append(LayerNorm(hidden_channels)) 187 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 188 | for _ in range(n_layers - 1): 189 | self.conv_layers.append( 190 | nn.Conv1d( 191 | hidden_channels, 192 | hidden_channels, 193 | kernel_size, 194 | padding=kernel_size // 2, 195 | ) 196 | ) 197 | self.norm_layers.append(LayerNorm(hidden_channels)) 198 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 199 | self.proj.weight.data.zero_() 200 | self.proj.bias.data.zero_() 201 | 202 | def forward(self, x, x_mask): 203 | x_org = x 204 | for i in range(self.n_layers): 205 | x = self.conv_layers[i](x * x_mask) 206 | x = self.norm_layers[i](x) 207 | x = self.relu_drop(x) 208 | x = x_org + self.proj(x) 209 | return x * x_mask 210 | 211 | 212 | class WN(torch.nn.Module): 213 | def __init__( 214 | self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0 215 | ): 216 | super(WN, self).__init__() 217 | assert kernel_size % 2 == 1 218 | assert hidden_channels % 2 == 0 219 | 220 | self.hidden_channels = hidden_channels 221 | self.n_layers = n_layers 222 | 223 | self.in_layers = torch.nn.ModuleList() 224 | self.res_skip_layers = torch.nn.ModuleList() 225 | self.drop = nn.Dropout(p_dropout) 226 | 227 | if gin_channels != 0: 228 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 229 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 230 | 231 | for i in range(n_layers): 232 | dilation = dilation_rate ** i 233 | padding = int((kernel_size * dilation - dilation) / 2) 234 | in_layer = torch.nn.Conv1d( 235 | hidden_channels, 236 | 2 * hidden_channels, 237 | kernel_size, 238 | dilation=dilation, 239 | padding=padding, 240 | ) 241 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 242 | self.in_layers.append(in_layer) 243 | 244 | # last one is not necessary 245 | if i < n_layers - 1: 246 | res_skip_channels = 2 * hidden_channels 247 | else: 248 | res_skip_channels = hidden_channels 249 | 250 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 251 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 252 | self.res_skip_layers.append(res_skip_layer) 253 | 254 | def forward(self, x, x_mask=None, g=None, **kwargs): 255 | output = torch.zeros_like(x) 256 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 257 | 258 | if g is not None: 259 | g = self.cond_layer(g) 260 | 261 | for i in range(self.n_layers): 262 | x_in = self.in_layers[i](x) 263 | x_in = self.drop(x_in) 264 | if g is not None: 265 | cond_offset = i * 2 * self.hidden_channels 266 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 267 | else: 268 | g_l = torch.zeros_like(x_in) 269 | 270 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 271 | 272 | res_skip_acts = self.res_skip_layers[i](acts) 273 | if i < self.n_layers - 1: 274 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask 275 | output = output + res_skip_acts[:, self.hidden_channels :, :] 276 | else: 277 | output = output + res_skip_acts 278 | return output * x_mask 279 | 280 | def remove_weight_norm(self): 281 | for l in self.in_layers: 282 | torch.nn.utils.remove_weight_norm(l) 283 | for l in self.res_skip_layers: 284 | torch.nn.utils.remove_weight_norm(l) 285 | 286 | 287 | class ActNorm(nn.Module): 288 | def __init__(self, channels, ddi=False, **kwargs): 289 | super().__init__() 290 | self.channels = channels 291 | self.initialized = not ddi 292 | 293 | self.logs = nn.Parameter(torch.zeros(1, channels, 1)) 294 | self.bias = nn.Parameter(torch.zeros(1, channels, 1)) 295 | 296 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 297 | if x_mask is None: 298 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) 299 | x_len = torch.sum(x_mask, [1, 2]) 300 | if not self.initialized: 301 | self.initialize(x, x_mask) 302 | self.initialized = True 303 | 304 | if reverse: 305 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask 306 | logdet = None 307 | else: 308 | z = (self.bias + torch.exp(self.logs) * x) * x_mask 309 | logdet = torch.sum(self.logs) * x_len # [b] 310 | 311 | return z, logdet 312 | 313 | def store_inverse(self): 314 | pass 315 | 316 | def set_ddi(self, ddi): 317 | self.initialized = not ddi 318 | 319 | def initialize(self, x, x_mask): 320 | with torch.no_grad(): 321 | denom = torch.sum(x_mask, [0, 2]) 322 | m = torch.sum(x * x_mask, [0, 2]) / denom 323 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom 324 | v = m_sq - (m ** 2) 325 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) 326 | 327 | bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) 328 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) 329 | 330 | self.bias.data.copy_(bias_init) 331 | self.logs.data.copy_(logs_init) 332 | 333 | 334 | class InvConvNear(nn.Module): 335 | def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): 336 | super().__init__() 337 | assert n_split % 2 == 0 338 | self.channels = channels 339 | self.n_split = n_split 340 | self.no_jacobian = no_jacobian 341 | 342 | w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] 343 | if torch.det(w_init) < 0: 344 | w_init[:, 0] = -1 * w_init[:, 0] 345 | self.weight = nn.Parameter(w_init) 346 | 347 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 348 | b, c, t = x.size() 349 | assert c % self.n_split == 0 350 | if x_mask is None: 351 | x_mask = 1 352 | x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t 353 | else: 354 | x_len = torch.sum(x_mask, [1, 2]) 355 | 356 | x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) 357 | x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) 358 | 359 | if reverse: 360 | if hasattr(self, "weight_inv"): 361 | weight = self.weight_inv 362 | else: 363 | weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 364 | logdet = None 365 | else: 366 | weight = self.weight 367 | if self.no_jacobian: 368 | logdet = 0 369 | else: 370 | logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b] 371 | 372 | weight = weight.view(self.n_split, self.n_split, 1, 1) 373 | z = F.conv2d(x, weight) 374 | 375 | z = z.view(b, 2, self.n_split // 2, c // self.n_split, t) 376 | z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask 377 | return z, logdet 378 | 379 | def store_inverse(self): 380 | self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 381 | 382 | 383 | class CouplingBlock(nn.Module): 384 | def __init__( 385 | self, 386 | in_channels, 387 | hidden_channels, 388 | kernel_size, 389 | dilation_rate, 390 | n_layers, 391 | gin_channels=0, 392 | p_dropout=0, 393 | sigmoid_scale=False, 394 | ): 395 | super().__init__() 396 | self.in_channels = in_channels 397 | self.hidden_channels = hidden_channels 398 | self.kernel_size = kernel_size 399 | self.dilation_rate = dilation_rate 400 | self.n_layers = n_layers 401 | self.gin_channels = gin_channels 402 | self.p_dropout = p_dropout 403 | self.sigmoid_scale = sigmoid_scale 404 | 405 | start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) 406 | start = torch.nn.utils.weight_norm(start) 407 | self.start = start 408 | # Initializing last layer to 0 makes the affine coupling layers 409 | # do nothing at first. This helps with training stability 410 | end = torch.nn.Conv1d(hidden_channels, in_channels, 1) 411 | end.weight.data.zero_() 412 | end.bias.data.zero_() 413 | self.end = end 414 | 415 | self.wn = WN( 416 | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout, gin_channels 417 | ) 418 | 419 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): 420 | if x_mask is None: 421 | x_mask = 1 422 | x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] 423 | 424 | x = self.start(x_0) * x_mask 425 | x = self.wn(x, x_mask, g) 426 | out = self.end(x) 427 | 428 | z_0 = x_0 429 | m = out[:, : self.in_channels // 2, :] 430 | logs = out[:, self.in_channels // 2 :, :] 431 | if self.sigmoid_scale: 432 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) 433 | 434 | if reverse: 435 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask 436 | logdet = None 437 | else: 438 | z_1 = (m + torch.exp(logs) * x_1) * x_mask 439 | logdet = torch.sum(logs * x_mask, [1, 2]) 440 | 441 | z = torch.cat([z_0, z_1], 1) 442 | return z, logdet 443 | 444 | def store_inverse(self): 445 | self.wn.remove_weight_norm() 446 | 447 | 448 | class AttentionBlock(nn.Module): 449 | def __init__( 450 | self, 451 | channels, 452 | out_channels, 453 | n_heads, 454 | window_size=None, 455 | heads_share=True, 456 | p_dropout=0.0, 457 | block_length=None, 458 | proximal_bias=False, 459 | proximal_init=False, 460 | ): 461 | super().__init__() 462 | assert channels % n_heads == 0 463 | 464 | self.channels = channels 465 | self.out_channels = out_channels 466 | self.n_heads = n_heads 467 | self.window_size = window_size 468 | self.heads_share = heads_share 469 | self.block_length = block_length 470 | self.proximal_bias = proximal_bias 471 | self.p_dropout = p_dropout 472 | self.attn = None 473 | 474 | self.k_channels = channels // n_heads 475 | self.conv_q = nn.Conv1d(channels, channels, 1) 476 | self.conv_k = nn.Conv1d(channels, channels, 1) 477 | self.conv_v = nn.Conv1d(channels, channels, 1) 478 | if window_size is not None: 479 | n_heads_rel = 1 if heads_share else n_heads 480 | rel_stddev = self.k_channels ** -0.5 481 | self.emb_rel_k = nn.Parameter( 482 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev 483 | ) 484 | self.emb_rel_v = nn.Parameter( 485 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev 486 | ) 487 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 488 | self.drop = nn.Dropout(p_dropout) 489 | 490 | nn.init.xavier_uniform_(self.conv_q.weight) 491 | nn.init.xavier_uniform_(self.conv_k.weight) 492 | if proximal_init: 493 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 494 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 495 | nn.init.xavier_uniform_(self.conv_v.weight) 496 | 497 | def forward(self, x, c, attn_mask=None): 498 | q = self.conv_q(x) 499 | k = self.conv_k(c) 500 | v = self.conv_v(c) 501 | 502 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 503 | 504 | x = self.conv_o(x) 505 | return x 506 | 507 | def attention(self, query, key, value, mask=None): 508 | # reshape [b, d, t] -> [b, n_h, t, d_k] 509 | b, d, t_s, t_t = key.size(0), key.size(1), key.size(2), query.size(2) 510 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 511 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 512 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 513 | 514 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 515 | if self.window_size is not None: 516 | assert t_s == t_t, "Relative attention is only available for self-attention." 517 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 518 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 519 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 520 | scores_local = rel_logits / math.sqrt(self.k_channels) 521 | scores = scores + scores_local 522 | if self.proximal_bias: 523 | assert t_s == t_t, "Proximal bias is only available for self-attention." 524 | scores = scores + self._attention_bias_proximal(t_s).to( 525 | device=scores.device, dtype=scores.dtype 526 | ) 527 | if mask is not None: 528 | scores = scores.masked_fill(mask == 0, -1e4) 529 | if self.block_length is not None: 530 | block_mask = ( 531 | torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 532 | ) 533 | scores = scores * block_mask + -1e4 * (1 - block_mask) 534 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 535 | p_attn = self.drop(p_attn) 536 | output = torch.matmul(p_attn, value) 537 | if self.window_size is not None: 538 | relative_weights = self._absolute_position_to_relative_position(p_attn) 539 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 540 | output = output + self._matmul_with_relative_values( 541 | relative_weights, value_relative_embeddings 542 | ) 543 | output = ( 544 | output.transpose(2, 3).contiguous().view(b, d, t_t) 545 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 546 | return output, p_attn 547 | 548 | def _matmul_with_relative_values(self, x, y): 549 | """ 550 | x: [b, h, l, m] 551 | y: [h or 1, m, d] 552 | ret: [b, h, l, d] 553 | """ 554 | ret = torch.matmul(x, y.unsqueeze(0)) 555 | return ret 556 | 557 | def _matmul_with_relative_keys(self, x, y): 558 | """ 559 | x: [b, h, l, d] 560 | y: [h or 1, m, d] 561 | ret: [b, h, l, m] 562 | """ 563 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 564 | return ret 565 | 566 | def _get_relative_embeddings(self, relative_embeddings, length): 567 | # Pad first before slice to avoid using cond ops. 568 | pad_length = max(length - (self.window_size + 1), 0) 569 | slice_start_position = max((self.window_size + 1) - length, 0) 570 | slice_end_position = slice_start_position + 2 * length - 1 571 | if pad_length > 0: 572 | padded_relative_embeddings = F.pad( 573 | relative_embeddings, 574 | convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 575 | ) 576 | else: 577 | padded_relative_embeddings = relative_embeddings 578 | used_relative_embeddings = padded_relative_embeddings[ 579 | :, slice_start_position:slice_end_position 580 | ] 581 | return used_relative_embeddings 582 | 583 | def _relative_position_to_absolute_position(self, x): 584 | """ 585 | x: [b, h, l, 2*l-1] 586 | ret: [b, h, l, l] 587 | """ 588 | batch, heads, length, _ = x.size() 589 | # Concat columns of pad to shift from relative to absolute indexing. 590 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 591 | 592 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 593 | x_flat = x.view([batch, heads, length * 2 * length]) 594 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 595 | 596 | # Reshape and slice out the padded elements. 597 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 598 | :, :, :length, length - 1 : 599 | ] 600 | return x_final 601 | 602 | def _absolute_position_to_relative_position(self, x): 603 | """ 604 | x: [b, h, l, l] 605 | ret: [b, h, l, 2*l-1] 606 | """ 607 | batch, heads, length, _ = x.size() 608 | # padd along column 609 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 610 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 611 | # add 0's in the beginning that will skew the elements after reshape 612 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 613 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 614 | return x_final 615 | 616 | def _attention_bias_proximal(self, length): 617 | """Bias for self-attention to encourage attention to close positions. 618 | Args: 619 | length: an integer scalar. 620 | Returns: 621 | a Tensor with shape [1, 1, length, length] 622 | """ 623 | r = torch.arange(length, dtype=torch.float32) 624 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 625 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 626 | 627 | 628 | class FeedForwardNetwork(nn.Module): 629 | def __init__( 630 | self, 631 | in_channels, 632 | out_channels, 633 | filter_channels, 634 | kernel_size, 635 | p_dropout=0.0, 636 | activation=None, 637 | ): 638 | super().__init__() 639 | self.in_channels = in_channels 640 | self.out_channels = out_channels 641 | self.filter_channels = filter_channels 642 | self.kernel_size = kernel_size 643 | self.p_dropout = p_dropout 644 | self.activation = activation 645 | 646 | self.conv_1 = nn.Conv1d( 647 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 648 | ) 649 | self.conv_2 = nn.Conv1d( 650 | filter_channels, out_channels, kernel_size, padding=kernel_size // 2 651 | ) 652 | self.drop = nn.Dropout(p_dropout) 653 | 654 | def forward(self, x, x_mask): 655 | x = self.conv_1(x * x_mask) 656 | if self.activation == "gelu": 657 | x = x * torch.sigmoid(1.702 * x) 658 | else: 659 | x = torch.relu(x) 660 | x = self.drop(x) 661 | x = self.conv_2(x * x_mask) 662 | return x * x_mask 663 | 664 | 665 | class DurationPredictor(nn.Module): 666 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): 667 | """ 668 | Token duration predictor for the GlowTTS model. 669 | Takes in embeddings of the input tokens and predicts how many frames of 670 | mel-spectrogram are aligned to each text token. 671 | Architecture is the same as the duration predictor in FastSpeech. 672 | Args: 673 | in_channels: Number of channels for the token embeddings 674 | filter_channels: Number of channels in the intermediate layers 675 | kernel_size: Kernels size for the convolution layers 676 | p_dropout: Dropout probability 677 | """ 678 | 679 | super().__init__() 680 | 681 | self.drop = nn.Dropout(p_dropout) 682 | self.conv_1 = nn.Conv1d( 683 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 684 | ) 685 | self.norm_1 = LayerNorm(filter_channels) 686 | self.conv_2 = nn.Conv1d( 687 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 688 | ) 689 | self.norm_2 = LayerNorm(filter_channels) 690 | self.proj = nn.Conv1d(filter_channels, 1, 1) 691 | 692 | def forward(self, spect, mask, **kwargs): 693 | x = self.conv_1(spect * mask) 694 | x = torch.relu(x) 695 | x = self.norm_1(x) 696 | x = self.drop(x) 697 | x = self.conv_2(x * mask) 698 | x = torch.relu(x) 699 | x = self.norm_2(x) 700 | x = self.drop(x) 701 | x = self.proj(x * mask) 702 | durs = x * mask 703 | return durs.squeeze(1) 704 | 705 | 706 | class StochasticDurationPredictor(nn.Module): 707 | """Borrowed from VITS""" 708 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 709 | super().__init__() 710 | 711 | filter_channels = in_channels # it needs to be removed from future version. 712 | self.in_channels = in_channels 713 | self.filter_channels = filter_channels 714 | self.kernel_size = kernel_size 715 | self.p_dropout = p_dropout 716 | self.n_flows = n_flows 717 | self.gin_channels = gin_channels 718 | 719 | self.log_flow = stocpred_modules.Log() 720 | self.flows = nn.ModuleList() 721 | self.flows.append(stocpred_modules.ElementwiseAffine(2)) 722 | for i in range(n_flows): 723 | self.flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 724 | self.flows.append(stocpred_modules.Flip()) 725 | 726 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 727 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 728 | self.post_convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 729 | self.post_flows = nn.ModuleList() 730 | self.post_flows.append(stocpred_modules.ElementwiseAffine(2)) 731 | for i in range(4): 732 | self.post_flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 733 | self.post_flows.append(stocpred_modules.Flip()) 734 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 735 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 736 | self.convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 737 | if gin_channels != 0: 738 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 739 | 740 | def forward(self, spect, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 741 | x = torch.detach(spect) 742 | x = self.pre(x) 743 | if g is not None: 744 | g = torch.detach(g) 745 | x = x + self.cond(g) 746 | x = self.convs(x, x_mask) 747 | x = self.proj(x) * x_mask 748 | if not reverse: 749 | flows = self.flows 750 | assert w is not None 751 | 752 | logdet_tot_q = 0 753 | h_w = self.post_pre(w) 754 | h_w = self.post_convs(h_w, x_mask) 755 | h_w = self.post_proj(h_w) * x_mask 756 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 757 | z_q = e_q 758 | for flow in self.post_flows: 759 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 760 | logdet_tot_q += logdet_q 761 | z_u, z1 = torch.split(z_q, [1, 1], 1) 762 | u = torch.sigmoid(z_u) * x_mask 763 | z0 = (w - u) * x_mask 764 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 765 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 766 | 767 | logdet_tot = 0 768 | z0, logdet = self.log_flow(z0, x_mask) 769 | logdet_tot += logdet 770 | z = torch.cat([z0, z1], 1) 771 | for flow in flows: 772 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 773 | logdet_tot = logdet_tot + logdet 774 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 775 | stoch_dur_loss = nll + logq 776 | stoch_dur_loss = stoch_dur_loss / torch.sum(x_mask) 777 | stoch_dur_loss = torch.sum(stoch_dur_loss) 778 | 779 | return stoch_dur_loss, None # [b] 780 | 781 | else: 782 | flows = list(reversed(self.flows)) 783 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 784 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 785 | for flow in flows: 786 | z = flow(z, x_mask, g=x, reverse=reverse) 787 | z0, z1 = torch.split(z, [1, 1], 1) 788 | logw = z0 789 | return None, logw 790 | 791 | -------------------------------------------------------------------------------- /modules/glow_tts_modules/glow_tts_submodules_with_pitch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # MIT License 16 | # 17 | # Copyright (c) 2020 Jaehyeon Kim 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | 37 | import math 38 | 39 | import numpy as np 40 | import torch 41 | from torch import nn 42 | from torch.nn import functional as F 43 | 44 | from modules.glow_tts_modules import stocpred_modules 45 | 46 | 47 | def convert_pad_shape(pad_shape): 48 | """ 49 | Used to get arguments for F.pad 50 | """ 51 | l = pad_shape[::-1] 52 | pad_shape = [item for sublist in l for item in sublist] 53 | return pad_shape 54 | 55 | 56 | def sequence_mask(length, max_length=None): 57 | """ 58 | Get masks for given lengths 59 | """ 60 | if max_length is None: 61 | max_length = length.max() 62 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 63 | 64 | return x.unsqueeze(0) < length.unsqueeze(1) 65 | 66 | 67 | def maximum_path(value, mask, max_neg_val=None): 68 | """ 69 | Monotonic alignment search algorithm 70 | Numpy-friendly version. It's about 4 times faster than torch version. 71 | value: [b, t_x, t_y] 72 | mask: [b, t_x, t_y] 73 | """ 74 | if max_neg_val is None: 75 | max_neg_val = -np.inf # Patch for Sphinx complaint 76 | value = value * mask 77 | 78 | device = value.device 79 | dtype = value.dtype 80 | value = value.cpu().detach().numpy() 81 | mask = mask.cpu().detach().numpy().astype(np.bool) 82 | 83 | b, t_x, t_y = value.shape 84 | direction = np.zeros(value.shape, dtype=np.int64) 85 | v = np.zeros((b, t_x), dtype=np.float32) 86 | x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) 87 | for j in range(t_y): 88 | v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] 89 | v1 = v 90 | max_mask = v1 >= v0 91 | v_max = np.where(max_mask, v1, v0) 92 | direction[:, :, j] = max_mask 93 | 94 | index_mask = x_range <= j 95 | v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) 96 | direction = np.where(mask, direction, 1) 97 | 98 | path = np.zeros(value.shape, dtype=np.float32) 99 | index = mask[:, :, 0].sum(1).astype(np.int64) - 1 100 | index_range = np.arange(b) 101 | try: 102 | for j in reversed(range(t_y)): 103 | path[index_range, index, j] = 1 104 | index = index + direction[index_range, index, j] - 1 105 | except Exception as e: 106 | print("index range", index_range) 107 | print(e) 108 | path = path * mask.astype(np.float32) 109 | path = torch.from_numpy(path).to(device=device, dtype=dtype) 110 | return path 111 | 112 | 113 | def generate_path(duration, mask): 114 | """ 115 | Generate alignment based on predicted durations 116 | duration: [b, t_x] 117 | mask: [b, t_x, t_y] 118 | """ 119 | 120 | b, t_x, t_y = mask.shape 121 | cum_duration = torch.cumsum(duration, 1) 122 | 123 | cum_duration_flat = cum_duration.view(b * t_x) 124 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 125 | path = path.view(b, t_x, t_y) 126 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 127 | path = path * mask 128 | return path 129 | 130 | 131 | @torch.jit.script 132 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 133 | n_channels_int = n_channels[0] 134 | in_act = input_a + input_b 135 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 136 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 137 | acts = t_act * s_act 138 | return acts 139 | 140 | 141 | class LayerNorm(nn.Module): 142 | def __init__(self, channels, eps=1e-4): 143 | super().__init__() 144 | self.channels = channels 145 | self.eps = eps 146 | 147 | self.gamma = nn.Parameter(torch.ones(channels)) 148 | self.beta = nn.Parameter(torch.zeros(channels)) 149 | 150 | def forward(self, x): 151 | n_dims = len(x.shape) 152 | mean = torch.mean(x, 1, keepdim=True) 153 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True) 154 | 155 | x = (x - mean) * torch.rsqrt(variance + self.eps) 156 | 157 | shape = [1, -1] + [1] * (n_dims - 2) 158 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 159 | return x 160 | 161 | 162 | class ConvReluNorm(nn.Module): 163 | def __init__( 164 | self, 165 | in_channels, 166 | hidden_channels, 167 | out_channels, 168 | kernel_size, 169 | n_layers, 170 | p_dropout, 171 | ): 172 | super().__init__() 173 | self.in_channels = in_channels 174 | self.hidden_channels = hidden_channels 175 | self.out_channels = out_channels 176 | self.kernel_size = kernel_size 177 | self.n_layers = n_layers 178 | self.p_dropout = p_dropout 179 | assert n_layers > 1, "Number of layers should be larger than 0." 180 | 181 | self.conv_layers = nn.ModuleList() 182 | self.norm_layers = nn.ModuleList() 183 | self.conv_layers.append( 184 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) 185 | ) 186 | self.norm_layers.append(LayerNorm(hidden_channels)) 187 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 188 | for _ in range(n_layers - 1): 189 | self.conv_layers.append( 190 | nn.Conv1d( 191 | hidden_channels, 192 | hidden_channels, 193 | kernel_size, 194 | padding=kernel_size // 2, 195 | ) 196 | ) 197 | self.norm_layers.append(LayerNorm(hidden_channels)) 198 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 199 | self.proj.weight.data.zero_() 200 | self.proj.bias.data.zero_() 201 | 202 | def forward(self, x, x_mask): 203 | x_org = x 204 | for i in range(self.n_layers): 205 | x = self.conv_layers[i](x * x_mask) 206 | x = self.norm_layers[i](x) 207 | x = self.relu_drop(x) 208 | x = x_org + self.proj(x) 209 | return x * x_mask 210 | 211 | 212 | class WN(torch.nn.Module): 213 | def __init__( 214 | self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0 215 | ): 216 | super(WN, self).__init__() 217 | assert kernel_size % 2 == 1 218 | assert hidden_channels % 2 == 0 219 | 220 | self.hidden_channels = hidden_channels 221 | self.n_layers = n_layers 222 | 223 | self.in_layers = torch.nn.ModuleList() 224 | self.res_skip_layers = torch.nn.ModuleList() 225 | self.drop = nn.Dropout(p_dropout) 226 | 227 | if gin_channels != 0: 228 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 229 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 230 | 231 | for i in range(n_layers): 232 | dilation = dilation_rate ** i 233 | padding = int((kernel_size * dilation - dilation) / 2) 234 | in_layer = torch.nn.Conv1d( 235 | hidden_channels, 236 | 2 * hidden_channels, 237 | kernel_size, 238 | dilation=dilation, 239 | padding=padding, 240 | ) 241 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 242 | self.in_layers.append(in_layer) 243 | 244 | # last one is not necessary 245 | if i < n_layers - 1: 246 | res_skip_channels = 2 * hidden_channels 247 | else: 248 | res_skip_channels = hidden_channels 249 | 250 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 251 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 252 | self.res_skip_layers.append(res_skip_layer) 253 | 254 | def forward(self, x, x_mask=None, g=None, **kwargs): 255 | output = torch.zeros_like(x) 256 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 257 | if g is not None: 258 | g = self.cond_layer(g) 259 | for i in range(self.n_layers): 260 | x_in = self.in_layers[i](x) 261 | x_in = self.drop(x_in) 262 | if g is not None: 263 | cond_offset = i * 2 * self.hidden_channels 264 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 265 | else: 266 | g_l = torch.zeros_like(x_in) 267 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 268 | 269 | res_skip_acts = self.res_skip_layers[i](acts) 270 | if i < self.n_layers - 1: 271 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask 272 | output = output + res_skip_acts[:, self.hidden_channels :, :] 273 | else: 274 | output = output + res_skip_acts 275 | return output * x_mask 276 | 277 | def remove_weight_norm(self): 278 | for l in self.in_layers: 279 | torch.nn.utils.remove_weight_norm(l) 280 | for l in self.res_skip_layers: 281 | torch.nn.utils.remove_weight_norm(l) 282 | 283 | 284 | class WNP(torch.nn.Module): 285 | def __init__( 286 | self, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, n_sqz=2, 287 | ): 288 | super(WNP, self).__init__() 289 | assert kernel_size % 2 == 1 290 | assert hidden_channels % 2 == 0 291 | 292 | self.hidden_channels = hidden_channels 293 | self.n_layers = n_layers 294 | 295 | self.in_layers = torch.nn.ModuleList() 296 | self.res_skip_layers = torch.nn.ModuleList() 297 | self.drop = nn.Dropout(p_dropout) 298 | self.n_sqz = n_sqz 299 | 300 | if gin_channels != 0: 301 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers // self.n_sqz, 1) 302 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 303 | 304 | for i in range(n_layers): 305 | dilation = dilation_rate ** i 306 | padding = int((kernel_size * dilation - dilation) / 2) 307 | in_layer = torch.nn.Conv1d( 308 | hidden_channels, 309 | 2 * hidden_channels, 310 | kernel_size, 311 | dilation=dilation, 312 | padding=padding, 313 | ) 314 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 315 | self.in_layers.append(in_layer) 316 | 317 | # last one is not necessary 318 | if i < n_layers - 1: 319 | res_skip_channels = 2 * hidden_channels 320 | else: 321 | res_skip_channels = hidden_channels 322 | 323 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 324 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 325 | self.res_skip_layers.append(res_skip_layer) 326 | 327 | def forward(self, x, x_mask=None, g=None, **kwargs): 328 | output = torch.zeros_like(x) 329 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 330 | if g is not None: 331 | g = self.cond_layer(g) 332 | if self.n_sqz > 1: 333 | g = self.squeeze(g, self.n_sqz) 334 | else: 335 | return x 336 | 337 | for i in range(self.n_layers): 338 | x_in = self.in_layers[i](x) 339 | x_in = self.drop(x_in) 340 | if g is not None: 341 | cond_offset = i * 2 * self.hidden_channels 342 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 343 | else: 344 | g_l = torch.zeros_like(x_in) 345 | 346 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 347 | 348 | res_skip_acts = self.res_skip_layers[i](acts) 349 | if i < self.n_layers - 1: 350 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask 351 | output = output + res_skip_acts[:, self.hidden_channels :, :] 352 | else: 353 | output = output + res_skip_acts 354 | return output * x_mask 355 | 356 | def remove_weight_norm(self): 357 | for l in self.in_layers: 358 | torch.nn.utils.remove_weight_norm(l) 359 | for l in self.res_skip_layers: 360 | torch.nn.utils.remove_weight_norm(l) 361 | 362 | def squeeze(self, x, n_sqz=2): 363 | b, c, t = x.size() 364 | 365 | t = (t // n_sqz) * n_sqz 366 | x = x[:, :, :t] 367 | x_sqz = x.view(b, c, t // n_sqz, n_sqz) 368 | 369 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) 370 | 371 | return x_sqz 372 | 373 | 374 | class ActNorm(nn.Module): 375 | def __init__(self, channels, ddi=False, **kwargs): 376 | super().__init__() 377 | self.channels = channels 378 | self.initialized = not ddi 379 | 380 | self.logs = nn.Parameter(torch.zeros(1, channels, 1)) 381 | self.bias = nn.Parameter(torch.zeros(1, channels, 1)) 382 | 383 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 384 | if x_mask is None: 385 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) 386 | x_len = torch.sum(x_mask, [1, 2]) 387 | if not self.initialized: 388 | self.initialize(x, x_mask) 389 | self.initialized = True 390 | 391 | if reverse: 392 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask 393 | logdet = None 394 | else: 395 | z = (self.bias + torch.exp(self.logs) * x) * x_mask 396 | logdet = torch.sum(self.logs) * x_len # [b] 397 | 398 | return z, logdet 399 | 400 | def store_inverse(self): 401 | pass 402 | 403 | def set_ddi(self, ddi): 404 | self.initialized = not ddi 405 | 406 | def initialize(self, x, x_mask): 407 | with torch.no_grad(): 408 | denom = torch.sum(x_mask, [0, 2]) 409 | m = torch.sum(x * x_mask, [0, 2]) / denom 410 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom 411 | v = m_sq - (m ** 2) 412 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) 413 | 414 | bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) 415 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) 416 | 417 | self.bias.data.copy_(bias_init) 418 | self.logs.data.copy_(logs_init) 419 | 420 | 421 | class InvConvNear(nn.Module): 422 | def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): 423 | super().__init__() 424 | assert n_split % 2 == 0 425 | self.channels = channels 426 | self.n_split = n_split 427 | self.no_jacobian = no_jacobian 428 | 429 | w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] 430 | if torch.det(w_init) < 0: 431 | w_init[:, 0] = -1 * w_init[:, 0] 432 | self.weight = nn.Parameter(w_init) 433 | 434 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 435 | b, c, t = x.size() 436 | assert c % self.n_split == 0 437 | if x_mask is None: 438 | x_mask = 1 439 | x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t 440 | else: 441 | x_len = torch.sum(x_mask, [1, 2]) 442 | 443 | x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) 444 | x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) 445 | 446 | if reverse: 447 | if hasattr(self, "weight_inv"): 448 | weight = self.weight_inv 449 | else: 450 | weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 451 | logdet = None 452 | else: 453 | weight = self.weight 454 | if self.no_jacobian: 455 | logdet = 0 456 | else: 457 | logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b] 458 | 459 | weight = weight.view(self.n_split, self.n_split, 1, 1) 460 | z = F.conv2d(x, weight) 461 | 462 | z = z.view(b, 2, self.n_split // 2, c // self.n_split, t) 463 | z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask 464 | return z, logdet 465 | 466 | def store_inverse(self): 467 | self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 468 | 469 | 470 | class CouplingBlock(nn.Module): 471 | def __init__( 472 | self, 473 | in_channels, 474 | hidden_channels, 475 | kernel_size, 476 | dilation_rate, 477 | n_layers, 478 | gin_channels=0, 479 | p_dropout=0, 480 | sigmoid_scale=False, 481 | n_sqz=2, 482 | ): 483 | super().__init__() 484 | self.in_channels = in_channels 485 | self.hidden_channels = hidden_channels 486 | self.kernel_size = kernel_size 487 | self.dilation_rate = dilation_rate 488 | self.n_layers = n_layers 489 | self.gin_channels = gin_channels 490 | self.p_dropout = p_dropout 491 | self.sigmoid_scale = sigmoid_scale 492 | self.n_sqz = n_sqz 493 | 494 | start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) 495 | start = torch.nn.utils.weight_norm(start) 496 | self.start = start 497 | # Initializing last layer to 0 makes the affine coupling layers 498 | # do nothing at first. This helps with training stability 499 | end = torch.nn.Conv1d(hidden_channels, in_channels, 1) 500 | end.weight.data.zero_() 501 | end.bias.data.zero_() 502 | self.end = end 503 | 504 | self.wn = WN( 505 | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout, gin_channels 506 | ) 507 | self.wn_pitch = WNP( 508 | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout, 1, n_sqz 509 | ) # 1 dimension of pitch 510 | 511 | def forward(self, x, x_mask=None, reverse=False, g=None, pitch=None, **kwargs): 512 | if x_mask is None: 513 | x_mask = 1 514 | x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] 515 | 516 | x = self.start(x_0) * x_mask 517 | # add speaker emb 518 | x = self.wn(x, x_mask, g) 519 | # add pitch emb 520 | if pitch is not None and len(pitch.shape) == 2: 521 | pitch = pitch.unsqueeze(1) # B, T -> B,C,T 522 | 523 | x = self.wn_pitch(x, x_mask, pitch) 524 | out = self.end(x) 525 | 526 | z_0 = x_0 527 | m = out[:, : self.in_channels // 2, :] 528 | logs = out[:, self.in_channels // 2 :, :] 529 | if self.sigmoid_scale: 530 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) 531 | 532 | if reverse: 533 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask 534 | logdet = None 535 | else: 536 | z_1 = (m + torch.exp(logs) * x_1) * x_mask 537 | logdet = torch.sum(logs * x_mask, [1, 2]) 538 | 539 | z = torch.cat([z_0, z_1], 1) 540 | return z, logdet 541 | 542 | def store_inverse(self): 543 | self.wn.remove_weight_norm() 544 | 545 | 546 | class AttentionBlock(nn.Module): 547 | def __init__( 548 | self, 549 | channels, 550 | out_channels, 551 | n_heads, 552 | window_size=None, 553 | heads_share=True, 554 | p_dropout=0.0, 555 | block_length=None, 556 | proximal_bias=False, 557 | proximal_init=False, 558 | ): 559 | super().__init__() 560 | assert channels % n_heads == 0 561 | 562 | self.channels = channels 563 | self.out_channels = out_channels 564 | self.n_heads = n_heads 565 | self.window_size = window_size 566 | self.heads_share = heads_share 567 | self.block_length = block_length 568 | self.proximal_bias = proximal_bias 569 | self.p_dropout = p_dropout 570 | self.attn = None 571 | 572 | self.k_channels = channels // n_heads 573 | self.conv_q = nn.Conv1d(channels, channels, 1) 574 | self.conv_k = nn.Conv1d(channels, channels, 1) 575 | self.conv_v = nn.Conv1d(channels, channels, 1) 576 | if window_size is not None: 577 | n_heads_rel = 1 if heads_share else n_heads 578 | rel_stddev = self.k_channels ** -0.5 579 | self.emb_rel_k = nn.Parameter( 580 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev 581 | ) 582 | self.emb_rel_v = nn.Parameter( 583 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev 584 | ) 585 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 586 | self.drop = nn.Dropout(p_dropout) 587 | 588 | nn.init.xavier_uniform_(self.conv_q.weight) 589 | nn.init.xavier_uniform_(self.conv_k.weight) 590 | if proximal_init: 591 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 592 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 593 | nn.init.xavier_uniform_(self.conv_v.weight) 594 | 595 | def forward(self, x, c, attn_mask=None): 596 | q = self.conv_q(x) 597 | k = self.conv_k(c) 598 | v = self.conv_v(c) 599 | 600 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 601 | 602 | x = self.conv_o(x) 603 | return x 604 | 605 | def attention(self, query, key, value, mask=None): 606 | # reshape [b, d, t] -> [b, n_h, t, d_k] 607 | b, d, t_s, t_t = key.size(0), key.size(1), key.size(2), query.size(2) 608 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 609 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 610 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 611 | 612 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 613 | if self.window_size is not None: 614 | assert t_s == t_t, "Relative attention is only available for self-attention." 615 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 616 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 617 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 618 | scores_local = rel_logits / math.sqrt(self.k_channels) 619 | scores = scores + scores_local 620 | if self.proximal_bias: 621 | assert t_s == t_t, "Proximal bias is only available for self-attention." 622 | scores = scores + self._attention_bias_proximal(t_s).to( 623 | device=scores.device, dtype=scores.dtype 624 | ) 625 | if mask is not None: 626 | scores = scores.masked_fill(mask == 0, -1e4) 627 | if self.block_length is not None: 628 | block_mask = ( 629 | torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 630 | ) 631 | scores = scores * block_mask + -1e4 * (1 - block_mask) 632 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 633 | p_attn = self.drop(p_attn) 634 | output = torch.matmul(p_attn, value) 635 | if self.window_size is not None: 636 | relative_weights = self._absolute_position_to_relative_position(p_attn) 637 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 638 | output = output + self._matmul_with_relative_values( 639 | relative_weights, value_relative_embeddings 640 | ) 641 | output = ( 642 | output.transpose(2, 3).contiguous().view(b, d, t_t) 643 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 644 | return output, p_attn 645 | 646 | def _matmul_with_relative_values(self, x, y): 647 | """ 648 | x: [b, h, l, m] 649 | y: [h or 1, m, d] 650 | ret: [b, h, l, d] 651 | """ 652 | ret = torch.matmul(x, y.unsqueeze(0)) 653 | return ret 654 | 655 | def _matmul_with_relative_keys(self, x, y): 656 | """ 657 | x: [b, h, l, d] 658 | y: [h or 1, m, d] 659 | ret: [b, h, l, m] 660 | """ 661 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 662 | return ret 663 | 664 | def _get_relative_embeddings(self, relative_embeddings, length): 665 | # Pad first before slice to avoid using cond ops. 666 | pad_length = max(length - (self.window_size + 1), 0) 667 | slice_start_position = max((self.window_size + 1) - length, 0) 668 | slice_end_position = slice_start_position + 2 * length - 1 669 | if pad_length > 0: 670 | padded_relative_embeddings = F.pad( 671 | relative_embeddings, 672 | convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 673 | ) 674 | else: 675 | padded_relative_embeddings = relative_embeddings 676 | used_relative_embeddings = padded_relative_embeddings[ 677 | :, slice_start_position:slice_end_position 678 | ] 679 | return used_relative_embeddings 680 | 681 | def _relative_position_to_absolute_position(self, x): 682 | """ 683 | x: [b, h, l, 2*l-1] 684 | ret: [b, h, l, l] 685 | """ 686 | batch, heads, length, _ = x.size() 687 | # Concat columns of pad to shift from relative to absolute indexing. 688 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 689 | 690 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 691 | x_flat = x.view([batch, heads, length * 2 * length]) 692 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 693 | 694 | # Reshape and slice out the padded elements. 695 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 696 | :, :, :length, length - 1 : 697 | ] 698 | return x_final 699 | 700 | def _absolute_position_to_relative_position(self, x): 701 | """ 702 | x: [b, h, l, l] 703 | ret: [b, h, l, 2*l-1] 704 | """ 705 | batch, heads, length, _ = x.size() 706 | # padd along column 707 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 708 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 709 | # add 0's in the beginning that will skew the elements after reshape 710 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 711 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 712 | return x_final 713 | 714 | def _attention_bias_proximal(self, length): 715 | """Bias for self-attention to encourage attention to close positions. 716 | Args: 717 | length: an integer scalar. 718 | Returns: 719 | a Tensor with shape [1, 1, length, length] 720 | """ 721 | r = torch.arange(length, dtype=torch.float32) 722 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 723 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 724 | 725 | 726 | class FeedForwardNetwork(nn.Module): 727 | def __init__( 728 | self, 729 | in_channels, 730 | out_channels, 731 | filter_channels, 732 | kernel_size, 733 | p_dropout=0.0, 734 | activation=None, 735 | ): 736 | super().__init__() 737 | self.in_channels = in_channels 738 | self.out_channels = out_channels 739 | self.filter_channels = filter_channels 740 | self.kernel_size = kernel_size 741 | self.p_dropout = p_dropout 742 | self.activation = activation 743 | 744 | self.conv_1 = nn.Conv1d( 745 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 746 | ) 747 | self.conv_2 = nn.Conv1d( 748 | filter_channels, out_channels, kernel_size, padding=kernel_size // 2 749 | ) 750 | self.drop = nn.Dropout(p_dropout) 751 | 752 | def forward(self, x, x_mask): 753 | x = self.conv_1(x * x_mask) 754 | if self.activation == "gelu": 755 | x = x * torch.sigmoid(1.702 * x) 756 | else: 757 | x = torch.relu(x) 758 | x = self.drop(x) 759 | x = self.conv_2(x * x_mask) 760 | return x * x_mask 761 | 762 | 763 | class DurationPredictor(nn.Module): 764 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): 765 | """ 766 | Token duration predictor for the GlowTTS model. 767 | Takes in embeddings of the input tokens and predicts how many frames of 768 | mel-spectrogram are aligned to each text token. 769 | Architecture is the same as the duration predictor in FastSpeech. 770 | Args: 771 | in_channels: Number of channels for the token embeddings 772 | filter_channels: Number of channels in the intermediate layers 773 | kernel_size: Kernels size for the convolution layers 774 | p_dropout: Dropout probability 775 | """ 776 | 777 | super().__init__() 778 | 779 | self.drop = nn.Dropout(p_dropout) 780 | self.conv_1 = nn.Conv1d( 781 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 782 | ) 783 | self.norm_1 = LayerNorm(filter_channels) 784 | self.conv_2 = nn.Conv1d( 785 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 786 | ) 787 | self.norm_2 = LayerNorm(filter_channels) 788 | self.proj = nn.Conv1d(filter_channels, 1, 1) 789 | 790 | def forward(self, spect, mask, **kwargs): 791 | x = torch.detach(spect) 792 | x = self.conv_1(spect * mask) 793 | x = torch.relu(x) 794 | x = self.norm_1(x) 795 | x = self.drop(x) 796 | x = self.conv_2(x * mask) 797 | x = torch.relu(x) 798 | x = self.norm_2(x) 799 | x = self.drop(x) 800 | x = self.proj(x * mask) 801 | durs = x * mask 802 | return durs.squeeze(1) 803 | 804 | 805 | class StochasticDurationPredictor(nn.Module): 806 | """Borrowed from VITS""" 807 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 808 | super().__init__() 809 | 810 | filter_channels = in_channels # it needs to be removed from future version. 811 | self.in_channels = in_channels 812 | self.filter_channels = filter_channels 813 | self.kernel_size = kernel_size 814 | self.p_dropout = p_dropout 815 | self.n_flows = n_flows 816 | self.gin_channels = gin_channels 817 | 818 | self.log_flow = stocpred_modules.Log() 819 | self.flows = nn.ModuleList() 820 | self.flows.append(stocpred_modules.ElementwiseAffine(2)) 821 | for i in range(n_flows): 822 | self.flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 823 | self.flows.append(stocpred_modules.Flip()) 824 | 825 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 826 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 827 | self.post_convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 828 | self.post_flows = nn.ModuleList() 829 | self.post_flows.append(stocpred_modules.ElementwiseAffine(2)) 830 | for i in range(4): 831 | self.post_flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 832 | self.post_flows.append(stocpred_modules.Flip()) 833 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 834 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 835 | self.convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 836 | if gin_channels != 0: 837 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 838 | 839 | def forward(self, spect, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 840 | x = torch.detach(spect) 841 | x = self.pre(x) 842 | if g is not None: 843 | g = torch.detach(g) 844 | x = x + self.cond(g) 845 | x = self.convs(x, x_mask) 846 | x = self.proj(x) * x_mask 847 | if not reverse: 848 | flows = self.flows 849 | assert w is not None 850 | 851 | logdet_tot_q = 0 852 | h_w = self.post_pre(w) 853 | h_w = self.post_convs(h_w, x_mask) 854 | h_w = self.post_proj(h_w) * x_mask 855 | 856 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 857 | z_q = e_q 858 | for flow in self.post_flows: 859 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 860 | logdet_tot_q += logdet_q 861 | z_u, z1 = torch.split(z_q, [1, 1], 1) 862 | u = torch.sigmoid(z_u) * x_mask 863 | z0 = (w - u) * x_mask 864 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 865 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 866 | 867 | logdet_tot = 0 868 | z0, logdet = self.log_flow(z0, x_mask) 869 | logdet_tot += logdet 870 | z = torch.cat([z0, z1], 1) 871 | for flow in flows: 872 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 873 | logdet_tot = logdet_tot + logdet 874 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 875 | stoch_dur_loss = nll + logq 876 | stoch_dur_loss = stoch_dur_loss / torch.sum(x_mask) 877 | stoch_dur_loss = torch.sum(stoch_dur_loss) 878 | 879 | return stoch_dur_loss, None # [b] 880 | 881 | else: 882 | flows = list(reversed(self.flows)) 883 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 884 | # generator=torch.Generator().manual_seed(42) 885 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 886 | for flow in flows: 887 | z = flow(z, x_mask, g=x, reverse=reverse) 888 | z0, z1 = torch.split(z, [1, 1], 1) 889 | logw = z0 890 | return None, logw 891 | 892 | 893 | class StochasticPitchPredictor(nn.Module): 894 | """Borrowed from VITS""" 895 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 896 | super().__init__() 897 | 898 | filter_channels = in_channels # it needs to be removed from future version. 899 | self.in_channels = in_channels 900 | self.filter_channels = filter_channels 901 | self.kernel_size = kernel_size 902 | self.p_dropout = p_dropout 903 | self.n_flows = n_flows 904 | self.gin_channels = gin_channels 905 | 906 | self.log_flow = stocpred_modules.Log() 907 | self.flows = nn.ModuleList() 908 | self.flows.append(stocpred_modules.ElementwiseAffine(2)) 909 | for i in range(n_flows): 910 | self.flows.append(stocpred_modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 911 | self.flows.append(stocpred_modules.Flip()) 912 | 913 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 914 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 915 | self.convs = stocpred_modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 916 | if gin_channels != 0: 917 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 918 | 919 | def forward(self, spect, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 920 | x = torch.detach(spect) 921 | x = self.pre(x) 922 | if g is not None: 923 | g = torch.detach(g) 924 | a = self.cond(g) 925 | x = x + self.cond(g) 926 | x = self.convs(x, x_mask) 927 | x = self.proj(x) * x_mask 928 | if not reverse: 929 | flows = self.flows 930 | assert w is not None 931 | 932 | e_q = torch.randn(w.size()).to(device=x.device, dtype=x.dtype) * x_mask 933 | 934 | logdet_tot = 0 935 | z = torch.cat([w, e_q], 1) 936 | for flow in flows: 937 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 938 | logdet_tot = logdet_tot + logdet 939 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 940 | 941 | stoch_pitch_loss = nll / torch.sum(x_mask) 942 | stoch_pitch_loss = torch.sum(stoch_pitch_loss) 943 | return stoch_pitch_loss, None # [b] 944 | 945 | else: 946 | flows = list(reversed(self.flows)) 947 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 948 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 949 | for flow in flows: 950 | z = flow(z, x_mask, g=x, reverse=reverse) 951 | z0, z1 = torch.split(z, [1, 1], 1) 952 | w = z0 953 | return None, w 954 | -------------------------------------------------------------------------------- /modules/glow_tts_modules/stocpred_modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from modules.glow_tts_modules.transforms import piecewise_rational_quadratic_transform 7 | 8 | 9 | class LayerNorm(nn.Module): 10 | def __init__(self, channels, eps=1e-4): 11 | super().__init__() 12 | self.channels = channels 13 | self.eps = eps 14 | 15 | self.gamma = nn.Parameter(torch.ones(channels)) 16 | self.beta = nn.Parameter(torch.zeros(channels)) 17 | 18 | def forward(self, x): 19 | n_dims = len(x.shape) 20 | mean = torch.mean(x, 1, keepdim=True) 21 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True) 22 | 23 | x = (x - mean) * torch.rsqrt(variance + self.eps) 24 | 25 | shape = [1, -1] + [1] * (n_dims - 2) 26 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 27 | return x 28 | 29 | 30 | class DDSConv(nn.Module): 31 | """ 32 | Dialted and Depth-Separable Convolution 33 | """ 34 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 35 | super().__init__() 36 | self.channels = channels 37 | self.kernel_size = kernel_size 38 | self.n_layers = n_layers 39 | self.p_dropout = p_dropout 40 | 41 | self.drop = nn.Dropout(p_dropout) 42 | self.convs_sep = nn.ModuleList() 43 | self.convs_1x1 = nn.ModuleList() 44 | self.norms_1 = nn.ModuleList() 45 | self.norms_2 = nn.ModuleList() 46 | for i in range(n_layers): 47 | dilation = kernel_size ** i 48 | padding = (kernel_size * dilation - dilation) // 2 49 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 50 | groups=channels, dilation=dilation, padding=padding 51 | )) 52 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 53 | self.norms_1.append(LayerNorm(channels)) 54 | self.norms_2.append(LayerNorm(channels)) 55 | 56 | def forward(self, x, x_mask, g=None): 57 | if g is not None: 58 | x = x + g 59 | for i in range(self.n_layers): 60 | y = self.convs_sep[i](x * x_mask) 61 | y = self.norms_1[i](y) 62 | y = F.gelu(y) 63 | y = self.convs_1x1[i](y) 64 | y = self.norms_2[i](y) 65 | y = F.gelu(y) 66 | y = self.drop(y) 67 | x = x + y 68 | return x * x_mask 69 | 70 | 71 | class Log(nn.Module): 72 | def forward(self, x, x_mask, reverse=False, **kwargs): 73 | if not reverse: 74 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 75 | logdet = torch.sum(-y, [1, 2]) 76 | return y, logdet 77 | else: 78 | x = torch.exp(x) * x_mask 79 | return x 80 | 81 | 82 | class Flip(nn.Module): 83 | def forward(self, x, *args, reverse=False, **kwargs): 84 | x = torch.flip(x, [1]) 85 | if not reverse: 86 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 87 | return x, logdet 88 | else: 89 | return x 90 | 91 | 92 | class ElementwiseAffine(nn.Module): 93 | def __init__(self, channels): 94 | super().__init__() 95 | self.channels = channels 96 | self.m = nn.Parameter(torch.zeros(channels,1)) 97 | self.logs = nn.Parameter(torch.zeros(channels,1)) 98 | 99 | def forward(self, x, x_mask, reverse=False, **kwargs): 100 | if not reverse: 101 | y = self.m + torch.exp(self.logs) * x 102 | y = y * x_mask 103 | logdet = torch.sum(self.logs * x_mask, [1,2]) 104 | return y, logdet 105 | else: 106 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 107 | return x 108 | 109 | 110 | class ConvFlow(nn.Module): 111 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 112 | super().__init__() 113 | self.in_channels = in_channels 114 | self.filter_channels = filter_channels 115 | self.kernel_size = kernel_size 116 | self.n_layers = n_layers 117 | self.num_bins = num_bins 118 | self.tail_bound = tail_bound 119 | self.half_channels = in_channels // 2 120 | 121 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 122 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 123 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 124 | self.proj.weight.data.zero_() 125 | self.proj.bias.data.zero_() 126 | 127 | def forward(self, x, x_mask, g=None, reverse=False): 128 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 129 | h = self.pre(x0) 130 | h = self.convs(h, x_mask, g=g) 131 | h = self.proj(h) * x_mask 132 | 133 | b, c, t = x0.shape 134 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 135 | 136 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 137 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 138 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 139 | 140 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 141 | unnormalized_widths, 142 | unnormalized_heights, 143 | unnormalized_derivatives, 144 | inverse=reverse, 145 | tails='linear', 146 | tail_bound=self.tail_bound 147 | ) 148 | 149 | x = torch.cat([x0, x1], 1) * x_mask 150 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 151 | if not reverse: 152 | return x, logdet 153 | else: 154 | return x -------------------------------------------------------------------------------- /modules/glow_tts_modules/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet -------------------------------------------------------------------------------- /notebook/inference_glowtts_stdp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TTS Inference" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "vscode": { 15 | "languageId": "python" 16 | } 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "\n", 21 | "import torch\n", 22 | "\n", 23 | "from models.glow_tts_with_pitch import GlowTTSModel\n", 24 | "from utils.data import load_speaker_emb\n", 25 | "\n", 26 | "from nemo.collections.tts.models import HifiGanModel" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "vscode": { 34 | "languageId": "python" 35 | } 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "def infer(\n", 40 | " spec_gen_model,\n", 41 | " vocoder_model,\n", 42 | " str_input,\n", 43 | " noise_scale=0.0,\n", 44 | " length_scale=1.0,\n", 45 | " speaker=None,\n", 46 | " speaker_embeddings=None,\n", 47 | " stoch_dur_noise_scale=0.8,\n", 48 | " stoch_pitch_noise_scale=1.0,\n", 49 | " pitch_scale=0.0,\n", 50 | "):\n", 51 | "\n", 52 | " with torch.no_grad():\n", 53 | " parsed = spec_gen_model.parse(str_input)\n", 54 | "\n", 55 | " spectrogram = spec_gen_model.generate_spectrogram(\n", 56 | " tokens=parsed,\n", 57 | " noise_scale=noise_scale,\n", 58 | " length_scale=length_scale,\n", 59 | " speaker=speaker,\n", 60 | " speaker_embeddings=speaker_embeddings,\n", 61 | " stoch_dur_noise_scale=stoch_dur_noise_scale,\n", 62 | " stoch_pitch_noise_scale=stoch_pitch_noise_scale,\n", 63 | " pitch_scale=pitch_scale,\n", 64 | " )\n", 65 | "\n", 66 | " audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)\n", 67 | "\n", 68 | " if spectrogram is not None:\n", 69 | " if isinstance(spectrogram, torch.Tensor):\n", 70 | " spectrogram = spectrogram.to(\"cpu\").numpy()\n", 71 | " if len(spectrogram.shape) == 3:\n", 72 | " spectrogram = spectrogram[0]\n", 73 | " if isinstance(audio, torch.Tensor):\n", 74 | " audio = audio.to(\"cpu\").numpy()\n", 75 | " return spectrogram, audio" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "vscode": { 83 | "languageId": "python" 84 | } 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 89 | "print(device)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "vscode": { 97 | "languageId": "python" 98 | } 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "# load glowtts model from checkpoint\n", 103 | "spec_gen = GlowTTSModel.load_from_checkpoint(checkpoint_path=checkpoint_path)\n", 104 | "spec_gen = spec_gen.eval().to(device)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "vscode": { 112 | "languageId": "python" 113 | } 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "# load vocoder from checkpoint\n", 118 | "vocoder = HifiGanModel.load_from_checkpoint(checkpoint).eval().to(device)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "vscode": { 126 | "languageId": "python" 127 | } 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "# Load speaker embeddings for conditioning\n", 132 | "speaker_emb_dict = load_speaker_emb(spk_emb_path)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": { 138 | "tags": [] 139 | }, 140 | "source": [ 141 | "## Inference" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Now that everything is set up, let's give an input that we want our models to speak" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "vscode": { 156 | "languageId": "python" 157 | } 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "# Extract speaker embedding from file\n", 162 | "\n", 163 | "audio_path = \"common_voice_en_18498899.wav\"\n", 164 | "audio_path_wo = audio_path.split(\".\")[0]\n", 165 | "\n", 166 | "speaker_embeddings = speaker_emb_dict.get(audio_path_wo)\n", 167 | "speaker_embeddings = torch.from_numpy(speaker_embeddings).reshape(1, -1).to(device)\n", 168 | "\n", 169 | "if speaker_embeddings is None:\n", 170 | " print(\"Could not load speaker embedding\")" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Inference" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "vscode": { 185 | "languageId": "python" 186 | } 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "# Inference hyperparameters\n", 191 | "\n", 192 | "sr=16000\n", 193 | "noise_scale=0.667\n", 194 | "length_scale=1.0 #\n", 195 | "stoch_dur_noise_scale=0.8 #0.0-1.0\n", 196 | "stoch_pitch_noise_scale=0.8\n", 197 | "pitch_scale=0.0\n", 198 | "speaker=None" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": { 205 | "vscode": { 206 | "languageId": "python" 207 | } 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "from nemo_text_processing.text_normalization.normalize import Normalizer\n", 212 | "\n", 213 | "# initialize normalizer\n", 214 | "normalizer = Normalizer(input_case=\"cased\", lang=\"en\")" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": { 221 | "vscode": { 222 | "languageId": "python" 223 | } 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "text_to_generate = \"A look of fear crossed his face, but he regained his serenity immediately.\"\n", 228 | "\n", 229 | "# normalize text. necessary in case text contains numeric text, dates, and abbreviations\n", 230 | "text_to_generate = normalizer.normalize(text_to_generate)\n", 231 | "print(text_to_generate)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": { 238 | "vscode": { 239 | "languageId": "python" 240 | } 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "\n", 245 | "log_spec, audio = infer(spec_gen, vocoder, text_to_generate, \n", 246 | " noise_scale=noise_scale,\n", 247 | " length_scale=length_scale,\n", 248 | " speaker=speaker,\n", 249 | " stoch_dur_noise_scale=stoch_dur_noise_scale,\n", 250 | " stoch_pitch_noise_scale=stoch_pitch_noise_scale,\n", 251 | " pitch_scale=pitch_scale,\n", 252 | " speaker_embeddings=speaker_embeddings,)\n" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "vscode": { 260 | "languageId": "python" 261 | } 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "ipd.Audio(audio, rate=sr)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": { 272 | "vscode": { 273 | "languageId": "python" 274 | } 275 | }, 276 | "outputs": [], 277 | "source": [] 278 | } 279 | ], 280 | "metadata": { 281 | "kernelspec": { 282 | "display_name": "Python 3 (ipykernel)", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "vscode": { 287 | "interpreter": { 288 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 289 | } 290 | } 291 | }, 292 | "nbformat": 4, 293 | "nbformat_minor": 4 294 | } 295 | -------------------------------------------------------------------------------- /train_glowtts_baseline.sh: -------------------------------------------------------------------------------- 1 | 2 | python glow_tts_with_pitch.py \ 3 | --config-path="conf" \ 4 | --config-name="glow_tts_baseline" \ 5 | train_dataset="dataset/train_tts_common_voice.json" \ 6 | validation_datasets="dataset/dev_tts_common_voice.json" \ 7 | speaker_emb_path="../tts_experiments/embeddings/spk_emb_exp_1plus.pkl" \ 8 | sup_data_path="../cv-corpus-7.0-2021-07-21/en/cv_mos4_all_sup_data_folder" -------------------------------------------------------------------------------- /train_glowtts_std.sh: -------------------------------------------------------------------------------- 1 | 2 | python glow_tts_with_pitch.py \ 3 | --config-path="conf" \ 4 | --config-name="glow_tts_std" \ 5 | train_dataset="dataset/train_tts_common_voice.json" \ 6 | validation_datasets="dataset/dev_tts_common_voice.json" \ 7 | speaker_emb_path="../tts_experiments/embeddings/spk_emb_exp_1plus.pkl" \ 8 | sup_data_path="../cv-corpus-7.0-2021-07-21/en/cv_mos4_all_sup_data_folder" -------------------------------------------------------------------------------- /train_glowtts_stdp.sh: -------------------------------------------------------------------------------- 1 | 2 | python glow_tts_with_pitch.py \ 3 | --config-path="conf" \ 4 | --config-name="glow_tts_stdp" \ 5 | train_dataset="dataset/train_tts_common_voice.json" \ 6 | validation_datasets="dataset/dev_tts_common_voice.json" \ 7 | speaker_emb_path="../tts_experiments/embeddings/spk_emb_exp_1plus.pkl" \ 8 | sup_data_path="../cv-corpus-7.0-2021-07-21/en/cv_mos4_all_sup_data_folder" 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ogunlao/glowtts_stdp/07f71bbfce405018f43db7c086f46c6b91defa28/utils/__init__.py -------------------------------------------------------------------------------- /utils/glow_tts_loss.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import json 3 | 4 | from nemo.utils import logging 5 | 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from nemo.core.classes import Loss, typecheck 12 | from nemo.core.neural_types.elements import * 13 | from nemo.core.neural_types.neural_type import NeuralType 14 | 15 | 16 | class GlowTTSLoss(Loss): 17 | """ 18 | Loss for the GlowTTS model 19 | """ 20 | 21 | @property 22 | def input_types(self): 23 | return { 24 | "z": NeuralType(('B', 'D', 'T'), NormalDistributionSamplesType()), 25 | "y_m": NeuralType(('B', 'D', 'T'), NormalDistributionMeanType()), 26 | "y_logs": NeuralType(('B', 'D', 'T'), NormalDistributionLogVarianceType()), 27 | "logdet": NeuralType(('B',), LogDeterminantType()), 28 | "logw": NeuralType(('B', 'T'), TokenLogDurationType()), 29 | "logw_": NeuralType(('B', 'T'), TokenLogDurationType()), 30 | "x_lengths": NeuralType(('B',), LengthsType()), 31 | "y_lengths": NeuralType(('B',), LengthsType()), 32 | "stoch_dur_loss": NeuralType(optional=True), 33 | } 34 | 35 | @property 36 | def output_types(self): 37 | return { 38 | "l_mle": NeuralType(elements_type=LossType()), 39 | "l_length": NeuralType(elements_type=LossType()), 40 | "logdet": NeuralType(elements_type=VoidType()), 41 | } 42 | 43 | @typecheck() 44 | def forward(self, z, y_m, y_logs, logdet, logw, logw_, x_lengths, y_lengths, stoch_dur_loss,): 45 | 46 | logdet = torch.sum(logdet) 47 | l_mle = 0.5 * math.log(2 * math.pi) + ( 48 | torch.sum(y_logs) + 0.5 * torch.sum(torch.exp(-2 * y_logs) * (z - y_m) ** 2) - logdet 49 | ) / (torch.sum(y_lengths) * z.shape[1]) 50 | 51 | if stoch_dur_loss is None: 52 | l_length = torch.sum((logw - logw_) ** 2) / torch.sum(x_lengths) 53 | else: 54 | l_length = stoch_dur_loss 55 | return l_mle, l_length, logdet -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # BSD 3-Clause License 16 | # 17 | # Copyright (c) 2021, NVIDIA Corporation 18 | # All rights reserved. 19 | # 20 | # Redistribution and use in source and binary forms, with or without 21 | # modification, are permitted provided that the following conditions are met: 22 | # 23 | # * Redistributions of source code must retain the above copyright notice, this 24 | # list of conditions and the following disclaimer. 25 | # 26 | # * Redistributions in binary form must reproduce the above copyright notice, 27 | # this list of conditions and the following disclaimer in the documentation 28 | # and/or other materials provided with the distribution. 29 | # 30 | # * Neither the name of the copyright holder nor the names of its 31 | # contributors may be used to endorse or promote products derived from 32 | # this software without specific prior written permission. 33 | # 34 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 35 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 36 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 37 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 38 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 39 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 40 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 41 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 42 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 43 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 44 | 45 | from enum import Enum 46 | from typing import Dict, Optional, Sequence 47 | 48 | import pickle as pkl 49 | import json 50 | 51 | import librosa 52 | import matplotlib.pylab as plt 53 | import numpy as np 54 | import torch 55 | from numba import jit, prange 56 | from numpy import ndarray 57 | from pesq import pesq 58 | from pystoi import stoi 59 | 60 | from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens 61 | from nemo.utils import logging 62 | 63 | HAVE_WANDB = True 64 | try: 65 | import wandb 66 | except ModuleNotFoundError: 67 | HAVE_WANDB = False 68 | 69 | try: 70 | from pytorch_lightning.utilities import rank_zero_only 71 | except ModuleNotFoundError: 72 | from functools import wraps 73 | 74 | def rank_zero_only(fn): 75 | @wraps(fn) 76 | def wrapped_fn(*args, **kwargs): 77 | logging.error( 78 | f"Function {fn} requires lighting to be installed, but it was not found. Please install lightning first" 79 | ) 80 | exit(1) 81 | 82 | 83 | class OperationMode(Enum): 84 | """Training or Inference (Evaluation) mode""" 85 | 86 | training = 0 87 | validation = 1 88 | infer = 2 89 | 90 | 91 | def get_batch_size(train_dataloader): 92 | if train_dataloader.batch_size is not None: 93 | return train_dataloader.batch_size 94 | elif train_dataloader.batch_sampler is not None: 95 | if train_dataloader.batch_sampler.micro_batch_size is not None: 96 | return train_dataloader.batch_sampler.micro_batch_size 97 | else: 98 | raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}') 99 | else: 100 | raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}') 101 | 102 | 103 | def get_num_workers(trainer): 104 | return trainer.num_devices * trainer.num_nodes 105 | 106 | 107 | def binarize_attention(attn, in_len, out_len): 108 | """Convert soft attention matrix to hard attention matrix. 109 | 110 | Args: 111 | attn (torch.Tensor): B x 1 x max_mel_len x max_text_len. Soft attention matrix. 112 | in_len (torch.Tensor): B. Lengths of texts. 113 | out_len (torch.Tensor): B. Lengths of spectrograms. 114 | 115 | Output: 116 | attn_out (torch.Tensor): B x 1 x max_mel_len x max_text_len. Hard attention matrix, final dim max_text_len should sum to 1. 117 | """ 118 | b_size = attn.shape[0] 119 | with torch.no_grad(): 120 | attn_cpu = attn.data.cpu().numpy() 121 | attn_out = torch.zeros_like(attn) 122 | for ind in range(b_size): 123 | hard_attn = mas(attn_cpu[ind, 0, : out_len[ind], : in_len[ind]]) 124 | attn_out[ind, 0, : out_len[ind], : in_len[ind]] = torch.tensor(hard_attn, device=attn.device) 125 | return attn_out 126 | 127 | 128 | def binarize_attention_parallel(attn, in_lens, out_lens): 129 | """For training purposes only. Binarizes attention with MAS. 130 | These will no longer receive a gradient. 131 | 132 | Args: 133 | attn: B x 1 x max_mel_len x max_text_len 134 | """ 135 | with torch.no_grad(): 136 | attn_cpu = attn.data.cpu().numpy() 137 | attn_out = b_mas(attn_cpu, in_lens.cpu().numpy(), out_lens.cpu().numpy(), width=1) 138 | return torch.from_numpy(attn_out).to(attn.device) 139 | 140 | 141 | def get_mask_from_lengths(lengths, max_len: Optional[int] = None): 142 | if max_len is None: 143 | max_len = lengths.max() 144 | ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long) 145 | mask = (ids < lengths.unsqueeze(1)).bool() 146 | return mask 147 | 148 | 149 | @jit(nopython=True) 150 | def mas(attn_map, width=1): 151 | # assumes mel x text 152 | opt = np.zeros_like(attn_map) 153 | attn_map = np.log(attn_map) 154 | attn_map[0, 1:] = -np.inf 155 | log_p = np.zeros_like(attn_map) 156 | log_p[0, :] = attn_map[0, :] 157 | prev_ind = np.zeros_like(attn_map, dtype=np.int64) 158 | for i in range(1, attn_map.shape[0]): 159 | for j in range(attn_map.shape[1]): # for each text dim 160 | prev_j = np.arange(max(0, j - width), j + 1) 161 | prev_log = np.array([log_p[i - 1, prev_idx] for prev_idx in prev_j]) 162 | 163 | ind = np.argmax(prev_log) 164 | log_p[i, j] = attn_map[i, j] + prev_log[ind] 165 | prev_ind[i, j] = prev_j[ind] 166 | 167 | # now backtrack 168 | curr_text_idx = attn_map.shape[1] - 1 169 | for i in range(attn_map.shape[0] - 1, -1, -1): 170 | opt[i, curr_text_idx] = 1 171 | curr_text_idx = prev_ind[i, curr_text_idx] 172 | opt[0, curr_text_idx] = 1 173 | 174 | assert opt.sum(0).all() 175 | assert opt.sum(1).all() 176 | 177 | return opt 178 | 179 | 180 | @jit(nopython=True) 181 | def mas_width1(attn_map): 182 | """mas with hardcoded width=1""" 183 | # assumes mel x text 184 | opt = np.zeros_like(attn_map) 185 | attn_map = np.log(attn_map) 186 | attn_map[0, 1:] = -np.inf 187 | log_p = np.zeros_like(attn_map) 188 | log_p[0, :] = attn_map[0, :] 189 | prev_ind = np.zeros_like(attn_map, dtype=np.int64) 190 | for i in range(1, attn_map.shape[0]): 191 | for j in range(attn_map.shape[1]): # for each text dim 192 | prev_log = log_p[i - 1, j] 193 | prev_j = j 194 | 195 | if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]: 196 | prev_log = log_p[i - 1, j - 1] 197 | prev_j = j - 1 198 | 199 | log_p[i, j] = attn_map[i, j] + prev_log 200 | prev_ind[i, j] = prev_j 201 | 202 | # now backtrack 203 | curr_text_idx = attn_map.shape[1] - 1 204 | for i in range(attn_map.shape[0] - 1, -1, -1): 205 | opt[i, curr_text_idx] = 1 206 | curr_text_idx = prev_ind[i, curr_text_idx] 207 | opt[0, curr_text_idx] = 1 208 | return opt 209 | 210 | 211 | def b_mas(b_attn_map, in_lens, out_lens, width=1): 212 | assert width == 1 213 | attn_out = np.zeros_like(b_attn_map) 214 | 215 | for b in range(b_attn_map.shape[0]): 216 | out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]]) 217 | attn_out[b, 0, : out_lens[b], : in_lens[b]] = out 218 | return attn_out 219 | 220 | # @jit(nopython=True, parallel=True) 221 | # def b_mas(b_attn_map, in_lens, out_lens, width=1): 222 | # assert width == 1 223 | # attn_out = np.zeros_like(b_attn_map) 224 | 225 | # for b in prange(b_attn_map.shape[0]): 226 | # out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]]) 227 | # attn_out[b, 0, : out_lens[b], : in_lens[b]] = out 228 | # return attn_out 229 | 230 | 231 | def griffin_lim(magnitudes, n_iters=50, n_fft=1024): 232 | """ 233 | Griffin-Lim algorithm to convert magnitude spectrograms to audio signals 234 | """ 235 | phase = np.exp(2j * np.pi * np.random.rand(*magnitudes.shape)) 236 | complex_spec = magnitudes * phase 237 | signal = librosa.istft(complex_spec) 238 | if not np.isfinite(signal).all(): 239 | logging.warning("audio was not finite, skipping audio saving") 240 | return np.array([0]) 241 | 242 | for _ in range(n_iters): 243 | _, phase = librosa.magphase(librosa.stft(signal, n_fft=n_fft)) 244 | complex_spec = magnitudes * phase 245 | signal = librosa.istft(complex_spec) 246 | return signal 247 | 248 | 249 | @rank_zero_only 250 | def log_audio_to_tb( 251 | swriter, 252 | spect, 253 | name, 254 | step, 255 | griffin_lim_mag_scale=1024, 256 | griffin_lim_power=1.2, 257 | sr=22050, 258 | n_fft=1024, 259 | n_mels=80, 260 | fmax=8000, 261 | ): 262 | filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax) 263 | log_mel = spect.data.cpu().numpy().T 264 | mel = np.exp(log_mel) 265 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale 266 | audio = griffin_lim(magnitude.T ** griffin_lim_power) 267 | swriter.add_audio(name, audio / max(np.abs(audio)), step, sample_rate=sr) 268 | 269 | 270 | @rank_zero_only 271 | def tacotron2_log_to_tb_func( 272 | swriter, 273 | tensors, 274 | step, 275 | tag="train", 276 | log_images=False, 277 | log_images_freq=1, 278 | add_audio=True, 279 | griffin_lim_mag_scale=1024, 280 | griffin_lim_power=1.2, 281 | sr=22050, 282 | n_fft=1024, 283 | n_mels=80, 284 | fmax=8000, 285 | ): 286 | _, spec_target, mel_postnet, gate, gate_target, alignments = tensors 287 | if log_images and step % log_images_freq == 0: 288 | swriter.add_image( 289 | f"{tag}_alignment", plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), step, dataformats="HWC", 290 | ) 291 | swriter.add_image( 292 | f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), step, dataformats="HWC", 293 | ) 294 | swriter.add_image( 295 | f"{tag}_mel_predicted", 296 | plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), 297 | step, 298 | dataformats="HWC", 299 | ) 300 | swriter.add_image( 301 | f"{tag}_gate", 302 | plot_gate_outputs_to_numpy(gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),), 303 | step, 304 | dataformats="HWC", 305 | ) 306 | 307 | if add_audio: 308 | filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax) 309 | log_mel = mel_postnet[0].data.cpu().numpy().T 310 | mel = np.exp(log_mel) 311 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale 312 | audio = griffin_lim(magnitude.T ** griffin_lim_power) 313 | swriter.add_audio(f"audio/{tag}_predicted", audio / max(np.abs(audio)), step, sample_rate=sr) 314 | 315 | log_mel = spec_target[0].data.cpu().numpy().T 316 | mel = np.exp(log_mel) 317 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale 318 | audio = griffin_lim(magnitude.T ** griffin_lim_power) 319 | swriter.add_audio(f"audio/{tag}_target", audio / max(np.abs(audio)), step, sample_rate=sr) 320 | 321 | 322 | def tacotron2_log_to_wandb_func( 323 | swriter, 324 | tensors, 325 | step, 326 | tag="train", 327 | log_images=False, 328 | log_images_freq=1, 329 | add_audio=True, 330 | griffin_lim_mag_scale=1024, 331 | griffin_lim_power=1.2, 332 | sr=22050, 333 | n_fft=1024, 334 | n_mels=80, 335 | fmax=8000, 336 | ): 337 | _, spec_target, mel_postnet, gate, gate_target, alignments = tensors 338 | if not HAVE_WANDB: 339 | return 340 | if log_images and step % log_images_freq == 0: 341 | alignments = [] 342 | specs = [] 343 | gates = [] 344 | alignments += [ 345 | wandb.Image(plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), caption=f"{tag}_alignment",) 346 | ] 347 | alignments += [ 348 | wandb.Image(plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), caption=f"{tag}_mel_target",), 349 | wandb.Image(plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), caption=f"{tag}_mel_predicted",), 350 | ] 351 | gates += [ 352 | wandb.Image( 353 | plot_gate_outputs_to_numpy( 354 | gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(), 355 | ), 356 | caption=f"{tag}_gate", 357 | ) 358 | ] 359 | 360 | swriter.log({"specs": specs, "alignments": alignments, "gates": gates}) 361 | 362 | if add_audio: 363 | audios = [] 364 | filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax) 365 | log_mel = mel_postnet[0].data.cpu().numpy().T 366 | mel = np.exp(log_mel) 367 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale 368 | audio_pred = griffin_lim(magnitude.T ** griffin_lim_power) 369 | 370 | log_mel = spec_target[0].data.cpu().numpy().T 371 | mel = np.exp(log_mel) 372 | magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale 373 | audio_true = griffin_lim(magnitude.T ** griffin_lim_power) 374 | 375 | audios += [ 376 | wandb.Audio(audio_true / max(np.abs(audio_true)), caption=f"{tag}_wav_target", sample_rate=sr,), 377 | wandb.Audio(audio_pred / max(np.abs(audio_pred)), caption=f"{tag}_wav_predicted", sample_rate=sr,), 378 | ] 379 | 380 | swriter.log({"audios": audios}) 381 | 382 | 383 | def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None): 384 | if phoneme_seq: 385 | fig, ax = plt.subplots(figsize=(15, 10)) 386 | else: 387 | fig, ax = plt.subplots(figsize=(6, 4)) 388 | im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax) 389 | ax.set_title(title) 390 | fig.colorbar(im, ax=ax) 391 | xlabel = 'Decoder timestep' 392 | if info is not None: 393 | xlabel += '\n\n' + info 394 | plt.xlabel(xlabel) 395 | plt.ylabel('Encoder timestep') 396 | plt.tight_layout() 397 | 398 | if phoneme_seq != None: 399 | # for debugging of phonemes and durs in maps. Not used by def in training code 400 | ax.set_yticks(np.arange(len(phoneme_seq))) 401 | ax.set_yticklabels(phoneme_seq) 402 | ax.hlines(np.arange(len(phoneme_seq)), xmin=0.0, xmax=max(ax.get_xticks())) 403 | 404 | fig.canvas.draw() 405 | data = save_figure_to_numpy(fig) 406 | plt.close() 407 | return data 408 | 409 | 410 | def plot_pitch_to_numpy(pitch, ylim_range=None): 411 | fig, ax = plt.subplots(figsize=(12, 3)) 412 | plt.plot(pitch) 413 | if ylim_range is not None: 414 | plt.ylim(ylim_range) 415 | plt.xlabel("Frames") 416 | plt.ylabel("Pitch") 417 | plt.tight_layout() 418 | 419 | fig.canvas.draw() 420 | data = save_figure_to_numpy(fig) 421 | plt.close() 422 | return data 423 | 424 | 425 | def plot_spectrogram_to_numpy(spectrogram): 426 | spectrogram = spectrogram.astype(np.float32) 427 | fig, ax = plt.subplots(figsize=(12, 3)) 428 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none') 429 | plt.colorbar(im, ax=ax) 430 | plt.xlabel("Frames") 431 | plt.ylabel("Channels") 432 | plt.tight_layout() 433 | 434 | fig.canvas.draw() 435 | data = save_figure_to_numpy(fig) 436 | plt.close() 437 | return data 438 | 439 | 440 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 441 | fig, ax = plt.subplots(figsize=(12, 3)) 442 | ax.scatter( 443 | range(len(gate_targets)), gate_targets, alpha=0.5, color='green', marker='+', s=1, label='target', 444 | ) 445 | ax.scatter( 446 | range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=1, label='predicted', 447 | ) 448 | 449 | plt.xlabel("Frames (Green target, Red predicted)") 450 | plt.ylabel("Gate State") 451 | plt.tight_layout() 452 | 453 | fig.canvas.draw() 454 | data = save_figure_to_numpy(fig) 455 | plt.close() 456 | return data 457 | 458 | 459 | def save_figure_to_numpy(fig): 460 | # save it to a numpy array. 461 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 462 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 463 | return data 464 | 465 | 466 | @rank_zero_only 467 | def waveglow_log_to_tb_func( 468 | swriter, tensors, step, tag="train", n_fft=1024, hop_length=256, window="hann", mel_fb=None, 469 | ): 470 | _, audio_pred, spec_target, mel_length = tensors 471 | mel_length = mel_length[0] 472 | spec_target = spec_target[0].data.cpu().numpy()[:, :mel_length] 473 | swriter.add_image( 474 | f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target), step, dataformats="HWC", 475 | ) 476 | if mel_fb is not None: 477 | mag, _ = librosa.core.magphase( 478 | librosa.core.stft( 479 | np.nan_to_num(audio_pred[0].cpu().detach().numpy()), n_fft=n_fft, hop_length=hop_length, window=window, 480 | ) 481 | ) 482 | mel_pred = np.matmul(mel_fb.cpu().numpy(), mag).squeeze() 483 | log_mel_pred = np.log(np.clip(mel_pred, a_min=1e-5, a_max=None)) 484 | swriter.add_image( 485 | f"{tag}_mel_predicted", plot_spectrogram_to_numpy(log_mel_pred[:, :mel_length]), step, dataformats="HWC", 486 | ) 487 | 488 | 489 | def remove(conv_list): 490 | new_conv_list = torch.nn.ModuleList() 491 | for old_conv in conv_list: 492 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 493 | new_conv_list.append(old_conv) 494 | return new_conv_list 495 | 496 | 497 | def eval_tts_scores( 498 | y_clean: ndarray, y_est: ndarray, T_ys: Sequence[int] = (0,), sampling_rate=22050 499 | ) -> Dict[str, float]: 500 | """ 501 | calculate metric using EvalModule. y can be a batch. 502 | Args: 503 | y_clean: real audio 504 | y_est: estimated audio 505 | T_ys: length of the non-zero parts of the histograms 506 | sampling_rate: The used Sampling rate. 507 | 508 | Returns: 509 | A dictionary mapping scoring systems (string) to numerical scores. 510 | 1st entry: 'STOI' 511 | 2nd entry: 'PESQ' 512 | """ 513 | 514 | if y_clean.ndim == 1: 515 | y_clean = y_clean[np.newaxis, ...] 516 | y_est = y_est[np.newaxis, ...] 517 | if T_ys == (0,): 518 | T_ys = (y_clean.shape[1],) * y_clean.shape[0] 519 | 520 | clean = y_clean[0, : T_ys[0]] 521 | estimated = y_est[0, : T_ys[0]] 522 | stoi_score = stoi(clean, estimated, sampling_rate, extended=False) 523 | pesq_score = pesq(16000, np.asarray(clean), estimated, 'wb') 524 | ## fs was set 16,000, as pesq lib doesnt currently support felxible fs. 525 | 526 | return {'STOI': stoi_score, 'PESQ': pesq_score} 527 | 528 | 529 | def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): 530 | """A function that takes predicted durations per encoded token, and repeats enc_out according to the duration. 531 | NOTE: durations.shape[1] == enc_out.shape[1] 532 | 533 | Args: 534 | durations (torch.tensor): A tensor of shape (batch x enc_length) that represents how many times to repeat each 535 | token in enc_out. 536 | enc_out (torch.tensor): A tensor of shape (batch x enc_length x enc_hidden) that represents the encoded tokens. 537 | pace (float): The pace of speaker. Higher values result in faster speaking pace. Defaults to 1.0. 538 | max_mel_len (int): The maximum length above which the output will be removed. If sum(durations, dim=1) > 539 | max_mel_len, the values after max_mel_len will be removed. Defaults to None, which has no max length. 540 | """ 541 | 542 | dtype = enc_out.dtype 543 | reps = durations.float() / pace 544 | reps = (reps + 0.5).floor().long() 545 | dec_lens = reps.sum(dim=1) 546 | 547 | max_len = dec_lens.max() 548 | reps_cumsum = torch.cumsum(torch.nn.functional.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] 549 | reps_cumsum = reps_cumsum.to(dtype=dtype, device=enc_out.device) 550 | 551 | range_ = torch.arange(max_len).to(enc_out.device)[None, :, None] 552 | mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_) 553 | mult = mult.to(dtype) 554 | enc_rep = torch.matmul(mult, enc_out) 555 | 556 | if mel_max_len is not None: 557 | enc_rep = enc_rep[:, :mel_max_len] 558 | dec_lens = torch.clamp_max(dec_lens, mel_max_len) 559 | 560 | return enc_rep, dec_lens 561 | 562 | 563 | def split_view(tensor, split_size: int, dim: int = 0): 564 | if dim < 0: # Support negative indexing 565 | dim = len(tensor.shape) + dim 566 | # If not divisible by split_size, we need to pad with 0 567 | if tensor.shape[dim] % split_size != 0: 568 | to_pad = split_size - (tensor.shape[dim] % split_size) 569 | padding = [0] * len(tensor.shape) * 2 570 | padding[dim * 2 + 1] = to_pad 571 | padding.reverse() 572 | tensor = torch.nn.functional.pad(tensor, padding) 573 | cur_shape = tensor.shape 574 | new_shape = cur_shape[:dim] + (tensor.shape[dim] // split_size, split_size) + cur_shape[dim + 1 :] 575 | return tensor.reshape(*new_shape) 576 | 577 | 578 | def process_batch(batch_data, sup_data_types_set): 579 | batch_dict = {} 580 | batch_index = 0 581 | for name, datatype in DATA_STR2DATA_CLASS.items(): 582 | if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set: 583 | batch_dict[name] = batch_data[batch_index] 584 | batch_index = batch_index + 1 585 | if issubclass(datatype, WithLens): 586 | batch_dict[name + "_lens"] = batch_data[batch_index] 587 | batch_index = batch_index + 1 588 | return batch_dict 589 | 590 | -------------------------------------------------------------------------------- /utils/tts_data_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class TTSDataType: 17 | """Represent TTSDataType.""" 18 | 19 | name = None 20 | 21 | 22 | class WithLens: 23 | """Represent that this data type also returns lengths for data.""" 24 | 25 | 26 | class Audio(TTSDataType, WithLens): 27 | name = "audio" 28 | 29 | 30 | class Text(TTSDataType, WithLens): 31 | name = "text" 32 | 33 | 34 | class LogMel(TTSDataType, WithLens): 35 | name = "log_mel" 36 | 37 | 38 | class Durations(TTSDataType): 39 | name = "durations" 40 | 41 | 42 | class AlignPriorMatrix(TTSDataType): 43 | name = "align_prior_matrix" 44 | 45 | 46 | class Pitch(TTSDataType, WithLens): 47 | name = "pitch" 48 | 49 | 50 | class Energy(TTSDataType, WithLens): 51 | name = "energy" 52 | 53 | 54 | class SpeakerID(TTSDataType): 55 | name = "speaker_id" 56 | 57 | 58 | class SpeakerEmb(TTSDataType): 59 | name = "speaker_emb" 60 | 61 | 62 | class Voiced_mask(TTSDataType): 63 | name = "voiced_mask" 64 | 65 | 66 | class P_voiced(TTSDataType): 67 | name = "p_voiced" 68 | 69 | 70 | class LMTokens(TTSDataType): 71 | name = "lm_tokens" 72 | 73 | 74 | MAIN_DATA_TYPES = [Audio, Text] 75 | VALID_SUPPLEMENTARY_DATA_TYPES = [ 76 | LogMel, 77 | Durations, 78 | AlignPriorMatrix, 79 | Pitch, 80 | Energy, 81 | SpeakerID, 82 | SpeakerEmb, 83 | LMTokens, 84 | Voiced_mask, 85 | P_voiced, 86 | ] 87 | DATA_STR2DATA_CLASS = {d.name: d for d in MAIN_DATA_TYPES + VALID_SUPPLEMENTARY_DATA_TYPES} 88 | -------------------------------------------------------------------------------- /utils/tts_tokenizers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import itertools 17 | import string 18 | from abc import ABC, abstractmethod 19 | from contextlib import contextmanager 20 | from typing import List, Optional 21 | 22 | from nemo_text_processing.g2p.data.data_utils import english_text_preprocessing 23 | 24 | from nemo.utils import logging 25 | 26 | 27 | class BaseTokenizer(ABC): 28 | PAD, BLANK, OOV = '', '', '' 29 | 30 | def __init__(self, tokens, *, pad=PAD, blank=BLANK, oov=OOV, sep='', add_blank_at=None, ): 31 | """Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens. 32 | Args: 33 | tokens: List of tokens. 34 | pad: Pad token as string. 35 | blank: Blank token as string. 36 | oov: OOV token as string. 37 | sep: Separation token as string. 38 | add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), 39 | if None then no blank in labels. 40 | """ 41 | super().__init__() 42 | 43 | tokens = list(tokens) 44 | self.pad, tokens = len(tokens), tokens + [pad] # Padding 45 | 46 | if add_blank_at is not None: 47 | self.blank, tokens = len(tokens), tokens + [blank] # Reserved for blank from asr-model 48 | else: 49 | # use add_blank_at=None only for ASR where blank is added automatically, disable blank here 50 | self.blank = None 51 | 52 | self.oov, tokens = len(tokens), tokens + [oov] # Out Of Vocabulary 53 | 54 | if add_blank_at == "last": 55 | tokens[-1], tokens[-2] = tokens[-2], tokens[-1] 56 | self.oov, self.blank = self.blank, self.oov 57 | 58 | self.tokens = tokens 59 | self.sep = sep 60 | 61 | self._util_ids = {self.pad, self.blank, self.oov} 62 | self._token2id = {l: i for i, l in enumerate(tokens)} 63 | self._id2token = tokens 64 | 65 | def __call__(self, text: str) -> List[int]: 66 | return self.encode(text) 67 | 68 | @abstractmethod 69 | def encode(self, text: str) -> List[int]: 70 | """Turns str text into int tokens.""" 71 | pass 72 | 73 | def decode(self, tokens: List[int]) -> str: 74 | """Turns ints tokens into str text.""" 75 | return self.sep.join(self._id2token[t] for t in tokens if t not in self._util_ids) 76 | 77 | def intersperse(self, token, item): 78 | result = [item] * (len(token) * 2 + 1) 79 | result[1::2] = token 80 | return result 81 | 82 | 83 | class BaseCharsTokenizer(BaseTokenizer): 84 | # fmt: off 85 | PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally 86 | ',', '.', '!', '?', '-', 87 | ':', ';', '/', '"', '(', 88 | ')', '[', ']', '{', '}', 89 | ) 90 | # fmt: on 91 | 92 | def __init__( 93 | self, 94 | chars, 95 | punct=True, 96 | apostrophe=True, 97 | add_blank_at=None, 98 | pad_with_space=False, 99 | non_default_punct_list=None, 100 | text_preprocessing_func=lambda x: x, 101 | add_blank_to_text=False, 102 | ): 103 | """Base class for char-based tokenizer. 104 | Args: 105 | chars: string that represents all possible characters. 106 | punct: Whether to reserve grapheme for basic punctuation or not. 107 | apostrophe: Whether to use apostrophe or not. 108 | add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), 109 | if None then no blank in labels. 110 | pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. 111 | non_default_punct_list: List of punctuation marks which will be used instead default. 112 | text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. 113 | """ 114 | 115 | tokens = [] 116 | self.space, tokens = len(tokens), tokens + [' '] # Space 117 | tokens.extend(chars) 118 | if apostrophe: 119 | tokens.append("'") # Apostrophe for saving "don't" and "Joe's" 120 | 121 | if punct: 122 | if non_default_punct_list is not None: 123 | self.PUNCT_LIST = non_default_punct_list 124 | tokens.extend(self.PUNCT_LIST) 125 | 126 | add_blank_at = "last" if add_blank_to_text else add_blank_at 127 | self.add_blank_to_text = add_blank_to_text 128 | 129 | super().__init__(tokens, add_blank_at=add_blank_at) 130 | 131 | self.punct = punct 132 | self.pad_with_space = pad_with_space 133 | 134 | self.text_preprocessing_func = text_preprocessing_func 135 | 136 | def encode(self, text): 137 | """See base class.""" 138 | cs, space, tokens = [], self.tokens[self.space], set(self.tokens) 139 | 140 | text = self.text_preprocessing_func(text) 141 | for c in text: 142 | # Add space if last one isn't one 143 | if c == space and len(cs) > 0 and cs[-1] != space: 144 | cs.append(c) 145 | # Add next char 146 | elif (c.isalnum() or c == "'") and c in tokens: 147 | cs.append(c) 148 | # Add punct 149 | elif (c in self.PUNCT_LIST) and self.punct: 150 | cs.append(c) 151 | # Warn about unknown char 152 | elif c != space: 153 | logging.warning(f"Text: [{text}] contains unknown char: [{c}]. Symbol will be skipped.") 154 | 155 | # Remove trailing spaces 156 | if cs: 157 | while cs[-1] == space: 158 | cs.pop() 159 | 160 | if self.add_blank_to_text: 161 | blank = self._id2token[self.blank] 162 | cs = self.intersperse(cs, blank) 163 | elif self.pad_with_space: 164 | cs = [space] + cs + [space] 165 | 166 | return [self._token2id[p] for p in cs] 167 | 168 | 169 | class EnglishPhonemesTokenizer(BaseTokenizer): 170 | # fmt: off 171 | PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally 172 | ',', '.', '!', '?', '-', 173 | ':', ';', '/', '"', '(', 174 | ')', '[', ']', '{', '}', 175 | ) 176 | VOWELS = ( 177 | 'AA', 'AE', 'AH', 'AO', 'AW', 178 | 'AY', 'EH', 'ER', 'EY', 'IH', 179 | 'IY', 'OW', 'OY', 'UH', 'UW', 180 | ) 181 | CONSONANTS = ( 182 | 'B', 'CH', 'D', 'DH', 'F', 'G', 183 | 'HH', 'JH', 'K', 'L', 'M', 'N', 184 | 'NG', 'P', 'R', 'S', 'SH', 'T', 185 | 'TH', 'V', 'W', 'Y', 'Z', 'ZH', 186 | ) 187 | # fmt: on 188 | 189 | def __init__( 190 | self, 191 | g2p, 192 | punct=True, 193 | non_default_punct_list=None, 194 | stresses=False, 195 | chars=False, 196 | *, 197 | space=' ', 198 | silence=None, 199 | apostrophe=True, 200 | oov=BaseTokenizer.OOV, 201 | sep='|', # To be able to distinguish between 2/3 letters codes. 202 | add_blank_at=None, 203 | pad_with_space=False, 204 | text_preprocessing_func=lambda text: english_text_preprocessing(text, lower=False), 205 | add_blank_to_text=False, 206 | ): 207 | """English phoneme-based tokenizer. 208 | Args: 209 | g2p: Grapheme to phoneme module. 210 | punct: Whether to reserve grapheme for basic punctuation or not. 211 | non_default_punct_list: List of punctuation marks which will be used instead default. 212 | stresses: Whether to use phonemes codes with stresses (0-2) or not. 213 | chars: Whether to additionally use chars together with phonemes. It is useful if g2p module can return chars too. 214 | space: Space token as string. 215 | silence: Silence token as string (will be disabled if it is None). 216 | apostrophe: Whether to use apostrophe or not. 217 | oov: OOV token as string. 218 | sep: Separation token as string. 219 | add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), 220 | if None then no blank in labels. 221 | pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. 222 | text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. 223 | Basically, it replaces all non-unicode characters with unicode ones. 224 | Note that lower() function shouldn't applied here, in case the text contains phonemes (it will be handled by g2p). 225 | """ 226 | 227 | self.phoneme_probability = None 228 | if hasattr(g2p, "phoneme_probability"): 229 | self.phoneme_probability = g2p.phoneme_probability 230 | tokens = [] 231 | self.space, tokens = len(tokens), tokens + [space] # Space 232 | 233 | if silence is not None: 234 | self.silence, tokens = len(tokens), tokens + [silence] # Silence 235 | 236 | tokens.extend(self.CONSONANTS) 237 | vowels = list(self.VOWELS) 238 | 239 | if stresses: 240 | vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))] 241 | tokens.extend(vowels) 242 | 243 | if chars or self.phoneme_probability is not None: 244 | if not chars: 245 | logging.warning( 246 | "phoneme_probability was not None, characters will be enabled even though " 247 | "chars was set to False." 248 | ) 249 | tokens.extend(string.ascii_lowercase) 250 | 251 | if apostrophe: 252 | tokens.append("'") # Apostrophe 253 | 254 | if punct: 255 | if non_default_punct_list is not None: 256 | self.PUNCT_LIST = non_default_punct_list 257 | tokens.extend(self.PUNCT_LIST) 258 | 259 | add_blank_at = "last" if add_blank_to_text else add_blank_at 260 | self.add_blank_to_text = add_blank_to_text 261 | 262 | super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at) 263 | 264 | self.chars = chars if self.phoneme_probability is None else True 265 | self.punct = punct 266 | self.stresses = stresses 267 | self.pad_with_space = pad_with_space 268 | 269 | self.text_preprocessing_func = text_preprocessing_func 270 | self.g2p = g2p 271 | 272 | def encode(self, text): 273 | """See base class for more information.""" 274 | 275 | text = self.text_preprocessing_func(text) 276 | g2p_text = self.g2p(text) # TODO: handle infer 277 | return self.encode_from_g2p(g2p_text, text) 278 | 279 | def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): 280 | """ 281 | Encodes text that has already been run through G2P. 282 | Called for encoding to tokens after text preprocessing and G2P. 283 | 284 | Args: 285 | g2p_text: G2P's output, could be a mixture of phonemes and graphemes, 286 | e.g. "see OOV" -> ['S', 'IY1', ' ', 'O', 'O', 'V'] 287 | raw_text: original raw input 288 | """ 289 | ps, space, tokens = [], self.tokens[self.space], set(self.tokens) 290 | for p in g2p_text: # noqa 291 | # Remove stress 292 | if p.isalnum() and len(p) == 3 and not self.stresses: 293 | p = p[:2] 294 | 295 | # Add space if last one isn't one 296 | if p == space and len(ps) > 0 and ps[-1] != space: 297 | ps.append(p) 298 | # Add next phoneme or char (if chars=True) 299 | elif (p.isalnum() or p == "'") and p in tokens: 300 | ps.append(p) 301 | # Add punct 302 | elif (p in self.PUNCT_LIST) and self.punct: 303 | ps.append(p) 304 | # Warn about unknown char/phoneme 305 | elif p != space: 306 | pass 307 | # message = f"Text: [{''.join(g2p_text)}] contains unknown char/phoneme: [{p}]." 308 | # if raw_text is not None: 309 | # message += f"Original text: [{raw_text}]. Symbol will be skipped." 310 | # logging.warning(message) 311 | 312 | # Remove trailing spaces 313 | if ps: 314 | while ps[-1] == space: 315 | ps.pop() 316 | 317 | if self.add_blank_to_text: 318 | blank = self._id2token[self.blank] 319 | ps = self.intersperse(ps, blank) 320 | elif self.pad_with_space: 321 | ps = [space] + ps + [space] 322 | # if self.pad_with_space: 323 | # ps = [space] + ps + [space] 324 | 325 | return [self._token2id[p] for p in ps] 326 | 327 | @contextmanager 328 | def set_phone_prob(self, prob): 329 | if hasattr(self.g2p, "phoneme_probability"): 330 | self.g2p.phoneme_probability = prob 331 | try: 332 | yield 333 | finally: 334 | if hasattr(self.g2p, "phoneme_probability"): 335 | self.g2p.phoneme_probability = self.phoneme_probability 336 | 337 | --------------------------------------------------------------------------------