├── .gitignore ├── LICENSE ├── README.md ├── config └── diffvar │ ├── model.yaml │ ├── preprocess.yaml │ └── train.yaml ├── dataset.py ├── hfgan ├── env.py ├── mel_extractor │ ├── LICENSE │ ├── README.md │ ├── config16k.json │ ├── config22k.json │ └── mel.py ├── meldataset.py ├── models.py ├── utils.py ├── vocoder.py └── vocoderutils.py ├── img └── model.png ├── model ├── __init__.py ├── diffnet.py ├── diffspeech.py ├── diffvar.py ├── fastspeech2.py ├── loss.py ├── modules.py └── optimizer.py ├── requirements.txt ├── synthesize.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py ├── pinyin.py ├── symbols.py └── zjl_symbols.py ├── train.py ├── transformer ├── Constants.py ├── Layers.py ├── Models.py ├── Modules.py ├── SubLayers.py └── __init__.py └── utils ├── model.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | __pycache__ 107 | .vscode 108 | .DS_Store 109 | 110 | # MFA 111 | montreal-forced-aligner/ 112 | 113 | # data, checkpoint, and models 114 | raw_data/ 115 | output/ 116 | *.npy 117 | TextGrid/ 118 | hifigan/*.pth.tar 119 | eval/ 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 THUHCSI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffVar - PyTorch Implementation 2 | 3 | This is a PyTorch implementation of Interspeech 2023 paper [**Diverse and Expressive Speech Prosody Prediction with Denoising Diffusion Probabilistic Model**](https://arxiv.org/abs/2305.16749). 4 | 5 | ![](./img/model.png) 6 | 7 | # Audio Samples 8 | Audio samples generated by this implementation can be found [here](https://thuhcsi.github.io/interspeech2023-DiffVar/). 9 | 10 | # References 11 | - [ming24's FastSpeech2 implementation](https://github.com/ming024/Fastspeech2) 12 | - [Official DiffSpeech implementation](https://github.com/NATSpeech/NATSpeech) 13 | 14 | # Citation 15 | ``` 16 | @misc{li2023diverse, 17 | title={Diverse and Expressive Speech Prosody Prediction with Denoising Diffusion Probabilistic Model}, 18 | author={Xiang Li and Songxiang Liu and Max W. Y. Lam and Zhiyong Wu and Chao Weng and Helen Meng}, 19 | year={2023}, 20 | eprint={2305.16749}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.SD} 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /config/diffvar/model.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 4 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 6 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | multi_speaker: True 24 | 25 | max_seq_len: 1000 26 | 27 | vocoder: 28 | model: "hifigan" 29 | speaker: "universal" 30 | ckpt: "/path/to/hifi-gan/ckpt" 31 | 32 | diffusion: 33 | in_dim: 3 34 | # special configs for diffspeech 35 | timesteps: 500 36 | diff_loss_type: l2 37 | schedule_type: 'linear' 38 | max_beta: 0.06 39 | 40 | ## model configs for diffspeech 41 | diff_decoder_type: 'wavenet' 42 | dilation_cycle_length: 1 43 | residual_layers: 10 44 | residual_channels: 64 45 | 46 | ## normalize range: pitch, energy, duration 47 | x_max: [1, 1, 1] 48 | x_min: [-1., -1., -1.] 49 | clip_denoised: True -------------------------------------------------------------------------------- /config/diffvar/preprocess.yaml: -------------------------------------------------------------------------------- 1 | dataset: "name" 2 | 3 | path: 4 | preprocessed_path: "/path/to/preprocessed_data" 5 | variance_path: "/path/to/prosody/feature" 6 | 7 | preprocessing: 8 | val_size: 512 9 | text: 10 | text_cleaners: [] 11 | language: "zh" 12 | audio: 13 | sampling_rate: 16000 14 | max_wav_value: 32768.0 15 | stft: 16 | filter_length: 2048 17 | hop_length: 200 18 | win_length: 800 19 | mel: 20 | n_mel_channels: 80 21 | mel_fmin: 0 22 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 23 | pitch: 24 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 25 | normalization: True 26 | energy: 27 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 28 | normalization: True 29 | -------------------------------------------------------------------------------- /config/diffvar/train.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt" 3 | log_path: "./output/log" 4 | result_path: "./output/result" 5 | optimizer: 6 | batch_size: 16 7 | betas: [0.9, 0.98] 8 | eps: 0.000000001 9 | weight_decay: 0.0 10 | grad_clip_thresh: 1.0 11 | grad_acc_step: 1 12 | warm_up_step: 4000 13 | anneal_steps: [300000, 400000, 500000] 14 | anneal_rate: 0.3 15 | step: 16 | total_step: 900000 17 | log_step: 500 18 | synth_step: 1000 19 | val_step: 5000 20 | save_step: 5000 21 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | from text import text_to_sequence 9 | from utils.tools import pad_1D, pad_2D, drop_idxes 10 | 11 | 12 | class Dataset(Dataset): 13 | def __init__( 14 | self, filename, preprocess_config, train_config, sort=False, drop_last=False 15 | ): 16 | self.dataset_name = preprocess_config["dataset"] 17 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 18 | self.variance_path = preprocess_config["path"]["variance_path"] 19 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 20 | self.batch_size = train_config["optimizer"]["batch_size"] 21 | 22 | self.basename, self.speaker, self.text, self.raw_text = self.process_meta( 23 | filename 24 | ) 25 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: 26 | self.speaker_map = json.load(f) 27 | self.sort = sort 28 | self.drop_last = drop_last 29 | 30 | def __len__(self): 31 | return len(self.text) 32 | 33 | def __getitem__(self, idx): 34 | basename = self.basename[idx] 35 | speaker = self.speaker[idx] 36 | speaker_id = self.speaker_map[speaker] 37 | raw_text = self.raw_text[idx] 38 | # phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 39 | # mel_path = os.path.join( 40 | # self.preprocessed_path, 41 | # "mel", 42 | # "{}-mel-{}.npy".format(speaker, basename), 43 | # ) 44 | # pitch_path = os.path.join( 45 | # self.preprocessed_path, 46 | # "pitch", 47 | # "{}-pitch-{}.npy".format(speaker, basename), 48 | # ) 49 | # pitch = np.load(pitch_path) 50 | # energy_path = os.path.join( 51 | # self.preprocessed_path, 52 | # "energy", 53 | # "{}-energy-{}.npy".format(speaker, basename), 54 | # ) 55 | # energy = np.load(energy_path) 56 | # duration_path = os.path.join( 57 | # self.preprocessed_path, 58 | # "duration", 59 | # "{}-duration-{}.npy".format(speaker, basename), 60 | # ) 61 | # duration = np.load(duration_path) 62 | 63 | phone, quasi_flag = np.array(text_to_sequence(self.text[idx], self.cleaners)).T 64 | mel = np.load(f"/home/zhousp/DB-para/mels/mel-{basename}.npy") 65 | duration, pitch, energy = np.load(f"{self.variance_path}/{basename}.npy").T 66 | duration = duration[np.where(quasi_flag==0)] 67 | pitch = pitch[np.where(quasi_flag==0)] 68 | energy = energy[np.where(quasi_flag==0)] 69 | # print(phone.shape, quasi_flag.shape, duration.shape, energy.shape, pitch.shape) 70 | 71 | sample = { 72 | "id": basename, 73 | "speaker": speaker_id, 74 | "text": phone, 75 | "raw_text": raw_text, 76 | "mel": mel, 77 | "pitch": pitch, 78 | "energy": energy, 79 | "duration": duration, 80 | "quasi_flag": quasi_flag, 81 | } 82 | 83 | return sample 84 | 85 | def process_meta(self, filename): 86 | with open( 87 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8" 88 | ) as f: 89 | name = [] 90 | speaker = [] 91 | text = [] 92 | raw_text = [] 93 | for line in f.readlines(): 94 | n, s, t, r = line.strip("\n").split("|") 95 | name.append(n) 96 | speaker.append(s) 97 | text.append(t) 98 | raw_text.append(r) 99 | return name, speaker, text, raw_text 100 | 101 | def reprocess(self, data, idxs): 102 | ids = [data[idx]["id"] for idx in idxs] 103 | speakers = [data[idx]["speaker"] for idx in idxs] 104 | texts = [data[idx]["text"] for idx in idxs] 105 | raw_texts = [data[idx]["raw_text"] for idx in idxs] 106 | mels = [data[idx]["mel"] for idx in idxs] 107 | pitches = [data[idx]["pitch"] for idx in idxs] 108 | energies = [data[idx]["energy"] for idx in idxs] 109 | durations = [data[idx]["duration"] for idx in idxs] 110 | 111 | text_lens = np.array([text.shape[0] for text in texts]) 112 | mel_lens = np.array([mel.shape[0] for mel in mels]) 113 | 114 | speakers = np.array(speakers) 115 | texts = pad_1D(texts) 116 | mels = pad_2D(mels) 117 | pitches = pad_1D(pitches) 118 | energies = pad_1D(energies) 119 | durations = pad_1D(durations) 120 | 121 | quasi_flags = [data[idx]["quasi_flag"] for idx in idxs] 122 | quasi_flags = pad_1D(quasi_flags) 123 | 124 | return ( 125 | ids, 126 | raw_texts, 127 | speakers, 128 | texts, 129 | text_lens, 130 | max(text_lens), 131 | mels, 132 | mel_lens, 133 | max(mel_lens), 134 | pitches, 135 | energies, 136 | durations, 137 | quasi_flags, 138 | ) 139 | 140 | def collate_fn(self, data): 141 | data_size = len(data) 142 | 143 | if self.sort: 144 | len_arr = np.array([d["text"].shape[0] for d in data]) 145 | idx_arr = np.argsort(-len_arr) 146 | else: 147 | idx_arr = np.arange(data_size) 148 | 149 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :] 150 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)] 151 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() 152 | if not self.drop_last and len(tail) > 0: 153 | idx_arr += [tail.tolist()] 154 | 155 | output = list() 156 | for idx in idx_arr: 157 | output.append(self.reprocess(data, idx)) 158 | 159 | return output 160 | 161 | 162 | class TextDataset(Dataset): 163 | def __init__(self, filepath, preprocess_config): 164 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 165 | 166 | self.basename, self.speaker, self.text, self.raw_text = self.process_meta( 167 | filepath 168 | ) 169 | with open( 170 | os.path.join( 171 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 172 | ) 173 | ) as f: 174 | self.speaker_map = json.load(f) 175 | 176 | def __len__(self): 177 | return len(self.text) 178 | 179 | def __getitem__(self, idx): 180 | basename = self.basename[idx] 181 | speaker = self.speaker[idx] 182 | speaker_id = self.speaker_map[speaker] 183 | raw_text = self.raw_text[idx] 184 | # phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 185 | phone, quasi_flag = np.array(text_to_sequence(self.text[idx], self.cleaners)).T 186 | 187 | return (basename, speaker_id, phone, raw_text, quasi_flag) 188 | 189 | def process_meta(self, filename): 190 | with open(filename, "r", encoding="utf-8") as f: 191 | name = [] 192 | speaker = [] 193 | text = [] 194 | raw_text = [] 195 | for line in f.readlines(): 196 | n, s, t, r = line.strip("\n").split("|") 197 | name.append(n) 198 | speaker.append(s) 199 | text.append(t) 200 | raw_text.append(r) 201 | return name, speaker, text, raw_text 202 | 203 | def collate_fn(self, data): 204 | ids = [d[0] for d in data] 205 | speakers = np.array([d[1] for d in data]) 206 | texts = [d[2] for d in data] 207 | raw_texts = [d[3] for d in data] 208 | text_lens = np.array([text.shape[0] for text in texts]) 209 | 210 | texts = pad_1D(texts) 211 | 212 | quasi_flag = [d[4] for d in data] 213 | quasi_flag = pad_1D(quasi_flag) 214 | 215 | return ids, raw_texts, speakers, texts, text_lens, max(text_lens), quasi_flag 216 | 217 | 218 | if __name__ == "__main__": 219 | # Test 220 | import torch 221 | import yaml 222 | from torch.utils.data import DataLoader 223 | from utils.tools import to_device 224 | 225 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 226 | preprocess_config = yaml.load( 227 | open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader 228 | ) 229 | train_config = yaml.load( 230 | open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader 231 | ) 232 | 233 | train_dataset = Dataset( 234 | "train.txt", preprocess_config, train_config, sort=True, drop_last=True 235 | ) 236 | val_dataset = Dataset( 237 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False 238 | ) 239 | train_loader = DataLoader( 240 | train_dataset, 241 | batch_size=train_config["optimizer"]["batch_size"] * 4, 242 | shuffle=True, 243 | collate_fn=train_dataset.collate_fn, 244 | ) 245 | val_loader = DataLoader( 246 | val_dataset, 247 | batch_size=train_config["optimizer"]["batch_size"], 248 | shuffle=False, 249 | collate_fn=val_dataset.collate_fn, 250 | ) 251 | 252 | n_batch = 0 253 | for batchs in train_loader: 254 | for batch in batchs: 255 | to_device(batch, device) 256 | n_batch += 1 257 | print( 258 | "Training set with size {} is composed of {} batches.".format( 259 | len(train_dataset), n_batch 260 | ) 261 | ) 262 | 263 | n_batch = 0 264 | for batchs in val_loader: 265 | for batch in batchs: 266 | to_device(batch, device) 267 | n_batch += 1 268 | print( 269 | "Validation set with size {} is composed of {} batches.".format( 270 | len(val_dataset), n_batch 271 | ) 272 | ) -------------------------------------------------------------------------------- /hfgan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /hfgan/mel_extractor/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 THUHCSI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hfgan/mel_extractor/README.md: -------------------------------------------------------------------------------- 1 | # mel_extractor 2 | Extract Mel & Linear spectrogram from wav file. (Depending on `librosa` library) 3 | 4 | There are 3 functions provided in this module to help extract **Mel-spectrogram** and **Linear spectrogram** from given input waveform. 5 | 6 | 1. `wav2mel_npy(wav, **hparams)` 7 | 8 | **Basic function to extract spectrogram from one, or a batch of wavs.** 9 | 10 | It **takes**: 11 | - input wav(s) in `numpy.ndarray` (typically loaded with `librosa`), whose shape might be: 12 | - `(#sample_points,)`: for one single wav 13 | - or `(#batch_size, #sample_points)`: for one batch of wavs 14 | - **DO note that**: input wavs within the same batch need to be reformed(clipping/padding) into the same shape, before stacked together. 15 | - Currently, wavs within one single batch are handle one by one. To process a large amount of wave files, the following function `wav2mel` and Python Concurrency features are recommended. 16 | - parameters related to the spectrogram extracting process (for details about each parameter item, please refer to the comment about the function in script `mel.py`) 17 | 18 | and **returns** a tuple made up by 3 `numpy.ndarray` objects: 19 | - the extracted Mel-spectrogram, whose shape is `([#batch_size, ]#spec_frame, @n_mels)` 20 | - the extracted Linear spectrogram, whose shape is `([#batch_size, ]#spec_frame, @n_freq)` 21 | - the waveform data (after preprocess), whose shape is `([#batch_size, ]#sample_points,)` 22 | 23 | 24 | 2. `wav2mel(wavpath, **hparams)` 25 | 26 | **A wrapped up version of the previous function. Usually used for dataset preprocess.** 27 | 28 | Before calling `wav2mel_npy`, the waveform data is first load from wave file with `librosa`. 29 | 30 | 31 | 3. `wav2mel_config(wavpath, config_path)` 32 | 33 | **A wrapped up version of the previous function.** 34 | 35 | Instead of taking dozens of parameters, it only requires the path of the **JSON** configuration file (`config16k.json` and `config22k.json` are shown as examples). 36 | 37 | -------------------------------------------------------------------------------- /hfgan/mel_extractor/config16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "sr": 16000, 3 | 4 | "wav_pad": true, 5 | "wav_pad_mode": 1, 6 | "wav_pad_val": 0.0, 7 | 8 | "wav_rescale": false, 9 | "wav_rescale_max": 0.95, 10 | 11 | "pre_emph": true, 12 | "pre_emph_cof": 0.85, 13 | 14 | "n_fft": 2048, 15 | "hop_size": 200, 16 | "win_size": 800, 17 | 18 | "mag_pow": 1.0, 19 | 20 | "n_mels": 80, 21 | 22 | "fmin": 0.0, 23 | "fmax": 8000.0, 24 | 25 | "spec_ref_db": 20, 26 | "spec_min_db": -115, 27 | 28 | "spec_norm": true, 29 | "spec_max": 4.0, 30 | "spec_sym": true, 31 | "spec_clip": true 32 | } -------------------------------------------------------------------------------- /hfgan/mel_extractor/config22k.json: -------------------------------------------------------------------------------- 1 | { 2 | "sr": 22050, 3 | 4 | "wav_pad": true, 5 | "wav_pad_mode": 1, 6 | "wav_pad_val": 0.0, 7 | 8 | "wav_rescale": false, 9 | "wav_rescale_max": 0.95, 10 | 11 | "pre_emph": true, 12 | "pre_emph_cof": 0.85, 13 | 14 | "n_fft": 2048, 15 | "hop_size": 256, 16 | "win_size": 1024, 17 | 18 | "mag_pow": 1.0, 19 | 20 | "n_mels": 80, 21 | "fmin": 0.0, 22 | "fmax": 8000.0, 23 | 24 | "spec_ref_db": 20, 25 | "spec_min_db": -115, 26 | 27 | "spec_norm": true, 28 | "spec_max": 4.0, 29 | "spec_sym": true, 30 | "spec_clip": true 31 | } -------------------------------------------------------------------------------- /hfgan/mel_extractor/mel.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | from librosa.core import load as loadwav 5 | from librosa.core import stft 6 | from librosa.filters import mel as mel_basis 7 | from librosa.util import normalize 8 | from scipy import signal 9 | 10 | 11 | def _wav_addpadding(wav, hop_size:int, pad_mode:int, pad_val:float): 12 | pad_size = (wav.shape[0] // hop_size + 1) * hop_size - wav.shape[0] 13 | if pad_mode == 1: 14 | l_pad = 0 15 | r_pad = pad_size 16 | else: 17 | l_pad = pad_size // 2 18 | r_pad = pad_size // 2 + pad_size % 2 19 | return np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=pad_val) 20 | 21 | def _wav_preemphasis(wav, k): 22 | return signal.lfilter([1, -k], [1], wav) 23 | 24 | def _spec_normalize(S, clip:bool, symmetric:bool, max_val:float, min_db:float): 25 | if clip: 26 | if symmetric: 27 | return np.clip((2 * max_val) * ((S - min_db) / (-min_db)) - max_val, 28 | -max_val, max_val) 29 | else: 30 | return np.clip(max_val * ((S - min_db) / (-min_db)), 0, max_val) 31 | 32 | assert S.max() <= 0 and S.min() - min_db >= 0 33 | 34 | if symmetric: 35 | return (2 * max_val) * ((S - min_db) / (-min_db)) - max_val 36 | else: 37 | return max_val * ((S - min_db) / (-min_db)) 38 | 39 | def wav2mel_npy(wav, sr=16000, \ 40 | wav_pad=True, wav_pad_mode=1, wav_pad_val=0., \ 41 | wav_rescale=False, wav_rescale_max=0.95, \ 42 | pre_emph=True, pre_emph_cof=0.85, \ 43 | n_fft=2048, hop_size=200, win_size=800, mag_pow=1., \ 44 | n_mels=80, fmin=0., fmax=8000., \ 45 | spec_ref_db=20, spec_min_db=-115, \ 46 | spec_norm=True, spec_max=4., spec_sym=True, spec_clip=True): 47 | """ Extract Mel-spectrogram from wav file 48 | 49 | Parameters 50 | ---------- 51 | wavpath : string, pathlib.Path object 52 | path of the source wavfile 53 | 54 | sr : int 55 | target sampling rate 56 | 57 | wav_pad : bool 58 | whether to pad the input waveform for STFT 59 | 60 | wav_pad_mode : integer with value 0 or 1 61 | 0: add paddings to both sides of waveform 62 | 1: add padding to the tail of waveform 63 | 64 | wav_pad_val : float 65 | value used as padding 66 | (default: zero-padding) 67 | 68 | wav_rescale : bool 69 | whether to normalize input waveform into 70 | range [0, @wav_rescale_max] 71 | 72 | wav_rescale_max : float 73 | upper bound of the rescaled waveform 74 | 75 | pre_emph : bool 76 | whether to pre-emphasize waveform before STFT 77 | 78 | pre_emph_cof : float 79 | coefficient of pre-emphasis 80 | 81 | n_fft : int > 0 82 | length of the windowed signal after padding with zeros 83 | 84 | hop_size : int > 0 85 | number of audio samples between adjacent STFT columns 86 | 87 | win_size : int > 0 88 | Each frame of audio is windowed by `window()` of length `win_length` 89 | and then padded with zeros to match `n_fft` 90 | 91 | mag_pow : float > 0 92 | Exponent for the magnitude spectrogram, 93 | e.g., 1 for energy, 2 for power, etc. 94 | 95 | n_mels : int > 0 96 | number of Mel filter bands 97 | 98 | fmin : float >= 0 99 | lowest frequency (in Hz) of Mel filter 100 | 101 | fmax : float >= 0 102 | highest frequency (in Hz) of Mel filter 103 | If `None`, use `fmax = sr / 2.0` 104 | 105 | spec_ref_db : int 106 | "reference" level of mel spectrogram in dB, 107 | which is subtracted from the raw spectrogram 108 | 109 | spec_min_db : int 110 | minimum of the spectrogram in dB 111 | values below it are clipped 112 | 113 | spec_norm : bool 114 | whether to normalize generated spectrogram 115 | 116 | spec_max : float 117 | the supremum of the absolute value of normalized spectrogram 118 | 119 | spec_sym : bool 120 | whether to normalize spectrogram to a range 121 | that is symmetric by 0. 122 | i.e. eihter [-spec_max, spec_max], or [0, spec_max] 123 | 124 | spec_clip : bool 125 | whethter to allow clipping while doing normalization 126 | 127 | Returns 128 | ------- 129 | mel: np.ndarray [shape=(#spec_frame, @n_mels)] 130 | Extracted Mel spectrogram 131 | 132 | linear: np.ndarray [shape=(#spec_frame, #n_freq), with #n_freq = @n_fft / 2 + 1] 133 | Extracted Linear spectrogram 134 | 135 | wav: np.ndarray [shape=(#sample_point,)] 136 | Waveform after preprocessing 137 | 138 | """ 139 | 140 | # 1. if a batch of wave array is provided 141 | # do wav2mel for all wav samples in one single batch 142 | if len(wav.shape) > 1: 143 | result = [wav2mel_npy(single_wav) for single_wav in wav] 144 | return [np.stack(items) for items in list(zip(*result))] 145 | 146 | # 2. waveform preprocessing 147 | # 2.1 [optional] padding 148 | if wav_pad: 149 | wav = _wav_addpadding(wav, \ 150 | hop_size=hop_size, pad_mode=wav_pad_mode, pad_val=wav_pad_val) 151 | 152 | # 2.2 [optional] normalization 153 | if wav_rescale: 154 | wav = normalize(wav) * wav_rescale_max 155 | 156 | # 3. spectrogram generation 157 | # 3.1 [optional] preemphasize waveform 158 | if pre_emph: 159 | pre_emphed_wav = _wav_preemphasis(wav, pre_emph_cof) 160 | else: 161 | pre_emphed_wav = wav 162 | 163 | # 3.2 magnitude spectrogram retrieve 164 | S = np.abs( \ 165 | stft(pre_emphed_wav, n_fft=n_fft, hop_length=hop_size, win_length=win_size) 166 | ).T ** mag_pow 167 | mel_S = np.dot(S, mel_basis(sr, n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax).T) 168 | 169 | # 3.3 convert magnitude spectrogram to dB-scaled units 170 | min_amp = np.power(10, spec_min_db/20) 171 | linear = 20 * np.log10(np.maximum(min_amp, S)) - spec_ref_db 172 | mel = 20 * np.log10(np.maximum(min_amp, mel_S)) - spec_ref_db 173 | 174 | # 3.4 [optional] spectrogram normalization 175 | if spec_norm: 176 | mel = _spec_normalize(mel, clip=spec_clip, symmetric=spec_sym, \ 177 | max_val=spec_max, min_db=spec_min_db) 178 | linear = _spec_normalize(linear, clip=spec_clip, symmetric=spec_sym, \ 179 | max_val=spec_max, min_db=spec_min_db) 180 | 181 | return mel, linear, wav 182 | 183 | def wav2mel(wavpath, sr=16000, **hparams): 184 | # load wavfile into float waveform 185 | wav, _sr = loadwav(wavpath, sr=sr) 186 | return wav2mel_npy(wav, sr=_sr, **hparams) 187 | 188 | 189 | def wav2mel_config(wavpath, config_path): 190 | with open(config_path, 'r') as f: 191 | hparams = json.load(f) 192 | return wav2mel(wavpath, **hparams) 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /hfgan/meldataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import json 7 | import librosa 8 | import numpy as np 9 | import torch 10 | import torch.utils.data 11 | from librosa.filters import mel as librosa_mel_fn 12 | from librosa.util import normalize 13 | from scipy.io.wavfile import read 14 | 15 | MAX_WAV_VALUE = 32768.0 16 | 17 | 18 | def load_wav(full_path): 19 | sampling_rate, data = read(full_path) 20 | return data, sampling_rate 21 | 22 | 23 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 24 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 25 | 26 | 27 | def dynamic_range_decompression(x, C=1): 28 | return np.exp(x) / C 29 | 30 | 31 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 32 | return torch.log10(torch.clamp(x, min=clip_val) * C) 33 | 34 | 35 | def dynamic_range_decompression_torch(x, C=1): 36 | return torch.exp(x) / C 37 | 38 | 39 | def spectral_normalize_torch(magnitudes): 40 | output = dynamic_range_compression_torch(magnitudes) 41 | return output 42 | 43 | 44 | def spectral_de_normalize_torch(magnitudes): 45 | output = dynamic_range_decompression_torch(magnitudes) 46 | return output 47 | 48 | 49 | mel_basis = {} 50 | hann_window = {} 51 | 52 | 53 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 54 | # if torch.min(y) < -1.: 55 | # print('min value is ', torch.min(y)) 56 | # if torch.max(y) > 1.: 57 | # print('max value is ', torch.max(y)) 58 | 59 | global mel_basis, hann_window 60 | if fmax not in mel_basis: 61 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 62 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 63 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 64 | 65 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 66 | y = y.squeeze(1) 67 | 68 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 69 | center=center, pad_mode='reflect', normalized=False, onesided=True) 70 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 71 | 72 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 73 | spec = spectral_normalize_torch(spec) 74 | 75 | return spec 76 | 77 | 78 | def get_dataset_filelist(a): 79 | with open(a.input_training_file, 'r', encoding='utf-8') as fi: 80 | training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') 81 | for x in fi.read().split('\n') if len(x) > 0] 82 | 83 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi: 84 | validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') 85 | for x in fi.read().split('\n') if len(x) > 0] 86 | return training_files, validation_files 87 | 88 | def get_dataset_filelist_DB6(a): 89 | with open(a.input_training_file, 'r', encoding='utf-8-sig') as f: 90 | training_files = [l.split("|")[2:4] for l in f.readlines()] 91 | with open(a.input_validation_file, 'r', encoding='utf-8-sig') as f: 92 | validation_files = [l.split("|")[2:4] for l in f.readlines()] 93 | return training_files, validation_files 94 | 95 | def get_dataset_filelist_DBpara(a): 96 | with open(a.input_training_file, 'r', encoding='utf-8-sig') as f: 97 | training_files = [l.split("|")[0]+".npy" for l in f.readlines()] 98 | with open(a.input_validation_file, 'r', encoding='utf-8-sig') as f: 99 | validation_files = [l.split("|")[0]+".npy" for l in f.readlines()] 100 | return training_files, validation_files 101 | 102 | class MelDataset(torch.utils.data.Dataset): 103 | def __init__(self, metadata, segment_size, n_fft, num_mels, 104 | hop_size, win_size, sampling_rate, fmin, fmax, audio_dir, mel_dir=None, split=False, shuffle=True, n_cache_reuse=1, 105 | device=None, fmax_loss=None, fine_tuning=False, win_center=False): 106 | 107 | self.audio_files = metadata 108 | self.audio_dir = Path(audio_dir) 109 | self.mel_dir = Path(mel_dir) 110 | 111 | random.seed(1234) 112 | if shuffle: 113 | random.shuffle(self.audio_files) 114 | 115 | self.segment_size = segment_size 116 | self.sampling_rate = sampling_rate 117 | self.split = split 118 | self.n_fft = n_fft 119 | self.num_mels = num_mels 120 | self.hop_size = hop_size 121 | self.win_size = win_size 122 | self.fmin = fmin 123 | self.fmax = fmax 124 | self.fmax_loss = fmax_loss 125 | 126 | self.cached_wav = None 127 | self.n_cache_reuse = n_cache_reuse 128 | self._cache_ref_count = 0 129 | self.device = device 130 | self.fine_tuning = fine_tuning 131 | self.win_center = win_center 132 | 133 | def __getitem__(self, index): 134 | # modified Aug 25, 2021 135 | # audio_filepath, mel_filepath = self.audio_files[index] 136 | # audio_filepath = self.audio_dir / audio_filepath 137 | # mel_filepath = self.mel_dir / mel_filepath 138 | filepath = self.audio_files[index] 139 | audio_filepath = self.audio_dir / ("audio-"+filepath) 140 | if not audio_filepath.is_file(): 141 | audio_filepath = self.audio_dir / filepath 142 | mel_filepath = self.mel_dir / ("mel-"+filepath) 143 | if not mel_filepath.is_file(): 144 | mel_filepath = self.mel_dir / filepath 145 | 146 | if self._cache_ref_count == 0: 147 | audio = np.load(audio_filepath) 148 | self.cached_wav = audio 149 | self._cache_ref_count = self.n_cache_reuse 150 | else: 151 | audio = self.cached_wav 152 | self._cache_ref_count -= 1 153 | 154 | if not self.fine_tuning: 155 | raise ValueError("Non-fine-tuning training not supported.") 156 | else: 157 | mel = np.load(mel_filepath).T 158 | 159 | # self.split = True : training 160 | # self.split = False : validation/inference (batch size = 1) 161 | if self.split: 162 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 163 | 164 | if audio.shape[0] >= self.segment_size: 165 | mel_start = random.randint(0, mel.shape[1] - frames_per_seg - 1) 166 | mel = mel[:, mel_start:mel_start + frames_per_seg] 167 | audio = audio[mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] 168 | else: 169 | mel = np.pad(mel, ((0,0), (0, frames_per_seg - mel.shape[1])), mode='constant') 170 | audio = np.pad(audio, (0, self.segment_size - audio.shape[0]), mode='constant') 171 | else: 172 | audio = np.pad(audio, (0, mel.shape[1] * self.hop_size - audio.shape[0]), mode='constant') 173 | 174 | audio = torch.FloatTensor(audio) 175 | mel_loss = mel_spectrogram(audio.unsqueeze(0), self.n_fft, self.num_mels, 176 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, 177 | center=self.win_center) 178 | mel = torch.FloatTensor(mel) 179 | 180 | return (mel.squeeze(), audio.squeeze(), mel_loss.squeeze()) 181 | 182 | def __len__(self): 183 | return len(self.audio_files) 184 | -------------------------------------------------------------------------------- /hfgan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d 5 | from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm 6 | 7 | from .utils import get_padding, init_weights 8 | 9 | LRELU_SLOPE = 0.1 10 | 11 | 12 | class ResBlock1(torch.nn.Module): 13 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 14 | super(ResBlock1, self).__init__() 15 | self.h = h 16 | self.convs1 = nn.ModuleList([ 17 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 18 | padding=get_padding(kernel_size, dilation[0]))), 19 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 20 | padding=get_padding(kernel_size, dilation[1]))), 21 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 22 | padding=get_padding(kernel_size, dilation[2]))) 23 | ]) 24 | self.convs1.apply(init_weights) 25 | 26 | self.convs2 = nn.ModuleList([ 27 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 28 | padding=get_padding(kernel_size, 1))), 29 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 30 | padding=get_padding(kernel_size, 1))), 31 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 32 | padding=get_padding(kernel_size, 1))) 33 | ]) 34 | self.convs2.apply(init_weights) 35 | 36 | def forward(self, x): 37 | for c1, c2 in zip(self.convs1, self.convs2): 38 | xt = F.leaky_relu(x, LRELU_SLOPE) 39 | xt = c1(xt) 40 | xt = F.leaky_relu(xt, LRELU_SLOPE) 41 | xt = c2(xt) 42 | x = xt + x 43 | return x 44 | 45 | def remove_weight_norm(self): 46 | for l in self.convs1: 47 | remove_weight_norm(l) 48 | for l in self.convs2: 49 | remove_weight_norm(l) 50 | 51 | 52 | class ResBlock2(torch.nn.Module): 53 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 54 | super(ResBlock2, self).__init__() 55 | self.h = h 56 | self.convs = nn.ModuleList([ 57 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 58 | padding=get_padding(kernel_size, dilation[0]))), 59 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 60 | padding=get_padding(kernel_size, dilation[1]))) 61 | ]) 62 | self.convs.apply(init_weights) 63 | 64 | def forward(self, x): 65 | for c in self.convs: 66 | xt = F.leaky_relu(x, LRELU_SLOPE) 67 | xt = c(xt) 68 | x = xt + x 69 | return x 70 | 71 | def remove_weight_norm(self): 72 | for l in self.convs: 73 | remove_weight_norm(l) 74 | 75 | 76 | class Generator(torch.nn.Module): 77 | def __init__(self, h): 78 | super(Generator, self).__init__() 79 | self.h = h 80 | self.num_kernels = len(h.resblock_kernel_sizes) 81 | self.num_upsamples = len(h.upsample_rates) 82 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 83 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 84 | 85 | self.ups = nn.ModuleList() 86 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 87 | self.ups.append(weight_norm( 88 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 89 | # k, u, padding=(k-u)//2))) # for 22kHz version 90 | k, u, padding=(u//2 + u%2), output_padding=u%2))) # for 16kHz version 91 | 92 | self.resblocks = nn.ModuleList() 93 | for i in range(len(self.ups)): 94 | ch = h.upsample_initial_channel//(2**(i+1)) 95 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 96 | self.resblocks.append(resblock(h, ch, k, d)) 97 | 98 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 99 | self.ups.apply(init_weights) 100 | self.conv_post.apply(init_weights) 101 | 102 | def forward(self, x): 103 | x = self.conv_pre(x) 104 | for i in range(self.num_upsamples): 105 | x = F.leaky_relu(x, LRELU_SLOPE) 106 | x = self.ups[i](x) 107 | xs = None 108 | for j in range(self.num_kernels): 109 | if xs is None: 110 | xs = self.resblocks[i*self.num_kernels+j](x) 111 | else: 112 | xs += self.resblocks[i*self.num_kernels+j](x) 113 | x = xs / self.num_kernels 114 | x = F.leaky_relu(x) 115 | x = self.conv_post(x) 116 | x = torch.tanh(x) 117 | 118 | return x 119 | 120 | def remove_weight_norm(self): 121 | print('Removing weight norm...') 122 | for l in self.ups: 123 | remove_weight_norm(l) 124 | for l in self.resblocks: 125 | l.remove_weight_norm() 126 | remove_weight_norm(self.conv_pre) 127 | remove_weight_norm(self.conv_post) 128 | 129 | 130 | class DiscriminatorP(torch.nn.Module): 131 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 132 | super(DiscriminatorP, self).__init__() 133 | self.period = period 134 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 135 | self.convs = nn.ModuleList([ 136 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 137 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 138 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 139 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 140 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 141 | ]) 142 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 143 | 144 | def forward(self, x): 145 | fmap = [] 146 | 147 | # 1d to 2d 148 | b, c, t = x.shape 149 | if t % self.period != 0: # pad first 150 | n_pad = self.period - (t % self.period) 151 | x = F.pad(x, (0, n_pad), "reflect") 152 | t = t + n_pad 153 | x = x.view(b, c, t // self.period, self.period) 154 | 155 | for l in self.convs: 156 | x = l(x) 157 | x = F.leaky_relu(x, LRELU_SLOPE) 158 | fmap.append(x) 159 | x = self.conv_post(x) 160 | fmap.append(x) 161 | x = torch.flatten(x, 1, -1) 162 | 163 | return x, fmap 164 | 165 | 166 | class MultiPeriodDiscriminator(torch.nn.Module): 167 | def __init__(self): 168 | super(MultiPeriodDiscriminator, self).__init__() 169 | self.discriminators = nn.ModuleList([ 170 | DiscriminatorP(2), 171 | DiscriminatorP(3), 172 | DiscriminatorP(5), 173 | DiscriminatorP(7), 174 | DiscriminatorP(11), 175 | ]) 176 | 177 | def forward(self, y, y_hat): 178 | y_d_rs = [] 179 | y_d_gs = [] 180 | fmap_rs = [] 181 | fmap_gs = [] 182 | for i, d in enumerate(self.discriminators): 183 | y_d_r, fmap_r = d(y) 184 | y_d_g, fmap_g = d(y_hat) 185 | y_d_rs.append(y_d_r) 186 | fmap_rs.append(fmap_r) 187 | y_d_gs.append(y_d_g) 188 | fmap_gs.append(fmap_g) 189 | 190 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 191 | 192 | 193 | class DiscriminatorS(torch.nn.Module): 194 | def __init__(self, use_spectral_norm=False): 195 | super(DiscriminatorS, self).__init__() 196 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 197 | self.convs = nn.ModuleList([ 198 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 199 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 200 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 201 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 202 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 203 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 204 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 205 | ]) 206 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 207 | 208 | def forward(self, x): 209 | fmap = [] 210 | for l in self.convs: 211 | x = l(x) 212 | x = F.leaky_relu(x, LRELU_SLOPE) 213 | fmap.append(x) 214 | x = self.conv_post(x) 215 | fmap.append(x) 216 | x = torch.flatten(x, 1, -1) 217 | 218 | return x, fmap 219 | 220 | 221 | class MultiScaleDiscriminator(torch.nn.Module): 222 | def __init__(self): 223 | super(MultiScaleDiscriminator, self).__init__() 224 | self.discriminators = nn.ModuleList([ 225 | DiscriminatorS(use_spectral_norm=True), 226 | DiscriminatorS(), 227 | DiscriminatorS(), 228 | ]) 229 | self.meanpools = nn.ModuleList([ 230 | AvgPool1d(4, 2, padding=2), 231 | AvgPool1d(4, 2, padding=2) 232 | ]) 233 | 234 | def forward(self, y, y_hat): 235 | y_d_rs = [] 236 | y_d_gs = [] 237 | fmap_rs = [] 238 | fmap_gs = [] 239 | for i, d in enumerate(self.discriminators): 240 | if i != 0: 241 | y = self.meanpools[i-1](y) 242 | y_hat = self.meanpools[i-1](y_hat) 243 | y_d_r, fmap_r = d(y) 244 | y_d_g, fmap_g = d(y_hat) 245 | y_d_rs.append(y_d_r) 246 | fmap_rs.append(fmap_r) 247 | y_d_gs.append(y_d_g) 248 | fmap_gs.append(fmap_g) 249 | 250 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 251 | 252 | 253 | def feature_loss(fmap_r, fmap_g): 254 | loss = 0 255 | for dr, dg in zip(fmap_r, fmap_g): 256 | for rl, gl in zip(dr, dg): 257 | loss += torch.mean(torch.abs(rl - gl)) 258 | 259 | return loss*2 260 | 261 | 262 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 263 | loss = 0 264 | r_losses = [] 265 | g_losses = [] 266 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 267 | r_loss = torch.mean((1-dr)**2) 268 | g_loss = torch.mean(dg**2) 269 | loss += (r_loss + g_loss) 270 | r_losses.append(r_loss.item()) 271 | g_losses.append(g_loss.item()) 272 | 273 | return loss, r_losses, g_losses 274 | 275 | 276 | def generator_loss(disc_outputs): 277 | loss = 0 278 | gen_losses = [] 279 | for dg in disc_outputs: 280 | l = torch.mean((1-dg)**2) 281 | gen_losses.append(l) 282 | loss += l 283 | 284 | return loss, gen_losses 285 | 286 | -------------------------------------------------------------------------------- /hfgan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | -------------------------------------------------------------------------------- /hfgan/vocoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import argparse 5 | import glob 6 | import json 7 | import os 8 | from pathlib import Path 9 | from numpy import isin 10 | 11 | import torch 12 | from scipy.io.wavfile import write 13 | from tqdm import tqdm 14 | 15 | from .env import AttrDict 16 | from .mel_extractor.mel import wav2mel 17 | from .meldataset import MAX_WAV_VALUE 18 | from .models import Generator 19 | 20 | 21 | class Vocoder: 22 | def __init__(self, checkpoint_file, config_file="config.json", device="cpu"): 23 | checkpoint_file = Path(checkpoint_file) 24 | checkpoint_dir = checkpoint_file.parent 25 | self.config = self.load_config(checkpoint_dir / config_file) 26 | self.generator = self.load_generator(checkpoint_file, self.config, device) 27 | self.device = device 28 | 29 | 30 | def load_config(self, config_file): 31 | with open(config_file) as f: 32 | data = f.read() 33 | json_config = json.loads(data) 34 | return AttrDict(json_config) 35 | 36 | def load_generator(self, checkpoint_file, config, device): 37 | torch.manual_seed(config.seed) 38 | 39 | if isinstance(device, str): 40 | if device.startswith("cuda") and torch.cuda.is_available(): 41 | torch.cuda.manual_seed(config.seed) 42 | device = torch.device('cuda') 43 | else: 44 | device = torch.device('cpu') 45 | 46 | print("Loading '{}'".format(checkpoint_file)) 47 | assert checkpoint_file.is_file() 48 | generator = Generator(config).to(device) 49 | checkpoint_dict = torch.load(checkpoint_file, map_location=device) 50 | generator.load_state_dict(checkpoint_dict['generator']) 51 | generator.eval() 52 | generator.remove_weight_norm() 53 | print("Complete.") 54 | 55 | return generator 56 | 57 | def mel2wav(self, mel, output_file=None): 58 | with torch.no_grad(): 59 | if not isinstance(mel, torch.Tensor): 60 | mel = torch.FloatTensor(mel.T) 61 | x = mel.unsqueeze(0).to(self.device) 62 | y_g_hat = self.generator(x) 63 | audio = y_g_hat.squeeze() 64 | audio = audio.cpu().numpy() 65 | audio = (audio * MAX_WAV_VALUE).astype('int16') 66 | 67 | if output_file: 68 | write(output_file, self.config.sampling_rate, audio) 69 | return audio 70 | 71 | 72 | def main(): 73 | print('Initializing Inference Process..') 74 | 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('-i', '--input_wavs_dir', required=True) 77 | parser.add_argument('-o', '--output_dir', default='generated_files') 78 | parser.add_argument('--checkpoint_file', required=True) 79 | args = parser.parse_args() 80 | 81 | input_wavs_dir = Path(args.input_wavs_dir) 82 | output_wavs_dir = Path(args.output_dir) 83 | output_wavs_dir.mkdir(exist_ok=True) 84 | 85 | 86 | synthesizer = Vocoder(args.checkpoint_file, device="cpu") 87 | 88 | mels = [(wav2mel(wavfile)[0], wavfile.name) for wavfile in input_wavs_dir.iterdir()] 89 | for mel, savename in tqdm(mels): 90 | synthesizer.mel2wav(mel, output_file=str(output_wavs_dir/savename)) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | 96 | -------------------------------------------------------------------------------- /hfgan/vocoderutils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuhcsi/DiffVar/5380ca314dfe368cf3760d33f5ab9c142b01db0f/img/model.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2 import FastSpeech2 2 | from .loss import FastSpeech2Loss 3 | from .optimizer import ScheduledOptim 4 | from .diffspeech import GaussianDiffusion 5 | from .diffvar import DiffVariancePredictor -------------------------------------------------------------------------------- /model/diffnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/NATSpeech/NATSpeech/blob/238165e8cd430531b69c484cabb032c1313ee73b/modules/tts/diffspeech/net.py 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from math import sqrt 12 | 13 | Linear = nn.Linear 14 | ConvTranspose2d = nn.ConvTranspose2d 15 | 16 | 17 | class Mish(nn.Module): 18 | def forward(self, x): 19 | return x * torch.tanh(F.softplus(x)) 20 | 21 | 22 | class SinusoidalPosEmb(nn.Module): 23 | def __init__(self, dim): 24 | super().__init__() 25 | self.dim = dim 26 | 27 | def forward(self, x): 28 | device = x.device 29 | half_dim = self.dim // 2 30 | emb = math.log(10000) / (half_dim - 1) 31 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 32 | emb = x[:, None] * emb[None, :] 33 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 34 | return emb 35 | 36 | 37 | def Conv1d(*args, **kwargs): 38 | layer = nn.Conv1d(*args, **kwargs) 39 | nn.init.kaiming_normal_(layer.weight) 40 | return layer 41 | 42 | 43 | class ResidualBlock(nn.Module): 44 | def __init__(self, encoder_hidden, residual_channels, dilation): 45 | super().__init__() 46 | self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) 47 | self.diffusion_projection = Linear(residual_channels, residual_channels) 48 | self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) 49 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 50 | 51 | def forward(self, x, conditioner, diffusion_step): 52 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 53 | conditioner = self.conditioner_projection(conditioner) 54 | y = x + diffusion_step 55 | 56 | y = self.dilated_conv(y) + conditioner 57 | 58 | gate, filter = torch.chunk(y, 2, dim=1) 59 | y = torch.sigmoid(gate) * torch.tanh(filter) 60 | 61 | y = self.output_projection(y) 62 | residual, skip = torch.chunk(y, 2, dim=1) 63 | return (x + residual) / sqrt(2.0), skip 64 | 65 | 66 | class DiffNet(nn.Module): 67 | def __init__(self, in_dims, hidden_size, residual_layers, 68 | residual_channels, dilation_cycle_length): 69 | super().__init__() 70 | 71 | self.encoder_hidden = hidden_size 72 | self.residual_layers = residual_layers 73 | self.residual_channels = residual_channels 74 | self.dilation_cycle_length = dilation_cycle_length 75 | 76 | self.input_projection = Conv1d(in_dims, self.residual_channels, 1) 77 | self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) 78 | dim = self.residual_channels 79 | self.mlp = nn.Sequential( 80 | nn.Linear(dim, dim * 4), 81 | Mish(), 82 | nn.Linear(dim * 4, dim) 83 | ) 84 | self.residual_layers = nn.ModuleList([ 85 | ResidualBlock(self.encoder_hidden, self.residual_channels, 2 ** (i % self.dilation_cycle_length)) 86 | for i in range(self.residual_layers) 87 | ]) 88 | self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1) 89 | self.output_projection = Conv1d(self.residual_channels, in_dims, 1) 90 | nn.init.zeros_(self.output_projection.weight) 91 | 92 | def forward(self, spec, diffusion_step, cond): 93 | """ 94 | 95 | :param spec: [B, 1, M, T] 96 | :param diffusion_step: [B, 1] 97 | :param cond: [B, M, T] 98 | :return: 99 | """ 100 | x = spec[:, 0] 101 | x = self.input_projection(x) # x [B, residual_channel, T] 102 | 103 | x = F.relu(x) 104 | diffusion_step = self.diffusion_embedding(diffusion_step) 105 | diffusion_step = self.mlp(diffusion_step) 106 | skip = [] 107 | for layer_id, layer in enumerate(self.residual_layers): 108 | x, skip_connection = layer(x, cond, diffusion_step) 109 | skip.append(skip_connection) 110 | 111 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) 112 | x = self.skip_projection(x) 113 | x = F.relu(x) 114 | x = self.output_projection(x) # [B, 80, T] 115 | return x[:, None, :, :] -------------------------------------------------------------------------------- /model/diffspeech.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/NATSpeech/NATSpeech/blob/238165e8cd430531b69c484cabb032c1313ee73b/modules/tts/diffspeech/shallow_diffusion_tts.py 3 | """ 4 | 5 | import torch.nn as nn 6 | from model.fastspeech2 import FastSpeech2 7 | 8 | import math 9 | import random 10 | from functools import partial 11 | from inspect import isfunction 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | from tqdm import tqdm 17 | 18 | # from modules.tts.fs2_orig import FastSpeech2Orig 19 | # from modules.tts.diffspeech.net import DiffNet 20 | # from modules.tts.commons.align_ops import expand_states 21 | from model.diffnet import DiffNet 22 | 23 | 24 | def exists(x): 25 | return x is not None 26 | 27 | 28 | def default(val, d): 29 | if exists(val): 30 | return val 31 | return d() if isfunction(d) else d 32 | 33 | 34 | # gaussian diffusion trainer class 35 | 36 | def extract(a, t, x_shape): 37 | b, *_ = t.shape 38 | out = a.gather(-1, t) 39 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 40 | 41 | 42 | def noise_like(shape, device, repeat=False): 43 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 44 | noise = lambda: torch.randn(shape, device=device) 45 | return repeat_noise() if repeat else noise() 46 | 47 | 48 | def linear_beta_schedule(timesteps, max_beta=0.01): 49 | """ 50 | linear schedule 51 | """ 52 | betas = np.linspace(1e-4, max_beta, timesteps) 53 | return betas 54 | 55 | 56 | def cosine_beta_schedule(timesteps, s=0.008): 57 | """ 58 | cosine schedule 59 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 60 | """ 61 | steps = timesteps + 1 62 | x = np.linspace(0, steps, steps) 63 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 64 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 65 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 66 | return np.clip(betas, a_min=0, a_max=0.999) 67 | 68 | 69 | beta_schedule = { 70 | "cosine": cosine_beta_schedule, 71 | "linear": linear_beta_schedule, 72 | } 73 | 74 | 75 | DIFF_DECODERS = { 76 | 'wavenet': DiffNet, 77 | } 78 | 79 | 80 | # class AuxModel(FastSpeech2Orig): 81 | # def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None, 82 | # f0=None, uv=None, energy=None, infer=False, **kwargs): 83 | # ret = {} 84 | # encoder_out = self.encoder(txt_tokens) # [B, T, C] 85 | # src_nonpadding = (txt_tokens > 0).float()[:, :, None] 86 | # style_embed = self.forward_style_embed(spk_embed, spk_id) 87 | 88 | # # add dur 89 | # dur_inp = (encoder_out + style_embed) * src_nonpadding 90 | # mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret) 91 | # tgt_nonpadding = (mel2ph > 0).float()[:, :, None] 92 | # decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph) 93 | 94 | # # add pitch and energy embed 95 | # if self.hparams['use_pitch_embed']: 96 | # pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding 97 | # decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out) 98 | 99 | # # add pitch and energy embed 100 | # if self.hparams['use_energy_embed']: 101 | # energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding 102 | # decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret) 103 | 104 | # # decoder input 105 | # ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding 106 | # if self.hparams['dec_inp_add_noise']: 107 | # B, T, _ = decoder_inp.shape 108 | # z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device) 109 | # ret['adv_z'] = z 110 | # decoder_inp = torch.cat([decoder_inp, z], -1) 111 | # decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding 112 | # if kwargs['skip_decoder']: 113 | # return ret 114 | # ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) 115 | # return ret 116 | 117 | 118 | class GaussianDiffusion(nn.Module): 119 | def __init__(self, preprocess_config, model_config): 120 | # def __init__(self, dict_size, hparams, out_dims=None): 121 | super().__init__() 122 | 123 | hparams = model_config["diffusion"] 124 | self.hparams = hparams 125 | self.mel_bins = preprocess_config["preprocessing"]["mel"]["n_mel_channels"] 126 | # out_dims = hparams['audio_num_mel_bins'] 127 | denoise_fn = DIFF_DECODERS[hparams['diff_decoder_type']]( 128 | in_dims = self.mel_bins, 129 | hidden_size = model_config["transformer"]["decoder_hidden"], 130 | residual_layers = hparams["residual_layers"], 131 | residual_channels = hparams["residual_channels"], 132 | dilation_cycle_length = hparams["dilation_cycle_length"] 133 | ) 134 | timesteps = hparams['timesteps'] 135 | K_step = hparams['K_step'] 136 | loss_type = hparams['diff_loss_type'] 137 | # spec_min = hparams['spec_min'] 138 | # spec_max = hparams['spec_max'] 139 | 140 | self.denoise_fn = denoise_fn 141 | # self.fs2 = AuxModel(dict_size, hparams) 142 | self.fs2 = FastSpeech2(preprocess_config, model_config) 143 | pre_trained_fs2_ckpt = torch.load(hparams["fs2_path"]) 144 | self.fs2.load_state_dict(pre_trained_fs2_ckpt["model"], strict=True) 145 | for k, v in self.fs2.named_parameters(): 146 | v.requires_grad = False 147 | # self.mel_bins = out_dims 148 | 149 | if hparams['schedule_type'] == 'linear': 150 | betas = linear_beta_schedule(timesteps, hparams['max_beta']) 151 | else: 152 | betas = cosine_beta_schedule(timesteps) 153 | 154 | alphas = 1. - betas 155 | alphas_cumprod = np.cumprod(alphas, axis=0) 156 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 157 | 158 | timesteps, = betas.shape 159 | self.num_timesteps = int(timesteps) 160 | self.K_step = K_step 161 | self.loss_type = loss_type 162 | 163 | to_torch = partial(torch.tensor, dtype=torch.float32) 164 | 165 | self.register_buffer('betas', to_torch(betas)) 166 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 167 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 168 | 169 | # calculations for diffusion q(x_t | x_{t-1}) and others 170 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 171 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 172 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 173 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 174 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 175 | 176 | # calculations for posterior q(x_{t-1} | x_t, x_0) 177 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 178 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 179 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 180 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 181 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 182 | self.register_buffer('posterior_mean_coef1', to_torch( 183 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 184 | self.register_buffer('posterior_mean_coef2', to_torch( 185 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 186 | 187 | # self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) 188 | # self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) 189 | 190 | def q_mean_variance(self, x_start, t): 191 | mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 192 | variance = extract(1. - self.alphas_cumprod, t, x_start.shape) 193 | log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 194 | return mean, variance, log_variance 195 | 196 | def predict_start_from_noise(self, x_t, t, noise): 197 | return ( 198 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 199 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 200 | ) 201 | 202 | def q_posterior(self, x_start, x_t, t): 203 | posterior_mean = ( 204 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 205 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 206 | ) 207 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 208 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 209 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 210 | 211 | def p_mean_variance(self, x, t, cond, clip_denoised: bool): 212 | noise_pred = self.denoise_fn(x, t, cond=cond) 213 | x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) 214 | 215 | if clip_denoised: 216 | x_recon.clamp_(-1., 1.) 217 | 218 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 219 | return model_mean, posterior_variance, posterior_log_variance 220 | 221 | @torch.no_grad() 222 | def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): 223 | b, *_, device = *x.shape, x.device 224 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised) 225 | noise = noise_like(x.shape, device, repeat_noise) 226 | # no noise when t == 0 227 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 228 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 229 | 230 | def q_sample(self, x_start, t, noise=None): 231 | noise = default(noise, lambda: torch.randn_like(x_start)) 232 | return ( 233 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 234 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 235 | ) 236 | 237 | def p_losses(self, x_start, t, cond, noise=None, nonpadding=None): 238 | noise = default(noise, lambda: torch.randn_like(x_start)) 239 | 240 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 241 | x_recon = self.denoise_fn(x_noisy, t, cond) 242 | 243 | if self.loss_type == 'l1': 244 | if nonpadding is not None: 245 | loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean() 246 | else: 247 | # print('are you sure w/o nonpadding?') 248 | loss = (noise - x_recon).abs().mean() 249 | 250 | elif self.loss_type == 'l2': 251 | loss = F.mse_loss(noise, x_recon) 252 | else: 253 | raise NotImplementedError() 254 | 255 | return loss 256 | 257 | def forward(self, *inputs, **kwinputs): 258 | txt_tokens = inputs[1] 259 | b, *_, device = *txt_tokens.shape, txt_tokens.device 260 | 261 | fs2_ret = self.fs2(*inputs, **kwinputs, skip_decoder = True) 262 | cond = fs2_ret[0].transpose(1, 2) 263 | if self.training: 264 | t = torch.randint(0, self.K_step, (b,), device=device).long() 265 | x = inputs[4] 266 | x = self.norm_spec(x) 267 | x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] 268 | diff_loss = self.p_losses(x, t, cond) 269 | # nonpadding = (mel2ph != 0).float() 270 | # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding) 271 | return [diff_loss] 272 | else: 273 | t = self.K_step 274 | fs2_mels, postnet_output, mel_masks = self.fs2.decode(fs2_ret[0], fs2_ret[-3]) 275 | _cond, _post, *fs2_ret, _mel_masks, src_lens, mel_lens = fs2_ret 276 | fs2_ret = (fs2_mels, postnet_output, *fs2_ret, mel_masks, src_lens, mel_lens) 277 | 278 | fs2_mels = self.norm_spec(fs2_mels) 279 | fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :] 280 | 281 | x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long()) 282 | if self.hparams.get('gaussian_start') is not None and self.hparams['gaussian_start']: 283 | print('===> gaussian start.') 284 | shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) 285 | x = torch.randn(shape, device=device) 286 | # for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): 287 | for i in reversed(range(0, t)): 288 | x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) 289 | x = x[:, 0].transpose(1, 2) 290 | mel_out = self.denorm_spec(x) 291 | 292 | return mel_out, fs2_ret 293 | 294 | return ret 295 | 296 | def norm_spec(self, x): 297 | return x / 4. 298 | # return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 299 | 300 | def denorm_spec(self, x): 301 | return x * 4. 302 | # return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min 303 | 304 | def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): 305 | return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph) 306 | 307 | def out2mel(self, x): 308 | return x -------------------------------------------------------------------------------- /model/diffvar.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from model.diffnet import DiffNet 9 | from model.diffspeech import default, extract, noise_like 10 | 11 | 12 | class DenoiseDiffusion(nn.Module): 13 | def __init__(self, timesteps: int, loss_type: str, schedule_type:str, 14 | x_max: List[float], x_min: List[float], clip_denoised: bool, 15 | beta_min: float=1e-4, beta_max: float=0.01, s: float=8e-3): 16 | if schedule_type == 'linear': 17 | betas = np.linspace(beta_min, beta_max, timesteps) 18 | elif schedule_type == 'cosine': 19 | steps = timesteps + 1 20 | x = np.linspace(0, steps, steps) 21 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 22 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 23 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 24 | else: 25 | raise NotImplementedError(f"Unknow noise schedule type {schedule_type}") 26 | 27 | self.clip_denoised = clip_denoised 28 | 29 | alphas = 1. - betas 30 | alphas_cumprod = np.cumprod(alphas, axis=0) 31 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 32 | 33 | timesteps, = betas.shape 34 | self.num_timesteps = int(timesteps) 35 | # self.K_step = K_step 36 | self.loss_type = loss_type 37 | 38 | to_torch = partial(torch.tensor, dtype=torch.float32) 39 | 40 | self.register_buffer('betas', to_torch(betas)) 41 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 42 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 43 | 44 | # calculations for diffusion q(x_t | x_{t-1}) and others 45 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 46 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 47 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 48 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 49 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 50 | 51 | # calculations for posterior q(x_{t-1} | x_t, x_0) 52 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 53 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 54 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 55 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 56 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 57 | self.register_buffer('posterior_mean_coef1', to_torch( 58 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 59 | self.register_buffer('posterior_mean_coef2', to_torch( 60 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 61 | 62 | self.register_buffer('x_min', torch.FloatTensor(x_min)[None, None, :]) 63 | self.register_buffer('x_max', torch.FloatTensor(x_max)[None, None, :]) 64 | 65 | def q_mean_variance(self, x_start, t): 66 | mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 67 | variance = extract(1. - self.alphas_cumprod, t, x_start.shape) 68 | log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 69 | return mean, variance, log_variance 70 | 71 | def predict_start_from_noise(self, x_t, t, noise): 72 | return ( 73 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 74 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 75 | ) 76 | 77 | def q_posterior(self, x_start, x_t, t): 78 | posterior_mean = ( 79 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 80 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 81 | ) 82 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 83 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 84 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 85 | 86 | def p_mean_variance(self, x, t, cond, clip_denoised: bool, denoise_fn=None): 87 | if denoise_fn is None: 88 | denoise_fn = self.denoise_fn 89 | noise_pred = denoise_fn(x, t, cond=cond) 90 | x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) 91 | 92 | if clip_denoised: 93 | x_recon.clamp_(-1., 1.) 94 | 95 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 96 | return model_mean, posterior_variance, posterior_log_variance 97 | 98 | @torch.no_grad() 99 | def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False, denoise_fn=None): 100 | b, *_, device = *x.shape, x.device 101 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised, denoise_fn=denoise_fn) 102 | noise = noise_like(x.shape, device, repeat_noise) 103 | # no noise when t == 0 104 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 105 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 106 | 107 | def q_sample(self, x_start, t, noise=None): 108 | noise = default(noise, lambda: torch.randn_like(x_start)) 109 | return ( 110 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 111 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 112 | ) 113 | 114 | def p_losses(self, x_start, t, cond, noise=None, nonpadding=None, denoise_fn=None): 115 | if denoise_fn is None: 116 | denoise_fn = self.denoise_fn 117 | 118 | noise = default(noise, lambda: torch.randn_like(x_start)) 119 | 120 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 121 | x_recon = denoise_fn(x_noisy, t, cond) 122 | 123 | if self.loss_type == 'l1': 124 | if nonpadding is not None: 125 | loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean() 126 | else: 127 | # print('are you sure w/o nonpadding?') 128 | loss = (noise - x_recon).abs().mean() 129 | 130 | elif self.loss_type == 'l2': 131 | # loss = F.mse_loss(noise, x_recon) 132 | 133 | if nonpadding is not None: 134 | loss = (((noise - x_recon)**2) * nonpadding.unsqueeze(1)).mean() 135 | else: 136 | # print('are you sure w/o nonpadding?') 137 | loss = ((noise - x_recon)**2).mean() 138 | else: 139 | raise NotImplementedError() 140 | 141 | return loss 142 | 143 | def norm(self, x): 144 | return (x - self.x_min) / (self.x_max - self.x_min) * 2 - 1 145 | 146 | def denorm(self, x): 147 | return (x + 1) / 2 * (self.x_max - self.x_min) + self.x_min 148 | 149 | class DiffVariancePredictor(DenoiseDiffusion): 150 | def __init__(self, model_config, in_dim=None): 151 | hparams = model_config["diffusion"] 152 | self.timestep = hparams["timesteps"] 153 | self.in_dim = hparams["in_dim"] if in_dim is None else in_dim 154 | 155 | nn.Module.__init__(self) 156 | DenoiseDiffusion.__init__( 157 | self, 158 | timesteps = hparams["timesteps"], 159 | loss_type=hparams["diff_loss_type"], 160 | schedule_type=hparams["schedule_type"], 161 | beta_max=hparams["max_beta"], 162 | x_max = hparams["x_max"], # [channels] 163 | x_min = hparams["x_min"], # [channels] 164 | clip_denoised = hparams["clip_denoised"], 165 | ) 166 | 167 | if isinstance(self.in_dim, int): 168 | self.total_dim = self.in_dim 169 | self.in_dim = [list(range(self.in_dim))] 170 | else: 171 | self.total_dim = sum([len(chs) for chs in self.in_dim]) 172 | 173 | self.denoise_fn = nn.ModuleList([DiffNet( 174 | in_dims = len(chs), 175 | hidden_size = model_config["transformer"]["decoder_hidden"], 176 | residual_layers = hparams["residual_layers"], 177 | residual_channels = hparams["residual_channels"], 178 | dilation_cycle_length = hparams["dilation_cycle_length"] 179 | ) for chs in self.in_dim]) 180 | 181 | def split_channel(self, x): # ipnut x: [B, 1, C, T] 182 | return [x[:, :, chs] for chs in self.in_dim] 183 | 184 | def gather_channel(self, x): 185 | # TODO: implement proper gather method for non-continuous channel index 186 | return torch.cat(x, dim=2) 187 | 188 | def get_cond(self, fs2_model, batch): 189 | speakers, texts, src_lens, max_src_len = batch[:4] 190 | quasi_symbols = batch[-1] 191 | 192 | output, src_lens, max_src_len, src_masks = fs2_model.encode(speakers, texts, src_lens, max_src_len, quasi_symbols) 193 | 194 | return output.transpose(1,2), src_masks 195 | 196 | def training_step(self, batch, fs2_model): 197 | with torch.no_grad(): 198 | cond, cond_masks = self.get_cond(fs2_model, batch) 199 | 200 | t = torch.randint(0, self.timestep, (cond.size(0),), device=cond.device) 201 | 202 | p_targets = batch[-4] 203 | e_targets = batch[-3] 204 | batch[-2][batch[-2]==0] = 1 205 | log_d_targets = torch.log(batch[-2]) 206 | # d = batch[-2] 207 | # print(d.size()) 208 | # print(d.max(dim=-1)[0].max(dim=0)[0], d.min(dim=-1)[0].min(dim=0)[0]) 209 | # print(torch.log(d).max(), torch.log(d).min()) 210 | # exit(0) 211 | var_tgt = self.norm(torch.stack([p_targets, e_targets, log_d_targets], dim=-1)) 212 | var_tgt = var_tgt.transpose(1,2)[:, None, :, :] # [B, 1, M, T] 213 | 214 | diff_loss = [ 215 | self.p_losses(v_tgt, t, cond, nonpadding=~cond_masks, denoise_fn=d_fn) 216 | for d_fn, v_tgt in zip(self.denoise_fn, self.split_channel(var_tgt)) 217 | ] 218 | 219 | return diff_loss 220 | 221 | def validation_step(self, batch, fs2_model): 222 | with torch.no_grad(): 223 | cond, cond_masks = self.get_cond(fs2_model, batch) 224 | 225 | x = torch.randn(( 226 | cond.shape[0], 1, self.total_dim, cond.shape[2] 227 | ), device=cond.device) 228 | 229 | res = [] 230 | for d_fn, x_d in zip(self.denoise_fn, self.split_channel(x)): 231 | for i in reversed(range(0, self.timestep)): 232 | x_d = self.p_sample( 233 | x_d, 234 | torch.full((cond.shape[0],), i, device=cond.device, dtype=torch.long), 235 | cond, 236 | clip_denoised=self.clip_denoised, 237 | denoise_fn=d_fn, 238 | ) 239 | res.append(x_d) 240 | x = self.gather_channel(res) 241 | 242 | x = x[:, 0].transpose(1, 2) 243 | x = self.denorm(x) 244 | return x, cond, cond_masks 245 | -------------------------------------------------------------------------------- /model/fastspeech2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from transformer import Encoder, Decoder, PostNet 9 | from .modules import VarianceAdaptor 10 | from utils.tools import get_mask_from_lengths, drop_idxes 11 | 12 | 13 | class FastSpeech2(nn.Module): 14 | """ FastSpeech2 """ 15 | 16 | def __init__(self, preprocess_config, model_config): 17 | super(FastSpeech2, self).__init__() 18 | self.model_config = model_config 19 | 20 | self.encoder = Encoder(model_config) 21 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config) 22 | self.decoder = Decoder(model_config) 23 | self.mel_linear = nn.Linear( 24 | model_config["transformer"]["decoder_hidden"], 25 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 26 | ) 27 | self.postnet = PostNet() 28 | 29 | self.speaker_emb = None 30 | if model_config["multi_speaker"]: 31 | with open( 32 | os.path.join( 33 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 34 | ), 35 | "r", 36 | ) as f: 37 | n_speaker = len(json.load(f)) 38 | self.speaker_emb = nn.Embedding( 39 | n_speaker, 40 | model_config["transformer"]["encoder_hidden"], 41 | ) 42 | 43 | def forward( 44 | self, 45 | speakers, 46 | texts, 47 | src_lens, 48 | max_src_len, 49 | mels=None, 50 | mel_lens=None, 51 | max_mel_len=None, 52 | p_targets=None, 53 | e_targets=None, 54 | d_targets=None, 55 | quasi_symbols=None, 56 | p_control=1.0, 57 | e_control=1.0, 58 | d_control=1.0, 59 | skip_decoder=False, 60 | ): 61 | 62 | output, src_lens, max_src_len, src_masks = self.encode(speakers, texts, src_lens, max_src_len, quasi_symbols) 63 | 64 | mel_masks = ( 65 | get_mask_from_lengths(mel_lens, max_mel_len) 66 | if mel_lens is not None 67 | else None 68 | ) 69 | 70 | ( 71 | output, 72 | p_predictions, 73 | e_predictions, 74 | log_d_predictions, 75 | d_rounded, 76 | mel_lens, 77 | mel_masks, 78 | ) = self.variance_adaptor( 79 | output, 80 | src_masks, 81 | mel_masks, 82 | max_mel_len, 83 | p_targets, 84 | e_targets, 85 | d_targets, 86 | p_control, 87 | e_control, 88 | d_control, 89 | ) 90 | 91 | if not skip_decoder: 92 | output, postnet_output, mel_masks = self.decode(output, mel_masks) 93 | else: 94 | postnet_output = None 95 | 96 | return ( 97 | output, 98 | postnet_output, 99 | p_predictions, 100 | e_predictions, 101 | log_d_predictions, 102 | d_rounded, 103 | src_masks, 104 | mel_masks, 105 | src_lens, 106 | mel_lens, 107 | ) 108 | 109 | def encode( 110 | self, 111 | speakers, 112 | texts, 113 | src_lens, 114 | max_src_len, 115 | quasi_symbols=None, 116 | skip_speaker=False, 117 | ): 118 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 119 | 120 | output = self.encoder(texts, src_masks) 121 | 122 | 123 | if quasi_symbols is not None: 124 | assert quasi_symbols.size(1) == output.size(1), \ 125 | f"{quasi_symbols.size()} {output.size()}" 126 | 127 | quasi_symbol_cnt = torch.sum(quasi_symbols, dim=1) 128 | src_lens = src_lens - quasi_symbol_cnt 129 | 130 | max_src_len, shrinkage = torch.max(src_lens), max_src_len - torch.max(src_lens) 131 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 132 | 133 | output = drop_idxes(output, quasi_symbols, shrinkage) 134 | 135 | if not skip_speaker and self.speaker_emb is not None: 136 | output = output + self.speaker_emb(speakers).unsqueeze(1).expand( 137 | -1, max_src_len, -1 138 | ) 139 | 140 | return output, src_lens, max_src_len, src_masks 141 | 142 | def decode(self, dec_input, mel_masks): 143 | output, mel_masks = self.decoder(dec_input, mel_masks) 144 | output = self.mel_linear(output) 145 | postnet_output = self.postnet(output) + output 146 | 147 | return output, postnet_output, mel_masks -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FastSpeech2Loss(nn.Module): 6 | """ FastSpeech2 Loss """ 7 | 8 | def __init__(self, preprocess_config, model_config): 9 | super(FastSpeech2Loss, self).__init__() 10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 11 | "feature" 12 | ] 13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 14 | "feature" 15 | ] 16 | self.mse_loss = nn.MSELoss() 17 | self.mae_loss = nn.L1Loss() 18 | 19 | def forward(self, inputs, predictions): 20 | ( 21 | mel_targets, 22 | _, 23 | _, 24 | pitch_targets, 25 | energy_targets, 26 | duration_targets, 27 | *_ 28 | ) = inputs[6:] 29 | ( 30 | mel_predictions, 31 | postnet_mel_predictions, 32 | pitch_predictions, 33 | energy_predictions, 34 | log_duration_predictions, 35 | _, 36 | src_masks, 37 | mel_masks, 38 | _, 39 | _, 40 | ) = predictions 41 | src_masks = ~src_masks 42 | mel_masks = ~mel_masks 43 | log_duration_targets = torch.log(duration_targets.float() + 1) 44 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 45 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 46 | 47 | log_duration_targets.requires_grad = False 48 | pitch_targets.requires_grad = False 49 | energy_targets.requires_grad = False 50 | mel_targets.requires_grad = False 51 | 52 | if self.pitch_feature_level == "phoneme_level": 53 | pitch_predictions = pitch_predictions.masked_select(src_masks) 54 | pitch_targets = pitch_targets.masked_select(src_masks) 55 | elif self.pitch_feature_level == "frame_level": 56 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 57 | pitch_targets = pitch_targets.masked_select(mel_masks) 58 | 59 | if self.energy_feature_level == "phoneme_level": 60 | energy_predictions = energy_predictions.masked_select(src_masks) 61 | energy_targets = energy_targets.masked_select(src_masks) 62 | if self.energy_feature_level == "frame_level": 63 | energy_predictions = energy_predictions.masked_select(mel_masks) 64 | energy_targets = energy_targets.masked_select(mel_masks) 65 | 66 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 67 | log_duration_targets = log_duration_targets.masked_select(src_masks) 68 | 69 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 70 | postnet_mel_predictions = postnet_mel_predictions.masked_select( 71 | mel_masks.unsqueeze(-1) 72 | ) 73 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 74 | 75 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 76 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) 77 | 78 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 79 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 80 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) 81 | 82 | total_loss = ( 83 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss 84 | ) 85 | 86 | return ( 87 | total_loss, 88 | mel_loss, 89 | postnet_mel_loss, 90 | pitch_loss, 91 | energy_loss, 92 | duration_loss, 93 | ) 94 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | from utils.tools import get_mask_from_lengths, pad 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class VarianceAdaptor(nn.Module): 18 | """Variance Adaptor""" 19 | 20 | def __init__(self, preprocess_config, model_config): 21 | super(VarianceAdaptor, self).__init__() 22 | self.duration_predictor = VariancePredictor(model_config) 23 | self.length_regulator = LengthRegulator() 24 | self.pitch_predictor = VariancePredictor(model_config) 25 | self.energy_predictor = VariancePredictor(model_config) 26 | 27 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 28 | "feature" 29 | ] 30 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 31 | "feature" 32 | ] 33 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 34 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 35 | 36 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"] 37 | energy_quantization = model_config["variance_embedding"]["energy_quantization"] 38 | n_bins = model_config["variance_embedding"]["n_bins"] 39 | assert pitch_quantization in ["linear", "log"] 40 | assert energy_quantization in ["linear", "log"] 41 | with open( 42 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 43 | ) as f: 44 | stats = json.load(f) 45 | pitch_min, pitch_max = stats["pitch"][:2] 46 | energy_min, energy_max = stats["energy"][:2] 47 | 48 | if pitch_quantization == "log": 49 | self.pitch_bins = nn.Parameter( 50 | torch.exp( 51 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) 52 | ), 53 | requires_grad=False, 54 | ) 55 | else: 56 | self.pitch_bins = nn.Parameter( 57 | torch.linspace(pitch_min, pitch_max, n_bins - 1), 58 | requires_grad=False, 59 | ) 60 | if energy_quantization == "log": 61 | self.energy_bins = nn.Parameter( 62 | torch.exp( 63 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) 64 | ), 65 | requires_grad=False, 66 | ) 67 | else: 68 | self.energy_bins = nn.Parameter( 69 | torch.linspace(energy_min, energy_max, n_bins - 1), 70 | requires_grad=False, 71 | ) 72 | 73 | self.pitch_embedding = nn.Embedding( 74 | n_bins, model_config["transformer"]["encoder_hidden"] 75 | ) 76 | self.energy_embedding = nn.Embedding( 77 | n_bins, model_config["transformer"]["encoder_hidden"] 78 | ) 79 | 80 | def get_pitch_embedding(self, x, target, mask, control): 81 | prediction = self.pitch_predictor(x, mask) 82 | if target is not None: 83 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) 84 | else: 85 | prediction = prediction * control 86 | embedding = self.pitch_embedding( 87 | torch.bucketize(prediction, self.pitch_bins) 88 | ) 89 | return prediction, embedding 90 | 91 | def get_energy_embedding(self, x, target, mask, control): 92 | prediction = self.energy_predictor(x, mask) 93 | if target is not None: 94 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) 95 | else: 96 | prediction = prediction * control 97 | embedding = self.energy_embedding( 98 | torch.bucketize(prediction, self.energy_bins) 99 | ) 100 | return prediction, embedding 101 | 102 | def forward( 103 | self, 104 | x, 105 | src_mask, 106 | mel_mask=None, 107 | max_len=None, 108 | pitch_target=None, 109 | energy_target=None, 110 | duration_target=None, 111 | p_control=1.0, 112 | e_control=1.0, 113 | d_control=1.0, 114 | ): 115 | pred_input_x = x.detach() 116 | 117 | log_duration_prediction = self.duration_predictor(pred_input_x, src_mask) 118 | if self.pitch_feature_level == "phoneme_level": 119 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 120 | pred_input_x, pitch_target, src_mask, p_control 121 | ) 122 | pred_input_x = pred_input_x + pitch_embedding 123 | if self.energy_feature_level == "phoneme_level": 124 | energy_prediction, energy_embedding = self.get_energy_embedding( 125 | x, energy_target, src_mask, e_control 126 | ) 127 | pred_input_x = pred_input_x + energy_embedding 128 | 129 | x = x + energy_embedding + pitch_embedding 130 | 131 | if duration_target is not None: 132 | x, mel_len = self.length_regulator(x, duration_target, max_len) 133 | duration_rounded = duration_target 134 | mel_mask = get_mask_from_lengths(mel_len) 135 | else: 136 | duration_rounded = torch.clamp( 137 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 138 | min=0, 139 | ) 140 | x, mel_len = self.length_regulator(x, duration_rounded, max_len) 141 | mel_mask = get_mask_from_lengths(mel_len) 142 | 143 | if self.pitch_feature_level == "frame_level": 144 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 145 | x, pitch_target, mel_mask, p_control 146 | ) 147 | x = x + pitch_embedding 148 | if self.energy_feature_level == "frame_level": 149 | energy_prediction, energy_embedding = self.get_energy_embedding( 150 | x, energy_target, mel_mask, e_control 151 | ) 152 | x = x + energy_embedding 153 | 154 | return ( 155 | x, 156 | pitch_prediction, 157 | energy_prediction, 158 | log_duration_prediction, 159 | duration_rounded, 160 | mel_len, 161 | mel_mask, 162 | ) 163 | 164 | 165 | class LengthRegulator(nn.Module): 166 | """Length Regulator""" 167 | 168 | def __init__(self): 169 | super(LengthRegulator, self).__init__() 170 | 171 | def LR(self, x, duration, max_len): 172 | output = list() 173 | mel_len = list() 174 | for batch, expand_target in zip(x, duration): 175 | expanded = self.expand(batch, expand_target) 176 | output.append(expanded) 177 | mel_len.append(expanded.shape[0]) 178 | 179 | if max_len is not None: 180 | output = pad(output, max_len) 181 | else: 182 | output = pad(output) 183 | 184 | return output, torch.LongTensor(mel_len).to(device) 185 | 186 | def expand(self, batch, predicted): 187 | out = list() 188 | 189 | for i, vec in enumerate(batch): 190 | expand_size = predicted[i].item() 191 | out.append(vec.expand(max(int(expand_size), 0), -1)) 192 | out = torch.cat(out, 0) 193 | 194 | return out 195 | 196 | def forward(self, x, duration, max_len): 197 | output, mel_len = self.LR(x, duration, max_len) 198 | return output, mel_len 199 | 200 | 201 | class VariancePredictor(nn.Module): 202 | """Duration, Pitch and Energy Predictor""" 203 | 204 | def __init__(self, model_config): 205 | super(VariancePredictor, self).__init__() 206 | 207 | self.input_size = model_config["transformer"]["encoder_hidden"] 208 | self.filter_size = model_config["variance_predictor"]["filter_size"] 209 | self.kernel = model_config["variance_predictor"]["kernel_size"] 210 | self.conv_output_size = model_config["variance_predictor"]["filter_size"] 211 | self.dropout = model_config["variance_predictor"]["dropout"] 212 | 213 | self.conv_layer = nn.Sequential( 214 | OrderedDict( 215 | [ 216 | ( 217 | "conv1d_1", 218 | Conv( 219 | self.input_size, 220 | self.filter_size, 221 | kernel_size=self.kernel, 222 | padding=(self.kernel - 1) // 2, 223 | ), 224 | ), 225 | ("relu_1", nn.ReLU()), 226 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 227 | ("dropout_1", nn.Dropout(self.dropout)), 228 | ( 229 | "conv1d_2", 230 | Conv( 231 | self.filter_size, 232 | self.filter_size, 233 | kernel_size=self.kernel, 234 | padding=1, 235 | ), 236 | ), 237 | ("relu_2", nn.ReLU()), 238 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 239 | ("dropout_2", nn.Dropout(self.dropout)), 240 | ] 241 | ) 242 | ) 243 | 244 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 245 | 246 | def forward(self, encoder_output, mask): 247 | out = self.conv_layer(encoder_output) 248 | out = self.linear_layer(out) 249 | out = out.squeeze(-1) 250 | 251 | if mask is not None: 252 | out = out.masked_fill(mask, 0.0) 253 | 254 | return out 255 | 256 | 257 | class Conv(nn.Module): 258 | """ 259 | Convolution Module 260 | """ 261 | 262 | def __init__( 263 | self, 264 | in_channels, 265 | out_channels, 266 | kernel_size=1, 267 | stride=1, 268 | padding=0, 269 | dilation=1, 270 | bias=True, 271 | w_init="linear", 272 | ): 273 | """ 274 | :param in_channels: dimension of input 275 | :param out_channels: dimension of output 276 | :param kernel_size: size of kernel 277 | :param stride: size of stride 278 | :param padding: size of padding 279 | :param dilation: dilation rate 280 | :param bias: boolean. if True, bias is included. 281 | :param w_init: str. weight inits with xavier initialization. 282 | """ 283 | super(Conv, self).__init__() 284 | 285 | self.conv = nn.Conv1d( 286 | in_channels, 287 | out_channels, 288 | kernel_size=kernel_size, 289 | stride=stride, 290 | padding=padding, 291 | dilation=dilation, 292 | bias=bias, 293 | ) 294 | 295 | def forward(self, x): 296 | x = x.contiguous().transpose(1, 2) 297 | x = self.conv(x) 298 | x = x.contiguous().transpose(1, 2) 299 | 300 | return x 301 | -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, train_config, model_config, current_step): 9 | 10 | self._optimizer = torch.optim.Adam( 11 | model.parameters(), 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"] 19 | self.current_step = current_step 20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 21 | 22 | def step_and_update_lr(self): 23 | lr = self._update_learning_rate() 24 | self._optimizer.step() 25 | return lr 26 | 27 | def zero_grad(self): 28 | # print(self.init_lr) 29 | self._optimizer.zero_grad() 30 | 31 | def load_state_dict(self, path): 32 | self._optimizer.load_state_dict(path) 33 | 34 | def _get_lr_scale(self): 35 | lr = np.min( 36 | [ 37 | np.power(self.current_step, -0.5), 38 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 39 | ] 40 | ) 41 | for s in self.anneal_steps: 42 | if self.current_step > s: 43 | lr = lr * self.anneal_rate 44 | return lr 45 | 46 | def _update_learning_rate(self): 47 | """ Learning rate scheduling per step """ 48 | self.current_step += 1 49 | lr = self.init_lr * self._get_lr_scale() 50 | 51 | for param_group in self._optimizer.param_groups: 52 | param_group["lr"] = lr 53 | 54 | return lr -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | aiohttp==3.7.4.post0 3 | appdirs==1.4.4 4 | argon2-cffi==20.1.0 5 | async-generator==1.10 6 | async-timeout==3.0.1 7 | attrs==21.2.0 8 | audioread==2.1.9 9 | backcall==0.2.0 10 | batchrenorm==0.1.0 11 | bcrypt==3.2.0 12 | bleach==4.0.0 13 | blinker==1.4 14 | Bottleneck==1.3.2 15 | brotlipy==0.7.0 16 | cachetools==4.2.2 17 | certifi==2021.10.8 18 | cffi==1.14.6 19 | chardet==4.0.0 20 | charset-normalizer==2.0.0 21 | click==8.0.1 22 | colorama==0.4.4 23 | coverage==5.5 24 | cryptography==3.4.7 25 | cycler==0.10.0 26 | Cython==0.29.24 27 | dataclasses==0.8 28 | decorator==5.0.9 29 | defusedxml==0.7.1 30 | Distance==0.1.3 31 | einops==0.4.1 32 | entrypoints==0.3 33 | fabric==2.6.0 34 | filelock==3.4.0 35 | fsspec==2021.7.0 36 | future==0.18.2 37 | g2p-en==2.1.0 38 | google-auth==1.35.0 39 | google-auth-oauthlib==0.4.5 40 | grpcio==1.36.1 41 | huggingface-hub==0.2.1 42 | idna==3.2 43 | imageio==2.9.0 44 | importlib-metadata==4.6.4 45 | inflect==5.3.0 46 | invoke==1.6.0 47 | ipykernel==5.3.4 48 | ipython==7.26.0 49 | ipython-genutils==0.2.0 50 | ipywidgets==7.6.3 51 | jedi==0.18.0 52 | Jinja2==3.0.1 53 | joblib==1.0.1 54 | jsonargparse==3.18.0 55 | jsonschema==3.0.2 56 | jupyter==1.0.0 57 | jupyter-client==6.1.12 58 | jupyter-console==6.4.0 59 | jupyter-core==4.7.1 60 | jupyterlab-pygments==0.1.2 61 | jupyterlab-widgets==1.0.0 62 | kiwisolver==1.3.1 63 | librosa==0.8.1 64 | llvmlite==0.36.0 65 | Markdown==3.3.4 66 | MarkupSafe==2.0.1 67 | matplotlib==3.3.4 68 | matplotlib-inline==0.1.2 69 | mistune==0.8.4 70 | mkl-fft==1.3.0 71 | mkl-random==1.2.2 72 | mkl-service==2.4.0 73 | Montreal-Forced-Aligner==2.0.0b3 74 | multidict==5.1.0 75 | nbclient==0.5.3 76 | nbconvert==6.1.0 77 | nbformat==5.1.3 78 | nest-asyncio==1.5.1 79 | nltk==3.7 80 | notebook==6.4.3 81 | numba==0.53.1 82 | numexpr==2.7.3 83 | numpy==1.20.3 84 | oauthlib==3.1.1 85 | olefile==0.46 86 | packaging==21.0 87 | pandas==1.3.5 88 | pandocfilters==1.4.3 89 | paramiko==2.8.0 90 | parso==0.8.2 91 | pathlib2==2.3.6 92 | pexpect==4.8.0 93 | pickleshare==0.7.5 94 | Pillow==8.3.1 95 | pip==21.2.4 96 | plotly==5.14.1 97 | pooch==1.4.0 98 | praatio==5.0.0 99 | prometheus-client==0.11.0 100 | prompt-toolkit==3.0.17 101 | protobuf==3.17.2 102 | ptyprocess==0.7.0 103 | pyasn1==0.4.8 104 | pyasn1-modules==0.2.8 105 | pycparser==2.20 106 | pyDeprecate==0.3.1 107 | Pygments==2.10.0 108 | PyJWT==2.1.0 109 | PyNaCl==1.4.0 110 | pynini==2.1.4 111 | pyOpenSSL==20.0.1 112 | pyparsing==2.4.7 113 | pypinyin==0.44.0 114 | pyrsistent==0.18.0 115 | PySocks==1.7.1 116 | python-dateutil==2.8.2 117 | pytorch-lightning==1.4.2 118 | pytz==2021.3 119 | pyu2f==0.1.5 120 | pyworld==0.3.0 121 | PyYAML==5.4.1 122 | pyzmq==22.2.1 123 | qtconsole==5.1.0 124 | QtPy==1.10.0 125 | regex==2021.11.10 126 | requests==2.26.0 127 | requests-oauthlib==1.3.0 128 | resampy==0.2.2 129 | rsa==4.7.2 130 | sacremoses==0.0.46 131 | scikit-learn==0.24.1 132 | scipy==1.7.1 133 | seaborn==0.11.2 134 | Send2Trash==1.5.0 135 | setuptools==52.0.0.post20210125 136 | sip==4.19.13 137 | six==1.16.0 138 | SoundFile==0.10.3.post1 139 | tenacity==8.2.2 140 | tensorboard==2.4.0 141 | tensorboard-plugin-wit==1.8.0 142 | terminado==0.9.4 143 | test-tube==0.7.5 144 | testpath==0.5.0 145 | thop==0.1.1.post2209072238 146 | threadpoolctl==2.2.0 147 | tokenizers==0.10.3 148 | torch==1.9.0 149 | torchmetrics==0.5.0 150 | tornado==6.1 151 | tqdm==4.62.1 152 | traitlets==5.0.5 153 | transformers==4.14.1 154 | typing-extensions==3.10.0.0 155 | Unidecode==1.2.0 156 | urllib3==1.26.6 157 | wcwidth==0.2.5 158 | webencodings==0.5.1 159 | Werkzeug==2.0.1 160 | wheel==0.37.0 161 | widgetsnbextension==3.5.1 162 | yarl==1.6.3 163 | zipp==3.5.0 164 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | import yaml 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | from model import FastSpeech2, DiffVariancePredictor 9 | from utils.model import get_model, get_vocoder, get_param_num 10 | from utils.tools import to_device, synth_samples 11 | from dataset import Dataset 12 | 13 | from pathlib import Path 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | def synthesize_batch(batch, model, configs, vocoder, outdir, control_values, 18 | use_gt_var:bool=False, var_diff_pred=None): 19 | 20 | outdir = Path(outdir) 21 | outdir.mkdir(parents=True, exist_ok=True) 22 | 23 | ( 24 | ids, 25 | raw_texts, 26 | speakers, 27 | texts, 28 | text_lens, 29 | max_text_lens, 30 | mels, 31 | mel_lens, 32 | max_mel_lens, 33 | pitches, 34 | energies, 35 | durations, 36 | quasi_flags, 37 | ) = batch 38 | 39 | if var_diff_pred is not None: 40 | durations = torch.clamp( 41 | (torch.round(torch.exp(var_diff_pred[:, :, 2]))), 42 | min=0, 43 | ) 44 | for i, (l, f) in enumerate(zip(text_lens, quasi_flags)): 45 | durations[i, l-sum(f):] = 0. 46 | elif not use_gt_var: 47 | durations = None 48 | 49 | batch = ( 50 | ids, 51 | raw_texts, 52 | speakers, 53 | texts, 54 | text_lens, 55 | max(text_lens), 56 | None, 57 | mel_lens if use_gt_var else None, 58 | max_mel_lens if use_gt_var else None, 59 | var_diff_pred[:, :, 0] if var_diff_pred is not None else \ 60 | (pitches if use_gt_var else None), 61 | var_diff_pred[:, :, 1] if var_diff_pred is not None else \ 62 | (energies if use_gt_var else None), 63 | durations, 64 | quasi_flags, 65 | ) 66 | 67 | # Forward 68 | pitch_control, energy_control, duration_control = control_values 69 | output = model( 70 | *(batch[2:-1]), 71 | quasi_symbols = batch[-1], 72 | p_control=pitch_control, 73 | e_control=energy_control, 74 | d_control=duration_control 75 | ) 76 | 77 | preprocess_config, model_config, train_config = configs 78 | synth_samples( 79 | batch, 80 | output, 81 | vocoder, 82 | model_config, 83 | preprocess_config, 84 | outdir, 85 | ) 86 | 87 | def synthesize(fs2_model, configs, vocoder, dataloader, outdir, diffvar_model=None, 88 | use_gt_var=False, control_values=[1.,1.,1.], control_dv_spker=None,control_dec_spkers=None): 89 | for batchs in tqdm(dataloader): 90 | for batch in batchs: 91 | batch = to_device(batch, device) 92 | with torch.no_grad(): 93 | if diffvar_model is not None: 94 | if control_dv_spker is not None: 95 | batch[2].fill_(control_dv_spker) 96 | var, cond, var_mask = diffvar_model.validation_step(batch[2:], fs2_model) 97 | # pitch_diff_pred = var[:, :, 0] 98 | else: 99 | # pitch_diff_pred = None 100 | # var = None 101 | if control_dv_spker is not None: 102 | batch[2].fill_(control_dv_spker) 103 | outputs = fs2_model(*(list(batch[2:-4])+[None]*3+[batch[-1]]), skip_decoder=True) 104 | var = torch.stack(outputs[2:5], dim=-1) 105 | 106 | 107 | if control_dec_spkers is not None: 108 | for control_dec_spker in control_dec_spkers: 109 | batch[2].fill_(control_dec_spker) 110 | synthesize_batch( 111 | batch, fs2_model, configs, vocoder, 112 | os.path.join(outdir, str(control_dec_spker)), control_values, 113 | use_gt_var=use_gt_var, var_diff_pred=var) 114 | 115 | def getConfig(config_name): 116 | preprocess_config = yaml.load( 117 | open(f"config/{config_name}/preprocess.yaml", "r"), 118 | Loader=yaml.FullLoader 119 | ) 120 | model_config = yaml.load(open(f"config/{config_name}/model.yaml", "r"), Loader=yaml.FullLoader) 121 | train_config = yaml.load(open(f"config/{config_name}/train.yaml", "r"), Loader=yaml.FullLoader) 122 | return preprocess_config, model_config, train_config 123 | 124 | if __name__ == "__main__": 125 | # Read Config 126 | preprocess_config, model_config, train_config = configs = getConfig("zjl_enc_detach") 127 | 128 | # Get model 129 | fs2_model = FastSpeech2(preprocess_config, model_config).to(device) 130 | ckpt_path = os.path.join( 131 | train_config["path"]["ckpt_path"], 132 | "900000.pth.tar", 133 | ) 134 | ckpt = torch.load(ckpt_path) 135 | fs2_model.load_state_dict(ckpt["model"]) 136 | fs2_model.eval() 137 | fs2_model.requires_grad_ = False 138 | 139 | preprocess_config, model_config, train_config = configs = getConfig("zdl2_split") 140 | diffvar_model = DiffVariancePredictor(model_config).to(device) 141 | ckpt_path = os.path.join( 142 | train_config["path"]["ckpt_path"], 143 | "900000.pth.tar", 144 | ) 145 | ckpt = torch.load(ckpt_path) 146 | diffvar_model.load_state_dict(ckpt["model"]) 147 | diffvar_model.eval() 148 | diffvar_model.requires_grad_ = False 149 | print(get_param_num(diffvar_model)) 150 | 151 | # Load vocoder 152 | vocoder = get_vocoder(model_config, device) 153 | 154 | # Get dataset 155 | dataset = Dataset( 156 | "val.txt", 157 | preprocess_config, 158 | train_config, 159 | sort=False, drop_last=False 160 | ) 161 | dataloader = DataLoader( 162 | dataset, 163 | batch_size=16, 164 | collate_fn=dataset.collate_fn, 165 | ) 166 | 167 | synthesize( 168 | fs2_model, configs, vocoder, dataloader, 169 | f"output/result", 170 | diffvar_model=diffvar_model, 171 | use_gt_var=False, 172 | ) -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.zjl_symbols import symbols, quasi_symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | """ 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names).split()) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 38 | sequence += _arpabet_to_sequence(m.group(2)) 39 | text = m.group(3) 40 | 41 | return sequence 42 | 43 | 44 | def sequence_to_text(sequence): 45 | """Converts a sequence of IDs back to a string""" 46 | result = "" 47 | for symbol_id in sequence: 48 | if symbol_id in _id_to_symbol: 49 | s = _id_to_symbol[symbol_id] 50 | # Enclose ARPAbet back in curly braces: 51 | if len(s) > 1 and s[0] == "@": 52 | s = "{%s}" % s[1:] 53 | result += s 54 | return result.replace("}{", " ") 55 | 56 | 57 | def _clean_text(text, cleaner_names): 58 | for name in cleaner_names: 59 | cleaner = getattr(cleaners, name) 60 | if not cleaner: 61 | raise Exception("Unknown cleaner: %s" % name) 62 | text = cleaner(text) 63 | return text 64 | 65 | 66 | def _symbols_to_sequence(symbols): 67 | return [[_symbol_to_id[s], int(s in quasi_symbols)] for s in symbols if _should_keep_symbol(s)] 68 | 69 | 70 | def _arpabet_to_sequence(text): 71 | return _symbols_to_sequence(["@" + s for s in text.split()]) 72 | 73 | 74 | def _should_keep_symbol(s): 75 | return s in _symbol_to_id and s != "_" and s != "~" 76 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from text import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | -------------------------------------------------------------------------------- /text/zjl_symbols.py: -------------------------------------------------------------------------------- 1 | _pad = "_" 2 | 3 | _final = [f"{s}{i}" for s in ['u', 'o', 'iy', 'er', 'i', 'ng', 'v', 'e', 'a', 'n', 'ix'] for i in range(1,7)] 4 | _init_mid = ['a', 'b', 'c', 'ch', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'ng', 'o', 'p', 'q', 'r', 's', 'sh', 't', 'u', 'v', 'x', 'z', 'zh'] 5 | 6 | _prosodic = [f'#{i}' for i in range(1, 5)] 7 | _sil = [f'{sym}@{lv}' for sym in ['sil', 'pau'] for lv in list(range(5))+['S']] 8 | 9 | _punc = ['!', '"', "'", '(', ')', ',', '.', ':', ';', '?', '—', '…', '、', '《', '》'] 10 | 11 | quasi_symbols = _punc + _prosodic 12 | 13 | symbols = ( 14 | [_pad] 15 | + _init_mid 16 | + _final 17 | + _prosodic 18 | + _sil 19 | + _punc 20 | ) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from turtle import st 4 | 5 | import torch 6 | import yaml 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | from tqdm import tqdm 11 | 12 | from utils.model import get_vocoder, get_param_num, ScheduledOptim 13 | from utils.tools import to_device 14 | from model import FastSpeech2, DiffVariancePredictor 15 | from dataset import Dataset 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | print(device) 19 | 20 | def main(args, configs): 21 | print("Prepare training ...") 22 | 23 | preprocess_config, model_config, train_config = configs 24 | 25 | # Get dataset 26 | dataset = Dataset( 27 | "train.txt", preprocess_config, train_config, sort=True, drop_last=True 28 | ) 29 | val_dataset = Dataset( 30 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False 31 | ) 32 | batch_size = train_config["optimizer"]["batch_size"] 33 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset 34 | assert batch_size * group_size < len(dataset) 35 | loader = DataLoader( 36 | dataset, 37 | batch_size=batch_size * group_size, 38 | shuffle=True, 39 | collate_fn=dataset.collate_fn, 40 | ) 41 | val_loader = DataLoader( 42 | val_dataset, 43 | batch_size=batch_size, 44 | shuffle=False, 45 | collate_fn=val_dataset.collate_fn, 46 | ) 47 | 48 | # Prepare model 49 | fs2_model = FastSpeech2(preprocess_config, model_config) 50 | fs2_model.to(device) 51 | pre_trained_fs2_model = torch.load(args.fs_path) 52 | fs2_model.load_state_dict(pre_trained_fs2_model["model"],strict=True) 53 | 54 | model = DiffVariancePredictor(model_config) 55 | model.to(device) 56 | 57 | optimizers = [ScheduledOptim( 58 | d_fn, train_config, model_config, 0 59 | ) for d_fn in model.denoise_fn] 60 | 61 | print("Number of FS2 Total Parameters:", get_param_num(fs2_model)) 62 | print("Number of FS2 pitch predictor:", get_param_num(fs2_model.variance_adaptor.pitch_predictor)) 63 | print("Number of FS2 energy predictor:", get_param_num(fs2_model.variance_adaptor.energy_predictor)) 64 | print("Number of FS2 duration predictor:", get_param_num(fs2_model.variance_adaptor.duration_predictor)) 65 | print("Number of DiffVar predictor:", get_param_num(model)) 66 | 67 | # Load vocoder 68 | vocoder = get_vocoder(model_config, device) 69 | 70 | # Init logger 71 | for p in train_config["path"].values(): 72 | os.makedirs(p, exist_ok=True) 73 | train_log_path = os.path.join(train_config["path"]["log_path"], "train") 74 | val_log_path = os.path.join(train_config["path"]["log_path"], "val") 75 | os.makedirs(train_log_path, exist_ok=True) 76 | os.makedirs(val_log_path, exist_ok=True) 77 | train_logger = SummaryWriter(train_log_path) 78 | val_logger = SummaryWriter(val_log_path) 79 | 80 | # Training 81 | step = args.restore_step + 1 82 | epoch = 1 83 | grad_acc_step = train_config["optimizer"]["grad_acc_step"] 84 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] 85 | total_step = train_config["step"]["total_step"] 86 | log_step = train_config["step"]["log_step"] 87 | save_step = train_config["step"]["save_step"] 88 | synth_step = train_config["step"]["synth_step"] 89 | val_step = train_config["step"]["val_step"] 90 | 91 | outer_bar = tqdm(total=total_step, desc="Training", position=0) 92 | outer_bar.n = args.restore_step 93 | outer_bar.update() 94 | 95 | while True: 96 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) 97 | for batchs in loader: 98 | for batch in batchs: 99 | batch = to_device(batch, device) 100 | 101 | # Forward 102 | losses = model.training_step(batch[2:], fs2_model) 103 | 104 | # Cal Loss & Backward 105 | for ch_loss in losses: 106 | ch_loss = ch_loss / grad_acc_step 107 | ch_loss.backward() 108 | 109 | if step % grad_acc_step == 0: 110 | 111 | # Clipping gradients to avoid gradient explosion 112 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) 113 | 114 | # Update weights 115 | for optimizer in optimizers: 116 | new_lr = optimizer.step_and_update_lr() 117 | optimizer.zero_grad() 118 | 119 | 120 | 121 | if step % log_step == 0: 122 | message1 = "Step {}/{}, ".format(step, total_step) 123 | 124 | loss_msgs, losses = [], [l.item() for l in losses] 125 | for l, idx in zip(losses, model.in_dims): 126 | sidx = f"Total Loss{str[idx]}" 127 | loss_msgs.append(f"{sidx}: {l:.4f}") 128 | train_logger.add_scalar(f"Loss/{sidx}", l, step) 129 | 130 | total_msg = message1 + ' '.join(loss_msgs) 131 | with open(os.path.join(train_log_path, "log.txt"), "a") as f: 132 | f.write(total_msg + "\n") 133 | outer_bar.write(total_msg) 134 | 135 | train_logger.add_scalar("Learning_rate", new_lr, step) 136 | 137 | if step % val_step == 0: 138 | model.eval() 139 | 140 | mse_loss = nn.MSELoss() 141 | p_loss, e_loss, log_d_loss = 0., 0., 0. 142 | p_preds, e_preds, log_d_preds = [], [], [] 143 | p_tgts, e_tgts, log_d_tgts = [], [], [] 144 | for v_batchs in val_loader: 145 | for v_batch in v_batchs: 146 | v_batch = to_device(v_batch, device) 147 | var, cond, var_mask = model.validation_step(v_batch[2:], fs2_model) 148 | var_mask = ~var_mask 149 | p_targets = v_batch[-4].masked_select(var_mask) 150 | e_targets = v_batch[-3].masked_select(var_mask) 151 | log_d_targets = torch.log(v_batch[-2].masked_select(var_mask)) 152 | p_pred = var[:, :, 0].masked_select(var_mask) 153 | e_pred = var[:, :, 1].masked_select(var_mask) 154 | log_d_pred = var[:, :, 2].masked_select(var_mask) 155 | p_loss += mse_loss(p_pred, p_targets)# * len(v_batch[0]) 156 | e_loss += mse_loss(e_pred, e_targets)# * len(v_batch[0]) 157 | log_d_loss += mse_loss(log_d_pred, log_d_targets)# * len(v_batch[0]) 158 | 159 | p_preds.append(p_pred) 160 | e_preds.append(e_pred) 161 | log_d_preds.append(log_d_pred) 162 | p_tgts.append(p_targets) 163 | e_tgts.append(e_targets) 164 | log_d_tgts.append(log_d_targets) 165 | 166 | break 167 | break 168 | 169 | 170 | val_logger.add_scalar("Loss/pitch_loss", p_loss, step) 171 | val_logger.add_scalar("Loss/energy_loss", e_loss, step) 172 | val_logger.add_scalar("Loss/duration_loss", log_d_loss, step) 173 | 174 | p_preds = torch.cat(p_preds) 175 | e_preds = torch.cat(e_preds) 176 | log_d_preds = torch.cat(log_d_preds) 177 | p_tgts = torch.cat(p_tgts) 178 | e_tgts = torch.cat(e_tgts) 179 | log_d_tgts = torch.cat(log_d_tgts) 180 | val_logger.add_histogram("Hist/pitch_prediction", p_preds, step) 181 | val_logger.add_histogram("Hist/energy_prediction", e_preds, step) 182 | val_logger.add_histogram("Hist/duration_prediction", log_d_preds, step) 183 | val_logger.add_histogram("Hist/pitch_targets", p_targets, step) 184 | val_logger.add_histogram("Hist/energy_targets", e_tgts, step) 185 | val_logger.add_histogram("Hist/duration_targets", log_d_tgts, step) 186 | 187 | for k, v in model.named_parameters(): 188 | val_logger.add_histogram(k, v, step) 189 | 190 | message = f"Validation Step {step},Pitch loss: {p_loss:.4f}, Energy loss: {e_loss:.4f}, Duration loss: {log_d_loss:.4f}" 191 | 192 | with open(os.path.join(val_log_path, "log.txt"), "a") as f: 193 | f.write(message + "\n") 194 | outer_bar.write(message) 195 | 196 | model.train() 197 | 198 | if step % save_step == 0: 199 | torch.save( 200 | { 201 | "model": model.state_dict(), 202 | "optimizers": [opt._optimizer.state_dict() for opt in optimizers], 203 | }, 204 | os.path.join( 205 | train_config["path"]["ckpt_path"], 206 | "{}.pth.tar".format(step), 207 | ), 208 | ) 209 | 210 | if step == total_step: 211 | quit() 212 | step += 1 213 | outer_bar.update(1) 214 | 215 | inner_bar.update(1) 216 | epoch += 1 217 | 218 | 219 | if __name__ == "__main__": 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument("--restore_step", type=int, default=0) 222 | parser.add_argument( 223 | "-p", 224 | "--preprocess_config", 225 | type=str, 226 | required=True, 227 | help="path to preprocess.yaml", 228 | ) 229 | parser.add_argument( 230 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 231 | ) 232 | parser.add_argument( 233 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 234 | ) 235 | parser.add_argument( 236 | "-f", "--fs_path", type=str, required=True, help="path to pre-trained FastSpeech2 checkpoint" 237 | ) 238 | args = parser.parse_args() 239 | 240 | # Read Config 241 | preprocess_config = yaml.load( 242 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 243 | ) 244 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 245 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 246 | configs = (preprocess_config, model_config, train_config) 247 | 248 | main(args, configs) 249 | -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | 11 | class FFTBlock(torch.nn.Module): 12 | """FFT Block""" 13 | 14 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 15 | super(FFTBlock, self).__init__() 16 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 17 | self.pos_ffn = PositionwiseFeedForward( 18 | d_model, d_inner, kernel_size, dropout=dropout 19 | ) 20 | 21 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 22 | enc_output, enc_slf_attn = self.slf_attn( 23 | enc_input, enc_input, enc_input, mask=slf_attn_mask 24 | ) 25 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 26 | 27 | enc_output = self.pos_ffn(enc_output) 28 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 29 | 30 | return enc_output, enc_slf_attn 31 | 32 | 33 | class ConvNorm(torch.nn.Module): 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels, 38 | kernel_size=1, 39 | stride=1, 40 | padding=None, 41 | dilation=1, 42 | bias=True, 43 | w_init_gain="linear", 44 | ): 45 | super(ConvNorm, self).__init__() 46 | 47 | if padding is None: 48 | assert kernel_size % 2 == 1 49 | padding = int(dilation * (kernel_size - 1) / 2) 50 | 51 | self.conv = torch.nn.Conv1d( 52 | in_channels, 53 | out_channels, 54 | kernel_size=kernel_size, 55 | stride=stride, 56 | padding=padding, 57 | dilation=dilation, 58 | bias=bias, 59 | ) 60 | 61 | def forward(self, signal): 62 | conv_signal = self.conv(signal) 63 | 64 | return conv_signal 65 | 66 | 67 | class PostNet(nn.Module): 68 | """ 69 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 70 | """ 71 | 72 | def __init__( 73 | self, 74 | n_mel_channels=80, 75 | postnet_embedding_dim=512, 76 | postnet_kernel_size=5, 77 | postnet_n_convolutions=5, 78 | ): 79 | 80 | super(PostNet, self).__init__() 81 | self.convolutions = nn.ModuleList() 82 | 83 | self.convolutions.append( 84 | nn.Sequential( 85 | ConvNorm( 86 | n_mel_channels, 87 | postnet_embedding_dim, 88 | kernel_size=postnet_kernel_size, 89 | stride=1, 90 | padding=int((postnet_kernel_size - 1) / 2), 91 | dilation=1, 92 | w_init_gain="tanh", 93 | ), 94 | nn.BatchNorm1d(postnet_embedding_dim), 95 | ) 96 | ) 97 | 98 | for i in range(1, postnet_n_convolutions - 1): 99 | self.convolutions.append( 100 | nn.Sequential( 101 | ConvNorm( 102 | postnet_embedding_dim, 103 | postnet_embedding_dim, 104 | kernel_size=postnet_kernel_size, 105 | stride=1, 106 | padding=int((postnet_kernel_size - 1) / 2), 107 | dilation=1, 108 | w_init_gain="tanh", 109 | ), 110 | nn.BatchNorm1d(postnet_embedding_dim), 111 | ) 112 | ) 113 | 114 | self.convolutions.append( 115 | nn.Sequential( 116 | ConvNorm( 117 | postnet_embedding_dim, 118 | n_mel_channels, 119 | kernel_size=postnet_kernel_size, 120 | stride=1, 121 | padding=int((postnet_kernel_size - 1) / 2), 122 | dilation=1, 123 | w_init_gain="linear", 124 | ), 125 | nn.BatchNorm1d(n_mel_channels), 126 | ) 127 | ) 128 | 129 | def forward(self, x): 130 | x = x.contiguous().transpose(1, 2) 131 | 132 | for i in range(len(self.convolutions) - 1): 133 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 134 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 135 | 136 | x = x.contiguous().transpose(1, 2) 137 | return x 138 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import transformer.Constants as Constants 6 | from .Layers import FFTBlock 7 | from text.symbols import symbols 8 | 9 | 10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 11 | """ Sinusoid position encoding table """ 12 | 13 | def cal_angle(position, hid_idx): 14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 15 | 16 | def get_posi_angle_vec(position): 17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 18 | 19 | sinusoid_table = np.array( 20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 21 | ) 22 | 23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 25 | 26 | if padding_idx is not None: 27 | # zero vector for padding dimension 28 | sinusoid_table[padding_idx] = 0.0 29 | 30 | return torch.FloatTensor(sinusoid_table) 31 | 32 | 33 | class Encoder(nn.Module): 34 | """ Encoder """ 35 | 36 | def __init__(self, config): 37 | super(Encoder, self).__init__() 38 | 39 | n_position = config["max_seq_len"] + 1 40 | n_src_vocab = len(symbols) + 1 41 | d_word_vec = config["transformer"]["encoder_hidden"] 42 | n_layers = config["transformer"]["encoder_layer"] 43 | n_head = config["transformer"]["encoder_head"] 44 | d_k = d_v = ( 45 | config["transformer"]["encoder_hidden"] 46 | // config["transformer"]["encoder_head"] 47 | ) 48 | d_model = config["transformer"]["encoder_hidden"] 49 | d_inner = config["transformer"]["conv_filter_size"] 50 | kernel_size = config["transformer"]["conv_kernel_size"] 51 | dropout = config["transformer"]["encoder_dropout"] 52 | 53 | self.max_seq_len = config["max_seq_len"] 54 | self.d_model = d_model 55 | 56 | self.src_word_emb = nn.Embedding( 57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD 58 | ) 59 | self.position_enc = nn.Parameter( 60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 61 | requires_grad=False, 62 | ) 63 | 64 | self.layer_stack = nn.ModuleList( 65 | [ 66 | FFTBlock( 67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 68 | ) 69 | for _ in range(n_layers) 70 | ] 71 | ) 72 | 73 | def forward(self, src_seq, mask, return_attns=False): 74 | 75 | enc_slf_attn_list = [] 76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 77 | 78 | # -- Prepare masks 79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 80 | 81 | # -- Forward 82 | if not self.training and src_seq.shape[1] > self.max_seq_len: 83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 84 | src_seq.shape[1], self.d_model 85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 86 | src_seq.device 87 | ) 88 | else: 89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[ 90 | :, :max_len, : 91 | ].expand(batch_size, -1, -1) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 96 | ) 97 | if return_attns: 98 | enc_slf_attn_list += [enc_slf_attn] 99 | 100 | return enc_output 101 | 102 | 103 | class Decoder(nn.Module): 104 | """ Decoder """ 105 | 106 | def __init__(self, config): 107 | super(Decoder, self).__init__() 108 | 109 | n_position = config["max_seq_len"] + 1 110 | d_word_vec = config["transformer"]["decoder_hidden"] 111 | n_layers = config["transformer"]["decoder_layer"] 112 | n_head = config["transformer"]["decoder_head"] 113 | d_k = d_v = ( 114 | config["transformer"]["decoder_hidden"] 115 | // config["transformer"]["decoder_head"] 116 | ) 117 | d_model = config["transformer"]["decoder_hidden"] 118 | d_inner = config["transformer"]["conv_filter_size"] 119 | kernel_size = config["transformer"]["conv_kernel_size"] 120 | dropout = config["transformer"]["decoder_dropout"] 121 | 122 | self.max_seq_len = config["max_seq_len"] 123 | self.d_model = d_model 124 | 125 | self.position_enc = nn.Parameter( 126 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 127 | requires_grad=False, 128 | ) 129 | 130 | self.layer_stack = nn.ModuleList( 131 | [ 132 | FFTBlock( 133 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 134 | ) 135 | for _ in range(n_layers) 136 | ] 137 | ) 138 | 139 | def forward(self, enc_seq, mask, return_attns=False): 140 | 141 | dec_slf_attn_list = [] 142 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 143 | 144 | # -- Forward 145 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 146 | # -- Prepare masks 147 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 148 | dec_output = enc_seq + get_sinusoid_encoding_table( 149 | enc_seq.shape[1], self.d_model 150 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 151 | enc_seq.device 152 | ) 153 | else: 154 | max_len = min(max_len, self.max_seq_len) 155 | 156 | # -- Prepare masks 157 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 158 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 159 | :, :max_len, : 160 | ].expand(batch_size, -1, -1) 161 | mask = mask[:, :max_len] 162 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 163 | 164 | for dec_layer in self.layer_stack: 165 | dec_output, dec_slf_attn = dec_layer( 166 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 167 | ) 168 | if return_attns: 169 | dec_slf_attn_list += [dec_slf_attn] 170 | 171 | return dec_output, mask 172 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | 49 | output = output.view(n_head, sz_b, len_q, d_v) 50 | output = ( 51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 52 | ) # b x lq x (n*dv) 53 | 54 | output = self.dropout(self.fc(output)) 55 | output = self.layer_norm(output + residual) 56 | 57 | return output, attn 58 | 59 | 60 | class PositionwiseFeedForward(nn.Module): 61 | """ A two-feed-forward-layer module """ 62 | 63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 64 | super().__init__() 65 | 66 | # Use Conv1D 67 | # position-wise 68 | self.w_1 = nn.Conv1d( 69 | d_in, 70 | d_hid, 71 | kernel_size=kernel_size[0], 72 | padding=(kernel_size[0] - 1) // 2, 73 | ) 74 | # position-wise 75 | self.w_2 = nn.Conv1d( 76 | d_hid, 77 | d_in, 78 | kernel_size=kernel_size[1], 79 | padding=(kernel_size[1] - 1) // 2, 80 | ) 81 | 82 | self.layer_norm = nn.LayerNorm(d_in) 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def forward(self, x): 86 | residual = x 87 | output = x.transpose(1, 2) 88 | output = self.w_2(F.relu(self.w_1(output))) 89 | output = output.transpose(1, 2) 90 | output = self.dropout(output) 91 | output = self.layer_norm(output + residual) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import Encoder, Decoder 2 | from .Layers import PostNet -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from model import FastSpeech2, GaussianDiffusion, ScheduledOptim 8 | 9 | from hfgan.vocoder import Vocoder 10 | 11 | 12 | def get_model(args, configs, device, train=False): 13 | (preprocess_config, model_config, train_config) = configs 14 | 15 | ModelCls = GaussianDiffusion if "model_cls" in model_config and \ 16 | model_config["model_cls"]=="GaussianDiffusion" else FastSpeech2 17 | 18 | model = ModelCls(preprocess_config, model_config).to(device) 19 | if args.restore_step: 20 | ckpt_path = os.path.join( 21 | train_config["path"]["ckpt_path"], 22 | "{}.pth.tar".format(args.restore_step), 23 | ) 24 | ckpt = torch.load(ckpt_path) 25 | model.load_state_dict(ckpt["model"]) 26 | 27 | if train: 28 | scheduled_optim = ScheduledOptim( 29 | model, train_config, model_config, args.restore_step 30 | ) 31 | if args.restore_step: 32 | scheduled_optim.load_state_dict(ckpt["optimizer"]) 33 | model.train() 34 | return model, scheduled_optim 35 | 36 | model.eval() 37 | model.requires_grad_ = False 38 | return model 39 | 40 | 41 | def get_param_num(model): 42 | num_param = sum(param.numel() for param in model.parameters()) 43 | return num_param 44 | 45 | 46 | def get_vocoder(config, device): 47 | name = config["vocoder"]["model"] 48 | speaker = config["vocoder"]["speaker"] 49 | 50 | if name == "MelGAN": 51 | if speaker == "LJSpeech": 52 | vocoder = torch.hub.load( 53 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 54 | ) 55 | elif speaker == "universal": 56 | vocoder = torch.hub.load( 57 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 58 | ) 59 | vocoder.mel2wav.eval() 60 | vocoder.mel2wav.to(device) 61 | elif name =="hifigan": 62 | ckpt_path = config["vocoder"]["ckpt"] 63 | vocoder = Vocoder(ckpt_path, device=device) 64 | 65 | return vocoder 66 | 67 | 68 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): 69 | name = model_config["vocoder"]["model"] 70 | with torch.no_grad(): 71 | if name == "MelGAN": 72 | wavs = vocoder.inverse(mels / np.log(10)) 73 | elif name == "hifigan": 74 | wavs = [vocoder.mel2wav(mel) for mel in mels] 75 | 76 | # wavs = ( 77 | # wavs.cpu().numpy() 78 | # * preprocess_config["preprocessing"]["audio"]["max_wav_value"] 79 | # ).astype("int16") 80 | # wavs = [wav for wav in wavs] 81 | 82 | for i in range(len(mels)): 83 | if lengths is not None: 84 | wavs[i] = wavs[i][: lengths[i]] 85 | 86 | return wavs 87 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import matplotlib 8 | from scipy.io import wavfile 9 | from matplotlib import pyplot as plt 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | 12 | matplotlib.use("Agg") 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def to_device(data, device): 19 | if len(data) == 12+1: 20 | ( 21 | ids, 22 | raw_texts, 23 | speakers, 24 | texts, 25 | src_lens, 26 | max_src_len, 27 | mels, 28 | mel_lens, 29 | max_mel_len, 30 | pitches, 31 | energies, 32 | durations, 33 | quasi_symbols, 34 | ) = data 35 | 36 | speakers = torch.from_numpy(speakers).long().to(device) 37 | texts = torch.from_numpy(texts).long().to(device) 38 | src_lens = torch.from_numpy(src_lens).to(device) 39 | mels = torch.from_numpy(mels).float().to(device) 40 | mel_lens = torch.from_numpy(mel_lens).to(device) 41 | pitches = torch.from_numpy(pitches).float().to(device) 42 | energies = torch.from_numpy(energies).float().to(device) 43 | durations = torch.from_numpy(durations).long().to(device) 44 | 45 | quasi_symbols = torch.from_numpy(quasi_symbols).long().to(device) 46 | 47 | return ( 48 | ids, 49 | raw_texts, 50 | speakers, 51 | texts, 52 | src_lens, 53 | max_src_len, 54 | mels, 55 | mel_lens, 56 | max_mel_len, 57 | pitches, 58 | energies, 59 | durations, 60 | 61 | quasi_symbols, 62 | ) 63 | 64 | if len(data) == 6+1: 65 | (ids, raw_texts, speakers, texts, src_lens, max_src_len, quasi_flags) = data 66 | 67 | speakers = torch.from_numpy(speakers).long().to(device) 68 | texts = torch.from_numpy(texts).long().to(device) 69 | src_lens = torch.from_numpy(src_lens).to(device) 70 | quasi_flags = torch.from_numpy(quasi_flags).to(device) 71 | 72 | return (ids, raw_texts, speakers, texts, src_lens, max_src_len, quasi_flags) 73 | 74 | 75 | def log( 76 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 77 | ): 78 | if losses is not None: 79 | logger.add_scalar("Loss/total_loss", losses[0], step) 80 | logger.add_scalar("Loss/mel_loss", losses[1], step) 81 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) 82 | logger.add_scalar("Loss/pitch_loss", losses[3], step) 83 | logger.add_scalar("Loss/energy_loss", losses[4], step) 84 | logger.add_scalar("Loss/duration_loss", losses[5], step) 85 | 86 | if fig is not None: 87 | logger.add_figure(tag, fig) 88 | 89 | if audio is not None: 90 | if not isinstance(audio, torch.Tensor): 91 | audio = torch.FloatTensor(audio) 92 | logger.add_audio( 93 | tag, 94 | audio / max(abs(audio)), 95 | sample_rate=sampling_rate, 96 | ) 97 | 98 | 99 | def get_mask_from_lengths(lengths, max_len=None): 100 | batch_size = lengths.shape[0] 101 | if max_len is None: 102 | max_len = torch.max(lengths).item() 103 | 104 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 105 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 106 | 107 | return mask 108 | 109 | 110 | def expand(values, durations): 111 | out = list() 112 | for value, d in zip(values, durations): 113 | out += [value] * max(0, int(d)) 114 | return np.array(out) 115 | 116 | 117 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): 118 | 119 | basename = targets[0][0] 120 | src_len = predictions[8][0].item() 121 | mel_len = predictions[9][0].item() 122 | mel_target = targets[6][0, :mel_len].detach().transpose(0, 1) 123 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) 124 | duration = targets[11][0, :src_len].detach().cpu().numpy() 125 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 126 | pitch = targets[9][0, :src_len].detach().cpu().numpy() 127 | pitch = expand(pitch, duration) 128 | else: 129 | pitch = targets[9][0, :mel_len].detach().cpu().numpy() 130 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 131 | energy = targets[10][0, :src_len].detach().cpu().numpy() 132 | energy = expand(energy, duration) 133 | else: 134 | energy = targets[10][0, :mel_len].detach().cpu().numpy() 135 | 136 | with open( 137 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 138 | ) as f: 139 | stats = json.load(f) 140 | stats = stats["pitch"] + stats["energy"][:2] 141 | 142 | fig = plot_mel( 143 | [ 144 | (mel_prediction.cpu().numpy(), pitch, energy), 145 | (mel_target.cpu().numpy(), pitch, energy), 146 | ], 147 | stats, 148 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 149 | ) 150 | 151 | if vocoder is not None: 152 | from .model import vocoder_infer 153 | 154 | wav_reconstruction = vocoder_infer( 155 | mel_target.unsqueeze(0), 156 | vocoder, 157 | model_config, 158 | preprocess_config, 159 | )[0] 160 | wav_prediction = vocoder_infer( 161 | mel_prediction.unsqueeze(0), 162 | vocoder, 163 | model_config, 164 | preprocess_config, 165 | )[0] 166 | else: 167 | wav_reconstruction = wav_prediction = None 168 | 169 | return fig, wav_reconstruction, wav_prediction, basename 170 | 171 | 172 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path, comp_mels=None): 173 | 174 | basenames = targets[0] 175 | for i in range(len(predictions[0])): 176 | basename = basenames[i] 177 | src_len = predictions[8][i].item() 178 | mel_len = predictions[9][i].item() 179 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) 180 | duration = predictions[5][i, :src_len].detach().cpu().numpy() 181 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 182 | pitch = predictions[2][i, :src_len].detach().cpu().numpy() 183 | pitch = expand(pitch, duration) 184 | else: 185 | pitch = predictions[2][i, :mel_len].detach().cpu().numpy() 186 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 187 | energy = predictions[3][i, :src_len].detach().cpu().numpy() 188 | energy = expand(energy, duration) 189 | else: 190 | energy = predictions[3][i, :mel_len].detach().cpu().numpy() 191 | 192 | with open( 193 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 194 | ) as f: 195 | stats = json.load(f) 196 | stats = stats["pitch"] + stats["energy"][:2] 197 | 198 | if comp_mels is None: 199 | fig = plot_mel( 200 | [ 201 | (mel_prediction.cpu().numpy(), pitch, energy), 202 | ], 203 | stats, 204 | ["Synthetized Spectrogram"], 205 | ) 206 | else: 207 | comp_mel = comp_mels[i, :mel_len].detach().transpose(0, 1) 208 | fig = plot_mel_diff( 209 | mel_prediction.cpu().numpy(), 210 | comp_mel.cpu().numpy(), 211 | pitch, energy, pitch, energy, 212 | stats, 213 | "DiffSpeech", 214 | "FastSpeech2", 215 | ) 216 | plt.savefig(os.path.join(path, "{}.png".format(basename))) 217 | plt.close() 218 | 219 | from .model import vocoder_infer 220 | 221 | mel_predictions = predictions[1].transpose(1, 2) 222 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] 223 | wav_predictions = vocoder_infer( 224 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths 225 | ) 226 | 227 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 228 | for wav, basename in zip(wav_predictions, basenames): 229 | wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) 230 | 231 | 232 | def plot_mel_diff(mel1, mel2, pitch1, energy1, pitch2, energy2, stats, title1, title2): 233 | data = [ 234 | (mel1, pitch1, energy1), 235 | (mel2, pitch2, energy2), 236 | (np.abs(mel2-mel1), pitch2-pitch1, energy2-energy1), 237 | ] 238 | fig, axes = plt.subplots(3, 1, squeeze=False, figsize=(10, 16)) 239 | titles = [title1, title2, f"{title2}-{title1}"] 240 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 241 | pitch_min = pitch_min * pitch_std + pitch_mean 242 | pitch_max = pitch_max * pitch_std + pitch_mean 243 | 244 | def add_axis(fig, old_ax): 245 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 246 | ax.set_facecolor("None") 247 | return ax 248 | 249 | for i in range(len(data)): 250 | mel, pitch, energy = data[i] 251 | pitch = pitch * pitch_std + pitch_mean 252 | im = axes[i][0].imshow(mel, origin="lower") 253 | # divider = make_axes_locatable(axes[i][0]) 254 | # cax = divider.append_axes("right", size="5%", pad=0.05) 255 | # fig.colorbar(im, cax=cax) 256 | axes[i][0].set_aspect(2.5, adjustable="box") 257 | axes[i][0].set_ylim(0, mel.shape[0]) 258 | axes[i][0].set_title(titles[i], fontsize="medium") 259 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 260 | axes[i][0].set_anchor("W") 261 | 262 | ax1 = add_axis(fig, axes[i][0]) 263 | ax1.plot(pitch, color="tomato") 264 | ax1.set_xlim(0, mel.shape[1]) 265 | if i < len(data)-1: 266 | ax1.set_ylim(0, pitch_max) 267 | else: 268 | ax1.set_ylim(pitch.min()-1., pitch.max()+1.) 269 | ax1.set_ylabel("F0", color="tomato") 270 | ax1.tick_params( 271 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 272 | ) 273 | 274 | ax2 = add_axis(fig, axes[i][0]) 275 | ax2.plot(energy, color="darkviolet") 276 | ax2.set_xlim(0, mel.shape[1]) 277 | if i < len(data)-1: 278 | ax2.set_ylim(energy_min, energy_max) 279 | else: 280 | ax1.set_ylim(energy.min()-1., energy.max()+1.) 281 | ax2.set_ylabel("Energy", color="darkviolet") 282 | ax2.yaxis.set_label_position("right") 283 | ax2.tick_params( 284 | labelsize="x-small", 285 | colors="darkviolet", 286 | bottom=False, 287 | labelbottom=False, 288 | left=False, 289 | labelleft=False, 290 | right=True, 291 | labelright=True, 292 | ) 293 | 294 | return fig 295 | 296 | def plot_mel(data, stats, titles): 297 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 298 | if titles is None: 299 | titles = [None for i in range(len(data))] 300 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 301 | pitch_min = pitch_min * pitch_std + pitch_mean 302 | pitch_max = pitch_max * pitch_std + pitch_mean 303 | 304 | def add_axis(fig, old_ax): 305 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 306 | ax.set_facecolor("None") 307 | return ax 308 | 309 | for i in range(len(data)): 310 | mel, pitch, energy = data[i] 311 | pitch = pitch * pitch_std + pitch_mean 312 | axes[i][0].imshow(mel, origin="lower") 313 | axes[i][0].set_aspect(2.5, adjustable="box") 314 | axes[i][0].set_ylim(0, mel.shape[0]) 315 | axes[i][0].set_title(titles[i], fontsize="medium") 316 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 317 | axes[i][0].set_anchor("W") 318 | 319 | ax1 = add_axis(fig, axes[i][0]) 320 | ax1.plot(pitch, color="tomato") 321 | ax1.set_xlim(0, mel.shape[1]) 322 | ax1.set_ylim(0, pitch_max) 323 | ax1.set_ylabel("F0", color="tomato") 324 | ax1.tick_params( 325 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 326 | ) 327 | 328 | ax2 = add_axis(fig, axes[i][0]) 329 | ax2.plot(energy, color="darkviolet") 330 | ax2.set_xlim(0, mel.shape[1]) 331 | ax2.set_ylim(energy_min, energy_max) 332 | ax2.set_ylabel("Energy", color="darkviolet") 333 | ax2.yaxis.set_label_position("right") 334 | ax2.tick_params( 335 | labelsize="x-small", 336 | colors="darkviolet", 337 | bottom=False, 338 | labelbottom=False, 339 | left=False, 340 | labelleft=False, 341 | right=True, 342 | labelright=True, 343 | ) 344 | 345 | return fig 346 | 347 | 348 | def pad_1D(inputs, PAD=0): 349 | def pad_data(x, length, PAD): 350 | x_padded = np.pad( 351 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 352 | ) 353 | return x_padded 354 | 355 | max_len = max((len(x) for x in inputs)) 356 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 357 | 358 | return padded 359 | 360 | 361 | def pad_2D(inputs, maxlen=None): 362 | def pad(x, max_len): 363 | PAD = 0 364 | if np.shape(x)[0] > max_len: 365 | raise ValueError("not max_len") 366 | 367 | s = np.shape(x)[1] 368 | x_padded = np.pad( 369 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 370 | ) 371 | return x_padded[:, :s] 372 | 373 | if maxlen: 374 | output = np.stack([pad(x, maxlen) for x in inputs]) 375 | else: 376 | max_len = max(np.shape(x)[0] for x in inputs) 377 | output = np.stack([pad(x, max_len) for x in inputs]) 378 | 379 | return output 380 | 381 | 382 | def pad(input_ele, mel_max_length=None): 383 | if mel_max_length: 384 | max_len = mel_max_length 385 | else: 386 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 387 | 388 | out_list = list() 389 | for i, batch in enumerate(input_ele): 390 | if len(batch.shape) == 1: 391 | one_batch_padded = F.pad( 392 | batch, (0, max_len - batch.size(0)), "constant", 0.0 393 | ) 394 | elif len(batch.shape) == 2: 395 | one_batch_padded = F.pad( 396 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 397 | ) 398 | out_list.append(one_batch_padded) 399 | out_padded = torch.stack(out_list) 400 | return out_padded 401 | 402 | def drop_idxes(x, idxes, shrinkage: int=0): 403 | """ On BATCH basis """ 404 | padding_dim_prefix = [0] * ((len(x.size())-2) * 2) 405 | return torch.stack([ 406 | F.pad( 407 | row[torch.where(idx == 0)], 408 | padding_dim_prefix + [0, torch.sum(idx) - shrinkage] 409 | ) 410 | for row, idx in zip(x, idxes) 411 | ]) --------------------------------------------------------------------------------