├── .gitignore ├── LICENSE ├── README.md ├── config ├── LJSpeech │ └── config.yaml ├── MS_Persian │ └── config.yaml └── Persian │ └── config.yaml ├── inference.py ├── requirements.txt ├── resgrad.PNG ├── resgrad ├── data.py ├── inference.py ├── model │ ├── __init__.py │ ├── base.py │ ├── diffusion.py │ └── optimizer.py ├── train.py └── utils.py ├── resgrad_data.py ├── synthesizer ├── audio │ ├── __init__.py │ ├── audio_processing.py │ ├── stft.py │ └── tools.py ├── dataset.py ├── evaluate.py ├── model │ ├── __init__.py │ ├── fastspeech2.py │ ├── loss.py │ ├── modules.py │ └── optimizer.py ├── prepare_align.py ├── preprocess.py ├── preprocessor │ ├── ljspeech.py │ ├── persian.py │ ├── persian_v1.py │ └── preprocessor.py ├── synthesize.py ├── text │ ├── __init__.py │ ├── cleaners.py │ ├── cmudict.py │ ├── numbers.py │ ├── pinyin.py │ └── symbols.py ├── train.py ├── transformer │ ├── Constants.py │ ├── Layers.py │ ├── Models.py │ ├── Modules.py │ ├── SubLayers.py │ └── __init__.py ├── utils │ ├── model.py │ └── tools.py └── vocoder ├── train_resgrad.py ├── train_synthesizer.py ├── utils.py └── vocoder ├── ckpt └── config.json ├── inference.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## Synthesizer 2 | synthesizer/raw_data/* 3 | synthesizer/preprocessed_data/* 4 | synthesizer/*.npy 5 | 6 | ## Vocoder 7 | vocoder/ckpt/g_2500000_persian 8 | vocoder/ckpt/g_2500000 9 | 10 | ## Outputs 11 | output/* 12 | 13 | ## Data 14 | dataset/* 15 | # dataset/Persian/synthesizer_data/* 16 | # dataset/Persian/resgrad_data/* 17 | # dataset/MS_Persian/synthesizer_data/* 18 | # dataset/MS_Persian/resgrad_data/* 19 | # dataset/LJSpeech/synthesizer_data/wavs 20 | # dataset/LJSpeech/resgrad_data/* 21 | 22 | ## Others 23 | auto_inference.py 24 | commands.txt 25 | *__pycache__/ 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Majid Adibian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ResGrad - PyTorch Implementation 2 | [**ResGrad: Residual Denoising Diffusion Probabilistic Models for Text to Speech**](https://arxiv.org/abs/2212.14518) 3 | 4 | This is an *unofficial* PyTorch implementation of **ResGrad** as a high-quality denoising model for Text to Speech. In short, this model generates the spectrogram using FastSpeech2 and then removes the noise in the spectrogram using the Diffusion method to synthesize high-quality speeches. As mentioned in the paper the implementation is based on FastSpeech2 and Grad-TTS. Also, the HiFiGAN model is used to generate waveforms from synthesized spectrograms. 5 | 6 | 7 | 8 | ## Quickstart 9 | Data structures: 10 | ``` 11 | dataset/data_name/synthesizer_data/ 12 | test_data/ 13 | speaker1/ 14 | sample1.txt 15 | sample1.wav 16 | ... 17 | ... 18 | train_data/ 19 | ... 20 | test.txt (sample1|speaker1|*phoneme_sequence \n ...) 21 | train.txt (sample1|speaker1|*phoneme_sequence \n ...) 22 | ``` 23 | 24 | Preprocessing: 25 | ``` 26 | python synthesizer/prepare_align.py config/data_name/config.yaml 27 | python synthesizer/preprocess.py config/data_name/config.yaml 28 | ``` 29 | 30 | Train synthesizer: 31 | ``` 32 | python train_synthesizer.py --config config/data_name/config.yaml 33 | ``` 34 | 35 | Prepare data for ResGrade: 36 | ``` 37 | python resgrad_data.py --synthesizer_restore_step 1000000 --data_file_path dataset/data_name/synthesizer_data/train.txt \ 38 | --config config/data_name/config.yaml 39 | ``` 40 | 41 | Train ResGrade: 42 | ``` 43 | python train_resgrad.py --config config/data_name/config.yaml 44 | ``` 45 | 46 | Inference: 47 | ``` 48 | python inference.py --text "phonemes sequence example" \ 49 | --synthesizer_restore_step 1000000 --regrad_restore_step 1000000 --vocoder_restore_step 2500000 \ 50 | --config config/data_name/config.yaml --result_dir output/data_name/results 51 | ``` 52 | 53 | ## References :notebook_with_decorative_cover: 54 | - [ResGrad: Residual Denoising Diffusion Probabilistic Models for Text to Speech](https://arxiv.org/abs/2212.14518), Z. Chen, *et al*. 55 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558), Y. Ren, *et al*. 56 | - [Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech](https://arxiv.org/abs/2105.06337), V. Popov, *et al*. 57 | 58 | 59 | ## License 60 | This repository is an implementation of the [**ResGrad**](https://arxiv.org/abs/2212.14518) paper, and is licensed under the MIT License. 61 | -------------------------------------------------------------------------------- /config/LJSpeech/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################################################ 2 | ####################################### Main Config ######################################## 3 | ############################################################################################ 4 | main: 5 | dataset: &Dataset "LJSpeech" 6 | multi_speaker: False 7 | device: "cuda" ## "cpu" or "cuda" (for using all GPUs) or "cuda:0" or "cuda:1" or ... 8 | 9 | 10 | ############################################################################################ 11 | ####################################### Synthesizer ######################################## 12 | ############################################################################################ 13 | synthesizer: 14 | ##################### Model ####################### 15 | model: 16 | transformer: 17 | encoder_layer: 4 18 | encoder_head: 2 19 | encoder_hidden: 256 20 | decoder_layer: 6 21 | decoder_head: 2 22 | decoder_hidden: 256 23 | conv_filter_size: 1024 24 | conv_kernel_size: [9, 1] 25 | encoder_dropout: 0.2 26 | decoder_dropout: 0.2 27 | 28 | variance_predictor: 29 | filter_size: 256 30 | kernel_size: 3 31 | dropout: 0.5 32 | 33 | variance_embedding: 34 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 35 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 36 | n_bins: 256 37 | max_seq_len: 1000 38 | 39 | #################### Preprocesss ##################### 40 | preprocess: 41 | path: 42 | corpus_path: !join ["dataset/", *Dataset, "synthesizer_data"] 43 | raw_path: !join ["synthesizer/raw_data/", *Dataset] 44 | preprocessed_path: !join ["synthesizer/preprocessed_data/", *Dataset] 45 | lexicon_path: !join ["dataset/", *Dataset, "synthesizer_data/librispeech-lexicon.txt"] 46 | preprocessing: 47 | val_size: 100 48 | text: 49 | text_cleaners: "english_cleaner" ## english_cleaners or persian_cleaners 50 | language: "en" ## fa or en 51 | audio: 52 | sampling_rate: 22050 53 | max_wav_value: 32768.0 54 | stft: 55 | filter_length: 1024 56 | hop_length: 256 57 | win_length: 1024 58 | mel: 59 | n_mel_channels: 80 60 | mel_fmin: 0 61 | mel_fmax: 8000 62 | pitch: 63 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 64 | normalization: True 65 | energy: 66 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 67 | normalization: True 68 | 69 | #################### Training ##################### 70 | train: 71 | path: 72 | ckpt_path: !join ["output/", *Dataset, "synthesizer/ckpt"] 73 | log_path: !join ["output/", *Dataset, "synthesizer/log"] 74 | # result_path: !join ["output/", *Dataset, "synthesizer/result"] 75 | optimizer: 76 | batch_size: 16 77 | betas: [0.9, 0.98] 78 | eps: 0.000000001 79 | weight_decay: 0.0 80 | grad_clip_thresh: 1.0 81 | grad_acc_step: 1 82 | warm_up_step: 4000 83 | anneal_steps: [300000, 400000, 500000] 84 | anneal_rate: 0.3 85 | step: 86 | total_step: 1010000 87 | log_step: 500 88 | synth_step: 1000 89 | val_step: 1000 90 | save_step: 100000 91 | 92 | 93 | ############################################################################################ 94 | ########################################### ResGrad ######################################## 95 | ############################################################################################ 96 | resgrad: 97 | #################### Data ##################### 98 | data: 99 | batch_size: 32 100 | metadata_path: !join ["dataset/", *Dataset, "resgrad_data/metadata.csv"] 101 | input_mel_dir: !join ["dataset/", *Dataset, "resgrad_data/input_mel"] 102 | speaker_map_path: !join ["synthesizer/preprocessed_data", *Dataset, "speakers.json"] 103 | val_size: 128 104 | preprocessed_path: "processed_data" 105 | normalized_method: "min-max" 106 | 107 | shuffle_data: True 108 | normallize_spectrum: True 109 | min_spec_value: -13 110 | max_spec_value: 3 111 | normallize_residual: True 112 | min_residual_value: -0.25 113 | max_residual_value: 0.25 114 | max_win_length: 100 ## maximum size of window in spectrum 115 | 116 | ################## Training ################### 117 | train: 118 | lr: 0.0001 119 | save_model_path: !join ["output/", *Dataset, "resgrad/ckpt"] 120 | log_dir: !join ["output/", *Dataset, "resgrad/log"] 121 | total_steps: 100000 122 | save_ckpt_step: 10000 123 | validate_step: 100 124 | 125 | # optimizer: 126 | # betas: [0.9, 0.98] 127 | # eps: 0.000000001 128 | # weight_decay: 0.0 129 | # grad_clip_thresh: 1.0 130 | # grad_acc_step: 1 131 | # warm_up_step: 4000 132 | # anneal_steps: [10000, 20000, 30000] 133 | # anneal_rate: 0.3 134 | 135 | ############ Model Parameters ################# 136 | model: 137 | model_type1: "spec2residual" ## "spec2spec" or "spec2residual" 138 | model_type2: "segment-based" ## "segment-based" or "sentence-based" 139 | n_feats: 80 140 | dim: 64 141 | n_spks: 1 142 | spk_emb_dim: 64 143 | beta_min: 0.05 144 | beta_max: 20.0 145 | pe_scale: 1000 146 | 147 | 148 | 149 | ############################################################################################ 150 | ######################################### Vocoder ########################################## 151 | ############################################################################################ 152 | vocoder: 153 | model_name: "" -------------------------------------------------------------------------------- /config/MS_Persian/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################################################ 2 | ####################################### Main Config ######################################## 3 | ############################################################################################ 4 | main: 5 | dataset: &Dataset "MS_Persian" 6 | multi_speaker: True 7 | device: "cpu" ## "cpu" or "cuda" (for using all GPUs) or "cuda:0" or "cuda:1" or ... 8 | 9 | 10 | ############################################################################################ 11 | ####################################### Synthesizer ######################################## 12 | ############################################################################################ 13 | synthesizer: 14 | ##################### Model ####################### 15 | model: 16 | transformer: 17 | encoder_layer: 4 18 | encoder_head: 2 19 | encoder_hidden: 256 20 | decoder_layer: 6 21 | decoder_head: 2 22 | decoder_hidden: 256 23 | conv_filter_size: 1024 24 | conv_kernel_size: [9, 1] 25 | encoder_dropout: 0.2 26 | decoder_dropout: 0.2 27 | 28 | variance_predictor: 29 | filter_size: 256 30 | kernel_size: 3 31 | dropout: 0.5 32 | 33 | variance_embedding: 34 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 35 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 36 | n_bins: 256 37 | max_seq_len: 1000 38 | 39 | #################### Preprocesss ##################### 40 | preprocess: 41 | path: 42 | corpus_path: !join ["dataset/", *Dataset, "synthesizer_data"] 43 | raw_path: !join ["synthesizer/raw_data/", *Dataset] 44 | preprocessed_path: !join ["synthesizer/preprocessed_data/", *Dataset] 45 | preprocessing: 46 | val_size: 100 47 | text: 48 | text_cleaners: "persian_cleaner" ## english_cleaners or persian_cleaners 49 | language: "fa" ## fa or en 50 | audio: 51 | sampling_rate: 22050 52 | max_wav_value: 32768.0 53 | stft: 54 | filter_length: 1024 55 | hop_length: 256 56 | win_length: 1024 57 | mel: 58 | n_mel_channels: 80 59 | mel_fmin: 0 60 | mel_fmax: 8000 61 | pitch: 62 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 63 | normalization: True 64 | energy: 65 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 66 | normalization: True 67 | 68 | #################### Training ##################### 69 | train: 70 | path: 71 | ckpt_path: !join ["output/", *Dataset, "synthesizer/ckpt"] 72 | log_path: !join ["output/", *Dataset, "synthesizer/log"] 73 | # result_path: !join ["output/", *Dataset, "synthesizer/result"] 74 | optimizer: 75 | batch_size: 16 76 | betas: [0.9, 0.98] 77 | eps: 0.000000001 78 | weight_decay: 0.0 79 | grad_clip_thresh: 1.0 80 | grad_acc_step: 1 81 | warm_up_step: 4000 82 | anneal_steps: [300000, 400000, 500000] 83 | anneal_rate: 0.3 84 | step: 85 | total_step: 2010000 86 | log_step: 500 87 | synth_step: 1000 88 | val_step: 1000 89 | save_step: 100000 90 | 91 | 92 | 93 | ############################################################################################ 94 | ########################################### ResGrad ######################################## 95 | ############################################################################################ 96 | resgrad: 97 | #################### Data ##################### 98 | data: 99 | batch_size: 32 100 | metadata_path: !join ["dataset/", *Dataset, "resgrad_data/metadata.csv"] 101 | input_mel_dir: !join ["dataset/", *Dataset, "resgrad_data/input_mel"] 102 | speaker_map_path: !join ["synthesizer/preprocessed_data", *Dataset, "speakers.json"] 103 | val_size: 16 104 | preprocessed_path: "processed_data" 105 | normalized_method: "min-max" 106 | 107 | shuffle_data: True 108 | normallize_spectrum: True 109 | min_spec_value: -13 110 | max_spec_value: 3 111 | normallize_residual: True 112 | min_residual_value: -0.25 113 | max_residual_value: 0.25 114 | max_win_length: 100 ## maximum size of window in spectrum 115 | 116 | ################## Training ################### 117 | train: 118 | lr: 0.0001 119 | total_steps: 200000 120 | validate_step: 200 121 | save_ckpt_step: 10000 122 | save_model_path: !join ["output/", *Dataset, "resgrad/ckpt"] 123 | log_dir: !join ["output/", *Dataset, "resgrad/log"] 124 | 125 | ############ Model Parameters ################# 126 | model: 127 | model_type1: "spec2residual" ## "spec2spec" or "spec2residual" 128 | model_type2: "segment-based" ## "segment-based" or "sentence-based" 129 | n_feats: 80 130 | dim: 64 131 | # n_spks: 1 132 | spk_emb_dim: 64 133 | beta_min: 0.05 134 | beta_max: 20.0 135 | pe_scale: 1000 136 | 137 | 138 | ############################################################################################ 139 | ######################################### Vocoder ########################################## 140 | ############################################################################################ 141 | vocoder: 142 | model_name: "g_2500000_persian" 143 | -------------------------------------------------------------------------------- /config/Persian/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################################################ 2 | ####################################### Main Config ######################################## 3 | ############################################################################################ 4 | main: 5 | dataset: &Dataset "Persian" 6 | multi_speaker: False 7 | device: "cpu" ## "cpu" or "cuda" (for using all GPUs) or "cuda:0" or "cuda:1" or ... 8 | 9 | 10 | ############################################################################################ 11 | ####################################### Synthesizer ######################################## 12 | ############################################################################################ 13 | synthesizer: 14 | ##################### Model ####################### 15 | model: 16 | transformer: 17 | encoder_layer: 4 18 | encoder_head: 2 19 | encoder_hidden: 256 20 | decoder_layer: 6 21 | decoder_head: 2 22 | decoder_hidden: 256 23 | conv_filter_size: 1024 24 | conv_kernel_size: [9, 1] 25 | encoder_dropout: 0.2 26 | decoder_dropout: 0.2 27 | 28 | variance_predictor: 29 | filter_size: 256 30 | kernel_size: 3 31 | dropout: 0.5 32 | 33 | variance_embedding: 34 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 35 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 36 | n_bins: 256 37 | max_seq_len: 1000 38 | 39 | #################### Preprocesss ##################### 40 | preprocess: 41 | path: 42 | corpus_path: !join ["dataset/", *Dataset, "synthesizer_data"] 43 | raw_path: !join ["synthesizer/raw_data/", *Dataset] 44 | preprocessed_path: !join ["synthesizer/preprocessed_data/", *Dataset] 45 | preprocessing: 46 | val_size: 100 47 | text: 48 | text_cleaners: "persian_cleaner" ## english_cleaners or persian_cleaners 49 | language: "fa" ## fa or en 50 | audio: 51 | sampling_rate: 22050 52 | max_wav_value: 32768.0 53 | stft: 54 | filter_length: 1024 55 | hop_length: 256 56 | win_length: 1024 57 | mel: 58 | n_mel_channels: 80 59 | mel_fmin: 0 60 | mel_fmax: 8000 61 | pitch: 62 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 63 | normalization: True 64 | energy: 65 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 66 | normalization: True 67 | 68 | #################### Training ##################### 69 | train: 70 | path: 71 | ckpt_path: !join ["output/", *Dataset, "synthesizer/ckpt"] 72 | log_path: !join ["output/", *Dataset, "synthesizer/log"] 73 | # result_path: !join ["output/", *Dataset, "synthesizer/result"] 74 | optimizer: 75 | batch_size: 16 76 | betas: [0.9, 0.98] 77 | eps: 0.000000001 78 | weight_decay: 0.0 79 | grad_clip_thresh: 1.0 80 | grad_acc_step: 1 81 | warm_up_step: 4000 82 | anneal_steps: [300000, 400000, 500000] 83 | anneal_rate: 0.3 84 | step: 85 | total_step: 1010000 86 | log_step: 500 87 | synth_step: 1000 88 | val_step: 1000 89 | save_step: 100000 90 | 91 | 92 | ############################################################################################ 93 | ########################################### ResGrad ######################################## 94 | ############################################################################################ 95 | resgrad: 96 | #################### Data ##################### 97 | data: 98 | batch_size: 32 99 | metadata_path: !join ["dataset/", *Dataset, "resgrad_data/metadata.csv"] 100 | input_mel_dir: !join ["dataset/", *Dataset, "resgrad_data/input_mel"] 101 | speaker_map_path: !join ["synthesizer/preprocessed_data", *Dataset, "speakers.json"] 102 | val_size: 128 103 | preprocessed_path: "processed_data" 104 | normalized_method: "min-max" 105 | 106 | shuffle_data: True 107 | normallize_spectrum: True 108 | min_spec_value: -13 109 | max_spec_value: 3 110 | normallize_residual: True 111 | min_residual_value: -0.25 112 | max_residual_value: 0.25 113 | max_win_length: 100 ## maximum size of window in spectrum 114 | 115 | ################## Training ################### 116 | train: 117 | lr: 0.0001 118 | save_model_path: !join ["output/", *Dataset, "resgrad/ckpt"] 119 | log_dir: !join ["output/", *Dataset, "resgrad/log"] 120 | total_steps: 100000 121 | save_ckpt_step: 10000 122 | validate_step: 100 123 | 124 | # optimizer: 125 | # betas: [0.9, 0.98] 126 | # eps: 0.000000001 127 | # weight_decay: 0.0 128 | # grad_clip_thresh: 1.0 129 | # grad_acc_step: 1 130 | # warm_up_step: 4000 131 | # anneal_steps: [10000, 20000, 30000] 132 | # anneal_rate: 0.3 133 | 134 | ############ Model Parameters ################# 135 | model: 136 | model_type1: "spec2residual" ## "spec2spec" or "spec2residual" 137 | model_type2: "segment-based" ## "segment-based" or "sentence-based" 138 | n_feats: 80 139 | dim: 64 140 | n_spks: 1 141 | spk_emb_dim: 64 142 | beta_min: 0.05 143 | beta_max: 20.0 144 | pe_scale: 1000 145 | 146 | 147 | ############################################################################################ 148 | ######################################### Vocoder ########################################## 149 | ############################################################################################ 150 | vocoder: 151 | model_name: "g_2500000_persian" 152 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from synthesizer.synthesize import infer as synthesizer_infer 2 | from resgrad.inference import infer as resgrad_infer 3 | from vocoder.inference import infer as vocoder_infer 4 | from utils import load_models, save_result, load_yaml_file, get_file_name 5 | 6 | import argparse 7 | import time 8 | 9 | def infer(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--text", type=str, required=True) 12 | parser.add_argument("--speaker_id", type=int, default=0, required=False) 13 | parser.add_argument("--synthesizer_restore_step", type=int, required=True) 14 | parser.add_argument("--regrad_restore_step", type=int, required=False) 15 | parser.add_argument("--vocoder_restore_step", type=int, default=0 ,required=False) 16 | parser.add_argument("--result_dir", type=str, default="output/Persian/results", required=False) 17 | parser.add_argument("--result_file_name", type=str, default="", required=False) 18 | parser.add_argument("--pitch_control", type=float, default=1.0, required=False) 19 | parser.add_argument("--energy_control", type=float, default=1.0, required=False) 20 | parser.add_argument("--duration_control", type=float, default=1.0, required=False) 21 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml") 22 | args = parser.parse_args() 23 | 24 | # Read Config 25 | config = load_yaml_file(args.config) 26 | 27 | print("load models...") 28 | restore_steps = {"synthesizer":args.synthesizer_restore_step, "resgrad":args.regrad_restore_step, "vocoder":args.vocoder_restore_step} 29 | synthesizer_model, resgrad_model, vocoder_model = load_models(restore_steps, config) 30 | 31 | ## Synthesizer 32 | control_values = args.pitch_control, args.energy_control, args.duration_control 33 | start_time = time.time() 34 | mel_prediction, duration_prediction, pitch_prediction, energy_prediction = synthesizer_infer(synthesizer_model, args.text, control_values, \ 35 | config['synthesizer']['preprocess'], \ 36 | config['main']['device'], \ 37 | speaker = args.speaker_id) 38 | end_time = time.time() 39 | FastSpeech_process_time = end_time-start_time 40 | 41 | ## Save FastSpeech2 result as wav 42 | wav = vocoder_infer(vocoder_model, mel_prediction, config['synthesizer']['preprocess']["preprocessing"]["audio"]["max_wav_value"]) 43 | print("Save FastSpeech2 result...") 44 | file_name = get_file_name(args) 45 | save_result(mel_prediction, wav, pitch_prediction, energy_prediction, config['synthesizer']['preprocess'], args.result_dir, file_name) 46 | 47 | ## Real-Time factor of FastSpeech2 48 | wav_length = len(wav)/config['synthesizer']['preprocess']["preprocessing"]["audio"]["sampling_rate"] 49 | RTF_FastSpeech = FastSpeech_process_time / wav_length 50 | print("FastSpeech2 RTF: {:.6f}".format(RTF_FastSpeech)) 51 | 52 | ## ResGrad 53 | # print("Inference from ResGrad...") 54 | # start_time = time.time() 55 | # mel_prediction = resgrad_infer(resgrad_model, mel_prediction, duration_prediction, args.speaker_id, config['resgrad'], config['main']['device']) 56 | # end_time = time.time() 57 | # ResGrad_process_time = end_time-start_time 58 | 59 | # ## Vocoder 60 | # wav = vocoder_infer(vocoder_model, mel_prediction, config['synthesizer']['preprocess']["preprocessing"]["audio"]["max_wav_value"]) 61 | 62 | # ## Save result 63 | # print("Save ResGrad result...") 64 | # file_name = file_name.replace("FastSpeech", "ResGrad") 65 | # save_result(mel_prediction.squeeze(), wav, pitch_prediction, energy_prediction, config['synthesizer']['preprocess'], args.result_dir, file_name) 66 | 67 | # ## Real-Time factor of ResGrad 68 | # wav_length = len(wav)/config['synthesizer']['preprocess']["preprocessing"]["audio"]["sampling_rate"] 69 | # RTF_ResGrad = ResGrad_process_time / wav_length 70 | # print("ResGrad RTF: {:.6f}".format(RTF_ResGrad)) 71 | 72 | 73 | if __name__ == "__main__": 74 | infer() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | g2p-en==2.1.0 2 | inflect==4.1.0 3 | librosa==0.7.2 4 | matplotlib==3.2.2 5 | numba==0.48 6 | numpy==1.19.0 7 | pypinyin==0.39.0 8 | pyworld==0.2.10 9 | PyYAML==5.4.1 10 | scikit-learn==0.23.2 11 | scipy==1.5.0 12 | soundfile==0.10.3.post1 13 | tensorboard==2.2.2 14 | tgt==1.4.4 15 | torch==1.7.0 16 | tqdm==4.46.1 17 | unidecode==1.1.1 18 | einops==0.3.0 -------------------------------------------------------------------------------- /resgrad.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adibian/ResGrad/a79625a781fd5c3c8d630566b2c72b8f6e011371/resgrad.PNG -------------------------------------------------------------------------------- /resgrad/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import numpy as np 3 | import torch 4 | import os 5 | import csv 6 | import json 7 | 8 | from .utils import normalize_residual, normalize_data 9 | 10 | class SpectumDataset(Dataset): 11 | def __init__(self, config): 12 | super(SpectumDataset, self).__init__() 13 | self.config = config 14 | with open(config['data']['speaker_map_path']) as f: 15 | self.speaker_map = json.load(f) 16 | 17 | self.input_data_path = [] 18 | self.target_data_path = [] 19 | self.duration_data_path = [] 20 | self.speakers = [] 21 | with open(config['data']['metadata_path'], mode='r') as csv_file: 22 | csv_reader = csv.DictReader(csv_file) 23 | line_count = 0 24 | for row in csv_reader: 25 | line_count += 1 26 | if line_count > 1: 27 | self.input_data_path.append(row['predicted_mel']) 28 | self.target_data_path.append(row['target_mel']) 29 | self.duration_data_path.append(row['duration']) 30 | self.speakers.append(self.speaker_map[row['speaker']]) 31 | 32 | if config['model']['model_type2'] == "segment-based": 33 | self.max_len = config['data']['max_win_length'] 34 | else: 35 | self.max_len = config['data']['spectrum_max_size'] 36 | 37 | def __getitem__(self, index): 38 | input_spec_path = self.input_data_path[index] 39 | input_spec = np.load(input_spec_path) 40 | target_spec_path = self.target_data_path[index] 41 | target_spec = np.load(target_spec_path) 42 | dutarions_path = self.duration_data_path[index] 43 | durations = np.load(dutarions_path) 44 | target_spec = torch.from_numpy(target_spec).T 45 | input_spec = torch.from_numpy(input_spec).squeeze() 46 | if self.config['data']['normallize_spectrum']: 47 | input_spec = normalize_data(input_spec, self.config) 48 | target_spec = normalize_data(target_spec, self.config) 49 | 50 | if self.config['model']['model_type2'] == "segment-based": 51 | start_phoneme_index = np.random.choice(len(durations)-min(4, len(durations)-1), 1)[0] 52 | end_phoneme_index = 0 53 | for i in range(start_phoneme_index+1, len(durations)+1): 54 | win_length = sum(durations[start_phoneme_index:i]) 55 | if win_length > self.max_len: 56 | end_phoneme_index = i-1 57 | break 58 | if end_phoneme_index == 0: 59 | end_phoneme_index = len(durations) 60 | for i in range(start_phoneme_index): 61 | start_phoneme_index -= 1 62 | win_length = sum(durations[start_phoneme_index:end_phoneme_index]) 63 | if win_length > self.max_len: 64 | start_phoneme_index += 1 65 | break 66 | win_start = sum(durations[:start_phoneme_index]) 67 | win_end = sum(durations[:end_phoneme_index]) 68 | 69 | input_spec = input_spec[:,win_start:win_end] 70 | target_spec = target_spec[:,win_start:win_end] 71 | 72 | spec_size = input_spec.shape[-1] 73 | input_spec = torch.nn.functional.pad(input_spec, (0, self.max_len-spec_size), mode = "constant", value = 0.0) 74 | target_spec = torch.nn.functional.pad(target_spec, (0, self.max_len-spec_size), mode = "constant", value = 0.0) 75 | 76 | residual_spec = target_spec - input_spec 77 | if self.config['data']['normallize_residual']: 78 | residual_spec = normalize_residual(residual_spec, self.config) 79 | 80 | mask = torch.ones((1, input_spec.shape[-1])) 81 | mask[:,spec_size:] = 0 82 | 83 | speaker = self.speakers[index] 84 | 85 | if self.config['model']['model_type1'] == "spec2residual": 86 | residual_spec = target_spec - input_spec 87 | if self.config['data']['normallize_residual']: 88 | residual_spec = normalize_residual(residual_spec, self.config) 89 | residual_spec = residual_spec*mask 90 | return input_spec, target_spec, residual_spec, mask, speaker 91 | else: 92 | return input_spec, target_spec, mask, speaker 93 | 94 | 95 | def __len__(self): 96 | return len(self.input_data_path) 97 | 98 | 99 | def create_dataset(config): 100 | dataset = SpectumDataset(config) 101 | val_dataset, train_dataset = torch.utils.data.random_split(dataset, [config['data']['val_size'], len(dataset)-(config['data']['val_size'])]) 102 | return DataLoader(train_dataset, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle_data']), \ 103 | DataLoader(val_dataset, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle_data']) 104 | -------------------------------------------------------------------------------- /resgrad/inference.py: -------------------------------------------------------------------------------- 1 | from .utils import denormalize_residual, denormalize_data, normalize_data 2 | 3 | import torch 4 | 5 | 6 | def infer(model, mel_prediction, duration_prediction, speaker, config, device): 7 | synthesized_spec = mel_prediction.unsqueeze(0).to(device) 8 | # synthesized_spec = torch.from_numpy(synthesized_spec) 9 | if config['data']['normallize_spectrum']: 10 | synthesized_spec = normalize_data(synthesized_spec, config) 11 | 12 | if config['model']['model_type2'] == "segment-based": 13 | durations = duration_prediction 14 | # durations = torch.round(torch.exp(duration_prediction.squeeze()) - 1) 15 | 16 | all_mask, all_segment_spec, all_start_points, all_spec_size = [], [], [], [] 17 | pred = torch.zeros(synthesized_spec.shape) 18 | 19 | ## Create segments of date exept last segment 20 | start_phoneme_index = 0 21 | end_phoneme_index = 0 22 | for i in range(1, len(durations)+1): 23 | win_length = int(sum(durations[start_phoneme_index:i])) 24 | if win_length > config['data']['max_win_length']: 25 | end_phoneme_index = i-1 26 | start_point = int(sum(durations[:start_phoneme_index])) 27 | end_point = int(sum(durations[:end_phoneme_index])) 28 | segment_spec = synthesized_spec[:,:,start_point:end_point] 29 | all_start_points.append(start_point) 30 | spec_size = segment_spec.shape[-1] 31 | all_spec_size.append(spec_size) 32 | segment_spec = torch.nn.functional.pad(segment_spec, (0, config['data']['max_win_length']-spec_size), mode = "constant", value = 0.0) 33 | mask = torch.ones((1, segment_spec.shape[-1])).to(device) 34 | mask[:,spec_size:] = 0 35 | all_mask.append(mask.unsqueeze(0)) 36 | all_segment_spec.append(segment_spec) 37 | start_phoneme_index = end_phoneme_index 38 | 39 | ## Create last segment of data with overlapping to last previous segments 40 | start_phoneme_index = len(durations) 41 | end_phoneme_index = len(durations) 42 | for i in range(len(durations)): 43 | start_phoneme_index -= 1 44 | win_length = int(sum(durations[start_phoneme_index:])) 45 | if win_length > config['data']['max_win_length']: 46 | start_phoneme_index += 1 47 | start_point = int(sum(durations[:start_phoneme_index])) 48 | end_point = int(sum(durations[:end_phoneme_index])) 49 | segment_spec = synthesized_spec[:,:,start_point:end_point] 50 | all_start_points.append(start_point) 51 | spec_size = segment_spec.shape[-1] 52 | all_spec_size.append(spec_size) 53 | segment_spec = torch.nn.functional.pad(segment_spec, (0, config['data']['max_win_length']-spec_size), mode = "constant", value = 0.0) 54 | mask = torch.ones((1, segment_spec.shape[-1])).to(device) 55 | mask[:,spec_size:] = 0 56 | all_mask.append(mask.unsqueeze(0)) 57 | all_segment_spec.append(segment_spec) 58 | break 59 | 60 | speakers = [speaker] * len(all_segment_spec) 61 | mask = torch.cat(all_mask).to(device) 62 | segment_spec = torch.cat(all_segment_spec).to(device) 63 | z = segment_spec + torch.randn_like(segment_spec, device=device) / 1.5 64 | segments_pred = torch.zeros(segment_spec.shape) 65 | speakers = torch.tensor(speakers) 66 | batch_size = config['data']['batch_size'] 67 | for i in range(0, segment_spec.shape[0], batch_size): 68 | segments_pred = model(z[i:i+batch_size], mask[i:i+batch_size], segment_spec[i:i+batch_size], \ 69 | n_timesteps=50, stoc=False, spk_id=speakers[i:i+batch_size]) 70 | 71 | for i in range(len(segments_pred)): 72 | segment_pred = segments_pred[i,:,:all_spec_size[i]] 73 | pred[:,:,all_start_points[i]:all_start_points[i]+all_spec_size[i]] = segment_pred 74 | else: 75 | mask = torch.ones(synthesized_spec.shape).to(device) 76 | z = synthesized_spec + torch.randn_like(synthesized_spec, device=device) / 1.5 77 | pred = model(z, mask, synthesized_spec, n_timesteps=100, stoc=False, spk_id=[speaker]) 78 | pred = pred.to(device) 79 | 80 | if config['model']['model_type1'] == "spec2residual": 81 | if config['data']['normallize_residual']: 82 | spec_pred = denormalize_residual(pred, config) + synthesized_spec 83 | else: 84 | spec_pred = pred + synthesized_spec 85 | else: 86 | spec_pred = pred 87 | 88 | if config['data']['normallize_spectrum']: 89 | spec_pred = denormalize_data(spec_pred, config) 90 | 91 | return spec_pred -------------------------------------------------------------------------------- /resgrad/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | from .diffusion import Diffusion 10 | -------------------------------------------------------------------------------- /resgrad/model/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class BaseModule(torch.nn.Module): 14 | def __init__(self): 15 | super(BaseModule, self).__init__() 16 | 17 | @property 18 | def nparams(self): 19 | """ 20 | Returns number of trainable parameters of the module. 21 | """ 22 | num_params = 0 23 | for name, param in self.named_parameters(): 24 | if param.requires_grad: 25 | num_params += np.prod(param.detach().cpu().numpy().shape) 26 | return num_params 27 | 28 | 29 | def relocate_input(self, x: list): 30 | """ 31 | Relocates provided tensors to the same device set for the module. 32 | """ 33 | device = next(self.parameters()).device 34 | for i in range(len(x)): 35 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 36 | x[i] = x[i].to(device) 37 | return x 38 | -------------------------------------------------------------------------------- /resgrad/model/diffusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import math 10 | import torch 11 | from einops import rearrange 12 | 13 | from .base import BaseModule 14 | 15 | 16 | class Mish(BaseModule): 17 | def forward(self, x): 18 | return x * torch.tanh(torch.nn.functional.softplus(x)) 19 | 20 | 21 | class Upsample(BaseModule): 22 | def __init__(self, dim): 23 | super(Upsample, self).__init__() 24 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) 25 | 26 | def forward(self, x): 27 | return self.conv(x) 28 | 29 | 30 | class Downsample(BaseModule): 31 | def __init__(self, dim): 32 | super(Downsample, self).__init__() 33 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) 34 | 35 | def forward(self, x): 36 | return self.conv(x) 37 | 38 | 39 | class Rezero(BaseModule): 40 | def __init__(self, fn): 41 | super(Rezero, self).__init__() 42 | self.fn = fn 43 | self.g = torch.nn.Parameter(torch.zeros(1)) 44 | 45 | def forward(self, x): 46 | return self.fn(x) * self.g 47 | 48 | 49 | class Block(BaseModule): 50 | def __init__(self, dim, dim_out, groups=8): 51 | super(Block, self).__init__() 52 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, 53 | padding=1), torch.nn.GroupNorm( 54 | groups, dim_out), Mish()) 55 | 56 | def forward(self, x, mask): 57 | output = self.block(x * mask) 58 | return output * mask 59 | 60 | 61 | class ResnetBlock(BaseModule): 62 | def __init__(self, dim, dim_out, time_emb_dim, groups=8): 63 | super(ResnetBlock, self).__init__() 64 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 65 | dim_out)) 66 | 67 | self.block1 = Block(dim, dim_out, groups=groups) 68 | self.block2 = Block(dim_out, dim_out, groups=groups) 69 | if dim != dim_out: 70 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) 71 | else: 72 | self.res_conv = torch.nn.Identity() 73 | 74 | def forward(self, x, mask, time_emb): 75 | h = self.block1(x, mask) 76 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 77 | h = self.block2(h, mask) 78 | output = h + self.res_conv(x * mask) 79 | return output 80 | 81 | 82 | class LinearAttention(BaseModule): 83 | def __init__(self, dim, heads=4, dim_head=32): 84 | super(LinearAttention, self).__init__() 85 | self.heads = heads 86 | hidden_dim = dim_head * heads 87 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 88 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) 89 | 90 | def forward(self, x): 91 | b, c, h, w = x.shape 92 | qkv = self.to_qkv(x) 93 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', 94 | heads = self.heads, qkv=3) 95 | k = k.softmax(dim=-1) 96 | context = torch.einsum('bhdn,bhen->bhde', k, v) 97 | out = torch.einsum('bhde,bhdn->bhen', context, q) 98 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', 99 | heads=self.heads, h=h, w=w) 100 | return self.to_out(out) 101 | 102 | 103 | class Residual(BaseModule): 104 | def __init__(self, fn): 105 | super(Residual, self).__init__() 106 | self.fn = fn 107 | 108 | def forward(self, x, *args, **kwargs): 109 | output = self.fn(x, *args, **kwargs) + x 110 | return output 111 | 112 | 113 | class SinusoidalPosEmb(BaseModule): 114 | def __init__(self, dim): 115 | super(SinusoidalPosEmb, self).__init__() 116 | self.dim = dim 117 | 118 | def forward(self, x, scale=1000): 119 | device = x.device 120 | half_dim = self.dim // 2 121 | emb = math.log(10000) / (half_dim - 1) 122 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 123 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 124 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 125 | return emb 126 | 127 | 128 | class GradLogPEstimator2d(BaseModule): 129 | def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, 130 | n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): 131 | super(GradLogPEstimator2d, self).__init__() 132 | self.dim = dim 133 | self.dim_mults = dim_mults 134 | self.groups = groups 135 | self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 136 | self.spk_emb_dim = spk_emb_dim 137 | self.pe_scale = pe_scale 138 | 139 | if n_spks > 1: 140 | self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), 141 | torch.nn.Linear(spk_emb_dim * 4, n_feats)) 142 | self.time_pos_emb = SinusoidalPosEmb(dim) 143 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), 144 | torch.nn.Linear(dim * 4, dim)) 145 | 146 | dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] 147 | in_out = list(zip(dims[:-1], dims[1:])) 148 | self.downs = torch.nn.ModuleList([]) 149 | self.ups = torch.nn.ModuleList([]) 150 | num_resolutions = len(in_out) 151 | 152 | for ind, (dim_in, dim_out) in enumerate(in_out): 153 | is_last = ind >= (num_resolutions - 1) 154 | self.downs.append(torch.nn.ModuleList([ 155 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim), 156 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim), 157 | Residual(Rezero(LinearAttention(dim_out))), 158 | Downsample(dim_out) if not is_last else torch.nn.Identity()])) 159 | 160 | mid_dim = dims[-1] 161 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 162 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 163 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 164 | 165 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 166 | self.ups.append(torch.nn.ModuleList([ 167 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), 168 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim), 169 | Residual(Rezero(LinearAttention(dim_in))), 170 | Upsample(dim_in)])) 171 | self.final_block = Block(dim, dim) 172 | self.final_conv = torch.nn.Conv2d(dim, 1, 1) 173 | 174 | def forward(self, x, mask, mu, t, spk=None): 175 | if not isinstance(spk, type(None)): 176 | s = self.spk_mlp(spk) 177 | 178 | t = self.time_pos_emb(t, scale=self.pe_scale) 179 | t = self.mlp(t) 180 | 181 | if self.n_spks < 2: 182 | x = torch.stack([mu, x], 1) 183 | else: 184 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) 185 | x = torch.stack([mu, x, s], 1) 186 | mask = mask.unsqueeze(1) 187 | 188 | hiddens = [] 189 | masks = [mask] 190 | for resnet1, resnet2, attn, downsample in self.downs: 191 | mask_down = masks[-1] 192 | x = resnet1(x, mask_down, t) 193 | x = resnet2(x, mask_down, t) 194 | x = attn(x) 195 | hiddens.append(x) 196 | x = downsample(x * mask_down) 197 | masks.append(mask_down[:, :, :, ::2]) 198 | 199 | masks = masks[:-1] 200 | mask_mid = masks[-1] 201 | x = self.mid_block1(x, mask_mid, t) 202 | x = self.mid_attn(x) 203 | x = self.mid_block2(x, mask_mid, t) 204 | 205 | for resnet1, resnet2, attn, upsample in self.ups: 206 | mask_up = masks.pop() 207 | x = torch.cat((x, hiddens.pop()), dim=1) 208 | x = resnet1(x, mask_up, t) 209 | x = resnet2(x, mask_up, t) 210 | x = attn(x) 211 | x = upsample(x * mask_up) 212 | 213 | x = self.final_block(x, mask) 214 | output = self.final_conv(x * mask) 215 | 216 | return (output * mask).squeeze(1) 217 | 218 | 219 | def get_noise(t, beta_init, beta_term, cumulative=False): 220 | if cumulative: 221 | noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) 222 | else: 223 | noise = beta_init + (beta_term - beta_init)*t 224 | return noise 225 | 226 | 227 | class Diffusion(BaseModule): 228 | def __init__(self, n_feats, dim, 229 | n_spks=1, spk_emb_dim=64, 230 | beta_min=0.05, beta_max=20, pe_scale=1000): 231 | super(Diffusion, self).__init__() 232 | self.n_feats = n_feats 233 | self.dim = dim 234 | self.n_spks = n_spks 235 | self.spk_emb_dim = spk_emb_dim 236 | self.beta_min = beta_min 237 | self.beta_max = beta_max 238 | self.pe_scale = pe_scale 239 | if n_spks>1: 240 | self.spk_emb = torch.nn.Embedding( 241 | n_spks, 242 | spk_emb_dim 243 | ) 244 | 245 | self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks, 246 | spk_emb_dim=spk_emb_dim, 247 | pe_scale=pe_scale, n_feats=n_feats) 248 | 249 | def forward_diffusion(self, x0, mask, mu, t): 250 | time = t.unsqueeze(-1).unsqueeze(-1) 251 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) 252 | mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) 253 | variance = 1.0 - torch.exp(-cum_noise) 254 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, 255 | requires_grad=False) 256 | xt = mean + z * torch.sqrt(variance) 257 | return xt * mask, z * mask 258 | 259 | @torch.no_grad() 260 | def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None): 261 | h = 1.0 / n_timesteps 262 | xt = z * mask 263 | for i in range(n_timesteps): 264 | t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, 265 | device=z.device) 266 | time = t.unsqueeze(-1).unsqueeze(-1) 267 | noise_t = get_noise(time, self.beta_min, self.beta_max, 268 | cumulative=False) 269 | if stoc: # adds stochastic term 270 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) 271 | dxt_det = dxt_det * noise_t * h 272 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, 273 | requires_grad=False) 274 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 275 | dxt = dxt_det + dxt_stoc 276 | else: 277 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk)) 278 | dxt = dxt * noise_t * h 279 | xt = (xt - dxt) * mask 280 | return xt 281 | 282 | @torch.no_grad() 283 | def forward(self, z, mask, mu, n_timesteps, stoc=False, spk_id=None): 284 | if self.n_spks > 1: 285 | spk = self.spk_emb(spk_id) 286 | return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk) 287 | 288 | def loss_t(self, x0, mask, mu, t, spk_id=None): 289 | xt, z = self.forward_diffusion(x0, mask, mu, t) 290 | time = t.unsqueeze(-1).unsqueeze(-1) 291 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) 292 | if self.n_spks > 1: 293 | spk = self.spk_emb(spk_id) 294 | noise_estimation = self.estimator(xt, mask, mu, t, spk) 295 | noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) 296 | loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats) 297 | return loss, xt 298 | 299 | def compute_loss(self, x0, mask, mu, spk_id=None, offset=1e-5): 300 | t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device, 301 | requires_grad=False) 302 | t = torch.clamp(t, offset, 1.0 - offset) 303 | return self.loss_t(x0, mask, mu, t, spk_id) 304 | -------------------------------------------------------------------------------- /resgrad/model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, config, current_step): 9 | 10 | self._optimizer = torch.optim.Adam( 11 | model.parameters(), 12 | betas=config["optimizer"]["betas"], 13 | eps=config["optimizer"]["eps"], 14 | weight_decay=config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = config["optimizer"]["anneal_rate"] 19 | self.current_step = current_step 20 | self.init_lr = np.power(config["model"]["dim"]*4, -0.5) 21 | 22 | def step_and_update_lr(self): 23 | self._update_learning_rate() 24 | self._optimizer.step() 25 | 26 | def zero_grad(self): 27 | self._optimizer.zero_grad() 28 | 29 | def load_state_dict(self, path): 30 | self._optimizer.load_state_dict(path) 31 | 32 | def _get_lr_scale(self): 33 | lr = np.min( 34 | [ 35 | np.power(self.current_step, -0.5), 36 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 37 | ] 38 | ) 39 | for s in self.anneal_steps: 40 | if self.current_step > s: 41 | lr = lr * self.anneal_rate 42 | return lr 43 | 44 | def _update_learning_rate(self): 45 | """ Learning rate scheduling per step """ 46 | self.current_step += 1 47 | lr = self.init_lr * self._get_lr_scale() 48 | 49 | for param_group in self._optimizer.param_groups: 50 | param_group["lr"] = lr 51 | -------------------------------------------------------------------------------- /resgrad/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | from tqdm import tqdm 4 | import os 5 | 6 | from .utils import plot_tensor, denormalize_residual, load_model 7 | from .data import create_dataset 8 | 9 | 10 | def logging(logger, config, original_spec, synthesized_spec, target_residual, noisy_spec, pred, mask, title, step): 11 | zero_indexes = (mask == 0).nonzero() 12 | if len(zero_indexes): 13 | start_zero_index = int(zero_indexes[0][-1]) 14 | else: 15 | start_zero_index = original_spec.shape[-1] 16 | original_spec = original_spec[:,:start_zero_index] 17 | synthesized_spec = synthesized_spec[:,:start_zero_index] 18 | noisy_spec = noisy_spec[:,:start_zero_index] 19 | if config['model']['model_type1'] == "spec2residual": 20 | target_residual = target_residual[:,:start_zero_index] 21 | pred_residual = pred[:,:start_zero_index] 22 | if config['data']['normallize_residual']: 23 | pred_spec = denormalize_residual(pred[:,:start_zero_index], config) + synthesized_spec 24 | else: 25 | pred_spec = pred[:,:start_zero_index] + synthesized_spec 26 | logger.add_image(f'{title}/target_residual_spec', plot_tensor(target_residual.squeeze().cpu().detach().numpy(), "residual", config), global_step=step, dataformats='HWC') 27 | logger.add_image(f'{title}/predicted_residual_spec', plot_tensor(pred_residual.squeeze().cpu().detach().numpy(), "residual", config), global_step=step, dataformats='HWC') 28 | else: 29 | pred_spec = pred[:,:start_zero_index] 30 | 31 | logger.add_image(f'{title}/input_spec', plot_tensor(synthesized_spec.squeeze().cpu().detach().numpy(), "spectrum", config), global_step=step, dataformats='HWC') 32 | logger.add_image(f'{title}/predicted_spec', plot_tensor(pred_spec.squeeze().cpu().detach().numpy(), "spectrum", config), global_step=step, dataformats='HWC') 33 | logger.add_image(f'{title}/target_spec', plot_tensor(original_spec.squeeze().cpu().detach().numpy(), "spectrum", config), global_step=step, dataformats='HWC') 34 | logger.add_image(f'{title}/noisy_spec', plot_tensor(noisy_spec.squeeze().cpu().detach().numpy(), "noisy_spectrum", config), global_step=step, dataformats='HWC') 35 | 36 | 37 | def resgrad_train(args, config): 38 | os.makedirs(config['train']['log_dir'], exist_ok=True) 39 | os.makedirs(config['train']['save_model_path'], exist_ok=True) 40 | 41 | device = config['main']['device'] 42 | 43 | print('Initializing logger...') 44 | logger = SummaryWriter(log_dir=config['train']['log_dir']) 45 | print("Load data...") 46 | train_dataset, val_dataset = create_dataset(config) 47 | 48 | print("Load model...") 49 | model, optimizer = load_model(config, train=True, restore_model_step=args.restore_step) 50 | 51 | scaler = torch.cuda.amp.GradScaler() 52 | # grad_acc_step = config["optimizer"]["grad_acc_step"] 53 | # grad_clip_thresh = config["optimizer"]["grad_clip_thresh"] 54 | 55 | step = args.restore_step - 1 56 | epoch = args.restore_step // (len(train_dataset)//config['data']['batch_size'] + 1) 57 | avg_val_loss = 0 58 | avg_train_loss = 0 59 | 60 | print("Start training...") 61 | outer_bar = tqdm(total=config['train']['total_steps'], desc="Total Training", position=0) 62 | outer_bar.n = step 63 | outer_bar.update() 64 | 65 | while True: 66 | inner_bar = tqdm(total=len(train_dataset), desc="Epoch {}".format(epoch), position=1) 67 | train_loss_list = [] 68 | epoch += 1 69 | for train_data in train_dataset: 70 | step += 1 71 | inner_bar.update(1) 72 | outer_bar.update(1) 73 | if config['model']['model_type1'] == "spec2residual": 74 | synthesized_spec, original_spec, residual_spec, mask, speakers = train_data 75 | synthesized_spec = synthesized_spec.to(device) 76 | mask = mask.to(device) 77 | residual_spec = residual_spec.to(device) 78 | if config['main']['multi_speaker']: 79 | speakers = speakers.to(device) 80 | loss, pred = model.compute_loss(residual_spec, mask, synthesized_spec, speakers) 81 | else: 82 | loss, pred = model.compute_loss(residual_spec, mask, synthesized_spec) 83 | 84 | else: 85 | synthesized_spec, original_spec, mask, speakers = train_data 86 | mask = mask.to(device) 87 | synthesized_spec = synthesized_spec.to(device) 88 | original_spec = original_spec.to(device) 89 | if config['main']['multi_speaker']: 90 | speakers = speakers.to(device) 91 | loss, pred = model.compute_loss(original_spec, mask, synthesized_spec, speakers) 92 | else: 93 | loss, pred = model.compute_loss(original_spec, mask, synthesized_spec) 94 | 95 | optimizer.zero_grad() 96 | scaler.scale(loss).backward() 97 | scaler.step(optimizer) 98 | scaler.update() 99 | train_loss_list.append(loss.item()) 100 | 101 | # loss.backward() 102 | # train_loss_list.append(loss.item()) 103 | # if step % grad_acc_step == 0: 104 | # # Clipping gradients to avoid gradient explosion 105 | # torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) 106 | # # Update weights 107 | # optimizer.step_and_update_lr() 108 | # optimizer.zero_grad() 109 | 110 | if step % config['train']['validate_step'] == 0: 111 | model.eval() 112 | with torch.no_grad(): 113 | all_val_loss = [] 114 | val_num = 0 115 | for val_data in val_dataset: 116 | val_num += 1 117 | if config['model']['model_type1'] == "spec2residual": 118 | synthesized_spec, original_spec, target_residual, mask, speakers = val_data 119 | synthesized_spec = synthesized_spec.to(device) 120 | mask = mask.to(device) 121 | target_residual = target_residual.to(device) 122 | if config['main']['multi_speaker']: 123 | speakers = speakers.to(device) 124 | val_loss, noisy_spec = model.compute_loss(target_residual, mask, synthesized_spec, speakers) 125 | else: 126 | val_loss, noisy_spec = model.compute_loss(target_residual, mask, synthesized_spec) 127 | 128 | else: 129 | synthesized_spec, original_spec, mask, speakers = val_data 130 | synthesized_spec = synthesized_spec.to(device) 131 | mask = mask.to(device) 132 | original_spec = original_spec.to(device) 133 | if config['main']['multi_speaker']: 134 | speakers = speakers.to(device) 135 | val_loss, noisy_spec = model.compute_loss(original_spec, mask, synthesized_spec, speakers) 136 | else: 137 | val_loss, noisy_spec = model.compute_loss(original_spec, mask, synthesized_spec) 138 | target_residual = [None for _ in range(len(original_spec))] 139 | all_val_loss.append(val_loss.item()) 140 | 141 | ## logging result spectrums 142 | if val_num == 1: 143 | z = synthesized_spec + torch.randn_like(synthesized_spec, device=device) / 1.5 144 | # Generate sample by performing reverse dynamics 145 | if config['main']['multi_speaker']: 146 | pred = model(z, mask, synthesized_spec, n_timesteps=50, stoc=False, spk_id=speakers) 147 | else: 148 | pred = model(z, mask, synthesized_spec, n_timesteps=50, stoc=False, spk_id=None) 149 | for i in range(3): 150 | logging(logger, config, original_spec[i], synthesized_spec[i], target_residual[i], noisy_spec[i], pred[i], mask[i], \ 151 | f'image{i}_step{step}', step) 152 | 153 | avg_val_loss = sum(all_val_loss) / len(all_val_loss) 154 | logger.add_scalar('validation/loss', avg_val_loss, global_step=step) 155 | avg_train_loss = sum(train_loss_list) / len(train_loss_list) 156 | logger.add_scalar('training/loss', avg_train_loss, global_step=step) 157 | train_loss_list = [] 158 | model.train() 159 | 160 | inner_bar.set_postfix(train_loss=avg_train_loss, val_loss=avg_val_loss) 161 | 162 | if step % config['train']['save_ckpt_step'] == 0: 163 | ## Save checkpoints 164 | torch.save( 165 | { 166 | "model": model.state_dict(), 167 | "optimizer": optimizer.state_dict() 168 | # "optimizer": optimizer._optimizer.state_dict(), 169 | }, 170 | os.path.join(config['train']['save_model_path'], f'ResGrad_step{step}.pth') 171 | ) 172 | # torch.save(model.state_dict(), os.path.join(config['train']['save_model_path'], f'ResGrad_step{step}.pth')) 173 | # torch.save(optimizer.state_dict(), os.path.join(config['train']['save_model_path'], 'optimizer.pth')) 174 | 175 | if step > config['train']['total_steps']: 176 | quit() 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /resgrad/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import os 5 | import json 6 | from .model import Diffusion 7 | from .model.optimizer import ScheduledOptim 8 | 9 | def save_figure_to_numpy(fig): 10 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 11 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 12 | return data 13 | 14 | def plot_tensor(tensor, tensor_name, config): 15 | plt.style.use('default') 16 | fig, ax = plt.subplots(figsize=(8, 3)) 17 | if tensor_name == "spectrum" and config['data']['normallize_spectrum']: 18 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none', \ 19 | vmin = 0.2, vmax = 0.9) 20 | elif tensor_name == "residual" and config['data']['normallize_residual']: 21 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none', \ 22 | vmin = 0.2, vmax = 0.9) 23 | else: 24 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') 25 | plt.colorbar(im, ax=ax) 26 | plt.tight_layout() 27 | fig.canvas.draw() 28 | data = save_figure_to_numpy(fig) 29 | plt.close() 30 | return data 31 | 32 | def plot_spectrum(spec, path): 33 | plt.figure(figsize=(10, 3)) 34 | im = plt.imshow(spec, aspect="auto", origin="lower", interpolation='none') 35 | plt.colorbar(im) 36 | plt.tight_layout() 37 | plt.savefig(path) 38 | plt.close() 39 | 40 | def crop_masked_values(mat_list, length): 41 | new_mat_list = [] 42 | for mat in mat_list: 43 | new_mat_list.append(mat[:,:length]) 44 | return new_mat_list 45 | 46 | def normalize_data(data, config): 47 | if config['data']['normalized_method'] == "min-max": 48 | data = (data - config['data']['min_spec_value'])/(config['data']['max_spec_value'] - config['data']['min_spec_value']) 49 | else: 50 | print("normalized method is not supported!") 51 | return data 52 | 53 | def normalize_residual(residual_spec, config): 54 | if config['data']['normalized_method'] == "min-max": 55 | residual_spec = (residual_spec - config['data']['min_residual_value'])/(config['data']['max_residual_value'] - \ 56 | config['data']['min_residual_value']) 57 | else: 58 | print("normalized method is not supported!") 59 | return residual_spec 60 | 61 | def denormalize_data(data, config): 62 | if config['data']['normalized_method'] == "min-max": 63 | data = data * (config['data']['max_spec_value'] - config['data']['min_spec_value']) + config['data']['min_spec_value'] 64 | else: 65 | print("normalized method is not supported!") 66 | return data 67 | 68 | def denormalize_residual(residual_spec, config): 69 | if config['data']['normalized_method'] == "min-max": 70 | residual_spec = residual_spec * (config['data']['max_residual_value'] - config['data']['min_residual_value']) + \ 71 | config['data']['min_residual_value'] 72 | else: 73 | print("normalized method is not supported!") 74 | return residual_spec 75 | 76 | def load_model(config, device, train=False, restore_model_step=0): 77 | with open(config['data']['speaker_map_path']) as f: 78 | speaker_map = json.load(f) 79 | n_spks = len(speaker_map.keys()) 80 | model = Diffusion(n_feats=config['model']['n_feats'], dim=config['model']['dim'], n_spks=n_spks, \ 81 | spk_emb_dim=config['model']['spk_emb_dim'], beta_min=config['model']['beta_min'], \ 82 | beta_max=config['model']['beta_max'], pe_scale=config['model']['pe_scale']) 83 | model = model.to(device) 84 | if restore_model_step > 0: 85 | checkpoint = torch.load(os.path.join(config['train']['save_model_path'], f'ResGrad_step{restore_model_step}.pth'), \ 86 | map_location=device) 87 | # checkpoint = torch.load(os.path.join("/mnt/hdd1/adibian/FastSpeech2/ResGrad/output/Persian/dur_taget_pitch_target/resgrad/ckpt", \ 88 | # f'ResGrad_epoch{restore_model_epoch}.pth'), \ 89 | # map_location=device) 90 | model.load_state_dict(checkpoint['model']) 91 | 92 | if train: 93 | # scheduled_optim = ScheduledOptim(model, config, restore_model_step) 94 | optimizer = torch.optim.Adam(params=model.parameters(), lr=config['train']['lr']) 95 | if restore_model_step > 0: 96 | optimizer_state = torch.load(os.path.join(config['train']['save_model_path'], 'optimizer.pth')) 97 | # scheduled_optim.load_state_dict(checkpoint['optimizer']) 98 | optimizer.load_state_dict(optimizer_state) 99 | model.train() 100 | return model, optimizer 101 | 102 | model.eval() 103 | return model -------------------------------------------------------------------------------- /resgrad_data.py: -------------------------------------------------------------------------------- 1 | from utils import load_models, load_yaml_file 2 | from synthesizer.synthesize import infer as synthesizer_infer 3 | 4 | import argparse 5 | import os 6 | import csv 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | import json 11 | 12 | def read_input_data(data_file_path): 13 | input_texts = {} 14 | with open(data_file_path, mode='r') as f: 15 | lines = f.readlines() 16 | for line in lines: 17 | fields = line.split("|") 18 | file_name, speaker, input_text = fields[0], fields[1], fields[2] 19 | input_texts[(speaker, file_name)] = input_text.strip() 20 | return input_texts 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--synthesizer_restore_step", type=int, required=True) 25 | parser.add_argument("--data_file_path", type=str, default="dataset/Persian/synthesizer_data/train.txt", required=False) 26 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml") 27 | args = parser.parse_args() 28 | 29 | # Read Config 30 | config = load_yaml_file(args.config) 31 | 32 | restore_steps = {"synthesizer":args.synthesizer_restore_step, "resgrad":None, "vocoder":None} 33 | synthesizer_model, _, _ = load_models(restore_steps, config) 34 | print("Load input data...") 35 | text_data = read_input_data(args.data_file_path) 36 | print("{} inputs data is loaded.".format(len(text_data))) 37 | 38 | with open(os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], "speakers.json")) as f: 39 | speaker_map = json.load(f) 40 | 41 | mel_pred_dir = os.path.join(config['resgrad']['data']['input_mel_dir']) 42 | os.makedirs(mel_pred_dir, exist_ok=True) 43 | 44 | resgrad_data = [] 45 | device = config['main']['device'] 46 | i = 0 47 | for (speaker, file_name), text in tqdm(text_data.items()): 48 | # i +=1 49 | # if i>30: 50 | # break 51 | dur_file_name = speaker + "-duration-" + file_name + ".npy" 52 | pitch_file_name = speaker + "-pitch-" + file_name + ".npy" 53 | dur_path = os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], 'duration', dur_file_name) 54 | pitch_path = os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], 'pitch', pitch_file_name) 55 | 56 | if not os.path.exists(dur_path): 57 | continue 58 | 59 | mel_target_file_name = speaker + "-mel-" + file_name + ".npy" 60 | mel_target_path = os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], 'mel', mel_target_file_name) 61 | 62 | ### Synthersize mel-spectrum and save as data for resgrad 63 | dur_target = torch.from_numpy(np.load(dur_path)).to(device).unsqueeze(0) 64 | pitch_target = torch.from_numpy(np.load(pitch_path)).to(device).unsqueeze(0) 65 | control_values = 1.0,1.0,1.0 66 | mel_prediction, _, _, _ = synthesizer_infer(synthesizer_model, text, control_values, config['synthesizer']['preprocess'], 67 | device, speaker=speaker_map[speaker], d_target=dur_target, p_target=pitch_target) 68 | 69 | mel_pred_path = os.path.join(mel_pred_dir, speaker + "-pred_mel-" + file_name + ".npy") 70 | np.save(mel_pred_path, mel_prediction.cpu()) 71 | 72 | resgrad_data.append({'speaker': speaker, 'predicted_mel':mel_pred_path, 'target_mel':mel_target_path, 'duration':dur_path}) 73 | 74 | with open(config['resgrad']['data']['metadata_path'], 'w') as file: 75 | fields = ['speaker', 'predicted_mel', 'target_mel', 'duration'] 76 | writer = csv.DictWriter(file, fieldnames = fields) 77 | writer.writeheader() 78 | writer.writerows(resgrad_data) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | # python resgrad_data.py --synthesizer_restore_step 240 84 | -------------------------------------------------------------------------------- /synthesizer/audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /synthesizer/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /synthesizer/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data.cuda(), 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes): 152 | output = dynamic_range_compression(magnitudes) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1 170 | assert torch.max(y.data) <= 1 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, energy 179 | -------------------------------------------------------------------------------- /synthesizer/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | from librosa.filters import mel as librosa_mel_fn 5 | 6 | from audio.audio_processing import griffin_lim 7 | 8 | 9 | def get_mel_from_wav(audio, _stft): 10 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 11 | audio = torch.autograd.Variable(audio, requires_grad=False) 12 | melspec, energy = _stft.mel_spectrogram(audio) 13 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 15 | 16 | return melspec, energy 17 | 18 | 19 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 20 | mel = torch.stack([mel]) 21 | mel_decompress = _stft.spectral_de_normalize(mel) 22 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 23 | spec_from_mel_scaling = 1000 24 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 25 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 26 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 27 | 28 | audio = griffin_lim( 29 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 30 | ) 31 | 32 | audio = audio.squeeze() 33 | audio = audio.cpu().numpy() 34 | audio_path = out_filename 35 | write(audio_path, _stft.sampling_rate, audio) 36 | 37 | 38 | mel_basis = {} 39 | hann_window = {} 40 | 41 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 42 | return torch.log(torch.clamp(x, min=clip_val) * C) 43 | def spectral_normalize_torch(magnitudes): 44 | output = dynamic_range_compression_torch(magnitudes) 45 | return output 46 | 47 | def get_mel_from_wav_as_hifigan(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 48 | if torch.min(y) < -1.: 49 | print('min value is ', torch.min(y)) 50 | if torch.max(y) > 1.: 51 | print('max value is ', torch.max(y)) 52 | 53 | global mel_basis, hann_window 54 | if fmax not in mel_basis: 55 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 56 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 57 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 58 | 59 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 60 | y = y.squeeze(1) 61 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 62 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 63 | 64 | magnitudes = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 65 | 66 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], magnitudes) 67 | spec = spectral_normalize_torch(spec) 68 | 69 | energy = torch.norm(magnitudes, dim=1) 70 | return spec.squeeze(0).numpy(), energy.squeeze(0).numpy() 71 | 72 | -------------------------------------------------------------------------------- /synthesizer/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | from .text import text_to_sequence 7 | from .utils.tools import pad_1D, pad_2D 8 | 9 | 10 | class Dataset(Dataset): 11 | def __init__( 12 | self, filename, preprocess_config, train_config, sort=False, drop_last=False 13 | ): 14 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 15 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 16 | self.batch_size = train_config["optimizer"]["batch_size"] 17 | 18 | self.basename, self.speaker, self.text = self.process_meta( 19 | filename 20 | ) 21 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: 22 | self.speaker_map = json.load(f) 23 | self.sort = sort 24 | self.drop_last = drop_last 25 | 26 | def __len__(self): 27 | return len(self.text) 28 | 29 | def __getitem__(self, idx): 30 | basename = self.basename[idx] 31 | speaker = self.speaker[idx] 32 | speaker_id = self.speaker_map[speaker] 33 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 34 | mel_path = os.path.join( 35 | self.preprocessed_path, 36 | "mel", 37 | "{}-mel-{}.npy".format(speaker, basename), 38 | ) 39 | mel = np.load(mel_path) 40 | pitch_path = os.path.join( 41 | self.preprocessed_path, 42 | "pitch", 43 | "{}-pitch-{}.npy".format(speaker, basename), 44 | ) 45 | pitch = np.load(pitch_path) 46 | energy_path = os.path.join( 47 | self.preprocessed_path, 48 | "energy", 49 | "{}-energy-{}.npy".format(speaker, basename), 50 | ) 51 | energy = np.load(energy_path) 52 | duration_path = os.path.join( 53 | self.preprocessed_path, 54 | "duration", 55 | "{}-duration-{}.npy".format(speaker, basename), 56 | ) 57 | duration = np.load(duration_path) 58 | 59 | sample = { 60 | "id": basename, 61 | "speaker": speaker_id, 62 | "text": phone, 63 | "mel": mel, 64 | "pitch": pitch, 65 | "energy": energy, 66 | "duration": duration, 67 | } 68 | 69 | return sample 70 | 71 | def process_meta(self, filename): 72 | with open( 73 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8" 74 | ) as f: 75 | name = [] 76 | speaker = [] 77 | text = [] 78 | for line in f.readlines(): 79 | n, s, t, r = line.strip("\n").split("|") 80 | name.append(n) 81 | speaker.append(s) 82 | text.append(t) 83 | return name, speaker, text 84 | 85 | def reprocess(self, data, idxs): 86 | ids = [data[idx]["id"] for idx in idxs] 87 | speakers = [data[idx]["speaker"] for idx in idxs] 88 | texts = [data[idx]["text"] for idx in idxs] 89 | mels = [data[idx]["mel"] for idx in idxs] 90 | pitches = [data[idx]["pitch"] for idx in idxs] 91 | energies = [data[idx]["energy"] for idx in idxs] 92 | durations = [data[idx]["duration"] for idx in idxs] 93 | 94 | text_lens = np.array([text.shape[0] for text in texts]) 95 | mel_lens = np.array([mel.shape[0] for mel in mels]) 96 | 97 | speakers = np.array(speakers) 98 | texts = pad_1D(texts) 99 | mels = pad_2D(mels) 100 | pitches = pad_1D(pitches) 101 | energies = pad_1D(energies) 102 | durations = pad_1D(durations) 103 | 104 | return ( 105 | ids, 106 | speakers, 107 | texts, 108 | text_lens, 109 | max(text_lens), 110 | mels, 111 | mel_lens, 112 | max(mel_lens), 113 | pitches, 114 | energies, 115 | durations, 116 | ) 117 | 118 | def collate_fn(self, data): 119 | data_size = len(data) 120 | 121 | if self.sort: 122 | len_arr = np.array([d["text"].shape[0] for d in data]) 123 | idx_arr = np.argsort(-len_arr) 124 | else: 125 | idx_arr = np.arange(data_size) 126 | 127 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :] 128 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)] 129 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() 130 | if not self.drop_last and len(tail) > 0: 131 | idx_arr += [tail.tolist()] 132 | 133 | output = list() 134 | for idx in idx_arr: 135 | output.append(self.reprocess(data, idx)) 136 | 137 | return output -------------------------------------------------------------------------------- /synthesizer/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from .utils.tools import to_device, log, synth_one_sample 5 | from .model import FastSpeech2Loss 6 | from .dataset import Dataset 7 | 8 | 9 | def evaluate(model, step, config, logger=None, vocoder=None): 10 | preprocess_config, model_config, train_config = config['synthesizer']['preprocess'], config['synthesizer']['model'], config['synthesizer']['train'] 11 | device = config['main']['device'] 12 | # Get dataset 13 | dataset = Dataset( 14 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False 15 | ) 16 | batch_size = train_config["optimizer"]["batch_size"] 17 | loader = DataLoader( 18 | dataset, 19 | batch_size=batch_size, 20 | shuffle=False, 21 | collate_fn=dataset.collate_fn, 22 | ) 23 | 24 | # Get loss function 25 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) 26 | 27 | # Evaluation 28 | loss_sums = [0 for _ in range(6)] 29 | for batchs in loader: 30 | for batch in batchs: 31 | batch = to_device(batch, device) 32 | with torch.no_grad(): 33 | # Forward 34 | output = model(*(batch[1:])) 35 | 36 | # Cal Loss 37 | losses = Loss(batch, output) 38 | 39 | for i in range(len(losses)): 40 | loss_sums[i] += losses[i].item() * len(batch[0]) 41 | 42 | loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums] 43 | 44 | message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format( 45 | *([step] + [l for l in loss_means]) 46 | ) 47 | 48 | if logger is not None: 49 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 50 | batch, 51 | output, 52 | vocoder, 53 | model_config, 54 | preprocess_config, 55 | ) 56 | 57 | log(logger, step, losses=loss_means) 58 | log( 59 | logger, 60 | fig=fig, 61 | tag="Validation/step_{}_{}".format(step, tag), 62 | ) 63 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 64 | log( 65 | logger, 66 | audio=wav_reconstruction, 67 | sampling_rate=sampling_rate, 68 | tag="Validation/step_{}_{}_reconstructed".format(step, tag), 69 | ) 70 | log( 71 | logger, 72 | audio=wav_prediction, 73 | sampling_rate=sampling_rate, 74 | tag="Validation/step_{}_{}_synthesized".format(step, tag), 75 | ) 76 | 77 | return message 78 | 79 | 80 | -------------------------------------------------------------------------------- /synthesizer/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2 import FastSpeech2 2 | from .loss import FastSpeech2Loss 3 | from .optimizer import ScheduledOptim -------------------------------------------------------------------------------- /synthesizer/model/fastspeech2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ..transformer import Encoder, Decoder, PostNet 8 | from .modules import VarianceAdaptor 9 | from ..utils.tools import get_mask_from_lengths 10 | 11 | 12 | class FastSpeech2(nn.Module): 13 | """ FastSpeech2 """ 14 | 15 | def __init__(self, config): 16 | super(FastSpeech2, self).__init__() 17 | preprocess_config, model_config = config['synthesizer']['preprocess'], config['synthesizer']['model'] 18 | self.model_config = model_config 19 | self.device = config['main']['device'] 20 | 21 | self.encoder = Encoder(model_config, preprocess_config['preprocessing']['text']['language']) 22 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config, self.device) 23 | self.decoder = Decoder(model_config) 24 | self.mel_linear = nn.Linear( 25 | model_config["transformer"]["decoder_hidden"], 26 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 27 | ) 28 | self.postnet = PostNet() 29 | 30 | self.speaker_emb = None 31 | if config['main']['multi_speaker']: 32 | with open( 33 | os.path.join( 34 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 35 | ), 36 | "r", 37 | ) as f: 38 | n_speaker = len(json.load(f)) 39 | self.speaker_emb = nn.Embedding( 40 | n_speaker, 41 | model_config["transformer"]["encoder_hidden"], 42 | ) 43 | 44 | def forward( 45 | self, 46 | speakers, 47 | texts, 48 | src_lens, 49 | max_src_len, 50 | mels=None, 51 | mel_lens=None, 52 | max_mel_len=None, 53 | p_targets=None, 54 | e_targets=None, 55 | d_targets=None, 56 | p_control=1.0, 57 | e_control=1.0, 58 | d_control=1.0, 59 | ): 60 | src_masks = get_mask_from_lengths(src_lens, max_src_len, self.device) 61 | mel_masks = ( 62 | get_mask_from_lengths(mel_lens, max_mel_len, self.device) 63 | if mel_lens is not None 64 | else None 65 | ) 66 | output = self.encoder(texts, src_masks) 67 | 68 | if self.speaker_emb is not None: 69 | self.speaker_vec = self.speaker_emb(speakers).unsqueeze(1) 70 | output = output + self.speaker_vec.expand(-1, max_src_len, -1) 71 | 72 | 73 | ( 74 | output, 75 | p_predictions, 76 | e_predictions, 77 | log_d_predictions, 78 | d_rounded, 79 | mel_lens, 80 | mel_masks, 81 | ) = self.variance_adaptor( 82 | output, 83 | src_masks, 84 | mel_masks, 85 | max_mel_len, 86 | p_targets, 87 | e_targets, 88 | d_targets, 89 | p_control, 90 | e_control, 91 | d_control, 92 | ) 93 | 94 | # if self.speaker_emb is not None: 95 | # output = output + self.speaker_vec.expand(-1, output.shape[1], -1) 96 | 97 | output, mel_masks = self.decoder(output, mel_masks) 98 | output = self.mel_linear(output) 99 | postnet_output = self.postnet(output) + output 100 | 101 | return ( 102 | output, 103 | postnet_output, 104 | p_predictions, 105 | e_predictions, 106 | log_d_predictions, 107 | d_rounded, 108 | src_masks, 109 | mel_masks, 110 | src_lens, 111 | mel_lens, 112 | ) 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /synthesizer/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FastSpeech2Loss(nn.Module): 6 | """ FastSpeech2 Loss """ 7 | 8 | def __init__(self, preprocess_config, model_config): 9 | super(FastSpeech2Loss, self).__init__() 10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 11 | "feature" 12 | ] 13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 14 | "feature" 15 | ] 16 | self.mse_loss = nn.MSELoss() 17 | self.mae_loss = nn.L1Loss() 18 | 19 | def forward(self, inputs, predictions): 20 | ( 21 | mel_targets, 22 | _, 23 | _, 24 | pitch_targets, 25 | energy_targets, 26 | duration_targets, 27 | ) = inputs[5:] 28 | ( 29 | mel_predictions, 30 | postnet_mel_predictions, 31 | pitch_predictions, 32 | energy_predictions, 33 | log_duration_predictions, 34 | _, 35 | src_masks, 36 | mel_masks, 37 | _, 38 | _, 39 | ) = predictions 40 | src_masks = ~src_masks 41 | mel_masks = ~mel_masks 42 | log_duration_targets = torch.log(duration_targets.float() + 1) 43 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 44 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 45 | 46 | log_duration_targets.requires_grad = False 47 | pitch_targets.requires_grad = False 48 | energy_targets.requires_grad = False 49 | mel_targets.requires_grad = False 50 | 51 | if self.pitch_feature_level == "phoneme_level": 52 | pitch_predictions = pitch_predictions.masked_select(src_masks) 53 | pitch_targets = pitch_targets.masked_select(src_masks) 54 | elif self.pitch_feature_level == "frame_level": 55 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 56 | pitch_targets = pitch_targets.masked_select(mel_masks) 57 | 58 | if self.energy_feature_level == "phoneme_level": 59 | energy_predictions = energy_predictions.masked_select(src_masks) 60 | energy_targets = energy_targets.masked_select(src_masks) 61 | if self.energy_feature_level == "frame_level": 62 | energy_predictions = energy_predictions.masked_select(mel_masks) 63 | energy_targets = energy_targets.masked_select(mel_masks) 64 | 65 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 66 | log_duration_targets = log_duration_targets.masked_select(src_masks) 67 | 68 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 69 | postnet_mel_predictions = postnet_mel_predictions.masked_select( 70 | mel_masks.unsqueeze(-1) 71 | ) 72 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 73 | 74 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 75 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) 76 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 77 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 78 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) 79 | 80 | total_loss = ( 81 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss 82 | ) 83 | 84 | return ( 85 | total_loss, 86 | mel_loss, 87 | postnet_mel_loss, 88 | pitch_loss, 89 | energy_loss, 90 | duration_loss, 91 | ) 92 | -------------------------------------------------------------------------------- /synthesizer/model/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | from ..utils.tools import get_mask_from_lengths, pad 11 | 12 | 13 | class VarianceAdaptor(nn.Module): 14 | """Variance Adaptor""" 15 | 16 | def __init__(self, preprocess_config, model_config, device): 17 | super(VarianceAdaptor, self).__init__() 18 | self.duration_predictor = VariancePredictor(model_config) 19 | self.length_regulator = LengthRegulator() 20 | self.pitch_predictor = VariancePredictor(model_config) 21 | self.energy_predictor = VariancePredictor(model_config) 22 | 23 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 24 | "feature" 25 | ] 26 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 27 | "feature" 28 | ] 29 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 30 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 31 | 32 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"] 33 | energy_quantization = model_config["variance_embedding"]["energy_quantization"] 34 | n_bins = model_config["variance_embedding"]["n_bins"] 35 | assert pitch_quantization in ["linear", "log"] 36 | assert energy_quantization in ["linear", "log"] 37 | with open( 38 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 39 | ) as f: 40 | stats = json.load(f) 41 | pitch_min, pitch_max = stats["pitch"][:2] 42 | energy_min, energy_max = stats["energy"][:2] 43 | 44 | if pitch_quantization == "log": 45 | self.pitch_bins = nn.Parameter( 46 | torch.exp( 47 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) 48 | ), 49 | requires_grad=False, 50 | ) 51 | else: 52 | self.pitch_bins = nn.Parameter( 53 | torch.linspace(pitch_min, pitch_max, n_bins - 1), 54 | requires_grad=False, 55 | ) 56 | if energy_quantization == "log": 57 | self.energy_bins = nn.Parameter( 58 | torch.exp( 59 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) 60 | ), 61 | requires_grad=False, 62 | ) 63 | else: 64 | self.energy_bins = nn.Parameter( 65 | torch.linspace(energy_min, energy_max, n_bins - 1), 66 | requires_grad=False, 67 | ) 68 | 69 | self.pitch_embedding = nn.Embedding( 70 | n_bins, model_config["transformer"]["encoder_hidden"] 71 | ) 72 | self.energy_embedding = nn.Embedding( 73 | n_bins, model_config["transformer"]["encoder_hidden"] 74 | ) 75 | self.device = device 76 | 77 | def get_pitch_embedding(self, x, target, mask, control): 78 | prediction = self.pitch_predictor(x, mask) 79 | if target is not None: 80 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) 81 | else: 82 | prediction = prediction * control 83 | embedding = self.pitch_embedding( 84 | torch.bucketize(prediction, self.pitch_bins) 85 | ) 86 | return prediction, embedding 87 | 88 | def get_energy_embedding(self, x, target, mask, control): 89 | prediction = self.energy_predictor(x, mask) 90 | if target is not None: 91 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) 92 | else: 93 | prediction = prediction * control 94 | embedding = self.energy_embedding( 95 | torch.bucketize(prediction, self.energy_bins) 96 | ) 97 | return prediction, embedding 98 | 99 | def forward( 100 | self, 101 | x, 102 | src_mask, 103 | mel_mask=None, 104 | max_len=None, 105 | pitch_target=None, 106 | energy_target=None, 107 | duration_target=None, 108 | p_control=1.0, 109 | e_control=1.0, 110 | d_control=1.0, 111 | ): 112 | 113 | log_duration_prediction = self.duration_predictor(x, src_mask) 114 | if self.pitch_feature_level == "phoneme_level": 115 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 116 | x, pitch_target, src_mask, p_control 117 | ) 118 | x = x + pitch_embedding 119 | if self.energy_feature_level == "phoneme_level": 120 | energy_prediction, energy_embedding = self.get_energy_embedding( 121 | x, energy_target, src_mask, p_control 122 | ) 123 | x = x + energy_embedding 124 | 125 | if duration_target is not None: 126 | x, mel_len = self.length_regulator(x, duration_target, max_len, self.device) 127 | duration_rounded = duration_target 128 | if mel_mask is None: 129 | mel_mask = get_mask_from_lengths(mel_len, device=self.device) 130 | 131 | else: 132 | duration_rounded = torch.clamp( 133 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 134 | min=0, 135 | ) 136 | x, mel_len = self.length_regulator(x, duration_rounded, max_len, self.device) 137 | mel_mask = get_mask_from_lengths(mel_len, device=self.device) 138 | 139 | if self.pitch_feature_level == "frame_level": 140 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 141 | x, pitch_target, mel_mask, p_control 142 | ) 143 | x = x + pitch_embedding 144 | if self.energy_feature_level == "frame_level": 145 | energy_prediction, energy_embedding = self.get_energy_embedding( 146 | x, energy_target, mel_mask, p_control 147 | ) 148 | x = x + energy_embedding 149 | 150 | return ( 151 | x, 152 | pitch_prediction, 153 | energy_prediction, 154 | log_duration_prediction, 155 | duration_rounded, 156 | mel_len, 157 | mel_mask, 158 | ) 159 | 160 | 161 | class LengthRegulator(nn.Module): 162 | """Length Regulator""" 163 | 164 | def __init__(self): 165 | super(LengthRegulator, self).__init__() 166 | 167 | def LR(self, x, duration, max_len, device): 168 | output = list() 169 | mel_len = list() 170 | for batch, expand_target in zip(x, duration): 171 | expanded = self.expand(batch, expand_target) 172 | output.append(expanded) 173 | mel_len.append(expanded.shape[0]) 174 | 175 | if max_len is not None: 176 | output = pad(output, max_len) 177 | else: 178 | output = pad(output) 179 | 180 | return output, torch.LongTensor(mel_len).to(device) 181 | 182 | def expand(self, batch, predicted): 183 | out = list() 184 | 185 | for i, vec in enumerate(batch): 186 | expand_size = predicted[i].item() 187 | out.append(vec.expand(max(int(expand_size), 0), -1)) 188 | out = torch.cat(out, 0) 189 | 190 | return out 191 | 192 | def forward(self, x, duration, max_len, device): 193 | output, mel_len = self.LR(x, duration, max_len, device) 194 | return output, mel_len 195 | 196 | 197 | class VariancePredictor(nn.Module): 198 | """Duration, Pitch and Energy Predictor""" 199 | 200 | def __init__(self, model_config): 201 | super(VariancePredictor, self).__init__() 202 | 203 | self.input_size = model_config["transformer"]["encoder_hidden"] 204 | self.filter_size = model_config["variance_predictor"]["filter_size"] 205 | self.kernel = model_config["variance_predictor"]["kernel_size"] 206 | self.conv_output_size = model_config["variance_predictor"]["filter_size"] 207 | self.dropout = model_config["variance_predictor"]["dropout"] 208 | 209 | self.conv_layer = nn.Sequential( 210 | OrderedDict( 211 | [ 212 | ( 213 | "conv1d_1", 214 | Conv( 215 | self.input_size, 216 | self.filter_size, 217 | kernel_size=self.kernel, 218 | padding=(self.kernel - 1) // 2, 219 | ), 220 | ), 221 | ("relu_1", nn.ReLU()), 222 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 223 | ("dropout_1", nn.Dropout(self.dropout)), 224 | ( 225 | "conv1d_2", 226 | Conv( 227 | self.filter_size, 228 | self.filter_size, 229 | kernel_size=self.kernel, 230 | padding=1, 231 | ), 232 | ), 233 | ("relu_2", nn.ReLU()), 234 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 235 | ("dropout_2", nn.Dropout(self.dropout)), 236 | ] 237 | ) 238 | ) 239 | 240 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 241 | 242 | def forward(self, encoder_output, mask): 243 | out = self.conv_layer(encoder_output) 244 | out = self.linear_layer(out) 245 | out = out.squeeze(-1) 246 | 247 | if mask is not None: 248 | out = out.masked_fill(mask, 0.0) 249 | 250 | return out 251 | 252 | 253 | class Conv(nn.Module): 254 | """ 255 | Convolution Module 256 | """ 257 | 258 | def __init__( 259 | self, 260 | in_channels, 261 | out_channels, 262 | kernel_size=1, 263 | stride=1, 264 | padding=0, 265 | dilation=1, 266 | bias=True, 267 | w_init="linear", 268 | ): 269 | """ 270 | :param in_channels: dimension of input 271 | :param out_channels: dimension of output 272 | :param kernel_size: size of kernel 273 | :param stride: size of stride 274 | :param padding: size of padding 275 | :param dilation: dilation rate 276 | :param bias: boolean. if True, bias is included. 277 | :param w_init: str. weight inits with xavier initialization. 278 | """ 279 | super(Conv, self).__init__() 280 | 281 | self.conv = nn.Conv1d( 282 | in_channels, 283 | out_channels, 284 | kernel_size=kernel_size, 285 | stride=stride, 286 | padding=padding, 287 | dilation=dilation, 288 | bias=bias, 289 | ) 290 | 291 | def forward(self, x): 292 | x = x.contiguous().transpose(1, 2) 293 | x = self.conv(x) 294 | x = x.contiguous().transpose(1, 2) 295 | 296 | return x 297 | -------------------------------------------------------------------------------- /synthesizer/model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, train_config, model_config, current_step): 9 | 10 | self._optimizer = torch.optim.Adam( 11 | model.parameters(), 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"] 19 | self.current_step = current_step 20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 21 | 22 | def step_and_update_lr(self): 23 | self._update_learning_rate() 24 | self._optimizer.step() 25 | 26 | def zero_grad(self): 27 | self._optimizer.zero_grad() 28 | 29 | def load_state_dict(self, path): 30 | self._optimizer.load_state_dict(path) 31 | 32 | def _get_lr_scale(self): 33 | lr = np.min( 34 | [ 35 | np.power(self.current_step, -0.5), 36 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 37 | ] 38 | ) 39 | for s in self.anneal_steps: 40 | if self.current_step > s: 41 | lr = lr * self.anneal_rate 42 | return lr 43 | 44 | def _update_learning_rate(self): 45 | """ Learning rate scheduling per step """ 46 | self.current_step += 1 47 | lr = self.init_lr * self._get_lr_scale() 48 | 49 | for param_group in self._optimizer.param_groups: 50 | param_group["lr"] = lr 51 | -------------------------------------------------------------------------------- /synthesizer/prepare_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from preprocessor import persian, ljspeech 4 | from utils.tools import load_yaml_file 5 | 6 | 7 | def main(config): 8 | if "Persian" in config["main"]["dataset"]: 9 | persian.prepare_align(config['synthesizer']['preprocess']) 10 | elif config["main"]["dataset"] == "LJSpeech": 11 | ljspeech.prepare_align(config['synthesizer']['preprocess']) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("config", type=str, help="path to config.yaml") 17 | args = parser.parse_args() 18 | 19 | config = load_yaml_file(args.config) 20 | main(config) 21 | -------------------------------------------------------------------------------- /synthesizer/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from preprocessor.preprocessor import Preprocessor 4 | from utils.tools import load_yaml_file 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("config", type=str, help="path to config.yaml") 10 | args = parser.parse_args() 11 | 12 | config = load_yaml_file(args.config) 13 | preprocessor = Preprocessor(config['synthesizer']['preprocess']) 14 | preprocessor.build_from_path() 15 | -------------------------------------------------------------------------------- /synthesizer/preprocessor/ljspeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | from librosa.util import normalize 8 | 9 | from text import _clean_text 10 | 11 | 12 | def prepare_align(config): 13 | in_dir = config["path"]["corpus_path"] 14 | out_dir = config["path"]["raw_path"] 15 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 16 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 17 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 18 | 19 | speaker = "single_speaker" 20 | with open(os.path.join(in_dir, "train.txt"), encoding="utf-8") as f: 21 | for line in tqdm(f.readlines()): 22 | parts = line.strip().split("|") 23 | base_name = parts[0] 24 | text = parts[2] 25 | wav_path = os.path.join(in_dir, 'wavs', base_name + '.wav') 26 | 27 | text = _clean_text(text, cleaners) 28 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 29 | wav, _ = librosa.load(wav_path, sr=sampling_rate) 30 | wav = wav / max_wav_value 31 | wav = normalize(wav) * 0.95 32 | wavfile.write( 33 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 34 | sampling_rate, 35 | wav.astype(np.float32), 36 | ) 37 | with open( 38 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 39 | "w", 40 | ) as f1: 41 | f1.write(text) 42 | -------------------------------------------------------------------------------- /synthesizer/preprocessor/persian.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | from librosa.util import normalize 8 | 9 | from text import _clean_text 10 | 11 | 12 | def prepare_align(config): 13 | in_dir = config["path"]["corpus_path"] 14 | out_dir = config["path"]["raw_path"] 15 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 16 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 17 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 18 | 19 | with open(os.path.join(in_dir, "train.txt"), encoding="utf-8") as f: 20 | for line in tqdm(f.readlines()): 21 | parts = line.strip().split("|") 22 | base_name = parts[0] 23 | speaker = parts[1] 24 | text = parts[2] 25 | wav_path = os.path.join(in_dir, 'train_data',speaker, base_name + '.wav') 26 | 27 | text = _clean_text(text, cleaners) 28 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 29 | wav, _ = librosa.load(wav_path, sr=sampling_rate) 30 | wav = wav / max_wav_value 31 | wav = normalize(wav) * 0.95 32 | wavfile.write( 33 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 34 | sampling_rate, 35 | wav.astype(np.float32), 36 | ) 37 | with open( 38 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 39 | "w", 40 | ) as f1: 41 | f1.write(text) 42 | -------------------------------------------------------------------------------- /synthesizer/preprocessor/persian_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import numpy as np 4 | from scipy.io import wavfile 5 | from tqdm import tqdm 6 | from librosa.util import normalize 7 | from scipy.io.wavfile import read 8 | 9 | from text import _clean_text 10 | 11 | 12 | def prepare_align(config): 13 | in_dir = config["path"]["corpus_path"] 14 | out_dir = config["path"]["raw_path"] 15 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 16 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 17 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 18 | speakers = os.listdir(in_dir) 19 | for i, speaker in enumerate(speakers): 20 | loop =tqdm(os.listdir(os.path.join(in_dir, speaker))) 21 | loop.set_description(f'speaker count = {i+1}/{len(speakers)}') 22 | for file_name in loop: 23 | if file_name[-4:] != ".wav": 24 | continue 25 | base_name = file_name[:-4] 26 | text_path = os.path.join( 27 | in_dir, speaker, "{}.txt".format(base_name) 28 | ) 29 | wav_path = os.path.join( 30 | in_dir, speaker, "{}.wav".format(base_name) 31 | ) 32 | with open(text_path) as f: 33 | text = f.readline().strip("\n") 34 | text = _clean_text(text, cleaners) 35 | 36 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 37 | wav, _ = librosa.load(wav_path, sr=sampling_rate) 38 | wav = wav / max_wav_value 39 | wav = normalize(wav) * 0.95 40 | 41 | wavfile.write( 42 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 43 | sampling_rate, 44 | wav.astype(np.float32), 45 | ) 46 | with open( 47 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 48 | "w", 49 | ) as f1: 50 | f1.write(text) 51 | -------------------------------------------------------------------------------- /synthesizer/preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | 5 | import tgt 6 | import librosa 7 | import numpy as np 8 | import pyworld as pw 9 | from scipy.interpolate import interp1d 10 | from sklearn.preprocessing import StandardScaler 11 | from tqdm import tqdm 12 | import torch 13 | from scipy.io.wavfile import read 14 | 15 | import audio as Audio 16 | 17 | 18 | class Preprocessor: 19 | def __init__(self, config): 20 | self.config = config 21 | self.in_dir = config["path"]["raw_path"] 22 | self.out_dir = config["path"]["preprocessed_path"] 23 | self.val_size = config["preprocessing"]["val_size"] 24 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 25 | self.hop_length = config["preprocessing"]["stft"]["hop_length"] 26 | 27 | assert config["preprocessing"]["pitch"]["feature"] in [ 28 | "phoneme_level", 29 | "frame_level", 30 | ] 31 | assert config["preprocessing"]["energy"]["feature"] in [ 32 | "phoneme_level", 33 | "frame_level", 34 | ] 35 | self.pitch_phoneme_averaging = ( 36 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level" 37 | ) 38 | self.energy_phoneme_averaging = ( 39 | config["preprocessing"]["energy"]["feature"] == "phoneme_level" 40 | ) 41 | 42 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"] 43 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"] 44 | 45 | self.preprocessing_confing = config["preprocessing"] 46 | 47 | def build_from_path(self): 48 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True) 49 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True) 50 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True) 51 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True) 52 | 53 | print("Processing Data ...") 54 | out = list() 55 | n_frames = 0 56 | pitch_scaler = StandardScaler() 57 | energy_scaler = StandardScaler() 58 | 59 | # Compute pitch, energy, duration, and mel-spectrogram 60 | speakers = {} 61 | speakers_list = os.listdir(self.in_dir) 62 | for i, speaker in enumerate(speakers_list): 63 | speakers[speaker] = i 64 | loop = tqdm(os.listdir(os.path.join(self.in_dir, speaker))) 65 | loop.set_description(f'speaker count = {i+1}/{len(speakers_list)}') 66 | for wav_name in loop: 67 | if ".wav" not in wav_name: 68 | continue 69 | 70 | basename = wav_name.split(".")[0] 71 | tg_path = os.path.join( 72 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 73 | ) 74 | if os.path.exists(tg_path): 75 | ret = self.process_utterance(speaker, basename) 76 | if ret is None: 77 | continue 78 | else: 79 | info, pitch, energy, n = ret 80 | out.append(info) 81 | 82 | if len(pitch) > 0: 83 | pitch_scaler.partial_fit(pitch.reshape((-1, 1))) 84 | if len(energy) > 0: 85 | energy_scaler.partial_fit(energy.reshape((-1, 1))) 86 | 87 | n_frames += n 88 | 89 | print("Computing statistic quantities ...") 90 | # Perform normalization if necessary 91 | if self.pitch_normalization: 92 | pitch_mean = pitch_scaler.mean_[0] 93 | pitch_std = pitch_scaler.scale_[0] 94 | else: 95 | # A numerical trick to avoid normalization... 96 | pitch_mean = 0 97 | pitch_std = 1 98 | if self.energy_normalization: 99 | energy_mean = energy_scaler.mean_[0] 100 | energy_std = energy_scaler.scale_[0] 101 | else: 102 | energy_mean = 0 103 | energy_std = 1 104 | 105 | pitch_min, pitch_max = self.normalize( 106 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std 107 | ) 108 | energy_min, energy_max = self.normalize( 109 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std 110 | ) 111 | 112 | # Save files 113 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f: 114 | f.write(json.dumps(speakers)) 115 | 116 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f: 117 | stats = { 118 | "pitch": [ 119 | float(pitch_min), 120 | float(pitch_max), 121 | float(pitch_mean), 122 | float(pitch_std), 123 | ], 124 | "energy": [ 125 | float(energy_min), 126 | float(energy_max), 127 | float(energy_mean), 128 | float(energy_std), 129 | ], 130 | } 131 | f.write(json.dumps(stats)) 132 | 133 | print( 134 | "Total data time: {} hours".format( 135 | n_frames * self.hop_length / self.sampling_rate / 3600 136 | ) 137 | ) 138 | 139 | random.shuffle(out) 140 | out = [r for r in out if r is not None] 141 | 142 | # Write metadata 143 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f: 144 | for m in out[self.val_size :]: 145 | f.write(m + "\n") 146 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f: 147 | for m in out[: self.val_size]: 148 | f.write(m + "\n") 149 | 150 | return out 151 | 152 | def process_utterance(self, speaker, basename): 153 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename)) 154 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename)) 155 | tg_path = os.path.join( 156 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 157 | ) 158 | 159 | # Get alignments 160 | textgrid = tgt.io.read_textgrid(tg_path) 161 | phone, duration, start, end = self.get_alignment( 162 | textgrid.get_tier_by_name("phones") 163 | ) 164 | text = "{" + " ".join(phone) + "}" 165 | if start >= end: 166 | return None 167 | 168 | # Read and trim wav files 169 | wav, _ = librosa.load(wav_path, sr=None) 170 | wav = wav[ 171 | int(self.sampling_rate * start) : int(self.sampling_rate * end) 172 | ].astype(np.float32) 173 | 174 | # Read raw text 175 | with open(text_path, "r") as f: 176 | raw_text = f.readline().strip("\n") 177 | 178 | # Compute fundamental frequency 179 | pitch, t = pw.dio( 180 | wav.astype(np.float64), 181 | self.sampling_rate, 182 | frame_period=self.hop_length / self.sampling_rate * 1000, 183 | ) 184 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate) 185 | 186 | wav = torch.FloatTensor(wav) 187 | wav = wav.unsqueeze(0) 188 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav_as_hifigan(wav, self.preprocessing_confing["stft"]["filter_length"], \ 189 | self.preprocessing_confing["mel"]["n_mel_channels"], self.preprocessing_confing["audio"]["sampling_rate"], \ 190 | self.preprocessing_confing["stft"]["hop_length"], self.preprocessing_confing["stft"]["win_length"], \ 191 | self.preprocessing_confing["mel"]["mel_fmin"], self.preprocessing_confing["mel"]["mel_fmax"], center=False) 192 | 193 | ## Matching tensors size 194 | min_size = min(sum(duration), mel_spectrogram.shape[1], energy.shape[0], pitch.shape[0]) 195 | duration[-1] = duration[-1] + (min_size - sum(duration)) 196 | mel_spectrogram = mel_spectrogram[:, : sum(duration)] 197 | energy = energy[: sum(duration)] 198 | pitch = pitch[: sum(duration)] 199 | 200 | if np.sum(pitch != 0) <= 1: 201 | return None 202 | 203 | if self.pitch_phoneme_averaging: 204 | # perform linear interpolation 205 | nonzero_ids = np.where(pitch != 0)[0] 206 | interp_fn = interp1d( 207 | nonzero_ids, 208 | pitch[nonzero_ids], 209 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 210 | bounds_error=False, 211 | ) 212 | pitch = interp_fn(np.arange(0, len(pitch))) 213 | 214 | # Phoneme-level average 215 | pos = 0 216 | for i, d in enumerate(duration): 217 | if d > 0: 218 | pitch[i] = np.mean(pitch[pos : pos + d]) 219 | else: 220 | pitch[i] = 0 221 | pos += d 222 | pitch = pitch[: len(duration)] 223 | 224 | if self.energy_phoneme_averaging: 225 | # Phoneme-level average 226 | pos = 0 227 | for i, d in enumerate(duration): 228 | if d > 0: 229 | energy[i] = np.mean(energy[pos : pos + d]) 230 | else: 231 | energy[i] = 0 232 | pos += d 233 | energy = energy[: len(duration)] 234 | 235 | # Save files 236 | # if speaker: 237 | dur_filename = "{}-duration-{}.npy".format(speaker, basename) 238 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename) 239 | energy_filename = "{}-energy-{}.npy".format(speaker, basename) 240 | mel_filename = "{}-mel-{}.npy".format(speaker, basename) 241 | # else: 242 | # dur_filename = "duration-{}.npy".format(basename) 243 | # pitch_filename = "pitch-{}.npy".format(basename) 244 | # energy_filename = "energy-{}.npy".format(basename) 245 | # mel_filename = "mel-{}.npy".format(basename) 246 | 247 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration) 248 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch) 249 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy) 250 | np.save( 251 | os.path.join(self.out_dir, "mel", mel_filename), 252 | mel_spectrogram.T, 253 | ) 254 | 255 | return ( 256 | "|".join([basename, speaker, text, raw_text]), 257 | self.remove_outlier(pitch), 258 | self.remove_outlier(energy), 259 | mel_spectrogram.shape[1], 260 | ) 261 | 262 | def get_alignment(self, tier): 263 | sil_phones = ["sil", "sp", "spn"] 264 | 265 | phones = [] 266 | durations = [] 267 | start_time = 0 268 | end_time = 0 269 | end_idx = 0 270 | for t in tier._objects: 271 | s, e, p = t.start_time, t.end_time, t.text 272 | 273 | # Trim leading silences 274 | if phones == []: 275 | if p in sil_phones: 276 | continue 277 | else: 278 | start_time = s 279 | 280 | if p not in sil_phones: 281 | # For ordinary phones 282 | phones.append(p) 283 | end_time = e 284 | end_idx = len(phones) 285 | else: 286 | # For silent phones 287 | phones.append(p) 288 | 289 | durations.append( 290 | int( 291 | np.round(e * self.sampling_rate / self.hop_length) 292 | - np.round(s * self.sampling_rate / self.hop_length) 293 | ) 294 | ) 295 | 296 | # Trim tailing silences 297 | phones = phones[:end_idx] 298 | durations = durations[:end_idx] 299 | 300 | return phones, durations, start_time, end_time 301 | 302 | def remove_outlier(self, values): 303 | values = np.array(values) 304 | p25 = np.percentile(values, 25) 305 | p75 = np.percentile(values, 75) 306 | lower = p25 - 1.5 * (p75 - p25) 307 | upper = p75 + 1.5 * (p75 - p25) 308 | normal_indices = np.logical_and(values > lower, values < upper) 309 | 310 | return values[normal_indices] 311 | 312 | def normalize(self, in_dir, mean, std): 313 | max_value = np.finfo(np.float64).min 314 | min_value = np.finfo(np.float64).max 315 | for filename in os.listdir(in_dir): 316 | filename = os.path.join(in_dir, filename) 317 | values = (np.load(filename) - mean) / std 318 | np.save(filename, values) 319 | 320 | max_value = max(max_value, max(values)) 321 | min_value = min(min_value, min(values)) 322 | 323 | return min_value, max_value 324 | -------------------------------------------------------------------------------- /synthesizer/synthesize.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | from string import punctuation 5 | import re 6 | from g2p_en import G2p 7 | 8 | from .utils.tools import to_device, prepare_outputs 9 | from .text import text_to_sequence 10 | 11 | def read_lexicon(lex_path): 12 | lexicon = {} 13 | with open(lex_path) as f: 14 | for line in f: 15 | temp = re.split(r"\s+", line.strip("\n")) 16 | word = temp[0] 17 | phones = temp[1:] 18 | if word.lower() not in lexicon: 19 | lexicon[word.lower()] = phones 20 | return lexicon 21 | 22 | 23 | def preprocess_english(text, preprocess_config): 24 | text = text.rstrip(punctuation) 25 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) 26 | 27 | g2p = G2p() 28 | phones = [] 29 | words = re.split(r"([,;.\-\?\!\s+])", text) 30 | for w in words: 31 | if w.lower() in lexicon: 32 | phones += lexicon[w.lower()] 33 | else: 34 | phones += list(filter(lambda p: p != " ", g2p(w))) 35 | phones = "{" + "}{".join(phones) + "}" 36 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) 37 | phones = phones.replace("}{", " ") 38 | 39 | print("Raw Text Sequence: {}".format(text)) 40 | print("Phoneme Sequence: {}".format(phones)) 41 | sequence = np.array( 42 | text_to_sequence( 43 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] 44 | ) 45 | ) 46 | return np.array(sequence) 47 | 48 | 49 | def synthesize(model, batch, control_values, preprocess_config, device, p_target=None, d_target=None, e_target=None): 50 | pitch_control, energy_control, duration_control = control_values 51 | 52 | batch = to_device(batch, device) 53 | with torch.no_grad(): 54 | # Forward 55 | output = model( 56 | *(batch[1:]), 57 | 58 | p_targets=p_target, 59 | e_targets=e_target, 60 | d_targets=d_target, 61 | 62 | p_control=pitch_control, 63 | e_control=energy_control, 64 | d_control=duration_control 65 | ) 66 | mel, durations, pitch, energy = prepare_outputs( 67 | batch, 68 | output, 69 | preprocess_config, 70 | ) 71 | return mel[0].to(device), durations[0].to(device), pitch[0].to(device), energy[0].to(device) 72 | 73 | def infer(model, text, control_values, preprocess_config, device, speaker=0, p_target=None, d_target=None, e_target=None): 74 | t = str(time.time()).replace('.', '_') 75 | ids = [t] 76 | speakers = np.array([speaker]) 77 | if preprocess_config["preprocessing"]["text"]["language"] == "fa": 78 | texts = np.array([text_to_sequence(text, preprocess_config['preprocessing']['text']['text_cleaners'])]) 79 | elif preprocess_config["preprocessing"]["text"]["language"] == "en": 80 | texts = np.array([preprocess_english(text, preprocess_config)]) 81 | text_lens = np.array([len(texts[0])]) 82 | batch = (ids, speakers, texts, text_lens, max(text_lens)) 83 | model.eval() 84 | mel, durations, pitch, energy = synthesize(model, batch, control_values, preprocess_config, device, p_target, d_target, e_target) 85 | return mel, durations, pitch, energy 86 | 87 | -------------------------------------------------------------------------------- /synthesizer/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from . import cleaners 3 | from .symbols import persian_symbols, english_symbols 4 | import re 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _persian_symbol_to_id = {s: i for i, s in enumerate(persian_symbols)} 8 | _persian_id_to_symbol = {i: s for i, s in enumerate(persian_symbols)} 9 | 10 | _english_symbol_to_id = {s: i for i, s in enumerate(english_symbols)} 11 | _english_id_to_symbol = {i: s for i, s in enumerate(english_symbols)} 12 | 13 | # Regular expression matching text enclosed in curly braces: 14 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 15 | 16 | def text_to_sequence(text, cleaner_name): 17 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 18 | Args: 19 | text: string to convert to a sequence 20 | cleaner_names: names of the cleaner functions to run the text through 21 | 22 | Returns: 23 | List of integers corresponding to the symbols in the text 24 | """ 25 | 26 | if cleaner_name == 'persian_cleaner': 27 | text = text.replace('{', '').replace('}', '') 28 | sequence = [_persian_symbol_to_id[phonem] for phonem in text.split()] 29 | 30 | elif cleaner_name == 'english_cleaner': 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_name)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_name)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | return sequence 40 | 41 | 42 | 43 | def sequence_to_text(sequence, cleaner_name): 44 | """Converts a sequence of IDs back to a string""" 45 | result = "" 46 | for symbol_id in sequence: 47 | if "persian_cleaner" == cleaner_name: 48 | if symbol_id in _persian_id_to_symbol: 49 | s = _persian_id_to_symbol[symbol_id] 50 | result += s 51 | elif "english_cleaner" == cleaner_name: 52 | if symbol_id in _english_id_to_symbol: 53 | s = _english_id_to_symbol[symbol_id] 54 | # Enclose ARPAbet back in curly braces: 55 | if len(s) > 1 and s[0] == "@": 56 | s = "{%s}" % s[1:] 57 | result += s 58 | return result.replace("}{", " ") 59 | 60 | 61 | def _clean_text(text, cleaner_name): 62 | cleaner = getattr(cleaners, cleaner_name) 63 | if not cleaner: 64 | raise Exception("Unknown cleaner: %s" % cleaner_name) 65 | text = cleaner(text) 66 | return text 67 | 68 | 69 | def _symbols_to_sequence(symbols, cleaner_name): 70 | if cleaner_name == "persian_cleaner": 71 | return [_persian_symbol_to_id[s] for s in symbols if _should_keep_symbol(s, _persian_symbol_to_id)] 72 | elif cleaner_name == "english_cleaner": 73 | return [_english_symbol_to_id[s] for s in symbols if _should_keep_symbol(s, _english_symbol_to_id)] 74 | 75 | 76 | def _arpabet_to_sequence(text): 77 | return _symbols_to_sequence(["@" + s for s in text.split()]) 78 | 79 | 80 | def _should_keep_symbol(s, _symbol_to_id): 81 | return s in _symbol_to_id and s != "_" and s != "~" 82 | -------------------------------------------------------------------------------- /synthesizer/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaner(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | 91 | def persian_cleaner(text): 92 | text = collapse_whitespace(text) 93 | return text -------------------------------------------------------------------------------- /synthesizer/text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /synthesizer/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /synthesizer/text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /synthesizer/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from . import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | english_symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | 31 | ## for Persian Language 32 | persian_phonemes = ['U', 'Q', 'G', 'AA', 'V', 'N', 'CH', 'R', 'KH', 'B', 'Z', 'SH', 'O', 'A', 'E', 'ZH', 'H', 'SIL', 'AH', \ 33 | 'S', 'D', 'J', 'L', 'F', 'K', 'I', 'T', 'P', 'M', 'Y'] 34 | persian_phonemes += ['?', '!', '.', ',', ';', ':'] 35 | persian_symbols = ( 36 | [_pad] 37 | + persian_phonemes 38 | ) 39 | 40 | def get_symbols(language): 41 | if language == "fa": 42 | return persian_symbols 43 | elif language == "en": 44 | return english_symbols 45 | -------------------------------------------------------------------------------- /synthesizer/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.utils.tensorboard import SummaryWriter 7 | from tqdm import tqdm 8 | 9 | from .utils.model import get_model, get_param_num 10 | from .utils.tools import to_device, log, synth_one_sample 11 | from .vocoder.utils import get_vocoder 12 | from .model import FastSpeech2Loss 13 | from .dataset import Dataset 14 | from .evaluate import evaluate 15 | 16 | 17 | def train_model(args, config): 18 | print("Prepare training ...") 19 | 20 | preprocess_config, model_config, train_config = config['synthesizer']['preprocess'], config['synthesizer']['model'], config['synthesizer']['train'] 21 | device = config['main']['device'] 22 | # Get dataset 23 | dataset = Dataset( 24 | "train.txt", preprocess_config, train_config, sort=True, drop_last=True 25 | ) 26 | batch_size = train_config["optimizer"]["batch_size"] 27 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset 28 | assert batch_size * group_size < len(dataset) 29 | loader = DataLoader( 30 | dataset, 31 | batch_size=batch_size * group_size, 32 | shuffle=True, 33 | collate_fn=dataset.collate_fn, 34 | ) 35 | 36 | # Prepare model 37 | model, optimizer = get_model(args.restore_step, config, train=True) 38 | if "cuda" in device: 39 | if "cuda" == device: ## If cuda id is not defined (use all GPUs) 40 | model = nn.DataParallel(model) 41 | else: 42 | device_id = int(device.replace("cuda:", "")) 43 | model = nn.DataParallel(model, device_ids=[device_id], output_device=[device_id]) 44 | num_param = get_param_num(model) 45 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) 46 | print("Number of FastSpeech2 Parameters:", num_param) 47 | 48 | # Load vocoder 49 | vocoder = get_vocoder(config['vocoder'], device) 50 | 51 | # Init logger 52 | for p in train_config["path"].values(): 53 | os.makedirs(p, exist_ok=True) 54 | train_log_path = os.path.join(train_config["path"]["log_path"], "train") 55 | val_log_path = os.path.join(train_config["path"]["log_path"], "val") 56 | os.makedirs(train_log_path, exist_ok=True) 57 | os.makedirs(val_log_path, exist_ok=True) 58 | train_logger = SummaryWriter(train_log_path) 59 | val_logger = SummaryWriter(val_log_path) 60 | 61 | # Training 62 | step = args.restore_step + 1 63 | epoch = 1 64 | grad_acc_step = train_config["optimizer"]["grad_acc_step"] 65 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] 66 | total_step = train_config["step"]["total_step"] 67 | log_step = train_config["step"]["log_step"] 68 | save_step = train_config["step"]["save_step"] 69 | synth_step = train_config["step"]["synth_step"] 70 | val_step = train_config["step"]["val_step"] 71 | 72 | outer_bar = tqdm(total=total_step, desc="Training", position=0) 73 | outer_bar.n = args.restore_step 74 | outer_bar.update() 75 | 76 | while True: 77 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) 78 | for batchs in loader: 79 | for batch in batchs: 80 | batch = to_device(batch, device) 81 | # Forward 82 | output = model(*(batch[1:])) 83 | 84 | # Cal Loss 85 | losses = Loss(batch, output) 86 | total_loss = losses[0] 87 | 88 | # Backward 89 | total_loss = total_loss / grad_acc_step 90 | total_loss.backward() 91 | if step % grad_acc_step == 0: 92 | # Clipping gradients to avoid gradient explosion 93 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) 94 | 95 | # Update weights 96 | optimizer.step_and_update_lr() 97 | optimizer.zero_grad() 98 | 99 | if step % log_step == 0: 100 | losses = [l.item() for l in losses] 101 | message1 = "Step {}/{}, ".format(step, total_step) 102 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format( 103 | *losses 104 | ) 105 | 106 | with open(os.path.join(train_log_path, "log.txt"), "a") as f: 107 | f.write(message1 + message2 + "\n") 108 | 109 | outer_bar.write(message1 + message2) 110 | 111 | log(train_logger, step, losses=losses) 112 | 113 | if step % synth_step == 0: 114 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 115 | batch, 116 | output, 117 | vocoder, 118 | model_config, 119 | preprocess_config, 120 | ) 121 | log( 122 | train_logger, 123 | fig=fig, 124 | tag="Training/step_{}_{}".format(step, tag), 125 | ) 126 | sampling_rate = preprocess_config["preprocessing"]["audio"][ 127 | "sampling_rate" 128 | ] 129 | log( 130 | train_logger, 131 | audio=wav_reconstruction, 132 | sampling_rate=sampling_rate, 133 | tag="Training/step_{}_{}_reconstructed".format(step, tag), 134 | ) 135 | log( 136 | train_logger, 137 | audio=wav_prediction, 138 | sampling_rate=sampling_rate, 139 | tag="Training/step_{}_{}_synthesized".format(step, tag), 140 | ) 141 | 142 | if step % val_step == 0: 143 | model.eval() 144 | message = evaluate(model, step, config, val_logger, vocoder) 145 | with open(os.path.join(val_log_path, "log.txt"), "a") as f: 146 | f.write(message + "\n") 147 | outer_bar.write(message) 148 | 149 | model.train() 150 | 151 | if step % save_step == 0: 152 | torch.save( 153 | { 154 | "model": model.module.state_dict(), 155 | "optimizer": optimizer._optimizer.state_dict(), 156 | }, 157 | os.path.join( 158 | train_config["path"]["ckpt_path"], 159 | "{}.pth.tar".format(step), 160 | ), 161 | ) 162 | 163 | if step == total_step: 164 | quit() 165 | step += 1 166 | outer_bar.update(1) 167 | 168 | inner_bar.update(1) 169 | epoch += 1 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /synthesizer/transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /synthesizer/transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | 11 | class FFTBlock(torch.nn.Module): 12 | """FFT Block""" 13 | 14 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 15 | super(FFTBlock, self).__init__() 16 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 17 | self.pos_ffn = PositionwiseFeedForward( 18 | d_model, d_inner, kernel_size, dropout=dropout 19 | ) 20 | 21 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 22 | enc_output, enc_slf_attn = self.slf_attn( 23 | enc_input, enc_input, enc_input, mask=slf_attn_mask 24 | ) 25 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 26 | 27 | enc_output = self.pos_ffn(enc_output) 28 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 29 | 30 | return enc_output, enc_slf_attn 31 | 32 | 33 | class ConvNorm(torch.nn.Module): 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels, 38 | kernel_size=1, 39 | stride=1, 40 | padding=None, 41 | dilation=1, 42 | bias=True, 43 | w_init_gain="linear", 44 | ): 45 | super(ConvNorm, self).__init__() 46 | 47 | if padding is None: 48 | assert kernel_size % 2 == 1 49 | padding = int(dilation * (kernel_size - 1) / 2) 50 | 51 | self.conv = torch.nn.Conv1d( 52 | in_channels, 53 | out_channels, 54 | kernel_size=kernel_size, 55 | stride=stride, 56 | padding=padding, 57 | dilation=dilation, 58 | bias=bias, 59 | ) 60 | 61 | def forward(self, signal): 62 | conv_signal = self.conv(signal) 63 | 64 | return conv_signal 65 | 66 | 67 | class PostNet(nn.Module): 68 | """ 69 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 70 | """ 71 | 72 | def __init__( 73 | self, 74 | n_mel_channels=80, 75 | postnet_embedding_dim=512, 76 | postnet_kernel_size=5, 77 | postnet_n_convolutions=5, 78 | ): 79 | 80 | super(PostNet, self).__init__() 81 | self.convolutions = nn.ModuleList() 82 | 83 | self.convolutions.append( 84 | nn.Sequential( 85 | ConvNorm( 86 | n_mel_channels, 87 | postnet_embedding_dim, 88 | kernel_size=postnet_kernel_size, 89 | stride=1, 90 | padding=int((postnet_kernel_size - 1) / 2), 91 | dilation=1, 92 | w_init_gain="tanh", 93 | ), 94 | nn.BatchNorm1d(postnet_embedding_dim), 95 | ) 96 | ) 97 | 98 | for i in range(1, postnet_n_convolutions - 1): 99 | self.convolutions.append( 100 | nn.Sequential( 101 | ConvNorm( 102 | postnet_embedding_dim, 103 | postnet_embedding_dim, 104 | kernel_size=postnet_kernel_size, 105 | stride=1, 106 | padding=int((postnet_kernel_size - 1) / 2), 107 | dilation=1, 108 | w_init_gain="tanh", 109 | ), 110 | nn.BatchNorm1d(postnet_embedding_dim), 111 | ) 112 | ) 113 | 114 | self.convolutions.append( 115 | nn.Sequential( 116 | ConvNorm( 117 | postnet_embedding_dim, 118 | n_mel_channels, 119 | kernel_size=postnet_kernel_size, 120 | stride=1, 121 | padding=int((postnet_kernel_size - 1) / 2), 122 | dilation=1, 123 | w_init_gain="linear", 124 | ), 125 | nn.BatchNorm1d(n_mel_channels), 126 | ) 127 | ) 128 | 129 | def forward(self, x): 130 | x = x.contiguous().transpose(1, 2) 131 | 132 | for i in range(len(self.convolutions) - 1): 133 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 134 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 135 | 136 | x = x.contiguous().transpose(1, 2) 137 | return x 138 | -------------------------------------------------------------------------------- /synthesizer/transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from ..transformer import Constants 6 | from .Layers import FFTBlock 7 | from ..text.symbols import get_symbols 8 | 9 | 10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 11 | """ Sinusoid position encoding table """ 12 | 13 | def cal_angle(position, hid_idx): 14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 15 | 16 | def get_posi_angle_vec(position): 17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 18 | 19 | sinusoid_table = np.array( 20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 21 | ) 22 | 23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 25 | 26 | if padding_idx is not None: 27 | # zero vector for padding dimension 28 | sinusoid_table[padding_idx] = 0.0 29 | 30 | return torch.FloatTensor(sinusoid_table) 31 | 32 | 33 | class Encoder(nn.Module): 34 | """ Encoder """ 35 | 36 | def __init__(self, config, language): 37 | super(Encoder, self).__init__() 38 | 39 | n_position = config["max_seq_len"] + 1 40 | n_src_vocab = len(get_symbols("en")) + 1 41 | d_word_vec = config["transformer"]["encoder_hidden"] 42 | n_layers = config["transformer"]["encoder_layer"] 43 | n_head = config["transformer"]["encoder_head"] 44 | d_k = d_v = ( 45 | config["transformer"]["encoder_hidden"] 46 | // config["transformer"]["encoder_head"] 47 | ) 48 | d_model = config["transformer"]["encoder_hidden"] 49 | d_inner = config["transformer"]["conv_filter_size"] 50 | kernel_size = config["transformer"]["conv_kernel_size"] 51 | dropout = config["transformer"]["encoder_dropout"] 52 | 53 | self.max_seq_len = config["max_seq_len"] 54 | self.d_model = d_model 55 | 56 | self.src_word_emb = nn.Embedding( 57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD 58 | ) 59 | self.position_enc = nn.Parameter( 60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 61 | requires_grad=False, 62 | ) 63 | 64 | self.layer_stack = nn.ModuleList( 65 | [ 66 | FFTBlock( 67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 68 | ) 69 | for _ in range(n_layers) 70 | ] 71 | ) 72 | 73 | def forward(self, src_seq, mask, return_attns=False): 74 | 75 | enc_slf_attn_list = [] 76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 77 | 78 | # -- Prepare masks 79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 80 | 81 | # -- Forward 82 | if not self.training and src_seq.shape[1] > self.max_seq_len: 83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 84 | src_seq.shape[1], self.d_model 85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 86 | src_seq.device 87 | ) 88 | else: 89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[ 90 | :, :max_len, : 91 | ].expand(batch_size, -1, -1) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 96 | ) 97 | if return_attns: 98 | enc_slf_attn_list += [enc_slf_attn] 99 | 100 | return enc_output 101 | 102 | 103 | class Decoder(nn.Module): 104 | """ Decoder """ 105 | 106 | def __init__(self, config): 107 | super(Decoder, self).__init__() 108 | 109 | n_position = config["max_seq_len"] + 1 110 | d_word_vec = config["transformer"]["decoder_hidden"] 111 | n_layers = config["transformer"]["decoder_layer"] 112 | n_head = config["transformer"]["decoder_head"] 113 | d_k = d_v = ( 114 | config["transformer"]["decoder_hidden"] 115 | // config["transformer"]["decoder_head"] 116 | ) 117 | d_model = config["transformer"]["decoder_hidden"] 118 | d_inner = config["transformer"]["conv_filter_size"] 119 | kernel_size = config["transformer"]["conv_kernel_size"] 120 | dropout = config["transformer"]["decoder_dropout"] 121 | 122 | self.max_seq_len = config["max_seq_len"] 123 | self.d_model = d_model 124 | 125 | self.position_enc = nn.Parameter( 126 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 127 | requires_grad=False, 128 | ) 129 | 130 | self.layer_stack = nn.ModuleList( 131 | [ 132 | FFTBlock( 133 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 134 | ) 135 | for _ in range(n_layers) 136 | ] 137 | ) 138 | 139 | def forward(self, enc_seq, mask, return_attns=False): 140 | 141 | dec_slf_attn_list = [] 142 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 143 | 144 | # -- Forward 145 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 146 | # -- Prepare masks 147 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 148 | dec_output = enc_seq + get_sinusoid_encoding_table( 149 | enc_seq.shape[1], self.d_model 150 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 151 | enc_seq.device 152 | ) 153 | else: 154 | max_len = min(max_len, self.max_seq_len) 155 | 156 | # -- Prepare masks 157 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 158 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 159 | :, :max_len, : 160 | ].expand(batch_size, -1, -1) 161 | mask = mask[:, :max_len] 162 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 163 | 164 | for dec_layer in self.layer_stack: 165 | dec_output, dec_slf_attn = dec_layer( 166 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 167 | ) 168 | if return_attns: 169 | dec_slf_attn_list += [dec_slf_attn] 170 | 171 | return dec_output, mask 172 | -------------------------------------------------------------------------------- /synthesizer/transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /synthesizer/transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | 49 | output = output.view(n_head, sz_b, len_q, d_v) 50 | output = ( 51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 52 | ) # b x lq x (n*dv) 53 | 54 | output = self.dropout(self.fc(output)) 55 | output = self.layer_norm(output + residual) 56 | 57 | return output, attn 58 | 59 | 60 | class PositionwiseFeedForward(nn.Module): 61 | """ A two-feed-forward-layer module """ 62 | 63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 64 | super().__init__() 65 | 66 | # Use Conv1D 67 | # position-wise 68 | self.w_1 = nn.Conv1d( 69 | d_in, 70 | d_hid, 71 | kernel_size=kernel_size[0], 72 | padding=(kernel_size[0] - 1) // 2, 73 | ) 74 | # position-wise 75 | self.w_2 = nn.Conv1d( 76 | d_hid, 77 | d_in, 78 | kernel_size=kernel_size[1], 79 | padding=(kernel_size[1] - 1) // 2, 80 | ) 81 | 82 | self.layer_norm = nn.LayerNorm(d_in) 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def forward(self, x): 86 | residual = x 87 | output = x.transpose(1, 2) 88 | output = self.w_2(F.relu(self.w_1(output))) 89 | output = output.transpose(1, 2) 90 | output = self.dropout(output) 91 | output = self.layer_norm(output + residual) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /synthesizer/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import Encoder, Decoder 2 | from .Layers import PostNet -------------------------------------------------------------------------------- /synthesizer/utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from ..model import FastSpeech2, ScheduledOptim 5 | 6 | 7 | def get_model(restore_step, config, train=False): 8 | 9 | device = config['main']['device'] 10 | model = FastSpeech2(config).to(device) 11 | if restore_step: 12 | ckpt_path = os.path.join( 13 | config['synthesizer']['train']["path"]["ckpt_path"], 14 | "{}.pth.tar".format(restore_step), 15 | ) 16 | # ckpt_path = "/mnt/hdd1/adibian/FastSpeech2/ResGrad/output/MS_Persian/add_speaker_befor_and_after_VA/synthesizer/ckpt/1000000.pth.tar" 17 | # ckpt_path = "/mnt/hdd1/adibian/FastSpeech2/ResGrad/output/MS_Persian/add_speaker_before_VA/synthesizer/ckpt/1000000.pth.tar" 18 | ckpt = torch.load(ckpt_path, map_location=torch.device(device)) 19 | model.load_state_dict(ckpt["model"]) 20 | 21 | if train: 22 | scheduled_optim = ScheduledOptim( 23 | model, config['synthesizer']['train'], config['synthesizer']['model'], restore_step 24 | ) 25 | if restore_step: 26 | scheduled_optim.load_state_dict(ckpt["optimizer"]) 27 | model.train() 28 | return model, scheduled_optim 29 | 30 | model.eval() 31 | model.requires_grad_ = False 32 | return model 33 | 34 | 35 | def get_param_num(model): 36 | num_param = sum(param.numel() for param in model.parameters()) 37 | return num_param 38 | 39 | -------------------------------------------------------------------------------- /synthesizer/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import matplotlib 9 | from scipy.io import wavfile 10 | from matplotlib import pyplot as plt 11 | 12 | from vocoder.inference import infer as vocoder_infer 13 | 14 | matplotlib.use("Agg") 15 | 16 | def to_device(data, device): 17 | if len(data) == 11: 18 | ( 19 | ids, 20 | speakers, 21 | texts, 22 | src_lens, 23 | max_src_len, 24 | mels, 25 | mel_lens, 26 | max_mel_len, 27 | pitches, 28 | energies, 29 | durations, 30 | ) = data 31 | 32 | speakers = torch.from_numpy(speakers).long().to(device) 33 | texts = torch.from_numpy(texts).long().to(device) 34 | src_lens = torch.from_numpy(src_lens).to(device) 35 | mels = torch.from_numpy(mels).float().to(device) 36 | mel_lens = torch.from_numpy(mel_lens).to(device) 37 | pitches = torch.from_numpy(pitches).float().to(device) 38 | energies = torch.from_numpy(energies).to(device) 39 | durations = torch.from_numpy(durations).long().to(device) 40 | 41 | return ( 42 | ids, 43 | speakers, 44 | texts, 45 | src_lens, 46 | max_src_len, 47 | mels, 48 | mel_lens, 49 | max_mel_len, 50 | pitches, 51 | energies, 52 | durations, 53 | ) 54 | 55 | if len(data) == 5: 56 | (ids, speakers, texts, src_lens, max_src_len) = data 57 | speakers = torch.from_numpy(speakers).long().to(device) 58 | texts = torch.from_numpy(texts).long().to(device) 59 | src_lens = torch.from_numpy(src_lens).to(device) 60 | 61 | return (ids, speakers, texts, src_lens, max_src_len) 62 | 63 | 64 | def log( 65 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 66 | ): 67 | if losses is not None: 68 | logger.add_scalar("Loss/total_loss", losses[0], step) 69 | logger.add_scalar("Loss/mel_loss", losses[1], step) 70 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) 71 | logger.add_scalar("Loss/pitch_loss", losses[3], step) 72 | logger.add_scalar("Loss/energy_loss", losses[4], step) 73 | logger.add_scalar("Loss/duration_loss", losses[5], step) 74 | 75 | if fig is not None: 76 | logger.add_figure(tag, fig) 77 | 78 | if audio is not None: 79 | logger.add_audio( 80 | tag, 81 | audio / max(abs(audio)), 82 | sample_rate=sampling_rate, 83 | ) 84 | 85 | 86 | def get_mask_from_lengths(lengths, max_len=None, device=None): 87 | batch_size = lengths.shape[0] 88 | if max_len is None: 89 | max_len = torch.max(lengths).item() 90 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 91 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 92 | 93 | return mask 94 | 95 | 96 | def expand(values, durations): 97 | out = list() 98 | for value, d in zip(values, durations): 99 | out += [value] * max(0, int(d)) 100 | return np.array(out) 101 | 102 | 103 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): 104 | 105 | basename = targets[0][0] 106 | src_len = predictions[8][0].item() 107 | mel_len = predictions[9][0].item() 108 | mel_target = targets[5][0, :mel_len].detach().transpose(0, 1) 109 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) 110 | duration = targets[10][0, :src_len].detach().cpu().numpy() 111 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 112 | pitch = targets[8][0, :src_len].detach().cpu().numpy() 113 | pitch = expand(pitch, duration) 114 | else: 115 | pitch = targets[8][0, :mel_len].detach().cpu().numpy() 116 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 117 | energy = targets[9][0, :src_len].detach().cpu().numpy() 118 | energy = expand(energy, duration) 119 | else: 120 | energy = targets[9][0, :mel_len].detach().cpu().numpy() 121 | 122 | with open( 123 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 124 | ) as f: 125 | stats = json.load(f) 126 | stats = stats["pitch"] + stats["energy"][:2] 127 | 128 | fig = plot_mel( 129 | [ 130 | (mel_prediction.cpu().numpy(), pitch, energy), 131 | (mel_target.cpu().numpy(), pitch, energy), 132 | ], 133 | stats, 134 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 135 | ) 136 | if vocoder is not None: 137 | wav_reconstruction = vocoder_infer( 138 | vocoder, 139 | mel_target.unsqueeze(0), 140 | preprocess_config["preprocessing"]["audio"]["max_wav_value"], 141 | ) 142 | wav_prediction = vocoder_infer( 143 | vocoder, 144 | mel_prediction.unsqueeze(0), 145 | preprocess_config["preprocessing"]["audio"]["max_wav_value"], 146 | ) 147 | else: 148 | wav_reconstruction = wav_prediction = None 149 | 150 | return fig, wav_reconstruction, wav_prediction, basename 151 | 152 | 153 | def prepare_outputs(targets, predictions, preprocess_config): 154 | all_mel, all_durations, all_pitch, all_energy = [], [], [], [] 155 | for i in range(len(predictions[0])): 156 | src_len = predictions[8][i].item() 157 | mel_len = predictions[9][i].item() 158 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) 159 | duration = predictions[5][i, :src_len].detach() 160 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 161 | pitch = predictions[2][i, :src_len].detach().cpu() 162 | pitch = expand(pitch, duration) 163 | pitch = torch.from_numpy(pitch) 164 | else: 165 | pitch = predictions[2][i, :mel_len].detach() 166 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 167 | energy = predictions[3][i, :src_len].detach().cpu() 168 | energy = expand(energy, duration) 169 | energy = torch.from_numpy(energy) 170 | else: 171 | energy = predictions[3][i, :mel_len].detach() 172 | 173 | all_mel.append(mel_prediction) 174 | all_durations.append(duration) 175 | all_pitch.append(pitch) 176 | all_energy.append(energy) 177 | 178 | return all_mel, all_durations, all_pitch, all_energy 179 | 180 | 181 | def plot_mel(data, stats, titles): 182 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 183 | if titles is None: 184 | titles = [None for i in range(len(data))] 185 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 186 | pitch_min = pitch_min * pitch_std + pitch_mean 187 | pitch_max = pitch_max * pitch_std + pitch_mean 188 | 189 | def add_axis(fig, old_ax): 190 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 191 | ax.set_facecolor("None") 192 | return ax 193 | 194 | for i in range(len(data)): 195 | mel, pitch, energy = data[i] 196 | pitch = pitch * pitch_std + pitch_mean 197 | axes[i][0].imshow(mel, origin="lower") 198 | axes[i][0].set_aspect(2.5, adjustable="box") 199 | axes[i][0].set_ylim(0, mel.shape[0]) 200 | axes[i][0].set_title(titles[i], fontsize="medium") 201 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 202 | axes[i][0].set_anchor("W") 203 | 204 | ax1 = add_axis(fig, axes[i][0]) 205 | ax1.plot(pitch, color="tomato") 206 | ax1.set_xlim(0, mel.shape[1]) 207 | ax1.set_ylim(0, pitch_max) 208 | ax1.set_ylabel("F0", color="tomato") 209 | ax1.tick_params( 210 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 211 | ) 212 | 213 | ax2 = add_axis(fig, axes[i][0]) 214 | ax2.plot(energy, color="darkviolet") 215 | ax2.set_xlim(0, mel.shape[1]) 216 | ax2.set_ylim(energy_min, energy_max) 217 | ax2.set_ylabel("Energy", color="darkviolet") 218 | ax2.yaxis.set_label_position("right") 219 | ax2.tick_params( 220 | labelsize="x-small", 221 | colors="darkviolet", 222 | bottom=False, 223 | labelbottom=False, 224 | left=False, 225 | labelleft=False, 226 | right=True, 227 | labelright=True, 228 | ) 229 | 230 | return fig 231 | 232 | 233 | def pad_1D(inputs, PAD=0): 234 | def pad_data(x, length, PAD): 235 | x_padded = np.pad( 236 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 237 | ) 238 | return x_padded 239 | 240 | max_len = max((len(x) for x in inputs)) 241 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 242 | 243 | return padded 244 | 245 | 246 | def pad_2D(inputs, maxlen=None): 247 | def pad(x, max_len): 248 | PAD = 0 249 | if np.shape(x)[0] > max_len: 250 | raise ValueError("not max_len") 251 | 252 | s = np.shape(x)[1] 253 | x_padded = np.pad( 254 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 255 | ) 256 | return x_padded[:, :s] 257 | 258 | if maxlen: 259 | output = np.stack([pad(x, maxlen) for x in inputs]) 260 | else: 261 | max_len = max(np.shape(x)[0] for x in inputs) 262 | output = np.stack([pad(x, max_len) for x in inputs]) 263 | 264 | return output 265 | 266 | 267 | def pad(input_ele, mel_max_length=None): 268 | if mel_max_length: 269 | max_len = mel_max_length 270 | else: 271 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 272 | 273 | out_list = list() 274 | for i, batch in enumerate(input_ele): 275 | if len(batch.shape) == 1: 276 | one_batch_padded = F.pad( 277 | batch, (0, max_len - batch.size(0)), "constant", 0.0 278 | ) 279 | elif len(batch.shape) == 2: 280 | one_batch_padded = F.pad( 281 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 282 | ) 283 | out_list.append(one_batch_padded) 284 | out_padded = torch.stack(out_list) 285 | return out_padded 286 | 287 | def load_yaml_file(path): 288 | ## define custom tag handler 289 | def join(loader, node): 290 | seq = loader.construct_sequence(node) 291 | return str(os.path.join(*[str(i) for i in seq])) 292 | 293 | ## register the tag handler 294 | yaml.add_constructor('!join', join) 295 | data = yaml.load(open(path, "r"), Loader=yaml.FullLoader) 296 | return data -------------------------------------------------------------------------------- /synthesizer/vocoder: -------------------------------------------------------------------------------- 1 | ../vocoder -------------------------------------------------------------------------------- /train_resgrad.py: -------------------------------------------------------------------------------- 1 | from utils import load_yaml_file 2 | from resgrad.train import resgrad_train 3 | 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--restore_step", type=int, default=0) 10 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml") 11 | args = parser.parse_args() 12 | 13 | # Read Config 14 | config = load_yaml_file(args.config) 15 | resgrad_config = config['resgrad'] 16 | resgrad_config['main'] = config['main'] 17 | resgrad_train(args, resgrad_config) -------------------------------------------------------------------------------- /train_synthesizer.py: -------------------------------------------------------------------------------- 1 | from synthesizer.train import train_model 2 | from utils import load_yaml_file 3 | 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--restore_step", type=int, default=0) 10 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml") 11 | args = parser.parse_args() 12 | 13 | # Read Config 14 | config = load_yaml_file(args.config) 15 | train_model(args, config) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from synthesizer.utils.model import get_model as load_synthesizer_model 2 | from resgrad.utils import load_model as load_resgrad_model 3 | from vocoder.utils import get_vocoder 4 | 5 | from synthesizer.utils.tools import plot_mel 6 | 7 | import os 8 | import json 9 | from scipy.io import wavfile 10 | from matplotlib import pyplot as plt 11 | import yaml 12 | import time 13 | 14 | def save_result(mel_prediction, wav, pitch, energy, preprocess_config, result_dir, file_name): 15 | with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")) as f: 16 | stats = json.load(f) 17 | stats = stats["pitch"] + stats["energy"][:2] 18 | fig = plot_mel([(mel_prediction.cpu().numpy(), pitch.cpu().numpy(), energy.cpu().numpy())], stats, ["Synthetized Spectrogram"]) 19 | plt.savefig(os.path.join(result_dir, "{}.png".format(file_name))) 20 | plt.close() 21 | 22 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 23 | wavfile.write(os.path.join(result_dir, "{}.wav".format(file_name)), sampling_rate, wav) 24 | 25 | 26 | def load_models(all_restore_step, config): 27 | synthesizer_model, resgrad_model, vocoder_model = None, None, None 28 | device = config['main']['device'] 29 | if all_restore_step['synthesizer'] not in [None, 0]: 30 | synthesizer_model = load_synthesizer_model(all_restore_step['synthesizer'], config).to(device) 31 | if all_restore_step['vocoder'] not in [None, 0]: 32 | vocoder_model = get_vocoder(config['vocoder'], device).to(device) 33 | if all_restore_step['resgrad'] not in [None, 0]: 34 | resgrad_model = load_resgrad_model(config['resgrad'], train=False, restore_model_step=all_restore_step['resgrad'], device=device).to(device) 35 | return synthesizer_model, resgrad_model, vocoder_model 36 | 37 | def load_yaml_file(path): 38 | ## define custom tag handler 39 | def join(loader, node): 40 | seq = loader.construct_sequence(node) 41 | return str(os.path.join(*[str(i) for i in seq])) 42 | 43 | ## register the tag handler 44 | yaml.add_constructor('!join', join) 45 | data = yaml.load(open(path, "r"), Loader=yaml.FullLoader) 46 | return data 47 | 48 | def get_file_name(args): 49 | file_name_parts = [] 50 | if args.result_file_name: 51 | file_name_parts.append(args.result_file_name) 52 | if args.speaker_id: 53 | file_name_parts.append("spk" + str(args.speaker_id)) 54 | if len(file_name_parts) == 0: 55 | file_name_parts.append(str(time.time()).replace('.', '_')) 56 | file_name_parts.append("FastSpeech") 57 | file_name = "_".join(file_name_parts) 58 | return file_name -------------------------------------------------------------------------------- /vocoder/ckpt/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /vocoder/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def infer(model, mel_spectrum, max_wav_value): 5 | if len(mel_spectrum.shape) == 2: 6 | mel_spectrum = mel_spectrum.unsqueeze(0) 7 | with torch.no_grad(): 8 | wav = model(mel_spectrum).squeeze(1)[0] 9 | wav = wav.cpu().numpy() * max_wav_value * 0.97 10 | wav = wav.astype("int16") 11 | return wav -------------------------------------------------------------------------------- /vocoder/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2 ** i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | for l in self.ups: 169 | remove_weight_norm(l) 170 | for l in self.resblocks: 171 | l.remove_weight_norm() 172 | remove_weight_norm(self.conv_pre) 173 | remove_weight_norm(self.conv_post) -------------------------------------------------------------------------------- /vocoder/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from vocoder.models import Generator 3 | import json 4 | import torch 5 | 6 | class AttrDict(dict): 7 | def __init__(self, *args, **kwargs): 8 | super(AttrDict, self).__init__(*args, **kwargs) 9 | self.__dict__ = self 10 | 11 | def get_vocoder(config, device): 12 | with open("vocoder/ckpt/config.json", "r") as f: 13 | model_config = json.load(f) 14 | model_config = AttrDict(model_config) 15 | vocoder = Generator(model_config) 16 | 17 | # if config['restore_step']: 18 | # ckpt = torch.load(f"vocoder/ckpt/g_{config['restore_step']}", map_location=config['device']) 19 | # else: 20 | ckpt = torch.load(f"vocoder/ckpt/{config['model_name']}", map_location=device) 21 | 22 | vocoder.load_state_dict(ckpt["generator"]) 23 | vocoder.eval() 24 | vocoder.remove_weight_norm() 25 | vocoder.to(device) 26 | return vocoder 27 | 28 | --------------------------------------------------------------------------------