├── README.md ├── configs ├── .DS_Store ├── hifigan-config.json └── train_grad.json ├── data_collate.py ├── data_loader.py ├── data_preparation.py ├── filelists ├── all_spks │ ├── eval_utts.txt │ ├── feats.ark │ ├── feats.scp │ ├── text │ ├── train_utts.txt │ ├── utt2emo.json │ └── utt2spk.json └── inference_generated.txt ├── inference_EMA.py ├── melspec.py ├── model ├── __init__.py ├── base.py ├── classifier.py ├── diffusion.py ├── monotonic_align │ ├── LICENCE │ ├── __init__.py │ ├── build │ │ ├── temp.linux-x86_64-3.6 │ │ │ └── core.o │ │ └── temp.macosx-10.9-x86_64-3.6 │ │ │ └── core.o │ ├── core.c │ ├── core.pyx │ ├── model │ │ └── monotonic_align │ │ │ ├── core.cpython-36m-darwin.so │ │ │ └── core.cpython-36m-x86_64-linux-gnu.so │ └── setup.py ├── text_encoder.py ├── tts.py └── utils.py ├── models.py ├── text ├── .DS_Store ├── LICENSE ├── __init__.py ├── cleaners.py ├── cmudict.py └── symbols.py ├── train_EMA.py ├── utils_data.py └── xutils.py /README.md: -------------------------------------------------------------------------------- 1 |

KazEmoTTS
⌨️ 😐 😠 🙂 😞 😱 😮 🗣

2 | 3 |

4 | 5 | GitHub stars 7 | 8 | 9 | GitHub issues 11 | 12 | 13 | ISSAI Official Website 15 | 16 |

17 | 18 |

This repository provides a dataset and a text-to-speech (TTS) model for the paper
KazEmoTTS: 19 | A Dataset for Kazakh Emotional Text-to-Speech Synthesis

20 | 21 |

Dataset Statistics 📊

22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 |
Emotion# recordingsNarrator F1Narrator M1Narrator M2
Total (h)Mean (s)Min (s)Max (s)Total (h)Mean (s)Min (s)Max (s)Total (h)Mean (s)Min (s)Max (s)
neutral9,3855.855.031.0315.514.544.770.8416.182.304.691.0215.81
angry9,0595.444.781.1114.094.274.750.9317.032.314.811.0215.67
happy9,0595.775.091.0715.334.434.850.9815.562.234.741.0915.25
sad8,9805.605.041.1115.214.625.130.7218.002.655.521.1618.16
scared9,0985.664.961.0015.674.134.510.6516.112.344.961.0714.49
surprised9,1795.915.091.0914.564.524.920.8117.672.284.871.0415.81
152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 |
Narrator# recordingsDuration (h)
F124,65634.23
M119,80226.51
M210,30214.11
Total54,76074.85
187 | 188 |

Installation 🛠️

189 | 190 |

First, you need to build the monotonic_align code:

191 | 192 | ```bash 193 | cd model/monotonic_align; python setup.py build_ext --inplace; cd ../.. 194 | ``` 195 | 196 | **Note**: Python version is 3.9.13 197 | 198 |

Pre-Processing Data for Training 🧹

199 | 200 |

You need to download the KazEmoTTS dataset and customize it, as in filelists/all_spk, by executing data_preparation.py:

201 | 202 | ```shell 203 | python data_preparation.py -d 204 | ``` 205 | 206 |

Training Stage 🏋️‍♂️

207 | 208 |

To initiate the training process, you must specify the path to the model configurations, which can be found in configs/train_grad.json, and designate a directory for checkpoints, typically located at logs/train_logs, to specify the GPU you will be using.

209 | 210 | ```shell 211 | CUDA_VISIBLE_DEVICES=YOUR_GPU_ID 212 | python train_EMA.py -c -m 213 | ``` 214 | 215 |

Inference 🧠

216 | 217 |

Pre-Training Stage 🏃

218 | 219 |

If you intend to utilize a pre-trained model, you will need to download the necessary checkpoints TTS, vocoder for both the TTS model based on GradTTS and HiFi-GAN.

220 | 221 | To conduct inference, follow these steps: 222 | 223 | - Create a text file containing the sentences you wish to synthesize, such as `filelists/inference_generated.txt`. 224 | - Specify the `txt` file format as follows: `text|emotion id|speaker id`. 225 | - Adjust the path to the HiFi-Gan checkpoint in `inference_EMA.py`. 226 | - Set the classifier guidance level to 100 using the `-g` parameter. 227 | 228 | ```shell 229 | python inference_EMA.py -c -m -t -g -f -r 230 | ``` 231 | 232 |

Synthesized samples 🔈

233 |

You can listen to some synthesized samples here.

234 | 235 |

Citation 🎓

236 | 237 |

