├── .gitignore ├── .gitmodules ├── LICENSE ├── LightGrad ├── __init__.py ├── base.py ├── conv.py ├── dataset.py ├── diffusion.py ├── dpm_solver.py ├── model.py ├── text_encoder.py └── utils.py ├── README.md ├── config ├── bznsyp_config.yaml └── ljspeech_config.yaml ├── dataset └── .gitignore ├── inference.ipynb ├── preprocess.py ├── requirements.txt ├── text ├── __init__.py ├── en_cleaners.py ├── g2p_en.py ├── g2p_zh.py └── numbers.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__ 3 | log/ 4 | venv/ 5 | logs/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hifi_gan"] 2 | path = hifi_gan 3 | url = https://github.com/jik876/hifi-gan.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Huawei Technologies Co., Ltd. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /LightGrad/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | from .model import LightGrad 10 | -------------------------------------------------------------------------------- /LightGrad/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class BaseModule(torch.nn.Module): 14 | def __init__(self): 15 | super(BaseModule, self).__init__() 16 | 17 | @property 18 | def nparams(self): 19 | """ 20 | Returns number of trainable parameters of the module. 21 | """ 22 | num_params = 0 23 | for name, param in self.named_parameters(): 24 | if param.requires_grad: 25 | num_params += np.prod(param.detach().cpu().numpy().shape) 26 | return num_params 27 | 28 | 29 | def relocate_input(self, x: list): 30 | """ 31 | Relocates provided tensors to the same device set for the module. 32 | """ 33 | device = next(self.parameters()).device 34 | for i in range(len(x)): 35 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 36 | x[i] = x[i].to(device) 37 | return x 38 | -------------------------------------------------------------------------------- /LightGrad/conv.py: -------------------------------------------------------------------------------- 1 | """modified from (https://github.com/tstandley/Xception-PyTorch/blob/master/xception.py) 2 | """ 3 | import torch 4 | from torch import nn 5 | from .base import BaseModule 6 | 7 | 8 | class Mish(BaseModule): 9 | 10 | def forward(self, x): 11 | return x * torch.tanh(torch.nn.functional.softplus(x)) 12 | 13 | 14 | class SeparableConv2d(BaseModule): 15 | 16 | def __init__(self, 17 | in_channels, 18 | out_channels, 19 | kernel_size=1, 20 | stride=1, 21 | padding=0, 22 | dilation=1, 23 | bias=True): 24 | super(SeparableConv2d, self).__init__() 25 | 26 | self.conv1 = nn.Conv2d(in_channels, 27 | in_channels, 28 | kernel_size, 29 | stride, 30 | padding, 31 | dilation, 32 | groups=in_channels, 33 | bias=bias) 34 | self.pointwise = nn.Conv2d(in_channels, 35 | out_channels, 36 | 1, 37 | 1, 38 | 0, 39 | 1, 40 | 1, 41 | bias=bias) 42 | 43 | def forward(self, x): 44 | x = self.conv1(x) 45 | x = self.pointwise(x) 46 | return x 47 | 48 | 49 | class SeparableLinearAttention(BaseModule): 50 | 51 | def __init__(self, dim, heads=4, dim_head=32): 52 | super().__init__() 53 | self.heads = heads 54 | self.hidden_dim = dim_head * heads 55 | self.to_q = SeparableConv2d(dim, self.hidden_dim, 1, 1, 0, 1, False) 56 | self.to_k = SeparableConv2d(dim, self.hidden_dim, 1, 1, 0, 1, False) 57 | self.to_v = SeparableConv2d(dim, self.hidden_dim, 1, 1, 0, 1, False) 58 | self.to_out = SeparableConv2d(self.hidden_dim, dim, 1, 1, 0, 1) 59 | 60 | def forward(self, x): 61 | b, c, h, w = x.shape 62 | q = self.to_q(x).reshape((b, self.heads, -1, h * w)) # (b,heads,d,h*w) 63 | k = self.to_k(x).reshape((b, self.heads, -1, h * w)) # (b,heads,d,h*w) 64 | v = self.to_v(x).reshape((b, self.heads, -1, h * w)) # (b,heads,e,h*w) 65 | k = k.softmax(dim=-1) 66 | context = torch.matmul(k, v.permute(0, 1, 3, 2)) # (b,heads,d,e) 67 | out = torch.matmul(context.permute(0, 1, 3, 2), q) # (b,heads,e,n) 68 | out = out.reshape(b, self.hidden_dim, h, w) 69 | return self.to_out(out) 70 | 71 | 72 | class SeparableBlock(BaseModule): 73 | 74 | def __init__(self, dim, dim_out, groups=8): 75 | super().__init__() 76 | self.block = torch.nn.Sequential( 77 | SeparableConv2d(dim, dim_out, 3, padding=1), 78 | nn.GroupNorm(groups, dim_out), Mish()) 79 | 80 | def forward(self, x, mask): 81 | output = self.block(x * mask) 82 | return output * mask 83 | 84 | 85 | class SeparableResnetBlock(BaseModule): 86 | 87 | def __init__(self, dim, dim_out, time_emb_dim, groups=8): 88 | super().__init__() 89 | self.mlp = nn.Linear(time_emb_dim, dim_out) 90 | 91 | self.block1 = SeparableBlock(dim, dim_out, groups=groups) 92 | self.block2 = SeparableBlock(dim_out, dim_out, groups=groups) 93 | if dim != dim_out: 94 | self.res_conv = nn.Conv2d(dim, dim_out, 1) 95 | else: 96 | self.res_conv = nn.Identity() 97 | 98 | def forward(self, x, mask, time_emb): 99 | h = self.block1(x, mask) 100 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 101 | h = self.block2(h, mask) 102 | output = h + self.res_conv(x * mask) 103 | return output 104 | -------------------------------------------------------------------------------- /LightGrad/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import json 4 | import math 5 | from librosa.filters import mel as librosa_mel_fn 6 | import re 7 | import torchaudio 8 | 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | 12 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 13 | return torch.log(torch.clamp(x, min=clip_val) * C) 14 | 15 | 16 | def spectral_normalize_torch(magnitudes): 17 | output = dynamic_range_compression_torch(magnitudes) 18 | return output 19 | 20 | 21 | class Dataset(torch.utils.data.Dataset): 22 | 23 | def __init__(self, 24 | datalist_path, 25 | phn2id_path, 26 | sample_rate, 27 | n_fft, 28 | n_mels, 29 | fmin, 30 | fmax, 31 | hop_size, 32 | win_size, 33 | add_blank=True): 34 | super().__init__() 35 | with open(datalist_path) as f: 36 | self.datalist = json.load(f) 37 | with open(phn2id_path) as f: 38 | self.phone_set = json.load(f) 39 | 40 | self.add_blank = add_blank 41 | self.sample_rate = sample_rate 42 | self.n_fft = n_fft 43 | self.n_mels = n_mels 44 | self.fmin = fmin 45 | self.fmax = fmax 46 | self.hop_size = hop_size 47 | self.win_size = win_size 48 | self.cache = {} 49 | self.hann_window = torch.hann_window(win_size) 50 | self.mel_basis = torch.from_numpy( 51 | librosa_mel_fn(sr=sample_rate, 52 | n_fft=n_fft, 53 | n_mels=n_mels, 54 | fmin=fmin, 55 | fmax=fmax)).float() 56 | 57 | def get_vocab_size(self): 58 | # PAD is also considered 59 | return len(self.phone_set) + 1 60 | 61 | def load_audio_and_melspectrogram(self, audio_path): 62 | audio, original_sr = torchaudio.load(audio_path) 63 | if original_sr != self.sample_rate: 64 | audio = torchaudio.functional.resample(audio, original_sr, 65 | self.sample_rate) 66 | audio = torch.nn.functional.pad(audio.unsqueeze(1), (int( 67 | (self.n_fft - self.hop_size) / 68 | 2), int((self.n_fft - self.hop_size) / 2)), 69 | mode='reflect') 70 | audio = audio.squeeze(1) 71 | spec = torch.stft(audio, 72 | self.n_fft, 73 | self.hop_size, 74 | self.win_size, 75 | self.hann_window, 76 | False, 77 | onesided=True, 78 | return_complex=True) 79 | spec = spec.abs() 80 | spec = torch.matmul(self.mel_basis, spec) 81 | spec = spectral_normalize_torch(spec).squeeze(0) 82 | # audio: (1,T) spec: (T,n_mels) 83 | return audio, spec.T 84 | 85 | def load_item(self, i): 86 | #item_name, wav_path, text, phonemes = self.datalist[i] 87 | item_name = self.datalist[i]['name'] 88 | wav_path = self.datalist[i]['wav_path'] 89 | text = self.datalist[i]['text'] 90 | phonemes = self.datalist[i]['phonemes'] 91 | 92 | audio, mel = self.load_audio_and_melspectrogram(wav_path) 93 | if self.add_blank: 94 | phonemes = " ".join(phonemes).split(' ') 95 | phonemes = [''] + phonemes + [''] 96 | ph_idx = [self.phone_set[x] for x in phonemes if x in self.phone_set] 97 | self.cache[i] = { 98 | 'item_name': item_name, 99 | 'txt': text, 100 | 'wav': audio, 101 | 'ph': phonemes, 102 | 'mel': mel, 103 | 'ph_idx': ph_idx 104 | } 105 | return self.cache[i] 106 | 107 | def __getitem__(self, i): 108 | return self.cache.get(i, self.load_item(i)) 109 | 110 | def process_item(self, item): 111 | ph = item['ph'] 112 | # remove original | because this indicates word boundary 113 | ph = re.sub(r' \|', '', ph).split(' ') 114 | if self.add_blank: 115 | # add | as the phoneme boundary 116 | ph = ' | '.join(ph).split(' ') 117 | new_item = { 118 | 'item_name': item['item_name'], 119 | 'txt': item['txt'], 120 | 'ph': ph, 121 | 'mel': item['mel'], 122 | 'ph_idx': [self.phone_set[x] for x in ph if x in self.phone_set], 123 | 'wav': item['wav'], 124 | } 125 | return new_item 126 | 127 | def __len__(self): 128 | return len(self.datalist) 129 | 130 | 131 | def collateFn(batch): 132 | phs_lengths, sorted_idx = torch.sort(torch.LongTensor( 133 | [len(x['ph_idx']) for x in batch]), 134 | descending=True) 135 | 136 | mel_lengths = torch.tensor([batch[i]['mel'].shape[0] for i in sorted_idx]) 137 | padded_phs = pad_sequence( 138 | [torch.tensor(batch[i]['ph_idx']) for i in sorted_idx], 139 | batch_first=True) 140 | 141 | padded_mels = pad_sequence([batch[i]['mel'] for i in sorted_idx], 142 | batch_first=True) 143 | batch_size, old_t, mel_d = padded_mels.shape 144 | txts = [batch[i]['txt'] for i in sorted_idx] 145 | wavs = [batch[i]['wav'] for i in sorted_idx] 146 | item_names = [batch[i]['item_name'] for i in sorted_idx] 147 | if old_t % 4 != 0: 148 | new_t = int(math.ceil(old_t / 4) * 4) 149 | temp = torch.zeros((batch_size, new_t, mel_d)) 150 | temp[:, :old_t] = padded_mels 151 | padded_mels = temp 152 | return { 153 | 'x': padded_phs, 154 | 'x_lengths': phs_lengths, 155 | 'y': padded_mels.permute(0, 2, 1), 156 | 'y_lengths': mel_lengths, 157 | 'txts': txts, 158 | 'wavs': wavs, 159 | 'names': item_names 160 | } 161 | 162 | 163 | if __name__ == '__main__': 164 | import tqdm 165 | #dataset = Dataset('dataset/bznsyp_processed/train_dataset.json', 166 | # 'dataset/bznsyp_processed/phn2id.json', 22050, 167 | # 1024, 80, 0, 8000, 256, 1024) 168 | dataset = Dataset('dataset/ljspeech_processed/train_dataset.json', 169 | 'dataset/ljspeech_processed/phn2id.json', 22050, 1024, 170 | 80, 0, 8000, 256, 1024) 171 | #for i in tqdm.tqdm(range(len(dataset))): 172 | # dataset[i] 173 | data = collateFn([dataset[i] for i in range(2)]) 174 | print(data['x']) 175 | print(data['x_lengths']) 176 | print(data['y'].shape) 177 | print(data['y_lengths']) 178 | print(data['txts']) 179 | print(data['names']) -------------------------------------------------------------------------------- /LightGrad/diffusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import math 10 | import torch 11 | from torch import nn 12 | 13 | from .base import BaseModule 14 | from .conv import (SeparableResnetBlock as ResnetBlock, 15 | SeparableLinearAttention as LinearAttention) 16 | from .dpm_solver import NoiseScheduleVP 17 | 18 | 19 | class Mish(BaseModule): 20 | 21 | def forward(self, x): 22 | return x * torch.tanh(torch.nn.functional.softplus(x)) 23 | 24 | 25 | class Upsample(BaseModule): 26 | 27 | def __init__(self, dim): 28 | super(Upsample, self).__init__() 29 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) 30 | 31 | def forward(self, x): 32 | return self.conv(x) 33 | 34 | 35 | class Downsample(BaseModule): 36 | 37 | def __init__(self, dim): 38 | super(Downsample, self).__init__() 39 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) 40 | 41 | def forward(self, x): 42 | return self.conv(x) 43 | 44 | 45 | class Block(BaseModule): 46 | 47 | def __init__(self, dim, dim_out, groups=8): 48 | super(Block, self).__init__() 49 | self.block = torch.nn.Sequential( 50 | torch.nn.Conv2d(dim, dim_out, 3, padding=1), 51 | torch.nn.GroupNorm(groups, dim_out), Mish()) 52 | 53 | def forward(self, x, mask): 54 | output = self.block(x * mask) 55 | return output * mask 56 | 57 | 58 | class Rezero(BaseModule): 59 | 60 | def __init__(self, fn): 61 | super(Rezero, self).__init__() 62 | self.fn = fn 63 | self.g = torch.nn.Parameter(torch.zeros(1)) 64 | 65 | def forward(self, x): 66 | return self.fn(x) * self.g 67 | 68 | 69 | class Residual(BaseModule): 70 | 71 | def __init__(self, fn): 72 | super(Residual, self).__init__() 73 | self.fn = fn 74 | 75 | def forward(self, x, *args, **kwargs): 76 | output = self.fn(x, *args, **kwargs) + x 77 | return output 78 | 79 | 80 | class SinusoidalPosEmb(BaseModule): 81 | 82 | def __init__(self, dim): 83 | super(SinusoidalPosEmb, self).__init__() 84 | self.dim = dim 85 | 86 | def forward(self, x, scale=1000): 87 | device = x.device 88 | half_dim = self.dim // 2 89 | emb = math.log(10000) / (half_dim - 1) 90 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 91 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 92 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 93 | return emb 94 | 95 | 96 | class GradLogPEstimator2d(BaseModule): 97 | 98 | def __init__(self, 99 | dim, 100 | dim_mults=(1, 2, 4), 101 | groups=8, 102 | n_spks=None, 103 | spk_emb_dim=64, 104 | n_feats=80, 105 | pe_scale=1000): 106 | super(GradLogPEstimator2d, self).__init__() 107 | self.dim = dim # 64 108 | self.dim_mults = dim_mults # (1,2,4) 109 | self.groups = groups # 8 110 | self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 # 1 111 | self.spk_emb_dim = spk_emb_dim # None 112 | self.pe_scale = pe_scale # 1000 113 | 114 | if n_spks > 1: 115 | self.spk_mlp = torch.nn.Sequential( 116 | torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), 117 | torch.nn.Linear(spk_emb_dim * 4, n_feats)) 118 | self.time_pos_emb = SinusoidalPosEmb(dim) 119 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), 120 | torch.nn.Linear(dim * 4, dim), Mish()) 121 | 122 | dims = [ 123 | 2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults) 124 | ] # [2,64,128,256] 125 | in_out = list(zip(dims[:-1], dims[1:])) # [(2,64),(64,128),(128,256)] 126 | self.downs = torch.nn.ModuleList([]) 127 | self.ups = torch.nn.ModuleList([]) 128 | num_resolutions = len(in_out) # 3 129 | 130 | for ind, (dim_in, dim_out) in enumerate(in_out): 131 | is_last = ind >= (num_resolutions - 1) 132 | self.downs.append( 133 | torch.nn.ModuleList([ 134 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim), 135 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim), 136 | Residual(Rezero(LinearAttention(dim_out))), 137 | Downsample(dim_out) 138 | if not is_last else torch.nn.Identity() 139 | ])) 140 | 141 | mid_dim = dims[-1] # 256 142 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 143 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 144 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 145 | 146 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 147 | self.ups.append( 148 | torch.nn.ModuleList([ 149 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), 150 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim), 151 | Residual(Rezero(LinearAttention(dim_in))), 152 | Upsample(dim_in) 153 | ])) 154 | self.final_block = Block(dim, dim) 155 | self.final_conv = torch.nn.Conv2d(dim, 1, 1) 156 | 157 | def forward(self, x, mask, mu, t, spk=None): 158 | """ 159 | Args: 160 | x (_type_): shape (b,80,tx) 161 | mask (_type_): shape (b,1,tx) 162 | mu (_type_): shape (b,80,tx) 163 | t (_type_): shape (b) 164 | spk (_type_, optional): 165 | """ 166 | if not isinstance(spk, type(None)): 167 | s = self.spk_mlp(spk) 168 | 169 | t = self.time_pos_emb(t, scale=self.pe_scale) 170 | t = self.mlp(t) 171 | 172 | if self.n_spks < 2: 173 | x = torch.stack([mu, x], 1) 174 | else: 175 | s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) 176 | x = torch.stack([mu, x, s], 1) 177 | # mask: (b,1,tx) 178 | mask = mask.unsqueeze(1) 179 | hiddens = [] 180 | masks = [mask] 181 | # x: (b,2,80,tx) -> (b,64,40,86) -> (b,128,20,43) -> (b,256,20,43)->mid 182 | # -> (b,256,20,43) -> (b,128,40,86) -> (b,64,80,172) -> (b,64,80,172) -> (b,1,80,172) 183 | for resnet1, resnet2, attn, downsample in self.downs: 184 | mask_down = masks[-1] 185 | x = resnet1(x, mask_down, t) 186 | x = resnet2(x, mask_down, t) 187 | x = attn(x) 188 | hiddens.append(x) 189 | x = downsample(x * mask_down) 190 | masks.append(mask_down[:, :, :, ::2]) 191 | masks = masks[:-1] 192 | mask_mid = masks[-1] 193 | x = self.mid_block1(x, mask_mid, t) 194 | x = self.mid_attn(x) 195 | x = self.mid_block2(x, mask_mid, t) 196 | for resnet1, resnet2, attn, upsample in self.ups: 197 | mask_up = masks.pop() 198 | x = torch.cat((x, hiddens.pop()), dim=1) 199 | x = resnet1(x, mask_up, t) 200 | x = resnet2(x, mask_up, t) 201 | x = attn(x) 202 | x = upsample(x * mask_up) 203 | x = self.final_block(x, mask) 204 | output = self.final_conv(x * mask) 205 | 206 | return (output * mask).squeeze(1) 207 | 208 | 209 | def get_noise(t, beta_init, beta_term, cumulative=False): 210 | if cumulative: 211 | # int(beta_0+(beta_1-beta_0)*t) = beta_0*t+0.5*(beta_1-beta_0)*t^2 212 | noise = beta_init * t + 0.5 * (beta_term - beta_init) * (t**2) 213 | else: 214 | noise = beta_init + (beta_term - beta_init) * t 215 | return noise 216 | 217 | 218 | class Diffusion(BaseModule): 219 | 220 | def __init__(self, 221 | n_feats, 222 | dim, 223 | n_spks=1, 224 | spk_emb_dim=64, 225 | beta_min=0.05, 226 | beta_max=20, 227 | pe_scale=1000): 228 | super(Diffusion, self).__init__() 229 | self.n_feats = n_feats 230 | self.dim = dim 231 | self.n_spks = n_spks 232 | self.spk_emb_dim = spk_emb_dim 233 | self.beta_min = beta_min 234 | self.beta_max = beta_max 235 | self.pe_scale = pe_scale 236 | self.dpm_solver_sch = NoiseScheduleVP() 237 | self.estimator = GradLogPEstimator2d(dim, 238 | n_spks=n_spks, 239 | spk_emb_dim=spk_emb_dim, 240 | pe_scale=pe_scale) 241 | 242 | def forward_diffusion(self, x0, mask, mu, t): 243 | """ 244 | Args: 245 | x0 (_type_): shape (b,80,tx) 246 | mask (_type_): shape (b,1,tx) 247 | mu (_type_): shape (b,80,tx) 248 | t (_type_): shape (b) 249 | """ 250 | time = t.unsqueeze(-1).unsqueeze(-1) # (b,1,1) 251 | cum_noise = get_noise(time, 252 | self.beta_min, 253 | self.beta_max, 254 | cumulative=True) 255 | mean = x0 * torch.exp( 256 | -0.5 * cum_noise) + mu * (1.0 - torch.exp(-0.5 * cum_noise)) 257 | variance = 1.0 - torch.exp(-cum_noise) 258 | z = torch.randn(x0.shape, 259 | dtype=x0.dtype, 260 | device=x0.device, 261 | requires_grad=False) 262 | xt = mean + z * torch.sqrt(variance) 263 | return xt * mask, z * mask 264 | 265 | def get_beta(self, t): 266 | beta = self.beta_min + (self.beta_max - self.beta_min) * t 267 | return beta 268 | 269 | def get_gamma(self, s, t, p=1.0): 270 | beta_integral = self.beta_min + 0.5 * (self.beta_max - 271 | self.beta_min) * (t + s) 272 | beta_integral *= (t - s) 273 | gamma = math.exp(-0.5 * p * beta_integral) 274 | return gamma 275 | 276 | def get_mu(self, s, t): 277 | a = self.get_gamma(s, t) 278 | b = 1.0 - self.get_gamma(0, s, p=2.0) 279 | c = 1.0 - self.get_gamma(0, t, p=2.0) 280 | return a * b / c 281 | 282 | def get_nu(self, s, t): 283 | a = self.get_gamma(0, s) 284 | b = 1.0 - self.get_gamma(s, t, p=2.0) 285 | c = 1.0 - self.get_gamma(0, t, p=2.0) 286 | return a * b / c 287 | 288 | def get_sigma(self, s, t): 289 | a = 1.0 - self.get_gamma(0, s, p=2.0) 290 | b = 1.0 - self.get_gamma(s, t, p=2.0) 291 | c = 1.0 - self.get_gamma(0, t, p=2.0) 292 | return math.sqrt(a * b / c) 293 | 294 | @torch.no_grad() 295 | def reverse_diffusion_ml(self, 296 | z, 297 | mask, 298 | mu, 299 | n_timesteps, 300 | stoc=False, 301 | spk=None): 302 | h = 1.0 / n_timesteps 303 | xt = z * mask 304 | for i in range(n_timesteps): 305 | t = 1.0 - i * h 306 | time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 307 | beta_t = self.get_beta(t) 308 | 309 | kappa = self.get_gamma( 310 | 0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) 311 | kappa /= (self.get_gamma(0, t) * beta_t * h) 312 | kappa -= 1.0 313 | omega = self.get_nu(t - h, t) / self.get_gamma(0, t) 314 | omega += self.get_mu(t - h, t) 315 | omega -= (0.5 * beta_t * h + 1.0) 316 | sigma = self.get_sigma(t - h, t) 317 | 318 | dxt = (mu - xt) * (0.5 * beta_t * h + omega) 319 | dxt -= self.estimator(xt, mask, mu, time, 320 | spk) * (1.0 + kappa) * (beta_t * h) 321 | dxt += torch.randn_like(z, device=z.device) * sigma 322 | xt = (xt - dxt) * mask 323 | return xt 324 | 325 | @torch.no_grad() 326 | def reverse_diffusion_original(self, 327 | z, 328 | mask, 329 | mu, 330 | n_timesteps, 331 | stoc=False, 332 | spk=None): 333 | h = 1.0 / n_timesteps 334 | xt = z * mask 335 | for i in range(n_timesteps): 336 | t = (1.0 - (i + 0.5) * h) * torch.ones( 337 | z.shape[0], dtype=z.dtype, device=z.device) 338 | time = t.unsqueeze(-1).unsqueeze(-1) 339 | noise_t = get_noise(time, 340 | self.beta_min, 341 | self.beta_max, 342 | cumulative=False) 343 | if stoc: # adds stochastic term 344 | dxt_det = 0.5 * (mu - xt) - self.estimator( 345 | xt, mask, mu, t, spk) 346 | dxt_det = dxt_det * noise_t * h 347 | dxt_stoc = torch.randn(z.shape, 348 | dtype=z.dtype, 349 | device=z.device, 350 | requires_grad=False) 351 | dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) 352 | dxt = dxt_det + dxt_stoc 353 | else: 354 | dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk)) 355 | dxt = dxt * noise_t * h 356 | xt = (xt - dxt) * mask 357 | return xt 358 | 359 | @torch.no_grad() 360 | def forward(self, 361 | z, 362 | mask, 363 | mu, 364 | n_timesteps, 365 | stoc=False, 366 | spk=None, 367 | solver='original'): 368 | if solver == 'original': 369 | return self.reverse_diffusion_original(z, mask, mu, n_timesteps, 370 | stoc, spk) 371 | elif solver == 'dpm': 372 | return self.reverse_diffusion_dpm_solver(z, mask, mu, n_timesteps, 373 | stoc, spk) 374 | elif solver == 'ml': 375 | return self.reverse_diffusion_ml(z, mask, mu, n_timesteps, stoc, 376 | spk) 377 | else: 378 | raise ValueError(f'Wrong solver:{solver}!') 379 | 380 | @torch.no_grad() 381 | def reverse_diffusion_dpm_solver(self, 382 | z, 383 | mask, 384 | mu, 385 | n_timesteps, 386 | stoc, 387 | spk=None): 388 | xt = z * mask 389 | yt = xt - mu 390 | T = 1 391 | eps = 1e-3 392 | time = self.dpm_solver_sch.get_time_steps(T, eps, n_timesteps) 393 | for i in range(n_timesteps): 394 | s = torch.ones((xt.shape[0], )).to(xt.device) * time[i] 395 | t = torch.ones((xt.shape[0], )).to(xt.device) * time[i + 1] 396 | ns = self.dpm_solver_sch 397 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) 398 | h = lambda_t - lambda_s 399 | log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff( 400 | s), ns.marginal_log_mean_coeff(t) 401 | sigma_t = ns.marginal_std(t) 402 | phi_1 = torch.expm1(h) 403 | 404 | noise_s = self.estimator(yt + mu, mask, mu, s, spk) 405 | lt = 1 - torch.exp( 406 | -get_noise(s, self.beta_min, self.beta_max, cumulative=True)) 407 | a = torch.exp(log_alpha_t - log_alpha_s) 408 | b = sigma_t * phi_1 * torch.sqrt(lt) 409 | yt = a * yt + (b * noise_s) 410 | xt = yt + mu 411 | return xt 412 | 413 | def loss_t(self, x0, mask, mu, t, spk=None): 414 | """ 415 | Args: 416 | x0 (_type_): shape (b,80,tx) 417 | mask (_type_): shape (b,1,tx) 418 | mu (_type_): shape (b,80,tx) 419 | t (_type_): shape (b) 420 | spk (_type_, optional): 421 | """ 422 | xt, z = self.forward_diffusion(x0, mask, mu, t) 423 | time = t.unsqueeze(-1).unsqueeze(-1) 424 | cum_noise = get_noise(time, 425 | self.beta_min, 426 | self.beta_max, 427 | cumulative=True) 428 | noise_estimation = self.estimator(xt, mask, mu, t, spk) 429 | noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) 430 | loss = torch.sum( 431 | (noise_estimation + z)**2) / (torch.sum(mask) * self.n_feats) 432 | return loss, xt 433 | 434 | def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5): 435 | """ 436 | Args: 437 | x0 (_type_): shape (b,80,tx) 438 | mask (_type_): shape (b,1,tx) 439 | mu (_type_): shape (b,80,tx) 440 | """ 441 | t = torch.rand(x0.shape[0], 442 | dtype=x0.dtype, 443 | device=x0.device, 444 | requires_grad=False) 445 | t = torch.clamp(t, offset, 1.0 - offset) 446 | # t: (b) 447 | return self.loss_t(x0, mask, mu, t, spk) 448 | -------------------------------------------------------------------------------- /LightGrad/dpm_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NoiseScheduleVP: 5 | 6 | def __init__(self, beta_0=0.05, beta_1=20): 7 | self.beta_0 = beta_0 8 | self.beta_1 = beta_1 9 | self.T = 1. 10 | 11 | def marginal_log_mean_coeff(self, t): 12 | return -0.25 * t**2 * (self.beta_1 - 13 | self.beta_0) - 0.5 * t * self.beta_0 14 | 15 | def marginal_std(self, t): 16 | return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) 17 | 18 | def marginal_lambda(self, t): 19 | log_mean_coeff = self.marginal_log_mean_coeff(t) 20 | log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) 21 | return log_mean_coeff - log_std 22 | 23 | def inverse_lambda(self, lamb): 24 | tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp( 25 | -2. * lamb, 26 | torch.zeros((1, )).to(lamb)) 27 | Delta = self.beta_0**2 + tmp 28 | return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - 29 | self.beta_0) 30 | 31 | def get_time_steps(self, t_T, t_0, N): 32 | lambda_T = self.marginal_lambda(torch.tensor(t_T)) 33 | lambda_0 = self.marginal_lambda(torch.tensor(t_0)) 34 | logSNR_steps = torch.linspace(lambda_T, lambda_0, N + 1) 35 | return self.inverse_lambda(logSNR_steps) 36 | -------------------------------------------------------------------------------- /LightGrad/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import math 10 | import random 11 | 12 | import torch 13 | 14 | import monotonic_align 15 | from .base import BaseModule 16 | from .text_encoder import TextEncoder 17 | from .diffusion import Diffusion 18 | from .utils import (sequence_mask, generate_path, duration_loss, 19 | fix_len_compatibility) 20 | 21 | 22 | class LightGrad(BaseModule): 23 | 24 | def __init__(self, n_vocab, n_spks, spk_emb_dim, n_enc_channels, 25 | filter_channels, filter_channels_dp, n_heads, n_enc_layers, 26 | enc_kernel, enc_dropout, window_size, n_feats, dec_dim, 27 | beta_min, beta_max, pe_scale): 28 | super().__init__() 29 | self.n_vocab = n_vocab 30 | self.n_spks = n_spks 31 | self.spk_emb_dim = spk_emb_dim 32 | self.n_enc_channels = n_enc_channels 33 | self.filter_channels = filter_channels 34 | self.filter_channels_dp = filter_channels_dp 35 | self.n_heads = n_heads 36 | self.n_enc_layers = n_enc_layers 37 | self.enc_kernel = enc_kernel 38 | self.enc_dropout = enc_dropout 39 | self.window_size = window_size 40 | self.n_feats = n_feats 41 | self.dec_dim = dec_dim 42 | self.beta_min = beta_min 43 | self.beta_max = beta_max 44 | self.pe_scale = pe_scale 45 | 46 | if n_spks > 1: 47 | self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) 48 | self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, 49 | filter_channels, filter_channels_dp, 50 | n_heads, n_enc_layers, enc_kernel, 51 | enc_dropout, window_size) 52 | self.decoder = Diffusion(n_feats, dec_dim, n_spks, spk_emb_dim, 53 | beta_min, beta_max, pe_scale) 54 | 55 | @classmethod 56 | def build_model(cls, config, vocab_size): 57 | return cls(vocab_size, 1, None, config['n_enc_channels'], 58 | config['filter_channels'], config['filter_channels_dp'], 59 | config['n_heads'], config['n_enc_layers'], 60 | config['enc_kernel'], config['enc_dropout'], 61 | config['window_size'], config['n_mels'], config['dec_dim'], 62 | config['beta_min'], config['beta_max'], config['pe_scale']) 63 | 64 | @torch.no_grad() 65 | def forward_streaming(self, 66 | x, 67 | x_lengths, 68 | n_timesteps, 69 | temperature=1.0, 70 | stoc=False, 71 | spk=None, 72 | length_scale=1.0, 73 | out_size=None, 74 | solver='original'): 75 | # if chunk_method == 'simple': 76 | # return self.forward_streaming_simple_chunk(x, x_lengths, 77 | # n_timesteps, 78 | # temperature, stoc, spk, 79 | # length_scale, out_size, 80 | # solver) 81 | # elif chunk_method == 'padding': 82 | # return self.forward_streaming_padding_chunk( 83 | # x, x_lengths, n_timesteps, temperature, stoc, spk, 84 | # length_scale, out_size, solver) 85 | # else: 86 | # raise ValueError(f'Wrong chunk method: {chunk_method}!') 87 | 88 | #@torch.no_grad() 89 | #def forward_streaming_padding_chunk(self, x, x_lengths, n_timesteps, 90 | # temperature, stoc, spk, length_scale, 91 | # out_size, solver): 92 | """ 93 | Generates mel-spectrogram from text. Returns: 94 | 1. encoder outputs 95 | 2. decoder outputs 96 | 3. generated alignment 97 | 98 | Args: 99 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 100 | x_lengths (torch.Tensor): lengths of texts in batch. 101 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 102 | temperature (float, optional): controls variance of terminal distribution. 103 | stoc (bool, optional): flag that adds stochastic term to the decoder sampler. 104 | Usually, does not provide synthesis improvements. 105 | length_scale (float, optional): controls speech pace. 106 | Increase value to slow down generated speech and vice versa. 107 | """ 108 | x, x_lengths = self.relocate_input([x, x_lengths]) 109 | assert x.shape[0] == 1 # streaming inference only support batch size 1 110 | if self.n_spks > 1: 111 | # Get speaker embedding 112 | spk = self.spk_emb(spk) 113 | 114 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 115 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) 116 | w = (torch.exp(logw) - 1) * x_mask 117 | w_ceil = torch.ceil(w * length_scale).squeeze(1) 118 | y_lengths = torch.clamp(torch.sum(w_ceil, dim=1), min=0).long() 119 | y_max_length = int(y_lengths.max()) 120 | y_max_length_ = fix_len_compatibility(y_max_length) 121 | out_size = fix_len_compatibility(out_size) 122 | 123 | # Using obtained durations `w` construct alignment map `attn` 124 | y_mask = sequence_mask(y_lengths, 125 | y_max_length_).unsqueeze(1).to(x_mask.dtype) 126 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 127 | attn = generate_path(w_ceil, attn_mask.squeeze(1)) 128 | (num_chunks, chunk_lengths, start_frames, end_frames, lpad, 129 | rpad) = generate_idxs(w_ceil, out_size) 130 | for i in range(num_chunks): 131 | lp = lpad[i] 132 | rp = rpad[i] 133 | l = chunk_lengths[i] 134 | 135 | # start_idx should be divisible by downsampling factor 136 | start_idx = fix_len_compatibility(start_frames[i] - lp, 137 | type='floor') 138 | # adjust left padding part according to start_idx 139 | lp += start_frames[i] - lp - start_idx 140 | end_idx = min(y_max_length_, 141 | fix_len_compatibility(end_frames[i] + rp)) 142 | 143 | y_mask_cut = y_mask[:, :, start_idx:end_idx] 144 | attn_cut = attn[:, :, start_idx:end_idx] 145 | mu_y_cut = torch.matmul(attn_cut.transpose(1, 2), 146 | mu_x.transpose(1, 2)) 147 | mu_y_cut = mu_y_cut.transpose(1, 2) 148 | z_cut = mu_y_cut + torch.randn_like( 149 | mu_y_cut, device=mu_y_cut.device) / temperature 150 | decoder_output_cut = self.decoder(z_cut, y_mask_cut, mu_y_cut, 151 | n_timesteps, stoc, spk, solver) 152 | yield (mu_y_cut[:, :, lp:lp + l], 153 | decoder_output_cut[:, :, lp:lp + l], attn_cut[:, :, 154 | lp:lp + l]) 155 | 156 | @torch.no_grad() 157 | def forward(self, 158 | x, 159 | x_lengths, 160 | n_timesteps, 161 | temperature=1.0, 162 | stoc=False, 163 | spk=None, 164 | length_scale=1.0, 165 | solver='original'): 166 | """ 167 | Generates mel-spectrogram from text. Returns: 168 | 1. encoder outputs 169 | 2. decoder outputs 170 | 3. generated alignment 171 | 172 | Args: 173 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 174 | x_lengths (torch.Tensor): lengths of texts in batch. 175 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 176 | temperature (float, optional): controls variance of terminal distribution. 177 | stoc (bool, optional): flag that adds stochastic term to the decoder sampler. 178 | Usually, does not provide synthesis improvements. 179 | length_scale (float, optional): controls speech pace. 180 | Increase value to slow down generated speech and vice versa. 181 | """ 182 | x, x_lengths = self.relocate_input([x, x_lengths]) 183 | 184 | if self.n_spks > 1: 185 | # Get speaker embedding 186 | spk = self.spk_emb(spk) 187 | 188 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 189 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) 190 | 191 | #w = torch.exp(logw) * x_mask 192 | w = (torch.exp(logw) - 1) * x_mask 193 | w_ceil = torch.ceil(w * length_scale) 194 | # y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 195 | y_lengths = torch.clamp(torch.sum(w_ceil, [1, 2]), min=0).long() 196 | y_max_length = int(y_lengths.max()) 197 | y_max_length_ = fix_len_compatibility(y_max_length) 198 | 199 | # Using obtained durations `w` construct alignment map `attn` 200 | y_mask = sequence_mask(y_lengths, 201 | y_max_length_).unsqueeze(1).to(x_mask.dtype) 202 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 203 | attn = generate_path(w_ceil.squeeze(1), 204 | attn_mask.squeeze(1)).unsqueeze(1) 205 | 206 | # Align encoded text and get mu_y 207 | mu_y = torch.matmul( 208 | attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 209 | mu_y = mu_y.transpose(1, 2) 210 | encoder_outputs = mu_y[:, :, :y_max_length] 211 | 212 | # Sample latent representation from terminal distribution N(mu_y, I) 213 | z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature 214 | # Generate sample by performing reverse dynamics 215 | decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk, 216 | solver) 217 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 218 | 219 | return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] 220 | 221 | def compute_loss(self, 222 | x, 223 | x_lengths, 224 | y, 225 | y_lengths, 226 | spk=None, 227 | out_size=None): 228 | """ 229 | Computes 3 losses: 230 | 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 231 | 2. prior loss: loss between mel-spectrogram and encoder outputs. 232 | 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. 233 | 234 | Args: 235 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 236 | x_lengths (torch.Tensor): lengths of texts in batch. 237 | y (torch.Tensor): batch of corresponding mel-spectrograms. 238 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. 239 | out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. 240 | Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. 241 | """ 242 | out_size = fix_len_compatibility(out_size) 243 | x, x_lengths, y, y_lengths = self.relocate_input( 244 | [x, x_lengths, y, y_lengths]) 245 | 246 | if self.n_spks > 1: 247 | # Get speaker embedding 248 | spk = self.spk_emb(spk) 249 | 250 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 251 | mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) 252 | # mu_x: (b,80,tx) 253 | # logw: (b,1,tx) 254 | # y: (b,80,ty) 255 | y_max_length = y.shape[-1] 256 | 257 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 258 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 259 | 260 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram 261 | with torch.no_grad(): 262 | # sum(-0.5*log(2*pi*sigma_i^2)) 263 | const = -0.5 * math.log(2 * math.pi) * self.n_feats 264 | # factor: (b,80,tx) 265 | factor = -0.5 * torch.ones( 266 | mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) 267 | # y_square: (b,tx,ty), y_square_{i,j}: mu_i is aligned with y_j 268 | # sum(-0.5*y_i^2*sigma^(-2)) 269 | y_square = torch.matmul(factor.transpose(1, 2), y**2) 270 | # y_mu_double: (b,tx,ty), sum(y_i*mu_i*sigma_i^(-2)) 271 | y_mu_double = torch.matmul(mu_x.transpose(1, 2), y) 272 | # mu_square: (b,tx,1), -0.5*sum(mu_i^2*sigma_i^(-2)) 273 | mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) 274 | log_prior = y_square + y_mu_double + mu_square + const 275 | 276 | attn = monotonic_align.maximum_path( 277 | log_prior.permute(0, 2, 1), 278 | attn_mask.squeeze(1).permute(0, 2, 1)).permute(0, 2, 1) 279 | # attn: (b,tx,ty) 280 | # attn = attn.detach() 281 | 282 | # Compute loss between predicted log-scaled durations and those obtained from MAS 283 | logw_ = torch.log(1 + torch.sum(attn.unsqueeze(1), -1)) * x_mask 284 | # logw_ = torch.log(1 + torch.sum(attn.unsqueeze(1), -1)) * x_mask 285 | dur_loss = duration_loss(logw, logw_, x_lengths) 286 | 287 | # Cut a small segment of mel-spectrogram in order to increase batch size 288 | if not isinstance(out_size, type(None)): 289 | max_offset = (y_lengths - out_size).clamp(0) 290 | offset_ranges = list( 291 | zip([0] * max_offset.shape[0], 292 | max_offset.cpu().numpy())) 293 | out_offset = torch.LongTensor([ 294 | torch.tensor( 295 | random.choice(range(start, end)) if end > start else 0) 296 | for start, end in offset_ranges 297 | ]).to(y_lengths) 298 | 299 | attn_cut = torch.zeros(attn.shape[0], 300 | attn.shape[1], 301 | out_size, 302 | dtype=attn.dtype, 303 | device=attn.device) 304 | y_cut = torch.zeros(y.shape[0], 305 | self.n_feats, 306 | out_size, 307 | dtype=y.dtype, 308 | device=y.device) 309 | y_cut_lengths = [] 310 | for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): 311 | y_cut_length = out_size + (y_lengths[i] - out_size).clamp( 312 | None, 0) 313 | y_cut_lengths.append(y_cut_length) 314 | cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length 315 | y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] 316 | attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] 317 | y_cut_lengths = torch.LongTensor(y_cut_lengths) 318 | y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) 319 | 320 | attn = attn_cut 321 | y = y_cut 322 | y_mask = y_cut_mask 323 | 324 | # Align encoded text with mel-spectrogram and get mu_y segment 325 | mu_y = torch.matmul(attn.transpose(1, 2), mu_x.transpose(1, 2)) 326 | mu_y = mu_y.transpose(1, 2) 327 | # mu_y: (b,80,t_y_clip) 328 | # Compute loss of score-based decoder 329 | diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk) 330 | 331 | # Compute loss between aligned encoder outputs and mel-spectrogram 332 | # prior_loss: sum(0.5*log(2*pi*sigma_i^2)+0.5*(y_i-mu_i)^2*sigma_i^(-2)) 333 | prior_loss = torch.sum(0.5 * ((y - mu_y)**2 + math.log(2 * math.pi)) * 334 | y_mask) 335 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) 336 | 337 | return dur_loss, prior_loss, diff_loss 338 | 339 | 340 | def generate_idxs(durations, chunk_frames): 341 | """ 342 | 343 | Args: 344 | durations (_type_): duration for each token. Shape (1,tx) 345 | chunk_frames (_type_): frames per chunk. 346 | 347 | """ 348 | durations = durations.flatten() 349 | cum_sum = durations.cumsum(dim=0).int() 350 | 351 | idx = torch.div(cum_sum, chunk_frames, rounding_mode='trunc').int() 352 | start_token_idx = 0 353 | num_chunks = idx.max() + 1 354 | lengths = torch.zeros((num_chunks), 355 | device=durations.device, 356 | dtype=torch.int) 357 | lpad = torch.zeros_like(lengths, device=lengths.device) 358 | rpad = torch.zeros_like(lengths, device=lengths.device) 359 | for i in range(num_chunks): 360 | duration_chunk_mask = i == idx 361 | duration_chunk_tokens = duration_chunk_mask.sum() 362 | duration_chunk = durations * duration_chunk_mask 363 | duration_chunk_frames = duration_chunk.sum() 364 | if i > 0: 365 | lpad[i] = durations[start_token_idx - 1] 366 | if i < num_chunks - 1: 367 | rpad[i] = durations[start_token_idx + duration_chunk_tokens] 368 | lengths[i] = duration_chunk_frames 369 | start_token_idx += duration_chunk_tokens 370 | end_frames = lengths.cumsum(dim=0) 371 | start_frames = end_frames - lengths 372 | return num_chunks, lengths, start_frames, end_frames, lpad, rpad 373 | -------------------------------------------------------------------------------- /LightGrad/text_encoder.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import math 4 | 5 | import torch 6 | 7 | from .base import BaseModule 8 | from .utils import sequence_mask, convert_pad_shape 9 | 10 | 11 | class LayerNorm(BaseModule): 12 | def __init__(self, channels, eps=1e-4): 13 | super(LayerNorm, self).__init__() 14 | self.channels = channels 15 | self.eps = eps 16 | 17 | self.gamma = torch.nn.Parameter(torch.ones(channels)) 18 | self.beta = torch.nn.Parameter(torch.zeros(channels)) 19 | 20 | def forward(self, x): 21 | n_dims = len(x.shape) 22 | mean = torch.mean(x, 1, keepdim=True) 23 | variance = torch.mean((x - mean)**2, 1, keepdim=True) 24 | 25 | x = (x - mean) * torch.rsqrt(variance + self.eps) 26 | 27 | shape = [1, -1] + [1] * (n_dims - 2) 28 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 29 | return x 30 | 31 | 32 | class ConvReluNorm(BaseModule): 33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, 34 | n_layers, p_dropout): 35 | super(ConvReluNorm, self).__init__() 36 | self.in_channels = in_channels 37 | self.hidden_channels = hidden_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = kernel_size 40 | self.n_layers = n_layers 41 | self.p_dropout = p_dropout 42 | 43 | self.conv_layers = torch.nn.ModuleList() 44 | self.norm_layers = torch.nn.ModuleList() 45 | self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, 46 | kernel_size, padding=kernel_size//2)) 47 | self.norm_layers.append(LayerNorm(hidden_channels)) 48 | self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) 49 | for _ in range(n_layers - 1): 50 | self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, 51 | kernel_size, padding=kernel_size//2)) 52 | self.norm_layers.append(LayerNorm(hidden_channels)) 53 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) 54 | self.proj.weight.data.zero_() 55 | self.proj.bias.data.zero_() 56 | 57 | def forward(self, x, x_mask): 58 | x_org = x 59 | for i in range(self.n_layers): 60 | x = self.conv_layers[i](x * x_mask) 61 | x = self.norm_layers[i](x) 62 | x = self.relu_drop(x) 63 | x = x_org + self.proj(x) 64 | return x * x_mask 65 | 66 | 67 | class DurationPredictor(BaseModule): 68 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): 69 | super(DurationPredictor, self).__init__() 70 | self.in_channels = in_channels 71 | self.filter_channels = filter_channels 72 | self.p_dropout = p_dropout 73 | 74 | self.drop = torch.nn.Dropout(p_dropout) 75 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, 76 | kernel_size, padding=kernel_size//2) 77 | self.norm_1 = LayerNorm(filter_channels) 78 | self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, 79 | kernel_size, padding=kernel_size//2) 80 | self.norm_2 = LayerNorm(filter_channels) 81 | self.proj = torch.nn.Conv1d(filter_channels, 1, 1) 82 | 83 | def forward(self, x, x_mask): 84 | x = self.conv_1(x * x_mask) 85 | x = torch.relu(x) 86 | x = self.norm_1(x) 87 | x = self.drop(x) 88 | x = self.conv_2(x * x_mask) 89 | x = torch.relu(x) 90 | x = self.norm_2(x) 91 | x = self.drop(x) 92 | x = self.proj(x * x_mask) 93 | return x * x_mask 94 | 95 | 96 | class MultiHeadAttention(BaseModule): 97 | def __init__(self, channels, out_channels, n_heads, window_size=None, 98 | heads_share=True, p_dropout=0.0, proximal_bias=False, 99 | proximal_init=False): 100 | super(MultiHeadAttention, self).__init__() 101 | assert channels % n_heads == 0 102 | 103 | self.channels = channels 104 | self.out_channels = out_channels 105 | self.n_heads = n_heads 106 | self.window_size = window_size 107 | self.heads_share = heads_share 108 | self.proximal_bias = proximal_bias 109 | self.p_dropout = p_dropout 110 | self.attn = None 111 | 112 | self.k_channels = channels // n_heads 113 | self.conv_q = torch.nn.Conv1d(channels, channels, 1) 114 | self.conv_k = torch.nn.Conv1d(channels, channels, 1) 115 | self.conv_v = torch.nn.Conv1d(channels, channels, 1) 116 | if window_size is not None: 117 | n_heads_rel = 1 if heads_share else n_heads 118 | rel_stddev = self.k_channels**-0.5 119 | self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, 120 | window_size * 2 + 1, self.k_channels) * rel_stddev) 121 | self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, 122 | window_size * 2 + 1, self.k_channels) * rel_stddev) 123 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) 124 | self.drop = torch.nn.Dropout(p_dropout) 125 | 126 | torch.nn.init.xavier_uniform_(self.conv_q.weight) 127 | torch.nn.init.xavier_uniform_(self.conv_k.weight) 128 | if proximal_init: 129 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 130 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 131 | torch.nn.init.xavier_uniform_(self.conv_v.weight) 132 | 133 | def forward(self, x, c, attn_mask=None): 134 | q = self.conv_q(x) 135 | k = self.conv_k(c) 136 | v = self.conv_v(c) 137 | 138 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 139 | 140 | x = self.conv_o(x) 141 | return x 142 | 143 | def attention(self, query, key, value, mask=None): 144 | b, d, t_s, t_t = (*key.size(), query.size(2)) 145 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 146 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 147 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 148 | 149 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 150 | if self.window_size is not None: 151 | assert t_s == t_t, "Relative attention is only available for self-attention." 152 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 153 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 154 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 155 | scores_local = rel_logits / math.sqrt(self.k_channels) 156 | scores = scores + scores_local 157 | if self.proximal_bias: 158 | assert t_s == t_t, "Proximal bias is only available for self-attention." 159 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, 160 | dtype=scores.dtype) 161 | if mask is not None: 162 | scores = scores.masked_fill(mask == 0, -1e4) 163 | p_attn = torch.nn.functional.softmax(scores, dim=-1) 164 | p_attn = self.drop(p_attn) 165 | output = torch.matmul(p_attn, value) 166 | if self.window_size is not None: 167 | relative_weights = self._absolute_position_to_relative_position(p_attn) 168 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 169 | output = output + self._matmul_with_relative_values(relative_weights, 170 | value_relative_embeddings) 171 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) 172 | return output, p_attn 173 | 174 | def _matmul_with_relative_values(self, x, y): 175 | ret = torch.matmul(x, y.unsqueeze(0)) 176 | return ret 177 | 178 | def _matmul_with_relative_keys(self, x, y): 179 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 180 | return ret 181 | 182 | def _get_relative_embeddings(self, relative_embeddings, length): 183 | pad_length = max(length - (self.window_size + 1), 0) 184 | slice_start_position = max((self.window_size + 1) - length, 0) 185 | slice_end_position = slice_start_position + 2 * length - 1 186 | if pad_length > 0: 187 | padded_relative_embeddings = torch.nn.functional.pad( 188 | relative_embeddings, convert_pad_shape([[0, 0], 189 | [pad_length, pad_length], [0, 0]])) 190 | else: 191 | padded_relative_embeddings = relative_embeddings 192 | used_relative_embeddings = padded_relative_embeddings[:, 193 | slice_start_position:slice_end_position] 194 | return used_relative_embeddings 195 | 196 | def _relative_position_to_absolute_position(self, x): 197 | batch, heads, length, _ = x.size() 198 | x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 199 | x_flat = x.view([batch, heads, length * 2 * length]) 200 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) 201 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 202 | return x_final 203 | 204 | def _absolute_position_to_relative_position(self, x): 205 | batch, heads, length, _ = x.size() 206 | x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 207 | x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) 208 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 209 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 210 | return x_final 211 | 212 | def _attention_bias_proximal(self, length): 213 | r = torch.arange(length, dtype=torch.float32) 214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 216 | 217 | 218 | class FFN(BaseModule): 219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, 220 | p_dropout=0.0): 221 | super(FFN, self).__init__() 222 | self.in_channels = in_channels 223 | self.out_channels = out_channels 224 | self.filter_channels = filter_channels 225 | self.kernel_size = kernel_size 226 | self.p_dropout = p_dropout 227 | 228 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, 229 | padding=kernel_size//2) 230 | self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, 231 | padding=kernel_size//2) 232 | self.drop = torch.nn.Dropout(p_dropout) 233 | 234 | def forward(self, x, x_mask): 235 | x = self.conv_1(x * x_mask) 236 | x = torch.relu(x) 237 | x = self.drop(x) 238 | x = self.conv_2(x * x_mask) 239 | return x * x_mask 240 | 241 | 242 | class Encoder(BaseModule): 243 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, 244 | kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): 245 | super(Encoder, self).__init__() 246 | self.hidden_channels = hidden_channels 247 | self.filter_channels = filter_channels 248 | self.n_heads = n_heads 249 | self.n_layers = n_layers 250 | self.kernel_size = kernel_size 251 | self.p_dropout = p_dropout 252 | self.window_size = window_size 253 | 254 | self.drop = torch.nn.Dropout(p_dropout) 255 | self.attn_layers = torch.nn.ModuleList() 256 | self.norm_layers_1 = torch.nn.ModuleList() 257 | self.ffn_layers = torch.nn.ModuleList() 258 | self.norm_layers_2 = torch.nn.ModuleList() 259 | for _ in range(self.n_layers): 260 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, 261 | n_heads, window_size=window_size, p_dropout=p_dropout)) 262 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 263 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, 264 | filter_channels, kernel_size, p_dropout=p_dropout)) 265 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 266 | 267 | def forward(self, x, x_mask): 268 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 269 | for i in range(self.n_layers): 270 | x = x * x_mask 271 | y = self.attn_layers[i](x, x, attn_mask) 272 | y = self.drop(y) 273 | x = self.norm_layers_1[i](x + y) 274 | y = self.ffn_layers[i](x, x_mask) 275 | y = self.drop(y) 276 | x = self.norm_layers_2[i](x + y) 277 | x = x * x_mask 278 | return x 279 | 280 | 281 | class TextEncoder(BaseModule): 282 | def __init__(self, n_vocab, n_feats, n_channels, filter_channels, 283 | filter_channels_dp, n_heads, n_layers, kernel_size, 284 | p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): 285 | super(TextEncoder, self).__init__() 286 | self.n_vocab = n_vocab 287 | self.n_feats = n_feats 288 | self.n_channels = n_channels 289 | self.filter_channels = filter_channels 290 | self.filter_channels_dp = filter_channels_dp 291 | self.n_heads = n_heads 292 | self.n_layers = n_layers 293 | self.kernel_size = kernel_size 294 | self.p_dropout = p_dropout 295 | self.window_size = window_size 296 | self.spk_emb_dim = spk_emb_dim 297 | self.n_spks = n_spks 298 | 299 | self.emb = torch.nn.Embedding(n_vocab, n_channels) 300 | torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) 301 | 302 | self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, 303 | kernel_size=5, n_layers=3, p_dropout=0.5) 304 | 305 | self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, 306 | kernel_size, 0, window_size=window_size) 307 | 308 | self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) 309 | self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, 310 | 5, 0.2) 311 | 312 | def forward(self, x, x_lengths, spk=None): 313 | x = self.emb(x) * math.sqrt(self.n_channels) 314 | x = torch.transpose(x, 1, -1) 315 | # x: (b,d,tx) 316 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 317 | 318 | x = self.prenet(x, x_mask) 319 | if self.n_spks > 1: 320 | x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) 321 | x = self.encoder(x, x_mask) 322 | # x: (b,n_channels,tx) 323 | mu = self.proj_m(x) * x_mask 324 | # mu: (b,n_feats,tx) 325 | x_dp = torch.detach(x) 326 | 327 | logw = self.proj_w(x_dp, x_mask) 328 | # logw: (b,d,1) 329 | return mu, logw, x_mask 330 | -------------------------------------------------------------------------------- /LightGrad/utils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import torch 4 | import librosa 5 | import numpy as np 6 | import math 7 | 8 | 9 | def sequence_mask(length, max_length=None): 10 | """Generating mask tensor according to `length`. 11 | 12 | Args: 13 | length (Tensor): length. 14 | max_length (int, optional): max length. Defaults to None. 15 | 16 | Returns: 17 | Tensor: mask tensor of shape (b,t), where t is the maximum of `length`. 18 | True indicates a non-padding element. 19 | """ 20 | if max_length is None: 21 | max_length = length.max() 22 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 23 | return x.unsqueeze(0) < length.unsqueeze(1) 24 | 25 | 26 | def fix_len_compatibility(length, num_downsamplings_in_unet=2, type='ceil'): 27 | factor = 2**num_downsamplings_in_unet 28 | if type == 'ceil': 29 | return int(math.ceil(length / factor) * factor) 30 | elif type == 'floor': 31 | return int(math.floor(length / factor) * factor) 32 | else: 33 | raise ValueError(f'Wrong type: {type}') 34 | 35 | 36 | def convert_pad_shape(pad_shape): 37 | l = pad_shape[::-1] 38 | pad_shape = [item for sublist in l for item in sublist] 39 | return pad_shape 40 | 41 | 42 | def generate_path(duration, mask): 43 | device = duration.device 44 | 45 | b, t_x, t_y = mask.shape 46 | cum_duration = torch.cumsum(duration, 1) 47 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 48 | 49 | cum_duration_flat = cum_duration.view(b * t_x) 50 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 51 | path = path.view(b, t_x, t_y) 52 | path = path - torch.nn.functional.pad( 53 | path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 54 | path = path * mask 55 | return path 56 | 57 | 58 | def duration_loss(logw, logw_, lengths): 59 | loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) 60 | return loss 61 | 62 | 63 | def get_mcd(ground_truth_mel, predicted_mel): 64 | """Getting MCD from dtw. 65 | 66 | Args: 67 | ground_truth_mel: Ground truth mel. Shape (mel_d,t1) 68 | predicted_mel: Predicted mel. Shape (mel_d,t2) 69 | """ 70 | cost = librosa.sequence.dtw(ground_truth_mel, 71 | predicted_mel, 72 | backtrack=False) 73 | return cost[-1, -1] / ground_truth_mel.shape[1] 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightGrad: Lightweight Diffusion Probabilistic Model for Text-to-speech 2 | Demos are available at: https://thuhcsi.github.io/LightGrad/ 3 | 4 | ## Setup Environment 5 | 6 | Install python 3.10. 7 | 8 | Then, run: 9 | ```bash 10 | git clone --recursive https://github.com/thuhcsi/LightGrad.git 11 | python -m pip install -r requirements.txt 12 | ``` 13 | 14 | ## Training 15 | ### Preprocess for BZNSYP 16 | 17 | Download dataset from [url](https://www.data-baker.com/data/index/TNtts). 18 | Run 19 | ```bash 20 | python preprocess.py bznsyp [PATH_TO_DIRECTORY_CONTAINING_DATASET] \ 21 | [PATH_TO_DIRECTORY_FOR_SAVING_PREPROCESS_RESULTS] \ 22 | --test_sample_count 200 --valid_sample_count 200 23 | ``` 24 | This will produce `phn2id.json`, `train_dataset.json`, `test_dataset.json`, `valid_dataset.json` in `[PATH_TO_DIRECTORY_FOR_SAVING_PREPROCESS_RESULTS]`. 25 | 26 | ### Preprocess for LJSpeech 27 | 28 | Download dataset from [url](https://keithito.com/LJ-Speech-Dataset/). 29 | Run 30 | ```bash 31 | python preprocess.py ljspeech [PATH_TO_DIRECTORY_CONTAINING_DATASET] \ 32 | [PATH_TO_DIRECTORY_FOR_SAVING_PREPROCESS_RESULTS] \ 33 | --test_sample_count 200 --valid_sample_count 200 34 | ``` 35 | This will produce `phn2id.json`, `train_dataset.json`, `test_dataset.json`, `valid_dataset.json` in `[PATH_TO_DIRECTORY_FOR_SAVING_PREPROCESS_RESULTS]`. 36 | 37 | ### Training for BZNSYP 38 | 39 | Edit `config/bznsyp_config.yaml`, set `train_datalist_path`, `valid_datalist_path`, `phn2id_path` and `log_dir`. 40 | Run: 41 | ```bash 42 | python train.py -c config/bznsyp_config.yaml 43 | ``` 44 | 45 | ### Training for LJSpeech 46 | 47 | Edit `config/ljspeech_config.yaml`, set `train_datalist_path`, `valid_datalist_path`, `phn2id_path` and `log_dir`. 48 | Run: 49 | ```bash 50 | python train.py -c config/ljspeech_config.yaml 51 | ``` 52 | 53 | ## Inference 54 | 55 | Edit `inference.ipynb`. 56 | Set `HiFiGAN_CONFIG`, `HiFiGAN_ckpt` and `ckpt_path` to corresponding files, respectively. 57 | 58 | * Note: `add_blank` in `inference.ipynb` should be the same as that in `LightGrad/dataset.py`. 59 | 60 | ## References 61 | 62 | * Our model is based on [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones). 63 | * [HiFi-GAN](https://github.com/jik876/hifi-gan) is used as vocoder. 64 | -------------------------------------------------------------------------------- /config/bznsyp_config.yaml: -------------------------------------------------------------------------------- 1 | train_datalist_path: 2 | valid_datalist_path: 3 | phn2id_path: 4 | ckpt: 5 | 6 | n_mels: 80 7 | n_fft: 1024 8 | sample_rate: 22050 9 | hop_size: 256 10 | win_size: 1024 11 | f_min: 0 12 | f_max: 8000 13 | 14 | 15 | n_enc_channels: 128 16 | filter_channels: 512 17 | filter_channels_dp: 256 18 | n_enc_layers: 6 19 | enc_kernel: 3 20 | enc_dropout: 0.1 21 | n_heads: 2 22 | window_size: 4 23 | 24 | dec_dim: 64 25 | beta_min: 0.05 26 | beta_max: 20.0 27 | pe_scale: 1000 28 | 29 | log_dir: 30 | max_step: 1700000 31 | batch_size: 16 32 | learning_rate: 0.0001 33 | out_size: 2 34 | 35 | random_seed: 37 -------------------------------------------------------------------------------- /config/ljspeech_config.yaml: -------------------------------------------------------------------------------- 1 | train_datalist_path: 2 | valid_datalist_path: 3 | phn2id_path: 4 | ckpt: 5 | 6 | n_mels: 80 7 | n_fft: 1024 8 | sample_rate: 22050 9 | hop_size: 256 10 | win_size: 1024 11 | f_min: 0 12 | f_max: 8000 13 | 14 | 15 | n_enc_channels: 128 16 | filter_channels: 512 17 | filter_channels_dp: 256 18 | n_enc_layers: 6 19 | enc_kernel: 3 20 | enc_dropout: 0.1 21 | n_heads: 2 22 | window_size: 4 23 | 24 | dec_dim: 64 25 | beta_min: 0.05 26 | beta_max: 20.0 27 | pe_scale: 1000 28 | 29 | log_dir: 30 | max_step: 1700000 31 | batch_size: 16 32 | learning_rate: 0.0001 33 | out_size: 2 34 | 35 | random_seed: 37 -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import yaml\n", 11 | "\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import torch\n", 14 | "\n", 15 | "from LightGrad import LightGrad\n", 16 | "\n", 17 | "\n", 18 | "import IPython.display as ipd" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "def convert_phn_to_id(phonemes, phn2id):\n", 28 | " \"\"\"\n", 29 | " phonemes: phonemes separated by ' '\n", 30 | " phn2id: phn2id dict\n", 31 | " \"\"\"\n", 32 | " return [phn2id[x] for x in [''] + phonemes.split(' ') + ['']]\n", 33 | "\n", 34 | "\n", 35 | "def text2phnid(text, phn2id, language='zh', add_blank=True):\n", 36 | " if language == 'zh':\n", 37 | " from text import G2pZh\n", 38 | " character2phn = G2pZh()\n", 39 | " pinyin, phonemes = character2phn.character2phoneme(text)\n", 40 | " if add_blank:\n", 41 | " phonemes = ' '.join(phonemes.split(' '))\n", 42 | " return pinyin, phonemes, convert_phn_to_id(phonemes, phn2id)\n", 43 | " elif language == 'en':\n", 44 | " from text import G2pEn\n", 45 | " word2phn = G2pEn()\n", 46 | " phonemes = word2phn(text)\n", 47 | " if add_blank:\n", 48 | " phonemes = ' '.join(phonemes)\n", 49 | " return phonemes, convert_phn_to_id(phonemes, phn2id)\n", 50 | " else:\n", 51 | " raise ValueError(\n", 52 | " 'Language should be zh (for Chinese) or en (for English)!')\n", 53 | "\n", 54 | "\n", 55 | "def plot_mel(tensors, titles):\n", 56 | " xlim = max([t.shape[1] for t in tensors])\n", 57 | " fig, axs = plt.subplots(nrows=len(tensors),\n", 58 | " ncols=1,\n", 59 | " figsize=(12, 9),\n", 60 | " constrained_layout=True)\n", 61 | " for i in range(len(tensors)):\n", 62 | " im = axs[i].imshow(tensors[i],\n", 63 | " aspect=\"auto\",\n", 64 | " origin=\"lower\",\n", 65 | " interpolation='none')\n", 66 | " plt.colorbar(im, ax=axs[i])\n", 67 | " axs[i].set_title(titles[i])\n", 68 | " axs[i].set_xlim([0, xlim])\n", 69 | " fig.canvas.draw()\n", 70 | " return plt" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Setup HiFi-GAN\n", 80 | "\n", 81 | "from hifi_gan import models, env\n", 82 | "\n", 83 | "HiFiGAN_CONFIG = ''\n", 84 | "HiFiGAN_ckpt = ''\n", 85 | "with open(HiFiGAN_CONFIG) as f:\n", 86 | " hifigan_hparams = env.AttrDict(json.load(f))\n", 87 | "\n", 88 | "generator = models.Generator(hifigan_hparams)\n", 89 | "\n", 90 | "generator.load_state_dict(torch.load(\n", 91 | " HiFiGAN_ckpt, map_location='cpu')['generator'])\n", 92 | "generator = generator.eval()\n", 93 | "generator.remove_weight_norm()\n", 94 | "\n", 95 | "\n", 96 | "def convert_mel_to_audio(mel):\n", 97 | " # only support batch size of 1\n", 98 | " assert mel.shape[0] == 1\n", 99 | " with torch.no_grad():\n", 100 | " audio = generator(mel).squeeze(1) # (b,t)\n", 101 | " return audio" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# inference for bznsyp\n", 111 | "\n", 112 | "N_STEP = 4\n", 113 | "TEMP = 1.5\n", 114 | "STREAMING_CLIP_SIZE = 0.5 # in seconds\n", 115 | "\n", 116 | "config_path = 'config/bznsyp_config.yaml'\n", 117 | "ckpt_path = ''\n", 118 | "\n", 119 | "print('loading ', ckpt_path)\n", 120 | "_, _, state_dict = torch.load(ckpt_path,\n", 121 | " map_location='cpu')\n", 122 | "\n", 123 | "\n", 124 | "with open(config_path) as f:\n", 125 | " config = yaml.load(f, yaml.SafeLoader)\n", 126 | "\n", 127 | "with open(config['phn2id_path']) as f:\n", 128 | " phn2id = json.load(f)\n", 129 | "vocab_size = len(phn2id) + 1\n", 130 | "\n", 131 | "model = LightGrad.build_model(config, vocab_size)\n", 132 | "model.load_state_dict(state_dict)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "text = \"做一个测试\"\n", 142 | "\n", 143 | "pinyin, phonemes, phnid = text2phnid(text, phn2id, 'zh')\n", 144 | "print(f'pinyin seq: {pinyin}')\n", 145 | "print(f'phoneme seq: {phonemes}')\n", 146 | "phnid_len = torch.tensor(len(phnid), dtype=torch.long).unsqueeze(0)\n", 147 | "phnid = torch.tensor(phnid).unsqueeze(0)\n", 148 | "\n", 149 | "mel_clips = []\n", 150 | "\n", 151 | "streaming_clip_frames = STREAMING_CLIP_SIZE * config['sample_rate'] // config[\n", 152 | " 'hop_size']\n", 153 | "\n", 154 | "for _, mel_clip, _ in model.forward_streaming(phnid,\n", 155 | " phnid_len,\n", 156 | " n_timesteps=N_STEP,\n", 157 | " temperature=TEMP,\n", 158 | " out_size=streaming_clip_frames,\n", 159 | " solver='dpm'):\n", 160 | " mel_clips.append(mel_clip)\n", 161 | "\n", 162 | "mel_prediction_streaming = torch.cat(mel_clips, dim=2)\n", 163 | "\n", 164 | "_, mel_prediction, _ = model.forward(phnid,\n", 165 | " phnid_len,\n", 166 | " n_timesteps=N_STEP,\n", 167 | " temperature=TEMP,\n", 168 | " solver='dpm')\n", 169 | "\n", 170 | "plot_mel([mel_prediction_streaming[0], mel_prediction[0]],\n", 171 | " ['streaming inference', 'non-streaming inference'])\n", 172 | "\n", 173 | "ipd.display(\n", 174 | " ipd.Audio(convert_mel_to_audio(mel_prediction_streaming), rate=22050))\n", 175 | "ipd.display(ipd.Audio(convert_mel_to_audio(mel_prediction), rate=22050))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "# inference for ljspeech\n", 185 | "\n", 186 | "N_STEP = 4\n", 187 | "TEMP = 1.5\n", 188 | "STREAMING_CLIP_SIZE = 0.5 # in seconds\n", 189 | "\n", 190 | "config_path = 'config/ljspeech_config.yaml'\n", 191 | "ckpt_path = ''\n", 192 | "\n", 193 | "print('loading ', ckpt_path)\n", 194 | "_, _, state_dict = torch.load(ckpt_path,\n", 195 | " map_location='cpu')\n", 196 | "\n", 197 | "\n", 198 | "with open(config_path) as f:\n", 199 | " config = yaml.load(f, yaml.SafeLoader)\n", 200 | "\n", 201 | "with open(config['phn2id_path']) as f:\n", 202 | " phn2id = json.load(f)\n", 203 | "vocab_size = len(phn2id) + 1\n", 204 | "\n", 205 | "model = LightGrad.build_model(config, vocab_size)\n", 206 | "model.load_state_dict(state_dict)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "text = \"This is a test\"\n", 216 | "\n", 217 | "phonemes, phnid = text2phnid(text, phn2id, 'en')\n", 218 | "print(f'phoneme seq: {phonemes}', type(phonemes))\n", 219 | "phnid_len = torch.tensor(len(phnid), dtype=torch.long).unsqueeze(0)\n", 220 | "phnid = torch.tensor(phnid).unsqueeze(0)\n", 221 | "\n", 222 | "mel_clips = []\n", 223 | "\n", 224 | "streaming_clip_frames = STREAMING_CLIP_SIZE * config['sample_rate'] // config[\n", 225 | " 'hop_size']\n", 226 | "\n", 227 | "for _, mel_clip, _ in model.forward_streaming(phnid,\n", 228 | " phnid_len,\n", 229 | " n_timesteps=N_STEP,\n", 230 | " temperature=TEMP,\n", 231 | " out_size=streaming_clip_frames,\n", 232 | " solver='dpm'):\n", 233 | " mel_clips.append(mel_clip)\n", 234 | "\n", 235 | "mel_prediction_streaming = torch.cat(mel_clips, dim=2)\n", 236 | "\n", 237 | "_, mel_prediction, _ = model.forward(phnid,\n", 238 | " phnid_len,\n", 239 | " n_timesteps=N_STEP,\n", 240 | " temperature=TEMP,\n", 241 | " solver='dpm')\n", 242 | "\n", 243 | "plot_mel([mel_prediction_streaming[0], mel_prediction[0]],\n", 244 | " ['streaming inference', 'non-streaming inference'])\n", 245 | "\n", 246 | "ipd.display(ipd.Audio(convert_mel_to_audio(\n", 247 | " mel_prediction_streaming), rate=22050))\n", 248 | "ipd.display(ipd.Audio(convert_mel_to_audio(mel_prediction), rate=22050))" 249 | ] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "gradtts", 255 | "language": "python", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.10.12" 269 | }, 270 | "metadata": { 271 | "interpreter": { 272 | "hash": "1c27759576147a09f82f75fe7e6da160ee29ac300de0ba196702adc9d307c9a1" 273 | } 274 | }, 275 | "vscode": { 276 | "interpreter": { 277 | "hash": "1059529bf0eac96a858df282bfcfa3b0fdcaa085677d3010c56aeec385ff20b6" 278 | } 279 | } 280 | }, 281 | "nbformat": 4, 282 | "nbformat_minor": 4 283 | } 284 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import random 4 | import re 5 | import json 6 | import tqdm 7 | import itertools 8 | 9 | 10 | # for BZNSYP, 200 samples for test, 200 samples for validation 11 | # for LJSpeech, 523 samples for test, 348 samples for validation 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("dataset", choices=["ljspeech", "bznsyp"]) 17 | parser.add_argument("dataset_path", type=str, help="path to dataset dir") 18 | parser.add_argument("export_dir", type=str, 19 | help="path to save preprocess result") 20 | parser.add_argument("--test_sample_count", type=int, default=200) 21 | parser.add_argument("--valid_sample_count", type=int, default=200) 22 | return parser.parse_args() 23 | 24 | 25 | def main(): 26 | args = get_args() 27 | if args.dataset == "ljspeech": 28 | (train_dataset, valid_dataset, test_dataset, 29 | phn2id) = preprocess_ljspeech(args) 30 | if args.dataset == "bznsyp": 31 | (train_dataset, valid_dataset, test_dataset, 32 | phn2id) = preprocess_bznsyp(args) 33 | export_dir = pathlib.Path(args.export_dir) 34 | export_dir.mkdir(parents=True, exist_ok=True) 35 | with open(export_dir / "train_dataset.json", "w") as f: 36 | json.dump(train_dataset, f) 37 | with open(export_dir / "valid_dataset.json", "w") as f: 38 | json.dump(valid_dataset, f) 39 | with open(export_dir / "test_dataset.json", "w") as f: 40 | json.dump(test_dataset, f) 41 | with open(export_dir / "phn2id.json", "w") as f: 42 | json.dump(phn2id, f) 43 | 44 | 45 | def preprocess_ljspeech(args): 46 | from text import G2pEn, phn2id_en 47 | 48 | dataset_path = pathlib.Path(args.dataset_path) 49 | metadata_path = dataset_path / "metadata.csv.txt" 50 | meta_info = [] 51 | g2p = G2pEn() 52 | with open(metadata_path) as f: 53 | for line in tqdm.tqdm(f.readlines()): 54 | name, _, normalized_text = line.strip().split("|") 55 | wav_path = dataset_path / "wavs" / f"{name}.wav" 56 | if wav_path.exists(): 57 | phonemes = g2p(normalized_text) 58 | meta_info.append( 59 | { 60 | "name": name, 61 | "wav_path": str(wav_path), 62 | "text": normalized_text, 63 | "phonemes": phonemes, 64 | } 65 | ) 66 | random.shuffle(meta_info) 67 | test_dataset = meta_info[: args.test_sample_count] 68 | valid_dataset = meta_info[ 69 | args.test_sample_count: args.test_sample_count + args.valid_sample_count 70 | ] 71 | train_dataset = meta_info[args.test_sample_count + 72 | args.valid_sample_count:] 73 | return train_dataset, valid_dataset, test_dataset, phn2id_en 74 | 75 | 76 | def preprocess_bznsyp(args): 77 | from text import G2pZh 78 | 79 | punc = set([",", '、', '。', '!', ':', ';', '?']) 80 | 81 | dataset_path = pathlib.Path(args.dataset_path) 82 | metadata_path = dataset_path / 'ProsodyLabeling' / '000001-010000.txt' 83 | meta_info = [] 84 | g2p = G2pZh() 85 | with open(metadata_path) as f: 86 | all_lines = f.readlines() 87 | text_labels = all_lines[0::2] 88 | pinyin_labels = all_lines[1::2] 89 | for text_label, pinyin_label in tqdm.tqdm(zip(text_labels, pinyin_labels)): 90 | name, text = text_label.split() 91 | wav_path = dataset_path / "Wave" / f"{name}.wav" 92 | if wav_path.exists(): 93 | pinyin = re.sub('ng1 yuan4 le5', 94 | 'en1 yuan4 le5', pinyin_label[1:]) 95 | pinyin = re.sub('P IY1 guo4', 'pi1 guo4', pinyin).split() 96 | text = re.sub('…”$', '。”', text) 97 | text = re.sub('[“”]', '', text) 98 | text = re.sub('…。$', '。', text) 99 | text = re.sub('…{1,}$', '。', text) 100 | text = re.sub('…{1,}', ',', text) 101 | text = re.sub('—{1,}', '。', text) 102 | text = re.sub('[()]', '', text) 103 | i = 0 104 | j = 0 105 | phonemes = [] 106 | while i < len(text): 107 | # insert prosodic structure label 108 | if text[i] == '#': 109 | if text[i+1] in {'1', '2', '3', '4'}: 110 | phonemes.append('#'+text[i+1]) 111 | i += 2 112 | else: 113 | i += 1 114 | # insert punctuation 115 | elif text[i] in punc: 116 | phonemes.append(text[i]) 117 | i += 1 118 | else: 119 | # skip erhua 120 | if text[i] == '儿': 121 | if j < len(pinyin): 122 | if not pinyin[j].startswith('er'): 123 | i += 1 124 | continue 125 | # erhua at the end of sentence 126 | else: 127 | i += 1 128 | continue 129 | # insert pinyin for current character 130 | phonemes.append(pinyin[j]) 131 | i += 1 132 | j += 1 133 | 134 | phonemes = g2p.pinyin2phoneme(' '.join(phonemes)) 135 | meta_info.append( 136 | { 137 | "name": name, 138 | "wav_path": str(wav_path), 139 | "text": text, 140 | "phonemes": phonemes, 141 | } 142 | ) 143 | random.shuffle(meta_info) 144 | test_dataset = meta_info[: args.test_sample_count] 145 | valid_dataset = meta_info[ 146 | args.test_sample_count: args.test_sample_count + args.valid_sample_count 147 | ] 148 | train_dataset = meta_info[args.test_sample_count + 149 | args.valid_sample_count:] 150 | phn2id = {x: i+1 for i, x in enumerate(sorted(itertools.chain( 151 | g2p.phn2id().keys(), punc, set(['#1', '#2', '#3', '#4']))))} 152 | return train_dataset, valid_dataset, test_dataset, phn2id 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | matplotlib 4 | inflect 5 | librosa 6 | scipy 7 | tensorboard 8 | Unidecode 9 | g2pM 10 | g2p_en 11 | pyyaml 12 | git+https://github.com/unrea1-sama/monotonic_align.git 13 | notebook -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from .g2p_en import G2pEn, _symbol_to_id as phn2id_en 2 | from .g2p_zh import G2pZh -------------------------------------------------------------------------------- /text/en_cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from unidecode import unidecode 5 | from .numbers import normalize_numbers 6 | 7 | 8 | _whitespace_re = re.compile(r'\s+') 9 | 10 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 11 | ('mrs', 'misess'), 12 | ('mr', 'mister'), 13 | ('dr', 'doctor'), 14 | ('st', 'saint'), 15 | ('co', 'company'), 16 | ('jr', 'junior'), 17 | ('maj', 'major'), 18 | ('gen', 'general'), 19 | ('drs', 'doctors'), 20 | ('rev', 'reverend'), 21 | ('lt', 'lieutenant'), 22 | ('hon', 'honorable'), 23 | ('sgt', 'sergeant'), 24 | ('capt', 'captain'), 25 | ('esq', 'esquire'), 26 | ('ltd', 'limited'), 27 | ('col', 'colonel'), 28 | ('ft', 'fort'), 29 | ]] 30 | 31 | 32 | def expand_abbreviations(text): 33 | for regex, replacement in _abbreviations: 34 | text = re.sub(regex, replacement, text) 35 | return text 36 | 37 | 38 | def expand_numbers(text): 39 | return normalize_numbers(text) 40 | 41 | 42 | def lowercase(text): 43 | return text.lower() 44 | 45 | 46 | def collapse_whitespace(text): 47 | return re.sub(_whitespace_re, ' ', text) 48 | 49 | 50 | def convert_to_ascii(text): 51 | return unidecode(text) 52 | 53 | 54 | def basic_cleaners(text): 55 | text = lowercase(text) 56 | text = collapse_whitespace(text) 57 | return text 58 | 59 | 60 | def transliteration_cleaners(text): 61 | text = convert_to_ascii(text) 62 | text = lowercase(text) 63 | text = collapse_whitespace(text) 64 | return text 65 | 66 | 67 | def english_cleaners(text): 68 | text = convert_to_ascii(text) 69 | text = lowercase(text) 70 | text = expand_numbers(text) 71 | text = expand_abbreviations(text) 72 | text = collapse_whitespace(text) 73 | return text 74 | -------------------------------------------------------------------------------- /text/g2p_en.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from text import en_cleaners 5 | from g2p_en import G2p 6 | 7 | valid_symbols = [ 8 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 9 | 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 10 | 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 11 | 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 12 | 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 13 | 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 14 | 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 15 | 'Y', 'Z', 'ZH' 16 | ] 17 | 18 | _punctuation = '!\'(),.:;? ' 19 | _special = ['-', '', '', ''] 20 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 21 | _alt_re = re.compile(r'\([0-9]+\)') 22 | 23 | _arpabet = [s for s in valid_symbols] 24 | 25 | # Export all symbols: 26 | symbols = _special + list(_punctuation) + _arpabet 27 | _valid_symbol_set = set(valid_symbols) 28 | 29 | # zero is reserved for padding 30 | _symbol_to_id = {s: i + 1 for i, s in enumerate(symbols)} 31 | _id_to_symbol = {i + 1: s for i, s in enumerate(symbols)} 32 | 33 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 34 | 35 | 36 | def _clean_text(text, cleaner_names): 37 | for name in cleaner_names: 38 | cleaner = getattr(en_cleaners, name) 39 | if not cleaner: 40 | raise Exception('Unknown cleaner: %s' % name) 41 | text = cleaner(text) 42 | text = re.sub('-','',text) 43 | return text 44 | 45 | 46 | class G2pEn(): 47 | 48 | def __init__(self) -> None: 49 | self.g2p = G2p() 50 | 51 | def __call__(self, text): 52 | phonemes = self.g2p(_clean_text(text, ["english_cleaners"])) 53 | text = ' '.join(phonemes) 54 | text = re.sub(' ,',',',text) 55 | text = re.sub(', ',',',text) 56 | text = re.sub(' !','!',text) 57 | text = re.sub('! ','!',text) 58 | text = re.sub('\? ','?',text) 59 | text = re.sub(' \?','?',text) 60 | text = re.sub(" '","'",text) 61 | text = re.sub("' ","'",text) 62 | text = re.sub(" "," - ",text) 63 | return text.split(' ') 64 | 65 | -------------------------------------------------------------------------------- /text/g2p_zh.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Generate lexicon and symbols for Mandarin Chinese phonology. 16 | The lexicon is used for Montreal Force Aligner. 17 | Note that syllables are used as word in this lexicon. Since syllables rather 18 | than words are used in transcriptions produced by `reorganize_baker.py`. 19 | We make this choice to better leverage other software for chinese text to 20 | pinyin tools like pypinyin. This is the convention for G2P in Chinese. 21 | """ 22 | import re 23 | from collections import OrderedDict 24 | from g2pM import G2pM 25 | 26 | INITIALS = [ 27 | 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh', 28 | 'r', 'z', 'c', 's', 'j', 'q', 'x' 29 | ] 30 | 31 | FINALS = [ 32 | 'a', 'ai', 'ao', 'an', 'ang', 'e', 'er', 'ei', 'en', 'eng', 'o', 'ou', 33 | 'ong', 'ii', 'iii', 'i', 'ia', 'iao', 'ian', 'iang', 'ie', 'io', 'iou', 34 | 'iong', 'in', 'ing', 'u', 'ua', 'uai', 'uan', 'uang', 'uei', 'uo', 'uen', 35 | 'ueng', 'v', 've', 'van', 'vn' 36 | ] 37 | 38 | SPECIALS = ['sil', 'sp', '', '', ''] 39 | 40 | 41 | def rule(C, V, R, T): 42 | """Generate a syllable given the initial, the final, erhua indicator, 43 | and tone. Orthographical rules for pinyin are 44 | applied. (special case for y, w, ui, un, iu) 45 | Note that in this system, 'ü' is alway written as 'v' when appeared in 46 | phoneme, but converted to 'u' in syllables when certain conditions 47 | are satisfied. 48 | 'i' is distinguished when appeared in phonemes, and separated into 3 49 | categories, 'i', 'ii' and 'iii'. 50 | Erhua is is possibly applied to every finals, except for finals that 51 | already ends with 'r'. 52 | When a syllable is impossible or does not have any characters with this 53 | pronunciation, return None to filter it out. 54 | """ 55 | 56 | # 不可拼的音节, ii 只能和 z, c, s 拼 57 | if V in ["ii"] and (C not in ['z', 'c', 's']): 58 | return None 59 | # iii 只能和 zh, ch, sh, r 拼 60 | if V in ['iii'] and (C not in ['zh', 'ch', 'sh', 'r']): 61 | return None 62 | 63 | # 齐齿呼或者撮口呼不能和 f, g, k, h, zh, ch, sh, r, z, c, s 64 | if (V not in ['ii', 'iii']) and V[0] in ['i', 'v'] and (C in [ 65 | 'f', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'r', 'z', 'c', 's' 66 | ]): 67 | return None 68 | 69 | # 撮口呼只能和 j, q, x l, n 拼 70 | if V.startswith("v"): 71 | # v, ve 只能和 j ,q , x, n, l 拼 72 | if V in ['v', 've']: 73 | if C not in ['j', 'q', 'x', 'n', 'l', '']: 74 | return None 75 | # 其他只能和 j, q, x 拼 76 | else: 77 | if C not in ['j', 'q', 'x', '']: 78 | return None 79 | 80 | # j, q, x 只能和齐齿呼或者撮口呼拼 81 | if (C in ['j', 'q', 'x' 82 | ]) and not ((V not in ['ii', 'iii']) and V[0] in ['i', 'v']): 83 | return None 84 | 85 | # b, p ,m, f 不能和合口呼拼,除了 u 之外 86 | # bm p, m, f 不能和撮口呼拼 87 | if (C in ['b', 'p', 'm', 'f']) and ((V[0] in ['u', 'v'] and V != "u") 88 | or V == 'ong'): 89 | return None 90 | 91 | # ua, uai, uang 不能和 d, t, n, l, r, z, c, s 拼 92 | if V in ['ua', 'uai', 'uang' 93 | ] and C in ['d', 't', 'n', 'l', 'r', 'z', 'c', 's']: 94 | return None 95 | 96 | # sh 和 ong 不能拼 97 | if V == 'ong' and C in ['sh']: 98 | return None 99 | 100 | # o 和 gkh, zh ch sh r z c s 不能拼 101 | if V == "o" and C in [ 102 | 'd', 't', 'n', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'r', 'z', 'c', 's' 103 | ]: 104 | return None 105 | 106 | # ueng 只是 weng 这个 ad-hoc 其他情况下都是 ong 107 | if V == 'ueng' and C != '': 108 | return 109 | 110 | # 非儿化的 er 只能单独存在 111 | if V == 'er' and C != '': 112 | return None 113 | 114 | if C == '': 115 | if V in ["i", "in", "ing"]: 116 | C = 'y' 117 | elif V == 'u': 118 | C = 'w' 119 | elif V.startswith('i') and V not in ["ii", "iii"]: 120 | C = 'y' 121 | V = V[1:] 122 | elif V.startswith('u'): 123 | C = 'w' 124 | V = V[1:] 125 | elif V.startswith('v'): 126 | C = 'yu' 127 | V = V[1:] 128 | else: 129 | if C in ['j', 'q', 'x']: 130 | if V.startswith('v'): 131 | V = re.sub('v', 'u', V) 132 | if V == 'iou': 133 | V = 'iu' 134 | elif V == 'uei': 135 | V = 'ui' 136 | elif V == 'uen': 137 | V = 'un' 138 | result = C + V 139 | 140 | # Filter er 不能再儿化 141 | if result.endswith('r') and R == 'r': 142 | return None 143 | 144 | # ii and iii, change back to i 145 | result = re.sub(r'i+', 'i', result) 146 | 147 | result = result + R + T 148 | return result 149 | 150 | 151 | def generate_lexicon(with_tone=False, with_erhua=False): 152 | """Generate lexicon for Mandarin Chinese.""" 153 | syllables = OrderedDict() 154 | 155 | for C in [''] + INITIALS: 156 | for V in FINALS: 157 | for R in [''] if not with_erhua else ['', 'r']: 158 | for T in [''] if not with_tone else ['1', '2', '3', '4', '5']: 159 | result = rule(C, V, R, T) 160 | if result: 161 | # remove whitespace at the begining 162 | syllables[result] = re.sub('^ ', '', f'{C} {V}{R}{T}') 163 | return syllables 164 | 165 | 166 | def generate_symbols(lexicon): 167 | """Generate phoneme list for a lexicon.""" 168 | symbols = set() 169 | for p in SPECIALS: 170 | symbols.add(p) 171 | for syllable, phonemes in lexicon.items(): 172 | phonemes = phonemes.split() 173 | for p in phonemes: 174 | symbols.add(p) 175 | return sorted(list(symbols)) 176 | 177 | 178 | class G2pZh: 179 | 180 | def __init__(self): 181 | self.lexicon = generate_lexicon(True, True) 182 | self.symbols = generate_symbols(self.lexicon) 183 | self.model = G2pM() 184 | self.re = re.compile('u:') 185 | 186 | def character2phoneme(self, text): 187 | pinyin = ' '.join(self.model(text, tone=True, char_split=False)) 188 | pinyin = self.re.sub('v',pinyin) 189 | phonemes = self.pinyin2phoneme(pinyin) 190 | return pinyin,' '.join(phonemes) 191 | 192 | def pinyin2phoneme(self, text): 193 | result = [] 194 | for pinyin in text.split(' '): 195 | phoneme = self.lexicon.get(pinyin, pinyin) 196 | result.extend(phoneme.split(' ')) 197 | return result 198 | 199 | def phn2id(self): 200 | return {phn: i + 1 for i, phn in enumerate(self.symbols)} 201 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', 60 | group=2).replace(', ', ' ') 61 | else: 62 | return _inflect.number_to_words(num, andword='') 63 | 64 | 65 | def normalize_numbers(text): 66 | text = re.sub(_comma_number_re, _remove_commas, text) 67 | text = re.sub(_pounds_re, r'\1 pounds', text) 68 | text = re.sub(_dollars_re, _expand_dollars, text) 69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 70 | text = re.sub(_ordinal_re, _expand_ordinal, text) 71 | text = re.sub(_number_re, _expand_number, text) 72 | return text 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | from LightGrad import LightGrad 16 | from utils import plot_tensor, save_plot 17 | import yaml 18 | import argparse 19 | import random 20 | import pathlib 21 | 22 | from LightGrad.dataset import Dataset, collateFn 23 | 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("-c", "--config", type=str, help="path to config file") 28 | return parser.parse_args() 29 | 30 | 31 | if __name__ == "__main__": 32 | args = get_args() 33 | print("Initializing data loaders...") 34 | with open(args.config) as f: 35 | config = yaml.load(f, yaml.SafeLoader) 36 | log_dir = pathlib.Path(config["log_dir"]) 37 | torch.manual_seed(config["random_seed"]) 38 | np.random.seed(config["random_seed"]) 39 | train_dataset = Dataset( 40 | config["train_datalist_path"], 41 | config["phn2id_path"], 42 | config["sample_rate"], 43 | config["n_fft"], 44 | config["n_mels"], 45 | config["f_min"], 46 | config["f_max"], 47 | config["hop_size"], 48 | config["win_size"], 49 | ) 50 | val_dataset = Dataset( 51 | config["valid_datalist_path"], 52 | config["phn2id_path"], 53 | config["sample_rate"], 54 | config["n_fft"], 55 | config["n_mels"], 56 | config["f_min"], 57 | config["f_max"], 58 | config["hop_size"], 59 | config["win_size"], 60 | ) 61 | train_loader = DataLoader( 62 | train_dataset, 63 | batch_size=config["batch_size"], 64 | collate_fn=collateFn, 65 | num_workers=16, 66 | ) 67 | val_loader = DataLoader( 68 | val_dataset, batch_size=config["batch_size"], shuffle=True, collate_fn=collateFn 69 | ) 70 | 71 | print("Initializing model...") 72 | model = LightGrad.build_model(config, train_dataset.get_vocab_size()) 73 | print(f"Total parameters: {model.nparams}") 74 | start_epoch = 1 75 | start_steps = 1 76 | if config["ckpt"]: 77 | print("loading ", config["ckpt"]) 78 | epoch, steps, state_dict = torch.load(config["ckpt"], map_location="cpu") 79 | start_epoch = epoch + 1 80 | start_steps = steps + 1 81 | model.load_state_dict(state_dict) 82 | 83 | model = model.cuda() 84 | 85 | print("Initializing optimizer...") 86 | optimizer = torch.optim.Adam(params=model.parameters(), lr=config["learning_rate"]) 87 | 88 | print("Initializing logger...") 89 | logger = SummaryWriter(log_dir=log_dir) 90 | 91 | ckpt_dir = log_dir / "ckpt" 92 | pic_dir = log_dir / "pic" 93 | ckpt_dir.mkdir(parents=True, exist_ok=True) 94 | pic_dir.mkdir(parents=True, exist_ok=True) 95 | print("Start training...") 96 | iteration = start_steps 97 | out_size = config["out_size"] * config["sample_rate"] // config["hop_size"] 98 | 99 | for epoch in range(start_epoch, start_epoch + 10000): 100 | model.train() 101 | dur_losses = [] 102 | prior_losses = [] 103 | diff_losses = [] 104 | with tqdm( 105 | train_loader, total=len(train_dataset) // config["batch_size"] 106 | ) as progress_bar: 107 | for batch_idx, batch in enumerate(progress_bar): 108 | model.zero_grad() 109 | x, x_lengths = batch["x"].cuda(), batch["x_lengths"].cuda() 110 | y, y_lengths = batch["y"].cuda(), batch["y_lengths"].cuda() 111 | dur_loss, prior_loss, diff_loss = model.compute_loss( 112 | x, x_lengths, y, y_lengths, out_size=out_size 113 | ) 114 | loss = sum([dur_loss, prior_loss, diff_loss]) 115 | loss.backward() 116 | 117 | enc_grad_norm = torch.nn.utils.clip_grad_norm_( 118 | model.encoder.parameters(), max_norm=1 119 | ) 120 | dec_grad_norm = torch.nn.utils.clip_grad_norm_( 121 | model.decoder.parameters(), max_norm=1 122 | ) 123 | optimizer.step() 124 | 125 | logger.add_scalar( 126 | "training/duration_loss", dur_loss.item(), global_step=iteration 127 | ) 128 | logger.add_scalar( 129 | "training/prior_loss", prior_loss.item(), global_step=iteration 130 | ) 131 | logger.add_scalar( 132 | "training/diffusion_loss", diff_loss.item(), global_step=iteration 133 | ) 134 | logger.add_scalar( 135 | "training/encoder_grad_norm", enc_grad_norm, global_step=iteration 136 | ) 137 | logger.add_scalar( 138 | "training/decoder_grad_norm", dec_grad_norm, global_step=iteration 139 | ) 140 | 141 | dur_losses.append(dur_loss.item()) 142 | prior_losses.append(prior_loss.item()) 143 | diff_losses.append(diff_loss.item()) 144 | 145 | if batch_idx % 5 == 0: 146 | msg = ( 147 | f"LightGrad Epoch: {epoch}, iteration: {iteration} | " 148 | f" dur_loss: {dur_loss.item()}, " 149 | f"prior_loss: {prior_loss.item()}, " 150 | f"diff_loss: {diff_loss.item()}" 151 | ) 152 | progress_bar.set_description(msg) 153 | 154 | iteration += 1 155 | if iteration >= config["max_step"]: 156 | torch.save( 157 | [epoch, iteration, model.state_dict()], 158 | f=ckpt_dir / f"LightGrad_{epoch}_{iteration}.pt", 159 | ) 160 | model.eval() 161 | with torch.no_grad(): 162 | all_dur_loss = [] 163 | all_prior_loss = [] 164 | all_diffusion_loss = [] 165 | for _, item in enumerate(val_loader): 166 | x, x_lengths = batch["x"].cuda(), batch["x_lengths"].cuda() 167 | y, y_lengths = batch["y"].cuda(), batch["y_lengths"].cuda() 168 | 169 | dur_loss, prior_loss, diff_loss = model.compute_loss( 170 | x, x_lengths, y, y_lengths, out_size=out_size 171 | ) 172 | loss = sum([dur_loss, prior_loss, diff_loss]) 173 | all_dur_loss.append(dur_loss) 174 | all_prior_loss.append(prior_loss) 175 | all_diffusion_loss.append(diff_loss) 176 | average_dur_loss = sum(all_dur_loss) / len(all_dur_loss) 177 | average_prior_loss = sum(all_prior_loss) / len(all_prior_loss) 178 | average_diffusion_loss = sum(all_diffusion_loss) / len(all_diffusion_loss) 179 | logger.add_scalar("val/duration_loss", average_dur_loss, global_step=epoch) 180 | logger.add_scalar("val/prior_loss", average_prior_loss, global_step=epoch) 181 | logger.add_scalar( 182 | "val/diffusion_loss", average_diffusion_loss, global_step=epoch 183 | ) 184 | print( 185 | f"val duration_loss: {average_dur_loss}, " 186 | f"prior_loss: {average_prior_loss}, " 187 | f"diffusion_loss: {average_diffusion_loss}" 188 | ) 189 | y_enc, y_dec, attn = model(x, x_lengths, n_timesteps=10) 190 | idx = random.randrange(0, y_enc.shape[0]) 191 | y_enc = y_enc[idx].cpu() 192 | y_dec = y_dec[idx].cpu() 193 | y = y[idx].cpu() 194 | attn = attn[idx][0].cpu() 195 | logger.add_image( 196 | "image/generated_enc", 197 | plot_tensor(y_enc), 198 | global_step=epoch, 199 | dataformats="HWC", 200 | ) 201 | logger.add_image( 202 | "image/generated_dec", 203 | plot_tensor(y_dec), 204 | global_step=epoch, 205 | dataformats="HWC", 206 | ) 207 | logger.add_image( 208 | "image/alignment", 209 | plot_tensor(attn), 210 | global_step=epoch, 211 | dataformats="HWC", 212 | ) 213 | logger.add_image( 214 | "image/ground_truth", 215 | plot_tensor(y), 216 | global_step=epoch, 217 | dataformats="HWC", 218 | ) 219 | save_plot(y_enc, pic_dir / f"generated_enc_{epoch}.png") 220 | save_plot(y_dec, pic_dir / f"generated_dec_{epoch}.png") 221 | save_plot(attn, pic_dir / f"alignment_{epoch}.png") 222 | save_plot(y, pic_dir / f"ground_truth_{epoch}.png") 223 | torch.save( 224 | [epoch, iteration, model.state_dict()], 225 | f=ckpt_dir / f"LightGrad_{epoch}_{iteration}.pt", 226 | ) 227 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. 2 | # This program is free software; you can redistribute it and/or modify 3 | # it under the terms of the MIT License. 4 | # This program is distributed in the hope that it will be useful, 5 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 6 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 7 | # MIT License for more details. 8 | 9 | import os 10 | import glob 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from torch.nn.utils import weight_norm 14 | 15 | import torch 16 | 17 | 18 | def intersperse(lst, item): 19 | # Adds blank symbol 20 | result = [item] * (len(lst) * 2 + 1) 21 | result[1::2] = lst 22 | return result 23 | 24 | 25 | def parse_filelist(filelist_path, split_char="|"): 26 | with open(filelist_path, encoding="utf-8") as f: 27 | filepaths_and_text = [line.strip().split(split_char) for line in f] 28 | return filepaths_and_text 29 | 30 | 31 | def latest_checkpoint_path(dir_path, regex="grad_*.pt"): 32 | f_list = glob.glob(os.path.join(dir_path, regex)) 33 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 34 | x = f_list[-1] 35 | return x 36 | 37 | 38 | def load_checkpoint(logdir, model, num=None): 39 | if num is None: 40 | model_path = latest_checkpoint_path(logdir, regex="grad_*.pt") 41 | else: 42 | model_path = os.path.join(logdir, f"grad_{num}.pt") 43 | print(f"Loading checkpoint {model_path}...") 44 | model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) 45 | model.load_state_dict(model_dict, strict=False) 46 | return model 47 | 48 | 49 | def save_figure_to_numpy(fig): 50 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 51 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 52 | return data 53 | 54 | 55 | def plot_tensor(tensor): 56 | plt.style.use("default") 57 | fig, ax = plt.subplots(figsize=(12, 3)) 58 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 59 | plt.colorbar(im, ax=ax) 60 | plt.tight_layout() 61 | fig.canvas.draw() 62 | data = save_figure_to_numpy(fig) 63 | plt.close() 64 | return data 65 | 66 | 67 | def save_plot(tensor, savepath): 68 | plt.style.use("default") 69 | fig, ax = plt.subplots(figsize=(12, 3)) 70 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 71 | plt.colorbar(im, ax=ax) 72 | plt.tight_layout() 73 | fig.canvas.draw() 74 | plt.savefig(savepath) 75 | plt.close() 76 | 77 | 78 | def init_weights(m, mean=0.0, std=0.01): 79 | classname = m.__class__.__name__ 80 | if classname.find("Conv") != -1: 81 | m.weight.data.normal_(mean, std) 82 | 83 | 84 | def apply_weight_norm(m): 85 | classname = m.__class__.__name__ 86 | if classname.find("Conv") != -1: 87 | weight_norm(m) 88 | 89 | 90 | def get_padding(kernel_size, dilation=1): 91 | return int((kernel_size * dilation - dilation) / 2) 92 | --------------------------------------------------------------------------------