230 | ```
231 |
232 | Synthesized samples 🔈
233 | You can listen to some synthesized samples here.
234 |
235 | Citation 🎓
236 |
237 | We kindly urge you, if you incorporate our dataset and/or model into your work, to cite our paper as a gesture of recognition for its valuable contribution. The act of referencing the relevant sources not only upholds academic honesty but also ensures proper acknowledgement of the authors' efforts. Your citation in your research significantly contributes to the continuous progress and evolution of the scholarly realm. Your endorsement and acknowledgement of our endeavours are genuinely appreciated.
238 |
239 | ```bibtex
240 | @misc{abilbekov2024kazemotts,
241 | title={KazEmoTTS: A Dataset for Kazakh Emotional Text-to-Speech Synthesis},
242 | author={Adal Abilbekov and Saida Mussakhojayeva and Rustem Yeshpanov and Huseyin Atakan Varol},
243 | year={2024},
244 | eprint={2404.01033},
245 | archivePrefix={arXiv},
246 | primaryClass={eess.AS}
247 | }
248 | ```
249 |
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/configs/.DS_Store
--------------------------------------------------------------------------------
/configs/hifigan-config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 1,
4 | "batch_size": 64,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54320",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/configs/train_grad.json:
--------------------------------------------------------------------------------
1 | {
2 | "xvector": false,
3 | "pe": false,
4 | "train": {
5 | "test_size": 6,
6 | "n_epochs": 10000,
7 | "batch_size": 64,
8 | "learning_rate": 1e-4,
9 | "seed": 37,
10 | "save_every": 1,
11 | "use_gt_dur": false
12 | },
13 | "data": {
14 | "load_mel_from_disk": false,
15 | "train_utts": "filelists/all_spks/train_utts.txt",
16 | "val_utts": "filelists/all_spks/eval_utts.txt",
17 | "train_utt2phns": "filelists/all_spks/text",
18 | "val_utt2phns": "filelists/all_spks/text",
19 | "train_feats_scp": "filelists/all_spks/feats.scp",
20 | "val_feats_scp": "filelists/all_spks/feats.scp",
21 | "train_utt2spk": "filelists/all_spks/utt2spk.json",
22 | "val_utt2spk": "filelists/all_spks/utt2spk.json",
23 | "train_utt2emo": "filelists/all_spks/utt2emo.json",
24 | "val_utt2emo": "filelists/all_spks/utt2emo.json",
25 |
26 | "train_var_scp": "",
27 | "val_var_scp": "",
28 |
29 | "text_cleaners": [
30 | "kazakh_cleaners"
31 | ],
32 | "max_wav_value": 32768.0,
33 | "sampling_rate": 22050,
34 | "filter_length": 1024,
35 | "hop_length": 200,
36 | "win_length": 800,
37 | "n_mel_channels": 80,
38 | "mel_fmin": 20.0,
39 | "mel_fmax": 8000.0,
40 | "utt2phn_path": "data/res_utt2phns.json",
41 | "add_blank": false
42 | },
43 | "model": {
44 | "n_vocab": 200,
45 | "n_spks": 3,
46 | "n_emos": 6,
47 | "spk_emb_dim": 64,
48 | "n_enc_channels": 192,
49 | "filter_channels": 768,
50 | "filter_channels_dp": 256,
51 | "n_enc_layers": 6,
52 | "enc_kernel": 3,
53 | "enc_dropout": 0.1,
54 | "n_heads": 2,
55 | "window_size": 4,
56 | "dec_dim": 64,
57 | "beta_min": 0.05,
58 | "beta_max": 20.0,
59 | "pe_scale": 1000,
60 | "d_decoder": 128,
61 | "l_decoder": 3,
62 | "k_decoder": 7,
63 | "h_decoder": 4,
64 | "decoder_dropout":0.1,
65 |
66 | "classifier_type": "CNN-with-time"
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/data_collate.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import numpy as np
4 | import torch
5 | import re
6 | import torch.utils.data
7 | import json
8 |
9 | import kaldiio
10 | from tqdm import tqdm
11 |
12 |
13 | class BaseCollate:
14 | def __init__(self, n_frames_per_step=1):
15 | self.n_frames_per_step = n_frames_per_step
16 |
17 | def collate_text_mel(self, batch: [dict]):
18 | """
19 | :param batch: list of dicts
20 | """
21 | utt = list(map(lambda x: x['utt'], batch))
22 | input_lengths, ids_sorted_decreasing = torch.sort(
23 | torch.LongTensor([len(x['text']) for x in batch]),
24 | dim=0, descending=True)
25 | max_input_len = input_lengths[0]
26 |
27 | text_padded = torch.LongTensor(len(batch), max_input_len)
28 | text_padded.zero_()
29 | for i in range(len(ids_sorted_decreasing)):
30 | text = batch[ids_sorted_decreasing[i]]['text']
31 | text_padded[i, :text.size(0)] = text
32 |
33 | # Right zero-pad mel-spec
34 | num_mels = batch[0]['mel'].size(0)
35 | max_target_len = max([x['mel'].size(1) for x in batch])
36 | if max_target_len % self.n_frames_per_step != 0:
37 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
38 | assert max_target_len % self.n_frames_per_step == 0
39 |
40 | # include mel padded
41 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
42 | mel_padded.zero_()
43 | output_lengths = torch.LongTensor(len(batch))
44 | for i in range(len(ids_sorted_decreasing)):
45 | mel = batch[ids_sorted_decreasing[i]]['mel']
46 | mel_padded[i, :, :mel.size(1)] = mel
47 | output_lengths[i] = mel.size(1)
48 |
49 | utt_name = np.array(utt)[ids_sorted_decreasing].tolist()
50 | if isinstance(utt_name, str):
51 | utt_name = [utt_name]
52 |
53 | res = {
54 | "utt": utt_name,
55 | "text_padded": text_padded,
56 | "input_lengths": input_lengths,
57 | "mel_padded": mel_padded,
58 | "output_lengths": output_lengths,
59 | }
60 | return res, ids_sorted_decreasing
61 |
62 |
63 | class SpkIDCollate(BaseCollate):
64 | def __call__(self, batch, *args, **kwargs):
65 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch)
66 | spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch)))
67 | spk_ids = spk_ids[ids_sorted_decreasing]
68 | base_data.update({
69 | "spk_ids": spk_ids
70 | })
71 | return base_data
72 |
73 |
74 | class SpkIDCollateWithEmo(BaseCollate):
75 | def __call__(self, batch, *args, **kwargs):
76 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch)
77 |
78 | spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch)))
79 | spk_ids = spk_ids[ids_sorted_decreasing]
80 | emo_ids = torch.LongTensor(list(map(lambda x: x['emo_ids'], batch)))
81 | emo_ids = emo_ids[ids_sorted_decreasing]
82 | base_data.update({
83 | "spk_ids": spk_ids,
84 | "emo_ids": emo_ids
85 | })
86 | return base_data
87 |
88 |
89 | class XvectorCollate(BaseCollate):
90 | def __call__(self, batch, *args, **kwargs):
91 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch)
92 | xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0)
93 | xvectors = xvectors[ids_sorted_decreasing]
94 | base_data.update({
95 | "xvector": xvectors
96 | })
97 | return base_data
98 |
99 |
100 | class SpkIDCollateWithPE(BaseCollate):
101 | def __call__(self, batch, *args, **kwargs):
102 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch)
103 | spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch)))
104 | spk_ids = spk_ids[ids_sorted_decreasing]
105 |
106 | num_var = batch[0]["var"].size(0)
107 | max_target_len = max([x["var"].size(1) for x in batch])
108 | if max_target_len % self.n_frames_per_step != 0:
109 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
110 | assert max_target_len % self.n_frames_per_step == 0
111 |
112 | var_padded = torch.FloatTensor(len(batch), num_var, max_target_len)
113 | var_padded.zero_()
114 | for i in range(len(ids_sorted_decreasing)):
115 | var = batch[ids_sorted_decreasing[i]]["var"]
116 | var_padded[i, :, :var.size(1)] = var
117 |
118 | base_data.update({
119 | "spk_ids": spk_ids,
120 | "var_padded": var_padded
121 | })
122 | return base_data
123 |
124 |
125 | class XvectorCollateWithPE(BaseCollate):
126 | def __call__(self, batch, *args, **kwargs):
127 | base_data, ids_sorted_decreasing = self.collate_text_mel(batch)
128 | xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0)
129 | xvectors = xvectors[ids_sorted_decreasing]
130 |
131 | num_var = batch[0]["var"].size(0)
132 | max_target_len = max([x["var"].size(1) for x in batch])
133 | if max_target_len % self.n_frames_per_step != 0:
134 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
135 | assert max_target_len % self.n_frames_per_step == 0
136 |
137 | var_padded = torch.FloatTensor(len(batch), num_var, max_target_len)
138 | var_padded.zero_()
139 | for i in range(len(ids_sorted_decreasing)):
140 | var = batch[ids_sorted_decreasing[i]]["var"]
141 | var_padded[i, :, :var.size(1)] = var
142 |
143 | base_data.update({
144 | "xvector": xvectors,
145 | "var_padded": var_padded
146 | })
147 | return base_data
148 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import numpy as np
4 | import torch
5 | import re
6 | import torch.utils.data
7 | import json
8 |
9 | import kaldiio
10 | from tqdm import tqdm
11 | from text import text_to_sequence
12 |
13 | class BaseLoader(torch.utils.data.Dataset):
14 | def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str):
15 | """
16 | :param utts: file path. A list of utts for this loader. These are the only utts that this loader has access.
17 | This loader only deals with text, duration and feats. Other files despite `utts` can be larger.
18 | """
19 | self.n_mel_channels = hparams.n_mel_channels
20 | self.sampling_rate = hparams.sampling_rate
21 | self.utts = self.get_utts(utts)
22 | self.utt2feat = self.get_utt2feat(feats_scp)
23 | self.utt2text = self.get_utt2text(utt2text)
24 |
25 | def get_utts(self, utts: str) -> list:
26 | with open(utts, 'r') as f:
27 | L = f.readlines()
28 | L = list(map(lambda x: x.strip(), L))
29 | random.seed(1234)
30 | random.shuffle(L)
31 | return L
32 |
33 |
34 | def get_utt2feat(self, feats_scp: str):
35 | utt2feat = kaldiio.load_scp(feats_scp) # lazy load mode
36 | print(f"Succeed reading feats from {feats_scp}")
37 | return utt2feat
38 |
39 | def get_utt2text(self, utt2text: str):
40 | with open(utt2text, 'r') as f:
41 | L = f.readlines()
42 | utt2text = {line.split()[0]: line.strip().split(" ", 1)[1] for line in L}
43 | return utt2text
44 |
45 | def get_mel_from_kaldi(self, utt):
46 | feat = self.utt2feat[utt]
47 | feat = torch.FloatTensor(feat).squeeze()
48 | assert self.n_mel_channels in feat.shape
49 | if feat.shape[0] == self.n_mel_channels:
50 | return feat
51 | else:
52 | return feat.T
53 |
54 | def get_text(self, utt):
55 | text = self.utt2text[utt]
56 | text_norm = text_to_sequence(text)
57 | text_norm = torch.IntTensor(text_norm)
58 | return text_norm
59 |
60 | def __getitem__(self, index):
61 | res = self.get_mel_text_pair(self.utts[index])
62 | return res
63 |
64 | def __len__(self):
65 | return len(self.utts)
66 |
67 | def sample_test_batch(self, size):
68 | idx = np.random.choice(range(len(self)), size=size, replace=False)
69 | test_batch = []
70 | for index in idx:
71 | test_batch.append(self.__getitem__(index))
72 | return test_batch
73 |
74 |
75 | class SpkIDLoader(BaseLoader):
76 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
77 | utt2phn_duration: str, utt2spk: str):
78 | """
79 | :param utt2spk: json file path (utt name -> spk id)
80 | This loader loads speaker as a speaker ID for embedding table
81 | """
82 | super(SpkIDLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration)
83 | self.utt2spk = self.get_utt2spk(utt2spk)
84 |
85 | def get_utt2spk(self, utt2spk: str) -> dict:
86 | with open(utt2spk, 'r') as f:
87 | res = json.load(f)
88 | return res
89 |
90 | def get_mel_text_pair(self, utt):
91 | # separate filename and text
92 | spkid = self.utt2spk[utt]
93 | phn_ids = self.get_text(utt)
94 | mel = self.get_mel_from_kaldi(utt)
95 | dur = self.get_dur_from_kaldi(utt)
96 |
97 | assert sum(dur) == mel.shape[1], f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}"
98 | res = {
99 | "utt": utt,
100 | "mel": mel,
101 | "spk_ids": spkid
102 | }
103 | return res
104 |
105 | def __getitem__(self, index):
106 | res = self.get_mel_text_pair(self.utts[index])
107 | return res
108 |
109 | def __len__(self):
110 | return len(self.utts)
111 |
112 |
113 | class SpkIDLoaderWithEmo(BaseLoader):
114 | def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str, utt2spk: str, utt2emo: str):
115 | """
116 | :param utt2spk: json file path (utt name -> spk id)
117 | This loader loads speaker as a speaker ID for embedding table
118 | """
119 | super(SpkIDLoaderWithEmo, self).__init__(utts, hparams, feats_scp, utt2text)
120 | self.utt2spk = self.get_utt2spk(utt2spk)
121 | self.utt2emo = self.get_utt2emo(utt2emo)
122 |
123 | def get_utt2spk(self, utt2spk: str) -> dict:
124 | with open(utt2spk, 'r') as f:
125 | res = json.load(f)
126 | return res
127 |
128 | def get_utt2emo(self, utt2emo: str) -> dict:
129 | with open(utt2emo, 'r') as f:
130 | res = json.load(f)
131 | return res
132 |
133 | def get_mel_text_pair(self, utt):
134 | # separate filename and text
135 | spkid = int(self.utt2spk[utt])
136 | emoid = int(self.utt2emo[utt])
137 | text = self.get_text(utt)
138 | mel = self.get_mel_from_kaldi(utt)
139 |
140 | res = {
141 | "utt": utt,
142 | "text": text,
143 | "mel": mel,
144 | "spk_ids": spkid,
145 | "emo_ids": emoid
146 | }
147 | return res
148 |
149 | def __getitem__(self, index):
150 | res = self.get_mel_text_pair(self.utts[index])
151 | return res
152 |
153 | def __len__(self):
154 | return len(self.utts)
155 |
156 |
157 | class SpkIDLoaderWithPE(SpkIDLoader):
158 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
159 | utt2phn_duration: str, utt2spk: str, var_scp: str):
160 | """
161 | This loader loads speaker ID together with variance (4-dim pitch, 1-dim energy)
162 | """
163 | super(SpkIDLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration, utt2spk)
164 | self.utt2var = self.get_utt2var(var_scp)
165 |
166 | def get_utt2var(self, utt2var: str) -> dict:
167 | res = kaldiio.load_scp(utt2var)
168 | print(f"Succeed reading feats from {utt2var}")
169 | return res
170 |
171 | def get_var_from_kaldi(self, utt):
172 | var = self.utt2var[utt]
173 | var = torch.FloatTensor(var).squeeze()
174 | assert 5 in var.shape
175 | if var.shape[0] == 5:
176 | return var
177 | else:
178 | return var.T
179 |
180 | def get_mel_text_pair(self, utt):
181 | # separate filename and text
182 | spkid = self.utt2spk[utt]
183 | phn_ids = self.get_text(utt)
184 | mel = self.get_mel_from_kaldi(utt)
185 | dur = self.get_dur_from_kaldi(utt)
186 | var = self.get_var_from_kaldi(utt)
187 |
188 | assert sum(dur) == mel.shape[1] == var.shape[1], \
189 | f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}"
190 |
191 | res = {
192 | "utt": utt,
193 | "phn_ids": phn_ids,
194 | "mel": mel,
195 | "dur": dur,
196 | "spk_ids": spkid,
197 | "var": var
198 | }
199 | return res
200 |
201 |
202 | class XvectorLoader(BaseLoader):
203 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
204 | utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str):
205 | """
206 | :param utt2spk_name: like kaldi-style utt2spk
207 | :param spk_xvector_scp: kaldi-style speaker-level xvector.scp
208 | """
209 | super(XvectorLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration)
210 | self.utt2spk = self.get_utt2spk(utt2spk_name)
211 | self.spk2xvector = self.get_spk2xvector(spk_xvector_scp)
212 |
213 | def get_utt2spk(self, utt2spk):
214 | res = dict()
215 | with open(utt2spk, 'r') as f:
216 | for l in f.readlines():
217 | res[l.split()[0]] = l.split()[1]
218 | return res
219 |
220 | def get_spk2xvector(self, spk_xvector_scp: str) -> dict:
221 | res = kaldiio.load_scp(spk_xvector_scp)
222 | print(f"Succeed reading xvector from {spk_xvector_scp}")
223 | return res
224 |
225 | def get_xvector(self, utt):
226 | xv = self.spk2xvector[self.utt2spk[utt]]
227 | xv = torch.FloatTensor(xv).squeeze()
228 | return xv
229 |
230 | def get_mel_text_pair(self, utt):
231 | phn_ids = self.get_text(utt)
232 | mel = self.get_mel_from_kaldi(utt)
233 | dur = self.get_dur_from_kaldi(utt)
234 | xvector = self.get_xvector(utt)
235 |
236 | assert sum(dur) == mel.shape[1], \
237 | f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}"
238 |
239 | res = {
240 | "utt": utt,
241 | "phn_ids": phn_ids,
242 | "mel": mel,
243 | "dur": dur,
244 | "xvector": xvector,
245 | }
246 | return res
247 |
248 |
249 | class XvectorLoaderWithPE(BaseLoader):
250 | def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str,
251 | utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str, var_scp: str):
252 | super(XvectorLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration)
253 | self.utt2spk = self.get_utt2spk(utt2spk_name)
254 | self.spk2xvector = self.get_spk2xvector(spk_xvector_scp)
255 | self.utt2var = self.get_utt2var(var_scp)
256 |
257 | def get_spk2xvector(self, spk_xvector_scp: str) -> dict:
258 | res = kaldiio.load_scp(spk_xvector_scp)
259 | print(f"Succeed reading xvector from {spk_xvector_scp}")
260 | return res
261 |
262 | def get_utt2spk(self, utt2spk):
263 | res = dict()
264 | with open(utt2spk, 'r') as f:
265 | for l in f.readlines():
266 | res[l.split()[0]] = l.split()[1]
267 | return res
268 |
269 | def get_utt2var(self, utt2var: str) -> dict:
270 | res = kaldiio.load_scp(utt2var)
271 | print(f"Succeed reading feats from {utt2var}")
272 | return res
273 |
274 | def get_var_from_kaldi(self, utt):
275 | var = self.utt2var[utt]
276 | var = torch.FloatTensor(var).squeeze()
277 | assert 5 in var.shape
278 | if var.shape[0] == 5:
279 | return var
280 | else:
281 | return var.T
282 |
283 | def get_xvector(self, utt):
284 | xv = self.spk2xvector[self.utt2spk[utt]]
285 | xv = torch.FloatTensor(xv).squeeze()
286 | return xv
287 |
288 | def get_mel_text_pair(self, utt):
289 | # separate filename and text
290 | spkid = self.utt2spk[utt]
291 | phn_ids = self.get_text(utt)
292 | mel = self.get_mel_from_kaldi(utt)
293 | dur = self.get_dur_from_kaldi(utt)
294 | var = self.get_var_from_kaldi(utt)
295 | xvector = self.get_xvector(utt)
296 |
297 | assert sum(dur) == mel.shape[1] == var.shape[1], \
298 | f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}"
299 |
300 | res = {
301 | "utt": utt,
302 | "phn_ids": phn_ids,
303 | "mel": mel,
304 | "dur": dur,
305 | "spk_ids": spkid,
306 | "var": var,
307 | "xvector": xvector
308 | }
309 | return res
310 |
--------------------------------------------------------------------------------
/data_preparation.py:
--------------------------------------------------------------------------------
1 | import kaldiio
2 | import os
3 | import librosa
4 | from tqdm import tqdm
5 | import glob
6 | import json
7 | from shutil import copyfile
8 | import pandas as pd
9 | import argparse
10 | from text import _clean_text, symbols
11 | from num2words import num2words
12 | import re
13 | from melspec import mel_spectrogram
14 | import torchaudio
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-d', '--data', type=str, required=True, help='path to the emotional dataset')
19 | args = parser.parse_args()
20 | dataset_path = args.data
21 | filelists_path = 'filelists/all_spks/'
22 | feats_scp_file = filelists_path + 'feats.scp'
23 | feats_ark_file = filelists_path + 'feats.ark'
24 |
25 |
26 | spks = ['1263201035', '805570882', '399172782']
27 | train_files = []
28 | eval_files = []
29 | for spk in spks:
30 | train_files += glob.glob(dataset_path + spk + "/train/*.wav")
31 | eval_files += glob.glob(dataset_path + spk + "/eval/*.wav")
32 |
33 | os.makedirs(filelists_path, exist_ok=True)
34 |
35 | with open(filelists_path + 'train_utts.txt', 'w', encoding='utf-8') as f:
36 | for wav_path in train_files:
37 | wav_name = os.path.splitext(os.path.basename(wav_path))[0]
38 | f.write(wav_name + '\n')
39 | with open(filelists_path + 'eval_utts.txt', 'w', encoding='utf-8') as f:
40 | for wav_path in eval_files:
41 | wav_name = os.path.splitext(os.path.basename(wav_path))[0]
42 | f.write(wav_name + '\n')
43 |
44 | with open(feats_scp_file, 'w') as feats_scp, \
45 | kaldiio.WriteHelper(f'ark,scp:{feats_ark_file},{feats_scp_file}') as writer:
46 | for root, dirs, files in os.walk(dataset_path):
47 | for file in tqdm(files):
48 | if file.endswith('.wav'):
49 | # Get the file name and relative path to the root folder
50 | wav_path = os.path.join(root, file)
51 | rel_path = os.path.relpath(wav_path, dataset_path)
52 | wav_name = os.path.splitext(os.path.basename(wav_path))[0]
53 | signal, rate = torchaudio.load(wav_path)
54 | spec = mel_spectrogram(signal, 1024, 80, 22050, 256,
55 | 1024, 0, 8000, center=False).squeeze()
56 | # Write the features to feats.ark and feats.scp
57 | writer[wav_name] = spec
58 |
59 |
60 | emotions = [os.path.basename(x).split("_")[1] for x in glob.glob(dataset_path + '/**/**/*')]
61 | emotions = sorted(set(emotions))
62 |
63 | utt2spk = {}
64 | utt2emo = {}
65 | wavs = glob.glob(dataset_path + '**/**/*.wav')
66 | for wav_path in tqdm(wavs):
67 | wav_name = os.path.splitext(os.path.basename(wav_path))[0]
68 | emotion = emotions.index(wav_name.split("_")[1])
69 | if wav_path.split('/')[-3] == '1263201035':
70 | spk = 0 ## labels should start with 0
71 | elif wav_path.split('/')[-3] == '805570882':
72 | spk = 1
73 | else:
74 | spk = 2
75 | utt2spk[wav_name] = str(spk)
76 | utt2emo[wav_name] = str(emotion)
77 | utt2spk = dict(sorted(utt2spk.items()))
78 | utt2emo = dict(sorted(utt2emo.items()))
79 |
80 | with open(filelists_path + 'utt2emo.json', 'w') as fp:
81 | json.dump(utt2emo, fp, indent=4)
82 | with open(filelists_path + 'utt2spk.json', 'w') as fp:
83 | json.dump(utt2spk, fp, indent=4)
84 |
85 | txt_files = sorted(glob.glob(dataset_path + '/**/**/*.txt'))
86 | count = 0
87 | txt = []
88 | basenames = []
89 | utt2text = {}
90 | flag = False
91 | with open(filelists_path + 'text', 'w', encoding='utf-8') as write:
92 | for txt_path in txt_files:
93 | basename = os.path.basename(txt_path).replace('.txt', '')
94 | with open(txt_path, 'r', encoding='utf-8') as f:
95 | txt.append(_clean_text(f.read().strip("\n"), cleaner_names=["kazakh_cleaners"]).replace("'", ""))
96 | basenames.append(basename)
97 | output_string = [re.sub('(\d+)', lambda m: num2words(m.group(), lang='kz'), sentence) for sentence in txt]
98 | cleaned_txt = []
99 | for t in output_string:
100 | cleaned_txt.append(''.join([s for s in t if s in symbols]))
101 | utt2text = {basenames[i]: cleaned_txt[i] for i in range(len(cleaned_txt))}
102 | utt2text = dict(sorted(utt2text.items()))
103 |
104 | vocab = set()
105 | with open(filelists_path + '/text', 'w', encoding='utf-8') as f:
106 | for x, y in utt2text.items():
107 | for c in y: vocab.add(c)
108 | f.write(x + ' ' + y + '\n')
109 |
--------------------------------------------------------------------------------
/filelists/all_spks/feats.ark:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/filelists/all_spks/feats.ark
--------------------------------------------------------------------------------
/filelists/all_spks/feats.scp:
--------------------------------------------------------------------------------
1 | utt1 /Users/Desktop/code/GradTTS-emo/filelists/example/feats.ark:5
2 | utt2 /Users/Desktop/code/GradTTS-emo/filelists/example/feats.ark:78745
3 | utt3 /Users/Desktop/code/GradTTS-emo/filelists/example/feats.ark:370605
4 |
--------------------------------------------------------------------------------
/filelists/inference_generated.txt:
--------------------------------------------------------------------------------
1 | Августың аяқ жағына мүсінші тәңірия Венераның баласы Амур бейнесін орналастырған.|0|0
2 | Қарғыс айтқалы жатыр ғой, өз балаларына!– десіп үркіп үн салды.|1|1
--------------------------------------------------------------------------------
/inference_EMA.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import datetime as dt
4 | import numpy as np
5 | from scipy.io.wavfile import write
6 | import IPython.display as ipd
7 | import glob
8 | import torch
9 | from pydub import AudioSegment
10 | from torch.utils.data import DataLoader
11 | from text import text_to_sequence, cmudict
12 | from text.symbols import symbols
13 | import utils_data
14 | import re
15 | from num2words import num2words
16 | from kaldiio import WriteHelper
17 | import os
18 | from tqdm import tqdm
19 | from text import text_to_sequence, convert_text
20 | from model import GradTTSWithEmo
21 | import utils_data as utils
22 | from attrdict import AttrDict
23 | from models import Generator as HiFiGAN
24 |
25 |
26 | HIFIGAN_CONFIG = './configs/hifigan-config.json'
27 | HIFIGAN_CHECKPT = './checkpts/hifigan.pt'
28 |
29 |
30 | if __name__ == '__main__':
31 | hps, args = utils.get_hparams_decode()
32 | device = torch.device('cpu' if not torch.cuda.is_available() else "cuda")
33 | ckpt = utils_data.latest_checkpoint_path(hps.model_dir, "EMA_grad_*.pt")
34 | print(ckpt)
35 | model = GradTTSWithEmo(**hps.model).to(device)
36 | logger = utils_data.get_logger(hps.model_dir, "inference.log")
37 | utils_data.load_checkpoint(ckpt, model, None)
38 | _ = model.cuda().eval()
39 |
40 | print('Initializing HiFi-GAN...')
41 | with open(HIFIGAN_CONFIG) as f:
42 | h = AttrDict(json.load(f))
43 | vocoder = HiFiGAN(h)
44 | vocoder.load_state_dict(torch.load(HIFIGAN_CHECKPT, map_location=lambda loc, storage: loc)['generator'])
45 | _ = vocoder.cuda().eval()
46 | vocoder.remove_weight_norm()
47 |
48 | emos = sorted(["angry", "surprise", "fear", "happy", "neutral", "sad"])
49 | speakers = ['M1', 'F1', 'M2']
50 |
51 | with open(args.file, 'r', encoding='utf-8') as f:
52 | texts = [line.strip() for line in f.readlines()]
53 |
54 | replace_nums = []
55 | for i in texts:
56 | replace_nums.append(i.split('|', 1))
57 |
58 | nums2word = [re.sub('(\d+)', lambda m: num2words(m.group(), lang='kz'), sentence) for sentence in np.array(replace_nums)[:, 0]]
59 | # Speakers id.
60 | # M1 = 0
61 | # F1 = 1
62 | # M2 = 2
63 | text2speech = []
64 | for i, j in zip(nums2word, np.array(replace_nums)[:, 1]):
65 | text2speech.append(f'{i}|{j}')
66 |
67 | for i, line in enumerate(text2speech):
68 | emo_i = int(line.split('|')[1])
69 | control_spk_id = int(line.split('|')[2])
70 | control_emo_id = emos.index(emos[emo_i])
71 | text = line.split('|')[0]
72 | with torch.no_grad():
73 | ### define emotion
74 | emo = torch.LongTensor([control_emo_id]).to(device)
75 | sid = torch.LongTensor([control_spk_id]).to(device)
76 | text_padded, text_len = convert_text(text)
77 | y_enc, y_dec, attn = model.forward(text_padded, text_len,
78 | n_timesteps=args.timesteps,
79 | temperature=args.noise,
80 | stoc=args.stoc, spk=sid,emo=emo, length_scale=1.,
81 | classifier_free_guidance=args.guidance)
82 | res = y_dec.squeeze().cpu().numpy()
83 | x = torch.from_numpy(res).cuda().unsqueeze(0)
84 | y_g_hat = vocoder(x)
85 | audio = y_g_hat.squeeze()
86 | audio = audio * 32768.0
87 | audio = audio.detach().cpu().numpy().astype('int16')
88 | audio = AudioSegment(audio.data, frame_rate=22050, sample_width=2, channels=1)
89 | audio.export(f'{args.generated_path}/{emos[emo_i]}_{speakers[int(line.split("|")[2])]}.wav', format="wav")
--------------------------------------------------------------------------------
/melspec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import librosa
4 |
5 | mel_basis = {}
6 | hann_window = {}
7 |
8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9 | return torch.log(torch.clamp(x, min=clip_val) * C)
10 |
11 | def spectral_normalize_torch(magnitudes):
12 | output = dynamic_range_compression_torch(magnitudes)
13 | return output
14 |
15 |
16 |
17 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
18 | if torch.min(y) < -1.:
19 | print('min value is ', torch.min(y))
20 | if torch.max(y) > 1.:
21 | print('max value is ', torch.max(y))
22 |
23 | global mel_basis, hann_window
24 | if fmax not in mel_basis:
25 | mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
26 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
27 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
28 |
29 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
30 | y = y.squeeze(1)
31 |
32 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
33 | center=center, pad_mode='reflect', normalized=False, onesided=True)
34 |
35 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
36 |
37 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
38 | spec = spectral_normalize_torch(spec)
39 |
40 | return spec.numpy()
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .tts import GradTTSWithEmo, GradTTSXvector
3 |
--------------------------------------------------------------------------------
/model/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class BaseModule(torch.nn.Module):
6 | def __init__(self):
7 | super(BaseModule, self).__init__()
8 |
9 | @property
10 | def nparams(self):
11 | """
12 | Returns number of trainable parameters of the module.
13 | """
14 | num_params = 0
15 | for name, param in self.named_parameters():
16 | if param.requires_grad:
17 | num_params += np.prod(param.detach().cpu().numpy().shape)
18 | return num_params
19 |
20 | def relocate_input(self, x: list):
21 | """
22 | Relocates provided tensors to the same device set for the module.
23 | """
24 | device = next(self.parameters()).device
25 | for i in range(len(x)):
26 | if isinstance(x[i], torch.Tensor) and x[i].device != device:
27 | x[i] = x[i].to(device)
28 | return x
29 |
--------------------------------------------------------------------------------
/model/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor, BoolTensor
4 |
5 | from typing import Optional, Tuple, Iterable
6 | from model.diffusion import SinusoidalPosEmb
7 | from torch.nn.functional import pad
8 |
9 |
10 | import math
11 |
12 | def silu(input):
13 | '''
14 | Applies the Sigmoid Linear Unit (SiLU) function element-wise:
15 | SiLU(x) = x * sigmoid(x)
16 | '''
17 | return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions
18 |
19 |
20 | class RelPositionMultiHeadedAttention(nn.Module):
21 | """Multi-Head Self-Attention layer with relative position encoding.
22 | Paper: https://arxiv.org/abs/1901.02860
23 | Args:
24 | n_head: The number of heads.
25 | d: The number of features.
26 | dropout: Dropout rate.
27 | zero_triu: Whether to zero the upper triangular part of attention matrix.
28 | """
29 |
30 | def __init__(
31 | self, d: int, n_head: int, dropout: float
32 | ):
33 | super().__init__()
34 | assert d % n_head == 0
35 | self.c = d // n_head
36 | self.h = n_head
37 |
38 | self.linear_q = nn.Linear(d, d)
39 | self.linear_k = nn.Linear(d, d)
40 | self.linear_v = nn.Linear(d, d)
41 | self.linear_out = nn.Linear(d, d)
42 |
43 | self.p_attn = None
44 | self.dropout = nn.Dropout(p=dropout)
45 |
46 | # linear transformation for positional encoding
47 | self.linear_pos = nn.Linear(d, d, bias=False)
48 |
49 | # these two learnable bias are used in matrix c and matrix d
50 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3
51 | self.u = nn.Parameter(torch.Tensor(self.h, self.c))
52 | self.v = nn.Parameter(torch.Tensor(self.h, self.c))
53 | # [H, C]
54 | torch.nn.init.xavier_uniform_(self.u)
55 | torch.nn.init.xavier_uniform_(self.v)
56 |
57 | def forward_qkv(self, query, key, value) -> Tuple[Tensor, ...]:
58 | """Transform query, key and value.
59 | Args:
60 | query (Tensor): [B, S, D].
61 | key (Tensor): [B, T, D].
62 | value (Tensor): [B, T, D].
63 | Returns:
64 | q (Tensor): [B, H, S, C].
65 | k (Tensor): [B, H, T, C].
66 | v (Tensor): [B, H, T, C].
67 | """
68 | n_batch = query.size(0)
69 | q = self.linear_q(query).view(n_batch, -1, self.h, self.c)
70 | k = self.linear_k(key).view(n_batch, -1, self.h, self.c)
71 | v = self.linear_v(value).view(n_batch, -1, self.h, self.c)
72 | q = q.transpose(1, 2)
73 | k = k.transpose(1, 2)
74 | v = v.transpose(1, 2)
75 | return q, k, v
76 |
77 | def forward_attention(self, v, scores, mask, causal=False) -> Tensor:
78 | """Compute attention context vector.
79 | Args:
80 | v (Tensor): [B, H, T, C].
81 | scores (Tensor): [B, H, S, T].
82 | mask (BoolTensor): [B, T], True values are masked from scores.
83 | Returns:
84 | result (Tensor): [B, S, D]. Attention result weighted by the score.
85 | """
86 | n_batch, H, S, T = scores.shape
87 | if mask is not None:
88 | scores = scores.masked_fill(
89 | mask.unsqueeze(1).unsqueeze(2).to(bool),
90 | float("-inf"), # [B, H, S, T]
91 | )
92 | if causal:
93 | k_grid = torch.arange(0, S, dtype=torch.int32, device=scores.device)
94 | v_grid = torch.arange(0, T, dtype=torch.int32, device=scores.device)
95 | kk, vv = torch.meshgrid(k_grid, v_grid, indexing="ij")
96 | causal_mask = vv > kk
97 | scores = scores.masked_fill(
98 | causal_mask.view(1, 1, S, T), float("-inf")
99 | )
100 |
101 | p_attn = self.p_attn = torch.softmax(scores, dim=-1) # [B, H, S, T]
102 | p_attn = self.dropout(p_attn) # [B, H, S, T]
103 |
104 | x = torch.matmul(p_attn, v) # [B, H, S, C]
105 | x = (
106 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.c)
107 | ) # [B, S, D]
108 |
109 | return self.linear_out(x) # [B, S, D]
110 |
111 | def rel_shift(self, x):
112 | """Converting (..., i, i - j) matrix into (..., i, j) matrix.
113 | Args:
114 | x (Tensor): [B, H, S, 2S-1].
115 | Returns:
116 | x (Tensor): [B, H, S, S].
117 | Example: Take S = 2 for example, larger values work similarly.
118 | x = [
119 | [(0, -1), (0, 0), (0, 1)],
120 | [(1, 0), (1, 1), (1, 2)]
121 | ]
122 | x_padded = [
123 | [(x, x), (0, -1), (0, 0), (0, 1)],
124 | [(x, x), (1, 0), (1, 1), (1, 2)]]
125 | ]
126 | x_padded = [
127 | [(x, x), (0, -1)],
128 | [(0, 0), (0, 1)],
129 | [(x, x), (1, 0)],
130 | [(1, 1), (1, 2)]
131 | ]
132 | x = [
133 | [(0, 0), (0, 1)],
134 | [(1, 0), (1, 1)]
135 | ]
136 | """
137 | B, H, S, _ = x.shape
138 | zero_pad = torch.zeros((B, H, S, 1), device=x.device, dtype=x.dtype)
139 | # [B, H, S, 1]
140 | x_padded = torch.cat([zero_pad, x], dim=-1)
141 | # [B, H, S, 2S]
142 | x_padded = x_padded.view(B, H, 2 * S, S)
143 | # [B, H, 2S, S]
144 | x = x_padded[:, :, 1:].view_as(x)[:, :, :, :S]
145 | # only keep the positions from 0 to S
146 | # [B, H, 2S-1, S] [B, H, S, 2S - 1] [B, H, S, S]
147 | return x
148 |
149 | def forward(
150 | self, query, key, value, pos_emb, mask=None, causal=False):
151 | """Compute self-attention with relative positional embedding.
152 | Args:
153 | query (Tensor): [B, S, D].
154 | key (Tensor): [B, S, D].
155 | value (Tensor): [B, S, D].
156 | pos_emb (Tensor): [1/B, 2S-1, D]. Positional embedding.
157 | mask (BoolTensor): [B, S], True for masked.
158 | causal (bool): True for applying causal mask.
159 | Returns:
160 | output (Tensor): [B, S, D].
161 | """
162 | # Splitting Q, K, V:
163 | q, k, v = self.forward_qkv(query, key, value)
164 | # [B, H, S, C], [B, H, S, C], [B, H, S, C]
165 |
166 | # Adding per head & channel biases to the query vectors:
167 | q_u = q + self.u.unsqueeze(1)
168 | q_v = q + self.v.unsqueeze(1)
169 | # [B, H, S, C]
170 |
171 | # Splitting relative positional coding:
172 | n_batch_pos = pos_emb.size(0) # [1/B, 2S-1, D]
173 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.c)
174 | # [1/B, 2S-1, H, C]
175 | p = p.transpose(1, 2) # [1/B, H, 2S-1, C].
176 |
177 | # Compute query, key similarity:
178 | matrix_ac = torch.matmul(q_u, k.transpose(-2, -1))
179 | # [B, H, S, C] x [B, H, C, S] -> [B, H, S, S]
180 |
181 | matrix_bd = torch.matmul(q_v, p.transpose(-2, -1))
182 | # [B, H, S, C] x [1/B, H, C, 2S-1] -> [B, H, S, 2S-1]
183 | matrix_bd = self.rel_shift(matrix_bd)
184 |
185 | scores = (matrix_ac + matrix_bd) / math.sqrt(self.c)
186 | # [B, H, S, S]
187 |
188 | return self.forward_attention(v, scores, mask, causal) # [B, S, D]
189 |
190 |
191 | class ConditionalBiasScale(nn.Module):
192 | def __init__(self, channels: int, cond_channels: int):
193 | super().__init__()
194 | self.scale_transform = nn.Linear(
195 | cond_channels, channels, bias=True
196 | )
197 | self.bias_transform = nn.Linear(
198 | cond_channels, channels, bias=True
199 | )
200 | self.init_parameters()
201 |
202 | def init_parameters(self):
203 | torch.nn.init.constant_(self.scale_transform.weight, 0.0)
204 | torch.nn.init.constant_(self.scale_transform.bias, 1.0)
205 | torch.nn.init.constant_(self.bias_transform.weight, 0.0)
206 | torch.nn.init.constant_(self.bias_transform.bias, 0.0)
207 |
208 | def forward(self, x: Tensor, cond: Tensor) -> Tensor:
209 | """Applying conditional bias and scale.
210 | Args:
211 | x (Tensor): [..., channels].
212 | cond (Tensor): [..., cond_channels].
213 | Returns:
214 | y (Tensor): [..., channels].
215 | """
216 | a = self.scale_transform.forward(cond)
217 | b = self.bias_transform.forward(cond)
218 | return x * a + b
219 |
220 |
221 | class FeedForwardModule(torch.nn.Module):
222 | """Positionwise feed forward layer used in conformer"""
223 |
224 | def __init__(
225 | self, d_in: int, d_hidden: int,
226 | dropout: float, bias: bool = True, d_cond: int = 0
227 | ):
228 | """
229 | Args:
230 | d_in (int): Input feature dimension.
231 | d_hidden (int): Hidden unit dimension.
232 | dropout (float): dropout value for first Linear Layer.
233 | bias (bool): If linear layers should have bias.
234 | d_cond (int, optional): The channels of conditional tensor.
235 | """
236 | super(FeedForwardModule, self).__init__()
237 | self.layer_norm = torch.nn.LayerNorm(d_in)
238 |
239 | if d_cond > 0:
240 | self.cond_layer = ConditionalBiasScale(d_in, d_cond)
241 |
242 | self.w_1 = torch.nn.Linear(d_in, d_hidden, bias=bias)
243 | self.w_2 = torch.nn.Linear(d_hidden, d_in, bias=bias)
244 | self.dropout = torch.nn.Dropout(dropout)
245 |
246 | def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
247 | """
248 | Args:
249 | x (Tensor): [..., D].
250 | Returns:
251 | y (Tensor): [..., D].
252 | cond (Tensor): [..., D_cond]
253 | """
254 | x = self.layer_norm(x)
255 |
256 | if cond is not None:
257 | x = self.cond_layer.forward(x, cond)
258 |
259 | x = self.w_1(x)
260 | x = silu(x)
261 | x = self.dropout(x)
262 | x = self.w_2(x)
263 | return self.dropout(x)
264 |
265 |
266 | class RelPositionalEncoding(nn.Module):
267 | """Relative positional encoding cache.
268 |
269 | Args:
270 | d_model: Embedding dimension.
271 | dropout_rate: Dropout rate.
272 | max_len: Default maximum input length.
273 | """
274 |
275 | def __init__(self, max_len: int, d_model: int):
276 | super().__init__()
277 | self.d_model = d_model
278 | self.cached_code = None
279 | self.l = 0
280 | self.gen_code(torch.tensor(0.0).expand(1, max_len))
281 |
282 | def gen_code(self, x: Tensor):
283 | """Generate positional encoding with a reference tensor x.
284 | Args:
285 | x (Tensor): [B, L, ...], we extract the device, length, and dtype from it.
286 | Effects:
287 | self.cached_code (Tensor): [1, >=(2L-1), D].
288 | """
289 | l = x.size(1)
290 | if self.l >= l:
291 | if self.cached_code.dtype != x.dtype or self.cached_code.device != x.device:
292 | self.cached_code = self.cached_code.to(dtype=x.dtype, device=x.device)
293 | return
294 | # Suppose `i` means to the position of query vecotr and `j` means the
295 | # position of key vector. We use position relative positions when keys
296 | # are to the left (i>j) and negative relative positions otherwise (i Tensor:
319 | """Get positional encoding of appropriate shape given a reference Tensor.
320 | Args:
321 | x (Tensor): [B, L, ...].
322 | Returns:
323 | y (Tensor): [1, 2L-1, D].
324 | """
325 | self.gen_code(x)
326 | l = x.size(1)
327 | pos_emb = self.cached_code[
328 | :, self.l - l: self.l + l - 1,
329 | ]
330 | return pos_emb
331 |
332 |
333 | class ConformerBlock(torch.nn.Module):
334 | """Conformer block based on https://arxiv.org/abs/2005.08100."""
335 |
336 | def __init__(
337 | self, d: int, d_hidden: int,
338 | attention_heads: int, dropout: float,
339 | depthwise_conv_kernel_size: int = 7,
340 | causal: bool = False, d_cond: int = 0
341 | ):
342 | """
343 | Args:
344 | d (int): Block input output channel number.
345 | d_hidden (int): FFN layer dimension.
346 | attention_heads (int): Number of attention heads.
347 | dropout (float): dropout value.
348 | depthwise_conv_kernel_size (int): Size of kernel in depthwise conv.
349 | d_cond (int, optional): The channels of conditional tensor.
350 | """
351 | super(ConformerBlock, self).__init__()
352 | self.causal = causal
353 | self.ffn1 = FeedForwardModule(
354 | d, d_hidden, dropout, bias=True, d_cond=d_cond
355 | )
356 |
357 | self.self_attn_layer_norm = torch.nn.LayerNorm(d)
358 |
359 | if d_cond > 0:
360 | self.cond_layer = ConditionalBiasScale(d, d_cond)
361 |
362 | self.self_attn = RelPositionMultiHeadedAttention(
363 | d, attention_heads, dropout=dropout
364 | )
365 | self.self_attn_dropout = torch.nn.Dropout(dropout)
366 |
367 | self.conv_module = ConvolutionModule(
368 | d_in=d, d_hidden=d,
369 | depthwise_kernel_size=depthwise_conv_kernel_size,
370 | dropout=dropout, d_cond=d_cond
371 | )
372 |
373 | self.ffn2 = FeedForwardModule(
374 | d, d_hidden, dropout, bias=True, d_cond=d_cond
375 | )
376 |
377 | self.final_layer_norm = torch.nn.LayerNorm(d)
378 |
379 | def forward(
380 | self, x: Tensor, mask: BoolTensor, pos_emb: Tensor,
381 | cond: Optional[Tensor] = None
382 | ) -> Tensor:
383 | """
384 | Args:
385 | x (Tensor): [B, T, D_in].
386 | mask (BoolTensor): [B, T], True for masked.
387 | pos_emb (Tensor): [1 or B, 2T-1, D].
388 | cond (Tensor, optional): [B, ?, D_cond].
389 | Returns:
390 | y (Tensor): [B, T, D_in].
391 | """
392 | y = x
393 |
394 | x = self.ffn1(x) * 0.5 + y
395 | y = x
396 | # [B, T, D_in]
397 |
398 | x = self.self_attn_layer_norm(x)
399 |
400 | if cond is not None:
401 | x = self.cond_layer.forward(x, cond)
402 |
403 | x = self.self_attn.forward(
404 | query=x, key=x, value=x,
405 | pos_emb=pos_emb,
406 | mask=mask, causal=self.causal
407 | )
408 | x = self.self_attn_dropout(x) + y
409 | y = x
410 | # [B, T, D_in]
411 |
412 | x = self.conv_module.forward(x, mask) + y
413 | y = x
414 | # [B, T, D_in]
415 |
416 | x = self.ffn2(x) * 0.5 + y
417 |
418 | x = self.final_layer_norm(x)
419 |
420 | x.masked_fill(mask.unsqueeze(-1), 0.0)
421 |
422 | return x
423 |
424 |
425 | class ConvolutionModule(torch.nn.Module):
426 | """Convolution Block inside a Conformer Block."""
427 |
428 | def __init__(
429 | self, d_in: int, d_hidden: int,
430 | depthwise_kernel_size: int,
431 | dropout: float, bias: bool = False,
432 | causal: bool = False, d_cond: int = 0
433 | ):
434 | """
435 | Args:
436 | d_in (int): Embedding dimension.
437 | d_hidden (int): Number of channels in depthwise conv layers.
438 | depthwise_kernel_size (int): Depthwise conv layer kernel size.
439 | dropout (float): dropout value.
440 | bias (bool): If bias should be added to conv layers.
441 | conditional (bool): Whether to use conditional LayerNorm.
442 | """
443 | super(ConvolutionModule, self).__init__()
444 | assert (depthwise_kernel_size - 1) % 2 == 0, "kernel_size should be odd"
445 | self.causal = causal
446 | self.causal_padding = (depthwise_kernel_size - 1, 0)
447 | self.layer_norm = torch.nn.LayerNorm(d_in)
448 |
449 | # Optional conditional LayerNorm:
450 | self.d_cond = d_cond
451 | if d_cond > 0:
452 | self.cond_layer = ConditionalBiasScale(d_in, d_cond)
453 |
454 | self.pointwise_conv1 = torch.nn.Conv1d(
455 | d_in, 2 * d_hidden,
456 | kernel_size=1,
457 | stride=1, padding=0,
458 | bias=bias
459 | )
460 | self.glu = torch.nn.GLU(dim=1)
461 | self.depthwise_conv = torch.nn.Conv1d(
462 | d_hidden, d_hidden,
463 | kernel_size=depthwise_kernel_size,
464 | stride=1,
465 | padding=(depthwise_kernel_size - 1) // 2 if not causal else 0,
466 | groups=d_hidden, bias=bias
467 | )
468 | self.pointwise_conv2 = torch.nn.Conv1d(
469 | d_hidden, d_in,
470 | kernel_size=1,
471 | stride=1, padding=0,
472 | bias=bias,
473 | )
474 | self.dropout = torch.nn.Dropout(dropout)
475 |
476 | def forward(self, x: Tensor, mask: BoolTensor, cond: Optional[Tensor] = None) -> Tensor:
477 | """
478 | Args:
479 | x (Tensor): [B, T, D_in].
480 | mask (BoolTensor): [B, T], True for masked.
481 | cond (Tensor): [B, T, D_cond].
482 | Returns:
483 | y (Tensor): [B, T, D_in].
484 | """
485 | x = self.layer_norm(x)
486 |
487 | if cond is not None:
488 | x = self.cond_layer.forward(x, cond)
489 |
490 | x = x.transpose(-1, -2) # [B, D_in, T]
491 |
492 | x = self.pointwise_conv1(x) # [B, 2C, T]
493 | x = self.glu(x) # [B, C, T]
494 |
495 | # Take care of masking the input tensor:
496 | if mask is not None:
497 | x = x.masked_fill(mask.unsqueeze(1), 0.0)
498 |
499 | # 1D Depthwise Conv
500 | if self.causal: # Causal padding
501 | x = pad(x, self.causal_padding)
502 | x = self.depthwise_conv(x)
503 | # FIXME: BatchNorm should not be used in variable length training.
504 | x = silu(x) # [B, C, T]
505 |
506 | if mask is not None:
507 | x = x.masked_fill(mask.unsqueeze(1), 0.0)
508 |
509 | x = self.pointwise_conv2(x)
510 | x = self.dropout(x)
511 | return x.transpose(-1, -2) # [B, T, D_in]
512 |
513 |
514 | class Conformer(torch.nn.Module):
515 | def __init__(
516 | self,
517 | d: int,
518 | d_hidden: int,
519 | n_heads: int,
520 | n_layers: int,
521 | dropout: float,
522 | depthwise_conv_kernel_size: int,
523 | causal: bool = False,
524 | d_cond: int = 0
525 | ):
526 | super().__init__()
527 | self.pos_encoding = RelPositionalEncoding(1024, d)
528 | self.causal = causal
529 |
530 | self.blocks = torch.nn.ModuleList(
531 | [
532 | ConformerBlock(
533 | d=d,
534 | d_hidden=d_hidden,
535 | attention_heads=n_heads,
536 | dropout=dropout,
537 | depthwise_conv_kernel_size=depthwise_conv_kernel_size,
538 | causal=causal,
539 | d_cond=d_cond
540 | )
541 | for _ in range(n_layers)
542 | ]
543 | ) # type: Iterable[ConformerBlock]
544 |
545 | def forward(
546 | self, x: Tensor, mask: BoolTensor, cond: Tensor = None
547 | ) -> Tensor:
548 | """Conformer forwarding.
549 | Args:
550 | x (Tensor): [B, T, D].
551 | mask (BoolTensor): [B, T], with True for masked.
552 | cond (Tensor, optional): [B, T, D_cond].
553 | Returns:
554 | y (Tensor): [B, T, D]
555 | """
556 | pos_emb = self.pos_encoding(x) # [1, 2T-1, D]
557 |
558 | for block in self.blocks:
559 | x = block.forward(x, mask, pos_emb, cond)
560 |
561 | return x
562 |
563 |
564 | class CNNBlock(nn.Module):
565 | def __init__(self, in_dim, out_dim, dropout, cond_dim, kernel_size, stride):
566 | super(CNNBlock, self).__init__()
567 | self.layers = nn.Sequential(
568 | nn.Conv1d(in_dim, out_dim, kernel_size, stride),
569 | nn.ReLU(),
570 | nn.BatchNorm1d(out_dim,),
571 | nn.Dropout(p=dropout)
572 | )
573 |
574 | def forward(self, inp):
575 | out = self.layers(inp)
576 | return out
577 |
578 |
579 | class CNNClassifier(nn.Module):
580 | def __init__(self, in_dim, d_decoder, decoder_dropout, cond_dim):
581 | super(CNNClassifier, self).__init__()
582 | self.cnn = nn.Sequential(
583 | CNNBlock(in_dim, d_decoder, decoder_dropout, cond_dim, 8, 4),
584 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 8, 4),
585 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2),
586 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2),
587 | ) # receptive field is 180, frame shift is 64
588 | self.cond_layer = nn.Sequential(
589 | nn.Linear(cond_dim, in_dim),
590 | nn.LeakyReLU(),
591 | nn.Linear(in_dim, in_dim)
592 | )
593 |
594 | def forward(self, inp, mask, cond):
595 | inp = inp.transpose(-1, -2)
596 | cond = cond.transpose(-1, -2)
597 | inp.masked_fill_(mask.unsqueeze(1), 0.0)
598 | cond = self.cond_layer(cond.transpose(-1, -2)).transpose(-1, -2)
599 | cond.masked_fill_(mask.unsqueeze(1), 0.0)
600 | inp = inp + cond
601 | return self.cnn(inp)
602 |
603 |
604 | class CNNClassifierWithTime(nn.Module):
605 | def __init__(self, in_dim, d_decoder, decoder_dropout, cond_dim, time_emb_dim=512):
606 | super(CNNClassifierWithTime, self).__init__()
607 | self.cnn = nn.Sequential(
608 | CNNBlock(in_dim, d_decoder, decoder_dropout, cond_dim, 8, 4),
609 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 8, 4),
610 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2),
611 | CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2),
612 | ) # receptive field is 180, frame shift is 64
613 | self.cond_layer = nn.Sequential(
614 | nn.Linear(cond_dim, in_dim),
615 | nn.LeakyReLU(),
616 | nn.Linear(in_dim, in_dim)
617 | )
618 | self.time_emb = SinusoidalPosEmb(time_emb_dim)
619 | self.time_layer = nn.Sequential(
620 | nn.Linear(time_emb_dim, in_dim),
621 | nn.LeakyReLU(),
622 | nn.Linear(in_dim, in_dim)
623 | )
624 |
625 | def forward(self, inp, mask, cond, t):
626 | time_emb = self.time_emb(t) # [B, T]
627 | time_emb = self.time_layer(time_emb.unsqueeze(1)).transpose(-1, -2)
628 | inp = inp.transpose(-1, -2)
629 | cond = cond.transpose(-1, -2)
630 | inp.masked_fill_(mask.unsqueeze(1), 0.0)
631 | cond = self.cond_layer(cond.transpose(-1, -2)).transpose(-1, -2)
632 | cond.masked_fill_(mask.unsqueeze(1), 0.0)
633 | inp = inp + cond + time_emb
634 | return self.cnn(inp)
635 |
636 |
637 | class SpecClassifier(nn.Module):
638 | def __init__(self, in_dim, d_decoder, h_decoder,
639 | l_decoder, decoder_dropout,
640 | k_decoder, n_class, cond_dim, model_type='conformer'):
641 | super(SpecClassifier, self).__init__()
642 | self.model_type = model_type
643 | self.prenet = nn.Sequential(
644 | nn.Linear(in_features=in_dim, out_features=d_decoder)
645 | )
646 | if model_type == 'conformer':
647 | self.conformer = Conformer(d=d_decoder, d_hidden=d_decoder, n_heads=h_decoder,
648 | n_layers=l_decoder, dropout=decoder_dropout,
649 | depthwise_conv_kernel_size=k_decoder, d_cond=cond_dim)
650 | elif model_type == 'CNN':
651 | self.conformer = CNNClassifier(in_dim=d_decoder, d_decoder=d_decoder,
652 | decoder_dropout=decoder_dropout, cond_dim=cond_dim)
653 | elif model_type == 'CNN-with-time':
654 | self.conformer = CNNClassifierWithTime(in_dim=d_decoder, d_decoder=d_decoder,
655 | decoder_dropout=decoder_dropout, cond_dim=cond_dim, time_emb_dim=256)
656 | self.classifier = nn.Linear(d_decoder, n_class)
657 |
658 | def forward(self, noisy_mel, condition, mask, **kwargs):
659 | """
660 | Args:
661 | noisy_mel: [B, T, D]
662 | condition: [B, T, D]
663 | mask: [B, T] with True for un-masked (real-values)
664 |
665 | Returns:
666 | classification logits (un-softmaxed)
667 | """
668 | # print(noisy_mel.shape)
669 | noisy_mel = noisy_mel.masked_fill(~mask.unsqueeze(-1), 0.0)
670 |
671 | # print(self.prenet, noisy_mel.shape)
672 | hiddens = self.prenet(noisy_mel)
673 |
674 | if self.model_type == 'CNN-with-time':
675 | hiddens = self.conformer.forward(hiddens, ~mask, condition, kwargs['t'])
676 | else:
677 | hiddens = self.conformer.forward(hiddens, ~mask, condition) # [B, T, D]
678 |
679 | if self.model_type == 'conformer':
680 | averaged_hiddens = torch.mean(hiddens, dim=1) # [B, D]
681 | logits = self.classifier(averaged_hiddens)
682 | return logits
683 | elif self.model_type == 'CNN' or self.model_type == 'CNN-with-time':
684 | hiddens = hiddens.transpose(-1, -2)
685 | return self.classifier(hiddens) # [B, T', C]
686 |
687 | @property
688 | def nparams(self):
689 | return sum([p.numel() for p in self.parameters()])
690 |
691 |
--------------------------------------------------------------------------------
/model/diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from einops import rearrange
4 |
5 | from model.base import BaseModule
6 |
7 |
8 | class Mish(BaseModule):
9 | def forward(self, x):
10 | return x * torch.tanh(torch.nn.functional.softplus(x))
11 |
12 |
13 | class Upsample(BaseModule):
14 | def __init__(self, dim):
15 | super(Upsample, self).__init__()
16 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
17 |
18 | def forward(self, x):
19 | return self.conv(x)
20 |
21 |
22 | class Downsample(BaseModule):
23 | def __init__(self, dim):
24 | super(Downsample, self).__init__()
25 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) # kernel=3, stride=2, padding=1.
26 |
27 | def forward(self, x):
28 | return self.conv(x)
29 |
30 |
31 | class Rezero(BaseModule):
32 | def __init__(self, fn):
33 | super(Rezero, self).__init__()
34 | self.fn = fn
35 | self.g = torch.nn.Parameter(torch.zeros(1))
36 |
37 | def forward(self, x):
38 | return self.fn(x) * self.g
39 |
40 |
41 | class Block(BaseModule):
42 | def __init__(self, dim, dim_out, groups=8):
43 | super(Block, self).__init__()
44 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
45 | padding=1), torch.nn.GroupNorm(
46 | groups, dim_out), Mish())
47 |
48 | def forward(self, x, mask):
49 | output = self.block(x * mask)
50 | return output * mask
51 |
52 |
53 | class ResnetBlock(BaseModule):
54 | def __init__(self, dim, dim_out, time_emb_dim, groups=8):
55 | super(ResnetBlock, self).__init__()
56 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
57 | dim_out))
58 |
59 | self.block1 = Block(dim, dim_out, groups=groups)
60 | self.block2 = Block(dim_out, dim_out, groups=groups)
61 | if dim != dim_out:
62 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
63 | else:
64 | self.res_conv = torch.nn.Identity()
65 |
66 | def forward(self, x, mask, time_emb):
67 | h = self.block1(x, mask)
68 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
69 | h = self.block2(h, mask)
70 | output = h + self.res_conv(x * mask)
71 | return output
72 |
73 |
74 | class LinearAttention(BaseModule):
75 | def __init__(self, dim, heads=4, dim_head=32):
76 | super(LinearAttention, self).__init__()
77 | self.heads = heads
78 | hidden_dim = dim_head * heads
79 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # NOTE: 1x1 conv
80 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
81 |
82 | def forward(self, x):
83 | b, c, h, w = x.shape
84 | qkv = self.to_qkv(x)
85 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
86 | k = k.softmax(dim=-1)
87 | context = torch.einsum('bhdn,bhen->bhde', k, v)
88 | out = torch.einsum('bhde,bhdn->bhen', context, q)
89 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
90 | heads=self.heads, h=h, w=w)
91 | return self.to_out(out)
92 |
93 |
94 | class Residual(BaseModule):
95 | def __init__(self, fn):
96 | super(Residual, self).__init__()
97 | self.fn = fn
98 |
99 | def forward(self, x, *args, **kwargs):
100 | output = self.fn(x, *args, **kwargs) + x
101 | return output
102 |
103 |
104 | class SinusoidalPosEmb(BaseModule):
105 | def __init__(self, dim):
106 | super(SinusoidalPosEmb, self).__init__()
107 | self.dim = dim
108 |
109 | def forward(self, x, scale=1000):
110 | device = x.device
111 | half_dim = self.dim // 2
112 | emb = math.log(10000) / (half_dim - 1)
113 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
114 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
115 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
116 | return emb
117 |
118 |
119 | class GradLogPEstimator2d(BaseModule):
120 | def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, spk_emb_dim=64, n_feats=80, pe_scale=1000):
121 | super(GradLogPEstimator2d, self).__init__()
122 | self.dim = dim
123 | self.dim_mults = dim_mults
124 | self.groups = groups
125 | self.spk_emb_dim = spk_emb_dim
126 | self.pe_scale = pe_scale
127 |
128 | self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
129 | torch.nn.Linear(spk_emb_dim * 4, n_feats))
130 | self.time_pos_emb = SinusoidalPosEmb(dim)
131 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
132 | torch.nn.Linear(dim * 4, dim))
133 |
134 | dims = [3, *map(lambda m: dim * m, dim_mults)]
135 | in_out = list(zip(dims[:-1], dims[1:]))
136 | self.downs = torch.nn.ModuleList([])
137 | self.ups = torch.nn.ModuleList([])
138 | num_resolutions = len(in_out)
139 |
140 | for ind, (dim_in, dim_out) in enumerate(in_out):
141 | is_last = ind >= (num_resolutions - 1)
142 | self.downs.append(torch.nn.ModuleList([
143 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
144 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
145 | Residual(Rezero(LinearAttention(dim_out))),
146 | Downsample(dim_out) if not is_last else torch.nn.Identity()]))
147 |
148 | mid_dim = dims[-1]
149 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
150 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
151 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
152 |
153 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
154 | self.ups.append(torch.nn.ModuleList([
155 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
156 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
157 | Residual(Rezero(LinearAttention(dim_in))),
158 | Upsample(dim_in)]))
159 | self.final_block = Block(dim, dim)
160 | self.final_conv = torch.nn.Conv2d(dim, 1, 1)
161 |
162 | def forward(self, x, mask, mu, t, spk=None):
163 | # x, mu: [B, 80, L], t: [B, ], mask: [B, 1, L]
164 | if not isinstance(spk, type(None)):
165 | s = self.spk_mlp(spk)
166 |
167 | t = self.time_pos_emb(t, scale=self.pe_scale)
168 | t = self.mlp(t) # [B, 64]
169 |
170 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
171 | x = torch.stack([mu, x, s], 1) # [B, 3, 80, L]
172 | mask = mask.unsqueeze(1) # [B, 1, 1, L]
173 |
174 | hiddens = []
175 | masks = [mask]
176 | for resnet1, resnet2, attn, downsample in self.downs:
177 | mask_down = masks[-1]
178 | x = resnet1(x, mask_down, t) # [B, 64, 80, L]
179 | x = resnet2(x, mask_down, t)
180 | x = attn(x)
181 | hiddens.append(x)
182 | x = downsample(x * mask_down)
183 | masks.append(mask_down[:, :, :, ::2])
184 |
185 | masks = masks[:-1]
186 | mask_mid = masks[-1]
187 | x = self.mid_block1(x, mask_mid, t)
188 | x = self.mid_attn(x)
189 | x = self.mid_block2(x, mask_mid, t)
190 |
191 | for resnet1, resnet2, attn, upsample in self.ups:
192 | mask_up = masks.pop()
193 | x = torch.cat((x, hiddens.pop()), dim=1)
194 | x = resnet1(x, mask_up, t)
195 | x = resnet2(x, mask_up, t)
196 | x = attn(x)
197 | x = upsample(x * mask_up)
198 |
199 | x = self.final_block(x, mask)
200 | output = self.final_conv(x * mask)
201 |
202 | return (output * mask).squeeze(1)
203 |
204 |
205 | def get_noise(t, beta_init, beta_term, cumulative=False):
206 | if cumulative:
207 | noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
208 | else:
209 | noise = beta_init + (beta_term - beta_init)*t
210 | return noise
211 |
212 |
213 | class Diffusion(BaseModule):
214 | def __init__(self, n_feats, dim, spk_emb_dim=64,
215 | beta_min=0.05, beta_max=20, pe_scale=1000):
216 | super(Diffusion, self).__init__()
217 | self.n_feats = n_feats
218 | self.dim = dim
219 | # self.n_spks = n_spks
220 | self.spk_emb_dim = spk_emb_dim
221 | self.beta_min = beta_min
222 | self.beta_max = beta_max
223 | self.pe_scale = pe_scale
224 |
225 | self.estimator = GradLogPEstimator2d(dim,
226 | spk_emb_dim=spk_emb_dim,
227 | pe_scale=pe_scale,
228 | n_feats=n_feats)
229 |
230 | def forward_diffusion(self, x0, mask, mu, t):
231 | time = t.unsqueeze(-1).unsqueeze(-1)
232 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) # it is actually the integral of beta
233 | mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
234 | variance = 1.0 - torch.exp(-cum_noise)
235 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
236 | requires_grad=False)
237 | xt = mean + z * torch.sqrt(variance)
238 | return xt * mask, z * mask
239 |
240 | @torch.no_grad()
241 | def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None,
242 | use_classifier_free=False,
243 | classifier_free_guidance=3.0,
244 | dummy_spk=None): # emo need to be merged by spk
245 |
246 | # looks like a plain Euler-Maruyama method
247 | h = 1.0 / n_timesteps
248 | xt = z * mask
249 | for i in range(n_timesteps):
250 | t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
251 | device=z.device)
252 | time = t.unsqueeze(-1).unsqueeze(-1)
253 | noise_t = get_noise(time, self.beta_min, self.beta_max,
254 | cumulative=False)
255 |
256 | if not use_classifier_free:
257 | if stoc: # adds stochastic term
258 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
259 | dxt_det = dxt_det * noise_t * h
260 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
261 | requires_grad=False)
262 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
263 | dxt = dxt_det + dxt_stoc
264 | else:
265 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
266 | dxt = dxt * noise_t * h
267 | xt = (xt - dxt) * mask
268 | else:
269 | if stoc: # adds stochastic term
270 | score_estimate = (1 + classifier_free_guidance) * self.estimator(xt, mask, mu, t, spk) \
271 | - classifier_free_guidance * self.estimator(xt, mask, mu, t, dummy_spk)
272 | dxt_det = 0.5 * (mu - xt) - score_estimate
273 | dxt_det = dxt_det * noise_t * h
274 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
275 | requires_grad=False)
276 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
277 | dxt = dxt_det + dxt_stoc
278 | else:
279 | score_estimate = (1 + classifier_free_guidance) * self.estimator(xt, mask, mu, t, spk) \
280 | - classifier_free_guidance * self.estimator(xt, mask, mu, t, dummy_spk)
281 | dxt = 0.5 * (mu - xt - score_estimate)
282 | dxt = dxt * noise_t * h
283 | xt = (xt - dxt) * mask
284 | return xt
285 |
286 | @torch.no_grad()
287 | def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None,
288 | use_classifier_free=False,
289 | classifier_free_guidance=3.0,
290 | dummy_spk=None
291 | ):
292 | return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk, use_classifier_free, classifier_free_guidance, dummy_spk)
293 |
294 | def loss_t(self, x0, mask, mu, t, spk=None):
295 | xt, z = self.forward_diffusion(x0, mask, mu, t) # z is sampled from N(0, I)
296 | time = t.unsqueeze(-1).unsqueeze(-1)
297 | cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
298 | noise_estimation = self.estimator(xt, mask, mu, t, spk)
299 | noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) # multiply by lambda which is set to be variance
300 | # actually multiplied by sqrt(lambda), but not lambda
301 | # NOTE: here use a trick to put lambda into L2 norm so that don't divide z with std.
302 | loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats)
303 | return loss, xt
304 |
305 | def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
306 | t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
307 | requires_grad=False)
308 | t = torch.clamp(t, offset, 1.0 - offset)
309 | return self.loss_t(x0, mask, mu, t, spk)
310 |
311 | def classifier_decode(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo=None, classifier_type="conformer"):
312 | # control_emo should be [B, ] tensor
313 | h = 1.0 / n_timesteps
314 | xt = z * mask
315 | for i in range(n_timesteps):
316 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype,
317 | device=z.device)
318 | time = t.unsqueeze(-1).unsqueeze(-1)
319 | noise_t = get_noise(time, self.beta_min, self.beta_max,
320 | cumulative=False)
321 | # =========== classifier part ==============
322 | xt = xt.detach()
323 | xt.requires_grad_(True)
324 | if classifier_type == 'CNN-with-time':
325 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1), t=t)
326 | else:
327 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1))
328 |
329 | if classifier_type == 'conformer': # [B, C]
330 | probs = torch.log_softmax(logits, dim=-1) # [B, C]
331 | elif classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
332 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C]
333 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C]
334 | probs = torch.log(probs_mean)
335 | else:
336 | raise NotImplementedError
337 |
338 | control_emo_probs = probs[torch.arange(len(control_emo)).to(control_emo.device), control_emo]
339 | control_emo_probs.sum().backward(retain_graph=True)
340 | # NOTE: sum is to treat all the components as the same weight.
341 | xt_grad = xt.grad
342 | # ==========================================
343 |
344 | if stoc: # adds stochastic term
345 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad
346 | dxt_det = dxt_det * noise_t * h
347 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
348 | requires_grad=False)
349 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
350 | dxt = dxt_det + dxt_stoc
351 | else:
352 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad)
353 | dxt = dxt * noise_t * h
354 | xt = (xt - dxt) * mask
355 | return xt
356 |
357 | def classifier_decode_DPS(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo=None, classifier_type="conformer"):
358 | # control_emo should be [B, ] tensor
359 | h = 1.0 / n_timesteps
360 | xt = z * mask
361 | for i in range(n_timesteps):
362 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
363 | time = t.unsqueeze(-1).unsqueeze(-1)
364 | noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False)
365 | beta_integral_t = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
366 | bar_alpha_t = math.exp(-beta_integral_t)
367 |
368 | # =========== classifier part ==============
369 | xt = xt.detach()
370 | xt.requires_grad_(True)
371 | score_estimate = self.estimator(xt, mask, mu, t, spk)
372 | x0_hat = (xt + (1-bar_alpha_t) * score_estimate) / math.sqrt(bar_alpha_t)
373 |
374 | if classifier_type == 'CNN-with-time':
375 | raise NotImplementedError
376 | else:
377 | logits = classifier_func(x0_hat.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1))
378 | if classifier_type == 'conformer': # [B, C]
379 | probs = torch.log_softmax(logits, dim=-1) # [B, C]
380 | elif classifier_type == 'CNN':
381 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C]
382 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C]
383 |
384 | probs_mean = probs_mean + 10E-10
385 | # NOTE: at the first few steps, x0 may be very large. Then the classifier output logits will also have extreme value range.
386 | #
387 |
388 | probs = torch.log(probs_mean)
389 | else:
390 | raise NotImplementedError
391 |
392 | control_emo_probs = probs[torch.arange(len(control_emo)).to(control_emo.device), control_emo]
393 | control_emo_probs.sum().backward(retain_graph=True)
394 | # NOTE: sum is to treat all the components as the same weight.
395 | xt_grad = xt.grad
396 | # ==========================================
397 |
398 | if stoc: # adds stochastic term
399 | dxt_det = 0.5 * (mu - xt) - score_estimate - guidance * xt_grad
400 | dxt_det = dxt_det * noise_t * h
401 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, requires_grad=False)
402 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
403 | dxt = dxt_det + dxt_stoc
404 | else:
405 | dxt = 0.5 * (mu - xt - score_estimate - guidance * xt_grad)
406 | dxt = dxt * noise_t * h
407 | xt = (xt - dxt) * mask
408 | return xt
409 |
410 | def classifier_decode_mixture(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo1=None,control_emo2=None, emo1_weight=None, classifier_type="conformer"):
411 | # control_emo should be [B, ] tensor
412 | h = 1.0 / n_timesteps
413 | xt = z * mask
414 | for i in range(n_timesteps):
415 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype,
416 | device=z.device)
417 | time = t.unsqueeze(-1).unsqueeze(-1)
418 | noise_t = get_noise(time, self.beta_min, self.beta_max,
419 | cumulative=False)
420 | # =========== classifier part ==============
421 | xt = xt.detach()
422 | xt.requires_grad_(True)
423 | if classifier_type == 'CNN-with-time':
424 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1), t=t)
425 | else:
426 | logits = classifier_func(xt.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1))
427 |
428 | if classifier_type == 'conformer': # [B, C]
429 | probs = torch.log_softmax(logits, dim=-1) # [B, C]
430 | elif classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
431 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C]
432 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C]
433 | probs = torch.log(probs_mean)
434 | else:
435 | raise NotImplementedError
436 |
437 | control_emo_probs1 = probs[torch.arange(len(control_emo1)).to(control_emo1.device), control_emo1]
438 | control_emo_probs2 = probs[torch.arange(len(control_emo2)).to(control_emo2.device), control_emo2]
439 | control_emo_probs = control_emo_probs1 * emo1_weight + control_emo_probs2 * (1-emo1_weight) # interpolate
440 |
441 | control_emo_probs.sum().backward(retain_graph=True)
442 | # NOTE: sum is to treat all the components as the same weight.
443 | xt_grad = xt.grad
444 | # ==========================================
445 |
446 | if stoc: # adds stochastic term
447 | dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad
448 | dxt_det = dxt_det * noise_t * h
449 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
450 | requires_grad=False)
451 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
452 | dxt = dxt_det + dxt_stoc
453 | else:
454 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk) - guidance * xt_grad)
455 | dxt = dxt * noise_t * h
456 | xt = (xt - dxt) * mask
457 | return xt
458 |
459 | def classifier_decode_mixture_DPS(self, z, mask, mu, n_timesteps, stoc=False, spk=None, classifier_func=None, guidance=1.0, control_emo1=None,control_emo2=None, emo1_weight=None, classifier_type="conformer"):
460 | # control_emo should be [B, ] tensor
461 | h = 1.0 / n_timesteps
462 | xt = z * mask
463 | for i in range(n_timesteps):
464 | t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype,
465 | device=z.device)
466 | time = t.unsqueeze(-1).unsqueeze(-1)
467 | noise_t = get_noise(time, self.beta_min, self.beta_max,
468 | cumulative=False)
469 | beta_integral_t = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
470 | bar_alpha_t = math.exp(-beta_integral_t)
471 | # =========== classifier part ==============
472 | xt = xt.detach()
473 | xt.requires_grad_(True)
474 | score_estimate = self.estimator(xt, mask, mu, t, spk)
475 | x0_hat = (xt + (1 - bar_alpha_t) * score_estimate) / math.sqrt(bar_alpha_t)
476 |
477 | if classifier_type == 'CNN-with-time':
478 | raise NotImplementedError
479 | else:
480 | logits = classifier_func(x0_hat.transpose(1, 2), mu.transpose(1, 2), (mask == 1.0).squeeze(1))
481 |
482 | if classifier_type == 'conformer': # [B, C]
483 | probs = torch.log_softmax(logits, dim=-1) # [B, C]
484 | elif classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
485 | probs_every_place = torch.softmax(logits, dim=-1) # [B, T', C]
486 | probs_mean = torch.mean(probs_every_place, dim=1) # [B, C]
487 | probs_mean = probs_mean + 10E-10
488 |
489 | probs = torch.log(probs_mean)
490 | else:
491 | raise NotImplementedError
492 |
493 | control_emo_probs1 = probs[torch.arange(len(control_emo1)).to(control_emo1.device), control_emo1]
494 | control_emo_probs2 = probs[torch.arange(len(control_emo2)).to(control_emo2.device), control_emo2]
495 | control_emo_probs = control_emo_probs1 * emo1_weight + control_emo_probs2 * (1-emo1_weight) # interpolate
496 |
497 | control_emo_probs.sum().backward(retain_graph=True)
498 | # NOTE: sum is to treat all the components as the same weight.
499 | xt_grad = xt.grad
500 | # ==========================================
501 |
502 | if stoc: # adds stochastic term
503 | dxt_det = 0.5 * (mu - xt) - score_estimate - guidance * xt_grad
504 | dxt_det = dxt_det * noise_t * h
505 | dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
506 | requires_grad=False)
507 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
508 | dxt = dxt_det + dxt_stoc
509 | else:
510 | dxt = 0.5 * (mu - xt - score_estimate - guidance * xt_grad)
511 | dxt = dxt * noise_t * h
512 | xt = (xt - dxt) * mask
513 | return xt
514 |
--------------------------------------------------------------------------------
/model/monotonic_align/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jaehyeon Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/model/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | import numpy as np
4 | import torch
5 | from .model.monotonic_align.core import maximum_path_c
6 |
7 |
8 | def maximum_path(value, mask):
9 | """ Cython optimised version.
10 | value: [b, t_x, t_y]
11 | mask: [b, t_x, t_y]
12 | """
13 | value = value * mask
14 | device = value.device
15 | dtype = value.dtype
16 | value = value.data.cpu().numpy().astype(np.float32)
17 | path = np.zeros_like(value).astype(np.int32)
18 | mask = mask.data.cpu().numpy()
19 |
20 | t_x_max = mask.sum(1)[:, 0].astype(np.int32)
21 | t_y_max = mask.sum(2)[:, 0].astype(np.int32)
22 | maximum_path_c(path, value, t_x_max, t_y_max)
23 | return torch.from_numpy(path).to(device=device, dtype=dtype)
24 |
--------------------------------------------------------------------------------
/model/monotonic_align/build/temp.linux-x86_64-3.6/core.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/build/temp.linux-x86_64-3.6/core.o
--------------------------------------------------------------------------------
/model/monotonic_align/build/temp.macosx-10.9-x86_64-3.6/core.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/build/temp.macosx-10.9-x86_64-3.6/core.o
--------------------------------------------------------------------------------
/model/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | cimport numpy as np
3 | cimport cython
4 | from cython.parallel import prange
5 |
6 |
7 | @cython.boundscheck(False)
8 | @cython.wraparound(False)
9 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
10 | cdef int x
11 | cdef int y
12 | cdef float v_prev
13 | cdef float v_cur
14 | cdef float tmp
15 | cdef int index = t_x - 1
16 |
17 | for y in range(t_y):
18 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
19 | if x == y:
20 | v_cur = max_neg_val
21 | else:
22 | v_cur = value[x, y-1]
23 | if x == 0:
24 | if y == 0:
25 | v_prev = 0.
26 | else:
27 | v_prev = max_neg_val
28 | else:
29 | v_prev = value[x-1, y-1]
30 | value[x, y] = max(v_cur, v_prev) + value[x, y]
31 |
32 | for y in range(t_y - 1, -1, -1):
33 | path[index, y] = 1
34 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
35 | index = index - 1
36 |
37 |
38 | @cython.boundscheck(False)
39 | @cython.wraparound(False)
40 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
41 | cdef int b = values.shape[0]
42 |
43 | cdef int i
44 | for i in prange(b, nogil=True):
45 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
46 |
--------------------------------------------------------------------------------
/model/monotonic_align/model/monotonic_align/core.cpython-36m-darwin.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/model/monotonic_align/core.cpython-36m-darwin.so
--------------------------------------------------------------------------------
/model/monotonic_align/model/monotonic_align/core.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/model/monotonic_align/model/monotonic_align/core.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/model/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | from distutils.core import setup
4 | from Cython.Build import cythonize
5 | import numpy
6 |
7 | setup(
8 | name = 'monotonic_align',
9 | ext_modules = cythonize("core.pyx"),
10 | include_dirs=[numpy.get_include()]
11 | )
12 |
--------------------------------------------------------------------------------
/model/text_encoder.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | import math
4 |
5 | import torch
6 |
7 | from model.base import BaseModule
8 | from model.utils import sequence_mask, convert_pad_shape
9 |
10 |
11 | class LayerNorm(BaseModule):
12 | def __init__(self, channels, eps=1e-4):
13 | super(LayerNorm, self).__init__()
14 | self.channels = channels
15 | self.eps = eps
16 |
17 | self.gamma = torch.nn.Parameter(torch.ones(channels))
18 | self.beta = torch.nn.Parameter(torch.zeros(channels))
19 |
20 | def forward(self, x):
21 | n_dims = len(x.shape)
22 | mean = torch.mean(x, 1, keepdim=True)
23 | variance = torch.mean((x - mean)**2, 1, keepdim=True)
24 |
25 | x = (x - mean) * torch.rsqrt(variance + self.eps)
26 |
27 | shape = [1, -1] + [1] * (n_dims - 2)
28 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
29 | return x
30 |
31 |
32 | class ConvReluNorm(BaseModule):
33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
34 | n_layers, p_dropout):
35 | super(ConvReluNorm, self).__init__()
36 | self.in_channels = in_channels
37 | self.hidden_channels = hidden_channels
38 | self.out_channels = out_channels
39 | self.kernel_size = kernel_size
40 | self.n_layers = n_layers
41 | self.p_dropout = p_dropout
42 |
43 | self.conv_layers = torch.nn.ModuleList()
44 | self.norm_layers = torch.nn.ModuleList()
45 | self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
46 | kernel_size, padding=kernel_size//2))
47 | self.norm_layers.append(LayerNorm(hidden_channels))
48 | self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
49 | for _ in range(n_layers - 1):
50 | self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
51 | kernel_size, padding=kernel_size//2))
52 | self.norm_layers.append(LayerNorm(hidden_channels))
53 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
54 | self.proj.weight.data.zero_()
55 | self.proj.bias.data.zero_()
56 |
57 | def forward(self, x, x_mask):
58 | x_org = x
59 | for i in range(self.n_layers):
60 | x = self.conv_layers[i](x * x_mask)
61 | x = self.norm_layers[i](x)
62 | x = self.relu_drop(x)
63 | x = x_org + self.proj(x)
64 | return x * x_mask
65 |
66 |
67 | class DurationPredictor(BaseModule):
68 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
69 | super(DurationPredictor, self).__init__()
70 | self.in_channels = in_channels
71 | self.filter_channels = filter_channels
72 | self.p_dropout = p_dropout
73 |
74 | self.drop = torch.nn.Dropout(p_dropout)
75 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
76 | kernel_size, padding=kernel_size//2)
77 | self.norm_1 = LayerNorm(filter_channels)
78 | self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
79 | kernel_size, padding=kernel_size//2)
80 | self.norm_2 = LayerNorm(filter_channels)
81 | self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
82 |
83 | def forward(self, x, x_mask):
84 | x = self.conv_1(x * x_mask)
85 | x = torch.relu(x)
86 | x = self.norm_1(x)
87 | x = self.drop(x)
88 | x = self.conv_2(x * x_mask)
89 | x = torch.relu(x)
90 | x = self.norm_2(x)
91 | x = self.drop(x)
92 | x = self.proj(x * x_mask)
93 | return x * x_mask
94 |
95 |
96 | class MultiHeadAttention(BaseModule):
97 | def __init__(self, channels, out_channels, n_heads, window_size=None,
98 | heads_share=True, p_dropout=0.0, proximal_bias=False,
99 | proximal_init=False):
100 | super(MultiHeadAttention, self).__init__()
101 | assert channels % n_heads == 0
102 |
103 | self.channels = channels
104 | self.out_channels = out_channels
105 | self.n_heads = n_heads
106 | self.window_size = window_size
107 | self.heads_share = heads_share
108 | self.proximal_bias = proximal_bias
109 | self.p_dropout = p_dropout
110 | self.attn = None
111 |
112 | self.k_channels = channels // n_heads
113 | self.conv_q = torch.nn.Conv1d(channels, channels, 1)
114 | self.conv_k = torch.nn.Conv1d(channels, channels, 1)
115 | self.conv_v = torch.nn.Conv1d(channels, channels, 1)
116 | if window_size is not None:
117 | n_heads_rel = 1 if heads_share else n_heads
118 | rel_stddev = self.k_channels**-0.5
119 | self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
120 | window_size * 2 + 1, self.k_channels) * rel_stddev)
121 | self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
122 | window_size * 2 + 1, self.k_channels) * rel_stddev)
123 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
124 | self.drop = torch.nn.Dropout(p_dropout)
125 |
126 | torch.nn.init.xavier_uniform_(self.conv_q.weight)
127 | torch.nn.init.xavier_uniform_(self.conv_k.weight)
128 | if proximal_init:
129 | self.conv_k.weight.data.copy_(self.conv_q.weight.data)
130 | self.conv_k.bias.data.copy_(self.conv_q.bias.data)
131 | torch.nn.init.xavier_uniform_(self.conv_v.weight)
132 |
133 | def forward(self, x, c, attn_mask=None):
134 | q = self.conv_q(x)
135 | k = self.conv_k(c)
136 | v = self.conv_v(c)
137 |
138 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
139 |
140 | x = self.conv_o(x)
141 | return x
142 |
143 | def attention(self, query, key, value, mask=None):
144 | b, d, t_s, t_t = (*key.size(), query.size(2))
145 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
146 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
147 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
148 |
149 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
150 | if self.window_size is not None:
151 | assert t_s == t_t, "Relative attention is only available for self-attention."
152 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
153 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
154 | rel_logits = self._relative_position_to_absolute_position(rel_logits)
155 | scores_local = rel_logits / math.sqrt(self.k_channels)
156 | scores = scores + scores_local
157 | if self.proximal_bias:
158 | assert t_s == t_t, "Proximal bias is only available for self-attention."
159 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
160 | dtype=scores.dtype)
161 | if mask is not None:
162 | scores = scores.masked_fill(mask == 0, -1e4)
163 | p_attn = torch.nn.functional.softmax(scores, dim=-1)
164 | p_attn = self.drop(p_attn)
165 | output = torch.matmul(p_attn, value)
166 | if self.window_size is not None:
167 | relative_weights = self._absolute_position_to_relative_position(p_attn)
168 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
169 | output = output + self._matmul_with_relative_values(relative_weights,
170 | value_relative_embeddings)
171 | output = output.transpose(2, 3).contiguous().view(b, d, t_t)
172 | return output, p_attn
173 |
174 | def _matmul_with_relative_values(self, x, y):
175 | ret = torch.matmul(x, y.unsqueeze(0))
176 | return ret
177 |
178 | def _matmul_with_relative_keys(self, x, y):
179 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
180 | return ret
181 |
182 | def _get_relative_embeddings(self, relative_embeddings, length):
183 | pad_length = max(length - (self.window_size + 1), 0)
184 | slice_start_position = max((self.window_size + 1) - length, 0)
185 | slice_end_position = slice_start_position + 2 * length - 1
186 | if pad_length > 0:
187 | padded_relative_embeddings = torch.nn.functional.pad(
188 | relative_embeddings, convert_pad_shape([[0, 0],
189 | [pad_length, pad_length], [0, 0]]))
190 | else:
191 | padded_relative_embeddings = relative_embeddings
192 | used_relative_embeddings = padded_relative_embeddings[:,
193 | slice_start_position:slice_end_position]
194 | return used_relative_embeddings
195 |
196 | def _relative_position_to_absolute_position(self, x):
197 | batch, heads, length, _ = x.size()
198 | x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
199 | x_flat = x.view([batch, heads, length * 2 * length])
200 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
201 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
202 | return x_final
203 |
204 | def _absolute_position_to_relative_position(self, x):
205 | batch, heads, length, _ = x.size()
206 | x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
207 | x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
208 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
209 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
210 | return x_final
211 |
212 | def _attention_bias_proximal(self, length):
213 | r = torch.arange(length, dtype=torch.float32)
214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
216 |
217 |
218 | class FFN(BaseModule):
219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
220 | p_dropout=0.0):
221 | super(FFN, self).__init__()
222 | self.in_channels = in_channels
223 | self.out_channels = out_channels
224 | self.filter_channels = filter_channels
225 | self.kernel_size = kernel_size
226 | self.p_dropout = p_dropout
227 |
228 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
229 | padding=kernel_size//2)
230 | self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
231 | padding=kernel_size//2)
232 | self.drop = torch.nn.Dropout(p_dropout)
233 |
234 | def forward(self, x, x_mask):
235 | x = self.conv_1(x * x_mask)
236 | x = torch.relu(x)
237 | x = self.drop(x)
238 | x = self.conv_2(x * x_mask)
239 | return x * x_mask
240 |
241 |
242 | class Encoder(BaseModule):
243 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
244 | kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
245 | super(Encoder, self).__init__()
246 | self.hidden_channels = hidden_channels
247 | self.filter_channels = filter_channels
248 | self.n_heads = n_heads
249 | self.n_layers = n_layers
250 | self.kernel_size = kernel_size
251 | self.p_dropout = p_dropout
252 | self.window_size = window_size
253 |
254 | self.drop = torch.nn.Dropout(p_dropout)
255 | self.attn_layers = torch.nn.ModuleList()
256 | self.norm_layers_1 = torch.nn.ModuleList()
257 | self.ffn_layers = torch.nn.ModuleList()
258 | self.norm_layers_2 = torch.nn.ModuleList()
259 | for _ in range(self.n_layers):
260 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
261 | n_heads, window_size=window_size, p_dropout=p_dropout))
262 | self.norm_layers_1.append(LayerNorm(hidden_channels))
263 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
264 | filter_channels, kernel_size, p_dropout=p_dropout))
265 | self.norm_layers_2.append(LayerNorm(hidden_channels))
266 |
267 | def forward(self, x, x_mask):
268 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
269 | for i in range(self.n_layers):
270 | x = x * x_mask
271 | y = self.attn_layers[i](x, x, attn_mask)
272 | y = self.drop(y)
273 | x = self.norm_layers_1[i](x + y)
274 | y = self.ffn_layers[i](x, x_mask)
275 | y = self.drop(y)
276 | x = self.norm_layers_2[i](x + y)
277 | x = x * x_mask
278 | return x
279 |
280 |
281 | class TextEncoder(BaseModule):
282 | def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
283 | filter_channels_dp, n_heads, n_layers, kernel_size,
284 | p_dropout, window_size=None, spk_emb_dim=64, n_spks=1):
285 | super(TextEncoder, self).__init__()
286 | self.n_vocab = n_vocab
287 | self.n_feats = n_feats
288 | self.n_channels = n_channels
289 | self.filter_channels = filter_channels
290 | self.filter_channels_dp = filter_channels_dp
291 | self.n_heads = n_heads
292 | self.n_layers = n_layers
293 | self.kernel_size = kernel_size
294 | self.p_dropout = p_dropout
295 | self.window_size = window_size
296 | self.spk_emb_dim = spk_emb_dim
297 | self.n_spks = n_spks
298 |
299 | self.emb = torch.nn.Embedding(n_vocab, n_channels)
300 | torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
301 |
302 | self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
303 | kernel_size=5, n_layers=3, p_dropout=0.5)
304 |
305 | self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
306 | kernel_size, p_dropout, window_size=window_size)
307 |
308 | self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
309 | self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
310 | kernel_size, p_dropout)
311 |
312 | def forward(self, x, x_lengths, spk=None):
313 | x = self.emb(x) * math.sqrt(self.n_channels)
314 | x = torch.transpose(x, 1, -1)
315 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
316 |
317 | x = self.prenet(x, x_mask)
318 | if self.n_spks > 1:
319 | x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
320 | x = self.encoder(x, x_mask)
321 | mu = self.proj_m(x) * x_mask
322 |
323 | x_dp = torch.detach(x)
324 | logw = self.proj_w(x_dp, x_mask)
325 |
326 | return mu, logw, x_mask
327 |
--------------------------------------------------------------------------------
/model/tts.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import torch
5 |
6 | from model import monotonic_align
7 | from model.base import BaseModule
8 | from model.text_encoder import TextEncoder
9 | from model.diffusion import Diffusion
10 | from model.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility
11 |
12 |
13 | class GradTTSWithEmo(BaseModule):
14 | def __init__(self, n_vocab=148, n_spks=1,n_emos=5, spk_emb_dim=64,
15 | n_enc_channels=192, filter_channels=768, filter_channels_dp=256,
16 | n_heads=2, n_enc_layers=6, enc_kernel=3, enc_dropout=0.1, window_size=4,
17 | n_feats=80, dec_dim=64, beta_min=0.05, beta_max=20.0, pe_scale=1000,
18 | use_classifier_free=False, dummy_spk_rate=0.5,
19 | **kwargs):
20 | super(GradTTSWithEmo, self).__init__()
21 | self.n_vocab = n_vocab
22 | self.n_spks = n_spks
23 | self.n_emos = n_emos
24 | self.spk_emb_dim = spk_emb_dim
25 | self.n_enc_channels = n_enc_channels
26 | self.filter_channels = filter_channels
27 | self.filter_channels_dp = filter_channels_dp
28 | self.n_heads = n_heads
29 | self.n_enc_layers = n_enc_layers
30 | self.enc_kernel = enc_kernel
31 | self.enc_dropout = enc_dropout
32 | self.window_size = window_size
33 | self.n_feats = n_feats
34 | self.dec_dim = dec_dim
35 | self.beta_min = beta_min
36 | self.beta_max = beta_max
37 | self.pe_scale = pe_scale
38 | self.use_classifier_free = use_classifier_free
39 |
40 | # if n_spks > 1:
41 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
42 | self.emo_emb = torch.nn.Embedding(n_emos, spk_emb_dim)
43 | self.merge_spk_emo = torch.nn.Sequential(
44 | torch.nn.Linear(spk_emb_dim*2, spk_emb_dim),
45 | torch.nn.ReLU(),
46 | torch.nn.Linear(spk_emb_dim, spk_emb_dim)
47 | )
48 | self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels,
49 | filter_channels, filter_channels_dp, n_heads,
50 | n_enc_layers, enc_kernel, enc_dropout, window_size,
51 | spk_emb_dim=spk_emb_dim, n_spks=n_spks)
52 | self.decoder = Diffusion(n_feats, dec_dim, spk_emb_dim, beta_min, beta_max, pe_scale)
53 |
54 | if self.use_classifier_free:
55 | self.dummy_xv = torch.nn.Parameter(torch.randn(size=(spk_emb_dim, )))
56 | self.dummy_rate = dummy_spk_rate
57 | print(f"Using classifier free with rate {self.dummy_rate}")
58 |
59 | @torch.no_grad()
60 | def forward(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None,
61 | length_scale=1.0, classifier_free_guidance=1., force_dur=None):
62 | """
63 | Generates mel-spectrogram from text. Returns:
64 | 1. encoder outputs
65 | 2. decoder outputs
66 | 3. generated alignment
67 |
68 | Args:
69 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
70 | x_lengths (torch.Tensor): lengths of texts in batch.
71 | n_timesteps (int): number of steps to use for reverse diffusion in decoder.
72 | temperature (float, optional): controls variance of terminal distribution.
73 | stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
74 | Usually, does not provide synthesis improvements.
75 | length_scale (float, optional): controls speech pace.
76 | Increase value to slow down generated speech and vice versa.
77 | """
78 | x, x_lengths = self.relocate_input([x, x_lengths])
79 |
80 | # Get speaker embedding
81 | spk = self.spk_emb(spk)
82 | emo = self.emo_emb(emo)
83 |
84 | if self.use_classifier_free:
85 | emo = emo / torch.sqrt(torch.sum(emo**2, dim=1, keepdim=True)) # unit norm
86 |
87 | spk_merged = self.merge_spk_emo(torch.cat([spk, emo], dim=-1))
88 |
89 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
90 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged)
91 |
92 | w = torch.exp(logw) * x_mask
93 | w_ceil = torch.ceil(w) * length_scale
94 | if force_dur is not None:
95 | w_ceil = force_dur.unsqueeze(1) # [1, 1, Ltext]
96 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
97 | y_max_length = int(y_lengths.max())
98 | y_max_length_ = fix_len_compatibility(y_max_length)
99 |
100 | # Using obtained durations `w` construct alignment map `attn`
101 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
102 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
103 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
104 |
105 | # Align encoded text and get mu_y
106 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
107 | mu_y = mu_y.transpose(1, 2)
108 | encoder_outputs = mu_y[:, :, :y_max_length]
109 |
110 | # Sample latent representation from terminal distribution N(mu_y, I)
111 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
112 | # print(z)
113 | # Generate sample by performing reverse dynamics
114 |
115 | unit_dummy_emo = self.dummy_xv / torch.sqrt(torch.sum(self.dummy_xv**2)) if self.use_classifier_free else None
116 | dummy_spk = self.merge_spk_emo(torch.cat([spk, unit_dummy_emo.unsqueeze(0).repeat(len(spk), 1)], dim=-1)) if self.use_classifier_free else None
117 |
118 | decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk_merged,
119 | use_classifier_free=self.use_classifier_free,
120 | classifier_free_guidance=classifier_free_guidance,
121 | dummy_spk=dummy_spk)
122 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
123 |
124 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length]
125 |
126 | def classifier_guidance_decode(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None,
127 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'):
128 | x, x_lengths = self.relocate_input([x, x_lengths])
129 |
130 | # Get speaker embedding
131 | spk = self.spk_emb(spk)
132 | dummy_emo = self.emo_emb(torch.zeros_like(emo).long()) # this is for feeding the text encoder.
133 |
134 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1))
135 |
136 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
137 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged)
138 |
139 | w = torch.exp(logw) * x_mask
140 | # print("w shape is ", w.shape)
141 | w_ceil = torch.ceil(w) * length_scale
142 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
143 | y_max_length = int(y_lengths.max())
144 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
145 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier
146 | y_max_length_ = fix_len_compatibility(y_max_length)
147 |
148 | # Using obtained durations `w` construct alignment map `attn`
149 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
150 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
151 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
152 |
153 | # Align encoded text and get mu_y
154 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
155 | mu_y = mu_y.transpose(1, 2)
156 | encoder_outputs = mu_y[:, :, :y_max_length]
157 |
158 | # Sample latent representation from terminal distribution N(mu_y, I)
159 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
160 | # Generate sample by performing reverse dynamics
161 |
162 | decoder_outputs = self.decoder.classifier_decode(z, y_mask, mu_y, n_timesteps, stoc, spk_merged,
163 | classifier_func, guidance,
164 | control_emo=emo, classifier_type=classifier_type)
165 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
166 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length]
167 |
168 | def classifier_guidance_decode_DPS(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None,
169 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'):
170 | x, x_lengths = self.relocate_input([x, x_lengths])
171 |
172 | # Get speaker embedding
173 | spk = self.spk_emb(spk)
174 | dummy_emo = self.emo_emb(torch.zeros_like(emo).long()) # this is for feeding the text encoder.
175 |
176 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1))
177 |
178 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
179 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged)
180 |
181 | w = torch.exp(logw) * x_mask
182 | w_ceil = torch.ceil(w) * length_scale
183 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
184 | y_max_length = int(y_lengths.max())
185 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
186 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier
187 | y_max_length_ = fix_len_compatibility(y_max_length)
188 |
189 | # Using obtained durations `w` construct alignment map `attn`
190 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
191 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
192 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
193 |
194 | # Align encoded text and get mu_y
195 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
196 | mu_y = mu_y.transpose(1, 2)
197 | encoder_outputs = mu_y[:, :, :y_max_length]
198 |
199 | # Sample latent representation from terminal distribution N(mu_y, I)
200 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
201 | # Generate sample by performing reverse dynamics
202 |
203 | decoder_outputs = self.decoder.classifier_decode_DPS(z, y_mask, mu_y, n_timesteps, stoc, spk_merged,
204 | classifier_func, guidance,
205 | control_emo=emo, classifier_type=classifier_type)
206 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
207 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length]
208 |
209 | def classifier_guidance_decode_two_mixture(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo1=None, emo2=None, emo1_weight=None,
210 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'):
211 | x, x_lengths = self.relocate_input([x, x_lengths])
212 |
213 | # Get speaker embedding
214 | spk = self.spk_emb(spk)
215 | dummy_emo = self.emo_emb(torch.zeros_like(emo1).long()) # this is for feeding the text encoder.
216 |
217 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1))
218 |
219 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
220 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged)
221 |
222 | w = torch.exp(logw) * x_mask
223 | w_ceil = torch.ceil(w) * length_scale
224 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
225 | y_max_length = int(y_lengths.max())
226 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
227 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier
228 | y_max_length_ = fix_len_compatibility(y_max_length)
229 |
230 | # Using obtained durations `w` construct alignment map `attn`
231 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
232 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
233 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
234 |
235 | # Align encoded text and get mu_y
236 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
237 | mu_y = mu_y.transpose(1, 2)
238 | encoder_outputs = mu_y[:, :, :y_max_length]
239 |
240 | # Sample latent representation from terminal distribution N(mu_y, I)
241 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
242 | # Generate sample by performing reverse dynamics
243 |
244 | decoder_outputs = self.decoder.classifier_decode_mixture(z, y_mask, mu_y, n_timesteps, stoc, spk_merged,
245 | classifier_func, guidance,
246 | control_emo1=emo1, control_emo2=emo2, emo1_weight=emo1_weight, classifier_type=classifier_type)
247 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
248 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length]
249 |
250 | def classifier_guidance_decode_two_mixture_DPS(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo1=None, emo2=None, emo1_weight=None,
251 | length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'):
252 | x, x_lengths = self.relocate_input([x, x_lengths])
253 |
254 | # Get speaker embedding
255 | spk = self.spk_emb(spk)
256 | dummy_emo = self.emo_emb(torch.zeros_like(emo1).long()) # this is for feeding the text encoder.
257 |
258 | spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1))
259 |
260 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
261 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged)
262 |
263 | w = torch.exp(logw) * x_mask
264 | w_ceil = torch.ceil(w) * length_scale
265 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
266 | y_max_length = int(y_lengths.max())
267 | if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' :
268 | y_max_length = max(y_max_length, 180) # NOTE: added for CNN classifier
269 | y_max_length_ = fix_len_compatibility(y_max_length)
270 |
271 | # Using obtained durations `w` construct alignment map `attn`
272 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
273 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
274 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
275 |
276 | # Align encoded text and get mu_y
277 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
278 | mu_y = mu_y.transpose(1, 2)
279 | encoder_outputs = mu_y[:, :, :y_max_length]
280 |
281 | # Sample latent representation from terminal distribution N(mu_y, I)
282 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
283 | # Generate sample by performing reverse dynamics
284 |
285 | decoder_outputs = self.decoder.classifier_decode_mixture_DPS(z, y_mask, mu_y, n_timesteps, stoc, spk_merged,
286 | classifier_func, guidance,
287 | control_emo1=emo1, control_emo2=emo2, emo1_weight=emo1_weight, classifier_type=classifier_type)
288 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
289 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length]
290 |
291 | def compute_loss(self, x, x_lengths, y, y_lengths, spk=None, emo=None, out_size=None, use_gt_dur=False, durs=None):
292 | """
293 | Computes 3 losses:
294 | 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
295 | 2. prior loss: loss between mel-spectrogram and encoder outputs.
296 | 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
297 |
298 | Args:
299 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
300 | x_lengths (torch.Tensor): lengths of texts in batch.
301 | y (torch.Tensor): batch of corresponding mel-spectrograms.
302 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
303 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
304 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
305 | use_gt_dur: bool
306 | durs: gt duration
307 | """
308 | x, x_lengths, y, y_lengths = self.relocate_input([x, x_lengths, y, y_lengths]) # y: B, 80, L
309 |
310 | spk = self.spk_emb(spk)
311 | emo = self.emo_emb(emo) # [B, D]
312 | if self.use_classifier_free:
313 | emo = emo / torch.sqrt(torch.sum(emo ** 2, dim=1, keepdim=True)) # unit norm
314 | use_dummy_per_sample = torch.distributions.Binomial(1, torch.tensor(
315 | [self.dummy_rate] * len(emo))).sample().bool() # [b, ] True/False where True accords to rate
316 | emo[use_dummy_per_sample] = (self.dummy_xv / torch.sqrt(
317 | torch.sum(self.dummy_xv ** 2))) # substitute with dummy xv(unit norm too)
318 |
319 | spk = self.merge_spk_emo(torch.cat([spk, emo], dim=-1)) # [B, D]
320 |
321 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
322 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
323 | y_max_length = y.shape[-1]
324 |
325 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
326 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
327 |
328 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
329 | if use_gt_dur:
330 | attn = generate_path(durs, attn_mask.squeeze(1)).detach()
331 | else:
332 | with torch.no_grad():
333 | const = -0.5 * math.log(2 * math.pi) * self.n_feats
334 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
335 | y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
336 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
337 | mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
338 | log_prior = y_square - y_mu_double + mu_square + const
339 | # it's actually the log likelihood of y given the Gaussian with (mu_x, I)
340 |
341 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
342 | attn = attn.detach()
343 |
344 | # Compute loss between predicted log-scaled durations and those obtained from MAS
345 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
346 | dur_loss = duration_loss(logw, logw_, x_lengths)
347 | # print(attn.shape)
348 |
349 | # Cut a small segment of mel-spectrogram in order to increase batch size
350 | if not isinstance(out_size, type(None)):
351 | clip_size = min(out_size, y_max_length) # when out_size > max length, do not actually perform clipping
352 | clip_size = -fix_len_compatibility(-clip_size) # this is to ensure dividable
353 | max_offset = (y_lengths - clip_size).clamp(0)
354 | offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
355 | out_offset = torch.LongTensor([
356 | torch.tensor(random.choice(range(start, end)) if end > start else 0)
357 | for start, end in offset_ranges
358 | ]).to(y_lengths)
359 |
360 | attn_cut = torch.zeros(attn.shape[0], attn.shape[1], clip_size, dtype=attn.dtype, device=attn.device)
361 | y_cut = torch.zeros(y.shape[0], self.n_feats, clip_size, dtype=y.dtype, device=y.device)
362 | y_cut_lengths = []
363 | for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
364 | y_cut_length = clip_size + (y_lengths[i] - clip_size).clamp(None, 0)
365 | y_cut_lengths.append(y_cut_length)
366 | cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
367 | y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
368 | attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
369 | y_cut_lengths = torch.LongTensor(y_cut_lengths)
370 | y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
371 |
372 | attn = attn_cut # attn -> [B, text_length, cut_length]. It does not begin from top left corner
373 | y = y_cut
374 | y_mask = y_cut_mask
375 |
376 | # Align encoded text with mel-spectrogram and get mu_y segment
377 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) # here mu_x is not cut.
378 | mu_y = mu_y.transpose(1, 2) # B, 80, cut_length
379 |
380 | # Compute loss of score-based decoder
381 | # print(y.shape, y_mask.shape, mu_y.shape)
382 | diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk)
383 |
384 | # Compute loss between aligned encoder outputs and mel-spectrogram
385 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
386 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
387 |
388 | return dur_loss, prior_loss, diff_loss
389 |
390 |
391 | class GradTTSXvector(BaseModule):
392 | def __init__(self, n_vocab=148, spk_emb_dim=64,
393 | n_enc_channels=192, filter_channels=768, filter_channels_dp=256,
394 | n_heads=2, n_enc_layers=6, enc_kernel=3, enc_dropout=0.1, window_size=4,
395 | n_feats=80, dec_dim=64, beta_min=0.05, beta_max=20.0, pe_scale=1000, xvector_dim=512, **kwargs):
396 | super(GradTTSXvector, self).__init__()
397 | self.n_vocab = n_vocab
398 | # self.n_spks = n_spks
399 | self.spk_emb_dim = spk_emb_dim
400 | self.n_enc_channels = n_enc_channels
401 | self.filter_channels = filter_channels
402 | self.filter_channels_dp = filter_channels_dp
403 | self.n_heads = n_heads
404 | self.n_enc_layers = n_enc_layers
405 | self.enc_kernel = enc_kernel
406 | self.enc_dropout = enc_dropout
407 | self.window_size = window_size
408 | self.n_feats = n_feats
409 | self.dec_dim = dec_dim
410 | self.beta_min = beta_min
411 | self.beta_max = beta_max
412 | self.pe_scale = pe_scale
413 |
414 | self.xvector_proj = torch.nn.Linear(xvector_dim, spk_emb_dim)
415 | self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels,
416 | filter_channels, filter_channels_dp, n_heads,
417 | n_enc_layers, enc_kernel, enc_dropout, window_size,
418 | spk_emb_dim=spk_emb_dim, n_spks=999) # NOTE: not important `n_spk`
419 | self.decoder = Diffusion(n_feats, dec_dim, spk_emb_dim, beta_min, beta_max, pe_scale)
420 |
421 | @torch.no_grad()
422 | def forward(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, length_scale=1.0):
423 | """
424 | Generates mel-spectrogram from text. Returns:
425 | 1. encoder outputs
426 | 2. decoder outputs
427 | 3. generated alignment
428 |
429 | Args:
430 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
431 | x_lengths (torch.Tensor): lengths of texts in batch.
432 | n_timesteps (int): number of steps to use for reverse diffusion in decoder.
433 | temperature (float, optional): controls variance of terminal distribution.
434 | stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
435 | Usually, does not provide synthesis improvements.
436 | length_scale (float, optional): controls speech pace.
437 | Increase value to slow down generated speech and vice versa.
438 | spk: actually the xvectors
439 | """
440 | x, x_lengths = self.relocate_input([x, x_lengths])
441 |
442 | spk = self.xvector_proj(spk) # NOTE: use x-vectors instead of speaker embedding
443 |
444 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
445 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
446 |
447 | w = torch.exp(logw) * x_mask
448 | w_ceil = torch.ceil(w) * length_scale
449 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
450 | y_max_length = int(y_lengths.max())
451 | y_max_length_ = fix_len_compatibility(y_max_length)
452 |
453 | # Using obtained durations `w` construct alignment map `attn`
454 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
455 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
456 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
457 |
458 | # Align encoded text and get mu_y
459 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
460 | mu_y = mu_y.transpose(1, 2)
461 | encoder_outputs = mu_y[:, :, :y_max_length]
462 |
463 | # Sample latent representation from terminal distribution N(mu_y, I)
464 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
465 | # Generate sample by performing reverse dynamics
466 | decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk)
467 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
468 |
469 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length]
470 |
471 | def compute_loss(self, x, x_lengths, y, y_lengths, spk=None, out_size=None, use_gt_dur=False, durs=None):
472 | """
473 | Computes 3 losses:
474 | 1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS).
475 | 2. prior loss: loss between mel-spectrogram and encoder outputs.
476 | 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
477 |
478 | Args:
479 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
480 | x_lengths (torch.Tensor): lengths of texts in batch.
481 | y (torch.Tensor): batch of corresponding mel-spectrograms.
482 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
483 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
484 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
485 | spk: xvector
486 | use_gt_dur: bool
487 | durs: gt duration
488 | """
489 | x, x_lengths, y, y_lengths = self.relocate_input([x, x_lengths, y, y_lengths])
490 |
491 | spk = self.xvector_proj(spk) # NOTE: use x-vectors instead of speaker embedding
492 |
493 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
494 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk)
495 | y_max_length = y.shape[-1]
496 |
497 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
498 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
499 |
500 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
501 | if not use_gt_dur:
502 | with torch.no_grad():
503 | const = -0.5 * math.log(2 * math.pi) * self.n_feats
504 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
505 | y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
506 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
507 | mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
508 | log_prior = y_square - y_mu_double + mu_square + const
509 |
510 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
511 | attn = attn.detach()
512 | else:
513 | with torch.no_grad():
514 | attn = generate_path(durs, attn_mask.squeeze(1)).detach()
515 |
516 | # Compute loss between predicted log-scaled durations and those obtained from MAS
517 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
518 | dur_loss = duration_loss(logw, logw_, x_lengths)
519 |
520 | # print(attn.shape)
521 |
522 | # Cut a small segment of mel-spectrogram in order to increase batch size
523 | if not isinstance(out_size, type(None)):
524 | max_offset = (y_lengths - out_size).clamp(0)
525 | offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
526 | out_offset = torch.LongTensor([
527 | torch.tensor(random.choice(range(start, end)) if end > start else 0)
528 | for start, end in offset_ranges
529 | ]).to(y_lengths)
530 |
531 | attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
532 | y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
533 | y_cut_lengths = []
534 | for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
535 | y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
536 | y_cut_lengths.append(y_cut_length)
537 | cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
538 | y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
539 | attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
540 | y_cut_lengths = torch.LongTensor(y_cut_lengths)
541 | y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
542 |
543 | attn = attn_cut
544 | y = y_cut
545 | y_mask = y_cut_mask
546 |
547 | # Align encoded text with mel-spectrogram and get mu_y segment
548 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
549 | mu_y = mu_y.transpose(1, 2)
550 |
551 | # Compute loss of score-based decoder
552 | diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk)
553 |
554 | # Compute loss between aligned encoder outputs and mel-spectrogram
555 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
556 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
557 |
558 | return dur_loss, prior_loss, diff_loss
559 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | import torch
4 |
5 |
6 | def sequence_mask(length, max_length=None):
7 | if max_length is None:
8 | max_length = length.max()
9 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
10 | return x.unsqueeze(0) < length.unsqueeze(1)
11 |
12 |
13 | def fix_len_compatibility(length, num_downsamplings_in_unet=2):
14 | while True:
15 | if length % (2**num_downsamplings_in_unet) == 0:
16 | return length
17 | length += 1
18 |
19 |
20 | def convert_pad_shape(pad_shape):
21 | l = pad_shape[::-1]
22 | pad_shape = [item for sublist in l for item in sublist]
23 | return pad_shape
24 |
25 |
26 | def generate_path(duration, mask):
27 | device = duration.device
28 |
29 | b, t_x, t_y = mask.shape
30 | cum_duration = torch.cumsum(duration, 1)
31 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
32 |
33 | cum_duration_flat = cum_duration.view(b * t_x)
34 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
35 | path = path.view(b, t_x, t_y)
36 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
37 | [1, 0], [0, 0]]))[:, :-1]
38 | path = path * mask
39 | return path
40 |
41 |
42 | def duration_loss(logw, logw_, lengths):
43 | loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
44 | return loss
45 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6 | from xutils import init_weights, get_padding
7 |
8 | LRELU_SLOPE = 0.1
9 |
10 |
11 | class ResBlock1(torch.nn.Module):
12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
13 | super(ResBlock1, self).__init__()
14 | self.h = h
15 | self.convs1 = nn.ModuleList([
16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
17 | padding=get_padding(kernel_size, dilation[0]))),
18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
19 | padding=get_padding(kernel_size, dilation[1]))),
20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
21 | padding=get_padding(kernel_size, dilation[2])))
22 | ])
23 | self.convs1.apply(init_weights)
24 |
25 | self.convs2 = nn.ModuleList([
26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
27 | padding=get_padding(kernel_size, 1))),
28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
29 | padding=get_padding(kernel_size, 1))),
30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
31 | padding=get_padding(kernel_size, 1)))
32 | ])
33 | self.convs2.apply(init_weights)
34 |
35 | def forward(self, x):
36 | for c1, c2 in zip(self.convs1, self.convs2):
37 | xt = F.leaky_relu(x, LRELU_SLOPE)
38 | xt = c1(xt)
39 | xt = F.leaky_relu(xt, LRELU_SLOPE)
40 | xt = c2(xt)
41 | x = xt + x
42 | return x
43 |
44 | def remove_weight_norm(self):
45 | for l in self.convs1:
46 | remove_weight_norm(l)
47 | for l in self.convs2:
48 | remove_weight_norm(l)
49 |
50 |
51 | class ResBlock2(torch.nn.Module):
52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
53 | super(ResBlock2, self).__init__()
54 | self.h = h
55 | self.convs = nn.ModuleList([
56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
57 | padding=get_padding(kernel_size, dilation[0]))),
58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
59 | padding=get_padding(kernel_size, dilation[1])))
60 | ])
61 | self.convs.apply(init_weights)
62 |
63 | def forward(self, x):
64 | for c in self.convs:
65 | xt = F.leaky_relu(x, LRELU_SLOPE)
66 | xt = c(xt)
67 | x = xt + x
68 | return x
69 |
70 | def remove_weight_norm(self):
71 | for l in self.convs:
72 | remove_weight_norm(l)
73 |
74 |
75 | class Generator(torch.nn.Module):
76 | def __init__(self, h):
77 | super(Generator, self).__init__()
78 | self.h = h
79 | self.num_kernels = len(h.resblock_kernel_sizes)
80 | self.num_upsamples = len(h.upsample_rates)
81 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
82 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2
83 |
84 | self.ups = nn.ModuleList()
85 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
86 | self.ups.append(weight_norm(
87 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
88 | k, u, padding=(k-u)//2)))
89 |
90 | self.resblocks = nn.ModuleList()
91 | for i in range(len(self.ups)):
92 | ch = h.upsample_initial_channel//(2**(i+1))
93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
94 | self.resblocks.append(resblock(h, ch, k, d))
95 |
96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
97 | self.ups.apply(init_weights)
98 | self.conv_post.apply(init_weights)
99 |
100 | def forward(self, x):
101 | x = self.conv_pre(x)
102 | for i in range(self.num_upsamples):
103 | x = F.leaky_relu(x, LRELU_SLOPE)
104 | x = self.ups[i](x)
105 | xs = None
106 | for j in range(self.num_kernels):
107 | if xs is None:
108 | xs = self.resblocks[i*self.num_kernels+j](x)
109 | else:
110 | xs += self.resblocks[i*self.num_kernels+j](x)
111 | x = xs / self.num_kernels
112 | x = F.leaky_relu(x)
113 | x = self.conv_post(x)
114 | x = torch.tanh(x)
115 |
116 | return x
117 |
118 | def remove_weight_norm(self):
119 | print('Removing weight norm...')
120 | for l in self.ups:
121 | remove_weight_norm(l)
122 | for l in self.resblocks:
123 | l.remove_weight_norm()
124 | remove_weight_norm(self.conv_pre)
125 | remove_weight_norm(self.conv_post)
126 |
127 |
128 | class DiscriminatorP(torch.nn.Module):
129 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
130 | super(DiscriminatorP, self).__init__()
131 | self.period = period
132 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
133 | self.convs = nn.ModuleList([
134 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
135 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
136 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
137 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
138 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
139 | ])
140 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
141 |
142 | def forward(self, x):
143 | fmap = []
144 |
145 | # 1d to 2d
146 | b, c, t = x.shape
147 | if t % self.period != 0: # pad first
148 | n_pad = self.period - (t % self.period)
149 | x = F.pad(x, (0, n_pad), "reflect")
150 | t = t + n_pad
151 | x = x.view(b, c, t // self.period, self.period)
152 |
153 | for l in self.convs:
154 | x = l(x)
155 | x = F.leaky_relu(x, LRELU_SLOPE)
156 | fmap.append(x)
157 | x = self.conv_post(x)
158 | fmap.append(x)
159 | x = torch.flatten(x, 1, -1)
160 |
161 | return x, fmap
162 |
163 |
164 | class MultiPeriodDiscriminator(torch.nn.Module):
165 | def __init__(self):
166 | super(MultiPeriodDiscriminator, self).__init__()
167 | self.discriminators = nn.ModuleList([
168 | DiscriminatorP(2),
169 | DiscriminatorP(3),
170 | DiscriminatorP(5),
171 | DiscriminatorP(7),
172 | DiscriminatorP(11),
173 | ])
174 |
175 | def forward(self, y, y_hat):
176 | y_d_rs = []
177 | y_d_gs = []
178 | fmap_rs = []
179 | fmap_gs = []
180 | for i, d in enumerate(self.discriminators):
181 | y_d_r, fmap_r = d(y)
182 | y_d_g, fmap_g = d(y_hat)
183 | y_d_rs.append(y_d_r)
184 | fmap_rs.append(fmap_r)
185 | y_d_gs.append(y_d_g)
186 | fmap_gs.append(fmap_g)
187 |
188 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
189 |
190 |
191 | class DiscriminatorS(torch.nn.Module):
192 | def __init__(self, use_spectral_norm=False):
193 | super(DiscriminatorS, self).__init__()
194 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
195 | self.convs = nn.ModuleList([
196 | norm_f(Conv1d(1, 128, 15, 1, padding=7)),
197 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
198 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
199 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
200 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
201 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
202 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
203 | ])
204 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
205 |
206 | def forward(self, x):
207 | fmap = []
208 | for l in self.convs:
209 | x = l(x)
210 | x = F.leaky_relu(x, LRELU_SLOPE)
211 | fmap.append(x)
212 | x = self.conv_post(x)
213 | fmap.append(x)
214 | x = torch.flatten(x, 1, -1)
215 |
216 | return x, fmap
217 |
218 |
219 | class MultiScaleDiscriminator(torch.nn.Module):
220 | def __init__(self):
221 | super(MultiScaleDiscriminator, self).__init__()
222 | self.discriminators = nn.ModuleList([
223 | DiscriminatorS(use_spectral_norm=True),
224 | DiscriminatorS(),
225 | DiscriminatorS(),
226 | ])
227 | self.meanpools = nn.ModuleList([
228 | AvgPool1d(4, 2, padding=2),
229 | AvgPool1d(4, 2, padding=2)
230 | ])
231 |
232 | def forward(self, y, y_hat):
233 | y_d_rs = []
234 | y_d_gs = []
235 | fmap_rs = []
236 | fmap_gs = []
237 | for i, d in enumerate(self.discriminators):
238 | if i != 0:
239 | y = self.meanpools[i-1](y)
240 | y_hat = self.meanpools[i-1](y_hat)
241 | y_d_r, fmap_r = d(y)
242 | y_d_g, fmap_g = d(y_hat)
243 | y_d_rs.append(y_d_r)
244 | fmap_rs.append(fmap_r)
245 | y_d_gs.append(y_d_g)
246 | fmap_gs.append(fmap_g)
247 |
248 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
249 |
250 |
251 | def feature_loss(fmap_r, fmap_g):
252 | loss = 0
253 | for dr, dg in zip(fmap_r, fmap_g):
254 | for rl, gl in zip(dr, dg):
255 | loss += torch.mean(torch.abs(rl - gl))
256 |
257 | return loss*2
258 |
259 |
260 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
261 | loss = 0
262 | r_losses = []
263 | g_losses = []
264 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
265 | r_loss = torch.mean((1-dr)**2)
266 | g_loss = torch.mean(dg**2)
267 | loss += (r_loss + g_loss)
268 | r_losses.append(r_loss.item())
269 | g_losses.append(g_loss.item())
270 |
271 | return loss, r_losses, g_losses
272 |
273 |
274 | def generator_loss(disc_outputs):
275 | loss = 0
276 | gen_losses = []
277 | for dg in disc_outputs:
278 | l = torch.mean((1-dg)**2)
279 | gen_losses.append(l)
280 | loss += l
281 |
282 | return loss, gen_losses
283 |
284 |
--------------------------------------------------------------------------------
/text/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IS2AI/KazEmoTTS/d01c33a8c08e9a8ea04e805c0085f4ef8f740b92/text/.DS_Store
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | CMUdict
2 | -------
3 |
4 | CMUdict (the Carnegie Mellon Pronouncing Dictionary) is a free
5 | pronouncing dictionary of English, suitable for uses in speech
6 | technology and is maintained by the Speech Group in the School of
7 | Computer Science at Carnegie Mellon University.
8 |
9 | The Carnegie Mellon Speech Group does not guarantee the accuracy of
10 | this dictionary, nor its suitability for any specific purpose. In
11 | fact, we expect a number of errors, omissions and inconsistencies to
12 | remain in the dictionary. We intend to continually update the
13 | dictionary by correction existing entries and by adding new ones. From
14 | time to time a new major version will be released.
15 |
16 | We welcome input from users: Please send email to Alex Rudnicky
17 | (air+cmudict@cs.cmu.edu).
18 |
19 | The Carnegie Mellon Pronouncing Dictionary, in its current and
20 | previous versions is Copyright (C) 1993-2014 by Carnegie Mellon
21 | University. Use of this dictionary for any research or commercial
22 | purpose is completely unrestricted. If you make use of or
23 | redistribute this material we request that you acknowledge its
24 | origin in your descriptions.
25 |
26 | If you add words to or correct words in your version of this
27 | dictionary, we would appreciate it if you could send these additions
28 | and corrections to us (air+cmudict@cs.cmu.edu) for consideration in a
29 | subsequent version. All submissions will be reviewed and approved by
30 | the current maintainer, Alex Rudnicky at Carnegie Mellon.
31 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 | from text import cleaners
5 | from text.symbols import symbols
6 | import torch
7 |
8 |
9 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
10 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
11 |
12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
13 |
14 |
15 | def get_arpabet(word, dictionary):
16 | word_arpabet = dictionary.lookup(word)
17 | if word_arpabet is not None:
18 | return "{" + word_arpabet[0] + "}"
19 | else:
20 | return word
21 |
22 |
23 | def text_to_sequence(text, cleaner_names=["kazakh_cleaners"], dictionary=None):
24 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
25 |
26 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
27 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
28 |
29 | Args:
30 | text: string to convert to a sequence
31 | cleaner_names: names of the cleaner functions to run the text through
32 | dictionary: arpabet class with arpabet dictionary
33 |
34 | Returns:
35 | List of integers corresponding to the symbols in the text
36 | '''
37 | sequence = []
38 | space = _symbols_to_sequence(' ')
39 | # Check for curly braces and treat their contents as ARPAbet:
40 | while len(text):
41 | m = _curly_re.match(text)
42 | if not m:
43 | clean_text = _clean_text(text, cleaner_names)
44 | #clean_text = text
45 | if dictionary is not None:
46 | clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
47 | for i in range(len(clean_text)):
48 | t = clean_text[i]
49 | if t.startswith("{"):
50 | sequence += _arpabet_to_sequence(t[1:-1])
51 | else:
52 | sequence += _symbols_to_sequence(t)
53 | sequence += space
54 | else:
55 | sequence += _symbols_to_sequence(clean_text)
56 | break
57 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
58 | sequence += _arpabet_to_sequence(m.group(2))
59 | text = m.group(3)
60 |
61 | # remove trailing space
62 | if dictionary is not None:
63 | sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
64 | return sequence
65 |
66 |
67 | def sequence_to_text(sequence):
68 | '''Converts a sequence of IDs back to a string'''
69 | result = ''
70 | for symbol_id in sequence:
71 | if symbol_id in _id_to_symbol:
72 | s = _id_to_symbol[symbol_id]
73 | # Enclose ARPAbet back in curly braces:
74 | if len(s) > 1 and s[0] == '@':
75 | s = '{%s}' % s[1:]
76 | result += s
77 | return result.replace('}{', ' ')
78 |
79 | def convert_text(string):
80 | text_norm = text_to_sequence(string.lower())
81 | text_norm = torch.IntTensor(text_norm)
82 | text_len = torch.IntTensor([text_norm.size(0)])
83 | text_padded = torch.LongTensor(1, len(text_norm))
84 | text_padded.zero_()
85 | text_padded[0, :text_norm.size(0)] = text_norm
86 | return text_padded, text_len
87 |
88 | def _clean_text(text, cleaner_names):
89 | for name in cleaner_names:
90 | cleaner = getattr(cleaners, name)
91 | if not cleaner:
92 | raise Exception('Unknown cleaner: %s' % name)
93 | text = cleaner(text)
94 | return text
95 |
96 |
97 | def _symbols_to_sequence(symbols):
98 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
99 |
100 |
101 | def _arpabet_to_sequence(text):
102 | return _symbols_to_sequence(['@' + s for s in text.split()])
103 |
104 |
105 | def _should_keep_symbol(s):
106 | return s in _symbol_to_id and s != '_' and s != '~'
107 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 | from unidecode import unidecode
5 |
6 |
7 | _whitespace_re = re.compile(r'\s+')
8 |
9 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
10 | ('mrs', 'misess'),
11 | ('mr', 'mister'),
12 | ('dr', 'doctor'),
13 | ('st', 'saint'),
14 | ('co', 'company'),
15 | ('jr', 'junior'),
16 | ('maj', 'major'),
17 | ('gen', 'general'),
18 | ('drs', 'doctors'),
19 | ('rev', 'reverend'),
20 | ('lt', 'lieutenant'),
21 | ('hon', 'honorable'),
22 | ('sgt', 'sergeant'),
23 | ('capt', 'captain'),
24 | ('esq', 'esquire'),
25 | ('ltd', 'limited'),
26 | ('col', 'colonel'),
27 | ('ft', 'fort'),
28 | ]]
29 |
30 |
31 | def expand_abbreviations(text):
32 | for regex, replacement in _abbreviations:
33 | text = re.sub(regex, replacement, text)
34 | return text
35 |
36 |
37 | def lowercase(text):
38 | return text.lower()
39 |
40 |
41 | def collapse_whitespace(text):
42 | return re.sub(_whitespace_re, ' ', text)
43 |
44 |
45 | def convert_to_ascii(text):
46 | return unidecode(text)
47 |
48 |
49 | def basic_cleaners(text):
50 | text = lowercase(text)
51 | text = collapse_whitespace(text)
52 | return text
53 |
54 |
55 | def transliteration_cleaners(text):
56 | text = convert_to_ascii(text)
57 | text = lowercase(text)
58 | text = collapse_whitespace(text)
59 | return text
60 |
61 | def replace_english_words(text):
62 | text = text.replace("bluetooth не usb", "блютуз не юэсби").replace("mega silk way", "мега силк уэй")
63 | return text
64 |
65 | def kazakh_cleaners(text):
66 | # text = convert_to_ascii(text)
67 | text = lowercase(text)
68 | # text = expand_numbers(text)
69 | text = expand_abbreviations(text)
70 | text = replace_english_words(text)
71 | text = collapse_whitespace(text)
72 | return text.replace("c", "с").strip()
73 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
14 | ]
15 |
16 | _valid_symbol_set = set(valid_symbols)
17 |
18 |
19 | class CMUDict:
20 | def __init__(self, file_or_path, keep_ambiguous=True):
21 | if isinstance(file_or_path, str):
22 | with open(file_or_path, encoding='latin-1') as f:
23 | entries = _parse_cmudict(f)
24 | else:
25 | entries = _parse_cmudict(file_or_path)
26 | if not keep_ambiguous:
27 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
28 | self._entries = entries
29 |
30 | def __len__(self):
31 | return len(self._entries)
32 |
33 | def lookup(self, word):
34 | return self._entries.get(word.upper())
35 |
36 |
37 | _alt_re = re.compile(r'\([0-9]+\)')
38 |
39 |
40 | def _parse_cmudict(file):
41 | cmudict = {}
42 | for line in file:
43 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
44 | parts = line.split(' ')
45 | word = re.sub(_alt_re, '', parts[0])
46 | pronunciation = _get_pronunciation(parts[1])
47 | if pronunciation:
48 | if word in cmudict:
49 | cmudict[word].append(pronunciation)
50 | else:
51 | cmudict[word] = [pronunciation]
52 | return cmudict
53 |
54 |
55 | def _get_pronunciation(s):
56 | parts = s.strip().split(' ')
57 | for part in parts:
58 | if part not in _valid_symbol_set:
59 | return None
60 | return ' '.join(parts)
61 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | from text import cmudict
4 |
5 | _pad = '_'
6 | _punctuation = '!\'(),.:;? '
7 | _special = '-'
8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАӘБВГҒДЕЁЖЗИЙКҚЛМНҢОӨПРСТУҰҮФХҺЦЧШЩЪЫІЬЭЮЯаәбвгғдеёжзийкқлмнңоөпрстуұүфхһцчшщъыіьэюя'
9 | _numbers = '0123456789'
10 |
11 | # Prepend "@" to ARPAbet symbols to ensure uniqueness:
12 | _arpabet = ['@' + s for s in cmudict.valid_symbols]
13 |
14 | # Export all symbols:
15 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + list(_numbers)
16 |
--------------------------------------------------------------------------------
/train_EMA.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 |
4 | from copy import deepcopy
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from torch.utils.tensorboard import SummaryWriter
8 | import data_collate
9 | import data_loader
10 | from utils_data import plot_tensor, save_plot
11 | from model.utils import fix_len_compatibility
12 | from text.symbols import symbols
13 | import utils_data as utils
14 |
15 |
16 | class ModelEmaV2(torch.nn.Module):
17 | def __init__(self, model, decay=0.9999, device=None):
18 | super(ModelEmaV2, self).__init__()
19 | self.model_state_dict = deepcopy(model.state_dict())
20 | self.decay = decay
21 | self.device = device # perform ema on different device from model if set
22 |
23 | def _update(self, model, update_fn):
24 | with torch.no_grad():
25 | for ema_v, model_v in zip(self.model_state_dict.values(), model.state_dict().values()):
26 | if self.device is not None:
27 | model_v = model_v.to(device=self.device)
28 | ema_v.copy_(update_fn(ema_v, model_v))
29 |
30 | def update(self, model):
31 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
32 |
33 | def set(self, model):
34 | self._update(model, update_fn=lambda e, m: m)
35 |
36 | def state_dict(self, destination=None, prefix='', keep_vars=False):
37 | return self.model_state_dict
38 |
39 |
40 | if __name__ == "__main__":
41 | hps = utils.get_hparams()
42 | logger_text = utils.get_logger(hps.model_dir)
43 | logger_text.info(hps)
44 |
45 | out_size = fix_len_compatibility(2 * hps.data.sampling_rate // hps.data.hop_length) # NOTE: 2-sec of mel-spec
46 |
47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48 | torch.manual_seed(hps.train.seed)
49 | np.random.seed(hps.train.seed)
50 |
51 | print('Initializing logger...')
52 | log_dir = hps.model_dir
53 | logger = SummaryWriter(log_dir=log_dir)
54 |
55 | train_dataset, collate, model = utils.get_correct_class(hps)
56 | test_dataset, _, _ = utils.get_correct_class(hps, train=False)
57 |
58 | print('Initializing data loaders...')
59 |
60 | batch_collate = collate
61 | loader = DataLoader(dataset=train_dataset, batch_size=hps.train.batch_size,
62 | collate_fn=batch_collate, drop_last=True,
63 | num_workers=4, shuffle=False) # NOTE: if on server, worker can be 4
64 |
65 | print('Initializing model...')
66 | model = model(**hps.model).to(device)
67 | print('Number of encoder + duration predictor parameters: %.2fm' % (model.encoder.nparams / 1e6))
68 | print('Number of decoder parameters: %.2fm' % (model.decoder.nparams / 1e6))
69 | print('Total parameters: %.2fm' % (model.nparams / 1e6))
70 |
71 | use_gt_dur = getattr(hps.train, "use_gt_dur", False)
72 | if use_gt_dur:
73 | print("++++++++++++++> Using ground truth duration for training")
74 |
75 | print('Initializing optimizer...')
76 | optimizer = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate)
77 |
78 | print('Logging test batch...')
79 | test_batch = test_dataset.sample_test_batch(size=hps.train.test_size)
80 | for i, item in enumerate(test_batch):
81 | mel = item['mel']
82 | logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()),
83 | global_step=0, dataformats='HWC')
84 | save_plot(mel.squeeze(), f'{log_dir}/original_{i}.png')
85 |
86 | try:
87 | model, optimizer, learning_rate, epoch_logged = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "grad_*.pt"), model, optimizer)
88 | epoch_start = epoch_logged + 1
89 | print(f"Loaded checkpoint from {epoch_logged} epoch, resuming training.")
90 | global_step = epoch_logged * (len(train_dataset)/hps.train.batch_size)
91 | except:
92 | print(f"Cannot find trained checkpoint, begin to train from scratch")
93 | epoch_start = 1
94 | global_step = 0
95 | learning_rate = hps.train.learning_rate
96 |
97 | ema_model = ModelEmaV2(model, decay=0.9999) # It's necessary that we put this after loading model.
98 |
99 | print('Start training...')
100 | used_items = set()
101 | iteration = global_step
102 | for epoch in range(epoch_start, hps.train.n_epochs + 1):
103 | model.train()
104 | dur_losses = []
105 | prior_losses = []
106 | diff_losses = []
107 | with tqdm(loader, total=len(train_dataset) // hps.train.batch_size) as progress_bar:
108 | for batch_idx, batch in enumerate(progress_bar):
109 | model.zero_grad()
110 | x, x_lengths = batch['text_padded'].to(device), \
111 | batch['input_lengths'].to(device)
112 | y, y_lengths = batch['mel_padded'].to(device), \
113 | batch['output_lengths'].to(device)
114 | if hps.xvector:
115 | spk = batch['xvector'].to(device)
116 | else:
117 | spk = batch['spk_ids'].to(torch.long).to(device)
118 | emo = batch['emo_ids'].to(torch.long).to(device)
119 |
120 | dur_loss, prior_loss, diff_loss = model.compute_loss(x, x_lengths,
121 | y, y_lengths,
122 | spk=spk,
123 | emo=emo,
124 | out_size=out_size,
125 | use_gt_dur=use_gt_dur,
126 | durs=batch['dur_padded'].to(device) if use_gt_dur else None)
127 | loss = sum([dur_loss, prior_loss, diff_loss])
128 | loss.backward()
129 |
130 | enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(),
131 | max_norm=1)
132 | dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(),
133 | max_norm=1)
134 | optimizer.step()
135 | ema_model.update(model)
136 |
137 | logger.add_scalar('training/duration_loss', dur_loss.item(),
138 | global_step=iteration)
139 | logger.add_scalar('training/prior_loss', prior_loss.item(),
140 | global_step=iteration)
141 | logger.add_scalar('training/diffusion_loss', diff_loss.item(),
142 | global_step=iteration)
143 | logger.add_scalar('training/encoder_grad_norm', enc_grad_norm,
144 | global_step=iteration)
145 | logger.add_scalar('training/decoder_grad_norm', dec_grad_norm,
146 | global_step=iteration)
147 |
148 | dur_losses.append(dur_loss.item())
149 | prior_losses.append(prior_loss.item())
150 | diff_losses.append(diff_loss.item())
151 |
152 | if batch_idx % 5 == 0:
153 | msg = f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}'
154 | progress_bar.set_description(msg)
155 |
156 | iteration += 1
157 |
158 | log_msg = 'Epoch %d: duration loss = %.3f ' % (epoch, float(np.mean(dur_losses)))
159 | log_msg += '| prior loss = %.3f ' % np.mean(prior_losses)
160 | log_msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses)
161 | with open(f'{log_dir}/train.log', 'a') as f:
162 | f.write(log_msg)
163 |
164 | if epoch % hps.train.save_every > 0:
165 | continue
166 |
167 | model.eval()
168 | print('Synthesis...')
169 |
170 | with torch.no_grad():
171 | for i, item in enumerate(test_batch):
172 | if item['utt'] + "/truth" not in used_items:
173 | used_items.add(item['utt'] + "/truth")
174 | x = item['text'].to(torch.long).unsqueeze(0).to(device)
175 | if not hps.xvector:
176 | spk = item['spk_ids']
177 | spk = torch.LongTensor([spk]).to(device)
178 | else:
179 | spk = item["xvector"]
180 | spk = spk.unsqueeze(0).to(device)
181 | emo = item['emo_ids']
182 | emo = torch.LongTensor([emo]).to(device)
183 |
184 | x_lengths = torch.LongTensor([x.shape[-1]]).to(device)
185 |
186 | y_enc, y_dec, attn = model(x, x_lengths, spk=spk, emo=emo, n_timesteps=10)
187 | logger.add_image(f'image_{i}/generated_enc',
188 | plot_tensor(y_enc.squeeze().cpu()),
189 | global_step=iteration, dataformats='HWC')
190 | logger.add_image(f'image_{i}/generated_dec',
191 | plot_tensor(y_dec.squeeze().cpu()),
192 | global_step=iteration, dataformats='HWC')
193 | logger.add_image(f'image_{i}/alignment',
194 | plot_tensor(attn.squeeze().cpu()),
195 | global_step=iteration, dataformats='HWC')
196 | save_plot(y_enc.squeeze().cpu(),
197 | f'{log_dir}/generated_enc_{i}.png')
198 | save_plot(y_dec.squeeze().cpu(),
199 | f'{log_dir}/generated_dec_{i}.png')
200 | save_plot(attn.squeeze().cpu(),
201 | f'{log_dir}/alignment_{i}.png')
202 |
203 | ckpt = model.state_dict()
204 |
205 | utils.save_checkpoint(ema_model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/EMA_grad_{epoch}.pt")
206 | utils.save_checkpoint(model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/grad_{epoch}.pt")
207 |
208 |
--------------------------------------------------------------------------------
/utils_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import glob
4 | import logging
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import data_loader as loaders
8 | import data_collate as collates
9 | import json
10 | from model import GradTTSXvector, GradTTSWithEmo
11 | import torch
12 |
13 |
14 | def intersperse(lst, item):
15 | # Adds blank symbol
16 | result = [item] * (len(lst) * 2 + 1)
17 | result[1::2] = lst
18 | return result
19 |
20 |
21 | def parse_filelist(filelist_path, split_char="|"):
22 | with open(filelist_path, encoding='utf-8') as f:
23 | filepaths_and_text = [line.strip().split(split_char) for line in f]
24 | return filepaths_and_text
25 |
26 |
27 | def latest_checkpoint_path(dir_path, regex="grad_*.pt"):
28 | f_list = glob.glob(os.path.join(dir_path, regex))
29 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
30 | x = f_list[-1]
31 | return x
32 |
33 | def load_checkpoint(checkpoint_path, model, optimizer=None):
34 | assert os.path.isfile(checkpoint_path)
35 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
36 | iteration = 1
37 | if 'iteration' in checkpoint_dict.keys():
38 | iteration = checkpoint_dict['iteration']
39 | if 'learning_rate' in checkpoint_dict.keys():
40 | learning_rate = checkpoint_dict['learning_rate']
41 | else:
42 | learning_rate = None
43 | if optimizer is not None and 'optimizer' in checkpoint_dict.keys():
44 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
45 | saved_state_dict = checkpoint_dict['model']
46 | if hasattr(model, 'module'):
47 | state_dict = model.module.state_dict()
48 | else:
49 | state_dict = model.state_dict()
50 | new_state_dict = {}
51 | for k, v in state_dict.items():
52 | try:
53 | new_state_dict[k] = saved_state_dict[k]
54 | except:
55 | logger.info("%s is not in the checkpoint" % k)
56 | print("%s is not in the checkpoint" % k)
57 | new_state_dict[k] = v
58 | if hasattr(model, 'module'):
59 | model.module.load_state_dict(new_state_dict)
60 | else:
61 | model.load_state_dict(new_state_dict)
62 | logger.info("Loaded checkpoint '{}' (iteration {})".format(
63 | checkpoint_path, iteration))
64 | return model, optimizer, learning_rate, iteration
65 |
66 |
67 | def save_figure_to_numpy(fig):
68 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
69 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
70 | return data
71 |
72 |
73 | def plot_tensor(tensor):
74 | plt.style.use('default')
75 | fig, ax = plt.subplots(figsize=(12, 3))
76 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
77 | plt.colorbar(im, ax=ax)
78 | plt.tight_layout()
79 | fig.canvas.draw()
80 | data = save_figure_to_numpy(fig)
81 | plt.close()
82 | return data
83 |
84 |
85 | def save_plot(tensor, savepath):
86 | plt.style.use('default')
87 | fig, ax = plt.subplots(figsize=(12, 3))
88 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
89 | plt.colorbar(im, ax=ax)
90 | plt.tight_layout()
91 | fig.canvas.draw()
92 | plt.savefig(savepath)
93 | plt.close()
94 | return
95 |
96 |
97 | def get_correct_class(hps, train=True):
98 | if train:
99 | if hps.xvector and hps.pe:
100 | raise NotImplementedError
101 | elif hps.xvector: # no pitch energy
102 | raise NotImplementedError
103 | loader = loaders.XvectorLoader
104 | collate = collates.XvectorCollate
105 | model = GradTTSXvector
106 | dataset = loader(utts=hps.data.train_utts,
107 | hparams=hps.data,
108 | feats_scp=hps.data.train_feats_scp,
109 | utt2phns=hps.data.train_utt2phns,
110 | phn2id=hps.data.phn2id,
111 | utt2phn_duration=hps.data.train_utt2phn_duration,
112 | spk_xvector_scp=hps.data.train_spk_xvector_scp,
113 | utt2spk_name=hps.data.train_utt2spk)
114 | elif hps.pe:
115 | raise NotImplementedError
116 | else: # no PE, no xvector
117 | loader = loaders.SpkIDLoaderWithEmo
118 | collate = collates.SpkIDCollateWithEmo
119 | model = GradTTSWithEmo
120 | dataset = loader(utts=hps.data.train_utts,
121 | hparams=hps.data,
122 | feats_scp=hps.data.train_feats_scp,
123 | utt2text=hps.data.train_utt2phns,
124 | utt2spk=hps.data.train_utt2spk,
125 | utt2emo=hps.data.train_utt2emo)
126 | else:
127 | if hps.xvector and hps.pe:
128 | raise NotImplementedError
129 | elif hps.xvector:
130 | raise NotImplementedError
131 | loader = loaders.XvectorLoader
132 | collate = collates.XvectorCollate
133 | model = GradTTSXvector
134 | dataset = loader(utts=hps.data.val_utts,
135 | hparams=hps.data,
136 | feats_scp=hps.data.val_feats_scp,
137 | utt2phns=hps.data.val_utt2phns,
138 | phn2id=hps.data.phn2id,
139 | utt2phn_duration=hps.data.val_utt2phn_duration,
140 | spk_xvector_scp=hps.data.val_spk_xvector_scp,
141 | utt2spk_name=hps.data.val_utt2spk)
142 | elif hps.pe:
143 | raise NotImplementedError
144 | else: # no PE, no xvector
145 | loader = loaders.SpkIDLoaderWithEmo
146 | collate = collates.SpkIDCollateWithEmo
147 | model = GradTTSWithEmo
148 | dataset = loader(utts=hps.data.val_utts,
149 | hparams=hps.data,
150 | feats_scp=hps.data.val_feats_scp,
151 | utt2text=hps.data.val_utt2phns,
152 | utt2spk=hps.data.val_utt2spk,
153 | utt2emo=hps.data.val_utt2emo)
154 | return dataset, collate(), model
155 |
156 |
157 | def get_hparams(init=True):
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument('-c', '--config', type=str, default="./configs/train_grad.json",
160 | help='JSON file for configuration')
161 | parser.add_argument('-m', '--model', type=str, required=True,
162 | help='Model name')
163 | parser.add_argument('-s', '--seed', type=int, default=1234)
164 | parser.add_argument('--not-pretrained', action='store_true', help='if set to true, then train from scratch')
165 |
166 | args = parser.parse_args()
167 | model_dir = os.path.join("./logs", args.model)
168 |
169 | if not os.path.exists(model_dir):
170 | os.makedirs(model_dir)
171 |
172 | config_path = args.config
173 | config_save_path = os.path.join(model_dir, "config.json")
174 | if init:
175 | with open(config_path, "r") as f:
176 | data = f.read()
177 | with open(config_save_path, "w") as f:
178 | f.write(data)
179 | else:
180 | with open(config_save_path, "r") as f:
181 | data = f.read()
182 | config = json.loads(data)
183 |
184 | hparams = HParams(**config)
185 | hparams.model_dir = model_dir
186 | hparams.train.seed = args.seed
187 | hparams.not_pretrained = args.not_pretrained
188 | return hparams
189 |
190 |
191 | class HParams():
192 | def __init__(self, **kwargs):
193 | for k, v in kwargs.items():
194 | if type(v) == dict:
195 | v = HParams(**v)
196 | self[k] = v
197 |
198 | def keys(self):
199 | return self.__dict__.keys()
200 |
201 | def items(self):
202 | return self.__dict__.items()
203 |
204 | def values(self):
205 | return self.__dict__.values()
206 |
207 | def __len__(self):
208 | return len(self.__dict__)
209 |
210 | def __getitem__(self, key):
211 | return getattr(self, key)
212 |
213 | def __setitem__(self, key, value):
214 | return setattr(self, key, value)
215 |
216 | def __contains__(self, key):
217 | return key in self.__dict__
218 |
219 | def __repr__(self):
220 | return self.__dict__.__repr__()
221 |
222 |
223 | def get_logger(model_dir, filename="train.log"):
224 | global logger
225 | logger = logging.getLogger(os.path.basename(model_dir))
226 | logger.setLevel(logging.DEBUG)
227 |
228 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
229 | if not os.path.exists(model_dir):
230 | os.makedirs(model_dir)
231 | h = logging.FileHandler(os.path.join(model_dir, filename))
232 | h.setLevel(logging.DEBUG)
233 | h.setFormatter(formatter)
234 | logger.addHandler(h)
235 | return logger
236 |
237 |
238 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
239 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
240 | iteration, checkpoint_path))
241 | if hasattr(model, 'module'):
242 | state_dict = model.module.state_dict()
243 | else:
244 | state_dict = model.state_dict()
245 | torch.save({'model': state_dict,
246 | 'iteration': iteration,
247 | 'optimizer': optimizer.state_dict(),
248 | 'learning_rate': learning_rate}, checkpoint_path)
249 |
250 |
251 | def get_hparams_decode(model_dir=None):
252 | parser = argparse.ArgumentParser()
253 | parser.add_argument('-c', '--config', type=str, default="./configs/train_grad.json",
254 | help='JSON file for configuration')
255 | parser.add_argument('-m', '--model', type=str, default=model_dir,
256 | help='Model name')
257 | parser.add_argument('-s', '--seed', type=int, default=1234)
258 | parser.add_argument('-t', "--timesteps", type=int, default=10, help='how many timesteps to perform reverse diffusion')
259 |
260 | parser.add_argument("--stoc", action='store_true', default=False, help="Whether to add stochastic term into decoding")
261 | parser.add_argument("-g", "--guidance", type=float, default=3, help='classifier guidance')
262 | parser.add_argument('-n', '--noise', type=float, default=1.5, help='to multiply sigma')
263 |
264 | parser.add_argument('-f', '--file', type=str, required=True, help='path to a file with texts to synthesize')
265 | parser.add_argument('-r', '--generated_path', type=str, required=True, help='path to save wav files')
266 |
267 | args = parser.parse_args()
268 | model_dir = os.path.join("./logs", args.model)
269 |
270 | if not os.path.exists(model_dir):
271 | os.makedirs(model_dir)
272 |
273 | config_path = args.config
274 | config_save_path = os.path.join(model_dir, "config.json") # NOTE: which config to load
275 | with open(config_path, "r") as f:
276 | data = f.read()
277 | config = json.loads(data)
278 |
279 | hparams = HParams(**config)
280 | hparams.model_dir = model_dir
281 | hparams.train.seed = args.seed
282 |
283 | return hparams, args
284 |
285 |
286 | def get_hparams_decode_two_mixture(model_dir=None):
287 | parser = argparse.ArgumentParser()
288 | parser.add_argument('-c', '--config', type=str, default="./configs/train_grad.json",
289 | help='JSON file for configuration')
290 | parser.add_argument('-m', '--model', type=str, required=False, default='/raid/adal_abilbekov/training_emodiff/Emo_diff/logs/logs_train',
291 | help='Model name')
292 | parser.add_argument('-s', '--seed', type=int, default=1234)
293 | parser.add_argument('--dataset', choices=['train', 'val'], default='val', type=str, help='which dataset to use')
294 | parser.add_argument('--use-control-spk', action='store_true', help='whether to use GT spk or other spk')
295 | parser.add_argument('--control-spk-id', default=None, type=int, help='if use control spk, then which spk')
296 | parser.add_argument("--use-control-emo", action='store_true')
297 | parser.add_argument("--control-emo-id1", type=int)
298 | parser.add_argument("--control-emo-id2", type=int)
299 | parser.add_argument("--emo1-weight", type=float, default=0.5)
300 |
301 | parser.add_argument('--control-spk-name', default=None, type=str, help='if use control spk, then which spk')
302 | parser.add_argument("--max-utt-num", default=100, type=int, help='maximum utts number to decode')
303 | parser.add_argument("--specify-utt-name", default=None, type=str, help='if specified, only decodes for that utt')
304 | parser.add_argument('-t', "--timesteps", type=int, default=10, help='how many timesteps to perform reverse diffusion')
305 |
306 | parser.add_argument("--stoc", action='store_true', default=False, help="Whether to add stochastic term into decoding")
307 | parser.add_argument("-g", "--guidance", type=float, default=3, help='classifier guidance')
308 | parser.add_argument('-n', '--noise', type=float, default=1.5, help='to multiply sigma')
309 |
310 | parser.add_argument('--text', type=str, default=None, help="given text file")
311 |
312 | args = parser.parse_args()
313 | model_dir = os.path.join("./logs", args.model)
314 |
315 | if not os.path.exists(model_dir):
316 | os.makedirs(model_dir)
317 |
318 | config_path = args.config
319 | config_save_path = os.path.join(model_dir, "config.json") # NOTE: which config to load
320 | with open(config_path, "r") as f:
321 | data = f.read()
322 | config = json.loads(data)
323 |
324 | hparams = HParams(**config)
325 | hparams.model_dir = model_dir
326 | hparams.train.seed = args.seed
327 |
328 | if args.use_control_spk:
329 | if hparams.xvector:
330 | assert args.control_spk_name is not None
331 | else:
332 | assert args.control_spk_id is not None
333 |
334 | return hparams, args
335 |
336 |
337 | def get_hparams_classifier_objective():
338 | parser = argparse.ArgumentParser()
339 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
340 | help='JSON file for configuration')
341 | parser.add_argument('-m', '--model', type=str, required=True,
342 | help='Model name')
343 | parser.add_argument('-s', '--seed', type=int, default=1234)
344 | parser.add_argument('--dataset', choices=['train', 'val'], default='val', type=str, help='which dataset to use')
345 | parser.add_argument('--use-control-spk', action='store_true', help='whether to use GT spk or other spk')
346 | parser.add_argument('--control-spk-id', default=None, type=int, help='if use control spk, then which spk')
347 | parser.add_argument("--use-control-emo", action='store_true')
348 | parser.add_argument("--max-utt-num", default=100, type=int, help='maximum utts number to decode')
349 | parser.add_argument("--specify-utt-name", default=None, type=str, help='if specified, only decodes for that utt')
350 |
351 | parser.add_argument('--text', type=str, default=None, help="given text file")
352 | parser.add_argument("--feat", type=str, default=None, help='given feats.scp after CMVN')
353 | parser.add_argument("--dur", type=str, default=None, help='Force durations')
354 |
355 | args = parser.parse_args()
356 | model_dir = os.path.join("./logs", args.model)
357 |
358 | if not os.path.exists(model_dir):
359 | os.makedirs(model_dir)
360 |
361 | config_path = args.config
362 | config_save_path = os.path.join(model_dir, "config.json") # NOTE: which config to load
363 | with open(config_path, "r") as f:
364 | data = f.read()
365 | config = json.loads(data)
366 |
367 | hparams = HParams(**config)
368 | hparams.model_dir = model_dir
369 | hparams.train.seed = args.seed
370 |
371 | if args.use_control_spk:
372 | if hparams.xvector:
373 | assert args.control_spk_name is not None
374 | else:
375 | assert args.control_spk_id is not None
376 |
377 | return hparams, args
378 |
--------------------------------------------------------------------------------
/xutils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def scan_checkpoint(cp_dir, prefix):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern)
55 | if len(cp_list) == 0:
56 | return None
57 | return sorted(cp_list)[-1]
58 |
59 |
--------------------------------------------------------------------------------