├── .gitignore ├── LICENSE ├── README.md ├── assets ├── acoustic_embed.PNG └── adaspeech.PNG ├── compute_statistics.py ├── configs └── default.yaml ├── core ├── __init__.py ├── acoustic_encoder.py ├── attention.py ├── duration_modeling │ ├── __init__.py │ ├── duration_predictor.py │ └── length_regulator.py ├── embedding.py ├── encoder.py ├── modules.py ├── optimizer.py └── variance_predictor.py ├── dataset ├── __init__.py ├── audio_processing.py ├── dataloader.py ├── ljspeech.py └── texts │ ├── __init__.py │ ├── cleaners.py │ ├── cmudict.py │ ├── dict_.py │ ├── numbers.py │ └── symbols.py ├── demo_fastspeech2.ipynb ├── evaluation.py ├── export_torchscript.py ├── fastspeech.py ├── filelists ├── train_filelist.txt └── valid_filelist.txt ├── inference.py ├── nvidia_preprocessing.py ├── requirements.txt ├── tests ├── __init__.py └── test_fastspeech2.py ├── train_fastspeech.py └── utils ├── __init__.py ├── display.py ├── fastspeech2_script.py ├── hparams.py ├── plot.py ├── stft.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | idea/* 3 | /.idea 4 | /data 5 | /output 6 | /logs 7 | /__pycache__ 8 | /core/__pycache__ 9 | /core/duration_modeling/__pycache__ 10 | /core/energy_predictor/__pycache__ 11 | /core/pitch_predictor/__pycache__ 12 | /dataset/__pycache__ 13 | /dataset/texts/__pycache__ 14 | /utils/__pycache__ 15 | /checkpoints 16 | /trace_loss.txt 17 | /unused_code.txt 18 | /test.py 19 | /rest_tts.py 20 | /preprocess.py 21 | /trace_loss_nvidia.txt 22 | /conf 23 | /etc 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaSpeech: Adaptive Text to Speech for Custom Voice [WIP] 2 | Unofficial Pytorch implementation of [AdaSpeech](https://arxiv.org/pdf/2103.00993.pdf). 3 | 4 | 5 | ![](./assets/adaspeech.PNG) 6 | 7 | ## Note: 8 | * I am not considering multi-speaker use case, Iam much more focus only on single speaker. 9 | * I will use only `Utterance level encoder` and `Phoneme level encoder` not condition layer norm (which is the soul of AdaSpeech paper), it definelty restrict the adaptive nature of AdaSpeech but my focus is to improve FastSpeech 2 acoustic generalization rather than adaptation. 10 | 11 | ![](./assets/acoustic_embed.PNG) 12 | 13 | ## Citations 14 | ```bibtex 15 | @misc{chen2021adaspeech, 16 | title={AdaSpeech: Adaptive Text to Speech for Custom Voice}, 17 | author={Mingjian Chen and Xu Tan and Bohan Li and Yanqing Liu and Tao Qin and Sheng Zhao and Tie-Yan Liu}, 18 | year={2021}, 19 | eprint={2103.00993}, 20 | archivePrefix={arXiv}, 21 | primaryClass={eess.AS} 22 | } 23 | ``` 24 | 25 | ## Requirements : 26 | All code written in `Python 3.6.2` . 27 | * Install Pytorch 28 | > Before installing pytorch please check your Cuda version by running following command : 29 | `nvcc --version` 30 | ``` 31 | pip install torch torchvision 32 | ``` 33 | In this repo I have used Pytorch 1.6.0 for `torch.bucketize` feature which is not present in previous versions of PyTorch. 34 | 35 | 36 | * Installing other requirements : 37 | ``` 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | * To use Tensorboard install `tensorboard version 1.14.0` seperatly with supported `tensorflow (1.14.0)` 42 | 43 | 44 | 45 | ## For Preprocessing : 46 | 47 | `filelists` folder contains MFA (Motreal Force aligner) processed LJSpeech dataset files so you don't need to align text with audio (for extract duration) for LJSpeech dataset. 48 | For other dataset follow instruction [here](https://github.com/ivanvovk/DurIAN#6-how-to-align-your-own-data). For other pre-processing run following command : 49 | ``` 50 | python nvidia_preprocessing.py -d path_of_wavs 51 | ``` 52 | For finding the min and max of F0 and Energy 53 | ```buildoutcfg 54 | python compute_statistics.py 55 | ``` 56 | Update the following in `hparams.py` by min and max of F0 and Energy 57 | ``` 58 | p_min = Min F0/pitch 59 | p_max = Max F0 60 | e_min = Min energy 61 | e_max = Max energy 62 | ``` 63 | 64 | ## For training 65 | ``` 66 | python train_fastspeech.py --outdir etc -c configs/default.yaml -n "name" 67 | ``` 68 | 69 | ## Note 70 | * For more complete and end to end Voice cloning or Text to Speech (TTS) toolbox please visit [Deepsync Technologies](https://deepsync.co/). 71 | 72 | -------------------------------------------------------------------------------- /assets/acoustic_embed.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/assets/acoustic_embed.PNG -------------------------------------------------------------------------------- /assets/adaspeech.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/assets/adaspeech.PNG -------------------------------------------------------------------------------- /compute_statistics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from utils.util import get_files 4 | from tqdm import tqdm 5 | from utils.util import remove_outlier 6 | from utils.hparams import HParam 7 | 8 | if __name__ == "__main__": 9 | 10 | hp = HParam("./config/default.yaml") 11 | 12 | min_e = [] 13 | min_p = [] 14 | max_e = [] 15 | max_p = [] 16 | nz_min_p = [] 17 | nz_min_e = [] 18 | 19 | energy_path = os.path.join(hp.data.data_dir, "energy") 20 | pitch_path = os.path.join(hp.data.data_dir, "pitch") 21 | mel_path = os.path.join(hp.data.data_dir, "mels") 22 | energy_files = get_files(energy_path, extension=".npy") 23 | pitch_files = get_files(pitch_path, extension=".npy") 24 | mel_files = get_files(mel_path, extension=".npy") 25 | 26 | assert len(energy_files) == len(pitch_files) == len(mel_files) 27 | 28 | energy_vecs = [] 29 | for f in tqdm(energy_files): 30 | e = np.load(f) 31 | e = remove_outlier(e) 32 | energy_vecs.append(e) 33 | min_e.append(e.min()) 34 | nz_min_e.append(e[e > 0].min()) 35 | max_e.append(e.max()) 36 | 37 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) 38 | e_mean, e_std = np.mean(nonzeros), np.std(nonzeros) 39 | print("Non zero Min Energy : {}".format(min(nz_min_e))) 40 | print("Max Energy : {}".format(max(max_e))) 41 | print("Energy mean : {}".format(e_mean)) 42 | print("Energy std: {}".format(e_std)) 43 | 44 | pitch_vecs = [] 45 | bad_pitch = [] 46 | for f in tqdm(pitch_files): 47 | # print(f) 48 | p = np.load(f) 49 | p = remove_outlier(p) 50 | pitch_vecs.append(p) 51 | # print(len(p), "#########", p) 52 | try: 53 | min_p.append(p.min()) 54 | nz_min_p.append(p[p > 0].min()) 55 | max_p.append(p.max()) 56 | except ValueError: 57 | bad_pitch.append(f) 58 | 59 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) 60 | f0_mean, f0_std = np.mean(nonzeros), np.std(nonzeros) 61 | print("Min Pitch : {}".format(min(min_p))) 62 | print("Non zero Min Pitch : {}".format(min(nz_min_p))) 63 | print("Max Pitch : {}".format(max(max_p))) 64 | print("Pitch mean : {}".format(f0_mean)) 65 | print("Pitch std: {}".format(f0_std)) 66 | 67 | np.save( 68 | os.path.join(hp.data.data_dir, "e_mean.npy"), 69 | e_mean.astype(np.float32), 70 | allow_pickle=False, 71 | ) 72 | np.save( 73 | os.path.join(hp.data.data_dir, "e_std.npy"), 74 | e_std.astype(np.float32), 75 | allow_pickle=False, 76 | ) 77 | np.save( 78 | os.path.join(hp.data.data_dir, "f0_mean.npy"), 79 | f0_mean.astype(np.float32), 80 | allow_pickle=False, 81 | ) 82 | np.save( 83 | os.path.join(hp.data.data_dir, "f0_std.npy"), 84 | f0_std.astype(np.float32), 85 | allow_pickle=False, 86 | ) 87 | print("The len of bad Pitch Vectors is ", len(bad_pitch)) 88 | # print(bad_pitch) 89 | with open("bad_file.txt", "a") as f: 90 | for i in bad_pitch: 91 | c = i.split("/")[3].split(".")[0] 92 | f.write(c) 93 | f.write("\n") 94 | 95 | # print("Min Energy : {}".format(min(min_e))) 96 | 97 | # print("Min Pitch : {}".format(min(min_p))) 98 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: 'H:\Deepsync\backup\fastspeech\data\' 3 | wav_dir: 'H:\Deepsync\backup\deepsync\LJSpeech-1.1\wavs\' 4 | # Compute statistics 5 | e_mean: 21.578571319580078 6 | e_std: 18.916799545288086 7 | e_min: 0.01786651276051998 8 | e_max: 130.5338592529297 9 | 10 | f0_mean: 206.5135564772342 11 | f0_std: 53.633228905750336 12 | p_min: 71.0 13 | p_max: 676.2260946528305 # 799.8901977539062 14 | train_filelist: "./filelists/train_filelist.txt" 15 | valid_filelist: "./filelists/valid_filelist.txt" 16 | tts_cleaner_names: ['english_cleaners'] 17 | 18 | # feature extraction related 19 | audio: 20 | sample_rate: 22050 # sampling frequency 21 | fmax: 8000.0 # maximum frequency 22 | fmin: 0.0 # minimum frequency 23 | n_mels: 80 # number of mel basis 24 | n_fft: 1024 # number of fft points 25 | hop_length: 256 # number of shift points 26 | win_length: 1024 # window length 27 | num_mels : 80 28 | min_level_db : -100 29 | ref_level_db : 20 30 | bits : 9 # bit depth of signal 31 | mu_law : True # Recommended to suppress noise if using raw bits in hp.voc_mode below 32 | peak_norm : False # Normalise to the peak of each wav file 33 | 34 | 35 | 36 | 37 | # network architecture related 38 | model: 39 | embed_dim: 0 40 | eprenet_conv_layers: 0 # one more linear layer w/o non_linear will be added for 0_centor 41 | eprenet_conv_filts: 0 42 | eprenet_conv_chans: 0 43 | dprenet_layers: 2 # one more linear layer w/o non_linear will be added for 0_centor 44 | dprenet_units: 256 # 384 45 | adim: 256 46 | aheads: 2 47 | elayers: 4 48 | eunits: 1024 49 | ddim: 384 50 | dlayers: 4 51 | dunits: 1024 52 | positionwise_layer_type : "conv1d" # linear 53 | positionwise_conv_kernel_size : 9 # 1 54 | postnet_layers: 5 55 | postnet_filts: 5 56 | postnet_chans: 256 57 | use_masking: True 58 | use_weighted_masking: False 59 | bce_pos_weight: 5.0 60 | use_batch_norm: True 61 | use_scaled_pos_enc: True 62 | encoder_normalize_before: False 63 | decoder_normalize_before: False 64 | encoder_concat_after: False 65 | decoder_concat_after: False 66 | reduction_factor: 1 67 | loss_type : "L1" 68 | # minibatch related 69 | batch_sort_key: input # shuffle or input or output 70 | batch_bins: 2549760 # 12 * (870 * 80 + 180 * 35) 71 | # batch_size * (max_out * dim_out + max_in * dim_in) 72 | # resuling in 11 ~ 66 samples (avg 15 samples) in batch (809 batches per epochs) for ljspeech 73 | 74 | # training related 75 | transformer_init: 'pytorch' # choices:["pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"] 76 | transformer_warmup_steps: 4000 77 | transformer_lr: 1.0 78 | initial_encoder_alpha: 1.0 79 | initial_decoder_alpha: 1.0 80 | eprenet_dropout_rate: 0.0 81 | dprenet_dropout_rate: 0.5 82 | postnet_dropout_rate: 0.5 83 | transformer_enc_dropout_rate: 0.1 84 | transformer_enc_positional_dropout_rate: 0.1 85 | transformer_enc_attn_dropout_rate: 0.1 86 | transformer_dec_dropout_rate: 0.1 87 | transformer_dec_positional_dropout_rate: 0.1 88 | transformer_dec_attn_dropout_rate: 0.1 89 | transformer_enc_dec_attn_dropout_rate: 0.1 90 | use_guided_attn_loss: True 91 | num_heads_applied_guided_attn: 2 92 | num_layers_applied_guided_attn: 2 93 | modules_applied_guided_attn: ["encoder_decoder"] 94 | guided_attn_loss_sigma: 0.4 95 | guided_attn_loss_lambda: 1.0 96 | 97 | ### FastSpeech 98 | duration_predictor_layers : 2 99 | duration_predictor_chans : 256 100 | duration_predictor_kernel_size : 3 101 | transfer_encoder_from_teacher : True 102 | duration_predictor_dropout_rate : 0.5 103 | teacher_model : "" 104 | transferred_encoder_module : "all" # choices:["all", "embed"] 105 | 106 | attn_plot : False 107 | 108 | ####### AdaSpeech 109 | predictor_start_step: 60000 110 | phn_latent_dim: 4 111 | 112 | 113 | train: 114 | # optimization related 115 | eos: False #True 116 | opt: 'noam' 117 | accum_grad: 4 118 | grad_clip: 1.0 119 | weight_decay: 0.001 120 | patience: 0 121 | epochs: 1000 # 1,000 epochs * 809 batches / 5 accum_grad : 161,800 iters 122 | save_interval_epoch: 10 123 | GTA : False 124 | # other 125 | ngpu: 1 # number of gpus ("0" uses cpu, otherwise use gpu) 126 | nj: 4 # number of parallel jobs 127 | dumpdir: '' # directory to dump full features 128 | verbose: 0 # verbose option (if set > 0, get more log) 129 | N: 0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches. 130 | seed: 1 # random seed number 131 | resume: "" # the snapshot path to resume (if set empty, no effect) 132 | use_phonemes: True 133 | batch_size : 16 134 | # other 135 | melgan_vocoder : True 136 | save_interval : 1000 137 | chkpt_dir : './checkpoints' 138 | log_dir : './logs' 139 | summary_interval : 200 140 | validation_step : 500 141 | tts_max_mel_len : 870 # if you have a couple of extremely long spectrograms you might want to use this 142 | tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/core/__init__.py -------------------------------------------------------------------------------- /core/acoustic_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from core.modules import LayerNorm 4 | from typing import Optional 5 | import torch.nn.functional as F 6 | 7 | class UtteranceEncoder(nn.Module): 8 | 9 | def __init__(self, idim: int, 10 | n_layers: int = 2, 11 | n_chans: int = 256, 12 | kernel_size: int = 5, 13 | pool_kernel: int = 3, 14 | dropout_rate: float = 0.5, 15 | stride: int = 3): 16 | super(UtteranceEncoder, self).__init__() 17 | self.conv = torch.nn.ModuleList() 18 | for idx in range(n_layers): 19 | in_chans = idim if idx == 0 else n_chans 20 | self.conv += [ 21 | torch.nn.Sequential( 22 | torch.nn.Conv1d( 23 | in_chans, 24 | n_chans, 25 | kernel_size, 26 | stride=stride, 27 | padding=(kernel_size - 1) // 2, 28 | ), 29 | torch.nn.ReLU(), 30 | LayerNorm(n_chans), 31 | torch.nn.Dropout(dropout_rate), 32 | ) 33 | ] 34 | 35 | def forward(self, 36 | xs: torch.Tensor, 37 | x_masks: Optional[torch.Tensor] = None 38 | ) -> torch.Tensor: 39 | 40 | for f in self.conv: 41 | xs = f(xs) # (B, C, Tmax) 42 | 43 | # NOTE: calculate in log domain 44 | xs = F.avg_pool1d(xs, xs.size(-1)) # (B, C, 1) 45 | 46 | return xs 47 | 48 | 49 | class PhonemeLevelEncoder(nn.Module): 50 | 51 | def __init__(self, idim: int, 52 | n_layers: int = 2, 53 | n_chans: int = 256, 54 | out: int = 4, 55 | kernel_size: int = 3, 56 | dropout_rate: float = 0.5, 57 | stride: int = 1): 58 | super(PhonemeLevelEncoder, self).__init__() 59 | self.conv = torch.nn.ModuleList() 60 | for idx in range(n_layers): 61 | in_chans = idim if idx == 0 else n_chans 62 | self.conv += [ 63 | torch.nn.Sequential( 64 | torch.nn.Conv1d( 65 | in_chans, 66 | n_chans, 67 | kernel_size, 68 | stride=stride, 69 | padding=(kernel_size - 1) // 2, 70 | ), 71 | torch.nn.ReLU(), 72 | LayerNorm(n_chans), 73 | torch.nn.Dropout(dropout_rate), 74 | ) 75 | ] 76 | 77 | self.linear = torch.nn.Linear(n_chans, out) 78 | 79 | def forward(self, 80 | xs: torch.Tensor, 81 | x_masks: Optional[torch.Tensor] = None 82 | ) -> torch.Tensor: 83 | 84 | for f in self.conv: 85 | xs = f(xs) # (B, C, Lmax) 86 | 87 | 88 | xs = self.linear(xs.transpose(1, 2)) # (B, Lmax, 4) 89 | 90 | return xs 91 | 92 | 93 | class PhonemeLevelPredictor(nn.Module): 94 | 95 | def __init__(self, idim: int, 96 | n_layers: int = 2, 97 | n_chans: int = 256, 98 | out: int = 4, 99 | kernel_size: int = 3, 100 | dropout_rate: float = 0.5, 101 | stride: int = 1): 102 | super(PhonemeLevelPredictor, self).__init__() 103 | self.conv = torch.nn.ModuleList() 104 | for idx in range(n_layers): 105 | in_chans = idim if idx == 0 else n_chans 106 | self.conv += [ 107 | torch.nn.Sequential( 108 | torch.nn.Conv1d( 109 | in_chans, 110 | n_chans, 111 | kernel_size, 112 | stride=stride, 113 | padding=(kernel_size - 1) // 2, 114 | ), 115 | torch.nn.ReLU(), 116 | LayerNorm(n_chans), 117 | torch.nn.Dropout(dropout_rate), 118 | ) 119 | ] 120 | 121 | self.linear = torch.nn.Linear(n_chans, out) 122 | 123 | def forward(self, 124 | xs: torch.Tensor, 125 | x_masks: Optional[torch.Tensor] = None 126 | ) -> torch.Tensor: 127 | 128 | for f in self.conv: 129 | xs = f(xs) # (B, C, Tmax) 130 | 131 | xs = self.linear(xs.transpose(1, 2)) # (B, Tmax) 132 | 133 | return xs 134 | 135 | class AcousticPredictorLoss(torch.nn.Module): 136 | """Loss function module for duration predictor. 137 | The loss value is Calculated in log domain to make it Gaussian. 138 | """ 139 | 140 | def __init__(self, offset=1.0): 141 | """Initilize duration predictor loss module. 142 | Args: 143 | offset (float, optional): Offset value to avoid nan in log domain. 144 | """ 145 | super(AcousticPredictorLoss, self).__init__() 146 | self.criterion = torch.nn.MSELoss() 147 | self.offset = offset 148 | 149 | def forward(self, outputs, targets): 150 | """Calculate forward propagation. 151 | Args: 152 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 153 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 154 | Returns: 155 | Tensor: Mean squared error loss value. 156 | Note: 157 | `outputs` is in log domain but `targets` is in linear domain. 158 | """ 159 | # NOTE: outputs is in log domain while targets in linear 160 | loss = self.criterion(outputs, targets) 161 | 162 | return loss -------------------------------------------------------------------------------- /core/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class MultiHeadedAttention(nn.Module): 9 | """Multi-Head Attention layer 10 | 11 | :param int n_head: the number of head s 12 | :param int n_feat: the number of features 13 | :param float dropout_rate: dropout rate 14 | """ 15 | 16 | def __init__(self, n_head: int, n_feat: int, dropout_rate: float): 17 | super(MultiHeadedAttention, self).__init__() 18 | assert n_feat % n_head == 0 19 | # We assume d_v always equals d_k 20 | self.d_k = n_feat // n_head 21 | self.h = n_head 22 | self.linear_q = nn.Linear(n_feat, n_feat) 23 | self.linear_k = nn.Linear(n_feat, n_feat) 24 | self.linear_v = nn.Linear(n_feat, n_feat) 25 | self.linear_out = nn.Linear(n_feat, n_feat) 26 | # self.attn: Optional[torch.Tensor] = None # torch.empty(0) 27 | # self.register_buffer("attn", torch.empty(0)) 28 | self.dropout = nn.Dropout(p=dropout_rate) 29 | 30 | def forward( 31 | self, 32 | query: torch.Tensor, 33 | key: torch.Tensor, 34 | value: torch.Tensor, 35 | mask: Optional[torch.Tensor] = None, 36 | ) -> torch.Tensor: 37 | """Compute 'Scaled Dot Product Attention' 38 | 39 | :param torch.Tensor query: (batch, time1, size) 40 | :param torch.Tensor key: (batch, time2, size) 41 | :param torch.Tensor value: (batch, time2, size) 42 | :param torch.Tensor mask: (batch, time1, time2) 43 | :param torch.nn.Dropout dropout: 44 | :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) 45 | weighted by the query dot key attention (batch, head, time1, time2) 46 | """ 47 | n_batch = query.size(0) 48 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 49 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 50 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 51 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 52 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 53 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 54 | 55 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt( 56 | self.d_k 57 | ) # (batch, head, time1, time2) 58 | if mask is not None: 59 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) 60 | # min_value: float = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) 61 | mask = mask.to(device=scores.device) 62 | scores = scores.masked_fill_(mask, -np.inf) 63 | attn = torch.softmax(scores, dim=-1).masked_fill( 64 | mask, 0.0 65 | ) # (batch, head, time1, time2) 66 | else: 67 | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 68 | 69 | p_attn = self.dropout(attn) 70 | x = torch.matmul(p_attn, v) # (batch, head, time1, d_k) 71 | x = ( 72 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 73 | ) # (batch, time1, d_model) 74 | return self.linear_out(x) # (batch, time1, d_model) 75 | -------------------------------------------------------------------------------- /core/duration_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/core/duration_modeling/__init__.py -------------------------------------------------------------------------------- /core/duration_modeling/duration_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Duration predictor related loss.""" 8 | 9 | import torch 10 | from typing import Optional 11 | from core.modules import LayerNorm 12 | 13 | 14 | class DurationPredictor(torch.nn.Module): 15 | """Duration predictor module. 16 | 17 | This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 18 | The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder. 19 | 20 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 21 | https://arxiv.org/pdf/1905.09263.pdf 22 | 23 | Note: 24 | The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, 25 | the outputs are calculated in log domain but in `inference`, those are calculated in linear domain. 26 | 27 | """ 28 | 29 | def __init__( 30 | self, idim, n_layers=2, n_chans=256, kernel_size=3, dropout_rate=0.1, offset=1.0 31 | ): 32 | """Initilize duration predictor module. 33 | 34 | Args: 35 | idim (int): Input dimension. 36 | n_layers (int, optional): Number of convolutional layers. 37 | n_chans (int, optional): Number of channels of convolutional layers. 38 | kernel_size (int, optional): Kernel size of convolutional layers. 39 | dropout_rate (float, optional): Dropout rate. 40 | offset (float, optional): Offset value to avoid nan in log domain. 41 | 42 | """ 43 | super(DurationPredictor, self).__init__() 44 | self.offset = offset 45 | self.conv = torch.nn.ModuleList() 46 | for idx in range(n_layers): 47 | in_chans = idim if idx == 0 else n_chans 48 | self.conv += [ 49 | torch.nn.Sequential( 50 | torch.nn.Conv1d( 51 | in_chans, 52 | n_chans, 53 | kernel_size, 54 | stride=1, 55 | padding=(kernel_size - 1) // 2, 56 | ), 57 | torch.nn.ReLU(), 58 | LayerNorm(n_chans), 59 | torch.nn.Dropout(dropout_rate), 60 | ) 61 | ] 62 | self.linear = torch.nn.Linear(n_chans, 1) 63 | 64 | def _forward( 65 | self, 66 | xs: torch.Tensor, 67 | x_masks: Optional[torch.Tensor] = None, 68 | is_inference: bool = False, 69 | ): 70 | xs = xs.transpose(1, -1) # (B, idim, Tmax) 71 | for f in self.conv: 72 | xs = f(xs) # (B, C, Tmax) 73 | 74 | # NOTE: calculate in log domain 75 | xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax) 76 | 77 | if is_inference: 78 | # NOTE: calculate in linear domain 79 | xs = torch.clamp( 80 | torch.round(xs.exp() - self.offset), min=0 81 | ).long() # avoid negative value 82 | 83 | if x_masks is not None: 84 | xs = xs.masked_fill(x_masks, 0.0) 85 | 86 | return xs 87 | 88 | def forward(self, xs: torch.Tensor, x_masks: Optional[torch.Tensor] = None): 89 | """Calculate forward propagation. 90 | 91 | Args: 92 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 93 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 94 | 95 | Returns: 96 | Tensor: Batch of predicted durations in log domain (B, Tmax). 97 | 98 | """ 99 | return self._forward(xs, x_masks, False) 100 | 101 | def inference(self, xs, x_masks: Optional[torch.Tensor] = None): 102 | """Inference duration. 103 | 104 | Args: 105 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 106 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 107 | 108 | Returns: 109 | LongTensor: Batch of predicted durations in linear domain (B, Tmax). 110 | 111 | """ 112 | return self._forward(xs, x_masks, True) 113 | 114 | 115 | class DurationPredictorLoss(torch.nn.Module): 116 | """Loss function module for duration predictor. 117 | 118 | The loss value is Calculated in log domain to make it Gaussian. 119 | 120 | """ 121 | 122 | def __init__(self, offset=1.0): 123 | """Initilize duration predictor loss module. 124 | 125 | Args: 126 | offset (float, optional): Offset value to avoid nan in log domain. 127 | 128 | """ 129 | super(DurationPredictorLoss, self).__init__() 130 | self.criterion = torch.nn.MSELoss() 131 | self.offset = offset 132 | 133 | def forward(self, outputs, targets): 134 | """Calculate forward propagation. 135 | 136 | Args: 137 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 138 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 139 | 140 | Returns: 141 | Tensor: Mean squared error loss value. 142 | 143 | Note: 144 | `outputs` is in log domain but `targets` is in linear domain. 145 | 146 | """ 147 | # NOTE: outputs is in log domain while targets in linear 148 | targets = torch.log(targets.float() + self.offset) 149 | loss = self.criterion(outputs, targets) 150 | 151 | return loss 152 | -------------------------------------------------------------------------------- /core/duration_modeling/length_regulator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Length regulator related loss.""" 8 | 9 | import logging 10 | 11 | import torch 12 | 13 | from utils.util import pad_2d_tensor, pad_list 14 | 15 | 16 | class LengthRegulator(torch.nn.Module): 17 | """Length regulator module for feed-forward Transformer. 18 | 19 | This is a module of length regulator described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 20 | The length regulator expands char or phoneme-level embedding features to frame-level by repeating each 21 | feature based on the corresponding predicted durations. 22 | 23 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 24 | https://arxiv.org/pdf/1905.09263.pdf 25 | 26 | """ 27 | 28 | def __init__(self, pad_value: float = 0.0): 29 | """Initilize length regulator module. 30 | 31 | Args: 32 | pad_value (float, optional): Value used for padding. 33 | 34 | """ 35 | super(LengthRegulator, self).__init__() 36 | self.pad_value = pad_value 37 | 38 | def forward( 39 | self, 40 | xs: torch.Tensor, 41 | ds: torch.Tensor, 42 | ilens: torch.Tensor, 43 | alpha: float = 1.0, 44 | ) -> torch.Tensor: 45 | """Calculate forward propagation. 46 | 47 | Args: 48 | xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D). 49 | ds (LongTensor): Batch of durations of each frame (B, T). 50 | ilens (LongTensor): Batch of input lengths (B,). 51 | alpha (float, optional): Alpha value to control speed of speech. 52 | 53 | Returns: 54 | Tensor: replicated input tensor based on durations (B, T*, D). 55 | 56 | """ 57 | assert alpha > 0 58 | if alpha != 1.0: 59 | ds = torch.round(ds.float() * alpha).long() 60 | xs = [x[:ilen] for x, ilen in zip(xs, ilens)] 61 | ds = [d[:ilen] for d, ilen in zip(ds, ilens)] 62 | 63 | xs = [self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)] 64 | 65 | return pad_2d_tensor(xs, 0.0) 66 | 67 | def _repeat_one_sequence(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor: 68 | """Repeat each frame according to duration. 69 | 70 | Examples: 71 | >>> x = torch.tensor([[1], [2], [3]]) 72 | tensor([[1], 73 | [2], 74 | [3]]) 75 | >>> d = torch.tensor([1, 2, 3]) 76 | tensor([1, 2, 3]) 77 | >>> self._repeat_one_sequence(x, d) 78 | tensor([[1], 79 | [2], 80 | [2], 81 | [3], 82 | [3], 83 | [3]]) 84 | 85 | """ 86 | if d.sum() == 0: 87 | # logging.warn("all of the predicted durations are 0. fill 0 with 1.") 88 | d = d.fill_(1) 89 | # return torch.cat([x_.repeat(int(d_), 1) for x_, d_ in zip(x, d) if d_ != 0], dim=0) for torchscript 90 | out = [] 91 | for x_, d_ in zip(x, d): 92 | if d_ != 0: 93 | out.append(x_.repeat(int(d_), 1)) 94 | 95 | return torch.cat(out, dim=0) 96 | -------------------------------------------------------------------------------- /core/embedding.py: -------------------------------------------------------------------------------- 1 | """Positonal Encoding Module.""" 2 | import math 3 | 4 | import torch 5 | 6 | 7 | def _pre_hook( 8 | state_dict, 9 | prefix, 10 | local_metadata, 11 | strict, 12 | missing_keys, 13 | unexpected_keys, 14 | error_msgs, 15 | ): 16 | """Perform pre-hook in load_state_dict for backward compatibility. 17 | 18 | Note: 19 | We saved self.pe until v.0.5.2 but we have omitted it later. 20 | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. 21 | 22 | """ 23 | k = prefix + "pe" 24 | if k in state_dict: 25 | state_dict.pop(k) 26 | 27 | 28 | class PositionalEncoding(torch.nn.Module): 29 | """Positional encoding.""" 30 | 31 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 32 | """Initialize class. 33 | 34 | :param int d_model: embedding dim 35 | :param float dropout_rate: dropout rate 36 | :param int max_len: maximum input length 37 | 38 | """ 39 | super(PositionalEncoding, self).__init__() 40 | self.d_model = d_model 41 | self.xscale = math.sqrt(self.d_model) 42 | self.dropout = torch.nn.Dropout(p=dropout_rate) 43 | # self.pe = None 44 | self.register_buffer("pe", None) 45 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 46 | # self._register_load_state_dict_pre_hook(_pre_hook) 47 | 48 | def extend_pe(self, x: torch.Tensor): 49 | """Reset the positional encodings.""" 50 | if self.pe is not None: 51 | if self.pe.size(1) >= x.size(1): 52 | if ( 53 | self.pe.dtype != x.dtype 54 | ): # or self.pe.device != x.device: comment because of torchscript 55 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 56 | return 57 | pe = torch.zeros(x.size(1), self.d_model) 58 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 59 | div_term = torch.exp( 60 | torch.arange(0, self.d_model, 2, dtype=torch.float32) 61 | * -(math.log(10000.0) / self.d_model) 62 | ) 63 | pe[:, 0::2] = torch.sin(position * div_term) 64 | pe[:, 1::2] = torch.cos(position * div_term) 65 | pe = pe.unsqueeze(0) 66 | self.pe = pe.to(device=x.device, dtype=x.dtype) 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | """Add positional encoding. 70 | 71 | Args: 72 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 73 | 74 | Returns: 75 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 76 | 77 | """ 78 | self.extend_pe(x) 79 | x = x * self.xscale + self.pe[:, : x.size(1)] 80 | return self.dropout(x) 81 | 82 | 83 | class ScaledPositionalEncoding(PositionalEncoding): 84 | """Scaled positional encoding module. 85 | 86 | See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf 87 | 88 | """ 89 | 90 | def __init__(self, d_model, dropout_rate, max_len=5000): 91 | """Initialize class. 92 | 93 | :param int d_model: embedding dim 94 | :param float dropout_rate: dropout rate 95 | :param int max_len: maximum input length 96 | 97 | """ 98 | super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) 99 | self.alpha = torch.nn.Parameter(torch.tensor(1.0)) 100 | 101 | def reset_parameters(self): 102 | """Reset parameters.""" 103 | self.alpha.data = torch.tensor(1.0) 104 | 105 | def forward(self, x): 106 | """Add positional encoding. 107 | 108 | Args: 109 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 110 | 111 | Returns: 112 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 113 | 114 | """ 115 | device = x.device 116 | self.extend_pe(x) 117 | # print("Devices x :", x.device) 118 | self.alpha = self.alpha.to(device=device) 119 | x = x + self.alpha * self.pe[:, : x.size(1)].to(device=device) 120 | return self.dropout(x) 121 | -------------------------------------------------------------------------------- /core/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from core.attention import MultiHeadedAttention 4 | from core.embedding import PositionalEncoding 5 | from core.modules import MultiLayeredConv1d 6 | from core.modules import PositionwiseFeedForward 7 | from core.modules import Conv2dSubsampling 8 | from typing import Tuple, Optional 9 | 10 | 11 | class EncoderLayer(nn.Module): 12 | """Encoder layer module 13 | 14 | :param int size: input dim 15 | :param espnet.nets.pytorch_backend.core.attention.MultiHeadedAttention self_attn: self attention module 16 | :param espnet.nets.pytorch_backend.core.positionwise_feed_forward.PositionwiseFeedForward feed_forward: 17 | feed forward module 18 | :param float dropout_rate: dropout rate 19 | :param bool normalize_before: whether to use layer_norm before the first block 20 | :param bool concat_after: whether to concat attention layer's input and output 21 | if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) 22 | if False, no additional linear will be applied. i.e. x -> x + att(x) 23 | """ 24 | 25 | def __init__( 26 | self, 27 | size, 28 | self_attn, 29 | feed_forward, 30 | dropout_rate, 31 | normalize_before=True, 32 | concat_after=False, 33 | ): 34 | super(EncoderLayer, self).__init__() 35 | self.self_attn = self_attn 36 | self.feed_forward = feed_forward 37 | self.norm1 = torch.nn.LayerNorm(size) 38 | self.norm2 = torch.nn.LayerNorm(size) 39 | self.dropout = nn.Dropout(dropout_rate) 40 | self.size = size 41 | self.normalize_before = normalize_before 42 | self.concat_after = concat_after 43 | # if self.concat_after: 44 | self.concat_linear = nn.Linear(size + size, size) 45 | 46 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): 47 | """Compute encoded features 48 | 49 | :param torch.Tensor x: encoded source features (batch, max_time_in, size) 50 | :param torch.Tensor mask: mask for x (batch, max_time_in) 51 | :rtype: Tuple[torch.Tensor, torch.Tensor] 52 | """ 53 | residual = x 54 | if self.normalize_before: 55 | x = self.norm1(x) 56 | if self.concat_after: 57 | x_concat = torch.cat((x, self.self_attn(x, x, x, mask)), dim=-1) 58 | x = residual + self.concat_linear(x_concat) 59 | else: 60 | x = residual + self.dropout(self.self_attn(x, x, x, mask)) 61 | if not self.normalize_before: 62 | x = self.norm1(x) 63 | 64 | residual = x 65 | if self.normalize_before: 66 | x = self.norm2(x) 67 | x = residual + self.dropout(self.feed_forward(x)) 68 | if not self.normalize_before: 69 | x = self.norm2(x) 70 | 71 | return x, mask 72 | 73 | 74 | class Encoder(torch.nn.Module): 75 | """Transformer encoder module 76 | 77 | :param int idim: input dim 78 | :param int attention_dim: dimention of attention 79 | :param int attention_heads: the number of heads of multi head attention 80 | :param int linear_units: the number of units of position-wise feed forward 81 | :param int num_blocks: the number of decoder blocks 82 | :param float dropout_rate: dropout rate 83 | :param float attention_dropout_rate: dropout rate in attention 84 | :param float positional_dropout_rate: dropout rate after adding positional encoding 85 | :param str or torch.nn.Module input_layer: input layer type 86 | :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding 87 | :param bool normalize_before: whether to use layer_norm before the first block 88 | :param bool concat_after: whether to concat attention layer's input and output 89 | if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) 90 | if False, no additional linear will be applied. i.e. x -> x + att(x) 91 | :param str positionwise_layer_type: linear of conv1d 92 | :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer 93 | :param int padding_idx: padding_idx for input_layer=embed 94 | """ 95 | 96 | def __init__( 97 | self, 98 | idim: int, 99 | attention_dim: int = 256, 100 | attention_heads: int = 2, 101 | linear_units: int = 2048, 102 | num_blocks: int = 4, 103 | dropout_rate: float = 0.1, 104 | positional_dropout_rate: float = 0.1, 105 | attention_dropout_rate: float = 0.0, 106 | input_layer: str = "conv2d", 107 | pos_enc_class: torch.nn.Module = PositionalEncoding, 108 | normalize_before: bool = True, 109 | concat_after: bool = False, 110 | positionwise_layer_type: str = "linear", 111 | positionwise_conv_kernel_size: int = 1, 112 | padding_idx: int = -1, 113 | ): 114 | 115 | super(Encoder, self).__init__() 116 | # if self.normalize_before: 117 | self.after_norm = torch.nn.LayerNorm(attention_dim) 118 | if input_layer == "linear": 119 | self.embed = torch.nn.Sequential( 120 | torch.nn.Linear(idim, attention_dim), 121 | torch.nn.LayerNorm(attention_dim), 122 | torch.nn.Dropout(dropout_rate), 123 | torch.nn.ReLU(), 124 | pos_enc_class(attention_dim, positional_dropout_rate), 125 | ) 126 | elif input_layer == "conv2d": 127 | self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate) 128 | elif input_layer == "embed": 129 | self.embed = torch.nn.Sequential( 130 | torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), 131 | pos_enc_class(attention_dim, positional_dropout_rate), 132 | ) 133 | elif isinstance(input_layer, torch.nn.Module): 134 | self.embed = torch.nn.Sequential( 135 | input_layer, 136 | pos_enc_class(attention_dim, positional_dropout_rate), 137 | ) 138 | elif input_layer is None: 139 | self.embed = torch.nn.Sequential( 140 | pos_enc_class(attention_dim, positional_dropout_rate) 141 | ) 142 | else: 143 | raise ValueError("unknown input_layer: " + input_layer) 144 | self.normalize_before = normalize_before 145 | if positionwise_layer_type == "linear": 146 | positionwise_layer = PositionwiseFeedForward 147 | positionwise_layer_args = (attention_dim, linear_units, dropout_rate) 148 | elif positionwise_layer_type == "conv1d": 149 | positionwise_layer = MultiLayeredConv1d 150 | positionwise_layer_args = ( 151 | attention_dim, 152 | linear_units, 153 | positionwise_conv_kernel_size, 154 | dropout_rate, 155 | ) 156 | else: 157 | raise NotImplementedError("Support only linear or conv1d.") 158 | # self.encoders = repeat( 159 | # 4, 160 | # lambda: EncoderLayer( 161 | # attention_dim, 162 | # MultiHeadedAttention(attention_heads, attention_dim, attention_dropout_rate), 163 | # positionwise_layer(*positionwise_layer_args), 164 | # dropout_rate, 165 | # normalize_before, 166 | # concat_after 167 | # ) 168 | # ) 169 | self.encoders_ = nn.ModuleList( 170 | [ 171 | EncoderLayer( 172 | attention_dim, 173 | MultiHeadedAttention( 174 | attention_heads, attention_dim, attention_dropout_rate 175 | ), 176 | positionwise_layer(*positionwise_layer_args), 177 | dropout_rate, 178 | normalize_before, 179 | concat_after, 180 | ) 181 | for _ in range(num_blocks) 182 | ] 183 | ) 184 | 185 | def forward(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None): 186 | """Embed positions in tensor 187 | 188 | :param torch.Tensor xs: input tensor 189 | :param torch.Tensor masks: input mask 190 | :return: position embedded tensor and mask 191 | :rtype Tuple[torch.Tensor, torch.Tensor]: 192 | """ 193 | # if isinstance(self.embed, Conv2dSubsampling): 194 | # xs, masks = self.embed(xs, masks) 195 | # else: 196 | xs = self.embed(xs) 197 | 198 | # xs, masks = self.encoders_(xs, masks) 199 | for encoder in self.encoders_: 200 | xs, masks = encoder(xs, masks) 201 | if self.normalize_before: 202 | xs = self.after_norm(xs) 203 | 204 | return xs, masks 205 | -------------------------------------------------------------------------------- /core/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | from core.embedding import PositionalEncoding 4 | 5 | 6 | class Conv(torch.nn.Module): 7 | """ 8 | Convolution Module 9 | """ 10 | 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size=1, 16 | stride=1, 17 | padding=0, 18 | dilation=1, 19 | bias=True, 20 | ): 21 | """ 22 | :param in_channels: dimension of input 23 | :param out_channels: dimension of output 24 | :param kernel_size: size of kernel 25 | :param stride: size of stride 26 | :param padding: size of padding 27 | :param dilation: dilation rate 28 | :param bias: boolean. if True, bias is included. 29 | :param w_init: str. weight inits with xavier initialization. 30 | """ 31 | super(Conv, self).__init__() 32 | 33 | self.conv = torch.nn.Conv1d( 34 | in_channels, 35 | out_channels, 36 | kernel_size=kernel_size, 37 | stride=stride, 38 | padding=padding, 39 | dilation=dilation, 40 | bias=bias, 41 | ) 42 | 43 | def forward(self, x): 44 | x = x.contiguous().transpose(1, 2) 45 | x = self.conv(x) 46 | x = x.contiguous().transpose(1, 2) 47 | 48 | return x 49 | 50 | 51 | def initialize(model, init_type="pytorch"): 52 | """Initialize Transformer module 53 | 54 | :param torch.nn.Module model: core instance 55 | :param str init_type: initialization type 56 | """ 57 | if init_type == "pytorch": 58 | return 59 | 60 | # weight init 61 | for p in model.parameters(): 62 | if p.dim() > 1: 63 | if init_type == "xavier_uniform": 64 | torch.nn.init.xavier_uniform_(p.data) 65 | elif init_type == "xavier_normal": 66 | torch.nn.init.xavier_normal_(p.data) 67 | elif init_type == "kaiming_uniform": 68 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") 69 | elif init_type == "kaiming_normal": 70 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") 71 | else: 72 | raise ValueError("Unknown initialization: " + init_type) 73 | # bias init 74 | for p in model.parameters(): 75 | if p.dim() == 1: 76 | p.data.zero_() 77 | 78 | # reset some loss with default init 79 | for m in model.modules(): 80 | if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)): 81 | m.reset_parameters() 82 | 83 | 84 | class MultiSequential(torch.nn.Sequential): 85 | """Multi-input multi-output torch.nn.Sequential""" 86 | 87 | def forward(self, *args): 88 | for m in self: 89 | args = m(*args) 90 | return args 91 | 92 | 93 | def repeat(N, fn): 94 | """repeat module N times 95 | 96 | :param int N: repeat time 97 | :param function fn: function to generate module 98 | :return: repeated loss 99 | :rtype: MultiSequential 100 | """ 101 | return MultiSequential(*[fn() for _ in range(N)]) 102 | 103 | 104 | # def layer_norm(x: torch.Tensor, dim): 105 | # if dim == -1: 106 | # return torch.nn.LayerNorm(x) 107 | # else: 108 | # out = torch.nn.LayerNorm(x.transpose(1, -1)) 109 | # return out.transpose(1, -1) 110 | 111 | 112 | class LayerNorm(torch.nn.Module): 113 | def __init__(self, nout: int): 114 | super(LayerNorm, self).__init__() 115 | self.layer_norm = torch.nn.LayerNorm(nout, eps=1e-12) 116 | 117 | def forward(self, x: torch.Tensor) -> torch.Tensor: 118 | x = self.layer_norm(x.transpose(1, -1)) 119 | x = x.transpose(1, -1) 120 | return x 121 | 122 | 123 | # class LayerNorm(torch.nn.LayerNorm): 124 | # """Layer normalization module 125 | # 126 | # :param int nout: output dim size 127 | # :param int dim: dimension to be normalized 128 | # """ 129 | # 130 | # def __init__(self, nout: int, dim: int=-1): 131 | # super(LayerNorm, self).__init__(nout, eps=1e-12) 132 | # self.dim = dim 133 | # 134 | # def forward(self, x: torch.Tensor) -> torch.Tensor: 135 | # """Apply layer normalization 136 | # 137 | # :param torch.Tensor x: input tensor 138 | # :return: layer normalized tensor 139 | # :rtype torch.Tensor 140 | # """ 141 | # if self.dim == -1: 142 | # return super(LayerNorm, self).forward(x) 143 | # return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 144 | 145 | 146 | class Conv2dSubsampling(torch.nn.Module): 147 | """Convolutional 2D subsampling (to 1/4 length) 148 | 149 | :param int idim: input dim 150 | :param int odim: output dim 151 | :param flaot dropout_rate: dropout rate 152 | """ 153 | 154 | def __init__(self, idim: int, odim: int, dropout_rate: float): 155 | super(Conv2dSubsampling, self).__init__() 156 | self.conv = torch.nn.Sequential( 157 | torch.nn.Conv2d(1, odim, 3, 2), 158 | torch.nn.ReLU(), 159 | torch.nn.Conv2d(odim, odim, 3, 2), 160 | torch.nn.ReLU(), 161 | ) 162 | self.out = torch.nn.Sequential( 163 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), 164 | PositionalEncoding(odim, dropout_rate), 165 | ) 166 | 167 | def forward( 168 | self, x: torch.Tensor, x_mask: torch.Tensor 169 | ) -> Tuple[torch.Tensor, torch.Tensor]: 170 | """Subsample x 171 | 172 | :param torch.Tensor x: input tensor 173 | :param torch.Tensor x_mask: input mask 174 | :return: subsampled x and mask 175 | :rtype Tuple[torch.Tensor, torch.Tensor] 176 | """ 177 | x = x.unsqueeze(1) # (b, c, t, f) 178 | x = self.conv(x) 179 | b, c, t, f = x.size() 180 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 181 | if x_mask is None: 182 | return x, None 183 | return x, x_mask[:, :, :-2:2][:, :, :-2:2] 184 | 185 | 186 | class PositionwiseFeedForward(torch.nn.Module): 187 | """Positionwise feed forward 188 | 189 | :param int idim: input dimenstion 190 | :param int hidden_units: number of hidden units 191 | :param float dropout_rate: dropout rate 192 | """ 193 | 194 | def __init__(self, idim: int, hidden_units: int, dropout_rate: float): 195 | super(PositionwiseFeedForward, self).__init__() 196 | self.w_1 = torch.nn.Linear(idim, hidden_units) 197 | self.w_2 = torch.nn.Linear(hidden_units, idim) 198 | self.dropout = torch.nn.Dropout(dropout_rate) 199 | 200 | def forward(self, x: torch.Tensor) -> torch.Tensor: 201 | return self.w_2(self.dropout(torch.relu(self.w_1(x)))) 202 | 203 | 204 | class MultiLayeredConv1d(torch.nn.Module): 205 | """Multi-layered conv1d for Transformer block. 206 | 207 | This is a module of multi-leyered conv1d designed to replace positionwise feed-forward network 208 | in Transforner block, which is introduced in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 209 | 210 | Args: 211 | in_chans (int): Number of input channels. 212 | hidden_chans (int): Number of hidden channels. 213 | kernel_size (int): Kernel size of conv1d. 214 | dropout_rate (float): Dropout rate. 215 | 216 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 217 | https://arxiv.org/pdf/1905.09263.pdf 218 | 219 | """ 220 | 221 | def __init__( 222 | self, in_chans: int, hidden_chans: int, kernel_size: int, dropout_rate: float 223 | ): 224 | super(MultiLayeredConv1d, self).__init__() 225 | self.w_1 = torch.nn.Conv1d( 226 | in_chans, 227 | hidden_chans, 228 | kernel_size, 229 | stride=1, 230 | padding=(kernel_size - 1) // 2, 231 | ) 232 | self.w_2 = torch.nn.Conv1d( 233 | hidden_chans, in_chans, 1, stride=1, padding=(1 - 1) // 2 234 | ) 235 | self.dropout = torch.nn.Dropout(dropout_rate) 236 | 237 | def forward(self, x: torch.Tensor) -> torch.Tensor: 238 | """Calculate forward propagation. 239 | 240 | Args: 241 | x (Tensor): Batch of input tensors (B, *, in_chans). 242 | 243 | Returns: 244 | Tensor: Batch of output tensors (B, *, hidden_chans) 245 | 246 | """ 247 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 248 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) 249 | 250 | 251 | class Postnet(torch.nn.Module): 252 | """Postnet module for Spectrogram prediction network. 253 | This is a module of Postnet in Spectrogram prediction network, 254 | which described in `Natural TTS Synthesis by 255 | Conditioning WaveNet on Mel Spectrogram Predictions`_. 256 | The Postnet predicts refines the predicted 257 | Mel-filterbank of the decoder, 258 | which helps to compensate the detail sturcture of spectrogram. 259 | .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: 260 | https://arxiv.org/abs/1712.05884 261 | """ 262 | 263 | def __init__( 264 | self, 265 | idim: int, 266 | odim: int, 267 | n_layers: int = 5, 268 | n_chans: int = 512, 269 | n_filts: int = 5, 270 | dropout_rate: float = 0.5, 271 | use_batch_norm: bool = True, 272 | ): 273 | """Initialize postnet module. 274 | Args: 275 | idim (int): Dimension of the inputs. 276 | odim (int): Dimension of the outputs. 277 | n_layers (int, optional): The number of layers. 278 | n_filts (int, optional): The number of filter size. 279 | n_units (int, optional): The number of filter channels. 280 | use_batch_norm (bool, optional): Whether to use batch normalization.. 281 | dropout_rate (float, optional): Dropout rate.. 282 | """ 283 | super(Postnet, self).__init__() 284 | self.postnet = torch.nn.ModuleList() 285 | for layer in range(n_layers - 1): 286 | ichans = odim if layer == 0 else n_chans 287 | ochans = odim if layer == n_layers - 1 else n_chans 288 | if use_batch_norm: 289 | self.postnet += [ 290 | torch.nn.Sequential( 291 | torch.nn.Conv1d( 292 | ichans, 293 | ochans, 294 | n_filts, 295 | stride=1, 296 | padding=(n_filts - 1) // 2, 297 | bias=False, 298 | ), 299 | torch.nn.BatchNorm1d(ochans), 300 | torch.nn.Tanh(), 301 | torch.nn.Dropout(dropout_rate), 302 | ) 303 | ] 304 | else: 305 | self.postnet += [ 306 | torch.nn.Sequential( 307 | torch.nn.Conv1d( 308 | ichans, 309 | ochans, 310 | n_filts, 311 | stride=1, 312 | padding=(n_filts - 1) // 2, 313 | bias=False, 314 | ), 315 | torch.nn.Tanh(), 316 | torch.nn.Dropout(dropout_rate), 317 | ) 318 | ] 319 | ichans = n_chans if n_layers != 1 else odim 320 | if use_batch_norm: 321 | self.postnet += [ 322 | torch.nn.Sequential( 323 | torch.nn.Conv1d( 324 | ichans, 325 | odim, 326 | n_filts, 327 | stride=1, 328 | padding=(n_filts - 1) // 2, 329 | bias=False, 330 | ), 331 | torch.nn.BatchNorm1d(odim), 332 | torch.nn.Dropout(dropout_rate), 333 | ) 334 | ] 335 | else: 336 | self.postnet += [ 337 | torch.nn.Sequential( 338 | torch.nn.Conv1d( 339 | ichans, 340 | odim, 341 | n_filts, 342 | stride=1, 343 | padding=(n_filts - 1) // 2, 344 | bias=False, 345 | ), 346 | torch.nn.Dropout(dropout_rate), 347 | ) 348 | ] 349 | 350 | def forward(self, xs): 351 | """Calculate forward propagation. 352 | Args: 353 | xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). 354 | Returns: 355 | Tensor: Batch of padded output tensor. (B, odim, Tmax). 356 | """ 357 | for postnet in self.postnet: 358 | xs = postnet(xs) 359 | return xs 360 | -------------------------------------------------------------------------------- /core/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NoamOpt(object): 5 | "Optim wrapper that implements rate." 6 | 7 | def __init__(self, model_size, factor, warmup, optimizer): 8 | self.optimizer = optimizer 9 | self._step = 0 10 | self.warmup = warmup 11 | self.factor = factor 12 | self.model_size = model_size 13 | self._rate = 0 14 | 15 | @property 16 | def param_groups(self): 17 | return self.optimizer.param_groups 18 | 19 | def step(self): 20 | "Update parameters and rate" 21 | self._step += 1 22 | rate = self.rate() 23 | for p in self.optimizer.param_groups: 24 | p["lr"] = rate 25 | self._rate = rate 26 | self.optimizer.step() 27 | 28 | def rate(self, step=None): 29 | "Implement `lrate` above" 30 | if step is None: 31 | step = self._step 32 | return ( 33 | self.factor 34 | * self.model_size ** (-0.5) 35 | * min(step ** (-0.5), step * self.warmup ** (-1.5)) 36 | ) 37 | 38 | def zero_grad(self): 39 | self.optimizer.zero_grad() 40 | 41 | def state_dict(self): 42 | return { 43 | "_step": self._step, 44 | "warmup": self.warmup, 45 | "factor": self.factor, 46 | "model_size": self.model_size, 47 | "_rate": self._rate, 48 | "optimizer": self.optimizer.state_dict(), 49 | } 50 | 51 | def load_state_dict(self, state_dict): 52 | for key, value in state_dict.items(): 53 | if key == "optimizer": 54 | self.optimizer.load_state_dict(state_dict["optimizer"]) 55 | else: 56 | setattr(self, key, value) 57 | 58 | 59 | def get_std_opt(model, d_model, warmup, factor): 60 | base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) 61 | return NoamOpt(d_model, factor, warmup, base) 62 | -------------------------------------------------------------------------------- /core/variance_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | from core.modules import LayerNorm 5 | 6 | 7 | class VariancePredictor(torch.nn.Module): 8 | def __init__( 9 | self, 10 | idim: int, 11 | n_layers: int = 2, 12 | n_chans: int = 256, 13 | out: int = 1, 14 | kernel_size: int = 3, 15 | dropout_rate: float = 0.5, 16 | offset: float = 1.0, 17 | ): 18 | super(VariancePredictor, self).__init__() 19 | self.offset = offset 20 | self.conv = torch.nn.ModuleList() 21 | for idx in range(n_layers): 22 | in_chans = idim if idx == 0 else n_chans 23 | self.conv += [ 24 | torch.nn.Sequential( 25 | torch.nn.Conv1d( 26 | in_chans, 27 | n_chans, 28 | kernel_size, 29 | stride=1, 30 | padding=(kernel_size - 1) // 2, 31 | ), 32 | torch.nn.ReLU(), 33 | LayerNorm(n_chans), 34 | torch.nn.Dropout(dropout_rate), 35 | ) 36 | ] 37 | self.linear = torch.nn.Linear(n_chans, out) 38 | 39 | def _forward( 40 | self, 41 | xs: torch.Tensor, 42 | is_inference: bool = False, 43 | is_log_output: bool = False, 44 | alpha: float = 1.0, 45 | ) -> torch.Tensor: 46 | xs = xs.transpose(1, -1) # (B, idim, Tmax) 47 | for f in self.conv: 48 | xs = f(xs) # (B, C, Tmax) 49 | 50 | # NOTE: calculate in log domain 51 | xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax) 52 | 53 | if is_inference and is_log_output: 54 | # # NOTE: calculate in linear domain 55 | xs = torch.clamp( 56 | torch.round(xs.exp() - self.offset), min=0 57 | ).long() # avoid negative value 58 | xs = xs * alpha 59 | 60 | return xs 61 | 62 | def forward( 63 | self, xs: torch.Tensor, x_masks: Optional[torch.Tensor] = None 64 | ) -> torch.Tensor: 65 | """Calculate forward propagation. 66 | 67 | Args: 68 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 69 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 70 | 71 | Returns: 72 | Tensor: Batch of predicted durations in log domain (B, Tmax). 73 | 74 | """ 75 | xs = self._forward(xs) 76 | if x_masks is not None: 77 | xs = xs.masked_fill(x_masks, 0.0) 78 | return xs 79 | 80 | def inference( 81 | self, xs: torch.Tensor, is_log_output: bool = False, alpha: float = 1.0 82 | ) -> torch.Tensor: 83 | """Inference duration. 84 | 85 | Args: 86 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 87 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 88 | 89 | Returns: 90 | LongTensor: Batch of predicted durations in linear domain (B, Tmax). 91 | 92 | """ 93 | return self._forward( 94 | xs, is_inference=True, is_log_output=is_log_output, alpha=alpha 95 | ) 96 | 97 | 98 | class EnergyPredictor(torch.nn.Module): 99 | def __init__( 100 | self, 101 | idim, 102 | n_layers=2, 103 | n_chans=256, 104 | kernel_size=3, 105 | dropout_rate=0.1, 106 | offset=1.0, 107 | min=0, 108 | max=0, 109 | n_bins=256, 110 | ): 111 | """Initilize Energy predictor module. 112 | 113 | Args: 114 | idim (int): Input dimension. 115 | n_layers (int, optional): Number of convolutional layers. 116 | n_chans (int, optional): Number of channels of convolutional layers. 117 | kernel_size (int, optional): Kernel size of convolutional layers. 118 | dropout_rate (float, optional): Dropout rate. 119 | offset (float, optional): Offset value to avoid nan in log domain. 120 | 121 | """ 122 | super(EnergyPredictor, self).__init__() 123 | # self.bins = torch.linspace(min, max, n_bins - 1).cuda() 124 | self.register_buffer("energy_bins", torch.linspace(min, max, n_bins - 1)) 125 | self.predictor = VariancePredictor(idim) 126 | 127 | def forward(self, xs: torch.Tensor, x_masks: torch.Tensor): 128 | """Calculate forward propagation. 129 | 130 | Args: 131 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 132 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 133 | 134 | Returns: 135 | Tensor: Batch of predicted durations in log domain (B, Tmax). 136 | 137 | """ 138 | return self.predictor(xs, x_masks) 139 | 140 | def inference(self, xs: torch.Tensor, alpha: float = 1.0): 141 | """Inference duration. 142 | 143 | Args: 144 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 145 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 146 | 147 | Returns: 148 | LongTensor: Batch of predicted durations in linear domain (B, Tmax). 149 | 150 | """ 151 | out = self.predictor.inference(xs, False, alpha=alpha) 152 | return self.to_one_hot(out) # Need to do One hot code 153 | 154 | def to_one_hot(self, x): 155 | # e = de_norm_mean_std(e, hp.e_mean, hp.e_std) 156 | # For pytorch > = 1.6.0 157 | 158 | quantize = torch.bucketize(x, self.energy_bins).to(device=x.device) # .cuda() 159 | return F.one_hot(quantize.long(), 256).float() 160 | 161 | 162 | class PitchPredictor(torch.nn.Module): 163 | def __init__( 164 | self, 165 | idim, 166 | n_layers=2, 167 | n_chans=384, 168 | kernel_size=3, 169 | dropout_rate=0.1, 170 | offset=1.0, 171 | min=0, 172 | max=0, 173 | n_bins=256, 174 | ): 175 | """Initilize pitch predictor module. 176 | 177 | Args: 178 | idim (int): Input dimension. 179 | n_layers (int, optional): Number of convolutional layers. 180 | n_chans (int, optional): Number of channels of convolutional layers. 181 | kernel_size (int, optional): Kernel size of convolutional layers. 182 | dropout_rate (float, optional): Dropout rate. 183 | offset (float, optional): Offset value to avoid nan in log domain. 184 | 185 | """ 186 | super(PitchPredictor, self).__init__() 187 | # self.bins = torch.exp(torch.linspace(torch.log(torch.tensor(min)), torch.log(torch.tensor(max)), n_bins - 1)).cuda() 188 | self.register_buffer( 189 | "pitch_bins", 190 | torch.exp( 191 | torch.linspace( 192 | torch.log(torch.tensor(min)), 193 | torch.log(torch.tensor(max)), 194 | n_bins - 1, 195 | ) 196 | ), 197 | ) 198 | self.predictor = VariancePredictor(idim) 199 | 200 | def forward(self, xs: torch.Tensor, x_masks: torch.Tensor): 201 | """Calculate forward propagation. 202 | 203 | Args: 204 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 205 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 206 | 207 | Returns: 208 | Tensor: Batch of predicted durations in log domain (B, Tmax). 209 | 210 | """ 211 | return self.predictor(xs, x_masks) 212 | 213 | def inference(self, xs: torch.Tensor, alpha: float = 1.0): 214 | """Inference duration. 215 | 216 | Args: 217 | xs (Tensor): Batch of input sequences (B, Tmax, idim). 218 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). 219 | 220 | Returns: 221 | LongTensor: Batch of predicted durations in linear domain (B, Tmax). 222 | 223 | """ 224 | out = self.predictor.inference(xs, False, alpha=alpha) 225 | return self.to_one_hot(out) 226 | 227 | def to_one_hot(self, x: torch.Tensor): 228 | # e = de_norm_mean_std(e, hp.e_mean, hp.e_std) 229 | # For pytorch > = 1.6.0 230 | 231 | quantize = torch.bucketize(x, self.pitch_bins).to(device=x.device) # .cuda() 232 | return F.one_hot(quantize.long(), 256).float() 233 | 234 | 235 | class PitchPredictorLoss(torch.nn.Module): 236 | """Loss function module for duration predictor. 237 | 238 | The loss value is Calculated in log domain to make it Gaussian. 239 | 240 | """ 241 | 242 | def __init__(self, offset=1.0): 243 | """Initilize duration predictor loss module. 244 | 245 | Args: 246 | offset (float, optional): Offset value to avoid nan in log domain. 247 | 248 | """ 249 | super(PitchPredictorLoss, self).__init__() 250 | self.criterion = torch.nn.MSELoss() 251 | self.offset = offset 252 | 253 | def forward(self, outputs, targets): 254 | """Calculate forward propagation. 255 | 256 | Args: 257 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 258 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 259 | 260 | Returns: 261 | Tensor: Mean squared error loss value. 262 | 263 | Note: 264 | `outputs` is in log domain but `targets` is in linear domain. 265 | 266 | """ 267 | # NOTE: We convert the output in log domain low error value 268 | # print("Output :", outputs[0]) 269 | # print("Before Output :", targets[0]) 270 | # targets = torch.log(targets.float() + self.offset) 271 | # print("Before Output :", targets[0]) 272 | # outputs = torch.log(outputs.float() + self.offset) 273 | loss = self.criterion(outputs, targets) 274 | # print(loss) 275 | return loss 276 | 277 | 278 | class EnergyPredictorLoss(torch.nn.Module): 279 | """Loss function module for duration predictor. 280 | 281 | The loss value is Calculated in log domain to make it Gaussian. 282 | 283 | """ 284 | 285 | def __init__(self, offset=1.0): 286 | """Initilize duration predictor loss module. 287 | 288 | Args: 289 | offset (float, optional): Offset value to avoid nan in log domain. 290 | 291 | """ 292 | super(EnergyPredictorLoss, self).__init__() 293 | self.criterion = torch.nn.MSELoss() 294 | self.offset = offset 295 | 296 | def forward(self, outputs, targets): 297 | """Calculate forward propagation. 298 | 299 | Args: 300 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 301 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 302 | 303 | Returns: 304 | Tensor: Mean squared error loss value. 305 | 306 | Note: 307 | `outputs` is in log domain but `targets` is in linear domain. 308 | 309 | """ 310 | # NOTE: outputs is in log domain while targets in linear 311 | # targets = torch.log(targets.float() + self.offset) 312 | loss = self.criterion(outputs, targets) 313 | 314 | return loss 315 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/audio_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import librosa 4 | from scipy.signal import lfilter 5 | import pyworld as pw 6 | import torch 7 | from scipy.signal import get_window 8 | import librosa.util as librosa_util 9 | 10 | 11 | def label_2_float(x, bits): 12 | return 2 * x / (2 ** bits - 1.0) - 1.0 13 | 14 | 15 | def float_2_label(x, bits): 16 | assert abs(x).max() <= 1.0 17 | x = (x + 1.0) * (2 ** bits - 1) / 2 18 | return x.clip(0, 2 ** bits - 1) 19 | 20 | 21 | def load_wav(path, hp): 22 | return librosa.load(path, sr=hp.audio.sample_rate)[0] 23 | 24 | 25 | def save_wav(x, path, hp): 26 | librosa.output.write_wav(path, x.astype(np.float32), sr=hp.audio.sample_rate) 27 | 28 | 29 | def split_signal(x): 30 | unsigned = x + 2 ** 15 31 | coarse = unsigned // 256 32 | fine = unsigned % 256 33 | return coarse, fine 34 | 35 | 36 | def combine_signal(coarse, fine): 37 | return coarse * 256 + fine - 2 ** 15 38 | 39 | 40 | def encode_16bits(x): 41 | return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) 42 | 43 | 44 | mel_basis = None 45 | 46 | 47 | def energy(y): 48 | # Extract energy 49 | S = librosa.magphase(stft(y))[0] 50 | e = np.sqrt(np.sum(S ** 2, axis=0)) # np.linalg.norm(S, axis=0) 51 | return e.squeeze() # (Number of frames) => (654,) 52 | 53 | 54 | def pitch(y, hp): 55 | # Extract Pitch/f0 from raw waveform using PyWORLD 56 | y = y.astype(np.float64) 57 | """ 58 | f0_floor : float 59 | Lower F0 limit in Hz. 60 | Default: 71.0 61 | f0_ceil : float 62 | Upper F0 limit in Hz. 63 | Default: 800.0 64 | """ 65 | f0, timeaxis = pw.dio( 66 | y, 67 | hp.audio.sample_rate, 68 | frame_period=hp.audio.hop_length / hp.audio.sample_rate * 1000, 69 | ) # For hop size 256 frame period is 11.6 ms 70 | return f0 # (Number of Frames) = (654,) 71 | 72 | 73 | def linear_to_mel(spectrogram, hp): 74 | global mel_basis 75 | if mel_basis is None: 76 | mel_basis = build_mel_basis(hp) 77 | return np.dot(mel_basis, spectrogram) 78 | 79 | 80 | def build_mel_basis(hp): 81 | return librosa.filters.mel( 82 | hp.audio.sample_rate, 83 | hp.audio.n_fft, 84 | n_mels=hp.audio.num_mels, 85 | fmin=hp.audio.fmin, 86 | ) 87 | 88 | 89 | def normalize(S, hp): 90 | return np.clip((S - hp.audio.min_level_db) / -hp.audio.min_level_db, 0, 1) 91 | 92 | 93 | def denormalize(S, hp): 94 | return (np.clip(S, 0, 1) * -hp.audio.min_level_db) + hp.audio.min_level_db 95 | 96 | 97 | def amp_to_db(x): 98 | return 20 * np.log10(np.maximum(1e-5, x)) 99 | 100 | 101 | def db_to_amp(x): 102 | return np.power(10.0, x * 0.05) 103 | 104 | 105 | def spectrogram(y, hp): 106 | D = stft(y, hp) 107 | S = amp_to_db(np.abs(D)) - hp.audio.ref_level_db 108 | return normalize(S, hp) 109 | 110 | 111 | def melspectrogram(y, hp): 112 | D = stft(y, hp) 113 | S = amp_to_db(linear_to_mel(np.abs(D), hp)) 114 | return normalize(S, hp) 115 | 116 | 117 | def stft(y, hp): 118 | return librosa.stft( 119 | y=y, 120 | n_fft=hp.audio.n_fft, 121 | hop_length=hp.audio.hop_length, 122 | win_length=hp.audio.win_length, 123 | ) 124 | 125 | 126 | def pre_emphasis(x, hp): 127 | return lfilter([1, -hp.audio.preemphasis], [1], x) 128 | 129 | 130 | def de_emphasis(x, hp): 131 | return lfilter([1], [1, -hp.audio.preemphasis], x) 132 | 133 | 134 | def encode_mu_law(x, mu): 135 | mu = mu - 1 136 | fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) 137 | return np.floor((fx + 1) / 2 * mu + 0.5) 138 | 139 | 140 | def decode_mu_law(y, mu, from_labels=True): 141 | # TODO : get rid of log2 - makes no sense 142 | if from_labels: 143 | y = label_2_float(y, math.log2(mu)) 144 | mu = mu - 1 145 | x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) 146 | return x 147 | 148 | 149 | def reconstruct_waveform(mel, hp, n_iter=32): 150 | """Uses Griffin-Lim phase reconstruction to convert from a normalized 151 | mel spectrogram back into a waveform.""" 152 | denormalized = denormalize(mel) 153 | amp_mel = db_to_amp(denormalized) 154 | S = librosa.feature.inverse.mel_to_stft( 155 | amp_mel, 156 | power=1, 157 | sr=hp.audio.sample_rate, 158 | n_fft=hp.audio.n_fft, 159 | fmin=hp.audio.fmin, 160 | ) 161 | wav = librosa.core.griffinlim( 162 | S, n_iter=n_iter, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length 163 | ) 164 | return wav 165 | 166 | 167 | def quantize_input(input, min, max, num_bins=256): 168 | bins = np.linspace(min, max, num=num_bins) 169 | quantize = np.digitize(input, bins) 170 | return quantize 171 | 172 | 173 | def window_sumsquare( 174 | window, 175 | n_frames, 176 | hop_length=200, 177 | win_length=800, 178 | n_fft=800, 179 | dtype=np.float32, 180 | norm=None, 181 | ): 182 | """ 183 | # from librosa 0.6 184 | Compute the sum-square envelope of a window function at a given hop length. 185 | This is used to estimate modulation effects induced by windowing 186 | observations in short-time fourier transforms. 187 | Parameters 188 | ---------- 189 | window : string, tuple, number, callable, or list-like 190 | Window specification, as in `get_window` 191 | n_frames : int > 0 192 | The number of analysis frames 193 | hop_length : int > 0 194 | The number of samples to advance between frames 195 | win_length : [optional] 196 | The length of the window function. By default, this matches `n_fft`. 197 | n_fft : int > 0 198 | The length of each analysis frame. 199 | dtype : np.dtype 200 | The data type of the output 201 | Returns 202 | ------- 203 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 204 | The sum-squared envelope of the window function 205 | """ 206 | if win_length is None: 207 | win_length = n_fft 208 | 209 | n = n_fft + hop_length * (n_frames - 1) 210 | x = np.zeros(n, dtype=dtype) 211 | 212 | # Compute the squared window at the desired length 213 | win_sq = get_window(window, win_length, fftbins=True) 214 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 215 | win_sq = librosa_util.pad_center(win_sq, n_fft) 216 | 217 | # Fill the envelope 218 | for i in range(n_frames): 219 | sample = i * hop_length 220 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 221 | return x 222 | 223 | 224 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 225 | """ 226 | PARAMS 227 | ------ 228 | magnitudes: spectrogram magnitudes 229 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 230 | """ 231 | 232 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 233 | angles = angles.astype(np.float32) 234 | angles = torch.autograd.Variable(torch.from_numpy(angles).cuda()) 235 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 236 | 237 | for i in range(n_iters): 238 | _, angles = stft_fn.transform(signal) 239 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 240 | return signal 241 | 242 | 243 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 244 | """ 245 | PARAMS 246 | ------ 247 | C: compression factor 248 | """ 249 | return torch.log(torch.clamp(x, min=clip_val) * C) 250 | 251 | 252 | def dynamic_range_decompression(x, C=1): 253 | """ 254 | PARAMS 255 | ------ 256 | C: compression factor used to compress 257 | """ 258 | return torch.exp(x) / C 259 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.utils.data.sampler import Sampler 5 | from dataset.texts import phonemes_to_sequence 6 | import numpy as np 7 | from dataset.texts import text_to_sequence 8 | from utils.util import pad_list, str_to_int_list, remove_outlier 9 | 10 | 11 | def get_tts_dataset(path, batch_size, hp, valid=False): 12 | 13 | if valid: 14 | file_ = hp.data.valid_filelist 15 | pin_mem = False 16 | num_workers = 0 17 | shuffle = False 18 | else: 19 | file_ = hp.data.train_filelist 20 | pin_mem = True 21 | num_workers = 4 22 | shuffle = True 23 | train_dataset = TTSDataset( 24 | path, file_, hp.train.use_phonemes, hp.data.tts_cleaner_names, hp.train.eos 25 | ) 26 | 27 | train_set = DataLoader( 28 | train_dataset, 29 | collate_fn=collate_tts, 30 | batch_size=batch_size, 31 | num_workers=num_workers, 32 | shuffle=shuffle, 33 | pin_memory=pin_mem, 34 | ) 35 | return train_set 36 | 37 | 38 | class TTSDataset(Dataset): 39 | def __init__(self, path, file_, use_phonemes, tts_cleaner_names, eos): 40 | self.path = path 41 | with open("{}".format(file_), encoding="utf-8") as f: 42 | self._metadata = [line.strip().split("|") for line in f] 43 | self.use_phonemes = use_phonemes 44 | self.tts_cleaner_names = tts_cleaner_names 45 | self.eos = eos 46 | 47 | def __getitem__(self, index): 48 | id = self._metadata[index][4].split(".")[0] 49 | x_ = self._metadata[index][3].split() 50 | if self.use_phonemes: 51 | x = phonemes_to_sequence(x_) 52 | else: 53 | x = text_to_sequence(x_, self.tts_cleaner_names, self.eos) 54 | mel = np.load(f"{self.path}mels/{id}.npy") 55 | durations = str_to_int_list(self._metadata[index][2]) 56 | e = remove_outlier( 57 | np.load(f"{self.path}energy/{id}.npy") 58 | ) # self._norm_mean_std(np.load(f'{self.path}energy/{id}.npy'), self.e_mean, self.e_std, True) 59 | p = remove_outlier( 60 | np.load(f"{self.path}pitch/{id}.npy") 61 | ) # self._norm_mean_std(np.load(f'{self.path}pitch/{id}.npy'), self.f0_mean, self.f0_std, True) 62 | mel_len = mel.shape[1] 63 | durations = durations[: len(x)] 64 | durations[-1] = durations[-1] + (mel.shape[1] - sum(durations)) 65 | assert mel.shape[1] == sum(durations) 66 | 67 | avg_mel = np.load(f"{self.path}avg_mel_ph/{id}.npy") 68 | assert avg_mel.shape[0] == len(x) 69 | 70 | 71 | return ( 72 | np.array(x), 73 | mel.T, 74 | id, 75 | mel_len, 76 | np.array(durations), 77 | e, 78 | p, 79 | avg_mel, 80 | ) # Mel [T, num_mel] 81 | 82 | def __len__(self): 83 | return len(self._metadata) 84 | 85 | def _norm_mean_std(self, x, mean, std, is_remove_outlier=False): 86 | if is_remove_outlier: 87 | x = remove_outlier(x) 88 | zero_idxs = np.where(x == 0.0)[0] 89 | x = (x - mean) / std 90 | x[zero_idxs] = 0.0 91 | return x 92 | 93 | 94 | def pad1d(x, max_len): 95 | return np.pad(x, (0, max_len - len(x)), mode="constant") 96 | 97 | 98 | def pad2d(x, max_len): 99 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant") 100 | 101 | 102 | def collate_tts(batch): 103 | 104 | ilens = torch.from_numpy(np.array([x[0].shape[0] for x in batch])).long() 105 | olens = torch.from_numpy(np.array([y[1].shape[0] for y in batch])).long() 106 | ids = [x[2] for x in batch] 107 | 108 | # perform padding and conversion to tensor 109 | inputs = pad_list([torch.from_numpy(x[0]).long() for x in batch], 0) 110 | mels = pad_list([torch.from_numpy(y[1]).float() for y in batch], 0) 111 | 112 | durations = pad_list([torch.from_numpy(x[4]).long() for x in batch], 0) 113 | energys = pad_list([torch.from_numpy(y[5]).float() for y in batch], 0) 114 | pitches = pad_list([torch.from_numpy(y[6]).float() for y in batch], 0) 115 | 116 | avg_mel = pad_list([torch.from_numpy(y[7]).float() for y in batch], 0) 117 | 118 | # make labels for stop prediction 119 | labels = mels.new_zeros(mels.size(0), mels.size(1)) 120 | for i, l in enumerate(olens): 121 | labels[i, l - 1 :] = 1.0 122 | 123 | # scale spectrograms to -4 <--> 4 124 | # mels = (mels * 8.) - 4 125 | 126 | return inputs, ilens, mels, labels, olens, ids, durations, energys, pitches, avg_mel 127 | 128 | 129 | class BinnedLengthSampler(Sampler): 130 | def __init__(self, lengths, batch_size, bin_size): 131 | _, self.idx = torch.sort(torch.tensor(lengths).long()) 132 | self.batch_size = batch_size 133 | self.bin_size = bin_size 134 | assert self.bin_size % self.batch_size == 0 135 | 136 | def __iter__(self): 137 | # Need to change to numpy since there's a bug in random.shuffle(tensor) 138 | # TODO : Post an issue on pytorch repo 139 | idx = self.idx.numpy() 140 | bins = [] 141 | 142 | for i in range(len(idx) // self.bin_size): 143 | this_bin = idx[i * self.bin_size : (i + 1) * self.bin_size] 144 | random.shuffle(this_bin) 145 | bins += [this_bin] 146 | 147 | random.shuffle(bins) 148 | binned_idx = np.stack(bins).reshape(-1) 149 | 150 | if len(binned_idx) < len(idx): 151 | last_bin = idx[len(binned_idx) :] 152 | random.shuffle(last_bin) 153 | binned_idx = np.concatenate([binned_idx, last_bin]) 154 | 155 | return iter(torch.tensor(binned_idx).long()) 156 | 157 | def __len__(self): 158 | return len(self.idx) 159 | -------------------------------------------------------------------------------- /dataset/ljspeech.py: -------------------------------------------------------------------------------- 1 | from utils.util import get_files 2 | 3 | 4 | def ljspeech(path, hp): 5 | 6 | csv_file = get_files(path, extension=".csv") 7 | 8 | assert len(csv_file) == 1 9 | 10 | wavs = [] 11 | # texts = [] 12 | # encode = [] 13 | 14 | with open(csv_file[0], encoding="utf-8") as f_: 15 | # if 'phoneme_cleaners' in hp.tts_cleaner_names: 16 | # print("Cleaner : {} Language Code : {}\n".format(hp.tts_cleaner_names[0],hp.phoneme_language)) 17 | # for line in f : 18 | # split = line.split('|') 19 | # text_dict[split[0]] = text2phone(split[-1].strip(),hp.phoneme_language) 20 | # else: 21 | print("Cleaner : {} \n".format(hp.tts_cleaner_names)) 22 | for line in f_: 23 | sub = {} 24 | split = line.split("|") 25 | t = split[-1].strip().upper() 26 | # t = t.replace('"', '') 27 | # t = t.replace('-', ' ') 28 | # t = t.replace(';','') 29 | # t = t.replace('(', '') 30 | # t = t.replace(')', '') 31 | # t = t.replace(':', '') 32 | # t = re.sub('[^A-Za-z0-9.!?,\' ]+', '', t) 33 | if len(t) > 0: 34 | wavs.append(split[0].strip()) 35 | # texts.append(t) 36 | # encode.append(text_to_sequence(t, hp.tts_cleaner_names)) 37 | # with open(os.path.join(data_dir, 'train.txt'), 'w', encoding='utf-8') as f: 38 | # for w, t, e in zip(wavs, texts, encode): 39 | # f.write('{}|{}|{}'.format(w,e,t) + '\n') 40 | 41 | return wavs # , texts, encode 42 | 43 | 44 | if __name__ == "__main__": 45 | ljspeech("metadata.csv", ["english_cleaners"]) 46 | -------------------------------------------------------------------------------- /dataset/texts/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from dataset.texts import cleaners 4 | from dataset.texts.symbols import ( 5 | symbols, 6 | _eos, 7 | phonemes_symbols, 8 | PAD, 9 | EOS, 10 | _PHONEME_SEP, 11 | ) 12 | from dataset.texts.dict_ import symbols_ 13 | import nltk 14 | from g2p_en import G2p 15 | 16 | # Mappings from symbol to numeric ID and vice versa: 17 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 18 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 19 | 20 | # Regular expression matching text enclosed in curly braces: 21 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 22 | 23 | symbols_inv = {v: k for k, v in symbols_.items()} 24 | 25 | valid_symbols = [ 26 | "AA", 27 | "AA1", 28 | "AE", 29 | "AE0", 30 | "AE1", 31 | "AH", 32 | "AH0", 33 | "AH1", 34 | "AO", 35 | "AO1", 36 | "AW", 37 | "AW0", 38 | "AW1", 39 | "AY", 40 | "AY0", 41 | "AY1", 42 | "B", 43 | "CH", 44 | "D", 45 | "DH", 46 | "EH", 47 | "EH0", 48 | "EH1", 49 | "ER", 50 | "EY", 51 | "EY0", 52 | "EY1", 53 | "F", 54 | "G", 55 | "HH", 56 | "IH", 57 | "IH0", 58 | "IH1", 59 | "IY", 60 | "IY0", 61 | "IY1", 62 | "JH", 63 | "K", 64 | "L", 65 | "M", 66 | "N", 67 | "NG", 68 | "OW", 69 | "OW0", 70 | "OW1", 71 | "OY", 72 | "OY0", 73 | "OY1", 74 | "P", 75 | "R", 76 | "S", 77 | "SH", 78 | "T", 79 | "TH", 80 | "UH", 81 | "UH0", 82 | "UH1", 83 | "UW", 84 | "UW0", 85 | "UW1", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | "pau", 92 | "sil", 93 | "spn" 94 | ] 95 | 96 | 97 | def pad_with_eos_bos(_sequence): 98 | return _sequence + [_symbol_to_id[_eos]] 99 | 100 | 101 | def text_to_sequence(text, cleaner_names, eos): 102 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 103 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 104 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 105 | Args: 106 | text: string to convert to a sequence 107 | cleaner_names: names of the cleaner functions to run the text through 108 | Returns: 109 | List of integers corresponding to the symbols in the text 110 | """ 111 | sequence = [] 112 | if eos: 113 | text = text + "~" 114 | try: 115 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 116 | except KeyError: 117 | print("text : ", text) 118 | exit(0) 119 | 120 | return sequence 121 | 122 | 123 | def sequence_to_text(sequence): 124 | """Converts a sequence of IDs back to a string""" 125 | result = "" 126 | for symbol_id in sequence: 127 | if symbol_id in symbols_inv: 128 | s = symbols_inv[symbol_id] 129 | # Enclose ARPAbet back in curly braces: 130 | if len(s) > 1 and s[0] == "@": 131 | s = "{%s}" % s[1:] 132 | result += s 133 | return result.replace("}{", " ") 134 | 135 | 136 | def _clean_text(text, cleaner_names): 137 | for name in cleaner_names: 138 | cleaner = getattr(cleaners, name) 139 | if not cleaner: 140 | raise Exception("Unknown cleaner: %s" % name) 141 | text = cleaner(text) 142 | return text 143 | 144 | 145 | def _symbols_to_sequence(symbols): 146 | return [symbols_[s.upper()] for s in symbols] 147 | 148 | 149 | def _arpabet_to_sequence(text): 150 | return _symbols_to_sequence(["@" + s for s in text.split()]) 151 | 152 | 153 | def _should_keep_symbol(s): 154 | return s in _symbol_to_id and s != "_" and s != "~" 155 | 156 | 157 | # For phonemes 158 | _phoneme_to_id = {s: i for i, s in enumerate(valid_symbols)} 159 | _id_to_phoneme = {i: s for i, s in enumerate(valid_symbols)} 160 | 161 | 162 | def _should_keep_token(token, token_dict): 163 | return ( 164 | token in token_dict 165 | and token != PAD 166 | and token != EOS 167 | and token != _phoneme_to_id[PAD] 168 | and token != _phoneme_to_id[EOS] 169 | ) 170 | 171 | 172 | def phonemes_to_sequence(phonemes): 173 | string = phonemes.split() if isinstance(phonemes, str) else phonemes 174 | # string.append(EOS) 175 | sequence = list(map(convert_phoneme_CMU, string)) 176 | sequence = [_phoneme_to_id[s] for s in sequence] 177 | # if _should_keep_token(s, _phoneme_to_id)] 178 | return sequence 179 | 180 | 181 | def sequence_to_phonemes(sequence, use_eos=False): 182 | string = [_id_to_phoneme[idx] for idx in sequence] 183 | # if _should_keep_token(idx, _id_to_phoneme)] 184 | string = _PHONEME_SEP.join(string) 185 | if use_eos: 186 | string = string.replace(EOS, "") 187 | return string 188 | 189 | 190 | def convert_phoneme_CMU(phoneme): 191 | REMAPPING = { 192 | 'AA0': 'AA1', 193 | 'AA2': 'AA1', 194 | 'AE2': 'AE1', 195 | 'AH2': 'AH1', 196 | 'AO0': 'AO1', 197 | 'AO2': 'AO1', 198 | 'AW2': 'AW1', 199 | 'AY2': 'AY1', 200 | 'EH2': 'EH1', 201 | 'ER0': 'EH1', 202 | 'ER1': 'EH1', 203 | 'ER2': 'EH1', 204 | 'EY2': 'EY1', 205 | 'IH2': 'IH1', 206 | 'IY2': 'IY1', 207 | 'OW2': 'OW1', 208 | 'OY2': 'OY1', 209 | 'UH2': 'UH1', 210 | 'UW2': 'UW1', 211 | } 212 | return REMAPPING.get(phoneme, phoneme) 213 | 214 | 215 | def text_to_phonemes(text, custom_words={}): 216 | """ 217 | Convert text into ARPAbet. 218 | For known words use CMUDict; for the rest try 'espeak' (to IPA) followed by 'listener'. 219 | :param text: str, input text. 220 | :param custom_words: 221 | dict {str: list of str}, optional 222 | Pronounciations (a list of ARPAbet phonemes) you'd like to override. 223 | Example: {'word': ['W', 'EU1', 'R', 'D']} 224 | :return: list of str, phonemes 225 | """ 226 | g2p = G2p() 227 | 228 | """def convert_phoneme_CMU(phoneme): 229 | REMAPPING = { 230 | 'AA0': 'AA1', 231 | 'AA2': 'AA1', 232 | 'AE2': 'AE1', 233 | 'AH2': 'AH1', 234 | 'AO0': 'AO1', 235 | 'AO2': 'AO1', 236 | 'AW2': 'AW1', 237 | 'AY2': 'AY1', 238 | 'EH2': 'EH1', 239 | 'ER0': 'EH1', 240 | 'ER1': 'EH1', 241 | 'ER2': 'EH1', 242 | 'EY2': 'EY1', 243 | 'IH2': 'IH1', 244 | 'IY2': 'IY1', 245 | 'OW2': 'OW1', 246 | 'OY2': 'OY1', 247 | 'UH2': 'UH1', 248 | 'UW2': 'UW1', 249 | } 250 | return REMAPPING.get(phoneme, phoneme) 251 | """ 252 | 253 | def convert_phoneme_listener(phoneme): 254 | VOWELS = ['A', 'E', 'I', 'O', 'U'] 255 | if phoneme[0] in VOWELS: 256 | phoneme += '1' 257 | return phoneme # convert_phoneme_CMU(phoneme) 258 | 259 | try: 260 | known_words = nltk.corpus.cmudict.dict() 261 | except LookupError: 262 | nltk.download("cmudict") 263 | known_words = nltk.corpus.cmudict.dict() 264 | 265 | for word, phonemes in custom_words.items(): 266 | known_words[word.lower()] = [phonemes] 267 | 268 | words = nltk.tokenize.WordPunctTokenizer().tokenize(text.lower()) 269 | 270 | phonemes = [] 271 | PUNCTUATION = "!?.,-:;\"'()" 272 | for word in words: 273 | if all(c in PUNCTUATION for c in word): 274 | pronounciation = ["pau"] 275 | elif word in known_words: 276 | pronounciation = known_words[word][0] 277 | pronounciation = list( 278 | pronounciation 279 | ) # map(convert_phoneme_CMU, pronounciation)) 280 | else: 281 | pronounciation = g2p(word) 282 | pronounciation = list( 283 | pronounciation 284 | ) # (map(convert_phoneme_CMU, pronounciation)) 285 | 286 | phonemes += pronounciation 287 | 288 | return phonemes 289 | -------------------------------------------------------------------------------- /dataset/texts/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | """ 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | 21 | _whitespace_re = re.compile(r"\s+") 22 | punctuations = """+-!()[]{};:'"\<>/?@#^&*_~""" 23 | 24 | # List of (regular expression, replacement) pairs for abbreviations: 25 | _abbreviations = [ 26 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 27 | for x in [ 28 | ("mrs", "misess"), 29 | ("mr", "mister"), 30 | ("dr", "doctor"), 31 | ("st", "saint"), 32 | ("co", "company"), 33 | ("jr", "junior"), 34 | ("maj", "major"), 35 | ("gen", "general"), 36 | ("drs", "doctors"), 37 | ("rev", "reverend"), 38 | ("lt", "lieutenant"), 39 | ("hon", "honorable"), 40 | ("sgt", "sergeant"), 41 | ("capt", "captain"), 42 | ("esq", "esquire"), 43 | ("ltd", "limited"), 44 | ("col", "colonel"), 45 | ("ft", "fort"), 46 | ] 47 | ] 48 | 49 | 50 | def expand_abbreviations(text): 51 | for regex, replacement in _abbreviations: 52 | text = re.sub(regex, replacement, text) 53 | return text 54 | 55 | 56 | def expand_numbers(text): 57 | return normalize_numbers(text) 58 | 59 | 60 | def lowercase(text): 61 | return text.lower() 62 | 63 | 64 | def collapse_whitespace(text): 65 | return re.sub(_whitespace_re, " ", text) 66 | 67 | 68 | def convert_to_ascii(text): 69 | return unidecode(text) 70 | 71 | 72 | def basic_cleaners(text): 73 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 74 | text = lowercase(text) 75 | text = collapse_whitespace(text) 76 | return text 77 | 78 | 79 | def transliteration_cleaners(text): 80 | """Pipeline for non-English text that transliterates to ASCII.""" 81 | text = convert_to_ascii(text) 82 | text = lowercase(text) 83 | text = collapse_whitespace(text) 84 | return text 85 | 86 | 87 | def english_cleaners(text): 88 | """Pipeline for English text, including number and abbreviation expansion.""" 89 | text = convert_to_ascii(text) 90 | text = lowercase(text) 91 | text = expand_numbers(text) 92 | text = expand_abbreviations(text) 93 | text = collapse_whitespace(text) 94 | return text 95 | 96 | 97 | def punctuation_removers(text): 98 | no_punct = "" 99 | for char in text: 100 | if char not in punctuations: 101 | no_punct = no_punct + char 102 | return no_punct 103 | -------------------------------------------------------------------------------- /dataset/texts/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /dataset/texts/dict_.py: -------------------------------------------------------------------------------- 1 | symbols_ = { 2 | "": 1, 3 | "!": 2, 4 | "'": 3, 5 | ",": 4, 6 | ".": 5, 7 | " ": 6, 8 | "?": 7, 9 | "A": 8, 10 | "B": 9, 11 | "C": 10, 12 | "D": 11, 13 | "E": 12, 14 | "F": 13, 15 | "G": 14, 16 | "H": 15, 17 | "I": 16, 18 | "J": 17, 19 | "K": 18, 20 | "L": 19, 21 | "M": 20, 22 | "N": 21, 23 | "O": 22, 24 | "P": 23, 25 | "Q": 24, 26 | "R": 25, 27 | "S": 26, 28 | "T": 27, 29 | "U": 28, 30 | "V": 29, 31 | "W": 30, 32 | "X": 31, 33 | "Y": 32, 34 | "Z": 33, 35 | "~": 34, 36 | } 37 | -------------------------------------------------------------------------------- /dataset/texts/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /dataset/texts/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from dataset.texts import cmudict 9 | 10 | _pad = "_" 11 | _eos = "~" 12 | _bos = "^" 13 | _punctuation = "!'(),.:;? " 14 | _special = "-" 15 | _letters = "abcdefghijklmnopqrstuvwxyz" 16 | 17 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 18 | # _arpabet = ['@' + s for s in cmudict.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + [_eos] 22 | 23 | # For Phonemes 24 | 25 | PAD = "#" 26 | EOS = "~" 27 | PHONEME_CODES = "AA1 AE0 AE1 AH0 AH1 AO0 AO1 AW0 AW1 AY0 AY1 B CH D DH EH0 EH1 EU0 EU1 EY0 EY1 F G HH IH0 IH1 IY0 IY1 JH K L M N NG OW0 OW1 OY0 OY1 P R S SH T TH UH0 UH1 UW0 UW1 V W Y Z ZH pau".split() 28 | _PHONEME_SEP = " " 29 | 30 | phonemes_symbols = [PAD, EOS] + PHONEME_CODES # PAD should be first to have zero id 31 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | from dataset import dataloader as loader 3 | from fastspeech import FeedForwardTransformer 4 | import sys 5 | import torch 6 | from dataset.texts import valid_symbols 7 | import os 8 | from utils.hparams import HParam, load_hparam_str 9 | import numpy as np 10 | 11 | 12 | def evaluate(hp, validloader, model): 13 | energy_diff = list() 14 | pitch_diff = list() 15 | dur_diff = list() 16 | 17 | l1 = torch.nn.L1Loss() 18 | model.eval() 19 | for valid in validloader: 20 | x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid 21 | 22 | with torch.no_grad(): 23 | ilens = torch.tensor([x_[-1].shape[0]], dtype=torch.long, device=x_.device) 24 | _, after_outs, d_outs, e_outs, p_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel] 25 | 26 | # e_orig = model.energy_predictor.to_one_hot(e_).squeeze() 27 | # p_orig = model.pitch_predictor.to_one_hot(p_).squeeze() 28 | 29 | #print(d_outs) 30 | 31 | dur_diff.append(l1(d_outs, dur_.cuda()).item()) #.numpy() 32 | energy_diff.append(l1(e_outs, e_.cuda()).item()) #.numpy() 33 | pitch_diff.append(l1(p_outs, p_.cuda()).item()) #.numpy() 34 | 35 | 36 | '''_, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate) 37 | target_pitch = np.load(hp.data.data_dir + f"pitch/{ids_[-1]}.wav" ) 38 | target_energy = np.load(hp.data.data_dir + f"energy/{ids_[-1]}.wav" ) 39 | ''' 40 | model.train() 41 | return np.mean(pitch_diff), np.mean(energy_diff), np.mean(dur_diff) 42 | 43 | 44 | def get_parser(): 45 | """Get parser of training arguments.""" 46 | parser = configargparse.ArgumentParser( 47 | description="Train a new text-to-speech (TTS) model on one CPU, one or multiple GPUs", 48 | config_file_parser_class=configargparse.YAMLConfigFileParser, 49 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 50 | ) 51 | parser.add_argument( 52 | "-c", "--config", type=str, required=True, help="yaml file for configuration" 53 | ) 54 | parser.add_argument( 55 | "-p", 56 | "--checkpoint_path", 57 | type=str, 58 | default=None, 59 | help="path of checkpoint pt to evaluate", 60 | ) 61 | 62 | parser.add_argument("--outdir", type=str, required=True, help="Output directory") 63 | 64 | return parser 65 | 66 | def main(cmd_args): 67 | """Run training.""" 68 | parser = get_parser() 69 | args, _ = parser.parse_known_args(cmd_args) 70 | args = parser.parse_args(cmd_args) 71 | 72 | if os.path.exists(args.checkpoint_path): 73 | checkpoint = torch.load(args.checkpoint_path) 74 | else: 75 | print("Checkpoint not exixts") 76 | return None 77 | 78 | if args.config is not None: 79 | hp = HParam(args.config) 80 | else: 81 | hp = load_hparam_str(checkpoint["hp_str"]) 82 | 83 | validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True) 84 | print("Checkpoint : ", args.checkpoint_path) 85 | 86 | 87 | 88 | idim = len(valid_symbols) 89 | odim = hp.audio.num_mels 90 | model = FeedForwardTransformer( 91 | idim, odim, hp 92 | ) 93 | # os.makedirs(args.out, exist_ok=True) 94 | checkpoint = torch.load(args.checkpoint_path) 95 | model.load_state_dict(checkpoint["model"]) 96 | 97 | evaluate(hp, validloader, model) 98 | 99 | 100 | if __name__ == "__main__": 101 | main(sys.argv[1:]) 102 | -------------------------------------------------------------------------------- /export_torchscript.py: -------------------------------------------------------------------------------- 1 | from utils.hparams import HParam 2 | from dataset.texts import valid_symbols 3 | import utils.fastspeech2_script as fs2 4 | import configargparse 5 | import torch 6 | import sys 7 | 8 | 9 | def get_parser(): 10 | 11 | parser = configargparse.ArgumentParser( 12 | description="Train a new text-to-speech (TTS) model on one CPU, one or multiple GPUs", 13 | config_file_parser_class=configargparse.YAMLConfigFileParser, 14 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 15 | ) 16 | 17 | parser.add_argument( 18 | "-c", "--config", type=str, required=True, help="yaml file for configuration" 19 | ) 20 | parser.add_argument( 21 | "-n", 22 | "--name", 23 | type=str, 24 | required=True, 25 | help="name of the model for logging, saving checkpoint", 26 | ) 27 | parser.add_argument("--outdir", type=str, required=True, help="Output directory") 28 | parser.add_argument( 29 | "-t", "--trace", action="store_true", help="For JIT Trace Module" 30 | ) 31 | 32 | return parser 33 | 34 | 35 | def main(cmd_args): 36 | 37 | parser = get_parser() 38 | args, _ = parser.parse_known_args(cmd_args) 39 | 40 | args = parser.parse_args(cmd_args) 41 | 42 | hp = HParam(args.config) 43 | 44 | idim = len(valid_symbols) 45 | odim = hp.audio.num_mels 46 | model = fs2.FeedForwardTransformer(idim, odim, hp) 47 | my_script_module = torch.jit.script(model) 48 | print("Scripting") 49 | my_script_module.save("{}/{}.pt".format(args.outdir, args.name)) 50 | print("Script done") 51 | if args.trace: 52 | print("Tracing") 53 | model.eval() 54 | with torch.no_grad(): 55 | my_trace_module = torch.jit.trace( 56 | model, torch.ones(50).to(dtype=torch.int64) 57 | ) 58 | my_trace_module.save("{}/trace_{}.pt".format(args.outdir, args.name)) 59 | print("Trace Done") 60 | 61 | 62 | if __name__ == "__main__": 63 | main(sys.argv[1:]) 64 | -------------------------------------------------------------------------------- /fastspeech.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """FastSpeech related loss.""" 8 | 9 | import logging 10 | 11 | import torch 12 | from core.duration_modeling.duration_predictor import DurationPredictor 13 | from core.duration_modeling.duration_predictor import DurationPredictorLoss 14 | from core.variance_predictor import EnergyPredictor, EnergyPredictorLoss 15 | from core.variance_predictor import PitchPredictor, PitchPredictorLoss 16 | from core.duration_modeling.length_regulator import LengthRegulator 17 | from utils.util import make_non_pad_mask 18 | from utils.util import make_pad_mask 19 | from core.embedding import PositionalEncoding 20 | from core.embedding import ScaledPositionalEncoding 21 | from core.encoder import Encoder 22 | from core.modules import initialize 23 | from core.modules import Postnet 24 | from typeguard import check_argument_types 25 | from typing import Dict, Tuple, Sequence 26 | from core.acoustic_encoder import UtteranceEncoder, PhonemeLevelEncoder, PhonemeLevelPredictor, AcousticPredictorLoss 27 | 28 | class FeedForwardTransformer(torch.nn.Module): 29 | """Feed Forward Transformer for TTS a.k.a. FastSpeech. 30 | This is a module of FastSpeech, feed-forward Transformer with duration predictor described in 31 | `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive 32 | processing during inference, resulting in fast decoding compared with auto-regressive Transformer. 33 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 34 | https://arxiv.org/pdf/1905.09263.pdf 35 | """ 36 | 37 | def __init__(self, idim: int, odim: int, hp: Dict): 38 | """Initialize feed-forward Transformer module. 39 | Args: 40 | idim (int): Dimension of the inputs. 41 | odim (int): Dimension of the outputs. 42 | """ 43 | # initialize base classes 44 | assert check_argument_types() 45 | torch.nn.Module.__init__(self) 46 | 47 | # fill missing arguments 48 | 49 | # store hyperparameters 50 | self.idim = idim 51 | self.odim = odim 52 | 53 | self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc 54 | self.use_masking = hp.model.use_masking 55 | 56 | # use idx 0 as padding idx 57 | padding_idx = 0 58 | 59 | # get positional encoding class 60 | pos_enc_class = ( 61 | ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding 62 | ) 63 | 64 | # define encoder 65 | encoder_input_layer = torch.nn.Embedding( 66 | num_embeddings=idim, embedding_dim=hp.model.adim, padding_idx=padding_idx 67 | ) 68 | self.encoder = Encoder( 69 | idim=idim, 70 | attention_dim=hp.model.adim, 71 | attention_heads=hp.model.aheads, 72 | linear_units=hp.model.eunits, 73 | num_blocks=hp.model.elayers, 74 | input_layer=encoder_input_layer, 75 | dropout_rate=0.2, 76 | positional_dropout_rate=0.2, 77 | attention_dropout_rate=0.2, 78 | pos_enc_class=pos_enc_class, 79 | normalize_before=hp.model.encoder_normalize_before, 80 | concat_after=hp.model.encoder_concat_after, 81 | positionwise_layer_type=hp.model.positionwise_layer_type, 82 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size, 83 | ) 84 | 85 | self.duration_predictor = DurationPredictor( 86 | idim=hp.model.adim, 87 | n_layers=hp.model.duration_predictor_layers, 88 | n_chans=hp.model.duration_predictor_chans, 89 | kernel_size=hp.model.duration_predictor_kernel_size, 90 | dropout_rate=hp.model.duration_predictor_dropout_rate, 91 | ) 92 | 93 | self.energy_predictor = EnergyPredictor( 94 | idim=hp.model.adim, 95 | n_layers=hp.model.duration_predictor_layers, 96 | n_chans=hp.model.duration_predictor_chans, 97 | kernel_size=hp.model.duration_predictor_kernel_size, 98 | dropout_rate=hp.model.duration_predictor_dropout_rate, 99 | min=hp.data.e_min, 100 | max=hp.data.e_max, 101 | ) 102 | self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) 103 | 104 | self.pitch_predictor = PitchPredictor( 105 | idim=hp.model.adim, 106 | n_layers=hp.model.duration_predictor_layers, 107 | n_chans=hp.model.duration_predictor_chans, 108 | kernel_size=hp.model.duration_predictor_kernel_size, 109 | dropout_rate=hp.model.duration_predictor_dropout_rate, 110 | min=hp.data.p_min, 111 | max=hp.data.p_max, 112 | ) 113 | self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) 114 | 115 | # define length regulator 116 | self.length_regulator = LengthRegulator() 117 | 118 | ###### AdaSpeech 119 | 120 | self.utterance_encoder = UtteranceEncoder(idim=hp.audio.n_mels) 121 | 122 | 123 | self.phoneme_level_encoder = PhonemeLevelEncoder(idim=hp.audio.n_mels) 124 | 125 | self.phoneme_level_predictor = PhonemeLevelPredictor(idim=hp.model.adim) 126 | 127 | self.phone_level_embed = torch.nn.Linear(hp.model.phn_latent_dim, hp.model.adim) 128 | 129 | self.acoustic_criterion = AcousticPredictorLoss() 130 | 131 | # define decoder 132 | # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder 133 | self.decoder = Encoder( 134 | idim=hp.model.adim, 135 | attention_dim=hp.model.ddim, 136 | attention_heads=hp.model.aheads, 137 | linear_units=hp.model.dunits, 138 | num_blocks=hp.model.dlayers, 139 | input_layer="linear", 140 | dropout_rate=0.2, 141 | positional_dropout_rate=0.2, 142 | attention_dropout_rate=0.2, 143 | pos_enc_class=pos_enc_class, 144 | normalize_before=hp.model.decoder_normalize_before, 145 | concat_after=hp.model.decoder_concat_after, 146 | positionwise_layer_type=hp.model.positionwise_layer_type, 147 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size, 148 | ) 149 | 150 | # define postnet 151 | self.postnet = ( 152 | None 153 | if hp.model.postnet_layers == 0 154 | else Postnet( 155 | idim=idim, 156 | odim=odim, 157 | n_layers=hp.model.postnet_layers, 158 | n_chans=hp.model.postnet_chans, 159 | n_filts=hp.model.postnet_filts, 160 | use_batch_norm=hp.model.use_batch_norm, 161 | dropout_rate=hp.model.postnet_dropout_rate, 162 | ) 163 | ) 164 | 165 | # define final projection 166 | self.feat_out = torch.nn.Linear(hp.model.ddim, odim * hp.model.reduction_factor) 167 | 168 | # initialize parameters 169 | self._reset_parameters( 170 | init_type=hp.model.transformer_init, 171 | init_enc_alpha=hp.model.initial_encoder_alpha, 172 | init_dec_alpha=hp.model.initial_decoder_alpha, 173 | ) 174 | 175 | # define criterions 176 | self.duration_criterion = DurationPredictorLoss() 177 | self.energy_criterion = EnergyPredictorLoss() 178 | self.pitch_criterion = PitchPredictorLoss() 179 | self.criterion = torch.nn.L1Loss(reduction="mean") 180 | self.use_weighted_masking = hp.model.use_weighted_masking 181 | 182 | def _forward( 183 | self, 184 | xs: torch.Tensor, 185 | ilens: torch.Tensor, 186 | ys: torch.Tensor = None, 187 | olens: torch.Tensor = None, 188 | ds: torch.Tensor = None, 189 | es: torch.Tensor = None, 190 | ps: torch.Tensor = None, 191 | is_inference: bool = False, 192 | phn_level_predictor: bool = False, 193 | avg_mel: torch.Tensor = None, 194 | ) -> Sequence[torch.Tensor]: 195 | # forward encoder 196 | x_masks = self._source_mask( 197 | ilens 198 | ) # (B, Tmax, Tmax) -> torch.Size([32, 121, 121]) 199 | 200 | hs, _ = self.encoder( 201 | xs, x_masks 202 | ) # (B, Tmax, adim) -> torch.Size([32, 121, 256]) 203 | 204 | ## AdaSpeech Specific ## 205 | 206 | uttr = self.utterance_encoder(ys.transpose(1, 2)).transpose(1, 2) 207 | hs = hs + uttr.repeat(1, hs.size(1), 1) 208 | 209 | phn = None 210 | ys_phn = None 211 | 212 | if phn_level_predictor: 213 | if is_inference: 214 | ys_phn = self.phoneme_level_predictor(hs.transpose(1, 2)) # (B, Tmax, 4) 215 | hs = hs + self.phone_level_embed(ys_phn) 216 | else: 217 | with torch.no_grad(): 218 | ys_phn = self.phoneme_level_encoder(avg_mel.transpose(1, 2)) # (B, Tmax, 4) 219 | 220 | phn = self.phoneme_level_predictor(hs.transpose(1, 2)) # (B, Tmax, 4) 221 | hs = hs + self.phone_level_embed(ys_phn.detach()) # (B, Tmax, 256) 222 | 223 | else: 224 | ys_phn = self.phoneme_level_encoder(avg_mel.transpose(1, 2)) # (B, Tmax, 4) 225 | hs = hs + self.phone_level_embed(ys_phn) # (B, Tmax, 256) 226 | 227 | 228 | 229 | 230 | # forward duration predictor and length regulator 231 | d_masks = make_pad_mask(ilens).to(xs.device) 232 | 233 | if is_inference: 234 | d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) 235 | hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim) 236 | one_hot_energy = self.energy_predictor.inference(hs) # (B, Lmax, adim) 237 | one_hot_pitch = self.pitch_predictor.inference(hs) # (B, Lmax, adim) 238 | else: 239 | with torch.no_grad(): 240 | 241 | one_hot_energy = self.energy_predictor.to_one_hot( 242 | es 243 | ) # (B, Lmax, adim) torch.Size([32, 868, 256]) 244 | 245 | one_hot_pitch = self.pitch_predictor.to_one_hot( 246 | ps 247 | ) # (B, Lmax, adim) torch.Size([32, 868, 256]) 248 | 249 | mel_masks = make_pad_mask(olens).to(xs.device) 250 | 251 | d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) 252 | 253 | hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) 254 | 255 | e_outs = self.energy_predictor(hs, mel_masks) 256 | 257 | p_outs = self.pitch_predictor(hs, mel_masks) 258 | 259 | hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim) 260 | hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim) 261 | # forward decoder 262 | if olens is not None: 263 | h_masks = self._source_mask(olens) 264 | else: 265 | h_masks = None 266 | 267 | zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) 268 | 269 | before_outs = self.feat_out(zs).view( 270 | zs.size(0), -1, self.odim 271 | ) # (B, Lmax, odim) 272 | 273 | # postnet -> (B, Lmax//r * r, odim) 274 | if self.postnet is None: 275 | after_outs = before_outs 276 | else: 277 | after_outs = before_outs + self.postnet( 278 | before_outs.transpose(1, 2) 279 | ).transpose(1, 2) 280 | 281 | if is_inference: 282 | return before_outs, after_outs, d_outs, one_hot_energy, one_hot_pitch 283 | else: 284 | return before_outs, after_outs, d_outs, e_outs, p_outs, phn, ys_phn 285 | 286 | def forward( 287 | self, 288 | xs: torch.Tensor, 289 | ilens: torch.Tensor, 290 | ys: torch.Tensor, 291 | olens: torch.Tensor, 292 | ds: torch.Tensor, 293 | es: torch.Tensor, 294 | ps: torch.Tensor, 295 | avg_mel: torch.Tensor = None, 296 | phn_level_predictor: bool = False 297 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 298 | """Calculate forward propagation. 299 | Args: 300 | xs (Tensor): Batch of padded character ids (B, Tmax). 301 | ilens (LongTensor): Batch of lengths of each input batch (B,). 302 | ys (Tensor): Batch of padded target features (B, Lmax, odim). 303 | olens (LongTensor): Batch of the lengths of each target (B,). 304 | spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). 305 | Returns: 306 | Tensor: Loss value. 307 | """ 308 | # remove unnecessary padded part (for multi-gpus) 309 | xs = xs[:, : max(ilens)] # torch.Size([32, 121]) -> [B, Tmax] 310 | ys = ys[:, : max(olens)] # torch.Size([32, 868, 80]) -> [B, Lmax, odim] 311 | 312 | # forward propagation 313 | before_outs, after_outs, d_outs, e_outs, p_outs, phn, ys_phn = self._forward( 314 | xs, ilens, olens, ds, es, ps, is_inference=False, avg_mel=avg_mel, phn_level_predictor=phn_level_predictor 315 | ) 316 | 317 | 318 | # apply mask to remove padded part 319 | if self.use_masking: 320 | in_masks = make_non_pad_mask(ilens).to(xs.device) 321 | d_outs = d_outs.masked_select(in_masks) 322 | ds = ds.masked_select(in_masks) 323 | out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) 324 | mel_masks = make_non_pad_mask(olens).to(ys.device) 325 | before_outs = before_outs.masked_select(out_masks) 326 | es = es.masked_select(mel_masks) # Write size 327 | ps = ps.masked_select(mel_masks) # Write size 328 | e_outs = e_outs.masked_select(mel_masks) # Write size 329 | p_outs = p_outs.masked_select(mel_masks) # Write size 330 | after_outs = ( 331 | after_outs.masked_select(out_masks) if after_outs is not None else None 332 | ) 333 | ys = ys.masked_select(out_masks) 334 | if phn is not None and ys_phn is not None: 335 | phn = phn.masked_select(in_masks.unsqueeze(-1)) 336 | ys_phn = ys_phn.masked_select(in_masks.unsqueeze(-1)) 337 | 338 | acoustic_loss = 0 339 | 340 | if phn_level_predictor: 341 | acoustic_loss = self.acoustic_criterion(ys_phn, phn) 342 | 343 | # calculate loss 344 | before_loss = self.criterion(before_outs, ys) 345 | after_loss = 0 346 | if after_outs is not None: 347 | after_loss = self.criterion(after_outs, ys) 348 | l1_loss = before_loss + after_loss 349 | duration_loss = self.duration_criterion(d_outs, ds) 350 | energy_loss = self.energy_criterion(e_outs, es) 351 | pitch_loss = self.pitch_criterion(p_outs, ps) 352 | 353 | # make weighted mask and apply it 354 | if self.use_weighted_masking: 355 | out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) 356 | out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() 357 | out_weights /= ys.size(0) * ys.size(2) 358 | duration_masks = make_non_pad_mask(ilens).to(ys.device) 359 | duration_weights = ( 360 | duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() 361 | ) 362 | duration_weights /= ds.size(0) 363 | 364 | # apply weight 365 | l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() 366 | duration_loss = ( 367 | duration_loss.mul(duration_weights).masked_select(duration_masks).sum() 368 | ) 369 | 370 | loss = l1_loss + duration_loss + energy_loss + pitch_loss + acoustic_loss 371 | report_keys = [ 372 | {"l1_loss": l1_loss.item()}, 373 | {"before_loss": before_loss.item()}, 374 | {"after_loss": after_loss.item()}, 375 | {"duration_loss": duration_loss.item()}, 376 | {"energy_loss": energy_loss.item()}, 377 | {"pitch_loss": pitch_loss.item()}, 378 | {"acostic_loss": acoustic_loss}, 379 | {"loss": loss.item()}, 380 | ] 381 | 382 | # self.reporter.report(report_keys) 383 | 384 | return loss, report_keys 385 | 386 | def inference(self, x: torch.Tensor, ref_mel: torch.Tensor = None, avg_mel: torch.Tensor = None 387 | , phn_level_predictor: bool = True) -> torch.Tensor: 388 | """Generate the sequence of features given the sequences of characters. 389 | Args: 390 | x (Tensor): Input sequence of characters (T,). 391 | inference_args (Namespace): Dummy for compatibility. 392 | spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). 393 | Returns: 394 | Tensor: Output sequence of features (1, L, odim). 395 | None: Dummy for compatibility. 396 | None: Dummy for compatibility. 397 | """ 398 | # setup batch axis 399 | ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) 400 | xs = x.unsqueeze(0) 401 | 402 | if ref_mel is not None: 403 | ref_mel = ref_mel.unsqueeze(0) 404 | if avg_mel is not None: 405 | avg_mel = avg_mel.unsqueeze(0) 406 | # inference 407 | before_outs, outs, d_outs, _ = self._forward(xs, ilens=ilens, ys=ref_mel, avg_mel=avg_mel, 408 | is_inference=True, 409 | phn_level_predictor=phn_level_predictor) # (L, odim) 410 | else: 411 | before_outs, outs, d_outs, _ = self._forward(xs, ilens=ilens, ys=ref_mel, is_inference=True, 412 | phn_level_predictor=phn_level_predictor) # (L, odim) 413 | 414 | # inference 415 | _, outs, _, _, _ = self._forward(xs, ilens, is_inference=True) # (L, odim) 416 | 417 | return outs[0] 418 | 419 | def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: 420 | """Make masks for self-attention. 421 | Examples: 422 | >>> ilens = [5, 3] 423 | >>> self._source_mask(ilens) 424 | tensor([[[1, 1, 1, 1, 1], 425 | [1, 1, 1, 1, 1], 426 | [1, 1, 1, 1, 1], 427 | [1, 1, 1, 1, 1], 428 | [1, 1, 1, 1, 1]], 429 | [[1, 1, 1, 0, 0], 430 | [1, 1, 1, 0, 0], 431 | [1, 1, 1, 0, 0], 432 | [0, 0, 0, 0, 0], 433 | [0, 0, 0, 0, 0]]], dtype=torch.uint8) 434 | """ 435 | x_masks = make_non_pad_mask(ilens).to(device=next(self.parameters()).device) 436 | return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1) 437 | 438 | def _reset_parameters( 439 | self, init_type: str, init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0 440 | ): 441 | # initialize parameters 442 | initialize(self, init_type) 443 | 444 | # initialize alpha in scaled positional encoding 445 | if self.use_scaled_pos_enc: 446 | self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) 447 | self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) 448 | -------------------------------------------------------------------------------- /filelists/valid_filelist.txt: -------------------------------------------------------------------------------- 1 | printing in the only sense with which we are at present concerned differs from most if not from all the arts and crafts represented in the exhibition|-2 1 4 14 16 22 35 56 66 78 84 87 97 106 114 117 126 140 150 158 166 171 174 179 182 187 195 203 209 218 227 234 239 248 252 255 261 268 271 274 282 286 290 296 308 329 334 341 377 384 390 399 413 431 433 439 441 444 452 457 472 481 487 490 500 507 516 525 529 532 535 542 566 571 576 584 592 596 605 610 614 617 619 630 633 647 656 662 670 679 683 690 693 696 705 709 712 718 723 727 742 745 748 757 762 770 777 781 787 796 805 811 829|3 3 10 2 6 13 21 10 12 6 3 10 9 8 3 9 14 10 8 8 5 3 5 3 5 8 8 6 9 9 7 5 9 4 3 6 7 3 3 8 4 4 6 12 21 5 7 36 7 6 9 14 18 2 6 2 3 8 5 15 9 6 3 10 7 9 9 4 3 3 7 24 5 5 8 8 4 9 5 4 3 2 11 3 14 9 6 8 9 4 7 3 3 9 4 3 6 5 4 15 3 3 9 5 8 7 4 6 9 9 6 18|P R IH1 N T IH0 NG pau IH1 N DH IY0 OW1 N L IY0 S EH1 N S W IH1 DH pau W IH1 CH W IY1 AA1 R AE1 T P R EH1 Z AH0 N T K AH0 N S ER1 N D pau D IH1 F ER0 Z pau F R AH1 M M OW1 S T IH1 F N AA1 T F R AH1 M AO1 L DH IY0 AA1 R T S AH0 N D K R AE1 F T S R EH2 P R IH0 Z EH1 N T IH0 D IH1 N DH IY0 EH2 K S AH0 B IH1 SH AH0 N|LJ001-0001.wav 2 | in being comparatively modern |-2 4 10 14 25 28 32 38 41 48 58 65 71 75 83 88 93 103 107 118 132 137 142 155 159 161|6 6 4 11 3 4 6 3 7 10 7 6 4 8 5 5 10 4 11 14 5 5 13 4 2|IH1 N B IY1 IH0 NG K AH0 M P EH1 R AH0 T IH0 V L IY0 M AA1 D ER0 N pau sil|LJ001-0002.wav 3 | for although the chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the netherlands by a similar process|-2 2 8 15 21 26 30 43 46 48 52 66 76 82 97 109 115 119 125 135 141 146 150 156 165 170 180 196 199 207 210 213 221 230 234 244 247 254 271 282 297 313 329 335 339 347 362 367 375 380 386 389 394 402 413 423 426 433 439 451 458 463 471 479 487 493 497 503 512 517 523 526 532 540 545 552 561 569 575 585 593 598 606 608 612 619 628 633 642 650 656 663 668 677 701 709 718 727 737 741 746 749 759 763 773 778 789 800 810 830|4 6 7 6 5 4 13 3 2 4 14 10 6 15 12 6 4 6 10 6 5 4 6 9 5 10 16 3 8 3 3 8 9 4 10 3 7 17 11 15 16 16 6 4 8 15 5 8 5 6 3 5 8 11 10 3 7 6 12 7 5 8 8 8 6 4 6 9 5 6 3 6 8 5 7 9 8 6 10 8 5 8 2 4 7 9 5 9 8 6 7 5 9 24 8 9 9 10 4 5 3 10 4 10 5 11 11 10 20|F AO1 R AO2 L DH OW1 pau DH AH1 CH AY0 N IY1 Z T UH1 K IH0 M P R EH1 SH AH0 N Z pau F R AH1 M W UH1 D B L AA1 K S pau IH0 N G R EY1 V D IH1 N R IH0 L IY1 F pau F ER0 S EH1 N CH ER0 IY0 Z B IH0 F AO1 R DH AH0 W UH1 D K AH2 T ER0 Z AH1 V DH AH0 N EH1 DH ER0 L AH0 N D Z pau B AY1 AH0 S IH1 M AH0 L ER0 P R AA1 S EH2 S|LJ001-0003.wav 4 | produced the block books which were the immediate predecessors of the true printed book |-2 1 4 8 15 27 32 47 50 54 60 66 80 88 92 109 116 135 147 153 159 165 171 179 183 196 201 211 221 223 233 235 245 249 255 259 263 268 278 286 297 305 312 317 324 327 331 341 348 361 368 371 377 381 387 391 397 406 426 437 440|3 3 4 7 12 5 15 3 4 6 6 14 8 4 17 7 19 12 6 6 6 6 8 4 13 5 10 10 2 10 2 10 4 6 4 4 5 10 8 11 8 7 5 7 3 4 10 7 13 7 3 6 4 6 4 6 9 20 11 3|P R AH0 D UW1 S T DH AH0 B L AA1 K B UH1 K S pau W IH1 CH W ER1 DH IY0 IH0 M IY1 D IY0 AH0 T P R EH1 D AH0 S EH2 S ER0 Z AH0 V DH AH0 T R UW1 P R IH1 N T AH0 D B UH1 K pau|LJ001-0004.wav 5 | the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing |-2 0 9 11 16 21 29 34 43 46 51 56 63 72 79 83 87 91 94 103 112 120 124 127 135 143 152 161 172 181 186 189 191 194 202 207 209 214 221 224 231 234 239 246 252 260 267 276 280 287 296 303 308 317 325 343 362 370 380 390 399 407 413 418 424 430 439 446 451 456 464 471 478 485 495 517 530 537 541 550 555 560 566 574 578 587 590 594 599 603 606 614 625 631 634 637 646 653 656 662 668 674 680 693 696|2 9 2 5 5 8 5 9 3 5 5 7 9 7 4 4 4 3 9 9 8 4 3 8 8 9 9 11 9 5 3 2 3 8 5 2 5 7 3 7 3 5 7 6 8 7 9 4 7 9 7 5 9 8 18 19 8 10 10 9 8 6 5 6 6 9 7 5 5 8 7 7 7 10 22 13 7 4 9 5 5 6 8 4 9 3 4 5 4 3 8 11 6 3 3 9 7 3 6 6 6 6 13 3|DH IY0 IH0 N V EH1 N SH AH0 N AH0 V M UW1 V AH0 B AH0 L M EH1 T AH0 L L EH1 T ER0 Z IH0 N DH AH0 M IH1 D AH0 L AH1 V DH AH0 F IH0 F T IY1 N TH S EH1 N CH ER0 IY0 pau M EY1 JH AH1 S T L IY0 B IY1 K AH0 N S IH1 D ER0 D pau AE1 Z DH IY0 IH0 N V EH1 N SH AH0 N AH0 V DH IY0 AA1 R T AH0 V P R IH1 N T IH0 NG pau|LJ001-0005.wav 6 | and it is worth mention in passing that as an example of fine typography |-2 18 22 34 40 53 56 65 72 83 91 100 105 111 116 126 132 137 145 151 162 176 185 194 210 236 242 254 261 264 277 283 285 291 295 302 311 326 332 334 337 355 358 369 376 391 398 405 415 423 434 442 446 452 461 484 487|20 4 12 6 13 3 9 7 11 8 9 5 6 5 10 6 5 8 6 11 14 9 9 16 26 6 12 7 3 13 6 2 6 4 7 9 15 6 2 3 18 3 11 7 15 7 7 10 8 11 8 4 6 9 23 3|AE1 N D pau IH1 T IH1 Z W ER1 TH M EH1 N SH AH0 N IH1 N P AE1 S IH0 NG pau DH AE1 T pau AE1 Z AE1 N IH0 G Z AE1 M P AH0 L AH1 V F AY1 N T AH0 P AA1 G R AH0 F IY0 pau|LJ001-0006.wav 7 | the earliest book printed with movable types the gutenberg or forty two line bible of about fourteen fifty five |-2 0 10 27 35 45 49 55 63 68 87 103 109 112 117 122 128 134 139 142 148 157 164 172 176 182 186 192 199 209 227 235 252 273 277 287 295 304 310 317 320 326 346 358 363 377 394 396 404 409 414 419 427 436 446 461 476 487 493 508 513 516 535 544 554 558 563 569 582 594 600 605 609 623 633 639 646 652 659 666 672 684 707 718 720|2 10 17 8 10 4 6 8 5 19 16 6 3 5 5 6 6 5 3 6 9 7 8 4 6 4 6 7 10 18 8 17 21 4 10 8 9 6 7 3 6 20 12 5 14 17 2 8 5 5 5 8 9 10 15 15 11 6 15 5 3 19 9 10 4 5 6 13 12 6 5 4 14 10 6 7 6 7 7 6 12 23 11 2|DH IY0 ER1 L IY0 AH0 S T B UH1 K P R IH1 N T IH0 D W IH1 TH M UW1 V AH0 B AH0 L T AY1 P S pau DH IY0 G UW1 T AH0 N B ER0 G pau AO1 R pau F AO1 R T IY0 T UW1 L AY1 N B AY1 B AH0 L pau AH1 V AH0 B AW1 T F AO1 R T IY1 N F IH1 F T IY0 F AY1 V pau|LJ001-0007.wav 8 | has never been surpassed |-2 0 4 13 21 29 35 42 48 57 63 75 80 91 118 134 147 150 151|2 4 9 8 8 6 7 6 9 6 12 5 11 27 16 13 3 1|HH AE1 Z N EH1 V ER0 B IH1 N S ER0 P AE1 S T pau sil|LJ001-0008.wav 9 | printing then for our purpose may be considered as the art of making books by means of movable types|-2 0 3 15 17 22 30 42 47 64 73 82 84 91 110 116 128 139 144 157 181 202 209 217 221 229 237 241 246 258 265 270 278 283 292 300 303 314 329 335 346 352 358 366 374 382 385 399 405 420 426 452 467 474 487 494 507 513 520 528 537 547 556 561 566 571 575 584 598 612 624 648|2 3 12 2 5 8 12 5 17 9 9 2 7 19 6 12 11 5 13 24 21 7 8 4 8 8 4 5 12 7 5 8 5 9 8 3 11 15 6 11 6 6 8 8 8 3 14 6 15 6 26 15 7 13 7 13 6 7 8 9 10 9 5 5 5 4 9 14 14 12 24|P R IH1 N T IH0 NG DH EH1 N F AO1 R AW1 ER0 P ER1 P AH0 S pau M EY1 B IY1 K AH0 N S IH1 D ER0 D EH1 Z DH IY0 AA1 R T AH1 V M EY1 K IH0 NG B UH1 K S pau B AY1 M IY1 N Z AH0 V M UW1 V AH0 B AH0 L T AY1 P S|LJ001-0009.wav 10 | now as all books not primarily intended as picture books consist principally of types composed to form letterpress |-2 7 41 66 83 93 115 123 128 142 147 163 165 176 190 199 202 209 214 221 227 232 236 244 256 259 266 273 281 285 290 296 302 308 319 321 324 331 338 345 351 359 375 381 397 420 427 432 438 446 453 461 469 475 477 482 488 493 497 507 513 525 530 537 545 558 566 574 581 585 589 597 614 623 631 636 643 647 658 665 672 675 683 690 696 702 711 718 730 752 757|9 34 25 17 10 22 8 5 14 5 16 2 11 14 9 3 7 5 7 6 5 4 8 12 3 7 7 8 4 5 6 6 6 11 2 3 7 7 7 6 8 16 6 16 23 7 5 6 8 7 8 8 6 2 5 6 5 4 10 6 12 5 7 8 13 8 8 7 4 4 8 17 9 8 5 7 4 11 7 7 3 8 7 6 6 9 7 12 22 5|N AW1 pau AE1 Z AO1 L B UH1 K S pau N AA1 T P R AY0 M EH1 R AH0 L IY0 IH0 N T EH1 N D IH0 D EH1 Z pau P IH1 K CH ER0 B UH1 K S pau K AH0 N S IH1 S T P R IH1 N S IH0 P L IY0 AH0 V T AY1 P S K AH0 M P OW1 Z D pau T AH0 F AO1 R M L EH1 T ER0 P R EH2 S pau|LJ001-0010.wav 11 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """TTS Inference script.""" 2 | 3 | import configargparse 4 | import logging 5 | import os 6 | import torch 7 | import sys 8 | from utils.util import set_deterministic_pytorch 9 | from fastspeech import FeedForwardTransformer 10 | from dataset.texts import phonemes_to_sequence 11 | import time 12 | from dataset.audio_processing import griffin_lim 13 | import numpy as np 14 | from utils.stft import STFT 15 | from scipy.io.wavfile import write 16 | from dataset.texts import valid_symbols 17 | from utils.hparams import HParam, load_hparam_str 18 | from dataset.texts.cleaners import english_cleaners, punctuation_removers 19 | import matplotlib.pyplot as plt 20 | from g2p_en import G2p 21 | 22 | 23 | def synthesis(args, text, hp): 24 | """Decode with E2E-TTS model.""" 25 | set_deterministic_pytorch(args) 26 | # read training config 27 | idim = hp.symbol_len 28 | odim = hp.num_mels 29 | model = FeedForwardTransformer(idim, odim, hp) 30 | print(model) 31 | 32 | if os.path.exists(args.path): 33 | print("\nSynthesis Session...\n") 34 | model.load_state_dict(torch.load(args.path), strict=False) 35 | else: 36 | print("Checkpoint not exixts") 37 | return None 38 | 39 | model.eval() 40 | 41 | # set torch device 42 | device = torch.device("cuda" if args.ngpu > 0 else "cpu") 43 | model = model.to(device) 44 | 45 | input = np.asarray(phonemes_to_sequence(text.split())) 46 | text = torch.LongTensor(input) 47 | text = text.cuda() 48 | # [num_char] 49 | 50 | with torch.no_grad(): 51 | # decode and write 52 | idx = input[:5] 53 | start_time = time.time() 54 | print("text :", text.size()) 55 | outs, probs, att_ws = model.inference(text, hp) 56 | print("Out size : ", outs.size()) 57 | 58 | logging.info( 59 | "inference speed = %s msec / frame." 60 | % ((time.time() - start_time) / (int(outs.size(0)) * 1000)) 61 | ) 62 | if outs.size(0) == text.size(0) * args.maxlenratio: 63 | logging.warning("output length reaches maximum length .") 64 | 65 | print("mels", outs.size()) 66 | mel = outs.cpu().numpy() # [T_out, num_mel] 67 | print("numpy ", mel.shape) 68 | 69 | return mel 70 | 71 | 72 | ### for direct text/para input ### 73 | 74 | 75 | g2p = G2p() 76 | 77 | 78 | def plot_mel(mels): 79 | melspec = mels.reshape(1, 80, -1) 80 | plt.imshow(melspec.detach().cpu()[0], aspect="auto", origin="lower") 81 | plt.savefig("mel.png") 82 | 83 | 84 | def preprocess(text): 85 | 86 | # input - line of text 87 | # output - list of phonemes 88 | str1 = " " 89 | clean_content = english_cleaners(text) 90 | clean_content = punctuation_removers(clean_content) 91 | phonemes = g2p(clean_content) 92 | 93 | phonemes = ["" if x == " " else x for x in phonemes] 94 | phonemes = ["pau" if x == "," else x for x in phonemes] 95 | phonemes = ["pau" if x == "." else x for x in phonemes] 96 | phonemes = str1.join(phonemes) 97 | 98 | return phonemes 99 | 100 | 101 | def process_paragraph(para): 102 | # input - paragraph with lines seperated by "." 103 | # output - list with each item as lines of paragraph seperated by suitable padding 104 | text = [] 105 | for lines in para.split("."): 106 | text.append(lines) 107 | 108 | return text 109 | 110 | 111 | def synth(text, model, hp, ref_mel): 112 | """Decode with E2E-TTS model.""" 113 | 114 | print("TTS synthesis") 115 | 116 | model.eval() 117 | # set torch device 118 | device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu") 119 | model = model.to(device) 120 | 121 | input = np.asarray(phonemes_to_sequence(text)) 122 | 123 | text = torch.LongTensor(input) 124 | text = text.to(device) 125 | 126 | with torch.no_grad(): 127 | print("predicting") 128 | outs = model.inference(text, ref_mel = ref_mel) # model(text) for jit script 129 | mel = outs 130 | return mel 131 | 132 | 133 | def main(args): 134 | """Run deocding.""" 135 | para_mel = [] 136 | parser = get_parser() 137 | args = parser.parse_args(args) 138 | 139 | logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) 140 | 141 | print("Text : ", args.text) 142 | print("Checkpoint : ", args.checkpoint_path) 143 | ref_mel = np.load(args.ref_mel) 144 | ref_mel = torch.from_numpy(ref_mel).T 145 | if os.path.exists(args.checkpoint_path): 146 | checkpoint = torch.load(args.checkpoint_path) 147 | else: 148 | logging.info("Checkpoint not exixts") 149 | return None 150 | 151 | if args.config is not None: 152 | hp = HParam(args.config) 153 | else: 154 | hp = load_hparam_str(checkpoint["hp_str"]) 155 | 156 | idim = len(valid_symbols) 157 | odim = hp.audio.num_mels 158 | model = FeedForwardTransformer( 159 | idim, odim, hp 160 | ) # torch.jit.load("./etc/fastspeech_scrip_new.pt") 161 | 162 | os.makedirs(args.out, exist_ok=True) 163 | if args.old_model: 164 | logging.info("\nSynthesis Session...\n") 165 | model.load_state_dict(checkpoint, strict=False) 166 | else: 167 | checkpoint = torch.load(args.checkpoint_path) 168 | model.load_state_dict(checkpoint["model"]) 169 | 170 | text = process_paragraph(args.text) 171 | 172 | for i in range(0, len(text)): 173 | txt = preprocess(text[i]) 174 | audio = synth(txt, model, hp, ref_mel) 175 | m = audio.T 176 | para_mel.append(m) 177 | 178 | m = torch.cat(para_mel, dim=1) 179 | np.save("mel.npy", m.cpu().numpy()) 180 | plot_mel(m) 181 | 182 | if hp.train.melgan_vocoder: 183 | m = m.unsqueeze(0) 184 | print("Mel shape: ", m.shape) 185 | vocoder = torch.hub.load("seungwonpark/melgan", "melgan") 186 | vocoder.eval() 187 | if torch.cuda.is_available(): 188 | vocoder = vocoder.cuda() 189 | mel = m.cuda() 190 | 191 | with torch.no_grad(): 192 | wav = vocoder.inference( 193 | mel 194 | ) # mel ---> batch, num_mels, frames [1, 80, 234] 195 | wav = wav.cpu().float().numpy() 196 | else: 197 | stft = STFT(filter_length=1024, hop_length=256, win_length=1024) 198 | print(m.size()) 199 | m = m.unsqueeze(0) 200 | wav = griffin_lim(m, stft, 30) 201 | wav = wav.cpu().numpy() 202 | save_path = "{}/test_tts.wav".format(args.out) 203 | write(save_path, hp.audio.sample_rate, wav.astype("int16")) 204 | 205 | 206 | # NOTE: you need this func to generate our sphinx doc 207 | def get_parser(): 208 | """Get parser of decoding arguments.""" 209 | parser = configargparse.ArgumentParser( 210 | description="Synthesize speech from text using a TTS model on one CPU", 211 | config_file_parser_class=configargparse.YAMLConfigFileParser, 212 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 213 | ) 214 | # general configuration 215 | 216 | parser.add_argument( 217 | "-c", "--config", type=str, required=True, help="yaml file for configuration" 218 | ) 219 | parser.add_argument( 220 | "-p", 221 | "--checkpoint_path", 222 | type=str, 223 | default=None, 224 | help="path of checkpoint pt file to resume training", 225 | ) 226 | parser.add_argument("--out", type=str, required=True, help="Output filename") 227 | parser.add_argument( 228 | "-o", "--old_model", action="store_true", help="Resume Old model " 229 | ) 230 | # task related 231 | parser.add_argument( 232 | "--text", type=str, required=True, help="Filename of train label data (json)" 233 | ) 234 | parser.add_argument( 235 | "--ref_mel", type=str, required=True, help="Filename of Reference Mels" 236 | ) 237 | parser.add_argument( 238 | "--pad", default=2, type=int, help="padd value at the end of each sentence" 239 | ) 240 | return parser 241 | 242 | 243 | if __name__ == "__main__": 244 | print("Starting") 245 | main(sys.argv[1:]) 246 | -------------------------------------------------------------------------------- /nvidia_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import torch 5 | import argparse 6 | import numpy as np 7 | from utils.stft import TacotronSTFT 8 | from utils.util import read_wav_np 9 | from dataset.audio_processing import pitch 10 | from utils.hparams import HParam 11 | import torch.nn.functional as F 12 | from utils.util import str_to_int_list 13 | 14 | def _average_mel_by_duration(x: torch.Tensor, d: torch.Tensor) -> torch.Tensor: 15 | #print(d.sum(), len(x)) 16 | if d.sum() != x.shape[-1]: 17 | d[-1] += 1 18 | d_cumsum = F.pad(d.cumsum(dim=0), (1, 0)) 19 | x_avg = [ 20 | x[:, int(start):int(end)].sum(dim=1)//(end - start) if len(x[:, int(start):int(end)]) != 0 else x.zeros() 21 | for start, end in zip(d_cumsum[:-1], d_cumsum[1:]) 22 | ] 23 | return torch.stack(x_avg) 24 | 25 | def preprocess(data_path, hp, file): 26 | stft = TacotronSTFT( 27 | filter_length=hp.audio.n_fft, 28 | hop_length=hp.audio.hop_length, 29 | win_length=hp.audio.win_length, 30 | n_mel_channels=hp.audio.n_mels, 31 | sampling_rate=hp.audio.sample_rate, 32 | mel_fmin=hp.audio.fmin, 33 | mel_fmax=hp.audio.fmax, 34 | ) 35 | 36 | 37 | mel_path = os.path.join(hp.data.data_dir, "mels") 38 | energy_path = os.path.join(hp.data.data_dir, "energy") 39 | pitch_path = os.path.join(hp.data.data_dir, "pitch") 40 | avg_mel_phon = os.path.join(hp.data.data_dir, "avg_mel_ph") 41 | 42 | os.makedirs(mel_path, exist_ok=True) 43 | os.makedirs(energy_path, exist_ok=True) 44 | os.makedirs(pitch_path, exist_ok=True) 45 | os.makedirs(avg_mel_phon, exist_ok=True) 46 | print("Sample Rate : ", hp.audio.sample_rate) 47 | 48 | with open("{}".format(file), encoding="utf-8") as f: 49 | _metadata = [line.strip().split("|") for line in f] 50 | for metadata in tqdm.tqdm(_metadata, desc="preprocess wav to mel"): 51 | wavpath = os.path.join(data_path, metadata[4]) 52 | sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) 53 | 54 | dur = str_to_int_list(metadata[2]) 55 | dur = torch.from_numpy(np.array(dur)) 56 | 57 | p = pitch(wav, hp) # [T, ] T = Number of frames 58 | wav = torch.from_numpy(wav).unsqueeze(0) 59 | mel, mag = stft.mel_spectrogram(wav) # mel [1, 80, T] mag [1, num_mag, T] 60 | mel = mel.squeeze(0) # [num_mel, T] 61 | mag = mag.squeeze(0) # [num_mag, T] 62 | e = torch.norm(mag, dim=0) # [T, ] 63 | p = p[: mel.shape[1]] 64 | 65 | avg_mel_ph = _average_mel_by_duration(mel, dur) # [num_mel, L] 66 | assert (avg_mel_ph.shape[0] == dur.shape[-1]) 67 | 68 | id = os.path.basename(wavpath).split(".")[0] 69 | np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) 70 | np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) 71 | np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False) 72 | np.save("{}/{}.npy".format(avg_mel_phon, id), avg_mel_ph.numpy(), allow_pickle=False) 73 | 74 | def main(args, hp): 75 | print("Preprocess Training dataset :") 76 | preprocess(args.data_path, hp, hp.data.train_filelist) 77 | print("Preprocess Validation dataset :") 78 | preprocess(args.data_path, hp, hp.data.valid_filelist) 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument( 83 | "-d", "--data_path", type=str, required=True, help="root directory of wav files" 84 | ) 85 | parser.add_argument( 86 | "-c", "--config", type=str, required=True, help="yaml file for configuration" 87 | ) 88 | args = parser.parse_args() 89 | 90 | hp = HParam(args.config) 91 | 92 | main(args, hp) 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.3 2 | librosa==0.7.0 3 | numba==0.48 4 | matplotlib 5 | unidecode 6 | inflect 7 | nltk 8 | tqdm 9 | pyyaml 10 | pyworld==0.2.10 11 | configargparse 12 | tensorboardX 13 | typeguard==2.9.1 14 | g2p_en 15 | 16 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_fastspeech2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.hparams import HParam 3 | from dataset.texts import valid_symbols 4 | from fastspeech import FeedForwardTransformer 5 | 6 | 7 | def test_fastspeech(): 8 | idim = len(valid_symbols) 9 | hp = HParam("configs/default.yaml") 10 | hp.train.ngpu = 0 11 | odim = hp.audio.num_mels 12 | model = FeedForwardTransformer(idim, odim, hp) 13 | x = torch.ones(2, 100).to(dtype=torch.int64) 14 | input_length = torch.tensor([100, 100]) 15 | y = torch.ones(2, 100, 80) 16 | out_length = torch.tensor([100, 100]) 17 | dur = torch.ones(2, 100) 18 | e = torch.ones(2, 100) 19 | p = torch.ones(2, 100) 20 | loss, report_dict = model(x, input_length, y, out_length, dur, e, p) 21 | -------------------------------------------------------------------------------- /train_fastspeech.py: -------------------------------------------------------------------------------- 1 | import fastspeech 2 | from tensorboardX import SummaryWriter 3 | import torch 4 | from dataset import dataloader as loader 5 | import logging 6 | import math 7 | import os 8 | import sys 9 | import numpy as np 10 | import configargparse 11 | import random 12 | import tqdm 13 | import time 14 | from evaluation import evaluate 15 | from utils.plot import generate_audio, plot_spectrogram_to_numpy 16 | from core.optimizer import get_std_opt 17 | from utils.util import read_wav_np 18 | from dataset.texts import valid_symbols 19 | from utils.util import get_commit_hash 20 | from utils.hparams import HParam 21 | 22 | BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"] 23 | BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"] 24 | 25 | 26 | def train(args, hp, hp_str, logger, vocoder): 27 | os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True) 28 | os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True) 29 | os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True) 30 | device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu") 31 | 32 | dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size, hp) 33 | validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True) 34 | 35 | idim = len(valid_symbols) 36 | odim = hp.audio.num_mels 37 | model = fastspeech.FeedForwardTransformer(idim, odim, hp) 38 | # set torch device 39 | model = model.to(device) 40 | print("Model is loaded ...") 41 | githash = get_commit_hash() 42 | if args.checkpoint_path is not None: 43 | if os.path.exists(args.checkpoint_path): 44 | logger.info("Resuming from checkpoint: %s" % args.checkpoint_path) 45 | checkpoint = torch.load(args.checkpoint_path) 46 | model.load_state_dict(checkpoint["model"]) 47 | optimizer = get_std_opt( 48 | model, 49 | hp.model.adim, 50 | hp.model.transformer_warmup_steps, 51 | hp.model.transformer_lr, 52 | ) 53 | optimizer.load_state_dict(checkpoint["optim"]) 54 | global_step = checkpoint["step"] 55 | 56 | if hp_str != checkpoint["hp_str"]: 57 | logger.warning( 58 | "New hparams is different from checkpoint. Will use new." 59 | ) 60 | 61 | if githash != checkpoint["githash"]: 62 | logger.warning("Code might be different: git hash is different.") 63 | logger.warning("%s -> %s" % (checkpoint["githash"], githash)) 64 | 65 | else: 66 | print("Checkpoint does not exixts") 67 | global_step = 0 68 | return None 69 | else: 70 | print("New Training") 71 | global_step = 0 72 | optimizer = get_std_opt( 73 | model, 74 | hp.model.adim, 75 | hp.model.transformer_warmup_steps, 76 | hp.model.transformer_lr, 77 | ) 78 | 79 | print("Batch Size :", hp.train.batch_size) 80 | 81 | num_params(model) 82 | 83 | os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True) 84 | writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name)) 85 | model.train() 86 | forward_count = 0 87 | phn_level_predictor = False 88 | # print(model) 89 | for epoch in range(hp.train.epochs): 90 | start = time.time() 91 | running_loss = 0 92 | j = 0 93 | 94 | pbar = tqdm.tqdm(dataloader, desc="Loading train data") 95 | for data in pbar: 96 | global_step += 1 97 | if hp.model.phoneme_acoustic_embed and global_step >= hp.model.predictor_start_step: 98 | phn_level_predictor = True 99 | x, input_length, y, _, out_length, _, dur, e, p, avg_mel = data 100 | # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] 101 | # # stop_token : [batch, T_in], out_length : [batch] 102 | 103 | loss, report_dict = model( 104 | x.cuda(), 105 | input_length.cuda(), 106 | y.cuda(), 107 | out_length.cuda(), 108 | dur.cuda(), 109 | e.cuda(), 110 | p.cuda(), 111 | avg_mel.cuda(), 112 | phn_level_predictor 113 | ) 114 | loss = loss.mean() / hp.train.accum_grad 115 | running_loss += loss.item() 116 | 117 | loss.backward() 118 | 119 | # update parameters 120 | forward_count += 1 121 | j = j + 1 122 | if forward_count != hp.train.accum_grad: 123 | continue 124 | forward_count = 0 125 | step = global_step 126 | 127 | # compute the gradient norm to check if it is normal or not 128 | grad_norm = torch.nn.utils.clip_grad_norm_( 129 | model.parameters(), hp.train.grad_clip 130 | ) 131 | logging.debug("grad norm={}".format(grad_norm)) 132 | if math.isnan(grad_norm): 133 | logging.warning("grad norm is nan. Do not update model.") 134 | else: 135 | optimizer.step() 136 | optimizer.zero_grad() 137 | 138 | if step % hp.train.summary_interval == 0: 139 | pbar.set_description( 140 | "Average Loss %.04f Loss %.04f | step %d" 141 | % (running_loss / j, loss.item(), step) 142 | ) 143 | 144 | for r in report_dict: 145 | for k, v in r.items(): 146 | if k is not None and v is not None: 147 | if "cupy" in str(type(v)): 148 | v = v.get() 149 | if "cupy" in str(type(k)): 150 | k = k.get() 151 | writer.add_scalar("main/{}".format(k), v, step) 152 | 153 | if step % hp.train.validation_step == 0: 154 | 155 | for valid in validloader: 156 | x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_, avg_mel_ = valid 157 | model.eval() 158 | with torch.no_grad(): 159 | loss_, report_dict_ = model( 160 | x_.cuda(), 161 | input_length_.cuda(), 162 | y_.cuda(), 163 | out_length_.cuda(), 164 | dur_.cuda(), 165 | e_.cuda(), 166 | p_.cuda(), 167 | avg_mel_.cuda(), 168 | phn_level_predictor 169 | ) 170 | 171 | mels_ = model.inference(x_[-1].cuda(), ref_mel = y_[-1].cuda(), avg_mel = avg_mel_[-1].cuda(), 172 | phn_level_predictor = phn_level_predictor) # [T, num_mel] 173 | 174 | model.train() 175 | for r in report_dict_: 176 | for k, v in r.items(): 177 | if k is not None and v is not None: 178 | if "cupy" in str(type(v)): 179 | v = v.get() 180 | if "cupy" in str(type(k)): 181 | k = k.get() 182 | writer.add_scalar("validation/{}".format(k), v, step) 183 | 184 | mels_ = mels_.T # Out: [num_mels, T] 185 | writer.add_image( 186 | "melspectrogram_target_{}".format(ids_[-1]), 187 | plot_spectrogram_to_numpy( 188 | y_[-1].T.data.cpu().numpy()[:, : out_length_[-1]] 189 | ), 190 | step, 191 | dataformats="HWC", 192 | ) 193 | writer.add_image( 194 | "melspectrogram_prediction_{}".format(ids_[-1]), 195 | plot_spectrogram_to_numpy(mels_.data.cpu().numpy()), 196 | step, 197 | dataformats="HWC", 198 | ) 199 | 200 | # print(mels.unsqueeze(0).shape) 201 | 202 | audio = generate_audio( 203 | mels_.unsqueeze(0), vocoder 204 | ) # selecting the last data point to match mel generated above 205 | audio = audio.cpu().float().numpy() 206 | audio = audio / ( 207 | audio.max() - audio.min() 208 | ) # get values between -1 and 1 209 | 210 | writer.add_audio( 211 | tag="generated_audio_{}".format(ids_[-1]), 212 | snd_tensor=torch.Tensor(audio), 213 | global_step=step, 214 | sample_rate=hp.audio.sample_rate, 215 | ) 216 | 217 | _, target = read_wav_np( 218 | hp.data.wav_dir + f"{ids_[-1]}.wav", 219 | sample_rate=hp.audio.sample_rate, 220 | ) 221 | 222 | writer.add_audio( 223 | tag=" target_audio_{}".format(ids_[-1]), 224 | snd_tensor=torch.Tensor(target), 225 | global_step=step, 226 | sample_rate=hp.audio.sample_rate, 227 | ) 228 | 229 | ## 230 | if step % hp.train.save_interval == 0: 231 | avg_p, avg_e, avg_d = evaluate(hp, validloader, model) 232 | writer.add_scalar("evaluation/Pitch_Loss", avg_p, step) 233 | writer.add_scalar("evaluation/Energy_Loss", avg_e, step) 234 | writer.add_scalar("evaluation/Dur_Loss", avg_d, step) 235 | save_path = os.path.join( 236 | hp.train.chkpt_dir, 237 | args.name, 238 | "{}_fastspeech_{}_{}k_steps.pyt".format( 239 | args.name, githash, step // 1000 240 | ), 241 | ) 242 | 243 | torch.save( 244 | { 245 | "model": model.state_dict(), 246 | "optim": optimizer.state_dict(), 247 | "step": step, 248 | "hp_str": hp_str, 249 | "githash": githash, 250 | }, 251 | save_path, 252 | ) 253 | logger.info("Saved checkpoint to: %s" % save_path) 254 | print( 255 | "Time taken for epoch {} is {} sec\n".format( 256 | epoch + 1, int(time.time() - start) 257 | ) 258 | ) 259 | 260 | 261 | def num_params(model, print_out=True): 262 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 263 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 264 | if print_out: 265 | print("Trainable Parameters: %.3fM" % parameters) 266 | 267 | 268 | def create_gta(args, hp, hp_str, logger): 269 | os.makedirs(os.path.join(hp.data.data_dir, "gta"), exist_ok=True) 270 | device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu") 271 | 272 | dataloader = loader.get_tts_dataset(hp.data.data_dir, 1) 273 | validloader = loader.get_tts_dataset(hp.data.data_dir, 1, True) 274 | global_step = 0 275 | idim = len(valid_symbols) 276 | odim = hp.audio.num_mels 277 | model = fastspeech.FeedForwardTransformer(idim, odim, args) 278 | # set torch device 279 | if os.path.exists(args.checkpoint_path): 280 | print("\nSynthesis GTA Session...\n") 281 | checkpoint = torch.load(args.checkpoint_path) 282 | model.load_state_dict(checkpoint["model"]) 283 | else: 284 | print("Checkpoint not exixts") 285 | return None 286 | model.eval() 287 | model = model.to(device) 288 | print("Model is loaded ...") 289 | print("Batch Size :", hp.train.batch_size) 290 | num_params(model) 291 | onlyValidation = False 292 | if not onlyValidation: 293 | pbar = tqdm.tqdm(dataloader, desc="Loading train data") 294 | for data in pbar: 295 | # start_b = time.time() 296 | global_step += 1 297 | x, input_length, y, _, out_length, ids = data 298 | with torch.no_grad(): 299 | _, gta, _, _, _ = model._forward( 300 | x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda() 301 | ) 302 | # gta = model._forward(x.cuda(), input_length.cuda(), is_inference=False) 303 | gta = gta.cpu().numpy() 304 | 305 | for j in range(len(ids)): 306 | mel = gta[j] 307 | mel = mel.T 308 | mel = mel[:, : out_length[j]] 309 | mel = (mel + 4) / 8 310 | id = ids[j] 311 | np.save( 312 | "{}/{}.npy".format(os.path.join(hp.data.data_dir, "gta"), id), 313 | mel, 314 | allow_pickle=False, 315 | ) 316 | 317 | pbar = tqdm.tqdm(validloader, desc="Loading Valid data") 318 | for data in pbar: 319 | # start_b = time.time() 320 | global_step += 1 321 | x, input_length, y, _, out_length, ids = data 322 | with torch.no_grad(): 323 | gta, _, _ = model._forward( 324 | x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda() 325 | ) 326 | # gta = model._forward(x.cuda(), input_length.cuda(), is_inference=True) 327 | gta = gta.cpu().numpy() 328 | 329 | for j in range(len(ids)): 330 | print("Actual mel specs : {} = {}".format(ids[j], y[j].shape)) 331 | print("Out length:", out_length[j]) 332 | print("GTA size: {} = {}".format(ids[j], gta[j].shape)) 333 | mel = gta[j] 334 | mel = mel.T 335 | mel = mel[:, : out_length[j]] 336 | mel = (mel + 4) / 8 337 | print("Mel size: {} = {}".format(ids[j], mel.shape)) 338 | id = ids[j] 339 | np.save( 340 | "{}/{}.npy".format(os.path.join(hp.data.data_dir, "gta"), id), 341 | mel, 342 | allow_pickle=False, 343 | ) 344 | 345 | 346 | # define function for plot prob and att_ws 347 | def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): 348 | import matplotlib.pyplot as plt 349 | 350 | shape = array.shape 351 | if len(shape) == 1: 352 | # for eos probability 353 | fig = plt.figure(figsize=figsize, dpi=dpi) 354 | plt.plot(array) 355 | plt.xlabel("Frame") 356 | plt.ylabel("Probability") 357 | plt.ylim([0, 1]) 358 | elif len(shape) == 2: 359 | # for tacotron 2 attention weights, whose shape is (out_length, in_length) 360 | fig = plt.figure(figsize=figsize, dpi=dpi) 361 | plt.imshow(array, aspect="auto") 362 | plt.xlabel("Input") 363 | plt.ylabel("Output") 364 | elif len(shape) == 4: 365 | # for transformer attention weights, whose shape is (#leyers, #heads, out_length, in_length) 366 | fig = plt.figure( 367 | figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi 368 | ) 369 | for idx1, xs in enumerate(array): 370 | for idx2, x in enumerate(xs, 1): 371 | plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) 372 | plt.imshow(x.cpu().detach().numpy(), aspect="auto") 373 | plt.xlabel("Input") 374 | plt.ylabel("Output") 375 | else: 376 | raise NotImplementedError("Support only from 1D to 4D array.") 377 | plt.tight_layout() 378 | if not os.path.exists(os.path.dirname(figname)): 379 | # NOTE: exist_ok = True is needed for parallel process decoding 380 | os.makedirs(os.path.dirname(figname), exist_ok=True) 381 | plt.savefig(figname) 382 | plt.close() 383 | return fig 384 | 385 | 386 | # NOTE: you need this func to generate our sphinx doc 387 | def get_parser(): 388 | """Get parser of training arguments.""" 389 | parser = configargparse.ArgumentParser( 390 | description="Train a new text-to-speech (TTS) model on one CPU, one or multiple GPUs", 391 | config_file_parser_class=configargparse.YAMLConfigFileParser, 392 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 393 | ) 394 | 395 | parser.add_argument( 396 | "-c", "--config", type=str, required=True, help="yaml file for configuration" 397 | ) 398 | parser.add_argument( 399 | "-p", 400 | "--checkpoint_path", 401 | type=str, 402 | default=None, 403 | help="path of checkpoint pt file to resume training", 404 | ) 405 | parser.add_argument( 406 | "-n", 407 | "--name", 408 | type=str, 409 | required=True, 410 | help="name of the model for logging, saving checkpoint", 411 | ) 412 | parser.add_argument("--outdir", type=str, required=True, help="Output directory") 413 | 414 | return parser 415 | 416 | 417 | def main(cmd_args): 418 | """Run training.""" 419 | parser = get_parser() 420 | args, _ = parser.parse_known_args(cmd_args) 421 | 422 | args = parser.parse_args(cmd_args) 423 | 424 | hp = HParam(args.config) 425 | with open(args.config, "r") as f: 426 | hp_str = "".join(f.readlines()) 427 | 428 | # logging info 429 | os.makedirs(hp.train.log_dir, exist_ok=True) 430 | logging.basicConfig( 431 | level=logging.INFO, 432 | format="%(asctime)s - %(levelname)s - %(message)s", 433 | handlers=[ 434 | logging.FileHandler( 435 | os.path.join(hp.train.log_dir, "%s-%d.log" % (args.name, time.time())) 436 | ), 437 | logging.StreamHandler(), 438 | ], 439 | ) 440 | logger = logging.getLogger() 441 | 442 | # If --ngpu is not given, 443 | # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices 444 | # 2. if nvidia-smi exists, use all devices 445 | # 3. else ngpu=0 446 | ngpu = hp.train.ngpu 447 | logger.info(f"ngpu: {ngpu}") 448 | 449 | # set random seed 450 | logger.info("random seed = %d" % hp.train.seed) 451 | random.seed(hp.train.seed) 452 | np.random.seed(hp.train.seed) 453 | 454 | vocoder = torch.hub.load( 455 | "seungwonpark/melgan", "melgan" 456 | ) # load the vocoder for validation 457 | 458 | if hp.train.GTA: 459 | create_gta(args, hp, hp_str, logger) 460 | else: 461 | train(args, hp, hp_str, logger, vocoder) 462 | 463 | 464 | if __name__ == "__main__": 465 | main(sys.argv[1:]) 466 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AdaSpeech/c208a4fb83b09f4943cb4faf3c007892dc3f6b7b/utils/__init__.py -------------------------------------------------------------------------------- /utils/display.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import matplotlib 4 | 5 | matplotlib.use("Agg") 6 | 7 | 8 | def progbar(i, n, size=16): 9 | done = (i * size) // n 10 | bar = "" 11 | for i in range(size): 12 | bar += "█" if i <= done else "░" 13 | return bar 14 | 15 | 16 | def stream(message): 17 | sys.stdout.write(f"\r{message}") 18 | 19 | 20 | def simple_table(item_tuples): 21 | border_pattern = "+---------------------------------------" 22 | whitespace = " " 23 | 24 | headings, cells, = ( 25 | [], 26 | [], 27 | ) 28 | 29 | for item in item_tuples: 30 | 31 | heading, cell = str(item[0]), str(item[1]) 32 | 33 | pad_head = True if len(heading) < len(cell) else False 34 | 35 | pad = abs(len(heading) - len(cell)) 36 | pad = whitespace[:pad] 37 | 38 | pad_left = pad[: len(pad) // 2] 39 | pad_right = pad[len(pad) // 2 :] 40 | 41 | if pad_head: 42 | heading = pad_left + heading + pad_right 43 | else: 44 | cell = pad_left + cell + pad_right 45 | 46 | headings += [heading] 47 | cells += [cell] 48 | 49 | border, head, body = "", "", "" 50 | 51 | for i in range(len(item_tuples)): 52 | 53 | temp_head = f"| {headings[i]} " 54 | temp_body = f"| {cells[i]} " 55 | 56 | border += border_pattern[: len(temp_head)] 57 | head += temp_head 58 | body += temp_body 59 | 60 | if i == len(item_tuples) - 1: 61 | head += "|" 62 | body += "|" 63 | border += "+" 64 | 65 | print(border) 66 | print(head) 67 | print(border) 68 | print(body) 69 | print(border) 70 | print(" ") 71 | 72 | 73 | def time_since(started): 74 | elapsed = time.time() - started 75 | m = int(elapsed // 60) 76 | s = int(elapsed % 60) 77 | if m >= 60: 78 | h = int(m // 60) 79 | m = m % 60 80 | return f"{h}h {m}m {s}s" 81 | else: 82 | return f"{m}m {s}s" 83 | -------------------------------------------------------------------------------- /utils/fastspeech2_script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """FastSpeech related loss.""" 8 | 9 | import logging 10 | 11 | import torch 12 | from core.duration_modeling.duration_predictor import DurationPredictor 13 | from core.duration_modeling.duration_predictor import DurationPredictorLoss 14 | from core.variance_predictor import EnergyPredictor, EnergyPredictorLoss 15 | from core.variance_predictor import PitchPredictor, PitchPredictorLoss 16 | from core.duration_modeling.length_regulator import LengthRegulator 17 | from utils.util import make_non_pad_mask_script 18 | from utils.util import make_pad_mask_script 19 | from core.embedding import PositionalEncoding 20 | from core.embedding import ScaledPositionalEncoding 21 | from core.encoder import Encoder 22 | from core.modules import initialize 23 | from core.modules import Postnet 24 | from typeguard import check_argument_types 25 | from typing import Dict, Tuple, Sequence 26 | 27 | 28 | class FeedForwardTransformer(torch.nn.Module): 29 | def __init__(self, idim: int, odim: int, hp: Dict): 30 | """Initialize feed-forward Transformer module. 31 | Args: 32 | idim (int): Dimension of the inputs. 33 | odim (int): Dimension of the outputs. 34 | """ 35 | # initialize base classes 36 | assert check_argument_types() 37 | torch.nn.Module.__init__(self) 38 | 39 | # fill missing arguments 40 | 41 | # store hyperparameters 42 | self.idim = idim 43 | self.odim = odim 44 | 45 | self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc 46 | self.use_masking = hp.model.use_masking 47 | 48 | # use idx 0 as padding idx 49 | padding_idx = 0 50 | 51 | # get positional encoding class 52 | pos_enc_class = ( 53 | ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding 54 | ) 55 | 56 | # define encoder 57 | encoder_input_layer = torch.nn.Embedding( 58 | num_embeddings=idim, embedding_dim=hp.model.adim, padding_idx=padding_idx 59 | ) 60 | self.encoder = Encoder( 61 | idim=idim, 62 | attention_dim=hp.model.adim, 63 | attention_heads=hp.model.aheads, 64 | linear_units=hp.model.eunits, 65 | num_blocks=hp.model.elayers, 66 | input_layer=encoder_input_layer, 67 | dropout_rate=0.2, 68 | positional_dropout_rate=0.2, 69 | attention_dropout_rate=0.2, 70 | pos_enc_class=pos_enc_class, 71 | normalize_before=hp.model.encoder_normalize_before, 72 | concat_after=hp.model.encoder_concat_after, 73 | positionwise_layer_type=hp.model.positionwise_layer_type, 74 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size, 75 | ) 76 | 77 | self.duration_predictor = DurationPredictor( 78 | idim=hp.model.adim, 79 | n_layers=hp.model.duration_predictor_layers, 80 | n_chans=hp.model.duration_predictor_chans, 81 | kernel_size=hp.model.duration_predictor_kernel_size, 82 | dropout_rate=hp.model.duration_predictor_dropout_rate, 83 | ) 84 | 85 | self.energy_predictor = EnergyPredictor( 86 | idim=hp.model.adim, 87 | n_layers=hp.model.duration_predictor_layers, 88 | n_chans=hp.model.duration_predictor_chans, 89 | kernel_size=hp.model.duration_predictor_kernel_size, 90 | dropout_rate=hp.model.duration_predictor_dropout_rate, 91 | min=hp.data.e_min, 92 | max=hp.data.e_max, 93 | ) 94 | self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) 95 | 96 | self.pitch_predictor = PitchPredictor( 97 | idim=hp.model.adim, 98 | n_layers=hp.model.duration_predictor_layers, 99 | n_chans=hp.model.duration_predictor_chans, 100 | kernel_size=hp.model.duration_predictor_kernel_size, 101 | dropout_rate=hp.model.duration_predictor_dropout_rate, 102 | min=hp.data.p_min, 103 | max=hp.data.p_max, 104 | ) 105 | self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim) 106 | 107 | # define length regulator 108 | self.length_regulator = LengthRegulator() 109 | 110 | # define decoder 111 | # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder 112 | self.decoder = Encoder( 113 | idim=256, 114 | attention_dim=256, 115 | attention_heads=hp.model.aheads, 116 | linear_units=hp.model.dunits, 117 | num_blocks=hp.model.dlayers, 118 | input_layer=None, 119 | dropout_rate=0.2, 120 | positional_dropout_rate=0.2, 121 | attention_dropout_rate=0.2, 122 | pos_enc_class=pos_enc_class, 123 | normalize_before=hp.model.decoder_normalize_before, 124 | concat_after=hp.model.decoder_concat_after, 125 | positionwise_layer_type=hp.model.positionwise_layer_type, 126 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size, 127 | ) 128 | 129 | # define postnet 130 | self.postnet = ( 131 | None 132 | if hp.model.postnet_layers == 0 133 | else Postnet( 134 | idim=idim, 135 | odim=odim, 136 | n_layers=hp.model.postnet_layers, 137 | n_chans=hp.model.postnet_chans, 138 | n_filts=hp.model.postnet_filts, 139 | use_batch_norm=hp.model.use_batch_norm, 140 | dropout_rate=hp.model.postnet_dropout_rate, 141 | ) 142 | ) 143 | 144 | # define final projection 145 | self.feat_out = torch.nn.Linear(hp.model.adim, odim * hp.model.reduction_factor) 146 | 147 | # initialize parameters 148 | self._reset_parameters( 149 | init_type=hp.model.transformer_init, 150 | init_enc_alpha=hp.model.initial_encoder_alpha, 151 | init_dec_alpha=hp.model.initial_decoder_alpha, 152 | ) 153 | 154 | # define criterions 155 | self.duration_criterion = DurationPredictorLoss() 156 | self.energy_criterion = EnergyPredictorLoss() 157 | self.pitch_criterion = PitchPredictorLoss() 158 | self.criterion = torch.nn.L1Loss(reduction="mean") 159 | self.use_weighted_masking = hp.model.use_weighted_masking 160 | 161 | def _forward(self, xs: torch.Tensor, ilens: torch.Tensor): 162 | # forward encoder 163 | x_masks = self._source_mask( 164 | ilens 165 | ) # (B, Tmax, Tmax) -> torch.Size([32, 121, 121]) 166 | 167 | hs, _ = self.encoder( 168 | xs, x_masks 169 | ) # (B, Tmax, adim) -> torch.Size([32, 121, 256]) 170 | # print("ys :", ys.shape) 171 | 172 | # # forward duration predictor and length regulator 173 | d_masks = make_pad_mask_script(ilens).to(xs.device) 174 | 175 | d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) 176 | hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim) 177 | 178 | one_hot_energy = self.energy_predictor.inference(hs) # (B, Lmax, adim) 179 | 180 | one_hot_pitch = self.pitch_predictor.inference(hs) # (B, Lmax, adim) 181 | 182 | hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim) 183 | hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim) 184 | 185 | # # forward decoder 186 | # h_masks = self._source_mask(olens) we can find olens from length regulator and then calculate mask 187 | # h_masks = torch.empty(0) 188 | 189 | zs, _ = self.decoder(hs, None) # (B, Lmax, adim) 190 | 191 | before_outs = self.feat_out(zs).view( 192 | zs.size(0), -1, self.odim 193 | ) # (B, Lmax, odim) 194 | 195 | # postnet -> (B, Lmax//r * r, odim) 196 | after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose( 197 | 1, 2 198 | ) 199 | return after_outs 200 | 201 | def forward(self, x: torch.Tensor) -> torch.Tensor: 202 | """Generate the sequence of features given the sequences of characters. 203 | Args: 204 | x (Tensor): Input sequence of characters (T,). 205 | inference_args (Namespace): Dummy for compatibility. 206 | spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). 207 | Returns: 208 | Tensor: Output sequence of features (1, L, odim). 209 | None: Dummy for compatibility. 210 | None: Dummy for compatibility. 211 | """ 212 | # setup batch axis 213 | ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) 214 | xs = x.unsqueeze(0) 215 | 216 | # inference 217 | outs = self._forward(xs, ilens) # (L, odim) 218 | 219 | return outs[0] 220 | 221 | def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: 222 | """Make masks for self-attention. 223 | Examples: 224 | >>> ilens = [5, 3] 225 | >>> self._source_mask(ilens) 226 | tensor([[[1, 1, 1, 1, 1], 227 | [1, 1, 1, 1, 1], 228 | [1, 1, 1, 1, 1], 229 | [1, 1, 1, 1, 1], 230 | [1, 1, 1, 1, 1]], 231 | [[1, 1, 1, 0, 0], 232 | [1, 1, 1, 0, 0], 233 | [1, 1, 1, 0, 0], 234 | [0, 0, 0, 0, 0], 235 | [0, 0, 0, 0, 0]]], dtype=torch.uint8) 236 | """ 237 | x_masks = make_non_pad_mask_script(ilens) 238 | return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1) 239 | 240 | def _reset_parameters( 241 | self, init_type: str, init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0 242 | ): 243 | # initialize parameters 244 | initialize(self, init_type) 245 | # 246 | # initialize alpha in scaled positional encoding 247 | if self.use_scaled_pos_enc: 248 | self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) 249 | self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) 250 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | 5 | def load_hparam_str(hp_str): 6 | path = "temp-restore.yaml" 7 | with open(path, "w") as f: 8 | f.write(hp_str) 9 | ret = HParam(path) 10 | os.remove(path) 11 | return ret 12 | 13 | 14 | def load_hparam(filename): 15 | stream = open(filename, "r") 16 | docs = yaml.load_all(stream, Loader=yaml.Loader) 17 | hparam_dict = dict() 18 | for doc in docs: 19 | for k, v in doc.items(): 20 | hparam_dict[k] = v 21 | return hparam_dict 22 | 23 | 24 | def merge_dict(user, default): 25 | if isinstance(user, dict) and isinstance(default, dict): 26 | for k, v in default.items(): 27 | if k not in user: 28 | user[k] = v 29 | else: 30 | user[k] = merge_dict(user[k], v) 31 | return user 32 | 33 | 34 | class Dotdict(dict): 35 | """ 36 | a dictionary that supports dot notation 37 | as well as dictionary access notation 38 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 39 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 40 | get attributes: d.val2 or d['val2'] 41 | """ 42 | 43 | __getattr__ = dict.__getitem__ 44 | __setattr__ = dict.__setitem__ 45 | __delattr__ = dict.__delitem__ 46 | 47 | def __init__(self, dct=None): 48 | dct = dict() if not dct else dct 49 | for key, value in dct.items(): 50 | if hasattr(value, "keys"): 51 | value = Dotdict(value) 52 | self[key] = value 53 | 54 | 55 | class HParam(Dotdict): 56 | def __init__(self, file): 57 | super(Dotdict, self).__init__() 58 | hp_dict = load_hparam(file) 59 | hp_dotdict = Dotdict(hp_dict) 60 | for k, v in hp_dotdict.items(): 61 | setattr(self, k, v) 62 | 63 | __getattr__ = Dotdict.__getitem__ 64 | __setattr__ = Dotdict.__setitem__ 65 | __delattr__ = Dotdict.__delitem__ 66 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import torch 4 | 5 | matplotlib.use("Agg") 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | def save_attention(attn, path): 10 | fig = plt.figure(figsize=(12, 6)) 11 | plt.imshow(attn.T, interpolation="nearest", aspect="auto") 12 | fig.savefig(f"{path}.png", bbox_inches="tight") 13 | plt.close(fig) 14 | 15 | 16 | def save_spectrogram(M, path, length=None): 17 | M = np.flip(M, axis=0) 18 | if length: 19 | M = M[:, :length] 20 | fig = plt.figure(figsize=(12, 6)) 21 | plt.imshow(M, interpolation="nearest", aspect="auto") 22 | fig.savefig(f"{path}.png", bbox_inches="tight") 23 | plt.close(fig) 24 | 25 | 26 | def plot(array): 27 | fig = plt.figure(figsize=(30, 5)) 28 | ax = fig.add_subplot(111) 29 | ax.xaxis.label.set_color("grey") 30 | ax.yaxis.label.set_color("grey") 31 | ax.xaxis.label.set_fontsize(23) 32 | ax.yaxis.label.set_fontsize(23) 33 | ax.tick_params(axis="x", colors="grey", labelsize=23) 34 | ax.tick_params(axis="y", colors="grey", labelsize=23) 35 | plt.plot(array) 36 | 37 | 38 | def plot_spec(M): 39 | M = np.flip(M, axis=0) 40 | plt.figure(figsize=(18, 4)) 41 | plt.imshow(M, interpolation="nearest", aspect="auto") 42 | plt.show() 43 | 44 | 45 | def plot_image(target, melspec, mel_lengths): # , alignments 46 | fig, axes = plt.subplots(2, 1, figsize=(20, 20)) 47 | T = mel_lengths[-1] 48 | 49 | axes[0].imshow(target[-1].T.detach().cpu()[:, :T], origin="lower", aspect="auto") 50 | 51 | axes[1].imshow(melspec.cpu()[:, :T], origin="lower", aspect="auto") 52 | 53 | return fig 54 | 55 | 56 | def save_figure_to_numpy(fig, spectrogram=False): 57 | # save it to a numpy array. 58 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 59 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 60 | if spectrogram: 61 | return data 62 | data = np.transpose(data, (2, 0, 1)) 63 | return data 64 | 65 | 66 | def plot_waveform_to_numpy(waveform): 67 | fig, ax = plt.subplots(figsize=(12, 3)) 68 | ax.plot() 69 | ax.plot(range(len(waveform)), waveform, linewidth=0.1, alpha=0.7, color="blue") 70 | 71 | plt.xlabel("Samples") 72 | plt.ylabel("Amplitude") 73 | plt.ylim(-1, 1) 74 | plt.tight_layout() 75 | 76 | fig.canvas.draw() 77 | data = save_figure_to_numpy(fig) 78 | plt.close() 79 | return data 80 | 81 | 82 | def plot_spectrogram_to_numpy(spectrogram): 83 | fig, ax = plt.subplots(figsize=(12, 3)) 84 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 85 | plt.colorbar(im, ax=ax) 86 | plt.xlabel("Frames") 87 | plt.ylabel("Channels") 88 | plt.tight_layout() 89 | 90 | fig.canvas.draw() 91 | data = save_figure_to_numpy(fig, True) 92 | plt.close() 93 | return data 94 | 95 | 96 | def generate_audio(mel, vocoder): 97 | # input mel shape - [1,80,T] 98 | vocoder.eval() 99 | if torch.cuda.is_available(): 100 | vocoder = vocoder.cuda() 101 | mel = mel.cuda() 102 | 103 | with torch.no_grad(): 104 | audio = vocoder.inference(mel) 105 | return audio 106 | -------------------------------------------------------------------------------- /utils/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | Copyright (c) 2017, Prem Seetharaman 4 | All rights reserved. 5 | * Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its 13 | contributors may be used to endorse or promote products derived from this 14 | software without specific prior written permission. 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | import torch 28 | import numpy as np 29 | import torch.nn.functional as F 30 | from torch.autograd import Variable 31 | from scipy.signal import get_window 32 | from librosa.util import pad_center, tiny 33 | from dataset.audio_processing import ( 34 | window_sumsquare, 35 | dynamic_range_compression, 36 | dynamic_range_decompression, 37 | ) 38 | from librosa.filters import mel as librosa_mel_fn 39 | 40 | 41 | class STFT(torch.nn.Module): 42 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 43 | 44 | def __init__( 45 | self, filter_length=800, hop_length=200, win_length=800, window="hann" 46 | ): 47 | super(STFT, self).__init__() 48 | self.filter_length = filter_length 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.window = window 52 | self.forward_transform = None 53 | scale = self.filter_length / self.hop_length 54 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 55 | 56 | cutoff = int((self.filter_length / 2 + 1)) 57 | fourier_basis = np.vstack( 58 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 59 | ) 60 | 61 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 62 | inverse_basis = torch.FloatTensor( 63 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 64 | ) 65 | 66 | if window is not None: 67 | assert filter_length >= win_length 68 | # get window and zero center pad it to filter_length 69 | fft_window = get_window(window, win_length, fftbins=True) 70 | fft_window = pad_center(fft_window, filter_length) 71 | fft_window = torch.from_numpy(fft_window).float() 72 | 73 | # window the bases 74 | forward_basis *= fft_window 75 | inverse_basis *= fft_window 76 | 77 | self.register_buffer("forward_basis", forward_basis.float()) 78 | self.register_buffer("inverse_basis", inverse_basis.float()) 79 | 80 | def transform(self, input_data): 81 | num_batches = input_data.size(0) 82 | num_samples = input_data.size(1) 83 | 84 | self.num_samples = num_samples 85 | 86 | # similar to librosa, reflect-pad the input 87 | input_data = input_data.view(num_batches, 1, num_samples) 88 | input_data = F.pad( 89 | input_data.unsqueeze(1), 90 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 91 | mode="reflect", 92 | ) 93 | input_data = input_data.squeeze(1) 94 | 95 | # https://github.com/NVIDIA/tacotron2/issues/125 96 | forward_transform = F.conv1d( 97 | input_data.cuda(), 98 | Variable(self.forward_basis, requires_grad=False).cuda(), 99 | stride=self.hop_length, 100 | padding=0, 101 | ).cpu() 102 | 103 | cutoff = int((self.filter_length / 2) + 1) 104 | real_part = forward_transform[:, :cutoff, :] 105 | imag_part = forward_transform[:, cutoff:, :] 106 | 107 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 108 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 109 | 110 | return magnitude, phase 111 | 112 | def inverse(self, magnitude, phase): 113 | recombine_magnitude_phase = torch.cat( 114 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 115 | ) 116 | 117 | inverse_transform = F.conv_transpose1d( 118 | recombine_magnitude_phase, 119 | Variable(self.inverse_basis, requires_grad=False), 120 | stride=self.hop_length, 121 | padding=0, 122 | ) 123 | 124 | if self.window is not None: 125 | window_sum = window_sumsquare( 126 | self.window, 127 | magnitude.size(-1), 128 | hop_length=self.hop_length, 129 | win_length=self.win_length, 130 | n_fft=self.filter_length, 131 | dtype=np.float32, 132 | ) 133 | # remove modulation effects 134 | approx_nonzero_indices = torch.from_numpy( 135 | np.where(window_sum > tiny(window_sum))[0] 136 | ) 137 | window_sum = torch.autograd.Variable( 138 | torch.from_numpy(window_sum), requires_grad=False 139 | ) 140 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 141 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 142 | approx_nonzero_indices 143 | ] 144 | 145 | # scale by hop ratio 146 | inverse_transform *= float(self.filter_length) / self.hop_length 147 | 148 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 149 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 150 | 151 | return inverse_transform 152 | 153 | def forward(self, input_data): 154 | self.magnitude, self.phase = self.transform(input_data) 155 | reconstruction = self.inverse(self.magnitude, self.phase) 156 | return reconstruction 157 | 158 | 159 | class TacotronSTFT(torch.nn.Module): 160 | def __init__( 161 | self, 162 | filter_length=1024, 163 | hop_length=256, 164 | win_length=1024, 165 | n_mel_channels=80, 166 | sampling_rate=22050, 167 | mel_fmin=0.0, 168 | mel_fmax=None, 169 | ): 170 | super(TacotronSTFT, self).__init__() 171 | self.n_mel_channels = n_mel_channels 172 | self.sampling_rate = sampling_rate 173 | self.stft_fn = STFT(filter_length, hop_length, win_length) 174 | mel_basis = librosa_mel_fn( 175 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 176 | ) 177 | mel_basis = torch.from_numpy(mel_basis).float() 178 | self.register_buffer("mel_basis", mel_basis) 179 | 180 | def spectral_normalize(self, magnitudes): 181 | output = dynamic_range_compression(magnitudes) 182 | return output 183 | 184 | def spectral_de_normalize(self, magnitudes): 185 | output = dynamic_range_decompression(magnitudes) 186 | return output 187 | 188 | def mel_spectrogram(self, y): 189 | """Computes mel-spectrograms from a batch of waves 190 | PARAMS 191 | ------ 192 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 193 | RETURNS 194 | ------- 195 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 196 | """ 197 | assert torch.min(y.data) >= -1 198 | assert torch.max(y.data) <= 1 199 | 200 | magnitudes, phases = self.stft_fn.transform(y) 201 | magnitudes = magnitudes.data 202 | mel_output = torch.matmul(self.mel_basis, magnitudes) 203 | mel_output = self.spectral_normalize(mel_output) 204 | return mel_output, magnitudes 205 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Network related utility tools.""" 4 | 5 | import numpy as np 6 | import torch 7 | import argparse 8 | import json 9 | import os 10 | import logging 11 | import subprocess 12 | from scipy.io.wavfile import read 13 | import librosa 14 | import glob 15 | from typing import List 16 | import torch.nn.functional as F 17 | 18 | 19 | def get_files(path, extension=".wav"): 20 | filenames = [] 21 | for filename in glob.iglob(f"{path}/**/*{extension}", recursive=True): 22 | filenames += [filename] 23 | return filenames 24 | 25 | 26 | def is_outlier(x, p25, p75): 27 | """Check if value is an outlier.""" 28 | lower = p25 - 1.5 * (p75 - p25) 29 | upper = p75 + 1.5 * (p75 - p25) 30 | 31 | return x <= lower or x >= upper 32 | 33 | 34 | def remove_outlier(x): 35 | """Remove outlier from x.""" 36 | p25 = np.percentile(x, 25) 37 | p75 = np.percentile(x, 75) 38 | zero_idxs = np.where(x == 0.0)[0] 39 | indices_of_outliers = [] 40 | for ind, value in enumerate(x): 41 | if is_outlier(value, p25, p75): 42 | indices_of_outliers.append(ind) 43 | 44 | x[indices_of_outliers] = 0.0 45 | 46 | # replace by mean f0. 47 | x[indices_of_outliers] = np.max(x) 48 | x[zero_idxs] = 0.0 49 | return x 50 | 51 | 52 | def str_to_int_list(s): 53 | return list(map(int, s.split())) 54 | 55 | 56 | def to_device(m, x): 57 | """Send tensor into the device of the module. 58 | 59 | Args: 60 | m (torch.nn.Module): Torch module. 61 | x (Tensor): Torch tensor. 62 | 63 | Returns: 64 | Tensor: Torch tensor located in the same place as torch module. 65 | 66 | """ 67 | assert isinstance(m, torch.nn.Module) 68 | device = next(m.parameters()).device 69 | return x.to(device) 70 | 71 | 72 | @torch.jit.script 73 | def pad_1d_tensor(xs: List[torch.Tensor]): 74 | 75 | length = torch.jit.annotate(List[int], []) 76 | 77 | for x in xs: 78 | 79 | length.append(x.size(0)) 80 | 81 | max_len = max(length) 82 | x_padded = [] 83 | 84 | for x in xs: 85 | x_padded.append(F.pad(x, (0, max_len - x.shape[0]))) 86 | padded = torch.stack(x_padded) 87 | 88 | return padded 89 | 90 | 91 | @torch.jit.script 92 | def pad_2d_tensor(xs: List[torch.Tensor], pad_value: float = 0.0): 93 | max_len = max([xs[i].size(0) for i in range(len(xs))]) 94 | 95 | out_list = [] 96 | 97 | for i, batch in enumerate(xs): 98 | one_batch_padded = F.pad( 99 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", pad_value 100 | ) 101 | out_list.append(one_batch_padded) 102 | 103 | out_padded = torch.stack(out_list) 104 | return out_padded 105 | 106 | 107 | def pad_list(xs, pad_value): 108 | """Perform padding for the list of tensors. 109 | 110 | Args: 111 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 112 | pad_value (float): Value for padding. 113 | 114 | Returns: 115 | Tensor: Padded tensor (B, Tmax, `*`). 116 | 117 | Examples: 118 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 119 | >>> x 120 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 121 | >>> pad_list(x, 0) 122 | tensor([[1., 1., 1., 1.], 123 | [1., 1., 0., 0.], 124 | [1., 0., 0., 0.]]) 125 | 126 | """ 127 | n_batch = len(xs) 128 | max_len = max(x.size(0) for x in xs) 129 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 130 | 131 | for i in range(n_batch): 132 | pad[i, : xs[i].size(0)] = xs[i] 133 | 134 | return pad 135 | 136 | 137 | def subsequent_mask(size, device="cuda", dtype=torch.uint8): 138 | """Create mask for subsequent steps (1, size, size) 139 | 140 | :param int size: size of mask 141 | :param str device: "cpu" or "cuda" or torch.Tensor.device 142 | :param torch.dtype dtype: result dtype 143 | :rtype: torch.Tensor 144 | >>> subsequent_mask(3) 145 | [[1, 0, 0], 146 | [1, 1, 0], 147 | [1, 1, 1]] 148 | """ 149 | ret = torch.ones(size, size, device=device, dtype=dtype) 150 | return torch.tril(ret, out=ret) 151 | 152 | 153 | @torch.jit.script 154 | def tensor_1d_tolist(x): 155 | result: List[int] = [] 156 | for i in x: 157 | result.append(i.item()) 158 | return result 159 | 160 | 161 | @torch.jit.script 162 | def make_pad_mask_script(lengths: torch.Tensor): 163 | 164 | if not isinstance(lengths, list): 165 | lengths = tensor_1d_tolist(lengths) 166 | 167 | bs = int(len(lengths)) 168 | maxlen = int(max(lengths)) 169 | 170 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 171 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 172 | seq_length_expand = torch.tensor(lengths).unsqueeze(-1) 173 | mask = seq_range_expand >= seq_length_expand 174 | 175 | return mask 176 | 177 | 178 | def make_pad_mask(lengths: List[int], xs: torch.Tensor = None, length_dim: int = -1): 179 | """Make mask tensor containing indices of padded part. 180 | 181 | Args: 182 | lengths (LongTensor or List): Batch of lengths (B,). 183 | xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. 184 | length_dim (int, optional): Dimension indicator of the above tensor. See the example. 185 | 186 | Returns: 187 | Tensor: Mask tensor containing indices of padded part. 188 | 189 | Examples: 190 | With only lengths. 191 | 192 | >>> lengths = [5, 3, 2] 193 | >>> make_non_pad_mask(lengths) 194 | masks = [[0, 0, 0, 0 ,0], 195 | [0, 0, 0, 1, 1], 196 | [0, 0, 1, 1, 1]] 197 | 198 | With the reference tensor. 199 | 200 | >>> xs = torch.zeros((3, 2, 4)) 201 | >>> make_pad_mask(lengths, xs) 202 | tensor([[[0, 0, 0, 0], 203 | [0, 0, 0, 0]], 204 | [[0, 0, 0, 1], 205 | [0, 0, 0, 1]], 206 | [[0, 0, 1, 1], 207 | [0, 0, 1, 1]]], dtype=torch.uint8) 208 | >>> xs = torch.zeros((3, 2, 6)) 209 | >>> make_pad_mask(lengths, xs) 210 | tensor([[[0, 0, 0, 0, 0, 1], 211 | [0, 0, 0, 0, 0, 1]], 212 | [[0, 0, 0, 1, 1, 1], 213 | [0, 0, 0, 1, 1, 1]], 214 | [[0, 0, 1, 1, 1, 1], 215 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 216 | 217 | With the reference tensor and dimension indicator. 218 | 219 | >>> xs = torch.zeros((3, 6, 6)) 220 | >>> make_pad_mask(lengths, xs, 1) 221 | tensor([[[0, 0, 0, 0, 0, 0], 222 | [0, 0, 0, 0, 0, 0], 223 | [0, 0, 0, 0, 0, 0], 224 | [0, 0, 0, 0, 0, 0], 225 | [0, 0, 0, 0, 0, 0], 226 | [1, 1, 1, 1, 1, 1]], 227 | [[0, 0, 0, 0, 0, 0], 228 | [0, 0, 0, 0, 0, 0], 229 | [0, 0, 0, 0, 0, 0], 230 | [1, 1, 1, 1, 1, 1], 231 | [1, 1, 1, 1, 1, 1], 232 | [1, 1, 1, 1, 1, 1]], 233 | [[0, 0, 0, 0, 0, 0], 234 | [0, 0, 0, 0, 0, 0], 235 | [1, 1, 1, 1, 1, 1], 236 | [1, 1, 1, 1, 1, 1], 237 | [1, 1, 1, 1, 1, 1], 238 | [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) 239 | >>> make_pad_mask(lengths, xs, 2) 240 | tensor([[[0, 0, 0, 0, 0, 1], 241 | [0, 0, 0, 0, 0, 1], 242 | [0, 0, 0, 0, 0, 1], 243 | [0, 0, 0, 0, 0, 1], 244 | [0, 0, 0, 0, 0, 1], 245 | [0, 0, 0, 0, 0, 1]], 246 | [[0, 0, 0, 1, 1, 1], 247 | [0, 0, 0, 1, 1, 1], 248 | [0, 0, 0, 1, 1, 1], 249 | [0, 0, 0, 1, 1, 1], 250 | [0, 0, 0, 1, 1, 1], 251 | [0, 0, 0, 1, 1, 1]], 252 | [[0, 0, 1, 1, 1, 1], 253 | [0, 0, 1, 1, 1, 1], 254 | [0, 0, 1, 1, 1, 1], 255 | [0, 0, 1, 1, 1, 1], 256 | [0, 0, 1, 1, 1, 1], 257 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 258 | 259 | """ 260 | if length_dim == 0: 261 | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) 262 | 263 | if not isinstance(lengths, list): 264 | lengths = lengths.tolist() 265 | bs = int(len(lengths)) 266 | if xs is None: 267 | maxlen = int(max(lengths)) 268 | else: 269 | maxlen = xs.size(length_dim) 270 | 271 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 272 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 273 | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) 274 | mask = seq_range_expand >= seq_length_expand 275 | 276 | if xs is not None: 277 | assert xs.size(0) == bs, (xs.size(0), bs) 278 | 279 | if length_dim < 0: 280 | length_dim = xs.dim() + length_dim 281 | # ind = (:, None, ..., None, :, , None, ..., None) 282 | ind = tuple( 283 | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) 284 | ) 285 | mask = mask[ind].expand_as(xs).to(xs.device) 286 | return mask 287 | 288 | 289 | @torch.jit.script 290 | def make_non_pad_mask_script(lengths: torch.Tensor): 291 | return ~make_pad_mask_script(lengths) 292 | 293 | 294 | def make_non_pad_mask(lengths, xs=None, length_dim=-1): 295 | """Make mask tensor containing indices of non-padded part. 296 | 297 | Args: 298 | lengths (LongTensor or List): Batch of lengths (B,). 299 | xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. 300 | length_dim (int, optional): Dimension indicator of the above tensor. See the example. 301 | 302 | Returns: 303 | ByteTensor: mask tensor containing indices of padded part. 304 | 305 | Examples: 306 | With only lengths. 307 | 308 | >>> lengths = [5, 3, 2] 309 | >>> make_non_pad_mask(lengths) 310 | masks = [[1, 1, 1, 1 ,1], 311 | [1, 1, 1, 0, 0], 312 | [1, 1, 0, 0, 0]] 313 | 314 | With the reference tensor. 315 | 316 | >>> xs = torch.zeros((3, 2, 4)) 317 | >>> make_non_pad_mask(lengths, xs) 318 | tensor([[[1, 1, 1, 1], 319 | [1, 1, 1, 1]], 320 | [[1, 1, 1, 0], 321 | [1, 1, 1, 0]], 322 | [[1, 1, 0, 0], 323 | [1, 1, 0, 0]]], dtype=torch.uint8) 324 | >>> xs = torch.zeros((3, 2, 6)) 325 | >>> make_non_pad_mask(lengths, xs) 326 | tensor([[[1, 1, 1, 1, 1, 0], 327 | [1, 1, 1, 1, 1, 0]], 328 | [[1, 1, 1, 0, 0, 0], 329 | [1, 1, 1, 0, 0, 0]], 330 | [[1, 1, 0, 0, 0, 0], 331 | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) 332 | 333 | With the reference tensor and dimension indicator. 334 | 335 | >>> xs = torch.zeros((3, 6, 6)) 336 | >>> make_non_pad_mask(lengths, xs, 1) 337 | tensor([[[1, 1, 1, 1, 1, 1], 338 | [1, 1, 1, 1, 1, 1], 339 | [1, 1, 1, 1, 1, 1], 340 | [1, 1, 1, 1, 1, 1], 341 | [1, 1, 1, 1, 1, 1], 342 | [0, 0, 0, 0, 0, 0]], 343 | [[1, 1, 1, 1, 1, 1], 344 | [1, 1, 1, 1, 1, 1], 345 | [1, 1, 1, 1, 1, 1], 346 | [0, 0, 0, 0, 0, 0], 347 | [0, 0, 0, 0, 0, 0], 348 | [0, 0, 0, 0, 0, 0]], 349 | [[1, 1, 1, 1, 1, 1], 350 | [1, 1, 1, 1, 1, 1], 351 | [0, 0, 0, 0, 0, 0], 352 | [0, 0, 0, 0, 0, 0], 353 | [0, 0, 0, 0, 0, 0], 354 | [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) 355 | >>> make_non_pad_mask(lengths, xs, 2) 356 | tensor([[[1, 1, 1, 1, 1, 0], 357 | [1, 1, 1, 1, 1, 0], 358 | [1, 1, 1, 1, 1, 0], 359 | [1, 1, 1, 1, 1, 0], 360 | [1, 1, 1, 1, 1, 0], 361 | [1, 1, 1, 1, 1, 0]], 362 | [[1, 1, 1, 0, 0, 0], 363 | [1, 1, 1, 0, 0, 0], 364 | [1, 1, 1, 0, 0, 0], 365 | [1, 1, 1, 0, 0, 0], 366 | [1, 1, 1, 0, 0, 0], 367 | [1, 1, 1, 0, 0, 0]], 368 | [[1, 1, 0, 0, 0, 0], 369 | [1, 1, 0, 0, 0, 0], 370 | [1, 1, 0, 0, 0, 0], 371 | [1, 1, 0, 0, 0, 0], 372 | [1, 1, 0, 0, 0, 0], 373 | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) 374 | 375 | """ 376 | return ~make_pad_mask(lengths, xs, length_dim) 377 | 378 | 379 | def mask_by_length(xs, lengths, fill=0): 380 | """Mask tensor according to length. 381 | 382 | Args: 383 | xs (Tensor): Batch of input tensor (B, `*`). 384 | lengths (LongTensor or List): Batch of lengths (B,). 385 | fill (int or float): Value to fill masked part. 386 | 387 | Returns: 388 | Tensor: Batch of masked input tensor (B, `*`). 389 | 390 | Examples: 391 | >>> x = torch.arange(5).repeat(3, 1) + 1 392 | >>> x 393 | tensor([[1, 2, 3, 4, 5], 394 | [1, 2, 3, 4, 5], 395 | [1, 2, 3, 4, 5]]) 396 | >>> lengths = [5, 3, 2] 397 | >>> mask_by_length(x, lengths) 398 | tensor([[1, 2, 3, 4, 5], 399 | [1, 2, 3, 0, 0], 400 | [1, 2, 0, 0, 0]]) 401 | 402 | """ 403 | assert xs.size(0) == len(lengths) 404 | ret = xs.data.new(*xs.size()).fill_(fill) 405 | for i, l in enumerate(lengths): 406 | ret[i, :l] = xs[i, :l] 407 | return ret 408 | 409 | 410 | def th_accuracy(pad_outputs, pad_targets, ignore_label): 411 | """Calculate accuracy. 412 | 413 | Args: 414 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 415 | pad_targets (LongTensor): Target label tensors (B, Lmax, D). 416 | ignore_label (int): Ignore label id. 417 | 418 | Returns: 419 | float: Accuracy value (0.0 - 1.0). 420 | 421 | """ 422 | pad_pred = pad_outputs.view( 423 | pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) 424 | ).argmax(2) 425 | mask = pad_targets != ignore_label 426 | numerator = torch.sum( 427 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask) 428 | ) 429 | denominator = torch.sum(mask) 430 | return float(numerator) / float(denominator) 431 | 432 | 433 | def to_torch_tensor(x): 434 | """Change to torch.Tensor or ComplexTensor from numpy.ndarray. 435 | 436 | Args: 437 | x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. 438 | 439 | Returns: 440 | Tensor or ComplexTensor: Type converted inputs. 441 | 442 | Examples: 443 | >>> xs = np.ones(3, dtype=np.float32) 444 | >>> xs = to_torch_tensor(xs) 445 | tensor([1., 1., 1.]) 446 | >>> xs = torch.ones(3, 4, 5) 447 | >>> assert to_torch_tensor(xs) is xs 448 | >>> xs = {'real': xs, 'imag': xs} 449 | >>> to_torch_tensor(xs) 450 | ComplexTensor( 451 | Real: 452 | tensor([1., 1., 1.]) 453 | Imag; 454 | tensor([1., 1., 1.]) 455 | ) 456 | 457 | """ 458 | # If numpy, change to torch tensor 459 | if isinstance(x, np.ndarray): 460 | if x.dtype.kind == "c": 461 | # Dynamically importing because torch_complex requires python3 462 | from torch_complex.tensor import ComplexTensor 463 | 464 | return ComplexTensor(x) 465 | else: 466 | return torch.from_numpy(x) 467 | 468 | # If {'real': ..., 'imag': ...}, convert to ComplexTensor 469 | elif isinstance(x, dict): 470 | # Dynamically importing because torch_complex requires python3 471 | from torch_complex.tensor import ComplexTensor 472 | 473 | if "real" not in x or "imag" not in x: 474 | raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) 475 | # Relative importing because of using python3 syntax 476 | return ComplexTensor(x["real"], x["imag"]) 477 | 478 | # If torch.Tensor, as it is 479 | elif isinstance(x, torch.Tensor): 480 | return x 481 | 482 | else: 483 | error = ( 484 | "x must be numpy.ndarray, torch.Tensor or a dict like " 485 | "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " 486 | "but got {}".format(type(x)) 487 | ) 488 | try: 489 | from torch_complex.tensor import ComplexTensor 490 | except Exception: 491 | # If PY2 492 | raise ValueError(error) 493 | else: 494 | # If PY3 495 | if isinstance(x, ComplexTensor): 496 | return x 497 | else: 498 | raise ValueError(error) 499 | 500 | 501 | def set_deterministic_pytorch(args): 502 | """Ensures pytorch produces deterministic results depending on the program arguments 503 | :param Namespace args: The program arguments 504 | """ 505 | # seed setting 506 | torch.manual_seed(args.seed) 507 | 508 | # debug mode setting 509 | # 0 would be fastest, but 1 seems to be reasonable 510 | # considering reproducibility 511 | # remove type check 512 | torch.backends.cudnn.deterministic = True 513 | torch.backends.cudnn.benchmark = ( 514 | False # https://github.com/pytorch/pytorch/issues/6351 515 | ) 516 | 517 | 518 | def torch_load(path, model): 519 | """Load torch model states. 520 | 521 | Args: 522 | path (str): Model path or snapshot file path to be loaded. 523 | model (torch.nn.Module): Torch model. 524 | 525 | """ 526 | if "snapshot" in path: 527 | model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[ 528 | "model" 529 | ] 530 | else: 531 | model_state_dict = torch.load(path, map_location=lambda storage, loc: storage) 532 | 533 | if hasattr(model, "module"): 534 | model.module.load_state_dict(model_state_dict) 535 | else: 536 | model.load_state_dict(model_state_dict) 537 | 538 | del model_state_dict 539 | 540 | # * -------------------- general -------------------- * 541 | 542 | 543 | def get_model_conf(model_path, conf_path=None): 544 | """Get model config information by reading a model config file (model.json). 545 | 546 | Args: 547 | model_path (str): Model path. 548 | conf_path (str): Optional model config path. 549 | 550 | Returns: 551 | list[int, int, dict[str, Any]]: Config information loaded from json file. 552 | 553 | """ 554 | if conf_path is None: 555 | model_conf = os.path.dirname(model_path) + "/model.json" 556 | else: 557 | model_conf = conf_path 558 | with open(model_conf, "rb") as f: 559 | logging.info("reading a config file from " + model_conf) 560 | confs = json.load(f) 561 | if isinstance(confs, dict): 562 | # for lm 563 | args = confs 564 | return argparse.Namespace(**args) 565 | else: 566 | # for asr, tts, mt 567 | idim, odim, args = confs 568 | return idim, odim, argparse.Namespace(**args) 569 | 570 | 571 | def get_commit_hash(): 572 | message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) 573 | return message.strip().decode("utf-8") 574 | 575 | 576 | def read_wav_np(path, sample_rate): 577 | sr, wav = read(path) 578 | if sr == sample_rate: 579 | 580 | if len(wav.shape) == 2: 581 | wav = wav[:, 0] 582 | 583 | if wav.dtype == np.int16: 584 | wav = wav / 32768.0 585 | elif wav.dtype == np.int32: 586 | wav = wav / 2147483648.0 587 | elif wav.dtype == np.uint8: 588 | wav = (wav - 128) / 128.0 589 | else: 590 | wav = librosa.load(path, sr=sample_rate)[0] 591 | 592 | wav = wav.astype(np.float32) 593 | 594 | return sr, wav 595 | --------------------------------------------------------------------------------