We kindly urge you, if you incorporate our dataset and/or model into your work, to cite our paper as a gesture of recognition for its valuable contribution. The act of referencing the relevant sources not only upholds academic honesty but also ensures proper acknowledgement of the authors' efforts. Your citation in your research significantly contributes to the continuous progress and evolution of the scholarly realm. Your endorsement and acknowledgement of our endeavours are genuinely appreciated. 238 | 239 | ```bibtex 240 | @misc{abilbekov2024kazemotts, 241 | title={KazEmoTTS: A Dataset for Kazakh Emotional Text-to-Speech Synthesis}, 242 | author={Adal Abilbekov and Saida Mussakhojayeva and Rustem Yeshpanov and Huseyin Atakan Varol}, 243 | year={2024}, 244 | eprint={2404.01033}, 245 | archivePrefix={arXiv}, 246 | primaryClass={eess.AS} 247 | } 248 | ``` 249 | -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/configs/.DS_Store -------------------------------------------------------------------------------- /configs/hifigan-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 1, 4 | "batch_size": 64, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54320", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /configs/train_grad.json: -------------------------------------------------------------------------------- 1 | { 2 | "xvector": false, 3 | "pe": false, 4 | "train": { 5 | "test_size": 6, 6 | "n_epochs": 10000, 7 | "batch_size": 64, 8 | "learning_rate": 1e-4, 9 | "seed": 37, 10 | "save_every": 1, 11 | "use_gt_dur": false 12 | }, 13 | "data": { 14 | "load_mel_from_disk": false, 15 | "train_utts": "filelists/all_spks/train_utts.txt", 16 | "val_utts": "filelists/all_spks/eval_utts.txt", 17 | "train_utt2phns": "filelists/all_spks/text", 18 | "val_utt2phns": "filelists/all_spks/text", 19 | "train_feats_scp": "filelists/all_spks/feats.scp", 20 | "val_feats_scp": "filelists/all_spks/feats.scp", 21 | "train_utt2spk": "filelists/all_spks/utt2spk.json", 22 | "val_utt2spk": "filelists/all_spks/utt2spk.json", 23 | "train_utt2emo": "filelists/all_spks/utt2emo.json", 24 | "val_utt2emo": "filelists/all_spks/utt2emo.json", 25 | 26 | "train_var_scp": "", 27 | "val_var_scp": "", 28 | 29 | "text_cleaners": [ 30 | "kazakh_cleaners" 31 | ], 32 | "max_wav_value": 32768.0, 33 | "sampling_rate": 22050, 34 | "filter_length": 1024, 35 | "hop_length": 200, 36 | "win_length": 800, 37 | "n_mel_channels": 80, 38 | "mel_fmin": 20.0, 39 | "mel_fmax": 8000.0, 40 | "utt2phn_path": "data/res_utt2phns.json", 41 | "add_blank": false 42 | }, 43 | "model": { 44 | "n_vocab": 200, 45 | "n_spks": 3, 46 | "n_emos": 6, 47 | "spk_emb_dim": 64, 48 | "n_enc_channels": 192, 49 | "filter_channels": 768, 50 | "filter_channels_dp": 256, 51 | "n_enc_layers": 6, 52 | "enc_kernel": 3, 53 | "enc_dropout": 0.1, 54 | "n_heads": 2, 55 | "window_size": 4, 56 | "dec_dim": 64, 57 | "beta_min": 0.05, 58 | "beta_max": 20.0, 59 | "pe_scale": 1000, 60 | "d_decoder": 128, 61 | "l_decoder": 3, 62 | "k_decoder": 7, 63 | "h_decoder": 4, 64 | "decoder_dropout":0.1, 65 | 66 | "classifier_type": "CNN-with-time" 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /data_collate.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import numpy as np 4 | import torch 5 | import re 6 | import torch.utils.data 7 | import json 8 | 9 | import kaldiio 10 | from tqdm import tqdm 11 | 12 | 13 | class BaseCollate: 14 | def __init__(self, n_frames_per_step=1): 15 | self.n_frames_per_step = n_frames_per_step 16 | 17 | def collate_text_mel(self, batch: [dict]): 18 | """ 19 | :param batch: list of dicts 20 | """ 21 | utt = list(map(lambda x: x['utt'], batch)) 22 | input_lengths, ids_sorted_decreasing = torch.sort( 23 | torch.LongTensor([len(x['text']) for x in batch]), 24 | dim=0, descending=True) 25 | max_input_len = input_lengths[0] 26 | 27 | text_padded = torch.LongTensor(len(batch), max_input_len) 28 | text_padded.zero_() 29 | for i in range(len(ids_sorted_decreasing)): 30 | text = batch[ids_sorted_decreasing[i]]['text'] 31 | text_padded[i, :text.size(0)] = text 32 | 33 | # Right zero-pad mel-spec 34 | num_mels = batch[0]['mel'].size(0) 35 | max_target_len = max([x['mel'].size(1) for x in batch]) 36 | if max_target_len % self.n_frames_per_step != 0: 37 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 38 | assert max_target_len % self.n_frames_per_step == 0 39 | 40 | # include mel padded 41 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 42 | mel_padded.zero_() 43 | output_lengths = torch.LongTensor(len(batch)) 44 | for i in range(len(ids_sorted_decreasing)): 45 | mel = batch[ids_sorted_decreasing[i]]['mel'] 46 | mel_padded[i, :, :mel.size(1)] = mel 47 | output_lengths[i] = mel.size(1) 48 | 49 | utt_name = np.array(utt)[ids_sorted_decreasing].tolist() 50 | if isinstance(utt_name, str): 51 | utt_name = [utt_name] 52 | 53 | res = { 54 | "utt": utt_name, 55 | "text_padded": text_padded, 56 | "input_lengths": input_lengths, 57 | "mel_padded": mel_padded, 58 | "output_lengths": output_lengths, 59 | } 60 | return res, ids_sorted_decreasing 61 | 62 | 63 | class SpkIDCollate(BaseCollate): 64 | def __call__(self, batch, *args, **kwargs): 65 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch) 66 | spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) 67 | spk_ids = spk_ids[ids_sorted_decreasing] 68 | base_data.update({ 69 | "spk_ids": spk_ids 70 | }) 71 | return base_data 72 | 73 | 74 | class SpkIDCollateWithEmo(BaseCollate): 75 | def __call__(self, batch, *args, **kwargs): 76 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch) 77 | 78 | spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) 79 | spk_ids = spk_ids[ids_sorted_decreasing] 80 | emo_ids = torch.LongTensor(list(map(lambda x: x['emo_ids'], batch))) 81 | emo_ids = emo_ids[ids_sorted_decreasing] 82 | base_data.update({ 83 | "spk_ids": spk_ids, 84 | "emo_ids": emo_ids 85 | }) 86 | return base_data 87 | 88 | 89 | class XvectorCollate(BaseCollate): 90 | def __call__(self, batch, *args, **kwargs): 91 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch) 92 | xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0) 93 | xvectors = xvectors[ids_sorted_decreasing] 94 | base_data.update({ 95 | "xvector": xvectors 96 | }) 97 | return base_data 98 | 99 | 100 | class SpkIDCollateWithPE(BaseCollate): 101 | def __call__(self, batch, *args, **kwargs): 102 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch) 103 | spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) 104 | spk_ids = spk_ids[ids_sorted_decreasing] 105 | 106 | num_var = batch[0]["var"].size(0) 107 | max_target_len = max([x["var"].size(1) for x in batch]) 108 | if max_target_len % self.n_frames_per_step != 0: 109 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 110 | assert max_target_len % self.n_frames_per_step == 0 111 | 112 | var_padded = torch.FloatTensor(len(batch), num_var, max_target_len) 113 | var_padded.zero_() 114 | for i in range(len(ids_sorted_decreasing)): 115 | var = batch[ids_sorted_decreasing[i]]["var"] 116 | var_padded[i, :, :var.size(1)] = var 117 | 118 | base_data.update({ 119 | "spk_ids": spk_ids, 120 | "var_padded": var_padded 121 | }) 122 | return base_data 123 | 124 | 125 | class XvectorCollateWithPE(BaseCollate): 126 | def __call__(self, batch, *args, **kwargs): 127 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch) 128 | xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0) 129 | xvectors = xvectors[ids_sorted_decreasing] 130 | 131 | num_var = batch[0]["var"].size(0) 132 | max_target_len = max([x["var"].size(1) for x in batch]) 133 | if max_target_len % self.n_frames_per_step != 0: 134 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 135 | assert max_target_len % self.n_frames_per_step == 0 136 | 137 | var_padded = torch.FloatTensor(len(batch), num_var, max_target_len) 138 | var_padded.zero_() 139 | for i in range(len(ids_sorted_decreasing)): 140 | var = batch[ids_sorted_decreasing[i]]["var"] 141 | var_padded[i, :, :var.size(1)] = var 142 | 143 | base_data.update({ 144 | "xvector": xvectors, 145 | "var_padded": var_padded 146 | }) 147 | return base_data 148 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import numpy as np 4 | import torch 5 | import re 6 | import torch.utils.data 7 | import json 8 | 9 | import kaldiio 10 | from tqdm import tqdm 11 | from text import text_to_sequence 12 | 13 | class BaseLoader(torch.utils.data.Dataset): 14 | def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str): 15 | """ 16 | :param utts: file path. A list of utts for this loader. These are the only utts that this loader has access. 17 | This loader only deals with text, duration and feats. Other files despite `utts` can be larger. 18 | """ 19 | self.n_mel_channels = hparams.n_mel_channels 20 | self.sampling_rate = hparams.sampling_rate 21 | self.utts = self.get_utts(utts) 22 | self.utt2feat = self.get_utt2feat(feats_scp) 23 | self.utt2text = self.get_utt2text(utt2text) 24 | 25 | def get_utts(self, utts: str) -> list: 26 | with open(utts, 'r') as f: 27 | L = f.readlines() 28 | L = list(map(lambda x: x.strip(), L)) 29 | random.seed(1234) 30 | random.shuffle(L) 31 | return L 32 | 33 | 34 | def get_utt2feat(self, feats_scp: str): 35 | utt2feat = kaldiio.load_scp(feats_scp) # lazy load mode 36 | print(f"Succeed reading feats from {feats_scp}") 37 | return utt2feat 38 | 39 | def get_utt2text(self, utt2text: str): 40 | with open(utt2text, 'r') as f: 41 | L = f.readlines() 42 | utt2text = {line.split()[0]: line.strip().split(" ", 1)[1] for line in L} 43 | return utt2text 44 | 45 | def get_mel_from_kaldi(self, utt): 46 | feat = self.utt2feat[utt] 47 | feat = torch.FloatTensor(feat).squeeze() 48 | assert self.n_mel_channels in feat.shape 49 | if feat.shape[0] == self.n_mel_channels: 50 | return feat 51 | else: 52 | return feat.T 53 | 54 | def get_text(self, utt): 55 | text = self.utt2text[utt] 56 | text_norm = text_to_sequence(text) 57 | text_norm = torch.IntTensor(text_norm) 58 | return text_norm 59 | 60 | def __getitem__(self, index): 61 | res = self.get_mel_text_pair(self.utts[index]) 62 | return res 63 | 64 | def __len__(self): 65 | return len(self.utts) 66 | 67 | def sample_test_batch(self, size): 68 | idx = np.random.choice(range(len(self)), size=size, replace=False) 69 | test_batch = [] 70 | for index in idx: 71 | test_batch.append(self.__getitem__(index)) 72 | return test_batch 73 | 74 | 75 | class SpkIDLoader(BaseLoader): 76 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, 77 | utt2phn_duration: str, utt2spk: str): 78 | """ 79 | :param utt2spk: json file path (utt name -> spk id) 80 | This loader loads speaker as a speaker ID for embedding table 81 | """ 82 | super(SpkIDLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) 83 | self.utt2spk = self.get_utt2spk(utt2spk) 84 | 85 | def get_utt2spk(self, utt2spk: str) -> dict: 86 | with open(utt2spk, 'r') as f: 87 | res = json.load(f) 88 | return res 89 | 90 | def get_mel_text_pair(self, utt): 91 | # separate filename and text 92 | spkid = self.utt2spk[utt] 93 | phn_ids = self.get_text(utt) 94 | mel = self.get_mel_from_kaldi(utt) 95 | dur = self.get_dur_from_kaldi(utt) 96 | 97 | assert sum(dur) == mel.shape[1], f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}" 98 | res = { 99 | "utt": utt, 100 | "mel": mel, 101 | "spk_ids": spkid 102 | } 103 | return res 104 | 105 | def __getitem__(self, index): 106 | res = self.get_mel_text_pair(self.utts[index]) 107 | return res 108 | 109 | def __len__(self): 110 | return len(self.utts) 111 | 112 | 113 | class SpkIDLoaderWithEmo(BaseLoader): 114 | def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str, utt2spk: str, utt2emo: str): 115 | """ 116 | :param utt2spk: json file path (utt name -> spk id) 117 | This loader loads speaker as a speaker ID for embedding table 118 | """ 119 | super(SpkIDLoaderWithEmo, self).__init__(utts, hparams, feats_scp, utt2text) 120 | self.utt2spk = self.get_utt2spk(utt2spk) 121 | self.utt2emo = self.get_utt2emo(utt2emo) 122 | 123 | def get_utt2spk(self, utt2spk: str) -> dict: 124 | with open(utt2spk, 'r') as f: 125 | res = json.load(f) 126 | return res 127 | 128 | def get_utt2emo(self, utt2emo: str) -> dict: 129 | with open(utt2emo, 'r') as f: 130 | res = json.load(f) 131 | return res 132 | 133 | def get_mel_text_pair(self, utt): 134 | # separate filename and text 135 | spkid = int(self.utt2spk[utt]) 136 | emoid = int(self.utt2emo[utt]) 137 | text = self.get_text(utt) 138 | mel = self.get_mel_from_kaldi(utt) 139 | 140 | res = { 141 | "utt": utt, 142 | "text": text, 143 | "mel": mel, 144 | "spk_ids": spkid, 145 | "emo_ids": emoid 146 | } 147 | return res 148 | 149 | def __getitem__(self, index): 150 | res = self.get_mel_text_pair(self.utts[index]) 151 | return res 152 | 153 | def __len__(self): 154 | return len(self.utts) 155 | 156 | 157 | class SpkIDLoaderWithPE(SpkIDLoader): 158 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, 159 | utt2phn_duration: str, utt2spk: str, var_scp: str): 160 | """ 161 | This loader loads speaker ID together with variance (4-dim pitch, 1-dim energy) 162 | """ 163 | super(SpkIDLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration, utt2spk) 164 | self.utt2var = self.get_utt2var(var_scp) 165 | 166 | def get_utt2var(self, utt2var: str) -> dict: 167 | res = kaldiio.load_scp(utt2var) 168 | print(f"Succeed reading feats from {utt2var}") 169 | return res 170 | 171 | def get_var_from_kaldi(self, utt): 172 | var = self.utt2var[utt] 173 | var = torch.FloatTensor(var).squeeze() 174 | assert 5 in var.shape 175 | if var.shape[0] == 5: 176 | return var 177 | else: 178 | return var.T 179 | 180 | def get_mel_text_pair(self, utt): 181 | # separate filename and text 182 | spkid = self.utt2spk[utt] 183 | phn_ids = self.get_text(utt) 184 | mel = self.get_mel_from_kaldi(utt) 185 | dur = self.get_dur_from_kaldi(utt) 186 | var = self.get_var_from_kaldi(utt) 187 | 188 | assert sum(dur) == mel.shape[1] == var.shape[1], \ 189 | f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}" 190 | 191 | res = { 192 | "utt": utt, 193 | "phn_ids": phn_ids, 194 | "mel": mel, 195 | "dur": dur, 196 | "spk_ids": spkid, 197 | "var": var 198 | } 199 | return res 200 | 201 | 202 | class XvectorLoader(BaseLoader): 203 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, 204 | utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str): 205 | """ 206 | :param utt2spk_name: like kaldi-style utt2spk 207 | :param spk_xvector_scp: kaldi-style speaker-level xvector.scp 208 | """ 209 | super(XvectorLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) 210 | self.utt2spk = self.get_utt2spk(utt2spk_name) 211 | self.spk2xvector = self.get_spk2xvector(spk_xvector_scp) 212 | 213 | def get_utt2spk(self, utt2spk): 214 | res = dict() 215 | with open(utt2spk, 'r') as f: 216 | for l in f.readlines(): 217 | res[l.split()[0]] = l.split()[1] 218 | return res 219 | 220 | def get_spk2xvector(self, spk_xvector_scp: str) -> dict: 221 | res = kaldiio.load_scp(spk_xvector_scp) 222 | print(f"Succeed reading xvector from {spk_xvector_scp}") 223 | return res 224 | 225 | def get_xvector(self, utt): 226 | xv = self.spk2xvector[self.utt2spk[utt]] 227 | xv = torch.FloatTensor(xv).squeeze() 228 | return xv 229 | 230 | def get_mel_text_pair(self, utt): 231 | phn_ids = self.get_text(utt) 232 | mel = self.get_mel_from_kaldi(utt) 233 | dur = self.get_dur_from_kaldi(utt) 234 | xvector = self.get_xvector(utt) 235 | 236 | assert sum(dur) == mel.shape[1], \ 237 | f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}" 238 | 239 | res = { 240 | "utt": utt, 241 | "phn_ids": phn_ids, 242 | "mel": mel, 243 | "dur": dur, 244 | "xvector": xvector, 245 | } 246 | return res 247 | 248 | 249 | class XvectorLoaderWithPE(BaseLoader): 250 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, 251 | utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str, var_scp: str): 252 | super(XvectorLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) 253 | self.utt2spk = self.get_utt2spk(utt2spk_name) 254 | self.spk2xvector = self.get_spk2xvector(spk_xvector_scp) 255 | self.utt2var = self.get_utt2var(var_scp) 256 | 257 | def get_spk2xvector(self, spk_xvector_scp: str) -> dict: 258 | res = kaldiio.load_scp(spk_xvector_scp) 259 | print(f"Succeed reading xvector from {spk_xvector_scp}") 260 | return res 261 | 262 | def get_utt2spk(self, utt2spk): 263 | res = dict() 264 | with open(utt2spk, 'r') as f: 265 | for l in f.readlines(): 266 | res[l.split()[0]] = l.split()[1] 267 | return res 268 | 269 | def get_utt2var(self, utt2var: str) -> dict: 270 | res = kaldiio.load_scp(utt2var) 271 | print(f"Succeed reading feats from {utt2var}") 272 | return res 273 | 274 | def get_var_from_kaldi(self, utt): 275 | var = self.utt2var[utt] 276 | var = torch.FloatTensor(var).squeeze() 277 | assert 5 in var.shape 278 | if var.shape[0] == 5: 279 | return var 280 | else: 281 | return var.T 282 | 283 | def get_xvector(self, utt): 284 | xv = self.spk2xvector[self.utt2spk[utt]] 285 | xv = torch.FloatTensor(xv).squeeze() 286 | return xv 287 | 288 | def get_mel_text_pair(self, utt): 289 | # separate filename and text 290 | spkid = self.utt2spk[utt] 291 | phn_ids = self.get_text(utt) 292 | mel = self.get_mel_from_kaldi(utt) 293 | dur = self.get_dur_from_kaldi(utt) 294 | var = self.get_var_from_kaldi(utt) 295 | xvector = self.get_xvector(utt) 296 | 297 | assert sum(dur) == mel.shape[1] == var.shape[1], \ 298 | f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}" 299 | 300 | res = { 301 | "utt": utt, 302 | "phn_ids": phn_ids, 303 | "mel": mel, 304 | "dur": dur, 305 | "spk_ids": spkid, 306 | "var": var, 307 | "xvector": xvector 308 | } 309 | return res 310 | -------------------------------------------------------------------------------- /data_preparation.py: -------------------------------------------------------------------------------- 1 | import kaldiio 2 | import os 3 | import librosa 4 | from tqdm import tqdm 5 | import glob 6 | import json 7 | from shutil import copyfile 8 | import pandas as pd 9 | import argparse 10 | from text import _clean_text, symbols 11 | from num2words import num2words 12 | import re 13 | from melspec import mel_spectrogram 14 | import torchaudio 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-d', '--data', type=str, required=True, help='path to the emotional dataset') 19 | args = parser.parse_args() 20 | dataset_path = args.data 21 | filelists_path = 'filelists/all_spks/' 22 | feats_scp_file = filelists_path + 'feats.scp' 23 | feats_ark_file = filelists_path + 'feats.ark' 24 | 25 | 26 | spks = ['1263201035', '805570882', '399172782'] 27 | train_files = [] 28 | eval_files = [] 29 | for spk in spks: 30 | train_files += glob.glob(dataset_path + spk + "/train/*.wav") 31 | eval_files += glob.glob(dataset_path + spk + "/eval/*.wav") 32 | 33 | os.makedirs(filelists_path, exist_ok=True) 34 | 35 | with open(filelists_path + 'train_utts.txt', 'w', encoding='utf-8') as f: 36 | for wav_path in train_files: 37 | wav_name = os.path.splitext(os.path.basename(wav_path))[0] 38 | f.write(wav_name + '\n') 39 | with open(filelists_path + 'eval_utts.txt', 'w', encoding='utf-8') as f: 40 | for wav_path in eval_files: 41 | wav_name = os.path.splitext(os.path.basename(wav_path))[0] 42 | f.write(wav_name + '\n') 43 | 44 | with open(feats_scp_file, 'w') as feats_scp, \ 45 | kaldiio.WriteHelper(f'ark,scp:{feats_ark_file},{feats_scp_file}') as writer: 46 | for root, dirs, files in os.walk(dataset_path): 47 | for file in tqdm(files): 48 | if file.endswith('.wav'): 49 | # Get the file name and relative path to the root folder 50 | wav_path = os.path.join(root, file) 51 | rel_path = os.path.relpath(wav_path, dataset_path) 52 | wav_name = os.path.splitext(os.path.basename(wav_path))[0] 53 | signal, rate = torchaudio.load(wav_path) 54 | spec = mel_spectrogram(signal, 1024, 80, 22050, 256, 55 | 1024, 0, 8000, center=False).squeeze() 56 | # Write the features to feats.ark and feats.scp 57 | writer[wav_name] = spec 58 | 59 | 60 | emotions = [os.path.basename(x).split("_")[1] for x in glob.glob(dataset_path + '/**/**/*')] 61 | emotions = sorted(set(emotions)) 62 | 63 | utt2spk = {} 64 | utt2emo = {} 65 | wavs = glob.glob(dataset_path + '**/**/*.wav') 66 | for wav_path in tqdm(wavs): 67 | wav_name = os.path.splitext(os.path.basename(wav_path))[0] 68 | emotion = emotions.index(wav_name.split("_")[1]) 69 | if wav_path.split('/')[-3] == '1263201035': 70 | spk = 0 ## labels should start with 0 71 | elif wav_path.split('/')[-3] == '805570882': 72 | spk = 1 73 | else: 74 | spk = 2 75 | utt2spk[wav_name] = str(spk) 76 | utt2emo[wav_name] = str(emotion) 77 | utt2spk = dict(sorted(utt2spk.items())) 78 | utt2emo = dict(sorted(utt2emo.items())) 79 | 80 | with open(filelists_path + 'utt2emo.json', 'w') as fp: 81 | json.dump(utt2emo, fp, indent=4) 82 | with open(filelists_path + 'utt2spk.json', 'w') as fp: 83 | json.dump(utt2spk, fp, indent=4) 84 | 85 | txt_files = sorted(glob.glob(dataset_path + '/**/**/*.txt')) 86 | count = 0 87 | txt = [] 88 | basenames = [] 89 | utt2text = {} 90 | flag = False 91 | with open(filelists_path + 'text', 'w', encoding='utf-8') as write: 92 | for txt_path in txt_files: 93 | basename = os.path.basename(txt_path).replace('.txt', '') 94 | with open(txt_path, 'r', encoding='utf-8') as f: 95 | txt.append(_clean_text(f.read().strip("\n"), cleaner_names=["kazakh_cleaners"]).replace("'", "")) 96 | basenames.append(basename) 97 | output_string = [re.sub('(\d+)', lambda m: num2words(m.group(), lang='kz'), sentence) for sentence in txt] 98 | cleaned_txt = [] 99 | for t in output_string: 100 | cleaned_txt.append(''.join([s for s in t if s in symbols])) 101 | utt2text = {basenames[i]: cleaned_txt[i] for i in range(len(cleaned_txt))} 102 | utt2text = dict(sorted(utt2text.items())) 103 | 104 | vocab = set() 105 | with open(filelists_path + '/text', 'w', encoding='utf-8') as f: 106 | for x, y in utt2text.items(): 107 | for c in y: vocab.add(c) 108 | f.write(x + ' ' + y + '\n') 109 | -------------------------------------------------------------------------------- /filelists/all_spks/feats.ark: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/filelists/all_spks/feats.ark -------------------------------------------------------------------------------- /filelists/all_spks/feats.scp: -------------------------------------------------------------------------------- 1 | utt1 /Users/Desktop/code/GradTTS-emo/filelists/example/feats.ark:5 2 | utt2 /Users/Desktop/code/GradTTS-emo/filelists/example/feats.ark:78745 3 | utt3 /Users/Desktop/code/GradTTS-emo/filelists/example/feats.ark:370605 4 | -------------------------------------------------------------------------------- /filelists/inference_generated.txt: -------------------------------------------------------------------------------- 1 | Августың аяқ жағына мүсінші тәңірия Венераның баласы Амур бейнесін орналастырған.|0|0 2 | Қарғыс айтқалы жатыр ғой, өз балаларына!– десіп үркіп үн салды.|1|1 -------------------------------------------------------------------------------- /inference_EMA.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import datetime as dt 4 | import numpy as np 5 | from scipy.io.wavfile import write 6 | import IPython.display as ipd 7 | import glob 8 | import torch 9 | from pydub import AudioSegment 10 | from torch.utils.data import DataLoader 11 | from text import text_to_sequence, cmudict 12 | from text.symbols import symbols 13 | import utils_data 14 | import re 15 | from num2words import num2words 16 | from kaldiio import WriteHelper 17 | import os 18 | from tqdm import tqdm 19 | from text import text_to_sequence, convert_text 20 | from model import GradTTSWithEmo 21 | import utils_data as utils 22 | from attrdict import AttrDict 23 | from models import Generator as HiFiGAN 24 | 25 | 26 | HIFIGAN_CONFIG = './configs/hifigan-config.json' 27 | HIFIGAN_CHECKPT = './checkpts/hifigan.pt' 28 | 29 | 30 | if __name__ == '__main__': 31 | hps, args = utils.get_hparams_decode() 32 | device = torch.device('cpu' if not torch.cuda.is_available() else "cuda") 33 | ckpt = utils_data.latest_checkpoint_path(hps.model_dir, "EMA_grad_*.pt") 34 | print(ckpt) 35 | model = GradTTSWithEmo(**hps.model).to(device) 36 | logger = utils_data.get_logger(hps.model_dir, "inference.log") 37 | utils_data.load_checkpoint(ckpt, model, None) 38 | _ = model.cuda().eval() 39 | 40 | print('Initializing HiFi-GAN...') 41 | with open(HIFIGAN_CONFIG) as f: 42 | h = AttrDict(json.load(f)) 43 | vocoder = HiFiGAN(h) 44 | vocoder.load_state_dict(torch.load(HIFIGAN_CHECKPT, map_location=lambda loc, storage: loc)['generator']) 45 | _ = vocoder.cuda().eval() 46 | vocoder.remove_weight_norm() 47 | 48 | emos = sorted(["angry", "surprise", "fear", "happy", "neutral", "sad"]) 49 | speakers = ['M1', 'F1', 'M2'] 50 | 51 | with open(args.file, 'r', encoding='utf-8') as f: 52 | texts = [line.strip() for line in f.readlines()] 53 | 54 | replace_nums = [] 55 | for i in texts: 56 | replace_nums.append(i.split('|', 1)) 57 | 58 | nums2word = [re.sub('(\d+)', lambda m: num2words(m.group(), lang='kz'), sentence) for sentence in np.array(replace_nums)[:, 0]] 59 | # Speakers id. 60 | # M1 = 0 61 | # F1 = 1 62 | # M2 = 2 63 | text2speech = [] 64 | for i, j in zip(nums2word, np.array(replace_nums)[:, 1]): 65 | text2speech.append(f'{i}|{j}') 66 | 67 | for i, line in enumerate(text2speech): 68 | emo_i = int(line.split('|')[1]) 69 | control_spk_id = int(line.split('|')[2]) 70 | control_emo_id = emos.index(emos[emo_i]) 71 | text = line.split('|')[0] 72 | with torch.no_grad(): 73 | ### define emotion 74 | emo = torch.LongTensor([control_emo_id]).to(device) 75 | sid = torch.LongTensor([control_spk_id]).to(device) 76 | text_padded, text_len = convert_text(text) 77 | y_enc, y_dec, attn = model.forward(text_padded, text_len, 78 | n_timesteps=args.timesteps, 79 | temperature=args.noise, 80 | stoc=args.stoc, spk=sid,emo=emo, length_scale=1., 81 | classifier_free_guidance=args.guidance) 82 | res = y_dec.squeeze().cpu().numpy() 83 | x = torch.from_numpy(res).cuda().unsqueeze(0) 84 | y_g_hat = vocoder(x) 85 | audio = y_g_hat.squeeze() 86 | audio = audio * 32768.0 87 | audio = audio.detach().cpu().numpy().astype('int16') 88 | audio = AudioSegment(audio.data, frame_rate=22050, sample_width=2, channels=1) 89 | audio.export(f'{args.generated_path}/{emos[emo_i]}_{speakers[int(line.split("|")[2])]}.wav', format="wav") -------------------------------------------------------------------------------- /melspec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import librosa 4 | 5 | mel_basis = {} 6 | hann_window = {} 7 | 8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 9 | return torch.log(torch.clamp(x, min=clip_val) * C) 10 | 11 | def spectral_normalize_torch(magnitudes): 12 | output = dynamic_range_compression_torch(magnitudes) 13 | return output 14 | 15 | 16 | 17 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 18 | if torch.min(y) < -1.: 19 | print('min value is ', torch.min(y)) 20 | if torch.max(y) > 1.: 21 | print('max value is ', torch.max(y)) 22 | 23 | global mel_basis, hann_window 24 | if fmax not in mel_basis: 25 | mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 26 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 27 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 28 | 29 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 30 | y = y.squeeze(1) 31 | 32 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 33 | center=center, pad_mode='reflect', normalized=False, onesided=True) 34 | 35 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 36 | 37 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 38 | spec = spectral_normalize_torch(spec) 39 | 40 | return spec.numpy() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .tts import GradTTSWithEmo, GradTTSXvector 3 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BaseModule(torch.nn.Module): 6 | def __init__(self): 7 | super(BaseModule, self).__init__() 8 | 9 | @property 10 | def nparams(self): 11 | """ 12 | Returns number of trainable parameters of the module. 13 | """ 14 | num_params = 0 15 | for name, param in self.named_parameters(): 16 | if param.requires_grad: 17 | num_params += np.prod(param.detach().cpu().numpy().shape) 18 | return num_params 19 | 20 | def relocate_input(self, x: list): 21 | """ 22 | Relocates provided tensors to the same device set for the module. 23 | """ 24 | device = next(self.parameters()).device 25 | for i in range(len(x)): 26 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 27 | x[i] = x[i].to(device) 28 | return x 29 | -------------------------------------------------------------------------------- /model/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor, BoolTensor 4 | 5 | from typing import Optional, Tuple, Iterable 6 | from model.diffusion import SinusoidalPosEmb 7 | from torch.nn.functional import pad 8 | 9 | 10 | import math 11 | 12 | def silu(input): 13 | ''' 14 | Applies the Sigmoid Linear Unit (SiLU) function element-wise: 15 | SiLU(x) = x * sigmoid(x) 16 | ''' 17 | return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions 18 | 19 | 20 | class RelPositionMultiHeadedAttention(nn.Module): 21 | """Multi-Head Self-Attention layer with relative position encoding. 22 | Paper: https://arxiv.org/abs/1901.02860 23 | Args: 24 | n_head: The number of heads. 25 | d: The number of features. 26 | dropout: Dropout rate. 27 | zero_triu: Whether to zero the upper triangular part of attention matrix. 28 | """ 29 | 30 | def __init__( 31 | self, d: int, n_head: int, dropout: float 32 | ): 33 | super().__init__() 34 | assert d % n_head == 0 35 | self.c = d // n_head 36 | self.h = n_head 37 | 38 | self.linear_q = nn.Linear(d, d) 39 | self.linear_k = nn.Linear(d, d) 40 | self.linear_v = nn.Linear(d, d) 41 | self.linear_out = nn.Linear(d, d) 42 | 43 | self.p_attn = None 44 | self.dropout = nn.Dropout(p=dropout) 45 | 46 | # linear transformation for positional encoding 47 | self.linear_pos = nn.Linear(d, d, bias=False) 48 | 49 | # these two learnable bias are used in matrix c and matrix d 50 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 51 | self.u = nn.Parameter(torch.Tensor(self.h, self.c)) 52 | self.v = nn.Parameter(torch.Tensor(self.h, self.c)) 53 | # [H, C] 54 | torch.nn.init.xavier_uniform_(self.u) 55 | torch.nn.init.xavier_uniform_(self.v) 56 | 57 | def forward_qkv(self, query, key, value) -> Tuple[Tensor, ...]: 58 | """Transform query, key and value. 59 | Args: 60 | query (Tensor): [B, S, D]. 61 | key (Tensor): [B, T, D]. 62 | value (Tensor): [B, T, D]. 63 | Returns: 64 | q (Tensor): [B, H, S, C]. 65 | k (Tensor): [B, H, T, C]. 66 | v (Tensor): [B, H, T, C]. 67 | """ 68 | n_batch = query.size(0) 69 | q = self.linear_q(query).view(n_batch, -1, self.h, self.c) 70 | k = self.linear_k(key).view(n_batch, -1, self.h, self.c) 71 | v = self.linear_v(value).view(n_batch, -1, self.h, self.c) 72 | q = q.transpose(1, 2) 73 | k = k.transpose(1, 2) 74 | v = v.transpose(1, 2) 75 | return q, k, v 76 | 77 | def forward_attention(self, v, scores, mask, causal=False) -> Tensor: 78 | """Compute attention context vector. 79 | Args: 80 | v (Tensor): [B, H, T, C]. 81 | scores (Tensor): [B, H, S, T]. 82 | mask (BoolTensor): [B, T], True values are masked from scores. 83 | Returns: 84 | result (Tensor): [B, S, D]. Attention result weighted by the score. 85 | """ 86 | n_batch, H, S, T = scores.shape 87 | if mask is not None: 88 | scores = scores.masked_fill( 89 | mask.unsqueeze(1).unsqueeze(2).to(bool), 90 | float("-inf"), # [B, H, S, T] 91 | ) 92 | if causal: 93 | k_grid = torch.arange(0, S, dtype=torch.int32, device=scores.device) 94 | v_grid = torch.arange(0, T, dtype=torch.int32, device=scores.device) 95 | kk, vv = torch.meshgrid(k_grid, v_grid, indexing="ij") 96 | causal_mask = vv > kk 97 | scores = scores.masked_fill( 98 | causal_mask.view(1, 1, S, T), float("-inf") 99 | ) 100 | 101 | p_attn = self.p_attn = torch.softmax(scores, dim=-1) # [B, H, S, T] 102 | p_attn = self.dropout(p_attn) # [B, H, S, T] 103 | 104 | x = torch.matmul(p_attn, v) # [B, H, S, C] 105 | x = ( 106 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.c) 107 | ) # [B, S, D] 108 | 109 | return self.linear_out(x) # [B, S, D] 110 | 111 | def rel_shift(self, x): 112 | """Converting (..., i, i - j) matrix into (..., i, j) matrix. 113 | Args: 114 | x (Tensor): [B, H, S, 2S-1]. 115 | Returns: 116 | x (Tensor): [B, H, S, S]. 117 | Example: Take S = 2 for example, larger values work similarly. 118 | x = [ 119 | [(0, -1), (0, 0), (0, 1)], 120 | [(1, 0), (1, 1), (1, 2)] 121 | ] 122 | x_padded = [ 123 | [(x, x), (0, -1), (0, 0), (0, 1)], 124 | [(x, x), (1, 0), (1, 1), (1, 2)]] 125 | ] 126 | x_padded = [ 127 | [(x, x), (0, -1)], 128 | [(0, 0), (0, 1)], 129 | [(x, x), (1, 0)], 130 | [(1, 1), (1, 2)] 131 | ] 132 | x = [ 133 | [(0, 0), (0, 1)], 134 | [(1, 0), (1, 1)] 135 | ] 136 | """ 137 | B, H, S, _ = x.shape 138 | zero_pad = torch.zeros((B, H, S, 1), device=x.device, dtype=x.dtype) 139 | # [B, H, S, 1] 140 | x_padded = torch.cat([zero_pad, x], dim=-1) 141 | # [B, H, S, 2S] 142 | x_padded = x_padded.view(B, H, 2 * S, S) 143 | # [B, H, 2S, S] 144 | x = x_padded[:, :, 1:].view_as(x)[:, :, :, :S] 145 | # only keep the positions from 0 to S 146 | # [B, H, 2S-1, S] [B, H, S, 2S - 1] [B, H, S, S] 147 | return x 148 | 149 | def forward( 150 | self, query, key, value, pos_emb, mask=None, causal=False): 151 | """Compute self-attention with relative positional embedding. 152 | Args: 153 | query (Tensor): [B, S, D]. 154 | key (Tensor): [B, S, D]. 155 | value (Tensor): [B, S, D]. 156 | pos_emb (Tensor): [1/B, 2S-1, D]. Positional embedding. 157 | mask (BoolTensor): [B, S], True for masked. 158 | causal (bool): True for applying causal mask. 159 | Returns: 160 | output (Tensor): [B, S, D]. 161 | """ 162 | # Splitting Q, K, V: 163 | q, k, v = self.forward_qkv(query, key, value) 164 | # [B, H, S, C], [B, H, S, C], [B, H, S, C] 165 | 166 | # Adding per head & channel biases to the query vectors: 167 | q_u = q + self.u.unsqueeze(1) 168 | q_v = q + self.v.unsqueeze(1) 169 | # [B, H, S, C] 170 | 171 | # Splitting relative positional coding: 172 | n_batch_pos = pos_emb.size(0) # [1/B, 2S-1, D] 173 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.c) 174 | # [1/B, 2S-1, H, C] 175 | p = p.transpose(1, 2) # [1/B, H, 2S-1, C]. 176 | 177 | # Compute query, key similarity: 178 | matrix_ac = torch.matmul(q_u, k.transpose(-2, -1)) 179 | # [B, H, S, C] x [B, H, C, S] -> [B, H, S, S] 180 | 181 | matrix_bd = torch.matmul(q_v, p.transpose(-2, -1)) 182 | # [B, H, S, C] x [1/B, H, C, 2S-1] -> [B, H, S, 2S-1] 183 | matrix_bd = self.rel_shift(matrix_bd) 184 | 185 | scores = (matrix_ac + matrix_bd) / math.sqrt(self.c) 186 | # [B, H, S, S] 187 | 188 | return self.forward_attention(v, scores, mask, causal) # [B, S, D] 189 | 190 | 191 | class ConditionalBiasScale(nn.Module): 192 | def __init__(self, channels: int, cond_channels: int): 193 | super().__init__() 194 | self.scale_transform = nn.Linear( 195 | cond_channels, channels, bias=True 196 | ) 197 | self.bias_transform = nn.Linear( 198 | cond_channels, channels, bias=True 199 | ) 200 | self.init_parameters() 201 | 202 | def init_parameters(self): 203 | torch.nn.init.constant_(self.scale_transform.weight, 0.0) 204 | torch.nn.init.constant_(self.scale_transform.bias, 1.0) 205 | torch.nn.init.constant_(self.bias_transform.weight, 0.0) 206 | torch.nn.init.constant_(self.bias_transform.bias, 0.0) 207 | 208 | def forward(self, x: Tensor, cond: Tensor) -> Tensor: 209 | """Applying conditional bias and scale. 210 | Args: 211 | x (Tensor): [..., channels]. 212 | cond (Tensor): [..., cond_channels]. 213 | Returns: 214 | y (Tensor): [..., channels]. 215 | """ 216 | a = self.scale_transform.forward(cond) 217 | b = self.bias_transform.forward(cond) 218 | return x * a + b 219 | 220 | 221 | class FeedForwardModule(torch.nn.Module): 222 | """Positionwise feed forward layer used in conformer""" 223 | 224 | def __init__( 225 | self, d_in: int, d_hidden: int, 226 | dropout: float, bias: bool = True, d_cond: int = 0 227 | ): 228 | """ 229 | Args: 230 | d_in (int): Input feature dimension. 231 | d_hidden (int): Hidden unit dimension. 232 | dropout (float): dropout value for first Linear Layer. 233 | bias (bool): If linear layers should have bias. 234 | d_cond (int, optional): The channels of conditional tensor. 235 | """ 236 | super(FeedForwardModule, self).__init__() 237 | self.layer_norm = torch.nn.LayerNorm(d_in) 238 | 239 | if d_cond > 0: 240 | self.cond_layer = ConditionalBiasScale(d_in, d_cond) 241 | 242 | self.w_1 = torch.nn.Linear(d_in, d_hidden, bias=bias) 243 | self.w_2 = torch.nn.Linear(d_hidden, d_in, bias=bias) 244 | self.dropout = torch.nn.Dropout(dropout) 245 | 246 | def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor: 247 | """ 248 | Args: 249 | x (Tensor): [..., D]. 250 | Returns: 251 | y (Tensor): [..., D]. 252 | cond (Tensor): [..., D_cond] 253 | """ 254 | x = self.layer_norm(x) 255 | 256 | if cond is not None: 257 | x = self.cond_layer.forward(x, cond) 258 | 259 | x = self.w_1(x) 260 | x = silu(x) 261 | x = self.dropout(x) 262 | x = self.w_2(x) 263 | return self.dropout(x) 264 | 265 | 266 | class RelPositionalEncoding(nn.Module): 267 | """Relative positional encoding cache. 268 | 269 | Args: 270 | d_model: Embedding dimension. 271 | dropout_rate: Dropout rate. 272 | max_len: Default maximum input length. 273 | """ 274 | 275 | def __init__(self, max_len: int, d_model: int): 276 | super().__init__() 277 | self.d_model = d_model 278 | self.cached_code = None 279 | self.l = 0 280 | self.gen_code(torch.tensor(0.0).expand(1, max_len)) 281 | 282 | def gen_code(self, x: Tensor): 283 | """Generate positional encoding with a reference tensor x. 284 | Args: 285 | x (Tensor): [B, L, ...], we extract the device, length, and dtype from it. 286 | Effects: 287 | self.cached_code (Tensor): [1, >=(2L-1), D]. 288 | """ 289 | l = x.size(1) 290 | if self.l >= l: 291 | if self.cached_code.dtype != x.dtype or self.cached_code.device != x.device: 292 | self.cached_code = self.cached_code.to(dtype=x.dtype, device=x.device) 293 | return 294 | # Suppose `i` means to the position of query vecotr and `j` means the 295 | # position of key vector. We use position relative positions when keys 296 | # are to the left (i>j) and negative relative positions otherwise (i Tensor: 319 | """Get positional encoding of appropriate shape given a reference Tensor. 320 | Args: 321 | x (Tensor): [B, L, ...]. 322 | Returns: 323 | y (Tensor): [1, 2L-1, D]. 324 | """ 325 | self.gen_code(x) 326 | l = x.size(1) 327 | pos_emb = self.cached_code[ 328 | :, self.l - l: self.l + l - 1, 329 | ] 330 | return pos_emb 331 | 332 | 333 | class ConformerBlock(torch.nn.Module): 334 | """Conformer block based on https://arxiv.org/abs/2005.08100.""" 335 | 336 | def __init__( 337 | self, d: int, d_hidden: int, 338 | attention_heads: int, dropout: float, 339 | depthwise_conv_kernel_size: int = 7, 340 | causal: bool = False, d_cond: int = 0 341 | ): 342 | """ 343 | Args: 344 | d (int): Block input output channel number. 345 | d_hidden (int): FFN layer dimension. 346 | attention_heads (int): Number of attention heads. 347 | dropout (float): dropout value. 348 | depthwise_conv_kernel_size (int): Size of kernel in depthwise conv. 349 | d_cond (int, optional): The channels of conditional tensor. 350 | """ 351 | super(ConformerBlock, self).__init__() 352 | self.causal = causal 353 | self.ffn1 = FeedForwardModule( 354 | d, d_hidden, dropout, bias=True, d_cond=d_cond 355 | ) 356 | 357 | self.self_attn_layer_norm = torch.nn.LayerNorm(d) 358 | 359 | if d_cond > 0: 360 | self.cond_layer = ConditionalBiasScale(d, d_cond) 361 | 362 | self.self_attn = RelPositionMultiHeadedAttention( 363 | d, attention_heads, dropout=dropout 364 | ) 365 | self.self_attn_dropout = torch.nn.Dropout(dropout) 366 | 367 | self.conv_module = ConvolutionModule( 368 | d_in=d, d_hidden=d, 369 | depthwise_kernel_size=depthwise_conv_kernel_size, 370 | dropout=dropout, d_cond=d_cond 371 | ) 372 | 373 | self.ffn2 = FeedForwardModule( 374 | d, d_hidden, dropout, bias=True, d_cond=d_cond 375 | ) 376 | 377 | self.final_layer_norm = torch.nn.LayerNorm(d) 378 | 379 | def forward( 380 | self, x: Tensor, mask: BoolTensor, pos_emb: Tensor, 381 | cond: Optional[Tensor] = None 382 | ) -> Tensor: 383 | """ 384 | Args: 385 | x (Tensor): [B, T, D_in]. 386 | mask (BoolTensor): [B, T], True for masked. 387 | pos_emb (Tensor): [1 or B, 2T-1, D]. 388 | cond (Tensor, optional): [B, ?, D_cond]. 389 | Returns: 390 | y (Tensor): [B, T, D_in]. 391 | """ 392 | y = x 393 | 394 | x = self.ffn1(x) * 0.5 + y 395 | y = x 396 | # [B, T, D_in] 397 | 398 | x = self.self_attn_layer_norm(x) 399 | 400 | if cond is not None: 401 | x = self.cond_layer.forward(x, cond) 402 | 403 | x = self.self_attn.forward( 404 | query=x, key=x, value=x, 405 | pos_emb=pos_emb, 406 | mask=mask, causal=self.causal 407 | ) 408 | x = self.self_attn_dropout(x) + y 409 | y = x 410 | # [B, T, D_in] 411 | 412 | x = self.conv_module.forward(x, mask) + y 413 | y = x 414 | # [B, T, D_in] 415 | 416 | x = self.ffn2(x) * 0.5 + y 417 | 418 | x = self.final_layer_norm(x) 419 | 420 | x.masked_fill(mask.unsqueeze(-1), 0.0) 421 | 422 | return x 423 | 424 | 425 | class ConvolutionModule(torch.nn.Module): 426 | """Convolution Block inside a Conformer Block.""" 427 | 428 | def __init__( 429 | self, d_in: int, d_hidden: int, 430 | depthwise_kernel_size: int, 431 | dropout: float, bias: bool = False, 432 | causal: bool = False, d_cond: int = 0 433 | ): 434 | """ 435 | Args: 436 | d_in (int): Embedding dimension. 437 | d_hidden (int): Number of channels in depthwise conv layers. 438 | depthwise_kernel_size (int): Depthwise conv layer kernel size. 439 | dropout (float): dropout value. 440 | bias (bool): If bias should be added to conv layers. 441 | conditional (bool): Whether to use conditional LayerNorm. 442 | """ 443 | super(ConvolutionModule, self).__init__() 444 | assert (depthwise_kernel_size - 1) % 2 == 0, "kernel_size should be odd" 445 | self.causal = causal 446 | self.causal_padding = (depthwise_kernel_size - 1, 0) 447 | self.layer_norm = torch.nn.LayerNorm(d_in) 448 | 449 | # Optional conditional LayerNorm: 450 | self.d_cond = d_cond 451 | if d_cond > 0: 452 | self.cond_layer = ConditionalBiasScale(d_in, d_cond) 453 | 454 | self.pointwise_conv1 = torch.nn.Conv1d( 455 | d_in, 2 * d_hidden, 456 | kernel_size=1, 457 | stride=1, padding=0, 458 | bias=bias 459 | ) 460 | self.glu = torch.nn.GLU(dim=1) 461 | self.depthwise_conv = torch.nn.Conv1d( 462 | d_hidden, d_hidden, 463 | kernel_size=depthwise_kernel_size, 464 | stride=1, 465 | padding=(depthwise_kernel_size - 1) // 2 if not causal else 0, 466 | groups=d_hidden, bias=bias 467 | ) 468 | self.pointwise_conv2 = torch.nn.Conv1d( 469 | d_hidden, d_in, 470 | kernel_size=1, 471 | stride=1, padding=0, 472 | bias=bias, 473 | ) 474 | self.dropout = torch.nn.Dropout(dropout) 475 | 476 | def forward(self, x: Tensor, mask: BoolTensor, cond: Optional[Tensor] = None) -> Tensor: 477 | """ 478 | Args: 479 | x (Tensor): [B, T, D_in]. 480 | mask (BoolTensor): [B, T], True for masked. 481 | cond (Tensor): [B, T, D_cond]. 482 | Returns: 483 | y (Tensor): [B, T, D_in]. 484 | """ 485 | x = self.layer_norm(x) 486 | 487 | if cond is not None: 488 | x = self.cond_layer.forward(x, cond) 489 | 490 | x = x.transpose(-1, -2) # [B, D_in, T] 491 | 492 | x = self.pointwise_conv1(x) # [B, 2C, T] 493 | x = self.glu(x) # [B, C, T] 494 | 495 | # Take care of masking the input tensor: 496 | if mask is not None: 497 | x = x.masked_fill(mask.unsqueeze(1), 0.0) 498 | 499 | # 1D Depthwise Conv 500 | if self.causal: # Causal padding 501 | x = pad(x, self.causal_padding) 502 | x = self.depthwise_conv(x) 503 | # FIXME: BatchNorm should not be used in variable length training. 504 | x = silu(x) # [B, C, T] 505 | 506 | if mask is not None: 507 | x = x.masked_fill(mask.unsqueeze(1), 0.0) 508 | 509 | x = self.pointwise_conv2(x) 510 | x = self.dropout(x) 511 | return x.transpose(-1, -2) # [B, T, D_in] 512 | 513 | 514 | class Conformer(torch.nn.Module): 515 | def __init__( 516 | self, 517 | d: int, 518 | d_hidden: int, 519 | n_heads: int, 520 | n_layers: int, 521 | dropout: float, 522 | depthwise_conv_kernel_size: int, 523 | causal: bool = False, 524 | d_cond: int = 0 525 | ): 526 | super().__init__() 527 | self.pos_encoding = RelPositionalEncoding(1024, d) 528 | self.causal = causal 529 | 530 | self.blocks = torch.nn.ModuleList( 531 | [ 532 | ConformerBlock( 533 | d=d, 534 | d_hidden=d_hidden, 535 | attention_heads=n_heads, 536 | dropout=dropout, 537 | depthwise_conv_kernel_size=depthwise_conv_kernel_size, 538 | causal=causal, 539 | d_cond=d_cond 540 | ) 541 | for _ in range(n_layers) 542 | ] 543 | ) # type: Iterable[ConformerBlock] 544 | 545 | def forward( 546 | self, x: Tensor, mask: BoolTensor, cond: Tensor = None 547 | ) -> Tensor: 548 | """Conformer forwarding. 549 | Args: 550 | x (Tensor): [B, T, D]. 551 | mask (BoolTensor): [B, T], with True for masked. 552 | cond (Tensor, optional): [B, T, D_cond]. 553 | Returns: 554 | y (Tensor): [B, T, D] 555 | """ 556 | pos_emb = self.pos_encoding(x) # [1, 2T-1, D] 557 | 558 | for block in self.blocks: 559 | x = block.forward(x, mask, pos_emb, cond) 560 | 561 | return x 562 | 563 | 564 | class CNNBlock(nn.Module): 565 | def __init__(self, in_dim, out_dim, dropout, cond_dim, kernel_size, stride): 566 | super(CNNBlock, self).__init__() 567 | self.layers = nn.Sequential( 568 | nn.Conv1d(in_dim, out_dim, kernel_size, stride), 569 | nn.ReLU(), 570 | nn.BatchNorm1d(out_dim,), 571 | nn.Dropout(p=dropout) 572 | ) 573 | 574 | def forward(self, inp): 575 | out = self.layers(inp) 576 | return out 577 | 578 | 579 | class CNNClassifier(nn.Module): 580 | def __init__(self, in_dim, d_decoder, decoder_dropout, cond_dim): 581 | super(CNNClassifier, self).__init__() 582 | self.cnn = nn.Sequential( 583 | CNNBlock(in_dim, d_decoder, decoder_dropout, cond_dim, 8, 4), 584 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 8, 4), 585 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), 586 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), 587 | ) # receptive field is 180, frame shift is 64 588 | self.cond_layer = nn.Sequential( 589 | nn.Linear(cond_dim, in_dim), 590 | nn.LeakyReLU(), 591 | nn.Linear(in_dim, in_dim) 592 | ) 593 | 594 | def forward(self, inp, mask, cond): 595 | inp = inp.transpose(-1, -2) 596 | cond = cond.transpose(-1, -2) 597 | inp.masked_fill_(mask.unsqueeze(1), 0.0) 598 | cond = self.cond_layer(cond.transpose(-1, -2)).transpose(-1, -2) 599 | cond.masked_fill_(mask.unsqueeze(1), 0.0) 600 | inp = inp + cond 601 | return self.cnn(inp) 602 | 603 | 604 | class CNNClassifierWithTime(nn.Module): 605 | def __init__(self, in_dim, d_decoder, decoder_dropout, cond_dim, time_emb_dim=512): 606 | super(CNNClassifierWithTime, self).__init__() 607 | self.cnn = nn.Sequential( 608 | CNNBlock(in_dim, d_decoder, decoder_dropout, cond_dim, 8, 4), 609 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 8, 4), 610 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), 611 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), 612 | ) # receptive field is 180, frame shift is 64 613 | self.cond_layer = nn.Sequential( 614 | nn.Linear(cond_dim, in_dim), 615 | nn.LeakyReLU(), 616 | nn.Linear(in_dim, in_dim) 617 | ) 618 | self.time_emb = SinusoidalPosEmb(time_emb_dim) 619 | self.time_layer = nn.Sequential( 620 | nn.Linear(time_emb_dim, in_dim), 621 | nn.LeakyReLU(), 622 | nn.Linear(in_dim, in_dim) 623 | ) 624 | 625 | def forward(self, inp, mask, cond, t): 626 | time_emb = self.time_emb(t) # [B, T] 627 | time_emb = self.time_layer(time_emb.unsqueeze(1)).transpose(-1, -2) 628 | inp = inp.transpose(-1, -2) 629 | cond = cond.transpose(-1, -2) 630 | inp.masked_fill_(mask.unsqueeze(1), 0.0) 631 | cond = self.cond_layer(cond.transpose(-1, -2)).transpose(-1, -2) 632 | cond.masked_fill_(mask.unsqueeze(1), 0.0) 633 | inp = inp + cond + time_emb 634 | return self.cnn(inp) 635 | 636 | 637 | class SpecClassifier(nn.Module): 638 | def __init__(self, in_dim, d_decoder, h_decoder, 639 | l_decoder, decoder_dropout, 640 | k_decoder, n_class, cond_dim, model_type='conformer'): 641 | super(SpecClassifier, self).__init__() 642 | self.model_type = model_type 643 | self.prenet = nn.Sequential( 644 | nn.Linear(in_features=in_dim, out_features=d_decoder) 645 | ) 646 | if model_type == 'conformer': 647 | self.conformer = Conformer(d=d_decoder, d_hidden=d_decoder, n_heads=h_decoder, 648 | n_layers=l_decoder, dropout=decoder_dropout, 649 | depthwise_conv_kernel_size=k_decoder, d_cond=cond_dim) 650 | elif model_type == 'CNN': 651 | self.conformer = CNNClassifier(in_dim=d_decoder, d_decoder=d_decoder, 652 | decoder_dropout=decoder_dropout, cond_dim=cond_dim) 653 | elif model_type == 'CNN-with-time': 654 | self.conformer = CNNClassifierWithTime(in_dim=d_decoder, d_decoder=d_decoder, 655 | decoder_dropout=decoder_dropout, cond_dim=cond_dim, time_emb_dim=256) 656 | self.classifier = nn.Linear(d_decoder, n_class) 657 | 658 | def forward(self, noisy_mel, condition, mask, **kwargs): 659 | """ 660 | Args: 661 | noisy_mel: [B, T, D] 662 | condition: [B, T, D] 663 | mask: [B, T] with True for un-masked (real-values) 664 | 665 | Returns: 666 | classification logits (un-softmaxed) 667 | """ 668 | # print(noisy_mel.shape) 669 | noisy_mel = noisy_mel.masked_fill(~mask.unsqueeze(-1), 0.0) 670 | 671 | # print(self.prenet, noisy_mel.shape) 672 | hiddens = self.prenet(noisy_mel) 673 | 674 | if self.model_type == 'CNN-with-time': 675 | hiddens = self.conformer.forward(hiddens, ~mask, condition, kwargs['t']) 676 | else: 677 | hiddens = self.conformer.forward(hiddens, ~mask, condition) # [B, T, D] 678 | 679 | if self.model_type == 'conformer': 680 | averaged_hiddens = torch.mean(hiddens, dim=1) # [B, D] 681 | logits = self.classifier(averaged_hiddens) 682 | return logits 683 | elif self.model_type == 'CNN' or self.model_type == 'CNN-with-time': 684 | hiddens = hiddens.transpose(-1, -2) 685 | return self.classifier(hiddens) # [B, T', C] 686 | 687 | @property 688 | def nparams(self): 689 | return sum([p.numel() for p in self.parameters()]) 690 | 691 | -------------------------------------------------------------------------------- /model/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from einops import rearrange 4 | 5 | from model.base import BaseModule 6 | 7 | 8 | class Mish(BaseModule): 9 | def forward(self, x): 10 | return x * torch.tanh(torch.nn.functional.softplus(x)) 11 | 12 | 13 | class Upsample(BaseModule): 14 | def __init__(self, dim): 15 | super(Upsample, self).__init__() 16 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) 17 | 18 | def forward(self, x): 19 | return self.conv(x) 20 | 21 | 22 | class Downsample(BaseModule): 23 | def __init__(self, dim): 24 | super(Downsample, self).__init__() 25 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) # kernel=3, stride=2, padding=1. 26 | 27 | def forward(self, x): 28 | return self.conv(x) 29 | 30 | 31 | class Rezero(BaseModule): 32 | def __init__(self, fn): 33 | super(Rezero, self).__init__() 34 | self.fn = fn 35 | self.g = torch.nn.Parameter(torch.zeros(1)) 36 | 37 | def forward(self, x): 38 | return self.fn(x) * self.g 39 | 40 | 41 | class Block(BaseModule): 42 | def __init__(self, dim, dim_out, groups=8): 43 | super(Block, self).__init__() 44 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, 45 | padding=1), torch.nn.GroupNorm( 46 | groups, dim_out), Mish()) 47 | 48 | def forward(self, x, mask): 49 | output = self.block(x * mask) 50 | return output * mask 51 | 52 | 53 | class ResnetBlock(BaseModule): 54 | def __init__(self, dim, dim_out, time_emb_dim, groups=8): 55 | super(ResnetBlock, self).__init__() 56 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 57 | dim_out)) 58 | 59 | self.block1 = Block(dim, dim_out, groups=groups) 60 | self.block2 = Block(dim_out, dim_out, groups=groups) 61 | if dim != dim_out: 62 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) 63 | else: 64 | self.res_conv = torch.nn.Identity() 65 | 66 | def forward(self, x, mask, time_emb): 67 | h = self.block1(x, mask) 68 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 69 | h = self.block2(h, mask) 70 | output = h + self.res_conv(x * mask) 71 | return output 72 | 73 | 74 | class LinearAttention(BaseModule): 75 | def __init__(self, dim, heads=4, dim_head=32): 76 | super(LinearAttention, self).__init__() 77 | self.heads = heads 78 | hidden_dim = dim_head * heads 79 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # NOTE: 1x1 conv 80 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) 81 | 82 | def forward(self, x): 83 | b, c, h, w = x.shape 84 | qkv = self.to_qkv(x) 85 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) 86 | k = k.softmax(dim=-1) 87 | context = torch.einsum('bhdn,bhen->bhde', k, v) 88 | out = torch.einsum('bhde,bhdn->bhen', context, q) 89 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', 90 | heads=self.heads, h=h, w=w) 91 | return self.to_out(out) 92 | 93 | 94 | class Residual(BaseModule): 95 | def __init__(self, fn): 96 | super(Residual, self).__init__() 97 | self.fn = fn 98 | 99 | def forward(self, x, *args, **kwargs): 100 | output = self.fn(x, *args, **kwargs) + x 101 | return output 102 | 103 | 104 | class SinusoidalPosEmb(BaseModule): 105 | def __init__(self, dim): 106 | super(SinusoidalPosEmb, self).__init__() 107 | self.dim = dim 108 | 109 | def forward(self, x, scale=1000): 110 | device = x.device 111 | half_dim = self.dim // 2 112 | emb = math.log(10000) / (half_dim - 1) 113 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 114 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 115 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 116 | return emb 117 | 118 | 119 | class GradLogPEstimator2d(BaseModule): 120 | def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, spk_emb_dim=64, n_feats=80, pe_scale=1000): 121 | super(GradLogPEstimator2d, self).__init__() 122 | self.dim = dim 123 | self.dim_mults = dim_mults 124 | self.groups = groups 125 | self.spk_emb_dim = spk_emb_dim 126 | self.pe_scale = pe_scale 127 | 128 | self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), 129 | torch.nn.Linear(spk_emb_dim * 4, n_feats)) 130 | self.time_pos_emb = SinusoidalPosEmb(dim) 131 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), 132 | torch.nn.Linear(dim * 4, dim)) 133 | 134 | dims = [3, *map(lambda m: dim * m, dim_mults)] 135 | in_out = list(zip(dims[:-1], dims[1:])) 136 | self.downs = torch.nn.ModuleList([]) 137 | self.ups = torch.nn.ModuleList([]) 138 | num_resolutions = len(in_out) 139 | 140 | for ind, (dim_in, dim_out) in enumerate(in_out): 141 | is_last = ind >= (num_resolutions - 1) 142 | self.downs.append(torch.nn.ModuleList([ 143 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim), 144 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim), 145 | Residual(Rezero(LinearAttention(dim_out))), 146 | Downsample(dim_out) if not is_last else torch.nn.Identity()])) 147 | 148 | mid_dim = dims[-1] 149 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 150 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 151 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 152 | 153 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 154 | self.ups.append(torch.nn.ModuleList([ 155 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), 156 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim), 157 | Residual(Rezero(LinearAttention(dim_in))), 158 | Upsample(dim_in)])) 159 | self.final_block = Block(dim, dim) 160 | self.final_conv = torch.nn.Conv2d(dim, 1, 1) 161 | 162 | def forward(self, x, mask, mu, t, spk=None): 163 | # x, mu: [B, 80, L], t: [B, ], mask: [B, 1, L] 164 | if not isinstance(spk, type(None)): 165 | s = self.spk_mlp(spk) 166 | 167 | t = self.time_pos_emb(t, scale=self.pe_scale) 168 | t = self.mlp(t) # [B, 64] 169 | 170 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) 171 | x = torch.stack([mu, x, s], 1) # [B, 3, 80, L] 172 | mask = mask.unsqueeze(1) # [B, 1, 1, L] 173 | 174 | hiddens = [] 175 | masks = [mask] 176 | for resnet1, resnet2, attn, downsample in self.downs: 177 | mask_down = masks[-1] 178 | x = resnet1(x, mask_down, t) # [B, 64, 80, L] 179 | x = resnet2(x, mask_down, t) 180 | x = attn(x) 181 | hiddens.append(x) 182 | x = downsample(x * mask_down) 183 | masks.append(mask_down[:, :, :, ::2]) 184 | 185 | masks = masks[:-1] 186 | mask_mid = masks[-1] 187 | x = self.mid_block1(x, mask_mid, t) 188 | x = self.mid_attn(x) 189 | x = self.mid_block2(x, mask_mid, t) 190 | 191 | for resnet1, resnet2, attn, upsample in self.ups: 192 | mask_up = masks.pop() 193 | x = torch.cat((x, hiddens.pop()), dim=1) 194 | x = resnet1(x, mask_up, t) 195 | x = resnet2(x, mask_up, t) 196 | x = attn(x) 197 | x = upsample(x * mask_up) 198 | 199 | x = self.final_block(x, mask) 200 | output = self.final_conv(x * mask) 201 | 202 | return (output * mask).squeeze(1) 203 | 204 | 205 | def get_noise(t, beta_init, beta_term, cumulative=False): 206 | if cumulative: 207 | noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) 208 | else: 209 | noise = beta_init + (beta_term - beta_init)*t 210 | return noise 211 | 212 | 213 | class Diffusion(BaseModule): 214 | def __init__(self, n_feats, dim, spk_emb_dim=64, 215 | beta_min=0.05, beta_max=20, pe_scale=1000): 216 | super(Diffusion, self).__init__() 217 | self.n_feats = n_feats 218 | self.dim = dim 219 | # self.n_spks = n_spks 220 | self.spk_emb_dim = spk_emb_dim 221 | self.beta_min = beta_min 222 | self.beta_max = beta_max 223 | self.pe_scale = pe_scale 224 | 225 | self.estimator = GradLogPEstimator2d(dim, 226 | spk_emb_dim=spk_emb_dim, 227 | pe_scale=pe_scale, 228 | n_feats=n_feats) 229 | 230 | def forward_diffusion(self, x0, mask, mu, t): 231 | time = t.unsqueeze(-1).unsqueeze(-1) 232 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) # it is actually the integral of beta 233 | mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) 234 | variance = 1.0 - torch.exp(-cum_noise) 235 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, 236 | requires_grad=False) 237 | xt = mean + z * torch.sqrt(variance) 238 | return xt * mask, z * mask 239 | 240 | @torch.no_grad() 241 | def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None, 242 | use_classifier_free=False, 243 | classifier_free_guidance=3.0, 244 | dummy_spk=None): # emo need to be merged by spk 245 | 246 | # looks like a plain Euler-Maruyama method 247 | h = 1.0 / n_timesteps 248 | xt = z * mask 249 | for i in range(n_timesteps): 250 | t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, 251 | device=z.device) 252 | time = t.unsqueeze(-1).unsqueeze(-1) 253 | noise_t = get_noise(time, self.beta_min, self.beta_max, 254 | cumulative=False) 255 | 256 | if not use_classifier_free: 257 | if stoc: # adds stochastic term 258 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) 259 | dxt_det = dxt_det * noise_t * h 260 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, 261 | requires_grad=False) 262 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 263 | dxt = dxt_det + dxt_stoc 264 | else: 265 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk)) 266 | dxt = dxt * noise_t * h 267 | xt = (xt - dxt) * mask 268 | else: 269 | if stoc: # adds stochastic term 270 | score_estimate = (1 + classifier_free_guidance) * self.estimator(xt, mask, mu, t, spk) \ 271 | - classifier_free_guidance * self.estimator(xt, mask, mu, t, dummy_spk) 272 | dxt_det = 0.5 * (mu - xt) - score_estimate 273 | dxt_det = dxt_det * noise_t * h 274 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, 275 | requires_grad=False) 276 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 277 | dxt = dxt_det + dxt_stoc 278 | else: 279 | score_estimate = (1 + classifier_free_guidance) * self.estimator(xt, mask, mu, t, spk) \ 280 | - classifier_free_guidance * self.estimator(xt, mask, mu, t, dummy_spk) 281 | dxt = 0.5 * (mu - xt - score_estimate) 282 | dxt = dxt * noise_t * h 283 | xt = (xt - dxt) * mask 284 | return xt 285 | 286 | @torch.no_grad() 287 | def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None, 288 | use_classifier_free=False, 289 | classifier_free_guidance=3.0, 290 | dummy_spk=None 291 | ): 292 | return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk, use_classifier_free, classifier_free_guidance, dummy_spk) 293 | 294 | def loss_t(self, x0, mask, mu, t, spk=None): 295 | xt, z = self.forward_diffusion(x0, mask, mu, t) # z is sampled from N(0, I) 296 | time = t.unsqueeze(-1).unsqueeze(-1) 297 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) 298 | noise_estimation = self.estimator(xt, mask, mu, t, spk) 299 | noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) # multiply by lambda which is set to be variance 300 | # actually multiplied by sqrt(lambda), but not lambda 301 | # NOTE: here use a trick to put lambda into L2 norm so that don't divide z with std. 302 | loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats) 303 | return loss, xt 304 | 305 | def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5): 306 | t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device, 307 | requires_grad=False) 308 | t = torch.clamp(t, offset, 1.0 - offset) 309 | return self.loss_t(x0, mask, mu, t, spk) 310 | 311 | def classifier_decode(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo=None, classifier_type="conformer"): 312 | # control_emo should be [B, ] tensor 313 | h = 1.0 / n_timesteps 314 | xt = z * mask 315 | for i in range(n_timesteps): 316 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, 317 | device=z.device) 318 | time = t.unsqueeze(-1).unsqueeze(-1) 319 | noise_t = get_noise(time, self.beta_min, self.beta_max, 320 | cumulative=False) 321 | # =========== classifier part ============== 322 | xt = xt.detach() 323 | xt.requires_grad_(True) 324 | if classifier_type == 'CNN-with-time': 325 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1), t=t) 326 | else: 327 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1)) 328 | 329 | if classifier_type == 'conformer': # [B, C] 330 | probs = torch.log_softmax(logits, dim=-1) # [B, C] 331 | elif classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 332 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C] 333 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C] 334 | probs = torch.log(probs_mean) 335 | else: 336 | raise NotImplementedError 337 | 338 | control_emo_probs = probs[torch.arange(len(control_emo)).to(control_emo.device), control_emo] 339 | control_emo_probs.sum().backward(retain_graph=True) 340 | # NOTE: sum is to treat all the components as the same weight. 341 | xt_grad = xt.grad 342 | # ========================================== 343 | 344 | if stoc: # adds stochastic term 345 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad 346 | dxt_det = dxt_det * noise_t * h 347 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, 348 | requires_grad=False) 349 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 350 | dxt = dxt_det + dxt_stoc 351 | else: 352 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad) 353 | dxt = dxt * noise_t * h 354 | xt = (xt - dxt) * mask 355 | return xt 356 | 357 | def classifier_decode_DPS(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo=None, classifier_type="conformer"): 358 | # control_emo should be [B, ] tensor 359 | h = 1.0 / n_timesteps 360 | xt = z * mask 361 | for i in range(n_timesteps): 362 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 363 | time = t.unsqueeze(-1).unsqueeze(-1) 364 | noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False) 365 | beta_integral_t = get_noise(time, self.beta_min, self.beta_max, cumulative=True) 366 | bar_alpha_t = math.exp(-beta_integral_t) 367 | 368 | # =========== classifier part ============== 369 | xt = xt.detach() 370 | xt.requires_grad_(True) 371 | score_estimate = self.estimator(xt, mask, mu, t, spk) 372 | x0_hat = (xt + (1-bar_alpha_t) * score_estimate) / math.sqrt(bar_alpha_t) 373 | 374 | if classifier_type == 'CNN-with-time': 375 | raise NotImplementedError 376 | else: 377 | logits = classifier_func(x0_hat.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1)) 378 | if classifier_type == 'conformer': # [B, C] 379 | probs = torch.log_softmax(logits, dim=-1) # [B, C] 380 | elif classifier_type == 'CNN': 381 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C] 382 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C] 383 | 384 | probs_mean = probs_mean + 10E-10 385 | # NOTE: at the first few steps, x0 may be very large. Then the classifier output logits will also have extreme value range. 386 | # 387 | 388 | probs = torch.log(probs_mean) 389 | else: 390 | raise NotImplementedError 391 | 392 | control_emo_probs = probs[torch.arange(len(control_emo)).to(control_emo.device), control_emo] 393 | control_emo_probs.sum().backward(retain_graph=True) 394 | # NOTE: sum is to treat all the components as the same weight. 395 | xt_grad = xt.grad 396 | # ========================================== 397 | 398 | if stoc: # adds stochastic term 399 | dxt_det = 0.5 * (mu - xt) - score_estimate - guidance * xt_grad 400 | dxt_det = dxt_det * noise_t * h 401 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, requires_grad=False) 402 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 403 | dxt = dxt_det + dxt_stoc 404 | else: 405 | dxt = 0.5 * (mu - xt - score_estimate - guidance * xt_grad) 406 | dxt = dxt * noise_t * h 407 | xt = (xt - dxt) * mask 408 | return xt 409 | 410 | def classifier_decode_mixture(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo1=None,control_emo2=None, emo1_weight=None, classifier_type="conformer"): 411 | # control_emo should be [B, ] tensor 412 | h = 1.0 / n_timesteps 413 | xt = z * mask 414 | for i in range(n_timesteps): 415 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, 416 | device=z.device) 417 | time = t.unsqueeze(-1).unsqueeze(-1) 418 | noise_t = get_noise(time, self.beta_min, self.beta_max, 419 | cumulative=False) 420 | # =========== classifier part ============== 421 | xt = xt.detach() 422 | xt.requires_grad_(True) 423 | if classifier_type == 'CNN-with-time': 424 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1), t=t) 425 | else: 426 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1)) 427 | 428 | if classifier_type == 'conformer': # [B, C] 429 | probs = torch.log_softmax(logits, dim=-1) # [B, C] 430 | elif classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 431 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C] 432 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C] 433 | probs = torch.log(probs_mean) 434 | else: 435 | raise NotImplementedError 436 | 437 | control_emo_probs1 = probs[torch.arange(len(control_emo1)).to(control_emo1.device), control_emo1] 438 | control_emo_probs2 = probs[torch.arange(len(control_emo2)).to(control_emo2.device), control_emo2] 439 | control_emo_probs = control_emo_probs1 * emo1_weight + control_emo_probs2 * (1-emo1_weight) # interpolate 440 | 441 | control_emo_probs.sum().backward(retain_graph=True) 442 | # NOTE: sum is to treat all the components as the same weight. 443 | xt_grad = xt.grad 444 | # ========================================== 445 | 446 | if stoc: # adds stochastic term 447 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad 448 | dxt_det = dxt_det * noise_t * h 449 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, 450 | requires_grad=False) 451 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 452 | dxt = dxt_det + dxt_stoc 453 | else: 454 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad) 455 | dxt = dxt * noise_t * h 456 | xt = (xt - dxt) * mask 457 | return xt 458 | 459 | def classifier_decode_mixture_DPS(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo1=None,control_emo2=None, emo1_weight=None, classifier_type="conformer"): 460 | # control_emo should be [B, ] tensor 461 | h = 1.0 / n_timesteps 462 | xt = z * mask 463 | for i in range(n_timesteps): 464 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, 465 | device=z.device) 466 | time = t.unsqueeze(-1).unsqueeze(-1) 467 | noise_t = get_noise(time, self.beta_min, self.beta_max, 468 | cumulative=False) 469 | beta_integral_t = get_noise(time, self.beta_min, self.beta_max, cumulative=True) 470 | bar_alpha_t = math.exp(-beta_integral_t) 471 | # =========== classifier part ============== 472 | xt = xt.detach() 473 | xt.requires_grad_(True) 474 | score_estimate = self.estimator(xt, mask, mu, t, spk) 475 | x0_hat = (xt + (1 - bar_alpha_t) * score_estimate) / math.sqrt(bar_alpha_t) 476 | 477 | if classifier_type == 'CNN-with-time': 478 | raise NotImplementedError 479 | else: 480 | logits = classifier_func(x0_hat.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1)) 481 | 482 | if classifier_type == 'conformer': # [B, C] 483 | probs = torch.log_softmax(logits, dim=-1) # [B, C] 484 | elif classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 485 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C] 486 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C] 487 | probs_mean = probs_mean + 10E-10 488 | 489 | probs = torch.log(probs_mean) 490 | else: 491 | raise NotImplementedError 492 | 493 | control_emo_probs1 = probs[torch.arange(len(control_emo1)).to(control_emo1.device), control_emo1] 494 | control_emo_probs2 = probs[torch.arange(len(control_emo2)).to(control_emo2.device), control_emo2] 495 | control_emo_probs = control_emo_probs1 * emo1_weight + control_emo_probs2 * (1-emo1_weight) # interpolate 496 | 497 | control_emo_probs.sum().backward(retain_graph=True) 498 | # NOTE: sum is to treat all the components as the same weight. 499 | xt_grad = xt.grad 500 | # ========================================== 501 | 502 | if stoc: # adds stochastic term 503 | dxt_det = 0.5 * (mu - xt) - score_estimate - guidance * xt_grad 504 | dxt_det = dxt_det * noise_t * h 505 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, 506 | requires_grad=False) 507 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 508 | dxt = dxt_det + dxt_stoc 509 | else: 510 | dxt = 0.5 * (mu - xt - score_estimate - guidance * xt_grad) 511 | dxt = dxt * noise_t * h 512 | xt = (xt - dxt) * mask 513 | return xt 514 | -------------------------------------------------------------------------------- /model/monotonic_align/LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jaehyeon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /model/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | from .model.monotonic_align.core import maximum_path_c 6 | 7 | 8 | def maximum_path(value, mask): 9 | """ Cython optimised version. 10 | value: [b, t_x, t_y] 11 | mask: [b, t_x, t_y] 12 | """ 13 | value = value * mask 14 | device = value.device 15 | dtype = value.dtype 16 | value = value.data.cpu().numpy().astype(np.float32) 17 | path = np.zeros_like(value).astype(np.int32) 18 | mask = mask.data.cpu().numpy() 19 | 20 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 21 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 22 | maximum_path_c(path, value, t_x_max, t_y_max) 23 | return torch.from_numpy(path).to(device=device, dtype=dtype) 24 | -------------------------------------------------------------------------------- /model/monotonic_align/build/temp.linux-x86_64-3.6/core.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/build/temp.linux-x86_64-3.6/core.o -------------------------------------------------------------------------------- /model/monotonic_align/build/temp.macosx-10.9-x86_64-3.6/core.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/build/temp.macosx-10.9-x86_64-3.6/core.o -------------------------------------------------------------------------------- /model/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | cimport cython 4 | from cython.parallel import prange 5 | 6 | 7 | @cython.boundscheck(False) 8 | @cython.wraparound(False) 9 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 10 | cdef int x 11 | cdef int y 12 | cdef float v_prev 13 | cdef float v_cur 14 | cdef float tmp 15 | cdef int index = t_x - 1 16 | 17 | for y in range(t_y): 18 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 19 | if x == y: 20 | v_cur = max_neg_val 21 | else: 22 | v_cur = value[x, y-1] 23 | if x == 0: 24 | if y == 0: 25 | v_prev = 0. 26 | else: 27 | v_prev = max_neg_val 28 | else: 29 | v_prev = value[x-1, y-1] 30 | value[x, y] = max(v_cur, v_prev) + value[x, y] 31 | 32 | for y in range(t_y - 1, -1, -1): 33 | path[index, y] = 1 34 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 35 | index = index - 1 36 | 37 | 38 | @cython.boundscheck(False) 39 | @cython.wraparound(False) 40 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 41 | cdef int b = values.shape[0] 42 | 43 | cdef int i 44 | for i in prange(b, nogil=True): 45 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 46 | -------------------------------------------------------------------------------- /model/monotonic_align/model/monotonic_align/core.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/model/monotonic_align/core.cpython-36m-darwin.so -------------------------------------------------------------------------------- /model/monotonic_align/model/monotonic_align/core.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/model/monotonic_align/core.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /model/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | from distutils.core import setup 4 | from Cython.Build import cythonize 5 | import numpy 6 | 7 | setup( 8 | name = 'monotonic_align', 9 | ext_modules = cythonize("core.pyx"), 10 | include_dirs=[numpy.get_include()] 11 | ) 12 | -------------------------------------------------------------------------------- /model/text_encoder.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import math 4 | 5 | import torch 6 | 7 | from model.base import BaseModule 8 | from model.utils import sequence_mask, convert_pad_shape 9 | 10 | 11 | class LayerNorm(BaseModule): 12 | def __init__(self, channels, eps=1e-4): 13 | super(LayerNorm, self).__init__() 14 | self.channels = channels 15 | self.eps = eps 16 | 17 | self.gamma = torch.nn.Parameter(torch.ones(channels)) 18 | self.beta = torch.nn.Parameter(torch.zeros(channels)) 19 | 20 | def forward(self, x): 21 | n_dims = len(x.shape) 22 | mean = torch.mean(x, 1, keepdim=True) 23 | variance = torch.mean((x - mean)**2, 1, keepdim=True) 24 | 25 | x = (x - mean) * torch.rsqrt(variance + self.eps) 26 | 27 | shape = [1, -1] + [1] * (n_dims - 2) 28 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 29 | return x 30 | 31 | 32 | class ConvReluNorm(BaseModule): 33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, 34 | n_layers, p_dropout): 35 | super(ConvReluNorm, self).__init__() 36 | self.in_channels = in_channels 37 | self.hidden_channels = hidden_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = kernel_size 40 | self.n_layers = n_layers 41 | self.p_dropout = p_dropout 42 | 43 | self.conv_layers = torch.nn.ModuleList() 44 | self.norm_layers = torch.nn.ModuleList() 45 | self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, 46 | kernel_size, padding=kernel_size//2)) 47 | self.norm_layers.append(LayerNorm(hidden_channels)) 48 | self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) 49 | for _ in range(n_layers - 1): 50 | self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, 51 | kernel_size, padding=kernel_size//2)) 52 | self.norm_layers.append(LayerNorm(hidden_channels)) 53 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) 54 | self.proj.weight.data.zero_() 55 | self.proj.bias.data.zero_() 56 | 57 | def forward(self, x, x_mask): 58 | x_org = x 59 | for i in range(self.n_layers): 60 | x = self.conv_layers[i](x * x_mask) 61 | x = self.norm_layers[i](x) 62 | x = self.relu_drop(x) 63 | x = x_org + self.proj(x) 64 | return x * x_mask 65 | 66 | 67 | class DurationPredictor(BaseModule): 68 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): 69 | super(DurationPredictor, self).__init__() 70 | self.in_channels = in_channels 71 | self.filter_channels = filter_channels 72 | self.p_dropout = p_dropout 73 | 74 | self.drop = torch.nn.Dropout(p_dropout) 75 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, 76 | kernel_size, padding=kernel_size//2) 77 | self.norm_1 = LayerNorm(filter_channels) 78 | self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, 79 | kernel_size, padding=kernel_size//2) 80 | self.norm_2 = LayerNorm(filter_channels) 81 | self.proj = torch.nn.Conv1d(filter_channels, 1, 1) 82 | 83 | def forward(self, x, x_mask): 84 | x = self.conv_1(x * x_mask) 85 | x = torch.relu(x) 86 | x = self.norm_1(x) 87 | x = self.drop(x) 88 | x = self.conv_2(x * x_mask) 89 | x = torch.relu(x) 90 | x = self.norm_2(x) 91 | x = self.drop(x) 92 | x = self.proj(x * x_mask) 93 | return x * x_mask 94 | 95 | 96 | class MultiHeadAttention(BaseModule): 97 | def __init__(self, channels, out_channels, n_heads, window_size=None, 98 | heads_share=True, p_dropout=0.0, proximal_bias=False, 99 | proximal_init=False): 100 | super(MultiHeadAttention, self).__init__() 101 | assert channels % n_heads == 0 102 | 103 | self.channels = channels 104 | self.out_channels = out_channels 105 | self.n_heads = n_heads 106 | self.window_size = window_size 107 | self.heads_share = heads_share 108 | self.proximal_bias = proximal_bias 109 | self.p_dropout = p_dropout 110 | self.attn = None 111 | 112 | self.k_channels = channels // n_heads 113 | self.conv_q = torch.nn.Conv1d(channels, channels, 1) 114 | self.conv_k = torch.nn.Conv1d(channels, channels, 1) 115 | self.conv_v = torch.nn.Conv1d(channels, channels, 1) 116 | if window_size is not None: 117 | n_heads_rel = 1 if heads_share else n_heads 118 | rel_stddev = self.k_channels**-0.5 119 | self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, 120 | window_size * 2 + 1, self.k_channels) * rel_stddev) 121 | self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, 122 | window_size * 2 + 1, self.k_channels) * rel_stddev) 123 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) 124 | self.drop = torch.nn.Dropout(p_dropout) 125 | 126 | torch.nn.init.xavier_uniform_(self.conv_q.weight) 127 | torch.nn.init.xavier_uniform_(self.conv_k.weight) 128 | if proximal_init: 129 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 130 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 131 | torch.nn.init.xavier_uniform_(self.conv_v.weight) 132 | 133 | def forward(self, x, c, attn_mask=None): 134 | q = self.conv_q(x) 135 | k = self.conv_k(c) 136 | v = self.conv_v(c) 137 | 138 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 139 | 140 | x = self.conv_o(x) 141 | return x 142 | 143 | def attention(self, query, key, value, mask=None): 144 | b, d, t_s, t_t = (*key.size(), query.size(2)) 145 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 146 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 147 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 148 | 149 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 150 | if self.window_size is not None: 151 | assert t_s == t_t, "Relative attention is only available for self-attention." 152 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 153 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 154 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 155 | scores_local = rel_logits / math.sqrt(self.k_channels) 156 | scores = scores + scores_local 157 | if self.proximal_bias: 158 | assert t_s == t_t, "Proximal bias is only available for self-attention." 159 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, 160 | dtype=scores.dtype) 161 | if mask is not None: 162 | scores = scores.masked_fill(mask == 0, -1e4) 163 | p_attn = torch.nn.functional.softmax(scores, dim=-1) 164 | p_attn = self.drop(p_attn) 165 | output = torch.matmul(p_attn, value) 166 | if self.window_size is not None: 167 | relative_weights = self._absolute_position_to_relative_position(p_attn) 168 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 169 | output = output + self._matmul_with_relative_values(relative_weights, 170 | value_relative_embeddings) 171 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) 172 | return output, p_attn 173 | 174 | def _matmul_with_relative_values(self, x, y): 175 | ret = torch.matmul(x, y.unsqueeze(0)) 176 | return ret 177 | 178 | def _matmul_with_relative_keys(self, x, y): 179 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 180 | return ret 181 | 182 | def _get_relative_embeddings(self, relative_embeddings, length): 183 | pad_length = max(length - (self.window_size + 1), 0) 184 | slice_start_position = max((self.window_size + 1) - length, 0) 185 | slice_end_position = slice_start_position + 2 * length - 1 186 | if pad_length > 0: 187 | padded_relative_embeddings = torch.nn.functional.pad( 188 | relative_embeddings, convert_pad_shape([[0, 0], 189 | [pad_length, pad_length], [0, 0]])) 190 | else: 191 | padded_relative_embeddings = relative_embeddings 192 | used_relative_embeddings = padded_relative_embeddings[:, 193 | slice_start_position:slice_end_position] 194 | return used_relative_embeddings 195 | 196 | def _relative_position_to_absolute_position(self, x): 197 | batch, heads, length, _ = x.size() 198 | x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 199 | x_flat = x.view([batch, heads, length * 2 * length]) 200 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) 201 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 202 | return x_final 203 | 204 | def _absolute_position_to_relative_position(self, x): 205 | batch, heads, length, _ = x.size() 206 | x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 207 | x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) 208 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 209 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 210 | return x_final 211 | 212 | def _attention_bias_proximal(self, length): 213 | r = torch.arange(length, dtype=torch.float32) 214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 216 | 217 | 218 | class FFN(BaseModule): 219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, 220 | p_dropout=0.0): 221 | super(FFN, self).__init__() 222 | self.in_channels = in_channels 223 | self.out_channels = out_channels 224 | self.filter_channels = filter_channels 225 | self.kernel_size = kernel_size 226 | self.p_dropout = p_dropout 227 | 228 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, 229 | padding=kernel_size//2) 230 | self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, 231 | padding=kernel_size//2) 232 | self.drop = torch.nn.Dropout(p_dropout) 233 | 234 | def forward(self, x, x_mask): 235 | x = self.conv_1(x * x_mask) 236 | x = torch.relu(x) 237 | x = self.drop(x) 238 | x = self.conv_2(x * x_mask) 239 | return x * x_mask 240 | 241 | 242 | class Encoder(BaseModule): 243 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, 244 | kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): 245 | super(Encoder, self).__init__() 246 | self.hidden_channels = hidden_channels 247 | self.filter_channels = filter_channels 248 | self.n_heads = n_heads 249 | self.n_layers = n_layers 250 | self.kernel_size = kernel_size 251 | self.p_dropout = p_dropout 252 | self.window_size = window_size 253 | 254 | self.drop = torch.nn.Dropout(p_dropout) 255 | self.attn_layers = torch.nn.ModuleList() 256 | self.norm_layers_1 = torch.nn.ModuleList() 257 | self.ffn_layers = torch.nn.ModuleList() 258 | self.norm_layers_2 = torch.nn.ModuleList() 259 | for _ in range(self.n_layers): 260 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, 261 | n_heads, window_size=window_size, p_dropout=p_dropout)) 262 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 263 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, 264 | filter_channels, kernel_size, p_dropout=p_dropout)) 265 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 266 | 267 | def forward(self, x, x_mask): 268 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 269 | for i in range(self.n_layers): 270 | x = x * x_mask 271 | y = self.attn_layers[i](x, x, attn_mask) 272 | y = self.drop(y) 273 | x = self.norm_layers_1[i](x + y) 274 | y = self.ffn_layers[i](x, x_mask) 275 | y = self.drop(y) 276 | x = self.norm_layers_2[i](x + y) 277 | x = x * x_mask 278 | return x 279 | 280 | 281 | class TextEncoder(BaseModule): 282 | def __init__(self, n_vocab, n_feats, n_channels, filter_channels, 283 | filter_channels_dp, n_heads, n_layers, kernel_size, 284 | p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): 285 | super(TextEncoder, self).__init__() 286 | self.n_vocab = n_vocab 287 | self.n_feats = n_feats 288 | self.n_channels = n_channels 289 | self.filter_channels = filter_channels 290 | self.filter_channels_dp = filter_channels_dp 291 | self.n_heads = n_heads 292 | self.n_layers = n_layers 293 | self.kernel_size = kernel_size 294 | self.p_dropout = p_dropout 295 | self.window_size = window_size 296 | self.spk_emb_dim = spk_emb_dim 297 | self.n_spks = n_spks 298 | 299 | self.emb = torch.nn.Embedding(n_vocab, n_channels) 300 | torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) 301 | 302 | self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, 303 | kernel_size=5, n_layers=3, p_dropout=0.5) 304 | 305 | self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, 306 | kernel_size, p_dropout, window_size=window_size) 307 | 308 | self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) 309 | self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, 310 | kernel_size, p_dropout) 311 | 312 | def forward(self, x, x_lengths, spk=None): 313 | x = self.emb(x) * math.sqrt(self.n_channels) 314 | x = torch.transpose(x, 1, -1) 315 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 316 | 317 | x = self.prenet(x, x_mask) 318 | if self.n_spks > 1: 319 | x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) 320 | x = self.encoder(x, x_mask) 321 | mu = self.proj_m(x) * x_mask 322 | 323 | x_dp = torch.detach(x) 324 | logw = self.proj_w(x_dp, x_mask) 325 | 326 | return mu, logw, x_mask 327 | -------------------------------------------------------------------------------- /model/tts.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | 6 | from model import monotonic_align 7 | from model.base import BaseModule 8 | from model.text_encoder import TextEncoder 9 | from model.diffusion import Diffusion 10 | from model.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility 11 | 12 | 13 | class GradTTSWithEmo(BaseModule): 14 | def __init__(self, n_vocab=148, n_spks=1,n_emos=5, spk_emb_dim=64, 15 | n_enc_channels=192, filter_channels=768, filter_channels_dp=256, 16 | n_heads=2, n_enc_layers=6, enc_kernel=3, enc_dropout=0.1, window_size=4, 17 | n_feats=80, dec_dim=64, beta_min=0.05, beta_max=20.0, pe_scale=1000, 18 | use_classifier_free=False, dummy_spk_rate=0.5, 19 | **kwargs): 20 | super(GradTTSWithEmo, self).__init__() 21 | self.n_vocab = n_vocab 22 | self.n_spks = n_spks 23 | self.n_emos = n_emos 24 | self.spk_emb_dim = spk_emb_dim 25 | self.n_enc_channels = n_enc_channels 26 | self.filter_channels = filter_channels 27 | self.filter_channels_dp = filter_channels_dp 28 | self.n_heads = n_heads 29 | self.n_enc_layers = n_enc_layers 30 | self.enc_kernel = enc_kernel 31 | self.enc_dropout = enc_dropout 32 | self.window_size = window_size 33 | self.n_feats = n_feats 34 | self.dec_dim = dec_dim 35 | self.beta_min = beta_min 36 | self.beta_max = beta_max 37 | self.pe_scale = pe_scale 38 | self.use_classifier_free = use_classifier_free 39 | 40 | # if n_spks > 1: 41 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) 42 | self.emo_emb = torch.nn.Embedding(n_emos, spk_emb_dim) 43 | self.merge_spk_emo = torch.nn.Sequential( 44 | torch.nn.Linear(spk_emb_dim*2, spk_emb_dim), 45 | torch.nn.ReLU(), 46 | torch.nn.Linear(spk_emb_dim, spk_emb_dim) 47 | ) 48 | self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, 49 | filter_channels, filter_channels_dp, n_heads, 50 | n_enc_layers, enc_kernel, enc_dropout, window_size, 51 | spk_emb_dim=spk_emb_dim, n_spks=n_spks) 52 | self.decoder = Diffusion(n_feats, dec_dim, spk_emb_dim, beta_min, beta_max, pe_scale) 53 | 54 | if self.use_classifier_free: 55 | self.dummy_xv = torch.nn.Parameter(torch.randn(size=(spk_emb_dim, ))) 56 | self.dummy_rate = dummy_spk_rate 57 | print(f"Using classifier free with rate {self.dummy_rate}") 58 | 59 | @torch.no_grad() 60 | def forward(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None, 61 | length_scale=1.0, classifier_free_guidance=1., force_dur=None): 62 | """ 63 | Generates mel-spectrogram from text. Returns: 64 | 1. encoder outputs 65 | 2. decoder outputs 66 | 3. generated alignment 67 | 68 | Args: 69 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 70 | x_lengths (torch.Tensor): lengths of texts in batch. 71 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 72 | temperature (float, optional): controls variance of terminal distribution. 73 | stoc (bool, optional): flag that adds stochastic term to the decoder sampler. 74 | Usually, does not provide synthesis improvements. 75 | length_scale (float, optional): controls speech pace. 76 | Increase value to slow down generated speech and vice versa. 77 | """ 78 | x, x_lengths = self.relocate_input([x, x_lengths]) 79 | 80 | # Get speaker embedding 81 | spk = self.spk_emb(spk) 82 | emo = self.emo_emb(emo) 83 | 84 | if self.use_classifier_free: 85 | emo = emo / torch.sqrt(torch.sum(emo**2, dim=1, keepdim=True)) # unit norm 86 | 87 | spk_merged = self.merge_spk_emo(torch.cat([spk, emo], dim=-1)) 88 | 89 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 90 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) 91 | 92 | w = torch.exp(logw) * x_mask 93 | w_ceil = torch.ceil(w) * length_scale 94 | if force_dur is not None: 95 | w_ceil = force_dur.unsqueeze(1) # [1, 1, Ltext] 96 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 97 | y_max_length = int(y_lengths.max()) 98 | y_max_length_ = fix_len_compatibility(y_max_length) 99 | 100 | # Using obtained durations `w` construct alignment map `attn` 101 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 102 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 103 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 104 | 105 | # Align encoded text and get mu_y 106 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 107 | mu_y = mu_y.transpose(1, 2) 108 | encoder_outputs = mu_y[:, :, :y_max_length] 109 | 110 | # Sample latent representation from terminal distribution N(mu_y, I) 111 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 112 | # print(z) 113 | # Generate sample by performing reverse dynamics 114 | 115 | unit_dummy_emo = self.dummy_xv / torch.sqrt(torch.sum(self.dummy_xv**2)) if self.use_classifier_free else None 116 | dummy_spk = self.merge_spk_emo(torch.cat([spk, unit_dummy_emo.unsqueeze(0).repeat(len(spk), 1)], dim=-1)) if self.use_classifier_free else None 117 | 118 | decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, 119 | use_classifier_free=self.use_classifier_free, 120 | classifier_free_guidance=classifier_free_guidance, 121 | dummy_spk=dummy_spk) 122 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 123 | 124 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 125 | 126 | def classifier_guidance_decode(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None, 127 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): 128 | x, x_lengths = self.relocate_input([x, x_lengths]) 129 | 130 | # Get speaker embedding 131 | spk = self.spk_emb(spk) 132 | dummy_emo = self.emo_emb(torch.zeros_like(emo).long()) # this is for feeding the text encoder. 133 | 134 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) 135 | 136 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 137 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) 138 | 139 | w = torch.exp(logw) * x_mask 140 | # print("w shape is ", w.shape) 141 | w_ceil = torch.ceil(w) * length_scale 142 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 143 | y_max_length = int(y_lengths.max()) 144 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 145 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier 146 | y_max_length_ = fix_len_compatibility(y_max_length) 147 | 148 | # Using obtained durations `w` construct alignment map `attn` 149 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 150 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 151 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 152 | 153 | # Align encoded text and get mu_y 154 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 155 | mu_y = mu_y.transpose(1, 2) 156 | encoder_outputs = mu_y[:, :, :y_max_length] 157 | 158 | # Sample latent representation from terminal distribution N(mu_y, I) 159 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 160 | # Generate sample by performing reverse dynamics 161 | 162 | decoder_outputs = self.decoder.classifier_decode(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, 163 | classifier_func, guidance, 164 | control_emo=emo, classifier_type=classifier_type) 165 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 166 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 167 | 168 | def classifier_guidance_decode_DPS(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None, 169 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): 170 | x, x_lengths = self.relocate_input([x, x_lengths]) 171 | 172 | # Get speaker embedding 173 | spk = self.spk_emb(spk) 174 | dummy_emo = self.emo_emb(torch.zeros_like(emo).long()) # this is for feeding the text encoder. 175 | 176 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) 177 | 178 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 179 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) 180 | 181 | w = torch.exp(logw) * x_mask 182 | w_ceil = torch.ceil(w) * length_scale 183 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 184 | y_max_length = int(y_lengths.max()) 185 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 186 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier 187 | y_max_length_ = fix_len_compatibility(y_max_length) 188 | 189 | # Using obtained durations `w` construct alignment map `attn` 190 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 191 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 192 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 193 | 194 | # Align encoded text and get mu_y 195 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 196 | mu_y = mu_y.transpose(1, 2) 197 | encoder_outputs = mu_y[:, :, :y_max_length] 198 | 199 | # Sample latent representation from terminal distribution N(mu_y, I) 200 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 201 | # Generate sample by performing reverse dynamics 202 | 203 | decoder_outputs = self.decoder.classifier_decode_DPS(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, 204 | classifier_func, guidance, 205 | control_emo=emo, classifier_type=classifier_type) 206 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 207 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 208 | 209 | def classifier_guidance_decode_two_mixture(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo1=None, emo2=None, emo1_weight=None, 210 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): 211 | x, x_lengths = self.relocate_input([x, x_lengths]) 212 | 213 | # Get speaker embedding 214 | spk = self.spk_emb(spk) 215 | dummy_emo = self.emo_emb(torch.zeros_like(emo1).long()) # this is for feeding the text encoder. 216 | 217 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) 218 | 219 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 220 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) 221 | 222 | w = torch.exp(logw) * x_mask 223 | w_ceil = torch.ceil(w) * length_scale 224 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 225 | y_max_length = int(y_lengths.max()) 226 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 227 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier 228 | y_max_length_ = fix_len_compatibility(y_max_length) 229 | 230 | # Using obtained durations `w` construct alignment map `attn` 231 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 232 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 233 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 234 | 235 | # Align encoded text and get mu_y 236 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 237 | mu_y = mu_y.transpose(1, 2) 238 | encoder_outputs = mu_y[:, :, :y_max_length] 239 | 240 | # Sample latent representation from terminal distribution N(mu_y, I) 241 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 242 | # Generate sample by performing reverse dynamics 243 | 244 | decoder_outputs = self.decoder.classifier_decode_mixture(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, 245 | classifier_func, guidance, 246 | control_emo1=emo1, control_emo2=emo2, emo1_weight=emo1_weight, classifier_type=classifier_type) 247 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 248 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 249 | 250 | def classifier_guidance_decode_two_mixture_DPS(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo1=None, emo2=None, emo1_weight=None, 251 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): 252 | x, x_lengths = self.relocate_input([x, x_lengths]) 253 | 254 | # Get speaker embedding 255 | spk = self.spk_emb(spk) 256 | dummy_emo = self.emo_emb(torch.zeros_like(emo1).long()) # this is for feeding the text encoder. 257 | 258 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) 259 | 260 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 261 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) 262 | 263 | w = torch.exp(logw) * x_mask 264 | w_ceil = torch.ceil(w) * length_scale 265 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 266 | y_max_length = int(y_lengths.max()) 267 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : 268 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier 269 | y_max_length_ = fix_len_compatibility(y_max_length) 270 | 271 | # Using obtained durations `w` construct alignment map `attn` 272 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 273 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 274 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 275 | 276 | # Align encoded text and get mu_y 277 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 278 | mu_y = mu_y.transpose(1, 2) 279 | encoder_outputs = mu_y[:, :, :y_max_length] 280 | 281 | # Sample latent representation from terminal distribution N(mu_y, I) 282 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 283 | # Generate sample by performing reverse dynamics 284 | 285 | decoder_outputs = self.decoder.classifier_decode_mixture_DPS(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, 286 | classifier_func, guidance, 287 | control_emo1=emo1, control_emo2=emo2, emo1_weight=emo1_weight, classifier_type=classifier_type) 288 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 289 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 290 | 291 | def compute_loss(self, x, x_lengths, y, y_lengths, spk=None, emo=None, out_size=None, use_gt_dur=False, durs=None): 292 | """ 293 | Computes 3 losses: 294 | 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 295 | 2. prior loss: loss between mel-spectrogram and encoder outputs. 296 | 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. 297 | 298 | Args: 299 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 300 | x_lengths (torch.Tensor): lengths of texts in batch. 301 | y (torch.Tensor): batch of corresponding mel-spectrograms. 302 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. 303 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. 304 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. 305 | use_gt_dur: bool 306 | durs: gt duration 307 | """ 308 | x, x_lengths, y, y_lengths = self.relocate_input([x, x_lengths, y, y_lengths]) # y: B, 80, L 309 | 310 | spk = self.spk_emb(spk) 311 | emo = self.emo_emb(emo) # [B, D] 312 | if self.use_classifier_free: 313 | emo = emo / torch.sqrt(torch.sum(emo ** 2, dim=1, keepdim=True)) # unit norm 314 | use_dummy_per_sample = torch.distributions.Binomial(1, torch.tensor( 315 | [self.dummy_rate] * len(emo))).sample().bool() # [b, ] True/False where True accords to rate 316 | emo[use_dummy_per_sample] = (self.dummy_xv / torch.sqrt( 317 | torch.sum(self.dummy_xv ** 2))) # substitute with dummy xv(unit norm too) 318 | 319 | spk = self.merge_spk_emo(torch.cat([spk, emo], dim=-1)) # [B, D] 320 | 321 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 322 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) 323 | y_max_length = y.shape[-1] 324 | 325 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 326 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 327 | 328 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram 329 | if use_gt_dur: 330 | attn = generate_path(durs, attn_mask.squeeze(1)).detach() 331 | else: 332 | with torch.no_grad(): 333 | const = -0.5 * math.log(2 * math.pi) * self.n_feats 334 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) 335 | y_square = torch.matmul(factor.transpose(1, 2), y ** 2) 336 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) 337 | mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) 338 | log_prior = y_square - y_mu_double + mu_square + const 339 | # it's actually the log likelihood of y given the Gaussian with (mu_x, I) 340 | 341 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) 342 | attn = attn.detach() 343 | 344 | # Compute loss between predicted log-scaled durations and those obtained from MAS 345 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask 346 | dur_loss = duration_loss(logw, logw_, x_lengths) 347 | # print(attn.shape) 348 | 349 | # Cut a small segment of mel-spectrogram in order to increase batch size 350 | if not isinstance(out_size, type(None)): 351 | clip_size = min(out_size, y_max_length) # when out_size > max length, do not actually perform clipping 352 | clip_size = -fix_len_compatibility(-clip_size) # this is to ensure dividable 353 | max_offset = (y_lengths - clip_size).clamp(0) 354 | offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) 355 | out_offset = torch.LongTensor([ 356 | torch.tensor(random.choice(range(start, end)) if end > start else 0) 357 | for start, end in offset_ranges 358 | ]).to(y_lengths) 359 | 360 | attn_cut = torch.zeros(attn.shape[0], attn.shape[1], clip_size, dtype=attn.dtype, device=attn.device) 361 | y_cut = torch.zeros(y.shape[0], self.n_feats, clip_size, dtype=y.dtype, device=y.device) 362 | y_cut_lengths = [] 363 | for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): 364 | y_cut_length = clip_size + (y_lengths[i] - clip_size).clamp(None, 0) 365 | y_cut_lengths.append(y_cut_length) 366 | cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length 367 | y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] 368 | attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] 369 | y_cut_lengths = torch.LongTensor(y_cut_lengths) 370 | y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) 371 | 372 | attn = attn_cut # attn -> [B, text_length, cut_length]. It does not begin from top left corner 373 | y = y_cut 374 | y_mask = y_cut_mask 375 | 376 | # Align encoded text with mel-spectrogram and get mu_y segment 377 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) # here mu_x is not cut. 378 | mu_y = mu_y.transpose(1, 2) # B, 80, cut_length 379 | 380 | # Compute loss of score-based decoder 381 | # print(y.shape, y_mask.shape, mu_y.shape) 382 | diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk) 383 | 384 | # Compute loss between aligned encoder outputs and mel-spectrogram 385 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) 386 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) 387 | 388 | return dur_loss, prior_loss, diff_loss 389 | 390 | 391 | class GradTTSXvector(BaseModule): 392 | def __init__(self, n_vocab=148, spk_emb_dim=64, 393 | n_enc_channels=192, filter_channels=768, filter_channels_dp=256, 394 | n_heads=2, n_enc_layers=6, enc_kernel=3, enc_dropout=0.1, window_size=4, 395 | n_feats=80, dec_dim=64, beta_min=0.05, beta_max=20.0, pe_scale=1000, xvector_dim=512, **kwargs): 396 | super(GradTTSXvector, self).__init__() 397 | self.n_vocab = n_vocab 398 | # self.n_spks = n_spks 399 | self.spk_emb_dim = spk_emb_dim 400 | self.n_enc_channels = n_enc_channels 401 | self.filter_channels = filter_channels 402 | self.filter_channels_dp = filter_channels_dp 403 | self.n_heads = n_heads 404 | self.n_enc_layers = n_enc_layers 405 | self.enc_kernel = enc_kernel 406 | self.enc_dropout = enc_dropout 407 | self.window_size = window_size 408 | self.n_feats = n_feats 409 | self.dec_dim = dec_dim 410 | self.beta_min = beta_min 411 | self.beta_max = beta_max 412 | self.pe_scale = pe_scale 413 | 414 | self.xvector_proj = torch.nn.Linear(xvector_dim, spk_emb_dim) 415 | self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, 416 | filter_channels, filter_channels_dp, n_heads, 417 | n_enc_layers, enc_kernel, enc_dropout, window_size, 418 | spk_emb_dim=spk_emb_dim, n_spks=999) # NOTE: not important `n_spk` 419 | self.decoder = Diffusion(n_feats, dec_dim, spk_emb_dim, beta_min, beta_max, pe_scale) 420 | 421 | @torch.no_grad() 422 | def forward(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, length_scale=1.0): 423 | """ 424 | Generates mel-spectrogram from text. Returns: 425 | 1. encoder outputs 426 | 2. decoder outputs 427 | 3. generated alignment 428 | 429 | Args: 430 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 431 | x_lengths (torch.Tensor): lengths of texts in batch. 432 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 433 | temperature (float, optional): controls variance of terminal distribution. 434 | stoc (bool, optional): flag that adds stochastic term to the decoder sampler. 435 | Usually, does not provide synthesis improvements. 436 | length_scale (float, optional): controls speech pace. 437 | Increase value to slow down generated speech and vice versa. 438 | spk: actually the xvectors 439 | """ 440 | x, x_lengths = self.relocate_input([x, x_lengths]) 441 | 442 | spk = self.xvector_proj(spk) # NOTE: use x-vectors instead of speaker embedding 443 | 444 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 445 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) 446 | 447 | w = torch.exp(logw) * x_mask 448 | w_ceil = torch.ceil(w) * length_scale 449 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 450 | y_max_length = int(y_lengths.max()) 451 | y_max_length_ = fix_len_compatibility(y_max_length) 452 | 453 | # Using obtained durations `w` construct alignment map `attn` 454 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 455 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 456 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 457 | 458 | # Align encoded text and get mu_y 459 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 460 | mu_y = mu_y.transpose(1, 2) 461 | encoder_outputs = mu_y[:, :, :y_max_length] 462 | 463 | # Sample latent representation from terminal distribution N(mu_y, I) 464 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 465 | # Generate sample by performing reverse dynamics 466 | decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk) 467 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 468 | 469 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 470 | 471 | def compute_loss(self, x, x_lengths, y, y_lengths, spk=None, out_size=None, use_gt_dur=False, durs=None): 472 | """ 473 | Computes 3 losses: 474 | 1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS). 475 | 2. prior loss: loss between mel-spectrogram and encoder outputs. 476 | 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. 477 | 478 | Args: 479 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 480 | x_lengths (torch.Tensor): lengths of texts in batch. 481 | y (torch.Tensor): batch of corresponding mel-spectrograms. 482 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. 483 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. 484 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. 485 | spk: xvector 486 | use_gt_dur: bool 487 | durs: gt duration 488 | """ 489 | x, x_lengths, y, y_lengths = self.relocate_input([x, x_lengths, y, y_lengths]) 490 | 491 | spk = self.xvector_proj(spk) # NOTE: use x-vectors instead of speaker embedding 492 | 493 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 494 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) 495 | y_max_length = y.shape[-1] 496 | 497 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 498 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 499 | 500 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram 501 | if not use_gt_dur: 502 | with torch.no_grad(): 503 | const = -0.5 * math.log(2 * math.pi) * self.n_feats 504 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) 505 | y_square = torch.matmul(factor.transpose(1, 2), y ** 2) 506 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) 507 | mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) 508 | log_prior = y_square - y_mu_double + mu_square + const 509 | 510 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) 511 | attn = attn.detach() 512 | else: 513 | with torch.no_grad(): 514 | attn = generate_path(durs, attn_mask.squeeze(1)).detach() 515 | 516 | # Compute loss between predicted log-scaled durations and those obtained from MAS 517 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask 518 | dur_loss = duration_loss(logw, logw_, x_lengths) 519 | 520 | # print(attn.shape) 521 | 522 | # Cut a small segment of mel-spectrogram in order to increase batch size 523 | if not isinstance(out_size, type(None)): 524 | max_offset = (y_lengths - out_size).clamp(0) 525 | offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) 526 | out_offset = torch.LongTensor([ 527 | torch.tensor(random.choice(range(start, end)) if end > start else 0) 528 | for start, end in offset_ranges 529 | ]).to(y_lengths) 530 | 531 | attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) 532 | y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) 533 | y_cut_lengths = [] 534 | for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): 535 | y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) 536 | y_cut_lengths.append(y_cut_length) 537 | cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length 538 | y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] 539 | attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] 540 | y_cut_lengths = torch.LongTensor(y_cut_lengths) 541 | y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) 542 | 543 | attn = attn_cut 544 | y = y_cut 545 | y_mask = y_cut_mask 546 | 547 | # Align encoded text with mel-spectrogram and get mu_y segment 548 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 549 | mu_y = mu_y.transpose(1, 2) 550 | 551 | # Compute loss of score-based decoder 552 | diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk) 553 | 554 | # Compute loss between aligned encoder outputs and mel-spectrogram 555 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) 556 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) 557 | 558 | return dur_loss, prior_loss, diff_loss 559 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import torch 4 | 5 | 6 | def sequence_mask(length, max_length=None): 7 | if max_length is None: 8 | max_length = length.max() 9 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 10 | return x.unsqueeze(0) < length.unsqueeze(1) 11 | 12 | 13 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 14 | while True: 15 | if length % (2**num_downsamplings_in_unet) == 0: 16 | return length 17 | length += 1 18 | 19 | 20 | def convert_pad_shape(pad_shape): 21 | l = pad_shape[::-1] 22 | pad_shape = [item for sublist in l for item in sublist] 23 | return pad_shape 24 | 25 | 26 | def generate_path(duration, mask): 27 | device = duration.device 28 | 29 | b, t_x, t_y = mask.shape 30 | cum_duration = torch.cumsum(duration, 1) 31 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 32 | 33 | cum_duration_flat = cum_duration.view(b * t_x) 34 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 35 | path = path.view(b, t_x, t_y) 36 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], 37 | [1, 0], [0, 0]]))[:, :-1] 38 | path = path * mask 39 | return path 40 | 41 | 42 | def duration_loss(logw, logw_, lengths): 43 | loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) 44 | return loss 45 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from xutils import init_weights, get_padding 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class ResBlock1(torch.nn.Module): 12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 13 | super(ResBlock1, self).__init__() 14 | self.h = h 15 | self.convs1 = nn.ModuleList([ 16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 17 | padding=get_padding(kernel_size, dilation[0]))), 18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 19 | padding=get_padding(kernel_size, dilation[1]))), 20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 21 | padding=get_padding(kernel_size, dilation[2]))) 22 | ]) 23 | self.convs1.apply(init_weights) 24 | 25 | self.convs2 = nn.ModuleList([ 26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 27 | padding=get_padding(kernel_size, 1))), 28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 29 | padding=get_padding(kernel_size, 1))), 30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 31 | padding=get_padding(kernel_size, 1))) 32 | ]) 33 | self.convs2.apply(init_weights) 34 | 35 | def forward(self, x): 36 | for c1, c2 in zip(self.convs1, self.convs2): 37 | xt = F.leaky_relu(x, LRELU_SLOPE) 38 | xt = c1(xt) 39 | xt = F.leaky_relu(xt, LRELU_SLOPE) 40 | xt = c2(xt) 41 | x = xt + x 42 | return x 43 | 44 | def remove_weight_norm(self): 45 | for l in self.convs1: 46 | remove_weight_norm(l) 47 | for l in self.convs2: 48 | remove_weight_norm(l) 49 | 50 | 51 | class ResBlock2(torch.nn.Module): 52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 53 | super(ResBlock2, self).__init__() 54 | self.h = h 55 | self.convs = nn.ModuleList([ 56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 57 | padding=get_padding(kernel_size, dilation[0]))), 58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 59 | padding=get_padding(kernel_size, dilation[1]))) 60 | ]) 61 | self.convs.apply(init_weights) 62 | 63 | def forward(self, x): 64 | for c in self.convs: 65 | xt = F.leaky_relu(x, LRELU_SLOPE) 66 | xt = c(xt) 67 | x = xt + x 68 | return x 69 | 70 | def remove_weight_norm(self): 71 | for l in self.convs: 72 | remove_weight_norm(l) 73 | 74 | 75 | class Generator(torch.nn.Module): 76 | def __init__(self, h): 77 | super(Generator, self).__init__() 78 | self.h = h 79 | self.num_kernels = len(h.resblock_kernel_sizes) 80 | self.num_upsamples = len(h.upsample_rates) 81 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 82 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 83 | 84 | self.ups = nn.ModuleList() 85 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 86 | self.ups.append(weight_norm( 87 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 88 | k, u, padding=(k-u)//2))) 89 | 90 | self.resblocks = nn.ModuleList() 91 | for i in range(len(self.ups)): 92 | ch = h.upsample_initial_channel//(2**(i+1)) 93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 94 | self.resblocks.append(resblock(h, ch, k, d)) 95 | 96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 97 | self.ups.apply(init_weights) 98 | self.conv_post.apply(init_weights) 99 | 100 | def forward(self, x): 101 | x = self.conv_pre(x) 102 | for i in range(self.num_upsamples): 103 | x = F.leaky_relu(x, LRELU_SLOPE) 104 | x = self.ups[i](x) 105 | xs = None 106 | for j in range(self.num_kernels): 107 | if xs is None: 108 | xs = self.resblocks[i*self.num_kernels+j](x) 109 | else: 110 | xs += self.resblocks[i*self.num_kernels+j](x) 111 | x = xs / self.num_kernels 112 | x = F.leaky_relu(x) 113 | x = self.conv_post(x) 114 | x = torch.tanh(x) 115 | 116 | return x 117 | 118 | def remove_weight_norm(self): 119 | print('Removing weight norm...') 120 | for l in self.ups: 121 | remove_weight_norm(l) 122 | for l in self.resblocks: 123 | l.remove_weight_norm() 124 | remove_weight_norm(self.conv_pre) 125 | remove_weight_norm(self.conv_post) 126 | 127 | 128 | class DiscriminatorP(torch.nn.Module): 129 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 130 | super(DiscriminatorP, self).__init__() 131 | self.period = period 132 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 133 | self.convs = nn.ModuleList([ 134 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 135 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 136 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 137 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 138 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 139 | ]) 140 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 141 | 142 | def forward(self, x): 143 | fmap = [] 144 | 145 | # 1d to 2d 146 | b, c, t = x.shape 147 | if t % self.period != 0: # pad first 148 | n_pad = self.period - (t % self.period) 149 | x = F.pad(x, (0, n_pad), "reflect") 150 | t = t + n_pad 151 | x = x.view(b, c, t // self.period, self.period) 152 | 153 | for l in self.convs: 154 | x = l(x) 155 | x = F.leaky_relu(x, LRELU_SLOPE) 156 | fmap.append(x) 157 | x = self.conv_post(x) 158 | fmap.append(x) 159 | x = torch.flatten(x, 1, -1) 160 | 161 | return x, fmap 162 | 163 | 164 | class MultiPeriodDiscriminator(torch.nn.Module): 165 | def __init__(self): 166 | super(MultiPeriodDiscriminator, self).__init__() 167 | self.discriminators = nn.ModuleList([ 168 | DiscriminatorP(2), 169 | DiscriminatorP(3), 170 | DiscriminatorP(5), 171 | DiscriminatorP(7), 172 | DiscriminatorP(11), 173 | ]) 174 | 175 | def forward(self, y, y_hat): 176 | y_d_rs = [] 177 | y_d_gs = [] 178 | fmap_rs = [] 179 | fmap_gs = [] 180 | for i, d in enumerate(self.discriminators): 181 | y_d_r, fmap_r = d(y) 182 | y_d_g, fmap_g = d(y_hat) 183 | y_d_rs.append(y_d_r) 184 | fmap_rs.append(fmap_r) 185 | y_d_gs.append(y_d_g) 186 | fmap_gs.append(fmap_g) 187 | 188 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 189 | 190 | 191 | class DiscriminatorS(torch.nn.Module): 192 | def __init__(self, use_spectral_norm=False): 193 | super(DiscriminatorS, self).__init__() 194 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 195 | self.convs = nn.ModuleList([ 196 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 197 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 198 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 199 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 200 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 201 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 202 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 203 | ]) 204 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 205 | 206 | def forward(self, x): 207 | fmap = [] 208 | for l in self.convs: 209 | x = l(x) 210 | x = F.leaky_relu(x, LRELU_SLOPE) 211 | fmap.append(x) 212 | x = self.conv_post(x) 213 | fmap.append(x) 214 | x = torch.flatten(x, 1, -1) 215 | 216 | return x, fmap 217 | 218 | 219 | class MultiScaleDiscriminator(torch.nn.Module): 220 | def __init__(self): 221 | super(MultiScaleDiscriminator, self).__init__() 222 | self.discriminators = nn.ModuleList([ 223 | DiscriminatorS(use_spectral_norm=True), 224 | DiscriminatorS(), 225 | DiscriminatorS(), 226 | ]) 227 | self.meanpools = nn.ModuleList([ 228 | AvgPool1d(4, 2, padding=2), 229 | AvgPool1d(4, 2, padding=2) 230 | ]) 231 | 232 | def forward(self, y, y_hat): 233 | y_d_rs = [] 234 | y_d_gs = [] 235 | fmap_rs = [] 236 | fmap_gs = [] 237 | for i, d in enumerate(self.discriminators): 238 | if i != 0: 239 | y = self.meanpools[i-1](y) 240 | y_hat = self.meanpools[i-1](y_hat) 241 | y_d_r, fmap_r = d(y) 242 | y_d_g, fmap_g = d(y_hat) 243 | y_d_rs.append(y_d_r) 244 | fmap_rs.append(fmap_r) 245 | y_d_gs.append(y_d_g) 246 | fmap_gs.append(fmap_g) 247 | 248 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 249 | 250 | 251 | def feature_loss(fmap_r, fmap_g): 252 | loss = 0 253 | for dr, dg in zip(fmap_r, fmap_g): 254 | for rl, gl in zip(dr, dg): 255 | loss += torch.mean(torch.abs(rl - gl)) 256 | 257 | return loss*2 258 | 259 | 260 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 261 | loss = 0 262 | r_losses = [] 263 | g_losses = [] 264 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 265 | r_loss = torch.mean((1-dr)**2) 266 | g_loss = torch.mean(dg**2) 267 | loss += (r_loss + g_loss) 268 | r_losses.append(r_loss.item()) 269 | g_losses.append(g_loss.item()) 270 | 271 | return loss, r_losses, g_losses 272 | 273 | 274 | def generator_loss(disc_outputs): 275 | loss = 0 276 | gen_losses = [] 277 | for dg in disc_outputs: 278 | l = torch.mean((1-dg)**2) 279 | gen_losses.append(l) 280 | loss += l 281 | 282 | return loss, gen_losses 283 | 284 | -------------------------------------------------------------------------------- /text/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/text/.DS_Store -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | CMUdict 2 | ------- 3 | 4 | CMUdict (the Carnegie Mellon Pronouncing Dictionary) is a free 5 | pronouncing dictionary of English, suitable for uses in speech 6 | technology and is maintained by the Speech Group in the School of 7 | Computer Science at Carnegie Mellon University. 8 | 9 | The Carnegie Mellon Speech Group does not guarantee the accuracy of 10 | this dictionary, nor its suitability for any specific purpose. In 11 | fact, we expect a number of errors, omissions and inconsistencies to 12 | remain in the dictionary. We intend to continually update the 13 | dictionary by correction existing entries and by adding new ones. From 14 | time to time a new major version will be released. 15 | 16 | We welcome input from users: Please send email to Alex Rudnicky 17 | (air+cmudict@cs.cmu.edu). 18 | 19 | The Carnegie Mellon Pronouncing Dictionary, in its current and 20 | previous versions is Copyright (C) 1993-2014 by Carnegie Mellon 21 | University. Use of this dictionary for any research or commercial 22 | purpose is completely unrestricted. If you make use of or 23 | redistribute this material we request that you acknowledge its 24 | origin in your descriptions. 25 | 26 | If you add words to or correct words in your version of this 27 | dictionary, we would appreciate it if you could send these additions 28 | and corrections to us (air+cmudict@cs.cmu.edu) for consideration in a 29 | subsequent version. All submissions will be reviewed and approved by 30 | the current maintainer, Alex Rudnicky at Carnegie Mellon. 31 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from text import cleaners 5 | from text.symbols import symbols 6 | import torch 7 | 8 | 9 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 10 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 11 | 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def get_arpabet(word, dictionary): 16 | word_arpabet = dictionary.lookup(word) 17 | if word_arpabet is not None: 18 | return "{" + word_arpabet[0] + "}" 19 | else: 20 | return word 21 | 22 | 23 | def text_to_sequence(text, cleaner_names=["kazakh_cleaners"], dictionary=None): 24 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 25 | 26 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 27 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 28 | 29 | Args: 30 | text: string to convert to a sequence 31 | cleaner_names: names of the cleaner functions to run the text through 32 | dictionary: arpabet class with arpabet dictionary 33 | 34 | Returns: 35 | List of integers corresponding to the symbols in the text 36 | ''' 37 | sequence = [] 38 | space = _symbols_to_sequence(' ') 39 | # Check for curly braces and treat their contents as ARPAbet: 40 | while len(text): 41 | m = _curly_re.match(text) 42 | if not m: 43 | clean_text = _clean_text(text, cleaner_names) 44 | #clean_text = text 45 | if dictionary is not None: 46 | clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] 47 | for i in range(len(clean_text)): 48 | t = clean_text[i] 49 | if t.startswith("{"): 50 | sequence += _arpabet_to_sequence(t[1:-1]) 51 | else: 52 | sequence += _symbols_to_sequence(t) 53 | sequence += space 54 | else: 55 | sequence += _symbols_to_sequence(clean_text) 56 | break 57 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 58 | sequence += _arpabet_to_sequence(m.group(2)) 59 | text = m.group(3) 60 | 61 | # remove trailing space 62 | if dictionary is not None: 63 | sequence = sequence[:-1] if sequence[-1] == space[0] else sequence 64 | return sequence 65 | 66 | 67 | def sequence_to_text(sequence): 68 | '''Converts a sequence of IDs back to a string''' 69 | result = '' 70 | for symbol_id in sequence: 71 | if symbol_id in _id_to_symbol: 72 | s = _id_to_symbol[symbol_id] 73 | # Enclose ARPAbet back in curly braces: 74 | if len(s) > 1 and s[0] == '@': 75 | s = '{%s}' % s[1:] 76 | result += s 77 | return result.replace('}{', ' ') 78 | 79 | def convert_text(string): 80 | text_norm = text_to_sequence(string.lower()) 81 | text_norm = torch.IntTensor(text_norm) 82 | text_len = torch.IntTensor([text_norm.size(0)]) 83 | text_padded = torch.LongTensor(1, len(text_norm)) 84 | text_padded.zero_() 85 | text_padded[0, :text_norm.size(0)] = text_norm 86 | return text_padded, text_len 87 | 88 | def _clean_text(text, cleaner_names): 89 | for name in cleaner_names: 90 | cleaner = getattr(cleaners, name) 91 | if not cleaner: 92 | raise Exception('Unknown cleaner: %s' % name) 93 | text = cleaner(text) 94 | return text 95 | 96 | 97 | def _symbols_to_sequence(symbols): 98 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 99 | 100 | 101 | def _arpabet_to_sequence(text): 102 | return _symbols_to_sequence(['@' + s for s in text.split()]) 103 | 104 | 105 | def _should_keep_symbol(s): 106 | return s in _symbol_to_id and s != '_' and s != '~' 107 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from unidecode import unidecode 5 | 6 | 7 | _whitespace_re = re.compile(r'\s+') 8 | 9 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 10 | ('mrs', 'misess'), 11 | ('mr', 'mister'), 12 | ('dr', 'doctor'), 13 | ('st', 'saint'), 14 | ('co', 'company'), 15 | ('jr', 'junior'), 16 | ('maj', 'major'), 17 | ('gen', 'general'), 18 | ('drs', 'doctors'), 19 | ('rev', 'reverend'), 20 | ('lt', 'lieutenant'), 21 | ('hon', 'honorable'), 22 | ('sgt', 'sergeant'), 23 | ('capt', 'captain'), 24 | ('esq', 'esquire'), 25 | ('ltd', 'limited'), 26 | ('col', 'colonel'), 27 | ('ft', 'fort'), 28 | ]] 29 | 30 | 31 | def expand_abbreviations(text): 32 | for regex, replacement in _abbreviations: 33 | text = re.sub(regex, replacement, text) 34 | return text 35 | 36 | 37 | def lowercase(text): 38 | return text.lower() 39 | 40 | 41 | def collapse_whitespace(text): 42 | return re.sub(_whitespace_re, ' ', text) 43 | 44 | 45 | def convert_to_ascii(text): 46 | return unidecode(text) 47 | 48 | 49 | def basic_cleaners(text): 50 | text = lowercase(text) 51 | text = collapse_whitespace(text) 52 | return text 53 | 54 | 55 | def transliteration_cleaners(text): 56 | text = convert_to_ascii(text) 57 | text = lowercase(text) 58 | text = collapse_whitespace(text) 59 | return text 60 | 61 | def replace_english_words(text): 62 | text = text.replace("bluetooth не usb", "блютуз не юэсби").replace("mega silk way", "мега силк уэй") 63 | return text 64 | 65 | def kazakh_cleaners(text): 66 | # text = convert_to_ascii(text) 67 | text = lowercase(text) 68 | # text = expand_numbers(text) 69 | text = expand_abbreviations(text) 70 | text = replace_english_words(text) 71 | text = collapse_whitespace(text) 72 | return text.replace("c", "с").strip() 73 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | def __init__(self, file_or_path, keep_ambiguous=True): 21 | if isinstance(file_or_path, str): 22 | with open(file_or_path, encoding='latin-1') as f: 23 | entries = _parse_cmudict(f) 24 | else: 25 | entries = _parse_cmudict(file_or_path) 26 | if not keep_ambiguous: 27 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 28 | self._entries = entries 29 | 30 | def __len__(self): 31 | return len(self._entries) 32 | 33 | def lookup(self, word): 34 | return self._entries.get(word.upper()) 35 | 36 | 37 | _alt_re = re.compile(r'\([0-9]+\)') 38 | 39 | 40 | def _parse_cmudict(file): 41 | cmudict = {} 42 | for line in file: 43 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 44 | parts = line.split(' ') 45 | word = re.sub(_alt_re, '', parts[0]) 46 | pronunciation = _get_pronunciation(parts[1]) 47 | if pronunciation: 48 | if word in cmudict: 49 | cmudict[word].append(pronunciation) 50 | else: 51 | cmudict[word] = [pronunciation] 52 | return cmudict 53 | 54 | 55 | def _get_pronunciation(s): 56 | parts = s.strip().split(' ') 57 | for part in parts: 58 | if part not in _valid_symbol_set: 59 | return None 60 | return ' '.join(parts) 61 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | from text import cmudict 4 | 5 | _pad = '_' 6 | _punctuation = '!\'(),.:;? ' 7 | _special = '-' 8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАӘБВГҒДЕЁЖЗИЙКҚЛМНҢОӨПРСТУҰҮФХҺЦЧШЩЪЫІЬЭЮЯаәбвгғдеёжзийкқлмнңоөпрстуұүфхһцчшщъыіьэюя' 9 | _numbers = '0123456789' 10 | 11 | # Prepend "@" to ARPAbet symbols to ensure uniqueness: 12 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 13 | 14 | # Export all symbols: 15 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + list(_numbers) 16 | -------------------------------------------------------------------------------- /train_EMA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from copy import deepcopy 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.utils.tensorboard import SummaryWriter 8 | import data_collate 9 | import data_loader 10 | from utils_data import plot_tensor, save_plot 11 | from model.utils import fix_len_compatibility 12 | from text.symbols import symbols 13 | import utils_data as utils 14 | 15 | 16 | class ModelEmaV2(torch.nn.Module): 17 | def __init__(self, model, decay=0.9999, device=None): 18 | super(ModelEmaV2, self).__init__() 19 | self.model_state_dict = deepcopy(model.state_dict()) 20 | self.decay = decay 21 | self.device = device # perform ema on different device from model if set 22 | 23 | def _update(self, model, update_fn): 24 | with torch.no_grad(): 25 | for ema_v, model_v in zip(self.model_state_dict.values(), model.state_dict().values()): 26 | if self.device is not None: 27 | model_v = model_v.to(device=self.device) 28 | ema_v.copy_(update_fn(ema_v, model_v)) 29 | 30 | def update(self, model): 31 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 32 | 33 | def set(self, model): 34 | self._update(model, update_fn=lambda e, m: m) 35 | 36 | def state_dict(self, destination=None, prefix='', keep_vars=False): 37 | return self.model_state_dict 38 | 39 | 40 | if __name__ == "__main__": 41 | hps = utils.get_hparams() 42 | logger_text = utils.get_logger(hps.model_dir) 43 | logger_text.info(hps) 44 | 45 | out_size = fix_len_compatibility(2 * hps.data.sampling_rate // hps.data.hop_length) # NOTE: 2-sec of mel-spec 46 | 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | torch.manual_seed(hps.train.seed) 49 | np.random.seed(hps.train.seed) 50 | 51 | print('Initializing logger...') 52 | log_dir = hps.model_dir 53 | logger = SummaryWriter(log_dir=log_dir) 54 | 55 | train_dataset, collate, model = utils.get_correct_class(hps) 56 | test_dataset, _, _ = utils.get_correct_class(hps, train=False) 57 | 58 | print('Initializing data loaders...') 59 | 60 | batch_collate = collate 61 | loader = DataLoader(dataset=train_dataset, batch_size=hps.train.batch_size, 62 | collate_fn=batch_collate, drop_last=True, 63 | num_workers=4, shuffle=False) # NOTE: if on server, worker can be 4 64 | 65 | print('Initializing model...') 66 | model = model(**hps.model).to(device) 67 | print('Number of encoder + duration predictor parameters: %.2fm' % (model.encoder.nparams / 1e6)) 68 | print('Number of decoder parameters: %.2fm' % (model.decoder.nparams / 1e6)) 69 | print('Total parameters: %.2fm' % (model.nparams / 1e6)) 70 | 71 | use_gt_dur = getattr(hps.train, "use_gt_dur", False) 72 | if use_gt_dur: 73 | print("++++++++++++++> Using ground truth duration for training") 74 | 75 | print('Initializing optimizer...') 76 | optimizer = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate) 77 | 78 | print('Logging test batch...') 79 | test_batch = test_dataset.sample_test_batch(size=hps.train.test_size) 80 | for i, item in enumerate(test_batch): 81 | mel = item['mel'] 82 | logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), 83 | global_step=0, dataformats='HWC') 84 | save_plot(mel.squeeze(), f'{log_dir}/original_{i}.png') 85 | 86 | try: 87 | model, optimizer, learning_rate, epoch_logged = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "grad_*.pt"), model, optimizer) 88 | epoch_start = epoch_logged + 1 89 | print(f"Loaded checkpoint from {epoch_logged} epoch, resuming training.") 90 | global_step = epoch_logged * (len(train_dataset)/hps.train.batch_size) 91 | except: 92 | print(f"Cannot find trained checkpoint, begin to train from scratch") 93 | epoch_start = 1 94 | global_step = 0 95 | learning_rate = hps.train.learning_rate 96 | 97 | ema_model = ModelEmaV2(model, decay=0.9999) # It's necessary that we put this after loading model. 98 | 99 | print('Start training...') 100 | used_items = set() 101 | iteration = global_step 102 | for epoch in range(epoch_start, hps.train.n_epochs + 1): 103 | model.train() 104 | dur_losses = [] 105 | prior_losses = [] 106 | diff_losses = [] 107 | with tqdm(loader, total=len(train_dataset) // hps.train.batch_size) as progress_bar: 108 | for batch_idx, batch in enumerate(progress_bar): 109 | model.zero_grad() 110 | x, x_lengths = batch['text_padded'].to(device), \ 111 | batch['input_lengths'].to(device) 112 | y, y_lengths = batch['mel_padded'].to(device), \ 113 | batch['output_lengths'].to(device) 114 | if hps.xvector: 115 | spk = batch['xvector'].to(device) 116 | else: 117 | spk = batch['spk_ids'].to(torch.long).to(device) 118 | emo = batch['emo_ids'].to(torch.long).to(device) 119 | 120 | dur_loss, prior_loss, diff_loss = model.compute_loss(x, x_lengths, 121 | y, y_lengths, 122 | spk=spk, 123 | emo=emo, 124 | out_size=out_size, 125 | use_gt_dur=use_gt_dur, 126 | durs=batch['dur_padded'].to(device) if use_gt_dur else None) 127 | loss = sum([dur_loss, prior_loss, diff_loss]) 128 | loss.backward() 129 | 130 | enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 131 | max_norm=1) 132 | dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), 133 | max_norm=1) 134 | optimizer.step() 135 | ema_model.update(model) 136 | 137 | logger.add_scalar('training/duration_loss', dur_loss.item(), 138 | global_step=iteration) 139 | logger.add_scalar('training/prior_loss', prior_loss.item(), 140 | global_step=iteration) 141 | logger.add_scalar('training/diffusion_loss', diff_loss.item(), 142 | global_step=iteration) 143 | logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, 144 | global_step=iteration) 145 | logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, 146 | global_step=iteration) 147 | 148 | dur_losses.append(dur_loss.item()) 149 | prior_losses.append(prior_loss.item()) 150 | diff_losses.append(diff_loss.item()) 151 | 152 | if batch_idx % 5 == 0: 153 | msg = f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}' 154 | progress_bar.set_description(msg) 155 | 156 | iteration += 1 157 | 158 | log_msg = 'Epoch %d: duration loss = %.3f ' % (epoch, float(np.mean(dur_losses))) 159 | log_msg += '| prior loss = %.3f ' % np.mean(prior_losses) 160 | log_msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) 161 | with open(f'{log_dir}/train.log', 'a') as f: 162 | f.write(log_msg) 163 | 164 | if epoch % hps.train.save_every > 0: 165 | continue 166 | 167 | model.eval() 168 | print('Synthesis...') 169 | 170 | with torch.no_grad(): 171 | for i, item in enumerate(test_batch): 172 | if item['utt'] + "/truth" not in used_items: 173 | used_items.add(item['utt'] + "/truth") 174 | x = item['text'].to(torch.long).unsqueeze(0).to(device) 175 | if not hps.xvector: 176 | spk = item['spk_ids'] 177 | spk = torch.LongTensor([spk]).to(device) 178 | else: 179 | spk = item["xvector"] 180 | spk = spk.unsqueeze(0).to(device) 181 | emo = item['emo_ids'] 182 | emo = torch.LongTensor([emo]).to(device) 183 | 184 | x_lengths = torch.LongTensor([x.shape[-1]]).to(device) 185 | 186 | y_enc, y_dec, attn = model(x, x_lengths, spk=spk, emo=emo, n_timesteps=10) 187 | logger.add_image(f'image_{i}/generated_enc', 188 | plot_tensor(y_enc.squeeze().cpu()), 189 | global_step=iteration, dataformats='HWC') 190 | logger.add_image(f'image_{i}/generated_dec', 191 | plot_tensor(y_dec.squeeze().cpu()), 192 | global_step=iteration, dataformats='HWC') 193 | logger.add_image(f'image_{i}/alignment', 194 | plot_tensor(attn.squeeze().cpu()), 195 | global_step=iteration, dataformats='HWC') 196 | save_plot(y_enc.squeeze().cpu(), 197 | f'{log_dir}/generated_enc_{i}.png') 198 | save_plot(y_dec.squeeze().cpu(), 199 | f'{log_dir}/generated_dec_{i}.png') 200 | save_plot(attn.squeeze().cpu(), 201 | f'{log_dir}/alignment_{i}.png') 202 | 203 | ckpt = model.state_dict() 204 | 205 | utils.save_checkpoint(ema_model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/EMA_grad_{epoch}.pt") 206 | utils.save_checkpoint(model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/grad_{epoch}.pt") 207 | 208 | -------------------------------------------------------------------------------- /utils_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import glob 4 | import logging 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import data_loader as loaders 8 | import data_collate as collates 9 | import json 10 | from model import GradTTSXvector, GradTTSWithEmo 11 | import torch 12 | 13 | 14 | def intersperse(lst, item): 15 | # Adds blank symbol 16 | result = [item] * (len(lst) * 2 + 1) 17 | result[1::2] = lst 18 | return result 19 | 20 | 21 | def parse_filelist(filelist_path, split_char="|"): 22 | with open(filelist_path, encoding='utf-8') as f: 23 | filepaths_and_text = [line.strip().split(split_char) for line in f] 24 | return filepaths_and_text 25 | 26 | 27 | def latest_checkpoint_path(dir_path, regex="grad_*.pt"): 28 | f_list = glob.glob(os.path.join(dir_path, regex)) 29 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 30 | x = f_list[-1] 31 | return x 32 | 33 | def load_checkpoint(checkpoint_path, model, optimizer=None): 34 | assert os.path.isfile(checkpoint_path) 35 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 36 | iteration = 1 37 | if 'iteration' in checkpoint_dict.keys(): 38 | iteration = checkpoint_dict['iteration'] 39 | if 'learning_rate' in checkpoint_dict.keys(): 40 | learning_rate = checkpoint_dict['learning_rate'] 41 | else: 42 | learning_rate = None 43 | if optimizer is not None and 'optimizer' in checkpoint_dict.keys(): 44 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 45 | saved_state_dict = checkpoint_dict['model'] 46 | if hasattr(model, 'module'): 47 | state_dict = model.module.state_dict() 48 | else: 49 | state_dict = model.state_dict() 50 | new_state_dict = {} 51 | for k, v in state_dict.items(): 52 | try: 53 | new_state_dict[k] = saved_state_dict[k] 54 | except: 55 | logger.info("%s is not in the checkpoint" % k) 56 | print("%s is not in the checkpoint" % k) 57 | new_state_dict[k] = v 58 | if hasattr(model, 'module'): 59 | model.module.load_state_dict(new_state_dict) 60 | else: 61 | model.load_state_dict(new_state_dict) 62 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 63 | checkpoint_path, iteration)) 64 | return model, optimizer, learning_rate, iteration 65 | 66 | 67 | def save_figure_to_numpy(fig): 68 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 69 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 70 | return data 71 | 72 | 73 | def plot_tensor(tensor): 74 | plt.style.use('default') 75 | fig, ax = plt.subplots(figsize=(12, 3)) 76 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') 77 | plt.colorbar(im, ax=ax) 78 | plt.tight_layout() 79 | fig.canvas.draw() 80 | data = save_figure_to_numpy(fig) 81 | plt.close() 82 | return data 83 | 84 | 85 | def save_plot(tensor, savepath): 86 | plt.style.use('default') 87 | fig, ax = plt.subplots(figsize=(12, 3)) 88 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') 89 | plt.colorbar(im, ax=ax) 90 | plt.tight_layout() 91 | fig.canvas.draw() 92 | plt.savefig(savepath) 93 | plt.close() 94 | return 95 | 96 | 97 | def get_correct_class(hps, train=True): 98 | if train: 99 | if hps.xvector and hps.pe: 100 | raise NotImplementedError 101 | elif hps.xvector: # no pitch energy 102 | raise NotImplementedError 103 | loader = loaders.XvectorLoader 104 | collate = collates.XvectorCollate 105 | model = GradTTSXvector 106 | dataset = loader(utts=hps.data.train_utts, 107 | hparams=hps.data, 108 | feats_scp=hps.data.train_feats_scp, 109 | utt2phns=hps.data.train_utt2phns, 110 | phn2id=hps.data.phn2id, 111 | utt2phn_duration=hps.data.train_utt2phn_duration, 112 | spk_xvector_scp=hps.data.train_spk_xvector_scp, 113 | utt2spk_name=hps.data.train_utt2spk) 114 | elif hps.pe: 115 | raise NotImplementedError 116 | else: # no PE, no xvector 117 | loader = loaders.SpkIDLoaderWithEmo 118 | collate = collates.SpkIDCollateWithEmo 119 | model = GradTTSWithEmo 120 | dataset = loader(utts=hps.data.train_utts, 121 | hparams=hps.data, 122 | feats_scp=hps.data.train_feats_scp, 123 | utt2text=hps.data.train_utt2phns, 124 | utt2spk=hps.data.train_utt2spk, 125 | utt2emo=hps.data.train_utt2emo) 126 | else: 127 | if hps.xvector and hps.pe: 128 | raise NotImplementedError 129 | elif hps.xvector: 130 | raise NotImplementedError 131 | loader = loaders.XvectorLoader 132 | collate = collates.XvectorCollate 133 | model = GradTTSXvector 134 | dataset = loader(utts=hps.data.val_utts, 135 | hparams=hps.data, 136 | feats_scp=hps.data.val_feats_scp, 137 | utt2phns=hps.data.val_utt2phns, 138 | phn2id=hps.data.phn2id, 139 | utt2phn_duration=hps.data.val_utt2phn_duration, 140 | spk_xvector_scp=hps.data.val_spk_xvector_scp, 141 | utt2spk_name=hps.data.val_utt2spk) 142 | elif hps.pe: 143 | raise NotImplementedError 144 | else: # no PE, no xvector 145 | loader = loaders.SpkIDLoaderWithEmo 146 | collate = collates.SpkIDCollateWithEmo 147 | model = GradTTSWithEmo 148 | dataset = loader(utts=hps.data.val_utts, 149 | hparams=hps.data, 150 | feats_scp=hps.data.val_feats_scp, 151 | utt2text=hps.data.val_utt2phns, 152 | utt2spk=hps.data.val_utt2spk, 153 | utt2emo=hps.data.val_utt2emo) 154 | return dataset, collate(), model 155 | 156 | 157 | def get_hparams(init=True): 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument('-c', '--config', type=str, default="./configs/train_grad.json", 160 | help='JSON file for configuration') 161 | parser.add_argument('-m', '--model', type=str, required=True, 162 | help='Model name') 163 | parser.add_argument('-s', '--seed', type=int, default=1234) 164 | parser.add_argument('--not-pretrained', action='store_true', help='if set to true, then train from scratch') 165 | 166 | args = parser.parse_args() 167 | model_dir = os.path.join("./logs", args.model) 168 | 169 | if not os.path.exists(model_dir): 170 | os.makedirs(model_dir) 171 | 172 | config_path = args.config 173 | config_save_path = os.path.join(model_dir, "config.json") 174 | if init: 175 | with open(config_path, "r") as f: 176 | data = f.read() 177 | with open(config_save_path, "w") as f: 178 | f.write(data) 179 | else: 180 | with open(config_save_path, "r") as f: 181 | data = f.read() 182 | config = json.loads(data) 183 | 184 | hparams = HParams(**config) 185 | hparams.model_dir = model_dir 186 | hparams.train.seed = args.seed 187 | hparams.not_pretrained = args.not_pretrained 188 | return hparams 189 | 190 | 191 | class HParams(): 192 | def __init__(self, **kwargs): 193 | for k, v in kwargs.items(): 194 | if type(v) == dict: 195 | v = HParams(**v) 196 | self[k] = v 197 | 198 | def keys(self): 199 | return self.__dict__.keys() 200 | 201 | def items(self): 202 | return self.__dict__.items() 203 | 204 | def values(self): 205 | return self.__dict__.values() 206 | 207 | def __len__(self): 208 | return len(self.__dict__) 209 | 210 | def __getitem__(self, key): 211 | return getattr(self, key) 212 | 213 | def __setitem__(self, key, value): 214 | return setattr(self, key, value) 215 | 216 | def __contains__(self, key): 217 | return key in self.__dict__ 218 | 219 | def __repr__(self): 220 | return self.__dict__.__repr__() 221 | 222 | 223 | def get_logger(model_dir, filename="train.log"): 224 | global logger 225 | logger = logging.getLogger(os.path.basename(model_dir)) 226 | logger.setLevel(logging.DEBUG) 227 | 228 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 229 | if not os.path.exists(model_dir): 230 | os.makedirs(model_dir) 231 | h = logging.FileHandler(os.path.join(model_dir, filename)) 232 | h.setLevel(logging.DEBUG) 233 | h.setFormatter(formatter) 234 | logger.addHandler(h) 235 | return logger 236 | 237 | 238 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 239 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 240 | iteration, checkpoint_path)) 241 | if hasattr(model, 'module'): 242 | state_dict = model.module.state_dict() 243 | else: 244 | state_dict = model.state_dict() 245 | torch.save({'model': state_dict, 246 | 'iteration': iteration, 247 | 'optimizer': optimizer.state_dict(), 248 | 'learning_rate': learning_rate}, checkpoint_path) 249 | 250 | 251 | def get_hparams_decode(model_dir=None): 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument('-c', '--config', type=str, default="./configs/train_grad.json", 254 | help='JSON file for configuration') 255 | parser.add_argument('-m', '--model', type=str, default=model_dir, 256 | help='Model name') 257 | parser.add_argument('-s', '--seed', type=int, default=1234) 258 | parser.add_argument('-t', "--timesteps", type=int, default=10, help='how many timesteps to perform reverse diffusion') 259 | 260 | parser.add_argument("--stoc", action='store_true', default=False, help="Whether to add stochastic term into decoding") 261 | parser.add_argument("-g", "--guidance", type=float, default=3, help='classifier guidance') 262 | parser.add_argument('-n', '--noise', type=float, default=1.5, help='to multiply sigma') 263 | 264 | parser.add_argument('-f', '--file', type=str, required=True, help='path to a file with texts to synthesize') 265 | parser.add_argument('-r', '--generated_path', type=str, required=True, help='path to save wav files') 266 | 267 | args = parser.parse_args() 268 | model_dir = os.path.join("./logs", args.model) 269 | 270 | if not os.path.exists(model_dir): 271 | os.makedirs(model_dir) 272 | 273 | config_path = args.config 274 | config_save_path = os.path.join(model_dir, "config.json") # NOTE: which config to load 275 | with open(config_path, "r") as f: 276 | data = f.read() 277 | config = json.loads(data) 278 | 279 | hparams = HParams(**config) 280 | hparams.model_dir = model_dir 281 | hparams.train.seed = args.seed 282 | 283 | return hparams, args 284 | 285 | 286 | def get_hparams_decode_two_mixture(model_dir=None): 287 | parser = argparse.ArgumentParser() 288 | parser.add_argument('-c', '--config', type=str, default="./configs/train_grad.json", 289 | help='JSON file for configuration') 290 | parser.add_argument('-m', '--model', type=str, required=False, default='/raid/adal_abilbekov/training_emodiff/Emo_diff/logs/logs_train', 291 | help='Model name') 292 | parser.add_argument('-s', '--seed', type=int, default=1234) 293 | parser.add_argument('--dataset', choices=['train', 'val'], default='val', type=str, help='which dataset to use') 294 | parser.add_argument('--use-control-spk', action='store_true', help='whether to use GT spk or other spk') 295 | parser.add_argument('--control-spk-id', default=None, type=int, help='if use control spk, then which spk') 296 | parser.add_argument("--use-control-emo", action='store_true') 297 | parser.add_argument("--control-emo-id1", type=int) 298 | parser.add_argument("--control-emo-id2", type=int) 299 | parser.add_argument("--emo1-weight", type=float, default=0.5) 300 | 301 | parser.add_argument('--control-spk-name', default=None, type=str, help='if use control spk, then which spk') 302 | parser.add_argument("--max-utt-num", default=100, type=int, help='maximum utts number to decode') 303 | parser.add_argument("--specify-utt-name", default=None, type=str, help='if specified, only decodes for that utt') 304 | parser.add_argument('-t', "--timesteps", type=int, default=10, help='how many timesteps to perform reverse diffusion') 305 | 306 | parser.add_argument("--stoc", action='store_true', default=False, help="Whether to add stochastic term into decoding") 307 | parser.add_argument("-g", "--guidance", type=float, default=3, help='classifier guidance') 308 | parser.add_argument('-n', '--noise', type=float, default=1.5, help='to multiply sigma') 309 | 310 | parser.add_argument('--text', type=str, default=None, help="given text file") 311 | 312 | args = parser.parse_args() 313 | model_dir = os.path.join("./logs", args.model) 314 | 315 | if not os.path.exists(model_dir): 316 | os.makedirs(model_dir) 317 | 318 | config_path = args.config 319 | config_save_path = os.path.join(model_dir, "config.json") # NOTE: which config to load 320 | with open(config_path, "r") as f: 321 | data = f.read() 322 | config = json.loads(data) 323 | 324 | hparams = HParams(**config) 325 | hparams.model_dir = model_dir 326 | hparams.train.seed = args.seed 327 | 328 | if args.use_control_spk: 329 | if hparams.xvector: 330 | assert args.control_spk_name is not None 331 | else: 332 | assert args.control_spk_id is not None 333 | 334 | return hparams, args 335 | 336 | 337 | def get_hparams_classifier_objective(): 338 | parser = argparse.ArgumentParser() 339 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 340 | help='JSON file for configuration') 341 | parser.add_argument('-m', '--model', type=str, required=True, 342 | help='Model name') 343 | parser.add_argument('-s', '--seed', type=int, default=1234) 344 | parser.add_argument('--dataset', choices=['train', 'val'], default='val', type=str, help='which dataset to use') 345 | parser.add_argument('--use-control-spk', action='store_true', help='whether to use GT spk or other spk') 346 | parser.add_argument('--control-spk-id', default=None, type=int, help='if use control spk, then which spk') 347 | parser.add_argument("--use-control-emo", action='store_true') 348 | parser.add_argument("--max-utt-num", default=100, type=int, help='maximum utts number to decode') 349 | parser.add_argument("--specify-utt-name", default=None, type=str, help='if specified, only decodes for that utt') 350 | 351 | parser.add_argument('--text', type=str, default=None, help="given text file") 352 | parser.add_argument("--feat", type=str, default=None, help='given feats.scp after CMVN') 353 | parser.add_argument("--dur", type=str, default=None, help='Force durations') 354 | 355 | args = parser.parse_args() 356 | model_dir = os.path.join("./logs", args.model) 357 | 358 | if not os.path.exists(model_dir): 359 | os.makedirs(model_dir) 360 | 361 | config_path = args.config 362 | config_save_path = os.path.join(model_dir, "config.json") # NOTE: which config to load 363 | with open(config_path, "r") as f: 364 | data = f.read() 365 | config = json.loads(data) 366 | 367 | hparams = HParams(**config) 368 | hparams.model_dir = model_dir 369 | hparams.train.seed = args.seed 370 | 371 | if args.use_control_spk: 372 | if hparams.xvector: 373 | assert args.control_spk_name is not None 374 | else: 375 | assert args.control_spk_id is not None 376 | 377 | return hparams, args 378 | -------------------------------------------------------------------------------- /xutils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | --------------------------------------------------------------------------------