├── .gitignore
├── LICENSE
├── README.md
├── config
├── LJSpeech
│ └── config.yaml
├── MS_Persian
│ └── config.yaml
└── Persian
│ └── config.yaml
├── inference.py
├── requirements.txt
├── resgrad.PNG
├── resgrad
├── data.py
├── inference.py
├── model
│ ├── __init__.py
│ ├── base.py
│ ├── diffusion.py
│ └── optimizer.py
├── train.py
└── utils.py
├── resgrad_data.py
├── synthesizer
├── audio
│ ├── __init__.py
│ ├── audio_processing.py
│ ├── stft.py
│ └── tools.py
├── dataset.py
├── evaluate.py
├── model
│ ├── __init__.py
│ ├── fastspeech2.py
│ ├── loss.py
│ ├── modules.py
│ └── optimizer.py
├── prepare_align.py
├── preprocess.py
├── preprocessor
│ ├── ljspeech.py
│ ├── persian.py
│ ├── persian_v1.py
│ └── preprocessor.py
├── synthesize.py
├── text
│ ├── __init__.py
│ ├── cleaners.py
│ ├── cmudict.py
│ ├── numbers.py
│ ├── pinyin.py
│ └── symbols.py
├── train.py
├── transformer
│ ├── Constants.py
│ ├── Layers.py
│ ├── Models.py
│ ├── Modules.py
│ ├── SubLayers.py
│ └── __init__.py
├── utils
│ ├── model.py
│ └── tools.py
└── vocoder
├── train_resgrad.py
├── train_synthesizer.py
├── utils.py
└── vocoder
├── ckpt
└── config.json
├── inference.py
├── models.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Synthesizer
2 | synthesizer/raw_data/*
3 | synthesizer/preprocessed_data/*
4 | synthesizer/*.npy
5 |
6 | ## Vocoder
7 | vocoder/ckpt/g_2500000_persian
8 | vocoder/ckpt/g_2500000
9 |
10 | ## Outputs
11 | output/*
12 |
13 | ## Data
14 | dataset/*
15 | # dataset/Persian/synthesizer_data/*
16 | # dataset/Persian/resgrad_data/*
17 | # dataset/MS_Persian/synthesizer_data/*
18 | # dataset/MS_Persian/resgrad_data/*
19 | # dataset/LJSpeech/synthesizer_data/wavs
20 | # dataset/LJSpeech/resgrad_data/*
21 |
22 | ## Others
23 | auto_inference.py
24 | commands.txt
25 | *__pycache__/
26 |
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Majid Adibian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ResGrad - PyTorch Implementation
2 | [**ResGrad: Residual Denoising Diffusion Probabilistic Models for Text to Speech**](https://arxiv.org/abs/2212.14518)
3 |
4 | This is an *unofficial* PyTorch implementation of **ResGrad** as a high-quality denoising model for Text to Speech. In short, this model generates the spectrogram using FastSpeech2 and then removes the noise in the spectrogram using the Diffusion method to synthesize high-quality speeches. As mentioned in the paper the implementation is based on FastSpeech2 and Grad-TTS. Also, the HiFiGAN model is used to generate waveforms from synthesized spectrograms.
5 |
6 |
7 |
8 | ## Quickstart
9 | Data structures:
10 | ```
11 | dataset/data_name/synthesizer_data/
12 | test_data/
13 | speaker1/
14 | sample1.txt
15 | sample1.wav
16 | ...
17 | ...
18 | train_data/
19 | ...
20 | test.txt (sample1|speaker1|*phoneme_sequence \n ...)
21 | train.txt (sample1|speaker1|*phoneme_sequence \n ...)
22 | ```
23 |
24 | Preprocessing:
25 | ```
26 | python synthesizer/prepare_align.py config/data_name/config.yaml
27 | python synthesizer/preprocess.py config/data_name/config.yaml
28 | ```
29 |
30 | Train synthesizer:
31 | ```
32 | python train_synthesizer.py --config config/data_name/config.yaml
33 | ```
34 |
35 | Prepare data for ResGrade:
36 | ```
37 | python resgrad_data.py --synthesizer_restore_step 1000000 --data_file_path dataset/data_name/synthesizer_data/train.txt \
38 | --config config/data_name/config.yaml
39 | ```
40 |
41 | Train ResGrade:
42 | ```
43 | python train_resgrad.py --config config/data_name/config.yaml
44 | ```
45 |
46 | Inference:
47 | ```
48 | python inference.py --text "phonemes sequence example" \
49 | --synthesizer_restore_step 1000000 --regrad_restore_step 1000000 --vocoder_restore_step 2500000 \
50 | --config config/data_name/config.yaml --result_dir output/data_name/results
51 | ```
52 |
53 | ## References :notebook_with_decorative_cover:
54 | - [ResGrad: Residual Denoising Diffusion Probabilistic Models for Text to Speech](https://arxiv.org/abs/2212.14518), Z. Chen, *et al*.
55 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558), Y. Ren, *et al*.
56 | - [Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech](https://arxiv.org/abs/2105.06337), V. Popov, *et al*.
57 |
58 |
59 | ## License
60 | This repository is an implementation of the [**ResGrad**](https://arxiv.org/abs/2212.14518) paper, and is licensed under the MIT License.
61 |
--------------------------------------------------------------------------------
/config/LJSpeech/config.yaml:
--------------------------------------------------------------------------------
1 | ############################################################################################
2 | ####################################### Main Config ########################################
3 | ############################################################################################
4 | main:
5 | dataset: &Dataset "LJSpeech"
6 | multi_speaker: False
7 | device: "cuda" ## "cpu" or "cuda" (for using all GPUs) or "cuda:0" or "cuda:1" or ...
8 |
9 |
10 | ############################################################################################
11 | ####################################### Synthesizer ########################################
12 | ############################################################################################
13 | synthesizer:
14 | ##################### Model #######################
15 | model:
16 | transformer:
17 | encoder_layer: 4
18 | encoder_head: 2
19 | encoder_hidden: 256
20 | decoder_layer: 6
21 | decoder_head: 2
22 | decoder_hidden: 256
23 | conv_filter_size: 1024
24 | conv_kernel_size: [9, 1]
25 | encoder_dropout: 0.2
26 | decoder_dropout: 0.2
27 |
28 | variance_predictor:
29 | filter_size: 256
30 | kernel_size: 3
31 | dropout: 0.5
32 |
33 | variance_embedding:
34 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
35 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
36 | n_bins: 256
37 | max_seq_len: 1000
38 |
39 | #################### Preprocesss #####################
40 | preprocess:
41 | path:
42 | corpus_path: !join ["dataset/", *Dataset, "synthesizer_data"]
43 | raw_path: !join ["synthesizer/raw_data/", *Dataset]
44 | preprocessed_path: !join ["synthesizer/preprocessed_data/", *Dataset]
45 | lexicon_path: !join ["dataset/", *Dataset, "synthesizer_data/librispeech-lexicon.txt"]
46 | preprocessing:
47 | val_size: 100
48 | text:
49 | text_cleaners: "english_cleaner" ## english_cleaners or persian_cleaners
50 | language: "en" ## fa or en
51 | audio:
52 | sampling_rate: 22050
53 | max_wav_value: 32768.0
54 | stft:
55 | filter_length: 1024
56 | hop_length: 256
57 | win_length: 1024
58 | mel:
59 | n_mel_channels: 80
60 | mel_fmin: 0
61 | mel_fmax: 8000
62 | pitch:
63 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
64 | normalization: True
65 | energy:
66 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
67 | normalization: True
68 |
69 | #################### Training #####################
70 | train:
71 | path:
72 | ckpt_path: !join ["output/", *Dataset, "synthesizer/ckpt"]
73 | log_path: !join ["output/", *Dataset, "synthesizer/log"]
74 | # result_path: !join ["output/", *Dataset, "synthesizer/result"]
75 | optimizer:
76 | batch_size: 16
77 | betas: [0.9, 0.98]
78 | eps: 0.000000001
79 | weight_decay: 0.0
80 | grad_clip_thresh: 1.0
81 | grad_acc_step: 1
82 | warm_up_step: 4000
83 | anneal_steps: [300000, 400000, 500000]
84 | anneal_rate: 0.3
85 | step:
86 | total_step: 1010000
87 | log_step: 500
88 | synth_step: 1000
89 | val_step: 1000
90 | save_step: 100000
91 |
92 |
93 | ############################################################################################
94 | ########################################### ResGrad ########################################
95 | ############################################################################################
96 | resgrad:
97 | #################### Data #####################
98 | data:
99 | batch_size: 32
100 | metadata_path: !join ["dataset/", *Dataset, "resgrad_data/metadata.csv"]
101 | input_mel_dir: !join ["dataset/", *Dataset, "resgrad_data/input_mel"]
102 | speaker_map_path: !join ["synthesizer/preprocessed_data", *Dataset, "speakers.json"]
103 | val_size: 128
104 | preprocessed_path: "processed_data"
105 | normalized_method: "min-max"
106 |
107 | shuffle_data: True
108 | normallize_spectrum: True
109 | min_spec_value: -13
110 | max_spec_value: 3
111 | normallize_residual: True
112 | min_residual_value: -0.25
113 | max_residual_value: 0.25
114 | max_win_length: 100 ## maximum size of window in spectrum
115 |
116 | ################## Training ###################
117 | train:
118 | lr: 0.0001
119 | save_model_path: !join ["output/", *Dataset, "resgrad/ckpt"]
120 | log_dir: !join ["output/", *Dataset, "resgrad/log"]
121 | total_steps: 100000
122 | save_ckpt_step: 10000
123 | validate_step: 100
124 |
125 | # optimizer:
126 | # betas: [0.9, 0.98]
127 | # eps: 0.000000001
128 | # weight_decay: 0.0
129 | # grad_clip_thresh: 1.0
130 | # grad_acc_step: 1
131 | # warm_up_step: 4000
132 | # anneal_steps: [10000, 20000, 30000]
133 | # anneal_rate: 0.3
134 |
135 | ############ Model Parameters #################
136 | model:
137 | model_type1: "spec2residual" ## "spec2spec" or "spec2residual"
138 | model_type2: "segment-based" ## "segment-based" or "sentence-based"
139 | n_feats: 80
140 | dim: 64
141 | n_spks: 1
142 | spk_emb_dim: 64
143 | beta_min: 0.05
144 | beta_max: 20.0
145 | pe_scale: 1000
146 |
147 |
148 |
149 | ############################################################################################
150 | ######################################### Vocoder ##########################################
151 | ############################################################################################
152 | vocoder:
153 | model_name: ""
--------------------------------------------------------------------------------
/config/MS_Persian/config.yaml:
--------------------------------------------------------------------------------
1 | ############################################################################################
2 | ####################################### Main Config ########################################
3 | ############################################################################################
4 | main:
5 | dataset: &Dataset "MS_Persian"
6 | multi_speaker: True
7 | device: "cpu" ## "cpu" or "cuda" (for using all GPUs) or "cuda:0" or "cuda:1" or ...
8 |
9 |
10 | ############################################################################################
11 | ####################################### Synthesizer ########################################
12 | ############################################################################################
13 | synthesizer:
14 | ##################### Model #######################
15 | model:
16 | transformer:
17 | encoder_layer: 4
18 | encoder_head: 2
19 | encoder_hidden: 256
20 | decoder_layer: 6
21 | decoder_head: 2
22 | decoder_hidden: 256
23 | conv_filter_size: 1024
24 | conv_kernel_size: [9, 1]
25 | encoder_dropout: 0.2
26 | decoder_dropout: 0.2
27 |
28 | variance_predictor:
29 | filter_size: 256
30 | kernel_size: 3
31 | dropout: 0.5
32 |
33 | variance_embedding:
34 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
35 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
36 | n_bins: 256
37 | max_seq_len: 1000
38 |
39 | #################### Preprocesss #####################
40 | preprocess:
41 | path:
42 | corpus_path: !join ["dataset/", *Dataset, "synthesizer_data"]
43 | raw_path: !join ["synthesizer/raw_data/", *Dataset]
44 | preprocessed_path: !join ["synthesizer/preprocessed_data/", *Dataset]
45 | preprocessing:
46 | val_size: 100
47 | text:
48 | text_cleaners: "persian_cleaner" ## english_cleaners or persian_cleaners
49 | language: "fa" ## fa or en
50 | audio:
51 | sampling_rate: 22050
52 | max_wav_value: 32768.0
53 | stft:
54 | filter_length: 1024
55 | hop_length: 256
56 | win_length: 1024
57 | mel:
58 | n_mel_channels: 80
59 | mel_fmin: 0
60 | mel_fmax: 8000
61 | pitch:
62 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
63 | normalization: True
64 | energy:
65 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
66 | normalization: True
67 |
68 | #################### Training #####################
69 | train:
70 | path:
71 | ckpt_path: !join ["output/", *Dataset, "synthesizer/ckpt"]
72 | log_path: !join ["output/", *Dataset, "synthesizer/log"]
73 | # result_path: !join ["output/", *Dataset, "synthesizer/result"]
74 | optimizer:
75 | batch_size: 16
76 | betas: [0.9, 0.98]
77 | eps: 0.000000001
78 | weight_decay: 0.0
79 | grad_clip_thresh: 1.0
80 | grad_acc_step: 1
81 | warm_up_step: 4000
82 | anneal_steps: [300000, 400000, 500000]
83 | anneal_rate: 0.3
84 | step:
85 | total_step: 2010000
86 | log_step: 500
87 | synth_step: 1000
88 | val_step: 1000
89 | save_step: 100000
90 |
91 |
92 |
93 | ############################################################################################
94 | ########################################### ResGrad ########################################
95 | ############################################################################################
96 | resgrad:
97 | #################### Data #####################
98 | data:
99 | batch_size: 32
100 | metadata_path: !join ["dataset/", *Dataset, "resgrad_data/metadata.csv"]
101 | input_mel_dir: !join ["dataset/", *Dataset, "resgrad_data/input_mel"]
102 | speaker_map_path: !join ["synthesizer/preprocessed_data", *Dataset, "speakers.json"]
103 | val_size: 16
104 | preprocessed_path: "processed_data"
105 | normalized_method: "min-max"
106 |
107 | shuffle_data: True
108 | normallize_spectrum: True
109 | min_spec_value: -13
110 | max_spec_value: 3
111 | normallize_residual: True
112 | min_residual_value: -0.25
113 | max_residual_value: 0.25
114 | max_win_length: 100 ## maximum size of window in spectrum
115 |
116 | ################## Training ###################
117 | train:
118 | lr: 0.0001
119 | total_steps: 200000
120 | validate_step: 200
121 | save_ckpt_step: 10000
122 | save_model_path: !join ["output/", *Dataset, "resgrad/ckpt"]
123 | log_dir: !join ["output/", *Dataset, "resgrad/log"]
124 |
125 | ############ Model Parameters #################
126 | model:
127 | model_type1: "spec2residual" ## "spec2spec" or "spec2residual"
128 | model_type2: "segment-based" ## "segment-based" or "sentence-based"
129 | n_feats: 80
130 | dim: 64
131 | # n_spks: 1
132 | spk_emb_dim: 64
133 | beta_min: 0.05
134 | beta_max: 20.0
135 | pe_scale: 1000
136 |
137 |
138 | ############################################################################################
139 | ######################################### Vocoder ##########################################
140 | ############################################################################################
141 | vocoder:
142 | model_name: "g_2500000_persian"
143 |
--------------------------------------------------------------------------------
/config/Persian/config.yaml:
--------------------------------------------------------------------------------
1 | ############################################################################################
2 | ####################################### Main Config ########################################
3 | ############################################################################################
4 | main:
5 | dataset: &Dataset "Persian"
6 | multi_speaker: False
7 | device: "cpu" ## "cpu" or "cuda" (for using all GPUs) or "cuda:0" or "cuda:1" or ...
8 |
9 |
10 | ############################################################################################
11 | ####################################### Synthesizer ########################################
12 | ############################################################################################
13 | synthesizer:
14 | ##################### Model #######################
15 | model:
16 | transformer:
17 | encoder_layer: 4
18 | encoder_head: 2
19 | encoder_hidden: 256
20 | decoder_layer: 6
21 | decoder_head: 2
22 | decoder_hidden: 256
23 | conv_filter_size: 1024
24 | conv_kernel_size: [9, 1]
25 | encoder_dropout: 0.2
26 | decoder_dropout: 0.2
27 |
28 | variance_predictor:
29 | filter_size: 256
30 | kernel_size: 3
31 | dropout: 0.5
32 |
33 | variance_embedding:
34 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
35 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
36 | n_bins: 256
37 | max_seq_len: 1000
38 |
39 | #################### Preprocesss #####################
40 | preprocess:
41 | path:
42 | corpus_path: !join ["dataset/", *Dataset, "synthesizer_data"]
43 | raw_path: !join ["synthesizer/raw_data/", *Dataset]
44 | preprocessed_path: !join ["synthesizer/preprocessed_data/", *Dataset]
45 | preprocessing:
46 | val_size: 100
47 | text:
48 | text_cleaners: "persian_cleaner" ## english_cleaners or persian_cleaners
49 | language: "fa" ## fa or en
50 | audio:
51 | sampling_rate: 22050
52 | max_wav_value: 32768.0
53 | stft:
54 | filter_length: 1024
55 | hop_length: 256
56 | win_length: 1024
57 | mel:
58 | n_mel_channels: 80
59 | mel_fmin: 0
60 | mel_fmax: 8000
61 | pitch:
62 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
63 | normalization: True
64 | energy:
65 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
66 | normalization: True
67 |
68 | #################### Training #####################
69 | train:
70 | path:
71 | ckpt_path: !join ["output/", *Dataset, "synthesizer/ckpt"]
72 | log_path: !join ["output/", *Dataset, "synthesizer/log"]
73 | # result_path: !join ["output/", *Dataset, "synthesizer/result"]
74 | optimizer:
75 | batch_size: 16
76 | betas: [0.9, 0.98]
77 | eps: 0.000000001
78 | weight_decay: 0.0
79 | grad_clip_thresh: 1.0
80 | grad_acc_step: 1
81 | warm_up_step: 4000
82 | anneal_steps: [300000, 400000, 500000]
83 | anneal_rate: 0.3
84 | step:
85 | total_step: 1010000
86 | log_step: 500
87 | synth_step: 1000
88 | val_step: 1000
89 | save_step: 100000
90 |
91 |
92 | ############################################################################################
93 | ########################################### ResGrad ########################################
94 | ############################################################################################
95 | resgrad:
96 | #################### Data #####################
97 | data:
98 | batch_size: 32
99 | metadata_path: !join ["dataset/", *Dataset, "resgrad_data/metadata.csv"]
100 | input_mel_dir: !join ["dataset/", *Dataset, "resgrad_data/input_mel"]
101 | speaker_map_path: !join ["synthesizer/preprocessed_data", *Dataset, "speakers.json"]
102 | val_size: 128
103 | preprocessed_path: "processed_data"
104 | normalized_method: "min-max"
105 |
106 | shuffle_data: True
107 | normallize_spectrum: True
108 | min_spec_value: -13
109 | max_spec_value: 3
110 | normallize_residual: True
111 | min_residual_value: -0.25
112 | max_residual_value: 0.25
113 | max_win_length: 100 ## maximum size of window in spectrum
114 |
115 | ################## Training ###################
116 | train:
117 | lr: 0.0001
118 | save_model_path: !join ["output/", *Dataset, "resgrad/ckpt"]
119 | log_dir: !join ["output/", *Dataset, "resgrad/log"]
120 | total_steps: 100000
121 | save_ckpt_step: 10000
122 | validate_step: 100
123 |
124 | # optimizer:
125 | # betas: [0.9, 0.98]
126 | # eps: 0.000000001
127 | # weight_decay: 0.0
128 | # grad_clip_thresh: 1.0
129 | # grad_acc_step: 1
130 | # warm_up_step: 4000
131 | # anneal_steps: [10000, 20000, 30000]
132 | # anneal_rate: 0.3
133 |
134 | ############ Model Parameters #################
135 | model:
136 | model_type1: "spec2residual" ## "spec2spec" or "spec2residual"
137 | model_type2: "segment-based" ## "segment-based" or "sentence-based"
138 | n_feats: 80
139 | dim: 64
140 | n_spks: 1
141 | spk_emb_dim: 64
142 | beta_min: 0.05
143 | beta_max: 20.0
144 | pe_scale: 1000
145 |
146 |
147 | ############################################################################################
148 | ######################################### Vocoder ##########################################
149 | ############################################################################################
150 | vocoder:
151 | model_name: "g_2500000_persian"
152 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from synthesizer.synthesize import infer as synthesizer_infer
2 | from resgrad.inference import infer as resgrad_infer
3 | from vocoder.inference import infer as vocoder_infer
4 | from utils import load_models, save_result, load_yaml_file, get_file_name
5 |
6 | import argparse
7 | import time
8 |
9 | def infer():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--text", type=str, required=True)
12 | parser.add_argument("--speaker_id", type=int, default=0, required=False)
13 | parser.add_argument("--synthesizer_restore_step", type=int, required=True)
14 | parser.add_argument("--regrad_restore_step", type=int, required=False)
15 | parser.add_argument("--vocoder_restore_step", type=int, default=0 ,required=False)
16 | parser.add_argument("--result_dir", type=str, default="output/Persian/results", required=False)
17 | parser.add_argument("--result_file_name", type=str, default="", required=False)
18 | parser.add_argument("--pitch_control", type=float, default=1.0, required=False)
19 | parser.add_argument("--energy_control", type=float, default=1.0, required=False)
20 | parser.add_argument("--duration_control", type=float, default=1.0, required=False)
21 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml")
22 | args = parser.parse_args()
23 |
24 | # Read Config
25 | config = load_yaml_file(args.config)
26 |
27 | print("load models...")
28 | restore_steps = {"synthesizer":args.synthesizer_restore_step, "resgrad":args.regrad_restore_step, "vocoder":args.vocoder_restore_step}
29 | synthesizer_model, resgrad_model, vocoder_model = load_models(restore_steps, config)
30 |
31 | ## Synthesizer
32 | control_values = args.pitch_control, args.energy_control, args.duration_control
33 | start_time = time.time()
34 | mel_prediction, duration_prediction, pitch_prediction, energy_prediction = synthesizer_infer(synthesizer_model, args.text, control_values, \
35 | config['synthesizer']['preprocess'], \
36 | config['main']['device'], \
37 | speaker = args.speaker_id)
38 | end_time = time.time()
39 | FastSpeech_process_time = end_time-start_time
40 |
41 | ## Save FastSpeech2 result as wav
42 | wav = vocoder_infer(vocoder_model, mel_prediction, config['synthesizer']['preprocess']["preprocessing"]["audio"]["max_wav_value"])
43 | print("Save FastSpeech2 result...")
44 | file_name = get_file_name(args)
45 | save_result(mel_prediction, wav, pitch_prediction, energy_prediction, config['synthesizer']['preprocess'], args.result_dir, file_name)
46 |
47 | ## Real-Time factor of FastSpeech2
48 | wav_length = len(wav)/config['synthesizer']['preprocess']["preprocessing"]["audio"]["sampling_rate"]
49 | RTF_FastSpeech = FastSpeech_process_time / wav_length
50 | print("FastSpeech2 RTF: {:.6f}".format(RTF_FastSpeech))
51 |
52 | ## ResGrad
53 | # print("Inference from ResGrad...")
54 | # start_time = time.time()
55 | # mel_prediction = resgrad_infer(resgrad_model, mel_prediction, duration_prediction, args.speaker_id, config['resgrad'], config['main']['device'])
56 | # end_time = time.time()
57 | # ResGrad_process_time = end_time-start_time
58 |
59 | # ## Vocoder
60 | # wav = vocoder_infer(vocoder_model, mel_prediction, config['synthesizer']['preprocess']["preprocessing"]["audio"]["max_wav_value"])
61 |
62 | # ## Save result
63 | # print("Save ResGrad result...")
64 | # file_name = file_name.replace("FastSpeech", "ResGrad")
65 | # save_result(mel_prediction.squeeze(), wav, pitch_prediction, energy_prediction, config['synthesizer']['preprocess'], args.result_dir, file_name)
66 |
67 | # ## Real-Time factor of ResGrad
68 | # wav_length = len(wav)/config['synthesizer']['preprocess']["preprocessing"]["audio"]["sampling_rate"]
69 | # RTF_ResGrad = ResGrad_process_time / wav_length
70 | # print("ResGrad RTF: {:.6f}".format(RTF_ResGrad))
71 |
72 |
73 | if __name__ == "__main__":
74 | infer()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | g2p-en==2.1.0
2 | inflect==4.1.0
3 | librosa==0.7.2
4 | matplotlib==3.2.2
5 | numba==0.48
6 | numpy==1.19.0
7 | pypinyin==0.39.0
8 | pyworld==0.2.10
9 | PyYAML==5.4.1
10 | scikit-learn==0.23.2
11 | scipy==1.5.0
12 | soundfile==0.10.3.post1
13 | tensorboard==2.2.2
14 | tgt==1.4.4
15 | torch==1.7.0
16 | tqdm==4.46.1
17 | unidecode==1.1.1
18 | einops==0.3.0
--------------------------------------------------------------------------------
/resgrad.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Adibian/ResGrad/a79625a781fd5c3c8d630566b2c72b8f6e011371/resgrad.PNG
--------------------------------------------------------------------------------
/resgrad/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import numpy as np
3 | import torch
4 | import os
5 | import csv
6 | import json
7 |
8 | from .utils import normalize_residual, normalize_data
9 |
10 | class SpectumDataset(Dataset):
11 | def __init__(self, config):
12 | super(SpectumDataset, self).__init__()
13 | self.config = config
14 | with open(config['data']['speaker_map_path']) as f:
15 | self.speaker_map = json.load(f)
16 |
17 | self.input_data_path = []
18 | self.target_data_path = []
19 | self.duration_data_path = []
20 | self.speakers = []
21 | with open(config['data']['metadata_path'], mode='r') as csv_file:
22 | csv_reader = csv.DictReader(csv_file)
23 | line_count = 0
24 | for row in csv_reader:
25 | line_count += 1
26 | if line_count > 1:
27 | self.input_data_path.append(row['predicted_mel'])
28 | self.target_data_path.append(row['target_mel'])
29 | self.duration_data_path.append(row['duration'])
30 | self.speakers.append(self.speaker_map[row['speaker']])
31 |
32 | if config['model']['model_type2'] == "segment-based":
33 | self.max_len = config['data']['max_win_length']
34 | else:
35 | self.max_len = config['data']['spectrum_max_size']
36 |
37 | def __getitem__(self, index):
38 | input_spec_path = self.input_data_path[index]
39 | input_spec = np.load(input_spec_path)
40 | target_spec_path = self.target_data_path[index]
41 | target_spec = np.load(target_spec_path)
42 | dutarions_path = self.duration_data_path[index]
43 | durations = np.load(dutarions_path)
44 | target_spec = torch.from_numpy(target_spec).T
45 | input_spec = torch.from_numpy(input_spec).squeeze()
46 | if self.config['data']['normallize_spectrum']:
47 | input_spec = normalize_data(input_spec, self.config)
48 | target_spec = normalize_data(target_spec, self.config)
49 |
50 | if self.config['model']['model_type2'] == "segment-based":
51 | start_phoneme_index = np.random.choice(len(durations)-min(4, len(durations)-1), 1)[0]
52 | end_phoneme_index = 0
53 | for i in range(start_phoneme_index+1, len(durations)+1):
54 | win_length = sum(durations[start_phoneme_index:i])
55 | if win_length > self.max_len:
56 | end_phoneme_index = i-1
57 | break
58 | if end_phoneme_index == 0:
59 | end_phoneme_index = len(durations)
60 | for i in range(start_phoneme_index):
61 | start_phoneme_index -= 1
62 | win_length = sum(durations[start_phoneme_index:end_phoneme_index])
63 | if win_length > self.max_len:
64 | start_phoneme_index += 1
65 | break
66 | win_start = sum(durations[:start_phoneme_index])
67 | win_end = sum(durations[:end_phoneme_index])
68 |
69 | input_spec = input_spec[:,win_start:win_end]
70 | target_spec = target_spec[:,win_start:win_end]
71 |
72 | spec_size = input_spec.shape[-1]
73 | input_spec = torch.nn.functional.pad(input_spec, (0, self.max_len-spec_size), mode = "constant", value = 0.0)
74 | target_spec = torch.nn.functional.pad(target_spec, (0, self.max_len-spec_size), mode = "constant", value = 0.0)
75 |
76 | residual_spec = target_spec - input_spec
77 | if self.config['data']['normallize_residual']:
78 | residual_spec = normalize_residual(residual_spec, self.config)
79 |
80 | mask = torch.ones((1, input_spec.shape[-1]))
81 | mask[:,spec_size:] = 0
82 |
83 | speaker = self.speakers[index]
84 |
85 | if self.config['model']['model_type1'] == "spec2residual":
86 | residual_spec = target_spec - input_spec
87 | if self.config['data']['normallize_residual']:
88 | residual_spec = normalize_residual(residual_spec, self.config)
89 | residual_spec = residual_spec*mask
90 | return input_spec, target_spec, residual_spec, mask, speaker
91 | else:
92 | return input_spec, target_spec, mask, speaker
93 |
94 |
95 | def __len__(self):
96 | return len(self.input_data_path)
97 |
98 |
99 | def create_dataset(config):
100 | dataset = SpectumDataset(config)
101 | val_dataset, train_dataset = torch.utils.data.random_split(dataset, [config['data']['val_size'], len(dataset)-(config['data']['val_size'])])
102 | return DataLoader(train_dataset, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle_data']), \
103 | DataLoader(val_dataset, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle_data'])
104 |
--------------------------------------------------------------------------------
/resgrad/inference.py:
--------------------------------------------------------------------------------
1 | from .utils import denormalize_residual, denormalize_data, normalize_data
2 |
3 | import torch
4 |
5 |
6 | def infer(model, mel_prediction, duration_prediction, speaker, config, device):
7 | synthesized_spec = mel_prediction.unsqueeze(0).to(device)
8 | # synthesized_spec = torch.from_numpy(synthesized_spec)
9 | if config['data']['normallize_spectrum']:
10 | synthesized_spec = normalize_data(synthesized_spec, config)
11 |
12 | if config['model']['model_type2'] == "segment-based":
13 | durations = duration_prediction
14 | # durations = torch.round(torch.exp(duration_prediction.squeeze()) - 1)
15 |
16 | all_mask, all_segment_spec, all_start_points, all_spec_size = [], [], [], []
17 | pred = torch.zeros(synthesized_spec.shape)
18 |
19 | ## Create segments of date exept last segment
20 | start_phoneme_index = 0
21 | end_phoneme_index = 0
22 | for i in range(1, len(durations)+1):
23 | win_length = int(sum(durations[start_phoneme_index:i]))
24 | if win_length > config['data']['max_win_length']:
25 | end_phoneme_index = i-1
26 | start_point = int(sum(durations[:start_phoneme_index]))
27 | end_point = int(sum(durations[:end_phoneme_index]))
28 | segment_spec = synthesized_spec[:,:,start_point:end_point]
29 | all_start_points.append(start_point)
30 | spec_size = segment_spec.shape[-1]
31 | all_spec_size.append(spec_size)
32 | segment_spec = torch.nn.functional.pad(segment_spec, (0, config['data']['max_win_length']-spec_size), mode = "constant", value = 0.0)
33 | mask = torch.ones((1, segment_spec.shape[-1])).to(device)
34 | mask[:,spec_size:] = 0
35 | all_mask.append(mask.unsqueeze(0))
36 | all_segment_spec.append(segment_spec)
37 | start_phoneme_index = end_phoneme_index
38 |
39 | ## Create last segment of data with overlapping to last previous segments
40 | start_phoneme_index = len(durations)
41 | end_phoneme_index = len(durations)
42 | for i in range(len(durations)):
43 | start_phoneme_index -= 1
44 | win_length = int(sum(durations[start_phoneme_index:]))
45 | if win_length > config['data']['max_win_length']:
46 | start_phoneme_index += 1
47 | start_point = int(sum(durations[:start_phoneme_index]))
48 | end_point = int(sum(durations[:end_phoneme_index]))
49 | segment_spec = synthesized_spec[:,:,start_point:end_point]
50 | all_start_points.append(start_point)
51 | spec_size = segment_spec.shape[-1]
52 | all_spec_size.append(spec_size)
53 | segment_spec = torch.nn.functional.pad(segment_spec, (0, config['data']['max_win_length']-spec_size), mode = "constant", value = 0.0)
54 | mask = torch.ones((1, segment_spec.shape[-1])).to(device)
55 | mask[:,spec_size:] = 0
56 | all_mask.append(mask.unsqueeze(0))
57 | all_segment_spec.append(segment_spec)
58 | break
59 |
60 | speakers = [speaker] * len(all_segment_spec)
61 | mask = torch.cat(all_mask).to(device)
62 | segment_spec = torch.cat(all_segment_spec).to(device)
63 | z = segment_spec + torch.randn_like(segment_spec, device=device) / 1.5
64 | segments_pred = torch.zeros(segment_spec.shape)
65 | speakers = torch.tensor(speakers)
66 | batch_size = config['data']['batch_size']
67 | for i in range(0, segment_spec.shape[0], batch_size):
68 | segments_pred = model(z[i:i+batch_size], mask[i:i+batch_size], segment_spec[i:i+batch_size], \
69 | n_timesteps=50, stoc=False, spk_id=speakers[i:i+batch_size])
70 |
71 | for i in range(len(segments_pred)):
72 | segment_pred = segments_pred[i,:,:all_spec_size[i]]
73 | pred[:,:,all_start_points[i]:all_start_points[i]+all_spec_size[i]] = segment_pred
74 | else:
75 | mask = torch.ones(synthesized_spec.shape).to(device)
76 | z = synthesized_spec + torch.randn_like(synthesized_spec, device=device) / 1.5
77 | pred = model(z, mask, synthesized_spec, n_timesteps=100, stoc=False, spk_id=[speaker])
78 | pred = pred.to(device)
79 |
80 | if config['model']['model_type1'] == "spec2residual":
81 | if config['data']['normallize_residual']:
82 | spec_pred = denormalize_residual(pred, config) + synthesized_spec
83 | else:
84 | spec_pred = pred + synthesized_spec
85 | else:
86 | spec_pred = pred
87 |
88 | if config['data']['normallize_spectrum']:
89 | spec_pred = denormalize_data(spec_pred, config)
90 |
91 | return spec_pred
--------------------------------------------------------------------------------
/resgrad/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2 | # This program is free software; you can redistribute it and/or modify
3 | # it under the terms of the MIT License.
4 | # This program is distributed in the hope that it will be useful,
5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7 | # MIT License for more details.
8 |
9 | from .diffusion import Diffusion
10 |
--------------------------------------------------------------------------------
/resgrad/model/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2 | # This program is free software; you can redistribute it and/or modify
3 | # it under the terms of the MIT License.
4 | # This program is distributed in the hope that it will be useful,
5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7 | # MIT License for more details.
8 |
9 | import numpy as np
10 | import torch
11 |
12 |
13 | class BaseModule(torch.nn.Module):
14 | def __init__(self):
15 | super(BaseModule, self).__init__()
16 |
17 | @property
18 | def nparams(self):
19 | """
20 | Returns number of trainable parameters of the module.
21 | """
22 | num_params = 0
23 | for name, param in self.named_parameters():
24 | if param.requires_grad:
25 | num_params += np.prod(param.detach().cpu().numpy().shape)
26 | return num_params
27 |
28 |
29 | def relocate_input(self, x: list):
30 | """
31 | Relocates provided tensors to the same device set for the module.
32 | """
33 | device = next(self.parameters()).device
34 | for i in range(len(x)):
35 | if isinstance(x[i], torch.Tensor) and x[i].device != device:
36 | x[i] = x[i].to(device)
37 | return x
38 |
--------------------------------------------------------------------------------
/resgrad/model/diffusion.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2 | # This program is free software; you can redistribute it and/or modify
3 | # it under the terms of the MIT License.
4 | # This program is distributed in the hope that it will be useful,
5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7 | # MIT License for more details.
8 |
9 | import math
10 | import torch
11 | from einops import rearrange
12 |
13 | from .base import BaseModule
14 |
15 |
16 | class Mish(BaseModule):
17 | def forward(self, x):
18 | return x * torch.tanh(torch.nn.functional.softplus(x))
19 |
20 |
21 | class Upsample(BaseModule):
22 | def __init__(self, dim):
23 | super(Upsample, self).__init__()
24 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
25 |
26 | def forward(self, x):
27 | return self.conv(x)
28 |
29 |
30 | class Downsample(BaseModule):
31 | def __init__(self, dim):
32 | super(Downsample, self).__init__()
33 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
34 |
35 | def forward(self, x):
36 | return self.conv(x)
37 |
38 |
39 | class Rezero(BaseModule):
40 | def __init__(self, fn):
41 | super(Rezero, self).__init__()
42 | self.fn = fn
43 | self.g = torch.nn.Parameter(torch.zeros(1))
44 |
45 | def forward(self, x):
46 | return self.fn(x) * self.g
47 |
48 |
49 | class Block(BaseModule):
50 | def __init__(self, dim, dim_out, groups=8):
51 | super(Block, self).__init__()
52 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
53 | padding=1), torch.nn.GroupNorm(
54 | groups, dim_out), Mish())
55 |
56 | def forward(self, x, mask):
57 | output = self.block(x * mask)
58 | return output * mask
59 |
60 |
61 | class ResnetBlock(BaseModule):
62 | def __init__(self, dim, dim_out, time_emb_dim, groups=8):
63 | super(ResnetBlock, self).__init__()
64 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
65 | dim_out))
66 |
67 | self.block1 = Block(dim, dim_out, groups=groups)
68 | self.block2 = Block(dim_out, dim_out, groups=groups)
69 | if dim != dim_out:
70 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
71 | else:
72 | self.res_conv = torch.nn.Identity()
73 |
74 | def forward(self, x, mask, time_emb):
75 | h = self.block1(x, mask)
76 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
77 | h = self.block2(h, mask)
78 | output = h + self.res_conv(x * mask)
79 | return output
80 |
81 |
82 | class LinearAttention(BaseModule):
83 | def __init__(self, dim, heads=4, dim_head=32):
84 | super(LinearAttention, self).__init__()
85 | self.heads = heads
86 | hidden_dim = dim_head * heads
87 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
88 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
89 |
90 | def forward(self, x):
91 | b, c, h, w = x.shape
92 | qkv = self.to_qkv(x)
93 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
94 | heads = self.heads, qkv=3)
95 | k = k.softmax(dim=-1)
96 | context = torch.einsum('bhdn,bhen->bhde', k, v)
97 | out = torch.einsum('bhde,bhdn->bhen', context, q)
98 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
99 | heads=self.heads, h=h, w=w)
100 | return self.to_out(out)
101 |
102 |
103 | class Residual(BaseModule):
104 | def __init__(self, fn):
105 | super(Residual, self).__init__()
106 | self.fn = fn
107 |
108 | def forward(self, x, *args, **kwargs):
109 | output = self.fn(x, *args, **kwargs) + x
110 | return output
111 |
112 |
113 | class SinusoidalPosEmb(BaseModule):
114 | def __init__(self, dim):
115 | super(SinusoidalPosEmb, self).__init__()
116 | self.dim = dim
117 |
118 | def forward(self, x, scale=1000):
119 | device = x.device
120 | half_dim = self.dim // 2
121 | emb = math.log(10000) / (half_dim - 1)
122 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
123 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
124 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
125 | return emb
126 |
127 |
128 | class GradLogPEstimator2d(BaseModule):
129 | def __init__(self, dim, dim_mults=(1, 2, 4), groups=8,
130 | n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
131 | super(GradLogPEstimator2d, self).__init__()
132 | self.dim = dim
133 | self.dim_mults = dim_mults
134 | self.groups = groups
135 | self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
136 | self.spk_emb_dim = spk_emb_dim
137 | self.pe_scale = pe_scale
138 |
139 | if n_spks > 1:
140 | self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
141 | torch.nn.Linear(spk_emb_dim * 4, n_feats))
142 | self.time_pos_emb = SinusoidalPosEmb(dim)
143 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
144 | torch.nn.Linear(dim * 4, dim))
145 |
146 | dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
147 | in_out = list(zip(dims[:-1], dims[1:]))
148 | self.downs = torch.nn.ModuleList([])
149 | self.ups = torch.nn.ModuleList([])
150 | num_resolutions = len(in_out)
151 |
152 | for ind, (dim_in, dim_out) in enumerate(in_out):
153 | is_last = ind >= (num_resolutions - 1)
154 | self.downs.append(torch.nn.ModuleList([
155 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
156 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
157 | Residual(Rezero(LinearAttention(dim_out))),
158 | Downsample(dim_out) if not is_last else torch.nn.Identity()]))
159 |
160 | mid_dim = dims[-1]
161 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
162 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
163 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
164 |
165 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
166 | self.ups.append(torch.nn.ModuleList([
167 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
168 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
169 | Residual(Rezero(LinearAttention(dim_in))),
170 | Upsample(dim_in)]))
171 | self.final_block = Block(dim, dim)
172 | self.final_conv = torch.nn.Conv2d(dim, 1, 1)
173 |
174 | def forward(self, x, mask, mu, t, spk=None):
175 | if not isinstance(spk, type(None)):
176 | s = self.spk_mlp(spk)
177 |
178 | t = self.time_pos_emb(t, scale=self.pe_scale)
179 | t = self.mlp(t)
180 |
181 | if self.n_spks < 2:
182 | x = torch.stack([mu, x], 1)
183 | else:
184 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
185 | x = torch.stack([mu, x, s], 1)
186 | mask = mask.unsqueeze(1)
187 |
188 | hiddens = []
189 | masks = [mask]
190 | for resnet1, resnet2, attn, downsample in self.downs:
191 | mask_down = masks[-1]
192 | x = resnet1(x, mask_down, t)
193 | x = resnet2(x, mask_down, t)
194 | x = attn(x)
195 | hiddens.append(x)
196 | x = downsample(x * mask_down)
197 | masks.append(mask_down[:, :, :, ::2])
198 |
199 | masks = masks[:-1]
200 | mask_mid = masks[-1]
201 | x = self.mid_block1(x, mask_mid, t)
202 | x = self.mid_attn(x)
203 | x = self.mid_block2(x, mask_mid, t)
204 |
205 | for resnet1, resnet2, attn, upsample in self.ups:
206 | mask_up = masks.pop()
207 | x = torch.cat((x, hiddens.pop()), dim=1)
208 | x = resnet1(x, mask_up, t)
209 | x = resnet2(x, mask_up, t)
210 | x = attn(x)
211 | x = upsample(x * mask_up)
212 |
213 | x = self.final_block(x, mask)
214 | output = self.final_conv(x * mask)
215 |
216 | return (output * mask).squeeze(1)
217 |
218 |
219 | def get_noise(t, beta_init, beta_term, cumulative=False):
220 | if cumulative:
221 | noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
222 | else:
223 | noise = beta_init + (beta_term - beta_init)*t
224 | return noise
225 |
226 |
227 | class Diffusion(BaseModule):
228 | def __init__(self, n_feats, dim,
229 | n_spks=1, spk_emb_dim=64,
230 | beta_min=0.05, beta_max=20, pe_scale=1000):
231 | super(Diffusion, self).__init__()
232 | self.n_feats = n_feats
233 | self.dim = dim
234 | self.n_spks = n_spks
235 | self.spk_emb_dim = spk_emb_dim
236 | self.beta_min = beta_min
237 | self.beta_max = beta_max
238 | self.pe_scale = pe_scale
239 | if n_spks>1:
240 | self.spk_emb = torch.nn.Embedding(
241 | n_spks,
242 | spk_emb_dim
243 | )
244 |
245 | self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks,
246 | spk_emb_dim=spk_emb_dim,
247 | pe_scale=pe_scale, n_feats=n_feats)
248 |
249 | def forward_diffusion(self, x0, mask, mu, t):
250 | time = t.unsqueeze(-1).unsqueeze(-1)
251 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
252 | mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
253 | variance = 1.0 - torch.exp(-cum_noise)
254 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
255 | requires_grad=False)
256 | xt = mean + z * torch.sqrt(variance)
257 | return xt * mask, z * mask
258 |
259 | @torch.no_grad()
260 | def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
261 | h = 1.0 / n_timesteps
262 | xt = z * mask
263 | for i in range(n_timesteps):
264 | t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
265 | device=z.device)
266 | time = t.unsqueeze(-1).unsqueeze(-1)
267 | noise_t = get_noise(time, self.beta_min, self.beta_max,
268 | cumulative=False)
269 | if stoc: # adds stochastic term
270 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
271 | dxt_det = dxt_det * noise_t * h
272 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
273 | requires_grad=False)
274 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
275 | dxt = dxt_det + dxt_stoc
276 | else:
277 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
278 | dxt = dxt * noise_t * h
279 | xt = (xt - dxt) * mask
280 | return xt
281 |
282 | @torch.no_grad()
283 | def forward(self, z, mask, mu, n_timesteps, stoc=False, spk_id=None):
284 | if self.n_spks > 1:
285 | spk = self.spk_emb(spk_id)
286 | return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
287 |
288 | def loss_t(self, x0, mask, mu, t, spk_id=None):
289 | xt, z = self.forward_diffusion(x0, mask, mu, t)
290 | time = t.unsqueeze(-1).unsqueeze(-1)
291 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
292 | if self.n_spks > 1:
293 | spk = self.spk_emb(spk_id)
294 | noise_estimation = self.estimator(xt, mask, mu, t, spk)
295 | noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
296 | loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats)
297 | return loss, xt
298 |
299 | def compute_loss(self, x0, mask, mu, spk_id=None, offset=1e-5):
300 | t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
301 | requires_grad=False)
302 | t = torch.clamp(t, offset, 1.0 - offset)
303 | return self.loss_t(x0, mask, mu, t, spk_id)
304 |
--------------------------------------------------------------------------------
/resgrad/model/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ScheduledOptim:
6 | """ A simple wrapper class for learning rate scheduling """
7 |
8 | def __init__(self, model, config, current_step):
9 |
10 | self._optimizer = torch.optim.Adam(
11 | model.parameters(),
12 | betas=config["optimizer"]["betas"],
13 | eps=config["optimizer"]["eps"],
14 | weight_decay=config["optimizer"]["weight_decay"],
15 | )
16 | self.n_warmup_steps = config["optimizer"]["warm_up_step"]
17 | self.anneal_steps = config["optimizer"]["anneal_steps"]
18 | self.anneal_rate = config["optimizer"]["anneal_rate"]
19 | self.current_step = current_step
20 | self.init_lr = np.power(config["model"]["dim"]*4, -0.5)
21 |
22 | def step_and_update_lr(self):
23 | self._update_learning_rate()
24 | self._optimizer.step()
25 |
26 | def zero_grad(self):
27 | self._optimizer.zero_grad()
28 |
29 | def load_state_dict(self, path):
30 | self._optimizer.load_state_dict(path)
31 |
32 | def _get_lr_scale(self):
33 | lr = np.min(
34 | [
35 | np.power(self.current_step, -0.5),
36 | np.power(self.n_warmup_steps, -1.5) * self.current_step,
37 | ]
38 | )
39 | for s in self.anneal_steps:
40 | if self.current_step > s:
41 | lr = lr * self.anneal_rate
42 | return lr
43 |
44 | def _update_learning_rate(self):
45 | """ Learning rate scheduling per step """
46 | self.current_step += 1
47 | lr = self.init_lr * self._get_lr_scale()
48 |
49 | for param_group in self._optimizer.param_groups:
50 | param_group["lr"] = lr
51 |
--------------------------------------------------------------------------------
/resgrad/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.tensorboard import SummaryWriter
3 | from tqdm import tqdm
4 | import os
5 |
6 | from .utils import plot_tensor, denormalize_residual, load_model
7 | from .data import create_dataset
8 |
9 |
10 | def logging(logger, config, original_spec, synthesized_spec, target_residual, noisy_spec, pred, mask, title, step):
11 | zero_indexes = (mask == 0).nonzero()
12 | if len(zero_indexes):
13 | start_zero_index = int(zero_indexes[0][-1])
14 | else:
15 | start_zero_index = original_spec.shape[-1]
16 | original_spec = original_spec[:,:start_zero_index]
17 | synthesized_spec = synthesized_spec[:,:start_zero_index]
18 | noisy_spec = noisy_spec[:,:start_zero_index]
19 | if config['model']['model_type1'] == "spec2residual":
20 | target_residual = target_residual[:,:start_zero_index]
21 | pred_residual = pred[:,:start_zero_index]
22 | if config['data']['normallize_residual']:
23 | pred_spec = denormalize_residual(pred[:,:start_zero_index], config) + synthesized_spec
24 | else:
25 | pred_spec = pred[:,:start_zero_index] + synthesized_spec
26 | logger.add_image(f'{title}/target_residual_spec', plot_tensor(target_residual.squeeze().cpu().detach().numpy(), "residual", config), global_step=step, dataformats='HWC')
27 | logger.add_image(f'{title}/predicted_residual_spec', plot_tensor(pred_residual.squeeze().cpu().detach().numpy(), "residual", config), global_step=step, dataformats='HWC')
28 | else:
29 | pred_spec = pred[:,:start_zero_index]
30 |
31 | logger.add_image(f'{title}/input_spec', plot_tensor(synthesized_spec.squeeze().cpu().detach().numpy(), "spectrum", config), global_step=step, dataformats='HWC')
32 | logger.add_image(f'{title}/predicted_spec', plot_tensor(pred_spec.squeeze().cpu().detach().numpy(), "spectrum", config), global_step=step, dataformats='HWC')
33 | logger.add_image(f'{title}/target_spec', plot_tensor(original_spec.squeeze().cpu().detach().numpy(), "spectrum", config), global_step=step, dataformats='HWC')
34 | logger.add_image(f'{title}/noisy_spec', plot_tensor(noisy_spec.squeeze().cpu().detach().numpy(), "noisy_spectrum", config), global_step=step, dataformats='HWC')
35 |
36 |
37 | def resgrad_train(args, config):
38 | os.makedirs(config['train']['log_dir'], exist_ok=True)
39 | os.makedirs(config['train']['save_model_path'], exist_ok=True)
40 |
41 | device = config['main']['device']
42 |
43 | print('Initializing logger...')
44 | logger = SummaryWriter(log_dir=config['train']['log_dir'])
45 | print("Load data...")
46 | train_dataset, val_dataset = create_dataset(config)
47 |
48 | print("Load model...")
49 | model, optimizer = load_model(config, train=True, restore_model_step=args.restore_step)
50 |
51 | scaler = torch.cuda.amp.GradScaler()
52 | # grad_acc_step = config["optimizer"]["grad_acc_step"]
53 | # grad_clip_thresh = config["optimizer"]["grad_clip_thresh"]
54 |
55 | step = args.restore_step - 1
56 | epoch = args.restore_step // (len(train_dataset)//config['data']['batch_size'] + 1)
57 | avg_val_loss = 0
58 | avg_train_loss = 0
59 |
60 | print("Start training...")
61 | outer_bar = tqdm(total=config['train']['total_steps'], desc="Total Training", position=0)
62 | outer_bar.n = step
63 | outer_bar.update()
64 |
65 | while True:
66 | inner_bar = tqdm(total=len(train_dataset), desc="Epoch {}".format(epoch), position=1)
67 | train_loss_list = []
68 | epoch += 1
69 | for train_data in train_dataset:
70 | step += 1
71 | inner_bar.update(1)
72 | outer_bar.update(1)
73 | if config['model']['model_type1'] == "spec2residual":
74 | synthesized_spec, original_spec, residual_spec, mask, speakers = train_data
75 | synthesized_spec = synthesized_spec.to(device)
76 | mask = mask.to(device)
77 | residual_spec = residual_spec.to(device)
78 | if config['main']['multi_speaker']:
79 | speakers = speakers.to(device)
80 | loss, pred = model.compute_loss(residual_spec, mask, synthesized_spec, speakers)
81 | else:
82 | loss, pred = model.compute_loss(residual_spec, mask, synthesized_spec)
83 |
84 | else:
85 | synthesized_spec, original_spec, mask, speakers = train_data
86 | mask = mask.to(device)
87 | synthesized_spec = synthesized_spec.to(device)
88 | original_spec = original_spec.to(device)
89 | if config['main']['multi_speaker']:
90 | speakers = speakers.to(device)
91 | loss, pred = model.compute_loss(original_spec, mask, synthesized_spec, speakers)
92 | else:
93 | loss, pred = model.compute_loss(original_spec, mask, synthesized_spec)
94 |
95 | optimizer.zero_grad()
96 | scaler.scale(loss).backward()
97 | scaler.step(optimizer)
98 | scaler.update()
99 | train_loss_list.append(loss.item())
100 |
101 | # loss.backward()
102 | # train_loss_list.append(loss.item())
103 | # if step % grad_acc_step == 0:
104 | # # Clipping gradients to avoid gradient explosion
105 | # torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
106 | # # Update weights
107 | # optimizer.step_and_update_lr()
108 | # optimizer.zero_grad()
109 |
110 | if step % config['train']['validate_step'] == 0:
111 | model.eval()
112 | with torch.no_grad():
113 | all_val_loss = []
114 | val_num = 0
115 | for val_data in val_dataset:
116 | val_num += 1
117 | if config['model']['model_type1'] == "spec2residual":
118 | synthesized_spec, original_spec, target_residual, mask, speakers = val_data
119 | synthesized_spec = synthesized_spec.to(device)
120 | mask = mask.to(device)
121 | target_residual = target_residual.to(device)
122 | if config['main']['multi_speaker']:
123 | speakers = speakers.to(device)
124 | val_loss, noisy_spec = model.compute_loss(target_residual, mask, synthesized_spec, speakers)
125 | else:
126 | val_loss, noisy_spec = model.compute_loss(target_residual, mask, synthesized_spec)
127 |
128 | else:
129 | synthesized_spec, original_spec, mask, speakers = val_data
130 | synthesized_spec = synthesized_spec.to(device)
131 | mask = mask.to(device)
132 | original_spec = original_spec.to(device)
133 | if config['main']['multi_speaker']:
134 | speakers = speakers.to(device)
135 | val_loss, noisy_spec = model.compute_loss(original_spec, mask, synthesized_spec, speakers)
136 | else:
137 | val_loss, noisy_spec = model.compute_loss(original_spec, mask, synthesized_spec)
138 | target_residual = [None for _ in range(len(original_spec))]
139 | all_val_loss.append(val_loss.item())
140 |
141 | ## logging result spectrums
142 | if val_num == 1:
143 | z = synthesized_spec + torch.randn_like(synthesized_spec, device=device) / 1.5
144 | # Generate sample by performing reverse dynamics
145 | if config['main']['multi_speaker']:
146 | pred = model(z, mask, synthesized_spec, n_timesteps=50, stoc=False, spk_id=speakers)
147 | else:
148 | pred = model(z, mask, synthesized_spec, n_timesteps=50, stoc=False, spk_id=None)
149 | for i in range(3):
150 | logging(logger, config, original_spec[i], synthesized_spec[i], target_residual[i], noisy_spec[i], pred[i], mask[i], \
151 | f'image{i}_step{step}', step)
152 |
153 | avg_val_loss = sum(all_val_loss) / len(all_val_loss)
154 | logger.add_scalar('validation/loss', avg_val_loss, global_step=step)
155 | avg_train_loss = sum(train_loss_list) / len(train_loss_list)
156 | logger.add_scalar('training/loss', avg_train_loss, global_step=step)
157 | train_loss_list = []
158 | model.train()
159 |
160 | inner_bar.set_postfix(train_loss=avg_train_loss, val_loss=avg_val_loss)
161 |
162 | if step % config['train']['save_ckpt_step'] == 0:
163 | ## Save checkpoints
164 | torch.save(
165 | {
166 | "model": model.state_dict(),
167 | "optimizer": optimizer.state_dict()
168 | # "optimizer": optimizer._optimizer.state_dict(),
169 | },
170 | os.path.join(config['train']['save_model_path'], f'ResGrad_step{step}.pth')
171 | )
172 | # torch.save(model.state_dict(), os.path.join(config['train']['save_model_path'], f'ResGrad_step{step}.pth'))
173 | # torch.save(optimizer.state_dict(), os.path.join(config['train']['save_model_path'], 'optimizer.pth'))
174 |
175 | if step > config['train']['total_steps']:
176 | quit()
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/resgrad/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import os
5 | import json
6 | from .model import Diffusion
7 | from .model.optimizer import ScheduledOptim
8 |
9 | def save_figure_to_numpy(fig):
10 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
11 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
12 | return data
13 |
14 | def plot_tensor(tensor, tensor_name, config):
15 | plt.style.use('default')
16 | fig, ax = plt.subplots(figsize=(8, 3))
17 | if tensor_name == "spectrum" and config['data']['normallize_spectrum']:
18 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none', \
19 | vmin = 0.2, vmax = 0.9)
20 | elif tensor_name == "residual" and config['data']['normallize_residual']:
21 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none', \
22 | vmin = 0.2, vmax = 0.9)
23 | else:
24 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
25 | plt.colorbar(im, ax=ax)
26 | plt.tight_layout()
27 | fig.canvas.draw()
28 | data = save_figure_to_numpy(fig)
29 | plt.close()
30 | return data
31 |
32 | def plot_spectrum(spec, path):
33 | plt.figure(figsize=(10, 3))
34 | im = plt.imshow(spec, aspect="auto", origin="lower", interpolation='none')
35 | plt.colorbar(im)
36 | plt.tight_layout()
37 | plt.savefig(path)
38 | plt.close()
39 |
40 | def crop_masked_values(mat_list, length):
41 | new_mat_list = []
42 | for mat in mat_list:
43 | new_mat_list.append(mat[:,:length])
44 | return new_mat_list
45 |
46 | def normalize_data(data, config):
47 | if config['data']['normalized_method'] == "min-max":
48 | data = (data - config['data']['min_spec_value'])/(config['data']['max_spec_value'] - config['data']['min_spec_value'])
49 | else:
50 | print("normalized method is not supported!")
51 | return data
52 |
53 | def normalize_residual(residual_spec, config):
54 | if config['data']['normalized_method'] == "min-max":
55 | residual_spec = (residual_spec - config['data']['min_residual_value'])/(config['data']['max_residual_value'] - \
56 | config['data']['min_residual_value'])
57 | else:
58 | print("normalized method is not supported!")
59 | return residual_spec
60 |
61 | def denormalize_data(data, config):
62 | if config['data']['normalized_method'] == "min-max":
63 | data = data * (config['data']['max_spec_value'] - config['data']['min_spec_value']) + config['data']['min_spec_value']
64 | else:
65 | print("normalized method is not supported!")
66 | return data
67 |
68 | def denormalize_residual(residual_spec, config):
69 | if config['data']['normalized_method'] == "min-max":
70 | residual_spec = residual_spec * (config['data']['max_residual_value'] - config['data']['min_residual_value']) + \
71 | config['data']['min_residual_value']
72 | else:
73 | print("normalized method is not supported!")
74 | return residual_spec
75 |
76 | def load_model(config, device, train=False, restore_model_step=0):
77 | with open(config['data']['speaker_map_path']) as f:
78 | speaker_map = json.load(f)
79 | n_spks = len(speaker_map.keys())
80 | model = Diffusion(n_feats=config['model']['n_feats'], dim=config['model']['dim'], n_spks=n_spks, \
81 | spk_emb_dim=config['model']['spk_emb_dim'], beta_min=config['model']['beta_min'], \
82 | beta_max=config['model']['beta_max'], pe_scale=config['model']['pe_scale'])
83 | model = model.to(device)
84 | if restore_model_step > 0:
85 | checkpoint = torch.load(os.path.join(config['train']['save_model_path'], f'ResGrad_step{restore_model_step}.pth'), \
86 | map_location=device)
87 | # checkpoint = torch.load(os.path.join("/mnt/hdd1/adibian/FastSpeech2/ResGrad/output/Persian/dur_taget_pitch_target/resgrad/ckpt", \
88 | # f'ResGrad_epoch{restore_model_epoch}.pth'), \
89 | # map_location=device)
90 | model.load_state_dict(checkpoint['model'])
91 |
92 | if train:
93 | # scheduled_optim = ScheduledOptim(model, config, restore_model_step)
94 | optimizer = torch.optim.Adam(params=model.parameters(), lr=config['train']['lr'])
95 | if restore_model_step > 0:
96 | optimizer_state = torch.load(os.path.join(config['train']['save_model_path'], 'optimizer.pth'))
97 | # scheduled_optim.load_state_dict(checkpoint['optimizer'])
98 | optimizer.load_state_dict(optimizer_state)
99 | model.train()
100 | return model, optimizer
101 |
102 | model.eval()
103 | return model
--------------------------------------------------------------------------------
/resgrad_data.py:
--------------------------------------------------------------------------------
1 | from utils import load_models, load_yaml_file
2 | from synthesizer.synthesize import infer as synthesizer_infer
3 |
4 | import argparse
5 | import os
6 | import csv
7 | from tqdm import tqdm
8 | import numpy as np
9 | import torch
10 | import json
11 |
12 | def read_input_data(data_file_path):
13 | input_texts = {}
14 | with open(data_file_path, mode='r') as f:
15 | lines = f.readlines()
16 | for line in lines:
17 | fields = line.split("|")
18 | file_name, speaker, input_text = fields[0], fields[1], fields[2]
19 | input_texts[(speaker, file_name)] = input_text.strip()
20 | return input_texts
21 |
22 | def main():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--synthesizer_restore_step", type=int, required=True)
25 | parser.add_argument("--data_file_path", type=str, default="dataset/Persian/synthesizer_data/train.txt", required=False)
26 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml")
27 | args = parser.parse_args()
28 |
29 | # Read Config
30 | config = load_yaml_file(args.config)
31 |
32 | restore_steps = {"synthesizer":args.synthesizer_restore_step, "resgrad":None, "vocoder":None}
33 | synthesizer_model, _, _ = load_models(restore_steps, config)
34 | print("Load input data...")
35 | text_data = read_input_data(args.data_file_path)
36 | print("{} inputs data is loaded.".format(len(text_data)))
37 |
38 | with open(os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], "speakers.json")) as f:
39 | speaker_map = json.load(f)
40 |
41 | mel_pred_dir = os.path.join(config['resgrad']['data']['input_mel_dir'])
42 | os.makedirs(mel_pred_dir, exist_ok=True)
43 |
44 | resgrad_data = []
45 | device = config['main']['device']
46 | i = 0
47 | for (speaker, file_name), text in tqdm(text_data.items()):
48 | # i +=1
49 | # if i>30:
50 | # break
51 | dur_file_name = speaker + "-duration-" + file_name + ".npy"
52 | pitch_file_name = speaker + "-pitch-" + file_name + ".npy"
53 | dur_path = os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], 'duration', dur_file_name)
54 | pitch_path = os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], 'pitch', pitch_file_name)
55 |
56 | if not os.path.exists(dur_path):
57 | continue
58 |
59 | mel_target_file_name = speaker + "-mel-" + file_name + ".npy"
60 | mel_target_path = os.path.join(config['synthesizer']['preprocess']['path']['preprocessed_path'], 'mel', mel_target_file_name)
61 |
62 | ### Synthersize mel-spectrum and save as data for resgrad
63 | dur_target = torch.from_numpy(np.load(dur_path)).to(device).unsqueeze(0)
64 | pitch_target = torch.from_numpy(np.load(pitch_path)).to(device).unsqueeze(0)
65 | control_values = 1.0,1.0,1.0
66 | mel_prediction, _, _, _ = synthesizer_infer(synthesizer_model, text, control_values, config['synthesizer']['preprocess'],
67 | device, speaker=speaker_map[speaker], d_target=dur_target, p_target=pitch_target)
68 |
69 | mel_pred_path = os.path.join(mel_pred_dir, speaker + "-pred_mel-" + file_name + ".npy")
70 | np.save(mel_pred_path, mel_prediction.cpu())
71 |
72 | resgrad_data.append({'speaker': speaker, 'predicted_mel':mel_pred_path, 'target_mel':mel_target_path, 'duration':dur_path})
73 |
74 | with open(config['resgrad']['data']['metadata_path'], 'w') as file:
75 | fields = ['speaker', 'predicted_mel', 'target_mel', 'duration']
76 | writer = csv.DictWriter(file, fieldnames = fields)
77 | writer.writeheader()
78 | writer.writerows(resgrad_data)
79 |
80 |
81 | if __name__ == "__main__":
82 | main()
83 | # python resgrad_data.py --synthesizer_restore_step 240
84 |
--------------------------------------------------------------------------------
/synthesizer/audio/__init__.py:
--------------------------------------------------------------------------------
1 | import audio.tools
2 | import audio.stft
3 | import audio.audio_processing
4 |
--------------------------------------------------------------------------------
/synthesizer/audio/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import librosa.util as librosa_util
4 | from scipy.signal import get_window
5 |
6 |
7 | def window_sumsquare(
8 | window,
9 | n_frames,
10 | hop_length,
11 | win_length,
12 | n_fft,
13 | dtype=np.float32,
14 | norm=None,
15 | ):
16 | """
17 | # from librosa 0.6
18 | Compute the sum-square envelope of a window function at a given hop length.
19 |
20 | This is used to estimate modulation effects induced by windowing
21 | observations in short-time fourier transforms.
22 |
23 | Parameters
24 | ----------
25 | window : string, tuple, number, callable, or list-like
26 | Window specification, as in `get_window`
27 |
28 | n_frames : int > 0
29 | The number of analysis frames
30 |
31 | hop_length : int > 0
32 | The number of samples to advance between frames
33 |
34 | win_length : [optional]
35 | The length of the window function. By default, this matches `n_fft`.
36 |
37 | n_fft : int > 0
38 | The length of each analysis frame.
39 |
40 | dtype : np.dtype
41 | The data type of the output
42 |
43 | Returns
44 | -------
45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46 | The sum-squared envelope of the window function
47 | """
48 | if win_length is None:
49 | win_length = n_fft
50 |
51 | n = n_fft + hop_length * (n_frames - 1)
52 | x = np.zeros(n, dtype=dtype)
53 |
54 | # Compute the squared window at the desired length
55 | win_sq = get_window(window, win_length, fftbins=True)
56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57 | win_sq = librosa_util.pad_center(win_sq, n_fft)
58 |
59 | # Fill the envelope
60 | for i in range(n_frames):
61 | sample = i * hop_length
62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63 | return x
64 |
65 |
66 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
67 | """
68 | PARAMS
69 | ------
70 | magnitudes: spectrogram magnitudes
71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72 | """
73 |
74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75 | angles = angles.astype(np.float32)
76 | angles = torch.autograd.Variable(torch.from_numpy(angles))
77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78 |
79 | for i in range(n_iters):
80 | _, angles = stft_fn.transform(signal)
81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82 | return signal
83 |
84 |
85 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
86 | """
87 | PARAMS
88 | ------
89 | C: compression factor
90 | """
91 | return torch.log(torch.clamp(x, min=clip_val) * C)
92 |
93 |
94 | def dynamic_range_decompression(x, C=1):
95 | """
96 | PARAMS
97 | ------
98 | C: compression factor used to compress
99 | """
100 | return torch.exp(x) / C
101 |
--------------------------------------------------------------------------------
/synthesizer/audio/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.signal import get_window
5 | from librosa.util import pad_center, tiny
6 | from librosa.filters import mel as librosa_mel_fn
7 |
8 | from audio.audio_processing import (
9 | dynamic_range_compression,
10 | dynamic_range_decompression,
11 | window_sumsquare,
12 | )
13 |
14 |
15 | class STFT(torch.nn.Module):
16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17 |
18 | def __init__(self, filter_length, hop_length, win_length, window="hann"):
19 | super(STFT, self).__init__()
20 | self.filter_length = filter_length
21 | self.hop_length = hop_length
22 | self.win_length = win_length
23 | self.window = window
24 | self.forward_transform = None
25 | scale = self.filter_length / self.hop_length
26 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
27 |
28 | cutoff = int((self.filter_length / 2 + 1))
29 | fourier_basis = np.vstack(
30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31 | )
32 |
33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34 | inverse_basis = torch.FloatTensor(
35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36 | )
37 |
38 | if window is not None:
39 | assert filter_length >= win_length
40 | # get window and zero center pad it to filter_length
41 | fft_window = get_window(window, win_length, fftbins=True)
42 | fft_window = pad_center(fft_window, filter_length)
43 | fft_window = torch.from_numpy(fft_window).float()
44 |
45 | # window the bases
46 | forward_basis *= fft_window
47 | inverse_basis *= fft_window
48 |
49 | self.register_buffer("forward_basis", forward_basis.float())
50 | self.register_buffer("inverse_basis", inverse_basis.float())
51 |
52 | def transform(self, input_data):
53 | num_batches = input_data.size(0)
54 | num_samples = input_data.size(1)
55 |
56 | self.num_samples = num_samples
57 |
58 | # similar to librosa, reflect-pad the input
59 | input_data = input_data.view(num_batches, 1, num_samples)
60 | input_data = F.pad(
61 | input_data.unsqueeze(1),
62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63 | mode="reflect",
64 | )
65 | input_data = input_data.squeeze(1)
66 |
67 | forward_transform = F.conv1d(
68 | input_data.cuda(),
69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
70 | stride=self.hop_length,
71 | padding=0,
72 | ).cpu()
73 |
74 | cutoff = int((self.filter_length / 2) + 1)
75 | real_part = forward_transform[:, :cutoff, :]
76 | imag_part = forward_transform[:, cutoff:, :]
77 |
78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80 |
81 | return magnitude, phase
82 |
83 | def inverse(self, magnitude, phase):
84 | recombine_magnitude_phase = torch.cat(
85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86 | )
87 |
88 | inverse_transform = F.conv_transpose1d(
89 | recombine_magnitude_phase,
90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91 | stride=self.hop_length,
92 | padding=0,
93 | )
94 |
95 | if self.window is not None:
96 | window_sum = window_sumsquare(
97 | self.window,
98 | magnitude.size(-1),
99 | hop_length=self.hop_length,
100 | win_length=self.win_length,
101 | n_fft=self.filter_length,
102 | dtype=np.float32,
103 | )
104 | # remove modulation effects
105 | approx_nonzero_indices = torch.from_numpy(
106 | np.where(window_sum > tiny(window_sum))[0]
107 | )
108 | window_sum = torch.autograd.Variable(
109 | torch.from_numpy(window_sum), requires_grad=False
110 | )
111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113 | approx_nonzero_indices
114 | ]
115 |
116 | # scale by hop ratio
117 | inverse_transform *= float(self.filter_length) / self.hop_length
118 |
119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121 |
122 | return inverse_transform
123 |
124 | def forward(self, input_data):
125 | self.magnitude, self.phase = self.transform(input_data)
126 | reconstruction = self.inverse(self.magnitude, self.phase)
127 | return reconstruction
128 |
129 |
130 | class TacotronSTFT(torch.nn.Module):
131 | def __init__(
132 | self,
133 | filter_length,
134 | hop_length,
135 | win_length,
136 | n_mel_channels,
137 | sampling_rate,
138 | mel_fmin,
139 | mel_fmax,
140 | ):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147 | )
148 | mel_basis = torch.from_numpy(mel_basis).float()
149 | self.register_buffer("mel_basis", mel_basis)
150 |
151 | def spectral_normalize(self, magnitudes):
152 | output = dynamic_range_compression(magnitudes)
153 | return output
154 |
155 | def spectral_de_normalize(self, magnitudes):
156 | output = dynamic_range_decompression(magnitudes)
157 | return output
158 |
159 | def mel_spectrogram(self, y):
160 | """Computes mel-spectrograms from a batch of waves
161 | PARAMS
162 | ------
163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164 |
165 | RETURNS
166 | -------
167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168 | """
169 | assert torch.min(y.data) >= -1
170 | assert torch.max(y.data) <= 1
171 |
172 | magnitudes, phases = self.stft_fn.transform(y)
173 | magnitudes = magnitudes.data
174 | mel_output = torch.matmul(self.mel_basis, magnitudes)
175 | mel_output = self.spectral_normalize(mel_output)
176 | energy = torch.norm(magnitudes, dim=1)
177 |
178 | return mel_output, energy
179 |
--------------------------------------------------------------------------------
/synthesizer/audio/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.io.wavfile import write
4 | from librosa.filters import mel as librosa_mel_fn
5 |
6 | from audio.audio_processing import griffin_lim
7 |
8 |
9 | def get_mel_from_wav(audio, _stft):
10 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
11 | audio = torch.autograd.Variable(audio, requires_grad=False)
12 | melspec, energy = _stft.mel_spectrogram(audio)
13 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
15 |
16 | return melspec, energy
17 |
18 |
19 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
20 | mel = torch.stack([mel])
21 | mel_decompress = _stft.spectral_de_normalize(mel)
22 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
23 | spec_from_mel_scaling = 1000
24 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
25 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
26 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
27 |
28 | audio = griffin_lim(
29 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
30 | )
31 |
32 | audio = audio.squeeze()
33 | audio = audio.cpu().numpy()
34 | audio_path = out_filename
35 | write(audio_path, _stft.sampling_rate, audio)
36 |
37 |
38 | mel_basis = {}
39 | hann_window = {}
40 |
41 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
42 | return torch.log(torch.clamp(x, min=clip_val) * C)
43 | def spectral_normalize_torch(magnitudes):
44 | output = dynamic_range_compression_torch(magnitudes)
45 | return output
46 |
47 | def get_mel_from_wav_as_hifigan(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
48 | if torch.min(y) < -1.:
49 | print('min value is ', torch.min(y))
50 | if torch.max(y) > 1.:
51 | print('max value is ', torch.max(y))
52 |
53 | global mel_basis, hann_window
54 | if fmax not in mel_basis:
55 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
56 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
57 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
58 |
59 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
60 | y = y.squeeze(1)
61 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
62 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
63 |
64 | magnitudes = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
65 |
66 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], magnitudes)
67 | spec = spectral_normalize_torch(spec)
68 |
69 | energy = torch.norm(magnitudes, dim=1)
70 | return spec.squeeze(0).numpy(), energy.squeeze(0).numpy()
71 |
72 |
--------------------------------------------------------------------------------
/synthesizer/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | from .text import text_to_sequence
7 | from .utils.tools import pad_1D, pad_2D
8 |
9 |
10 | class Dataset(Dataset):
11 | def __init__(
12 | self, filename, preprocess_config, train_config, sort=False, drop_last=False
13 | ):
14 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
15 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
16 | self.batch_size = train_config["optimizer"]["batch_size"]
17 |
18 | self.basename, self.speaker, self.text = self.process_meta(
19 | filename
20 | )
21 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
22 | self.speaker_map = json.load(f)
23 | self.sort = sort
24 | self.drop_last = drop_last
25 |
26 | def __len__(self):
27 | return len(self.text)
28 |
29 | def __getitem__(self, idx):
30 | basename = self.basename[idx]
31 | speaker = self.speaker[idx]
32 | speaker_id = self.speaker_map[speaker]
33 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
34 | mel_path = os.path.join(
35 | self.preprocessed_path,
36 | "mel",
37 | "{}-mel-{}.npy".format(speaker, basename),
38 | )
39 | mel = np.load(mel_path)
40 | pitch_path = os.path.join(
41 | self.preprocessed_path,
42 | "pitch",
43 | "{}-pitch-{}.npy".format(speaker, basename),
44 | )
45 | pitch = np.load(pitch_path)
46 | energy_path = os.path.join(
47 | self.preprocessed_path,
48 | "energy",
49 | "{}-energy-{}.npy".format(speaker, basename),
50 | )
51 | energy = np.load(energy_path)
52 | duration_path = os.path.join(
53 | self.preprocessed_path,
54 | "duration",
55 | "{}-duration-{}.npy".format(speaker, basename),
56 | )
57 | duration = np.load(duration_path)
58 |
59 | sample = {
60 | "id": basename,
61 | "speaker": speaker_id,
62 | "text": phone,
63 | "mel": mel,
64 | "pitch": pitch,
65 | "energy": energy,
66 | "duration": duration,
67 | }
68 |
69 | return sample
70 |
71 | def process_meta(self, filename):
72 | with open(
73 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
74 | ) as f:
75 | name = []
76 | speaker = []
77 | text = []
78 | for line in f.readlines():
79 | n, s, t, r = line.strip("\n").split("|")
80 | name.append(n)
81 | speaker.append(s)
82 | text.append(t)
83 | return name, speaker, text
84 |
85 | def reprocess(self, data, idxs):
86 | ids = [data[idx]["id"] for idx in idxs]
87 | speakers = [data[idx]["speaker"] for idx in idxs]
88 | texts = [data[idx]["text"] for idx in idxs]
89 | mels = [data[idx]["mel"] for idx in idxs]
90 | pitches = [data[idx]["pitch"] for idx in idxs]
91 | energies = [data[idx]["energy"] for idx in idxs]
92 | durations = [data[idx]["duration"] for idx in idxs]
93 |
94 | text_lens = np.array([text.shape[0] for text in texts])
95 | mel_lens = np.array([mel.shape[0] for mel in mels])
96 |
97 | speakers = np.array(speakers)
98 | texts = pad_1D(texts)
99 | mels = pad_2D(mels)
100 | pitches = pad_1D(pitches)
101 | energies = pad_1D(energies)
102 | durations = pad_1D(durations)
103 |
104 | return (
105 | ids,
106 | speakers,
107 | texts,
108 | text_lens,
109 | max(text_lens),
110 | mels,
111 | mel_lens,
112 | max(mel_lens),
113 | pitches,
114 | energies,
115 | durations,
116 | )
117 |
118 | def collate_fn(self, data):
119 | data_size = len(data)
120 |
121 | if self.sort:
122 | len_arr = np.array([d["text"].shape[0] for d in data])
123 | idx_arr = np.argsort(-len_arr)
124 | else:
125 | idx_arr = np.arange(data_size)
126 |
127 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
128 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
129 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
130 | if not self.drop_last and len(tail) > 0:
131 | idx_arr += [tail.tolist()]
132 |
133 | output = list()
134 | for idx in idx_arr:
135 | output.append(self.reprocess(data, idx))
136 |
137 | return output
--------------------------------------------------------------------------------
/synthesizer/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | from .utils.tools import to_device, log, synth_one_sample
5 | from .model import FastSpeech2Loss
6 | from .dataset import Dataset
7 |
8 |
9 | def evaluate(model, step, config, logger=None, vocoder=None):
10 | preprocess_config, model_config, train_config = config['synthesizer']['preprocess'], config['synthesizer']['model'], config['synthesizer']['train']
11 | device = config['main']['device']
12 | # Get dataset
13 | dataset = Dataset(
14 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False
15 | )
16 | batch_size = train_config["optimizer"]["batch_size"]
17 | loader = DataLoader(
18 | dataset,
19 | batch_size=batch_size,
20 | shuffle=False,
21 | collate_fn=dataset.collate_fn,
22 | )
23 |
24 | # Get loss function
25 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
26 |
27 | # Evaluation
28 | loss_sums = [0 for _ in range(6)]
29 | for batchs in loader:
30 | for batch in batchs:
31 | batch = to_device(batch, device)
32 | with torch.no_grad():
33 | # Forward
34 | output = model(*(batch[1:]))
35 |
36 | # Cal Loss
37 | losses = Loss(batch, output)
38 |
39 | for i in range(len(losses)):
40 | loss_sums[i] += losses[i].item() * len(batch[0])
41 |
42 | loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums]
43 |
44 | message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
45 | *([step] + [l for l in loss_means])
46 | )
47 |
48 | if logger is not None:
49 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
50 | batch,
51 | output,
52 | vocoder,
53 | model_config,
54 | preprocess_config,
55 | )
56 |
57 | log(logger, step, losses=loss_means)
58 | log(
59 | logger,
60 | fig=fig,
61 | tag="Validation/step_{}_{}".format(step, tag),
62 | )
63 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
64 | log(
65 | logger,
66 | audio=wav_reconstruction,
67 | sampling_rate=sampling_rate,
68 | tag="Validation/step_{}_{}_reconstructed".format(step, tag),
69 | )
70 | log(
71 | logger,
72 | audio=wav_prediction,
73 | sampling_rate=sampling_rate,
74 | tag="Validation/step_{}_{}_synthesized".format(step, tag),
75 | )
76 |
77 | return message
78 |
79 |
80 |
--------------------------------------------------------------------------------
/synthesizer/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .fastspeech2 import FastSpeech2
2 | from .loss import FastSpeech2Loss
3 | from .optimizer import ScheduledOptim
--------------------------------------------------------------------------------
/synthesizer/model/fastspeech2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from ..transformer import Encoder, Decoder, PostNet
8 | from .modules import VarianceAdaptor
9 | from ..utils.tools import get_mask_from_lengths
10 |
11 |
12 | class FastSpeech2(nn.Module):
13 | """ FastSpeech2 """
14 |
15 | def __init__(self, config):
16 | super(FastSpeech2, self).__init__()
17 | preprocess_config, model_config = config['synthesizer']['preprocess'], config['synthesizer']['model']
18 | self.model_config = model_config
19 | self.device = config['main']['device']
20 |
21 | self.encoder = Encoder(model_config, preprocess_config['preprocessing']['text']['language'])
22 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config, self.device)
23 | self.decoder = Decoder(model_config)
24 | self.mel_linear = nn.Linear(
25 | model_config["transformer"]["decoder_hidden"],
26 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
27 | )
28 | self.postnet = PostNet()
29 |
30 | self.speaker_emb = None
31 | if config['main']['multi_speaker']:
32 | with open(
33 | os.path.join(
34 | preprocess_config["path"]["preprocessed_path"], "speakers.json"
35 | ),
36 | "r",
37 | ) as f:
38 | n_speaker = len(json.load(f))
39 | self.speaker_emb = nn.Embedding(
40 | n_speaker,
41 | model_config["transformer"]["encoder_hidden"],
42 | )
43 |
44 | def forward(
45 | self,
46 | speakers,
47 | texts,
48 | src_lens,
49 | max_src_len,
50 | mels=None,
51 | mel_lens=None,
52 | max_mel_len=None,
53 | p_targets=None,
54 | e_targets=None,
55 | d_targets=None,
56 | p_control=1.0,
57 | e_control=1.0,
58 | d_control=1.0,
59 | ):
60 | src_masks = get_mask_from_lengths(src_lens, max_src_len, self.device)
61 | mel_masks = (
62 | get_mask_from_lengths(mel_lens, max_mel_len, self.device)
63 | if mel_lens is not None
64 | else None
65 | )
66 | output = self.encoder(texts, src_masks)
67 |
68 | if self.speaker_emb is not None:
69 | self.speaker_vec = self.speaker_emb(speakers).unsqueeze(1)
70 | output = output + self.speaker_vec.expand(-1, max_src_len, -1)
71 |
72 |
73 | (
74 | output,
75 | p_predictions,
76 | e_predictions,
77 | log_d_predictions,
78 | d_rounded,
79 | mel_lens,
80 | mel_masks,
81 | ) = self.variance_adaptor(
82 | output,
83 | src_masks,
84 | mel_masks,
85 | max_mel_len,
86 | p_targets,
87 | e_targets,
88 | d_targets,
89 | p_control,
90 | e_control,
91 | d_control,
92 | )
93 |
94 | # if self.speaker_emb is not None:
95 | # output = output + self.speaker_vec.expand(-1, output.shape[1], -1)
96 |
97 | output, mel_masks = self.decoder(output, mel_masks)
98 | output = self.mel_linear(output)
99 | postnet_output = self.postnet(output) + output
100 |
101 | return (
102 | output,
103 | postnet_output,
104 | p_predictions,
105 | e_predictions,
106 | log_d_predictions,
107 | d_rounded,
108 | src_masks,
109 | mel_masks,
110 | src_lens,
111 | mel_lens,
112 | )
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/synthesizer/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class FastSpeech2Loss(nn.Module):
6 | """ FastSpeech2 Loss """
7 |
8 | def __init__(self, preprocess_config, model_config):
9 | super(FastSpeech2Loss, self).__init__()
10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
11 | "feature"
12 | ]
13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
14 | "feature"
15 | ]
16 | self.mse_loss = nn.MSELoss()
17 | self.mae_loss = nn.L1Loss()
18 |
19 | def forward(self, inputs, predictions):
20 | (
21 | mel_targets,
22 | _,
23 | _,
24 | pitch_targets,
25 | energy_targets,
26 | duration_targets,
27 | ) = inputs[5:]
28 | (
29 | mel_predictions,
30 | postnet_mel_predictions,
31 | pitch_predictions,
32 | energy_predictions,
33 | log_duration_predictions,
34 | _,
35 | src_masks,
36 | mel_masks,
37 | _,
38 | _,
39 | ) = predictions
40 | src_masks = ~src_masks
41 | mel_masks = ~mel_masks
42 | log_duration_targets = torch.log(duration_targets.float() + 1)
43 | mel_targets = mel_targets[:, : mel_masks.shape[1], :]
44 | mel_masks = mel_masks[:, :mel_masks.shape[1]]
45 |
46 | log_duration_targets.requires_grad = False
47 | pitch_targets.requires_grad = False
48 | energy_targets.requires_grad = False
49 | mel_targets.requires_grad = False
50 |
51 | if self.pitch_feature_level == "phoneme_level":
52 | pitch_predictions = pitch_predictions.masked_select(src_masks)
53 | pitch_targets = pitch_targets.masked_select(src_masks)
54 | elif self.pitch_feature_level == "frame_level":
55 | pitch_predictions = pitch_predictions.masked_select(mel_masks)
56 | pitch_targets = pitch_targets.masked_select(mel_masks)
57 |
58 | if self.energy_feature_level == "phoneme_level":
59 | energy_predictions = energy_predictions.masked_select(src_masks)
60 | energy_targets = energy_targets.masked_select(src_masks)
61 | if self.energy_feature_level == "frame_level":
62 | energy_predictions = energy_predictions.masked_select(mel_masks)
63 | energy_targets = energy_targets.masked_select(mel_masks)
64 |
65 | log_duration_predictions = log_duration_predictions.masked_select(src_masks)
66 | log_duration_targets = log_duration_targets.masked_select(src_masks)
67 |
68 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
69 | postnet_mel_predictions = postnet_mel_predictions.masked_select(
70 | mel_masks.unsqueeze(-1)
71 | )
72 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
73 |
74 | mel_loss = self.mae_loss(mel_predictions, mel_targets)
75 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
76 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
77 | energy_loss = self.mse_loss(energy_predictions, energy_targets)
78 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
79 |
80 | total_loss = (
81 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
82 | )
83 |
84 | return (
85 | total_loss,
86 | mel_loss,
87 | postnet_mel_loss,
88 | pitch_loss,
89 | energy_loss,
90 | duration_loss,
91 | )
92 |
--------------------------------------------------------------------------------
/synthesizer/model/modules.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 | from ..utils.tools import get_mask_from_lengths, pad
11 |
12 |
13 | class VarianceAdaptor(nn.Module):
14 | """Variance Adaptor"""
15 |
16 | def __init__(self, preprocess_config, model_config, device):
17 | super(VarianceAdaptor, self).__init__()
18 | self.duration_predictor = VariancePredictor(model_config)
19 | self.length_regulator = LengthRegulator()
20 | self.pitch_predictor = VariancePredictor(model_config)
21 | self.energy_predictor = VariancePredictor(model_config)
22 |
23 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
24 | "feature"
25 | ]
26 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
27 | "feature"
28 | ]
29 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
30 | assert self.energy_feature_level in ["phoneme_level", "frame_level"]
31 |
32 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
33 | energy_quantization = model_config["variance_embedding"]["energy_quantization"]
34 | n_bins = model_config["variance_embedding"]["n_bins"]
35 | assert pitch_quantization in ["linear", "log"]
36 | assert energy_quantization in ["linear", "log"]
37 | with open(
38 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
39 | ) as f:
40 | stats = json.load(f)
41 | pitch_min, pitch_max = stats["pitch"][:2]
42 | energy_min, energy_max = stats["energy"][:2]
43 |
44 | if pitch_quantization == "log":
45 | self.pitch_bins = nn.Parameter(
46 | torch.exp(
47 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
48 | ),
49 | requires_grad=False,
50 | )
51 | else:
52 | self.pitch_bins = nn.Parameter(
53 | torch.linspace(pitch_min, pitch_max, n_bins - 1),
54 | requires_grad=False,
55 | )
56 | if energy_quantization == "log":
57 | self.energy_bins = nn.Parameter(
58 | torch.exp(
59 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
60 | ),
61 | requires_grad=False,
62 | )
63 | else:
64 | self.energy_bins = nn.Parameter(
65 | torch.linspace(energy_min, energy_max, n_bins - 1),
66 | requires_grad=False,
67 | )
68 |
69 | self.pitch_embedding = nn.Embedding(
70 | n_bins, model_config["transformer"]["encoder_hidden"]
71 | )
72 | self.energy_embedding = nn.Embedding(
73 | n_bins, model_config["transformer"]["encoder_hidden"]
74 | )
75 | self.device = device
76 |
77 | def get_pitch_embedding(self, x, target, mask, control):
78 | prediction = self.pitch_predictor(x, mask)
79 | if target is not None:
80 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
81 | else:
82 | prediction = prediction * control
83 | embedding = self.pitch_embedding(
84 | torch.bucketize(prediction, self.pitch_bins)
85 | )
86 | return prediction, embedding
87 |
88 | def get_energy_embedding(self, x, target, mask, control):
89 | prediction = self.energy_predictor(x, mask)
90 | if target is not None:
91 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
92 | else:
93 | prediction = prediction * control
94 | embedding = self.energy_embedding(
95 | torch.bucketize(prediction, self.energy_bins)
96 | )
97 | return prediction, embedding
98 |
99 | def forward(
100 | self,
101 | x,
102 | src_mask,
103 | mel_mask=None,
104 | max_len=None,
105 | pitch_target=None,
106 | energy_target=None,
107 | duration_target=None,
108 | p_control=1.0,
109 | e_control=1.0,
110 | d_control=1.0,
111 | ):
112 |
113 | log_duration_prediction = self.duration_predictor(x, src_mask)
114 | if self.pitch_feature_level == "phoneme_level":
115 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
116 | x, pitch_target, src_mask, p_control
117 | )
118 | x = x + pitch_embedding
119 | if self.energy_feature_level == "phoneme_level":
120 | energy_prediction, energy_embedding = self.get_energy_embedding(
121 | x, energy_target, src_mask, p_control
122 | )
123 | x = x + energy_embedding
124 |
125 | if duration_target is not None:
126 | x, mel_len = self.length_regulator(x, duration_target, max_len, self.device)
127 | duration_rounded = duration_target
128 | if mel_mask is None:
129 | mel_mask = get_mask_from_lengths(mel_len, device=self.device)
130 |
131 | else:
132 | duration_rounded = torch.clamp(
133 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
134 | min=0,
135 | )
136 | x, mel_len = self.length_regulator(x, duration_rounded, max_len, self.device)
137 | mel_mask = get_mask_from_lengths(mel_len, device=self.device)
138 |
139 | if self.pitch_feature_level == "frame_level":
140 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
141 | x, pitch_target, mel_mask, p_control
142 | )
143 | x = x + pitch_embedding
144 | if self.energy_feature_level == "frame_level":
145 | energy_prediction, energy_embedding = self.get_energy_embedding(
146 | x, energy_target, mel_mask, p_control
147 | )
148 | x = x + energy_embedding
149 |
150 | return (
151 | x,
152 | pitch_prediction,
153 | energy_prediction,
154 | log_duration_prediction,
155 | duration_rounded,
156 | mel_len,
157 | mel_mask,
158 | )
159 |
160 |
161 | class LengthRegulator(nn.Module):
162 | """Length Regulator"""
163 |
164 | def __init__(self):
165 | super(LengthRegulator, self).__init__()
166 |
167 | def LR(self, x, duration, max_len, device):
168 | output = list()
169 | mel_len = list()
170 | for batch, expand_target in zip(x, duration):
171 | expanded = self.expand(batch, expand_target)
172 | output.append(expanded)
173 | mel_len.append(expanded.shape[0])
174 |
175 | if max_len is not None:
176 | output = pad(output, max_len)
177 | else:
178 | output = pad(output)
179 |
180 | return output, torch.LongTensor(mel_len).to(device)
181 |
182 | def expand(self, batch, predicted):
183 | out = list()
184 |
185 | for i, vec in enumerate(batch):
186 | expand_size = predicted[i].item()
187 | out.append(vec.expand(max(int(expand_size), 0), -1))
188 | out = torch.cat(out, 0)
189 |
190 | return out
191 |
192 | def forward(self, x, duration, max_len, device):
193 | output, mel_len = self.LR(x, duration, max_len, device)
194 | return output, mel_len
195 |
196 |
197 | class VariancePredictor(nn.Module):
198 | """Duration, Pitch and Energy Predictor"""
199 |
200 | def __init__(self, model_config):
201 | super(VariancePredictor, self).__init__()
202 |
203 | self.input_size = model_config["transformer"]["encoder_hidden"]
204 | self.filter_size = model_config["variance_predictor"]["filter_size"]
205 | self.kernel = model_config["variance_predictor"]["kernel_size"]
206 | self.conv_output_size = model_config["variance_predictor"]["filter_size"]
207 | self.dropout = model_config["variance_predictor"]["dropout"]
208 |
209 | self.conv_layer = nn.Sequential(
210 | OrderedDict(
211 | [
212 | (
213 | "conv1d_1",
214 | Conv(
215 | self.input_size,
216 | self.filter_size,
217 | kernel_size=self.kernel,
218 | padding=(self.kernel - 1) // 2,
219 | ),
220 | ),
221 | ("relu_1", nn.ReLU()),
222 | ("layer_norm_1", nn.LayerNorm(self.filter_size)),
223 | ("dropout_1", nn.Dropout(self.dropout)),
224 | (
225 | "conv1d_2",
226 | Conv(
227 | self.filter_size,
228 | self.filter_size,
229 | kernel_size=self.kernel,
230 | padding=1,
231 | ),
232 | ),
233 | ("relu_2", nn.ReLU()),
234 | ("layer_norm_2", nn.LayerNorm(self.filter_size)),
235 | ("dropout_2", nn.Dropout(self.dropout)),
236 | ]
237 | )
238 | )
239 |
240 | self.linear_layer = nn.Linear(self.conv_output_size, 1)
241 |
242 | def forward(self, encoder_output, mask):
243 | out = self.conv_layer(encoder_output)
244 | out = self.linear_layer(out)
245 | out = out.squeeze(-1)
246 |
247 | if mask is not None:
248 | out = out.masked_fill(mask, 0.0)
249 |
250 | return out
251 |
252 |
253 | class Conv(nn.Module):
254 | """
255 | Convolution Module
256 | """
257 |
258 | def __init__(
259 | self,
260 | in_channels,
261 | out_channels,
262 | kernel_size=1,
263 | stride=1,
264 | padding=0,
265 | dilation=1,
266 | bias=True,
267 | w_init="linear",
268 | ):
269 | """
270 | :param in_channels: dimension of input
271 | :param out_channels: dimension of output
272 | :param kernel_size: size of kernel
273 | :param stride: size of stride
274 | :param padding: size of padding
275 | :param dilation: dilation rate
276 | :param bias: boolean. if True, bias is included.
277 | :param w_init: str. weight inits with xavier initialization.
278 | """
279 | super(Conv, self).__init__()
280 |
281 | self.conv = nn.Conv1d(
282 | in_channels,
283 | out_channels,
284 | kernel_size=kernel_size,
285 | stride=stride,
286 | padding=padding,
287 | dilation=dilation,
288 | bias=bias,
289 | )
290 |
291 | def forward(self, x):
292 | x = x.contiguous().transpose(1, 2)
293 | x = self.conv(x)
294 | x = x.contiguous().transpose(1, 2)
295 |
296 | return x
297 |
--------------------------------------------------------------------------------
/synthesizer/model/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ScheduledOptim:
6 | """ A simple wrapper class for learning rate scheduling """
7 |
8 | def __init__(self, model, train_config, model_config, current_step):
9 |
10 | self._optimizer = torch.optim.Adam(
11 | model.parameters(),
12 | betas=train_config["optimizer"]["betas"],
13 | eps=train_config["optimizer"]["eps"],
14 | weight_decay=train_config["optimizer"]["weight_decay"],
15 | )
16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"]
17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"]
18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"]
19 | self.current_step = current_step
20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5)
21 |
22 | def step_and_update_lr(self):
23 | self._update_learning_rate()
24 | self._optimizer.step()
25 |
26 | def zero_grad(self):
27 | self._optimizer.zero_grad()
28 |
29 | def load_state_dict(self, path):
30 | self._optimizer.load_state_dict(path)
31 |
32 | def _get_lr_scale(self):
33 | lr = np.min(
34 | [
35 | np.power(self.current_step, -0.5),
36 | np.power(self.n_warmup_steps, -1.5) * self.current_step,
37 | ]
38 | )
39 | for s in self.anneal_steps:
40 | if self.current_step > s:
41 | lr = lr * self.anneal_rate
42 | return lr
43 |
44 | def _update_learning_rate(self):
45 | """ Learning rate scheduling per step """
46 | self.current_step += 1
47 | lr = self.init_lr * self._get_lr_scale()
48 |
49 | for param_group in self._optimizer.param_groups:
50 | param_group["lr"] = lr
51 |
--------------------------------------------------------------------------------
/synthesizer/prepare_align.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from preprocessor import persian, ljspeech
4 | from utils.tools import load_yaml_file
5 |
6 |
7 | def main(config):
8 | if "Persian" in config["main"]["dataset"]:
9 | persian.prepare_align(config['synthesizer']['preprocess'])
10 | elif config["main"]["dataset"] == "LJSpeech":
11 | ljspeech.prepare_align(config['synthesizer']['preprocess'])
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("config", type=str, help="path to config.yaml")
17 | args = parser.parse_args()
18 |
19 | config = load_yaml_file(args.config)
20 | main(config)
21 |
--------------------------------------------------------------------------------
/synthesizer/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from preprocessor.preprocessor import Preprocessor
4 | from utils.tools import load_yaml_file
5 |
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("config", type=str, help="path to config.yaml")
10 | args = parser.parse_args()
11 |
12 | config = load_yaml_file(args.config)
13 | preprocessor = Preprocessor(config['synthesizer']['preprocess'])
14 | preprocessor.build_from_path()
15 |
--------------------------------------------------------------------------------
/synthesizer/preprocessor/ljspeech.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 | from librosa.util import normalize
8 |
9 | from text import _clean_text
10 |
11 |
12 | def prepare_align(config):
13 | in_dir = config["path"]["corpus_path"]
14 | out_dir = config["path"]["raw_path"]
15 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
16 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
17 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
18 |
19 | speaker = "single_speaker"
20 | with open(os.path.join(in_dir, "train.txt"), encoding="utf-8") as f:
21 | for line in tqdm(f.readlines()):
22 | parts = line.strip().split("|")
23 | base_name = parts[0]
24 | text = parts[2]
25 | wav_path = os.path.join(in_dir, 'wavs', base_name + '.wav')
26 |
27 | text = _clean_text(text, cleaners)
28 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
29 | wav, _ = librosa.load(wav_path, sr=sampling_rate)
30 | wav = wav / max_wav_value
31 | wav = normalize(wav) * 0.95
32 | wavfile.write(
33 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
34 | sampling_rate,
35 | wav.astype(np.float32),
36 | )
37 | with open(
38 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
39 | "w",
40 | ) as f1:
41 | f1.write(text)
42 |
--------------------------------------------------------------------------------
/synthesizer/preprocessor/persian.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 | from librosa.util import normalize
8 |
9 | from text import _clean_text
10 |
11 |
12 | def prepare_align(config):
13 | in_dir = config["path"]["corpus_path"]
14 | out_dir = config["path"]["raw_path"]
15 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
16 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
17 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
18 |
19 | with open(os.path.join(in_dir, "train.txt"), encoding="utf-8") as f:
20 | for line in tqdm(f.readlines()):
21 | parts = line.strip().split("|")
22 | base_name = parts[0]
23 | speaker = parts[1]
24 | text = parts[2]
25 | wav_path = os.path.join(in_dir, 'train_data',speaker, base_name + '.wav')
26 |
27 | text = _clean_text(text, cleaners)
28 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
29 | wav, _ = librosa.load(wav_path, sr=sampling_rate)
30 | wav = wav / max_wav_value
31 | wav = normalize(wav) * 0.95
32 | wavfile.write(
33 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
34 | sampling_rate,
35 | wav.astype(np.float32),
36 | )
37 | with open(
38 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
39 | "w",
40 | ) as f1:
41 | f1.write(text)
42 |
--------------------------------------------------------------------------------
/synthesizer/preprocessor/persian_v1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import librosa
3 | import numpy as np
4 | from scipy.io import wavfile
5 | from tqdm import tqdm
6 | from librosa.util import normalize
7 | from scipy.io.wavfile import read
8 |
9 | from text import _clean_text
10 |
11 |
12 | def prepare_align(config):
13 | in_dir = config["path"]["corpus_path"]
14 | out_dir = config["path"]["raw_path"]
15 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
16 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
17 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
18 | speakers = os.listdir(in_dir)
19 | for i, speaker in enumerate(speakers):
20 | loop =tqdm(os.listdir(os.path.join(in_dir, speaker)))
21 | loop.set_description(f'speaker count = {i+1}/{len(speakers)}')
22 | for file_name in loop:
23 | if file_name[-4:] != ".wav":
24 | continue
25 | base_name = file_name[:-4]
26 | text_path = os.path.join(
27 | in_dir, speaker, "{}.txt".format(base_name)
28 | )
29 | wav_path = os.path.join(
30 | in_dir, speaker, "{}.wav".format(base_name)
31 | )
32 | with open(text_path) as f:
33 | text = f.readline().strip("\n")
34 | text = _clean_text(text, cleaners)
35 |
36 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
37 | wav, _ = librosa.load(wav_path, sr=sampling_rate)
38 | wav = wav / max_wav_value
39 | wav = normalize(wav) * 0.95
40 |
41 | wavfile.write(
42 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
43 | sampling_rate,
44 | wav.astype(np.float32),
45 | )
46 | with open(
47 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
48 | "w",
49 | ) as f1:
50 | f1.write(text)
51 |
--------------------------------------------------------------------------------
/synthesizer/preprocessor/preprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 |
5 | import tgt
6 | import librosa
7 | import numpy as np
8 | import pyworld as pw
9 | from scipy.interpolate import interp1d
10 | from sklearn.preprocessing import StandardScaler
11 | from tqdm import tqdm
12 | import torch
13 | from scipy.io.wavfile import read
14 |
15 | import audio as Audio
16 |
17 |
18 | class Preprocessor:
19 | def __init__(self, config):
20 | self.config = config
21 | self.in_dir = config["path"]["raw_path"]
22 | self.out_dir = config["path"]["preprocessed_path"]
23 | self.val_size = config["preprocessing"]["val_size"]
24 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
25 | self.hop_length = config["preprocessing"]["stft"]["hop_length"]
26 |
27 | assert config["preprocessing"]["pitch"]["feature"] in [
28 | "phoneme_level",
29 | "frame_level",
30 | ]
31 | assert config["preprocessing"]["energy"]["feature"] in [
32 | "phoneme_level",
33 | "frame_level",
34 | ]
35 | self.pitch_phoneme_averaging = (
36 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level"
37 | )
38 | self.energy_phoneme_averaging = (
39 | config["preprocessing"]["energy"]["feature"] == "phoneme_level"
40 | )
41 |
42 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"]
43 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"]
44 |
45 | self.preprocessing_confing = config["preprocessing"]
46 |
47 | def build_from_path(self):
48 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True)
49 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True)
50 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True)
51 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True)
52 |
53 | print("Processing Data ...")
54 | out = list()
55 | n_frames = 0
56 | pitch_scaler = StandardScaler()
57 | energy_scaler = StandardScaler()
58 |
59 | # Compute pitch, energy, duration, and mel-spectrogram
60 | speakers = {}
61 | speakers_list = os.listdir(self.in_dir)
62 | for i, speaker in enumerate(speakers_list):
63 | speakers[speaker] = i
64 | loop = tqdm(os.listdir(os.path.join(self.in_dir, speaker)))
65 | loop.set_description(f'speaker count = {i+1}/{len(speakers_list)}')
66 | for wav_name in loop:
67 | if ".wav" not in wav_name:
68 | continue
69 |
70 | basename = wav_name.split(".")[0]
71 | tg_path = os.path.join(
72 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
73 | )
74 | if os.path.exists(tg_path):
75 | ret = self.process_utterance(speaker, basename)
76 | if ret is None:
77 | continue
78 | else:
79 | info, pitch, energy, n = ret
80 | out.append(info)
81 |
82 | if len(pitch) > 0:
83 | pitch_scaler.partial_fit(pitch.reshape((-1, 1)))
84 | if len(energy) > 0:
85 | energy_scaler.partial_fit(energy.reshape((-1, 1)))
86 |
87 | n_frames += n
88 |
89 | print("Computing statistic quantities ...")
90 | # Perform normalization if necessary
91 | if self.pitch_normalization:
92 | pitch_mean = pitch_scaler.mean_[0]
93 | pitch_std = pitch_scaler.scale_[0]
94 | else:
95 | # A numerical trick to avoid normalization...
96 | pitch_mean = 0
97 | pitch_std = 1
98 | if self.energy_normalization:
99 | energy_mean = energy_scaler.mean_[0]
100 | energy_std = energy_scaler.scale_[0]
101 | else:
102 | energy_mean = 0
103 | energy_std = 1
104 |
105 | pitch_min, pitch_max = self.normalize(
106 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std
107 | )
108 | energy_min, energy_max = self.normalize(
109 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std
110 | )
111 |
112 | # Save files
113 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f:
114 | f.write(json.dumps(speakers))
115 |
116 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f:
117 | stats = {
118 | "pitch": [
119 | float(pitch_min),
120 | float(pitch_max),
121 | float(pitch_mean),
122 | float(pitch_std),
123 | ],
124 | "energy": [
125 | float(energy_min),
126 | float(energy_max),
127 | float(energy_mean),
128 | float(energy_std),
129 | ],
130 | }
131 | f.write(json.dumps(stats))
132 |
133 | print(
134 | "Total data time: {} hours".format(
135 | n_frames * self.hop_length / self.sampling_rate / 3600
136 | )
137 | )
138 |
139 | random.shuffle(out)
140 | out = [r for r in out if r is not None]
141 |
142 | # Write metadata
143 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f:
144 | for m in out[self.val_size :]:
145 | f.write(m + "\n")
146 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f:
147 | for m in out[: self.val_size]:
148 | f.write(m + "\n")
149 |
150 | return out
151 |
152 | def process_utterance(self, speaker, basename):
153 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename))
154 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename))
155 | tg_path = os.path.join(
156 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
157 | )
158 |
159 | # Get alignments
160 | textgrid = tgt.io.read_textgrid(tg_path)
161 | phone, duration, start, end = self.get_alignment(
162 | textgrid.get_tier_by_name("phones")
163 | )
164 | text = "{" + " ".join(phone) + "}"
165 | if start >= end:
166 | return None
167 |
168 | # Read and trim wav files
169 | wav, _ = librosa.load(wav_path, sr=None)
170 | wav = wav[
171 | int(self.sampling_rate * start) : int(self.sampling_rate * end)
172 | ].astype(np.float32)
173 |
174 | # Read raw text
175 | with open(text_path, "r") as f:
176 | raw_text = f.readline().strip("\n")
177 |
178 | # Compute fundamental frequency
179 | pitch, t = pw.dio(
180 | wav.astype(np.float64),
181 | self.sampling_rate,
182 | frame_period=self.hop_length / self.sampling_rate * 1000,
183 | )
184 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate)
185 |
186 | wav = torch.FloatTensor(wav)
187 | wav = wav.unsqueeze(0)
188 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav_as_hifigan(wav, self.preprocessing_confing["stft"]["filter_length"], \
189 | self.preprocessing_confing["mel"]["n_mel_channels"], self.preprocessing_confing["audio"]["sampling_rate"], \
190 | self.preprocessing_confing["stft"]["hop_length"], self.preprocessing_confing["stft"]["win_length"], \
191 | self.preprocessing_confing["mel"]["mel_fmin"], self.preprocessing_confing["mel"]["mel_fmax"], center=False)
192 |
193 | ## Matching tensors size
194 | min_size = min(sum(duration), mel_spectrogram.shape[1], energy.shape[0], pitch.shape[0])
195 | duration[-1] = duration[-1] + (min_size - sum(duration))
196 | mel_spectrogram = mel_spectrogram[:, : sum(duration)]
197 | energy = energy[: sum(duration)]
198 | pitch = pitch[: sum(duration)]
199 |
200 | if np.sum(pitch != 0) <= 1:
201 | return None
202 |
203 | if self.pitch_phoneme_averaging:
204 | # perform linear interpolation
205 | nonzero_ids = np.where(pitch != 0)[0]
206 | interp_fn = interp1d(
207 | nonzero_ids,
208 | pitch[nonzero_ids],
209 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
210 | bounds_error=False,
211 | )
212 | pitch = interp_fn(np.arange(0, len(pitch)))
213 |
214 | # Phoneme-level average
215 | pos = 0
216 | for i, d in enumerate(duration):
217 | if d > 0:
218 | pitch[i] = np.mean(pitch[pos : pos + d])
219 | else:
220 | pitch[i] = 0
221 | pos += d
222 | pitch = pitch[: len(duration)]
223 |
224 | if self.energy_phoneme_averaging:
225 | # Phoneme-level average
226 | pos = 0
227 | for i, d in enumerate(duration):
228 | if d > 0:
229 | energy[i] = np.mean(energy[pos : pos + d])
230 | else:
231 | energy[i] = 0
232 | pos += d
233 | energy = energy[: len(duration)]
234 |
235 | # Save files
236 | # if speaker:
237 | dur_filename = "{}-duration-{}.npy".format(speaker, basename)
238 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename)
239 | energy_filename = "{}-energy-{}.npy".format(speaker, basename)
240 | mel_filename = "{}-mel-{}.npy".format(speaker, basename)
241 | # else:
242 | # dur_filename = "duration-{}.npy".format(basename)
243 | # pitch_filename = "pitch-{}.npy".format(basename)
244 | # energy_filename = "energy-{}.npy".format(basename)
245 | # mel_filename = "mel-{}.npy".format(basename)
246 |
247 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration)
248 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch)
249 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy)
250 | np.save(
251 | os.path.join(self.out_dir, "mel", mel_filename),
252 | mel_spectrogram.T,
253 | )
254 |
255 | return (
256 | "|".join([basename, speaker, text, raw_text]),
257 | self.remove_outlier(pitch),
258 | self.remove_outlier(energy),
259 | mel_spectrogram.shape[1],
260 | )
261 |
262 | def get_alignment(self, tier):
263 | sil_phones = ["sil", "sp", "spn"]
264 |
265 | phones = []
266 | durations = []
267 | start_time = 0
268 | end_time = 0
269 | end_idx = 0
270 | for t in tier._objects:
271 | s, e, p = t.start_time, t.end_time, t.text
272 |
273 | # Trim leading silences
274 | if phones == []:
275 | if p in sil_phones:
276 | continue
277 | else:
278 | start_time = s
279 |
280 | if p not in sil_phones:
281 | # For ordinary phones
282 | phones.append(p)
283 | end_time = e
284 | end_idx = len(phones)
285 | else:
286 | # For silent phones
287 | phones.append(p)
288 |
289 | durations.append(
290 | int(
291 | np.round(e * self.sampling_rate / self.hop_length)
292 | - np.round(s * self.sampling_rate / self.hop_length)
293 | )
294 | )
295 |
296 | # Trim tailing silences
297 | phones = phones[:end_idx]
298 | durations = durations[:end_idx]
299 |
300 | return phones, durations, start_time, end_time
301 |
302 | def remove_outlier(self, values):
303 | values = np.array(values)
304 | p25 = np.percentile(values, 25)
305 | p75 = np.percentile(values, 75)
306 | lower = p25 - 1.5 * (p75 - p25)
307 | upper = p75 + 1.5 * (p75 - p25)
308 | normal_indices = np.logical_and(values > lower, values < upper)
309 |
310 | return values[normal_indices]
311 |
312 | def normalize(self, in_dir, mean, std):
313 | max_value = np.finfo(np.float64).min
314 | min_value = np.finfo(np.float64).max
315 | for filename in os.listdir(in_dir):
316 | filename = os.path.join(in_dir, filename)
317 | values = (np.load(filename) - mean) / std
318 | np.save(filename, values)
319 |
320 | max_value = max(max_value, max(values))
321 | min_value = min(min_value, min(values))
322 |
323 | return min_value, max_value
324 |
--------------------------------------------------------------------------------
/synthesizer/synthesize.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | from string import punctuation
5 | import re
6 | from g2p_en import G2p
7 |
8 | from .utils.tools import to_device, prepare_outputs
9 | from .text import text_to_sequence
10 |
11 | def read_lexicon(lex_path):
12 | lexicon = {}
13 | with open(lex_path) as f:
14 | for line in f:
15 | temp = re.split(r"\s+", line.strip("\n"))
16 | word = temp[0]
17 | phones = temp[1:]
18 | if word.lower() not in lexicon:
19 | lexicon[word.lower()] = phones
20 | return lexicon
21 |
22 |
23 | def preprocess_english(text, preprocess_config):
24 | text = text.rstrip(punctuation)
25 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
26 |
27 | g2p = G2p()
28 | phones = []
29 | words = re.split(r"([,;.\-\?\!\s+])", text)
30 | for w in words:
31 | if w.lower() in lexicon:
32 | phones += lexicon[w.lower()]
33 | else:
34 | phones += list(filter(lambda p: p != " ", g2p(w)))
35 | phones = "{" + "}{".join(phones) + "}"
36 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
37 | phones = phones.replace("}{", " ")
38 |
39 | print("Raw Text Sequence: {}".format(text))
40 | print("Phoneme Sequence: {}".format(phones))
41 | sequence = np.array(
42 | text_to_sequence(
43 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
44 | )
45 | )
46 | return np.array(sequence)
47 |
48 |
49 | def synthesize(model, batch, control_values, preprocess_config, device, p_target=None, d_target=None, e_target=None):
50 | pitch_control, energy_control, duration_control = control_values
51 |
52 | batch = to_device(batch, device)
53 | with torch.no_grad():
54 | # Forward
55 | output = model(
56 | *(batch[1:]),
57 |
58 | p_targets=p_target,
59 | e_targets=e_target,
60 | d_targets=d_target,
61 |
62 | p_control=pitch_control,
63 | e_control=energy_control,
64 | d_control=duration_control
65 | )
66 | mel, durations, pitch, energy = prepare_outputs(
67 | batch,
68 | output,
69 | preprocess_config,
70 | )
71 | return mel[0].to(device), durations[0].to(device), pitch[0].to(device), energy[0].to(device)
72 |
73 | def infer(model, text, control_values, preprocess_config, device, speaker=0, p_target=None, d_target=None, e_target=None):
74 | t = str(time.time()).replace('.', '_')
75 | ids = [t]
76 | speakers = np.array([speaker])
77 | if preprocess_config["preprocessing"]["text"]["language"] == "fa":
78 | texts = np.array([text_to_sequence(text, preprocess_config['preprocessing']['text']['text_cleaners'])])
79 | elif preprocess_config["preprocessing"]["text"]["language"] == "en":
80 | texts = np.array([preprocess_english(text, preprocess_config)])
81 | text_lens = np.array([len(texts[0])])
82 | batch = (ids, speakers, texts, text_lens, max(text_lens))
83 | model.eval()
84 | mel, durations, pitch, energy = synthesize(model, batch, control_values, preprocess_config, device, p_target, d_target, e_target)
85 | return mel, durations, pitch, energy
86 |
87 |
--------------------------------------------------------------------------------
/synthesizer/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from . import cleaners
3 | from .symbols import persian_symbols, english_symbols
4 | import re
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _persian_symbol_to_id = {s: i for i, s in enumerate(persian_symbols)}
8 | _persian_id_to_symbol = {i: s for i, s in enumerate(persian_symbols)}
9 |
10 | _english_symbol_to_id = {s: i for i, s in enumerate(english_symbols)}
11 | _english_id_to_symbol = {i: s for i, s in enumerate(english_symbols)}
12 |
13 | # Regular expression matching text enclosed in curly braces:
14 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
15 |
16 | def text_to_sequence(text, cleaner_name):
17 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
18 | Args:
19 | text: string to convert to a sequence
20 | cleaner_names: names of the cleaner functions to run the text through
21 |
22 | Returns:
23 | List of integers corresponding to the symbols in the text
24 | """
25 |
26 | if cleaner_name == 'persian_cleaner':
27 | text = text.replace('{', '').replace('}', '')
28 | sequence = [_persian_symbol_to_id[phonem] for phonem in text.split()]
29 |
30 | elif cleaner_name == 'english_cleaner':
31 | while len(text):
32 | m = _curly_re.match(text)
33 | if not m:
34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_name))
35 | break
36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_name))
37 | sequence += _arpabet_to_sequence(m.group(2))
38 | text = m.group(3)
39 | return sequence
40 |
41 |
42 |
43 | def sequence_to_text(sequence, cleaner_name):
44 | """Converts a sequence of IDs back to a string"""
45 | result = ""
46 | for symbol_id in sequence:
47 | if "persian_cleaner" == cleaner_name:
48 | if symbol_id in _persian_id_to_symbol:
49 | s = _persian_id_to_symbol[symbol_id]
50 | result += s
51 | elif "english_cleaner" == cleaner_name:
52 | if symbol_id in _english_id_to_symbol:
53 | s = _english_id_to_symbol[symbol_id]
54 | # Enclose ARPAbet back in curly braces:
55 | if len(s) > 1 and s[0] == "@":
56 | s = "{%s}" % s[1:]
57 | result += s
58 | return result.replace("}{", " ")
59 |
60 |
61 | def _clean_text(text, cleaner_name):
62 | cleaner = getattr(cleaners, cleaner_name)
63 | if not cleaner:
64 | raise Exception("Unknown cleaner: %s" % cleaner_name)
65 | text = cleaner(text)
66 | return text
67 |
68 |
69 | def _symbols_to_sequence(symbols, cleaner_name):
70 | if cleaner_name == "persian_cleaner":
71 | return [_persian_symbol_to_id[s] for s in symbols if _should_keep_symbol(s, _persian_symbol_to_id)]
72 | elif cleaner_name == "english_cleaner":
73 | return [_english_symbol_to_id[s] for s in symbols if _should_keep_symbol(s, _english_symbol_to_id)]
74 |
75 |
76 | def _arpabet_to_sequence(text):
77 | return _symbols_to_sequence(["@" + s for s in text.split()])
78 |
79 |
80 | def _should_keep_symbol(s, _symbol_to_id):
81 | return s in _symbol_to_id and s != "_" and s != "~"
82 |
--------------------------------------------------------------------------------
/synthesizer/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 | import re
18 | from unidecode import unidecode
19 | from .numbers import normalize_numbers
20 | _whitespace_re = re.compile(r'\s+')
21 |
22 | # List of (regular expression, replacement) pairs for abbreviations:
23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
24 | ('mrs', 'misess'),
25 | ('mr', 'mister'),
26 | ('dr', 'doctor'),
27 | ('st', 'saint'),
28 | ('co', 'company'),
29 | ('jr', 'junior'),
30 | ('maj', 'major'),
31 | ('gen', 'general'),
32 | ('drs', 'doctors'),
33 | ('rev', 'reverend'),
34 | ('lt', 'lieutenant'),
35 | ('hon', 'honorable'),
36 | ('sgt', 'sergeant'),
37 | ('capt', 'captain'),
38 | ('esq', 'esquire'),
39 | ('ltd', 'limited'),
40 | ('col', 'colonel'),
41 | ('ft', 'fort'),
42 | ]]
43 |
44 |
45 | def expand_abbreviations(text):
46 | for regex, replacement in _abbreviations:
47 | text = re.sub(regex, replacement, text)
48 | return text
49 |
50 |
51 | def expand_numbers(text):
52 | return normalize_numbers(text)
53 |
54 |
55 | def lowercase(text):
56 | return text.lower()
57 |
58 |
59 | def collapse_whitespace(text):
60 | return re.sub(_whitespace_re, ' ', text)
61 |
62 |
63 | def convert_to_ascii(text):
64 | return unidecode(text)
65 |
66 |
67 | def basic_cleaners(text):
68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
69 | text = lowercase(text)
70 | text = collapse_whitespace(text)
71 | return text
72 |
73 |
74 | def transliteration_cleaners(text):
75 | '''Pipeline for non-English text that transliterates to ASCII.'''
76 | text = convert_to_ascii(text)
77 | text = lowercase(text)
78 | text = collapse_whitespace(text)
79 | return text
80 |
81 |
82 | def english_cleaner(text):
83 | '''Pipeline for English text, including number and abbreviation expansion.'''
84 | text = convert_to_ascii(text)
85 | text = lowercase(text)
86 | text = expand_numbers(text)
87 | text = expand_abbreviations(text)
88 | text = collapse_whitespace(text)
89 | return text
90 |
91 | def persian_cleaner(text):
92 | text = collapse_whitespace(text)
93 | return text
--------------------------------------------------------------------------------
/synthesizer/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | "AA",
8 | "AA0",
9 | "AA1",
10 | "AA2",
11 | "AE",
12 | "AE0",
13 | "AE1",
14 | "AE2",
15 | "AH",
16 | "AH0",
17 | "AH1",
18 | "AH2",
19 | "AO",
20 | "AO0",
21 | "AO1",
22 | "AO2",
23 | "AW",
24 | "AW0",
25 | "AW1",
26 | "AW2",
27 | "AY",
28 | "AY0",
29 | "AY1",
30 | "AY2",
31 | "B",
32 | "CH",
33 | "D",
34 | "DH",
35 | "EH",
36 | "EH0",
37 | "EH1",
38 | "EH2",
39 | "ER",
40 | "ER0",
41 | "ER1",
42 | "ER2",
43 | "EY",
44 | "EY0",
45 | "EY1",
46 | "EY2",
47 | "F",
48 | "G",
49 | "HH",
50 | "IH",
51 | "IH0",
52 | "IH1",
53 | "IH2",
54 | "IY",
55 | "IY0",
56 | "IY1",
57 | "IY2",
58 | "JH",
59 | "K",
60 | "L",
61 | "M",
62 | "N",
63 | "NG",
64 | "OW",
65 | "OW0",
66 | "OW1",
67 | "OW2",
68 | "OY",
69 | "OY0",
70 | "OY1",
71 | "OY2",
72 | "P",
73 | "R",
74 | "S",
75 | "SH",
76 | "T",
77 | "TH",
78 | "UH",
79 | "UH0",
80 | "UH1",
81 | "UH2",
82 | "UW",
83 | "UW0",
84 | "UW1",
85 | "UW2",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | ]
92 |
93 | _valid_symbol_set = set(valid_symbols)
94 |
95 |
96 | class CMUDict:
97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98 |
99 | def __init__(self, file_or_path, keep_ambiguous=True):
100 | if isinstance(file_or_path, str):
101 | with open(file_or_path, encoding="latin-1") as f:
102 | entries = _parse_cmudict(f)
103 | else:
104 | entries = _parse_cmudict(file_or_path)
105 | if not keep_ambiguous:
106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107 | self._entries = entries
108 |
109 | def __len__(self):
110 | return len(self._entries)
111 |
112 | def lookup(self, word):
113 | """Returns list of ARPAbet pronunciations of the given word."""
114 | return self._entries.get(word.upper())
115 |
116 |
117 | _alt_re = re.compile(r"\([0-9]+\)")
118 |
119 |
120 | def _parse_cmudict(file):
121 | cmudict = {}
122 | for line in file:
123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124 | parts = line.split(" ")
125 | word = re.sub(_alt_re, "", parts[0])
126 | pronunciation = _get_pronunciation(parts[1])
127 | if pronunciation:
128 | if word in cmudict:
129 | cmudict[word].append(pronunciation)
130 | else:
131 | cmudict[word] = [pronunciation]
132 | return cmudict
133 |
134 |
135 | def _get_pronunciation(s):
136 | parts = s.strip().split(" ")
137 | for part in parts:
138 | if part not in _valid_symbol_set:
139 | return None
140 | return " ".join(parts)
141 |
--------------------------------------------------------------------------------
/synthesizer/text/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars" # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return "%s %s" % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return "%s %s" % (cents, cent_unit)
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(
60 | num, andword="", zero="oh", group=2
61 | ).replace(", ", " ")
62 | else:
63 | return _inflect.number_to_words(num, andword="")
64 |
65 |
66 | def normalize_numbers(text):
67 | text = re.sub(_comma_number_re, _remove_commas, text)
68 | text = re.sub(_pounds_re, r"\1 pounds", text)
69 | text = re.sub(_dollars_re, _expand_dollars, text)
70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71 | text = re.sub(_ordinal_re, _expand_ordinal, text)
72 | text = re.sub(_number_re, _expand_number, text)
73 | return text
74 |
--------------------------------------------------------------------------------
/synthesizer/text/pinyin.py:
--------------------------------------------------------------------------------
1 | initials = [
2 | "b",
3 | "c",
4 | "ch",
5 | "d",
6 | "f",
7 | "g",
8 | "h",
9 | "j",
10 | "k",
11 | "l",
12 | "m",
13 | "n",
14 | "p",
15 | "q",
16 | "r",
17 | "s",
18 | "sh",
19 | "t",
20 | "w",
21 | "x",
22 | "y",
23 | "z",
24 | "zh",
25 | ]
26 | finals = [
27 | "a1",
28 | "a2",
29 | "a3",
30 | "a4",
31 | "a5",
32 | "ai1",
33 | "ai2",
34 | "ai3",
35 | "ai4",
36 | "ai5",
37 | "an1",
38 | "an2",
39 | "an3",
40 | "an4",
41 | "an5",
42 | "ang1",
43 | "ang2",
44 | "ang3",
45 | "ang4",
46 | "ang5",
47 | "ao1",
48 | "ao2",
49 | "ao3",
50 | "ao4",
51 | "ao5",
52 | "e1",
53 | "e2",
54 | "e3",
55 | "e4",
56 | "e5",
57 | "ei1",
58 | "ei2",
59 | "ei3",
60 | "ei4",
61 | "ei5",
62 | "en1",
63 | "en2",
64 | "en3",
65 | "en4",
66 | "en5",
67 | "eng1",
68 | "eng2",
69 | "eng3",
70 | "eng4",
71 | "eng5",
72 | "er1",
73 | "er2",
74 | "er3",
75 | "er4",
76 | "er5",
77 | "i1",
78 | "i2",
79 | "i3",
80 | "i4",
81 | "i5",
82 | "ia1",
83 | "ia2",
84 | "ia3",
85 | "ia4",
86 | "ia5",
87 | "ian1",
88 | "ian2",
89 | "ian3",
90 | "ian4",
91 | "ian5",
92 | "iang1",
93 | "iang2",
94 | "iang3",
95 | "iang4",
96 | "iang5",
97 | "iao1",
98 | "iao2",
99 | "iao3",
100 | "iao4",
101 | "iao5",
102 | "ie1",
103 | "ie2",
104 | "ie3",
105 | "ie4",
106 | "ie5",
107 | "ii1",
108 | "ii2",
109 | "ii3",
110 | "ii4",
111 | "ii5",
112 | "iii1",
113 | "iii2",
114 | "iii3",
115 | "iii4",
116 | "iii5",
117 | "in1",
118 | "in2",
119 | "in3",
120 | "in4",
121 | "in5",
122 | "ing1",
123 | "ing2",
124 | "ing3",
125 | "ing4",
126 | "ing5",
127 | "iong1",
128 | "iong2",
129 | "iong3",
130 | "iong4",
131 | "iong5",
132 | "iou1",
133 | "iou2",
134 | "iou3",
135 | "iou4",
136 | "iou5",
137 | "o1",
138 | "o2",
139 | "o3",
140 | "o4",
141 | "o5",
142 | "ong1",
143 | "ong2",
144 | "ong3",
145 | "ong4",
146 | "ong5",
147 | "ou1",
148 | "ou2",
149 | "ou3",
150 | "ou4",
151 | "ou5",
152 | "u1",
153 | "u2",
154 | "u3",
155 | "u4",
156 | "u5",
157 | "ua1",
158 | "ua2",
159 | "ua3",
160 | "ua4",
161 | "ua5",
162 | "uai1",
163 | "uai2",
164 | "uai3",
165 | "uai4",
166 | "uai5",
167 | "uan1",
168 | "uan2",
169 | "uan3",
170 | "uan4",
171 | "uan5",
172 | "uang1",
173 | "uang2",
174 | "uang3",
175 | "uang4",
176 | "uang5",
177 | "uei1",
178 | "uei2",
179 | "uei3",
180 | "uei4",
181 | "uei5",
182 | "uen1",
183 | "uen2",
184 | "uen3",
185 | "uen4",
186 | "uen5",
187 | "uo1",
188 | "uo2",
189 | "uo3",
190 | "uo4",
191 | "uo5",
192 | "v1",
193 | "v2",
194 | "v3",
195 | "v4",
196 | "v5",
197 | "van1",
198 | "van2",
199 | "van3",
200 | "van4",
201 | "van5",
202 | "ve1",
203 | "ve2",
204 | "ve3",
205 | "ve4",
206 | "ve5",
207 | "vn1",
208 | "vn2",
209 | "vn3",
210 | "vn4",
211 | "vn5",
212 | ]
213 | valid_symbols = initials + finals + ["rr"]
--------------------------------------------------------------------------------
/synthesizer/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
7 |
8 | from . import cmudict, pinyin
9 |
10 | _pad = "_"
11 | _punctuation = "!'(),.:;? "
12 | _special = "-"
13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
14 | _silences = ["@sp", "@spn", "@sil"]
15 |
16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
17 | _arpabet = ["@" + s for s in cmudict.valid_symbols]
18 | _pinyin = ["@" + s for s in pinyin.valid_symbols]
19 |
20 | # Export all symbols:
21 | english_symbols = (
22 | [_pad]
23 | + list(_special)
24 | + list(_punctuation)
25 | + list(_letters)
26 | + _arpabet
27 | + _pinyin
28 | + _silences
29 | )
30 |
31 | ## for Persian Language
32 | persian_phonemes = ['U', 'Q', 'G', 'AA', 'V', 'N', 'CH', 'R', 'KH', 'B', 'Z', 'SH', 'O', 'A', 'E', 'ZH', 'H', 'SIL', 'AH', \
33 | 'S', 'D', 'J', 'L', 'F', 'K', 'I', 'T', 'P', 'M', 'Y']
34 | persian_phonemes += ['?', '!', '.', ',', ';', ':']
35 | persian_symbols = (
36 | [_pad]
37 | + persian_phonemes
38 | )
39 |
40 | def get_symbols(language):
41 | if language == "fa":
42 | return persian_symbols
43 | elif language == "en":
44 | return english_symbols
45 |
--------------------------------------------------------------------------------
/synthesizer/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torch.utils.tensorboard import SummaryWriter
7 | from tqdm import tqdm
8 |
9 | from .utils.model import get_model, get_param_num
10 | from .utils.tools import to_device, log, synth_one_sample
11 | from .vocoder.utils import get_vocoder
12 | from .model import FastSpeech2Loss
13 | from .dataset import Dataset
14 | from .evaluate import evaluate
15 |
16 |
17 | def train_model(args, config):
18 | print("Prepare training ...")
19 |
20 | preprocess_config, model_config, train_config = config['synthesizer']['preprocess'], config['synthesizer']['model'], config['synthesizer']['train']
21 | device = config['main']['device']
22 | # Get dataset
23 | dataset = Dataset(
24 | "train.txt", preprocess_config, train_config, sort=True, drop_last=True
25 | )
26 | batch_size = train_config["optimizer"]["batch_size"]
27 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset
28 | assert batch_size * group_size < len(dataset)
29 | loader = DataLoader(
30 | dataset,
31 | batch_size=batch_size * group_size,
32 | shuffle=True,
33 | collate_fn=dataset.collate_fn,
34 | )
35 |
36 | # Prepare model
37 | model, optimizer = get_model(args.restore_step, config, train=True)
38 | if "cuda" in device:
39 | if "cuda" == device: ## If cuda id is not defined (use all GPUs)
40 | model = nn.DataParallel(model)
41 | else:
42 | device_id = int(device.replace("cuda:", ""))
43 | model = nn.DataParallel(model, device_ids=[device_id], output_device=[device_id])
44 | num_param = get_param_num(model)
45 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
46 | print("Number of FastSpeech2 Parameters:", num_param)
47 |
48 | # Load vocoder
49 | vocoder = get_vocoder(config['vocoder'], device)
50 |
51 | # Init logger
52 | for p in train_config["path"].values():
53 | os.makedirs(p, exist_ok=True)
54 | train_log_path = os.path.join(train_config["path"]["log_path"], "train")
55 | val_log_path = os.path.join(train_config["path"]["log_path"], "val")
56 | os.makedirs(train_log_path, exist_ok=True)
57 | os.makedirs(val_log_path, exist_ok=True)
58 | train_logger = SummaryWriter(train_log_path)
59 | val_logger = SummaryWriter(val_log_path)
60 |
61 | # Training
62 | step = args.restore_step + 1
63 | epoch = 1
64 | grad_acc_step = train_config["optimizer"]["grad_acc_step"]
65 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
66 | total_step = train_config["step"]["total_step"]
67 | log_step = train_config["step"]["log_step"]
68 | save_step = train_config["step"]["save_step"]
69 | synth_step = train_config["step"]["synth_step"]
70 | val_step = train_config["step"]["val_step"]
71 |
72 | outer_bar = tqdm(total=total_step, desc="Training", position=0)
73 | outer_bar.n = args.restore_step
74 | outer_bar.update()
75 |
76 | while True:
77 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
78 | for batchs in loader:
79 | for batch in batchs:
80 | batch = to_device(batch, device)
81 | # Forward
82 | output = model(*(batch[1:]))
83 |
84 | # Cal Loss
85 | losses = Loss(batch, output)
86 | total_loss = losses[0]
87 |
88 | # Backward
89 | total_loss = total_loss / grad_acc_step
90 | total_loss.backward()
91 | if step % grad_acc_step == 0:
92 | # Clipping gradients to avoid gradient explosion
93 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
94 |
95 | # Update weights
96 | optimizer.step_and_update_lr()
97 | optimizer.zero_grad()
98 |
99 | if step % log_step == 0:
100 | losses = [l.item() for l in losses]
101 | message1 = "Step {}/{}, ".format(step, total_step)
102 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
103 | *losses
104 | )
105 |
106 | with open(os.path.join(train_log_path, "log.txt"), "a") as f:
107 | f.write(message1 + message2 + "\n")
108 |
109 | outer_bar.write(message1 + message2)
110 |
111 | log(train_logger, step, losses=losses)
112 |
113 | if step % synth_step == 0:
114 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
115 | batch,
116 | output,
117 | vocoder,
118 | model_config,
119 | preprocess_config,
120 | )
121 | log(
122 | train_logger,
123 | fig=fig,
124 | tag="Training/step_{}_{}".format(step, tag),
125 | )
126 | sampling_rate = preprocess_config["preprocessing"]["audio"][
127 | "sampling_rate"
128 | ]
129 | log(
130 | train_logger,
131 | audio=wav_reconstruction,
132 | sampling_rate=sampling_rate,
133 | tag="Training/step_{}_{}_reconstructed".format(step, tag),
134 | )
135 | log(
136 | train_logger,
137 | audio=wav_prediction,
138 | sampling_rate=sampling_rate,
139 | tag="Training/step_{}_{}_synthesized".format(step, tag),
140 | )
141 |
142 | if step % val_step == 0:
143 | model.eval()
144 | message = evaluate(model, step, config, val_logger, vocoder)
145 | with open(os.path.join(val_log_path, "log.txt"), "a") as f:
146 | f.write(message + "\n")
147 | outer_bar.write(message)
148 |
149 | model.train()
150 |
151 | if step % save_step == 0:
152 | torch.save(
153 | {
154 | "model": model.module.state_dict(),
155 | "optimizer": optimizer._optimizer.state_dict(),
156 | },
157 | os.path.join(
158 | train_config["path"]["ckpt_path"],
159 | "{}.pth.tar".format(step),
160 | ),
161 | )
162 |
163 | if step == total_step:
164 | quit()
165 | step += 1
166 | outer_bar.update(1)
167 |
168 | inner_bar.update(1)
169 | epoch += 1
170 |
171 |
172 |
173 |
--------------------------------------------------------------------------------
/synthesizer/transformer/Constants.py:
--------------------------------------------------------------------------------
1 | PAD = 0
2 | UNK = 1
3 | BOS = 2
4 | EOS = 3
5 |
6 | PAD_WORD = ""
7 | UNK_WORD = ""
8 | BOS_WORD = ""
9 | EOS_WORD = ""
10 |
--------------------------------------------------------------------------------
/synthesizer/transformer/Layers.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | from torch.nn import functional as F
7 |
8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
9 |
10 |
11 | class FFTBlock(torch.nn.Module):
12 | """FFT Block"""
13 |
14 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
15 | super(FFTBlock, self).__init__()
16 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
17 | self.pos_ffn = PositionwiseFeedForward(
18 | d_model, d_inner, kernel_size, dropout=dropout
19 | )
20 |
21 | def forward(self, enc_input, mask=None, slf_attn_mask=None):
22 | enc_output, enc_slf_attn = self.slf_attn(
23 | enc_input, enc_input, enc_input, mask=slf_attn_mask
24 | )
25 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
26 |
27 | enc_output = self.pos_ffn(enc_output)
28 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
29 |
30 | return enc_output, enc_slf_attn
31 |
32 |
33 | class ConvNorm(torch.nn.Module):
34 | def __init__(
35 | self,
36 | in_channels,
37 | out_channels,
38 | kernel_size=1,
39 | stride=1,
40 | padding=None,
41 | dilation=1,
42 | bias=True,
43 | w_init_gain="linear",
44 | ):
45 | super(ConvNorm, self).__init__()
46 |
47 | if padding is None:
48 | assert kernel_size % 2 == 1
49 | padding = int(dilation * (kernel_size - 1) / 2)
50 |
51 | self.conv = torch.nn.Conv1d(
52 | in_channels,
53 | out_channels,
54 | kernel_size=kernel_size,
55 | stride=stride,
56 | padding=padding,
57 | dilation=dilation,
58 | bias=bias,
59 | )
60 |
61 | def forward(self, signal):
62 | conv_signal = self.conv(signal)
63 |
64 | return conv_signal
65 |
66 |
67 | class PostNet(nn.Module):
68 | """
69 | PostNet: Five 1-d convolution with 512 channels and kernel size 5
70 | """
71 |
72 | def __init__(
73 | self,
74 | n_mel_channels=80,
75 | postnet_embedding_dim=512,
76 | postnet_kernel_size=5,
77 | postnet_n_convolutions=5,
78 | ):
79 |
80 | super(PostNet, self).__init__()
81 | self.convolutions = nn.ModuleList()
82 |
83 | self.convolutions.append(
84 | nn.Sequential(
85 | ConvNorm(
86 | n_mel_channels,
87 | postnet_embedding_dim,
88 | kernel_size=postnet_kernel_size,
89 | stride=1,
90 | padding=int((postnet_kernel_size - 1) / 2),
91 | dilation=1,
92 | w_init_gain="tanh",
93 | ),
94 | nn.BatchNorm1d(postnet_embedding_dim),
95 | )
96 | )
97 |
98 | for i in range(1, postnet_n_convolutions - 1):
99 | self.convolutions.append(
100 | nn.Sequential(
101 | ConvNorm(
102 | postnet_embedding_dim,
103 | postnet_embedding_dim,
104 | kernel_size=postnet_kernel_size,
105 | stride=1,
106 | padding=int((postnet_kernel_size - 1) / 2),
107 | dilation=1,
108 | w_init_gain="tanh",
109 | ),
110 | nn.BatchNorm1d(postnet_embedding_dim),
111 | )
112 | )
113 |
114 | self.convolutions.append(
115 | nn.Sequential(
116 | ConvNorm(
117 | postnet_embedding_dim,
118 | n_mel_channels,
119 | kernel_size=postnet_kernel_size,
120 | stride=1,
121 | padding=int((postnet_kernel_size - 1) / 2),
122 | dilation=1,
123 | w_init_gain="linear",
124 | ),
125 | nn.BatchNorm1d(n_mel_channels),
126 | )
127 | )
128 |
129 | def forward(self, x):
130 | x = x.contiguous().transpose(1, 2)
131 |
132 | for i in range(len(self.convolutions) - 1):
133 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
134 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
135 |
136 | x = x.contiguous().transpose(1, 2)
137 | return x
138 |
--------------------------------------------------------------------------------
/synthesizer/transformer/Models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from ..transformer import Constants
6 | from .Layers import FFTBlock
7 | from ..text.symbols import get_symbols
8 |
9 |
10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
11 | """ Sinusoid position encoding table """
12 |
13 | def cal_angle(position, hid_idx):
14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
15 |
16 | def get_posi_angle_vec(position):
17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
18 |
19 | sinusoid_table = np.array(
20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
21 | )
22 |
23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
25 |
26 | if padding_idx is not None:
27 | # zero vector for padding dimension
28 | sinusoid_table[padding_idx] = 0.0
29 |
30 | return torch.FloatTensor(sinusoid_table)
31 |
32 |
33 | class Encoder(nn.Module):
34 | """ Encoder """
35 |
36 | def __init__(self, config, language):
37 | super(Encoder, self).__init__()
38 |
39 | n_position = config["max_seq_len"] + 1
40 | n_src_vocab = len(get_symbols("en")) + 1
41 | d_word_vec = config["transformer"]["encoder_hidden"]
42 | n_layers = config["transformer"]["encoder_layer"]
43 | n_head = config["transformer"]["encoder_head"]
44 | d_k = d_v = (
45 | config["transformer"]["encoder_hidden"]
46 | // config["transformer"]["encoder_head"]
47 | )
48 | d_model = config["transformer"]["encoder_hidden"]
49 | d_inner = config["transformer"]["conv_filter_size"]
50 | kernel_size = config["transformer"]["conv_kernel_size"]
51 | dropout = config["transformer"]["encoder_dropout"]
52 |
53 | self.max_seq_len = config["max_seq_len"]
54 | self.d_model = d_model
55 |
56 | self.src_word_emb = nn.Embedding(
57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD
58 | )
59 | self.position_enc = nn.Parameter(
60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
61 | requires_grad=False,
62 | )
63 |
64 | self.layer_stack = nn.ModuleList(
65 | [
66 | FFTBlock(
67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
68 | )
69 | for _ in range(n_layers)
70 | ]
71 | )
72 |
73 | def forward(self, src_seq, mask, return_attns=False):
74 |
75 | enc_slf_attn_list = []
76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
77 |
78 | # -- Prepare masks
79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
80 |
81 | # -- Forward
82 | if not self.training and src_seq.shape[1] > self.max_seq_len:
83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table(
84 | src_seq.shape[1], self.d_model
85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
86 | src_seq.device
87 | )
88 | else:
89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[
90 | :, :max_len, :
91 | ].expand(batch_size, -1, -1)
92 |
93 | for enc_layer in self.layer_stack:
94 | enc_output, enc_slf_attn = enc_layer(
95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask
96 | )
97 | if return_attns:
98 | enc_slf_attn_list += [enc_slf_attn]
99 |
100 | return enc_output
101 |
102 |
103 | class Decoder(nn.Module):
104 | """ Decoder """
105 |
106 | def __init__(self, config):
107 | super(Decoder, self).__init__()
108 |
109 | n_position = config["max_seq_len"] + 1
110 | d_word_vec = config["transformer"]["decoder_hidden"]
111 | n_layers = config["transformer"]["decoder_layer"]
112 | n_head = config["transformer"]["decoder_head"]
113 | d_k = d_v = (
114 | config["transformer"]["decoder_hidden"]
115 | // config["transformer"]["decoder_head"]
116 | )
117 | d_model = config["transformer"]["decoder_hidden"]
118 | d_inner = config["transformer"]["conv_filter_size"]
119 | kernel_size = config["transformer"]["conv_kernel_size"]
120 | dropout = config["transformer"]["decoder_dropout"]
121 |
122 | self.max_seq_len = config["max_seq_len"]
123 | self.d_model = d_model
124 |
125 | self.position_enc = nn.Parameter(
126 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
127 | requires_grad=False,
128 | )
129 |
130 | self.layer_stack = nn.ModuleList(
131 | [
132 | FFTBlock(
133 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
134 | )
135 | for _ in range(n_layers)
136 | ]
137 | )
138 |
139 | def forward(self, enc_seq, mask, return_attns=False):
140 |
141 | dec_slf_attn_list = []
142 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
143 |
144 | # -- Forward
145 | if not self.training and enc_seq.shape[1] > self.max_seq_len:
146 | # -- Prepare masks
147 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
148 | dec_output = enc_seq + get_sinusoid_encoding_table(
149 | enc_seq.shape[1], self.d_model
150 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
151 | enc_seq.device
152 | )
153 | else:
154 | max_len = min(max_len, self.max_seq_len)
155 |
156 | # -- Prepare masks
157 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
158 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[
159 | :, :max_len, :
160 | ].expand(batch_size, -1, -1)
161 | mask = mask[:, :max_len]
162 | slf_attn_mask = slf_attn_mask[:, :, :max_len]
163 |
164 | for dec_layer in self.layer_stack:
165 | dec_output, dec_slf_attn = dec_layer(
166 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask
167 | )
168 | if return_attns:
169 | dec_slf_attn_list += [dec_slf_attn]
170 |
171 | return dec_output, mask
172 |
--------------------------------------------------------------------------------
/synthesizer/transformer/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class ScaledDotProductAttention(nn.Module):
7 | """ Scaled Dot-Product Attention """
8 |
9 | def __init__(self, temperature):
10 | super().__init__()
11 | self.temperature = temperature
12 | self.softmax = nn.Softmax(dim=2)
13 |
14 | def forward(self, q, k, v, mask=None):
15 |
16 | attn = torch.bmm(q, k.transpose(1, 2))
17 | attn = attn / self.temperature
18 |
19 | if mask is not None:
20 | attn = attn.masked_fill(mask, -np.inf)
21 |
22 | attn = self.softmax(attn)
23 | output = torch.bmm(attn, v)
24 |
25 | return output, attn
26 |
--------------------------------------------------------------------------------
/synthesizer/transformer/SubLayers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from .Modules import ScaledDotProductAttention
6 |
7 |
8 | class MultiHeadAttention(nn.Module):
9 | """ Multi-Head Attention module """
10 |
11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
12 | super().__init__()
13 |
14 | self.n_head = n_head
15 | self.d_k = d_k
16 | self.d_v = d_v
17 |
18 | self.w_qs = nn.Linear(d_model, n_head * d_k)
19 | self.w_ks = nn.Linear(d_model, n_head * d_k)
20 | self.w_vs = nn.Linear(d_model, n_head * d_v)
21 |
22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
23 | self.layer_norm = nn.LayerNorm(d_model)
24 |
25 | self.fc = nn.Linear(n_head * d_v, d_model)
26 |
27 | self.dropout = nn.Dropout(dropout)
28 |
29 | def forward(self, q, k, v, mask=None):
30 |
31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
32 |
33 | sz_b, len_q, _ = q.size()
34 | sz_b, len_k, _ = k.size()
35 | sz_b, len_v, _ = v.size()
36 |
37 | residual = q
38 |
39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
45 |
46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
47 | output, attn = self.attention(q, k, v, mask=mask)
48 |
49 | output = output.view(n_head, sz_b, len_q, d_v)
50 | output = (
51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)
52 | ) # b x lq x (n*dv)
53 |
54 | output = self.dropout(self.fc(output))
55 | output = self.layer_norm(output + residual)
56 |
57 | return output, attn
58 |
59 |
60 | class PositionwiseFeedForward(nn.Module):
61 | """ A two-feed-forward-layer module """
62 |
63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1):
64 | super().__init__()
65 |
66 | # Use Conv1D
67 | # position-wise
68 | self.w_1 = nn.Conv1d(
69 | d_in,
70 | d_hid,
71 | kernel_size=kernel_size[0],
72 | padding=(kernel_size[0] - 1) // 2,
73 | )
74 | # position-wise
75 | self.w_2 = nn.Conv1d(
76 | d_hid,
77 | d_in,
78 | kernel_size=kernel_size[1],
79 | padding=(kernel_size[1] - 1) // 2,
80 | )
81 |
82 | self.layer_norm = nn.LayerNorm(d_in)
83 | self.dropout = nn.Dropout(dropout)
84 |
85 | def forward(self, x):
86 | residual = x
87 | output = x.transpose(1, 2)
88 | output = self.w_2(F.relu(self.w_1(output)))
89 | output = output.transpose(1, 2)
90 | output = self.dropout(output)
91 | output = self.layer_norm(output + residual)
92 |
93 | return output
94 |
--------------------------------------------------------------------------------
/synthesizer/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .Models import Encoder, Decoder
2 | from .Layers import PostNet
--------------------------------------------------------------------------------
/synthesizer/utils/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from ..model import FastSpeech2, ScheduledOptim
5 |
6 |
7 | def get_model(restore_step, config, train=False):
8 |
9 | device = config['main']['device']
10 | model = FastSpeech2(config).to(device)
11 | if restore_step:
12 | ckpt_path = os.path.join(
13 | config['synthesizer']['train']["path"]["ckpt_path"],
14 | "{}.pth.tar".format(restore_step),
15 | )
16 | # ckpt_path = "/mnt/hdd1/adibian/FastSpeech2/ResGrad/output/MS_Persian/add_speaker_befor_and_after_VA/synthesizer/ckpt/1000000.pth.tar"
17 | # ckpt_path = "/mnt/hdd1/adibian/FastSpeech2/ResGrad/output/MS_Persian/add_speaker_before_VA/synthesizer/ckpt/1000000.pth.tar"
18 | ckpt = torch.load(ckpt_path, map_location=torch.device(device))
19 | model.load_state_dict(ckpt["model"])
20 |
21 | if train:
22 | scheduled_optim = ScheduledOptim(
23 | model, config['synthesizer']['train'], config['synthesizer']['model'], restore_step
24 | )
25 | if restore_step:
26 | scheduled_optim.load_state_dict(ckpt["optimizer"])
27 | model.train()
28 | return model, scheduled_optim
29 |
30 | model.eval()
31 | model.requires_grad_ = False
32 | return model
33 |
34 |
35 | def get_param_num(model):
36 | num_param = sum(param.numel() for param in model.parameters())
37 | return num_param
38 |
39 |
--------------------------------------------------------------------------------
/synthesizer/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import yaml
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import matplotlib
9 | from scipy.io import wavfile
10 | from matplotlib import pyplot as plt
11 |
12 | from vocoder.inference import infer as vocoder_infer
13 |
14 | matplotlib.use("Agg")
15 |
16 | def to_device(data, device):
17 | if len(data) == 11:
18 | (
19 | ids,
20 | speakers,
21 | texts,
22 | src_lens,
23 | max_src_len,
24 | mels,
25 | mel_lens,
26 | max_mel_len,
27 | pitches,
28 | energies,
29 | durations,
30 | ) = data
31 |
32 | speakers = torch.from_numpy(speakers).long().to(device)
33 | texts = torch.from_numpy(texts).long().to(device)
34 | src_lens = torch.from_numpy(src_lens).to(device)
35 | mels = torch.from_numpy(mels).float().to(device)
36 | mel_lens = torch.from_numpy(mel_lens).to(device)
37 | pitches = torch.from_numpy(pitches).float().to(device)
38 | energies = torch.from_numpy(energies).to(device)
39 | durations = torch.from_numpy(durations).long().to(device)
40 |
41 | return (
42 | ids,
43 | speakers,
44 | texts,
45 | src_lens,
46 | max_src_len,
47 | mels,
48 | mel_lens,
49 | max_mel_len,
50 | pitches,
51 | energies,
52 | durations,
53 | )
54 |
55 | if len(data) == 5:
56 | (ids, speakers, texts, src_lens, max_src_len) = data
57 | speakers = torch.from_numpy(speakers).long().to(device)
58 | texts = torch.from_numpy(texts).long().to(device)
59 | src_lens = torch.from_numpy(src_lens).to(device)
60 |
61 | return (ids, speakers, texts, src_lens, max_src_len)
62 |
63 |
64 | def log(
65 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""
66 | ):
67 | if losses is not None:
68 | logger.add_scalar("Loss/total_loss", losses[0], step)
69 | logger.add_scalar("Loss/mel_loss", losses[1], step)
70 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
71 | logger.add_scalar("Loss/pitch_loss", losses[3], step)
72 | logger.add_scalar("Loss/energy_loss", losses[4], step)
73 | logger.add_scalar("Loss/duration_loss", losses[5], step)
74 |
75 | if fig is not None:
76 | logger.add_figure(tag, fig)
77 |
78 | if audio is not None:
79 | logger.add_audio(
80 | tag,
81 | audio / max(abs(audio)),
82 | sample_rate=sampling_rate,
83 | )
84 |
85 |
86 | def get_mask_from_lengths(lengths, max_len=None, device=None):
87 | batch_size = lengths.shape[0]
88 | if max_len is None:
89 | max_len = torch.max(lengths).item()
90 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
91 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
92 |
93 | return mask
94 |
95 |
96 | def expand(values, durations):
97 | out = list()
98 | for value, d in zip(values, durations):
99 | out += [value] * max(0, int(d))
100 | return np.array(out)
101 |
102 |
103 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config):
104 |
105 | basename = targets[0][0]
106 | src_len = predictions[8][0].item()
107 | mel_len = predictions[9][0].item()
108 | mel_target = targets[5][0, :mel_len].detach().transpose(0, 1)
109 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1)
110 | duration = targets[10][0, :src_len].detach().cpu().numpy()
111 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
112 | pitch = targets[8][0, :src_len].detach().cpu().numpy()
113 | pitch = expand(pitch, duration)
114 | else:
115 | pitch = targets[8][0, :mel_len].detach().cpu().numpy()
116 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
117 | energy = targets[9][0, :src_len].detach().cpu().numpy()
118 | energy = expand(energy, duration)
119 | else:
120 | energy = targets[9][0, :mel_len].detach().cpu().numpy()
121 |
122 | with open(
123 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
124 | ) as f:
125 | stats = json.load(f)
126 | stats = stats["pitch"] + stats["energy"][:2]
127 |
128 | fig = plot_mel(
129 | [
130 | (mel_prediction.cpu().numpy(), pitch, energy),
131 | (mel_target.cpu().numpy(), pitch, energy),
132 | ],
133 | stats,
134 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
135 | )
136 | if vocoder is not None:
137 | wav_reconstruction = vocoder_infer(
138 | vocoder,
139 | mel_target.unsqueeze(0),
140 | preprocess_config["preprocessing"]["audio"]["max_wav_value"],
141 | )
142 | wav_prediction = vocoder_infer(
143 | vocoder,
144 | mel_prediction.unsqueeze(0),
145 | preprocess_config["preprocessing"]["audio"]["max_wav_value"],
146 | )
147 | else:
148 | wav_reconstruction = wav_prediction = None
149 |
150 | return fig, wav_reconstruction, wav_prediction, basename
151 |
152 |
153 | def prepare_outputs(targets, predictions, preprocess_config):
154 | all_mel, all_durations, all_pitch, all_energy = [], [], [], []
155 | for i in range(len(predictions[0])):
156 | src_len = predictions[8][i].item()
157 | mel_len = predictions[9][i].item()
158 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
159 | duration = predictions[5][i, :src_len].detach()
160 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
161 | pitch = predictions[2][i, :src_len].detach().cpu()
162 | pitch = expand(pitch, duration)
163 | pitch = torch.from_numpy(pitch)
164 | else:
165 | pitch = predictions[2][i, :mel_len].detach()
166 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
167 | energy = predictions[3][i, :src_len].detach().cpu()
168 | energy = expand(energy, duration)
169 | energy = torch.from_numpy(energy)
170 | else:
171 | energy = predictions[3][i, :mel_len].detach()
172 |
173 | all_mel.append(mel_prediction)
174 | all_durations.append(duration)
175 | all_pitch.append(pitch)
176 | all_energy.append(energy)
177 |
178 | return all_mel, all_durations, all_pitch, all_energy
179 |
180 |
181 | def plot_mel(data, stats, titles):
182 | fig, axes = plt.subplots(len(data), 1, squeeze=False)
183 | if titles is None:
184 | titles = [None for i in range(len(data))]
185 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
186 | pitch_min = pitch_min * pitch_std + pitch_mean
187 | pitch_max = pitch_max * pitch_std + pitch_mean
188 |
189 | def add_axis(fig, old_ax):
190 | ax = fig.add_axes(old_ax.get_position(), anchor="W")
191 | ax.set_facecolor("None")
192 | return ax
193 |
194 | for i in range(len(data)):
195 | mel, pitch, energy = data[i]
196 | pitch = pitch * pitch_std + pitch_mean
197 | axes[i][0].imshow(mel, origin="lower")
198 | axes[i][0].set_aspect(2.5, adjustable="box")
199 | axes[i][0].set_ylim(0, mel.shape[0])
200 | axes[i][0].set_title(titles[i], fontsize="medium")
201 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
202 | axes[i][0].set_anchor("W")
203 |
204 | ax1 = add_axis(fig, axes[i][0])
205 | ax1.plot(pitch, color="tomato")
206 | ax1.set_xlim(0, mel.shape[1])
207 | ax1.set_ylim(0, pitch_max)
208 | ax1.set_ylabel("F0", color="tomato")
209 | ax1.tick_params(
210 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
211 | )
212 |
213 | ax2 = add_axis(fig, axes[i][0])
214 | ax2.plot(energy, color="darkviolet")
215 | ax2.set_xlim(0, mel.shape[1])
216 | ax2.set_ylim(energy_min, energy_max)
217 | ax2.set_ylabel("Energy", color="darkviolet")
218 | ax2.yaxis.set_label_position("right")
219 | ax2.tick_params(
220 | labelsize="x-small",
221 | colors="darkviolet",
222 | bottom=False,
223 | labelbottom=False,
224 | left=False,
225 | labelleft=False,
226 | right=True,
227 | labelright=True,
228 | )
229 |
230 | return fig
231 |
232 |
233 | def pad_1D(inputs, PAD=0):
234 | def pad_data(x, length, PAD):
235 | x_padded = np.pad(
236 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
237 | )
238 | return x_padded
239 |
240 | max_len = max((len(x) for x in inputs))
241 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
242 |
243 | return padded
244 |
245 |
246 | def pad_2D(inputs, maxlen=None):
247 | def pad(x, max_len):
248 | PAD = 0
249 | if np.shape(x)[0] > max_len:
250 | raise ValueError("not max_len")
251 |
252 | s = np.shape(x)[1]
253 | x_padded = np.pad(
254 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
255 | )
256 | return x_padded[:, :s]
257 |
258 | if maxlen:
259 | output = np.stack([pad(x, maxlen) for x in inputs])
260 | else:
261 | max_len = max(np.shape(x)[0] for x in inputs)
262 | output = np.stack([pad(x, max_len) for x in inputs])
263 |
264 | return output
265 |
266 |
267 | def pad(input_ele, mel_max_length=None):
268 | if mel_max_length:
269 | max_len = mel_max_length
270 | else:
271 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
272 |
273 | out_list = list()
274 | for i, batch in enumerate(input_ele):
275 | if len(batch.shape) == 1:
276 | one_batch_padded = F.pad(
277 | batch, (0, max_len - batch.size(0)), "constant", 0.0
278 | )
279 | elif len(batch.shape) == 2:
280 | one_batch_padded = F.pad(
281 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
282 | )
283 | out_list.append(one_batch_padded)
284 | out_padded = torch.stack(out_list)
285 | return out_padded
286 |
287 | def load_yaml_file(path):
288 | ## define custom tag handler
289 | def join(loader, node):
290 | seq = loader.construct_sequence(node)
291 | return str(os.path.join(*[str(i) for i in seq]))
292 |
293 | ## register the tag handler
294 | yaml.add_constructor('!join', join)
295 | data = yaml.load(open(path, "r"), Loader=yaml.FullLoader)
296 | return data
--------------------------------------------------------------------------------
/synthesizer/vocoder:
--------------------------------------------------------------------------------
1 | ../vocoder
--------------------------------------------------------------------------------
/train_resgrad.py:
--------------------------------------------------------------------------------
1 | from utils import load_yaml_file
2 | from resgrad.train import resgrad_train
3 |
4 | import argparse
5 |
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--restore_step", type=int, default=0)
10 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml")
11 | args = parser.parse_args()
12 |
13 | # Read Config
14 | config = load_yaml_file(args.config)
15 | resgrad_config = config['resgrad']
16 | resgrad_config['main'] = config['main']
17 | resgrad_train(args, resgrad_config)
--------------------------------------------------------------------------------
/train_synthesizer.py:
--------------------------------------------------------------------------------
1 | from synthesizer.train import train_model
2 | from utils import load_yaml_file
3 |
4 | import argparse
5 |
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--restore_step", type=int, default=0)
10 | parser.add_argument("--config", type=str, default='config/Persian/config.yaml', required=False, help="path to config.yaml")
11 | args = parser.parse_args()
12 |
13 | # Read Config
14 | config = load_yaml_file(args.config)
15 | train_model(args, config)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from synthesizer.utils.model import get_model as load_synthesizer_model
2 | from resgrad.utils import load_model as load_resgrad_model
3 | from vocoder.utils import get_vocoder
4 |
5 | from synthesizer.utils.tools import plot_mel
6 |
7 | import os
8 | import json
9 | from scipy.io import wavfile
10 | from matplotlib import pyplot as plt
11 | import yaml
12 | import time
13 |
14 | def save_result(mel_prediction, wav, pitch, energy, preprocess_config, result_dir, file_name):
15 | with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")) as f:
16 | stats = json.load(f)
17 | stats = stats["pitch"] + stats["energy"][:2]
18 | fig = plot_mel([(mel_prediction.cpu().numpy(), pitch.cpu().numpy(), energy.cpu().numpy())], stats, ["Synthetized Spectrogram"])
19 | plt.savefig(os.path.join(result_dir, "{}.png".format(file_name)))
20 | plt.close()
21 |
22 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
23 | wavfile.write(os.path.join(result_dir, "{}.wav".format(file_name)), sampling_rate, wav)
24 |
25 |
26 | def load_models(all_restore_step, config):
27 | synthesizer_model, resgrad_model, vocoder_model = None, None, None
28 | device = config['main']['device']
29 | if all_restore_step['synthesizer'] not in [None, 0]:
30 | synthesizer_model = load_synthesizer_model(all_restore_step['synthesizer'], config).to(device)
31 | if all_restore_step['vocoder'] not in [None, 0]:
32 | vocoder_model = get_vocoder(config['vocoder'], device).to(device)
33 | if all_restore_step['resgrad'] not in [None, 0]:
34 | resgrad_model = load_resgrad_model(config['resgrad'], train=False, restore_model_step=all_restore_step['resgrad'], device=device).to(device)
35 | return synthesizer_model, resgrad_model, vocoder_model
36 |
37 | def load_yaml_file(path):
38 | ## define custom tag handler
39 | def join(loader, node):
40 | seq = loader.construct_sequence(node)
41 | return str(os.path.join(*[str(i) for i in seq]))
42 |
43 | ## register the tag handler
44 | yaml.add_constructor('!join', join)
45 | data = yaml.load(open(path, "r"), Loader=yaml.FullLoader)
46 | return data
47 |
48 | def get_file_name(args):
49 | file_name_parts = []
50 | if args.result_file_name:
51 | file_name_parts.append(args.result_file_name)
52 | if args.speaker_id:
53 | file_name_parts.append("spk" + str(args.speaker_id))
54 | if len(file_name_parts) == 0:
55 | file_name_parts.append(str(time.time()).replace('.', '_'))
56 | file_name_parts.append("FastSpeech")
57 | file_name = "_".join(file_name_parts)
58 | return file_name
--------------------------------------------------------------------------------
/vocoder/ckpt/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/vocoder/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def infer(model, mel_spectrum, max_wav_value):
5 | if len(mel_spectrum.shape) == 2:
6 | mel_spectrum = mel_spectrum.unsqueeze(0)
7 | with torch.no_grad():
8 | wav = model(mel_spectrum).squeeze(1)[0]
9 | wav = wav.cpu().numpy() * max_wav_value * 0.97
10 | wav = wav.astype("int16")
11 | return wav
--------------------------------------------------------------------------------
/vocoder/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv1d, ConvTranspose1d
5 | from torch.nn.utils import weight_norm, remove_weight_norm
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | def init_weights(m, mean=0.0, std=0.01):
11 | classname = m.__class__.__name__
12 | if classname.find("Conv") != -1:
13 | m.weight.data.normal_(mean, std)
14 |
15 |
16 | def get_padding(kernel_size, dilation=1):
17 | return int((kernel_size * dilation - dilation) / 2)
18 |
19 |
20 | class ResBlock(torch.nn.Module):
21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22 | super(ResBlock, self).__init__()
23 | self.h = h
24 | self.convs1 = nn.ModuleList(
25 | [
26 | weight_norm(
27 | Conv1d(
28 | channels,
29 | channels,
30 | kernel_size,
31 | 1,
32 | dilation=dilation[0],
33 | padding=get_padding(kernel_size, dilation[0]),
34 | )
35 | ),
36 | weight_norm(
37 | Conv1d(
38 | channels,
39 | channels,
40 | kernel_size,
41 | 1,
42 | dilation=dilation[1],
43 | padding=get_padding(kernel_size, dilation[1]),
44 | )
45 | ),
46 | weight_norm(
47 | Conv1d(
48 | channels,
49 | channels,
50 | kernel_size,
51 | 1,
52 | dilation=dilation[2],
53 | padding=get_padding(kernel_size, dilation[2]),
54 | )
55 | ),
56 | ]
57 | )
58 | self.convs1.apply(init_weights)
59 |
60 | self.convs2 = nn.ModuleList(
61 | [
62 | weight_norm(
63 | Conv1d(
64 | channels,
65 | channels,
66 | kernel_size,
67 | 1,
68 | dilation=1,
69 | padding=get_padding(kernel_size, 1),
70 | )
71 | ),
72 | weight_norm(
73 | Conv1d(
74 | channels,
75 | channels,
76 | kernel_size,
77 | 1,
78 | dilation=1,
79 | padding=get_padding(kernel_size, 1),
80 | )
81 | ),
82 | weight_norm(
83 | Conv1d(
84 | channels,
85 | channels,
86 | kernel_size,
87 | 1,
88 | dilation=1,
89 | padding=get_padding(kernel_size, 1),
90 | )
91 | ),
92 | ]
93 | )
94 | self.convs2.apply(init_weights)
95 |
96 | def forward(self, x):
97 | for c1, c2 in zip(self.convs1, self.convs2):
98 | xt = F.leaky_relu(x, LRELU_SLOPE)
99 | xt = c1(xt)
100 | xt = F.leaky_relu(xt, LRELU_SLOPE)
101 | xt = c2(xt)
102 | x = xt + x
103 | return x
104 |
105 | def remove_weight_norm(self):
106 | for l in self.convs1:
107 | remove_weight_norm(l)
108 | for l in self.convs2:
109 | remove_weight_norm(l)
110 |
111 |
112 | class Generator(torch.nn.Module):
113 | def __init__(self, h):
114 | super(Generator, self).__init__()
115 | self.h = h
116 | self.num_kernels = len(h.resblock_kernel_sizes)
117 | self.num_upsamples = len(h.upsample_rates)
118 | self.conv_pre = weight_norm(
119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
120 | )
121 | resblock = ResBlock
122 |
123 | self.ups = nn.ModuleList()
124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125 | self.ups.append(
126 | weight_norm(
127 | ConvTranspose1d(
128 | h.upsample_initial_channel // (2 ** i),
129 | h.upsample_initial_channel // (2 ** (i + 1)),
130 | k,
131 | u,
132 | padding=(k - u) // 2,
133 | )
134 | )
135 | )
136 |
137 | self.resblocks = nn.ModuleList()
138 | for i in range(len(self.ups)):
139 | ch = h.upsample_initial_channel // (2 ** (i + 1))
140 | for j, (k, d) in enumerate(
141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142 | ):
143 | self.resblocks.append(resblock(h, ch, k, d))
144 |
145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146 | self.ups.apply(init_weights)
147 | self.conv_post.apply(init_weights)
148 |
149 | def forward(self, x):
150 | x = self.conv_pre(x)
151 | for i in range(self.num_upsamples):
152 | x = F.leaky_relu(x, LRELU_SLOPE)
153 | x = self.ups[i](x)
154 | xs = None
155 | for j in range(self.num_kernels):
156 | if xs is None:
157 | xs = self.resblocks[i * self.num_kernels + j](x)
158 | else:
159 | xs += self.resblocks[i * self.num_kernels + j](x)
160 | x = xs / self.num_kernels
161 | x = F.leaky_relu(x)
162 | x = self.conv_post(x)
163 | x = torch.tanh(x)
164 |
165 | return x
166 |
167 | def remove_weight_norm(self):
168 | for l in self.ups:
169 | remove_weight_norm(l)
170 | for l in self.resblocks:
171 | l.remove_weight_norm()
172 | remove_weight_norm(self.conv_pre)
173 | remove_weight_norm(self.conv_post)
--------------------------------------------------------------------------------
/vocoder/utils.py:
--------------------------------------------------------------------------------
1 |
2 | from vocoder.models import Generator
3 | import json
4 | import torch
5 |
6 | class AttrDict(dict):
7 | def __init__(self, *args, **kwargs):
8 | super(AttrDict, self).__init__(*args, **kwargs)
9 | self.__dict__ = self
10 |
11 | def get_vocoder(config, device):
12 | with open("vocoder/ckpt/config.json", "r") as f:
13 | model_config = json.load(f)
14 | model_config = AttrDict(model_config)
15 | vocoder = Generator(model_config)
16 |
17 | # if config['restore_step']:
18 | # ckpt = torch.load(f"vocoder/ckpt/g_{config['restore_step']}", map_location=config['device'])
19 | # else:
20 | ckpt = torch.load(f"vocoder/ckpt/{config['model_name']}", map_location=device)
21 |
22 | vocoder.load_state_dict(ckpt["generator"])
23 | vocoder.eval()
24 | vocoder.remove_weight_norm()
25 | vocoder.to(device)
26 | return vocoder
27 |
28 |
--------------------------------------------------------------------------------