├── .gitignore ├── README.md ├── assets └── attention.gif ├── config.py ├── data.py ├── lj_eval_idx.npy ├── model.py ├── module.py ├── network.py ├── prepro.py ├── synthesize.py ├── test_sents.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | logs 4 | runs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCTTS (Deep Convolutional TTS) - pytorch implementation 2 | ### Paper: [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention](https://arxiv.org/abs/1710.08969) 3 | 4 | ## Prerequisite 5 | - python 3.6 6 | - pytorch 1.0 7 | - librosa, scipy, tqdm, tensorboardX 8 | 9 | ## Dataset 10 | - [LJ Speech 1.1](https://keithito.com/LJ-Speech-Dataset/), female single speaker dataset. 11 | - I follow [Kyubyong's DCTTS repo with TensorFlow](https://github.com/Kyubyong/dc_tts) for preprocessing speech signal data. It actually worked well. 12 | 13 | ## Usage 14 | 1. Download the above dataset and modify the path in config.py. And then run the below command. 1st arg: signal prepro, 2nd arg: metadata (train/test split) 15 | ``` 16 | python prepro.py 1 1 17 | ``` 18 | 19 | 2. DCTTS has two models. Firstly, you should train the model Text2Mel. I think that 20k step is enough (for only an hour). But you should train the model more and more with decaying guided attention loss. 20 | ``` 21 | python train.py 1 22 | ``` 23 | 24 | 3. Secondly, train the SSRN. The outputs of SSRN are many high resolution data. So training SSRN is slower than training Text2Mel 25 | ``` 26 | python train.py 2 27 | ``` 28 | 29 | 4. After training, you can synthesize some speech from text. 30 | ``` 31 | python synthesize.py 32 | ``` 33 | 34 | ## Attention 35 | - In speech synthesis, the attention module is important. If the model is normally trained, then you can see the monotonic attention like the follow figures. 36 | 37 | ![](assets/attention.gif) 38 | 39 | ## Notes 40 | - To do: previous attention for inference. 41 | - To do: Alleviate the overfitting. 42 | - In the paper, they did not refer normalization. So I used weight normalization like DeepVoice3. 43 | - Some hyperparameters are different. 44 | - If you want to improve the performance, you should use all of the data. For some various experiments, I seperated the training set and the validation set. 45 | 46 | ## Other Codes 47 | - [Another pytorch implementation](https://github.com/chaiyujin/dctts-pytorch) 48 | - [TensorFlow implementation](https://github.com/Kyubyong/dc_tts) 49 | -------------------------------------------------------------------------------- /assets/attention.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yangyangii/DeepConvolutionalTTS-pytorch/4807de0dba0398fd6e757d9f81cc3e572f3c7f94/assets/attention.gif -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | class ConfigArgs: 3 | data_path = '/home/yangyangii/ssd/data/LJSpeech-1.1' 4 | mel_dir, mag_dir = 'd_mels', 'd_mags' 5 | ga_dir = 'guides' # guided attention 6 | meta = 'metadata.csv' 7 | meta_train = 'meta-train.csv' 8 | meta_eval = 'meta-eval.csv' 9 | testset = 'test_sents.txt' 10 | logdir = 'logs' 11 | sampledir = 'samples' 12 | prepro = True 13 | mem_mode= True 14 | ga_mode = True 15 | log_mode = True 16 | save_term = 1000 17 | n_workers = 8 18 | n_gpu = 2 19 | global_step = 0 20 | 21 | sr = 22050 # sampling rate 22 | preemph = 0.97 # pre-emphasize 23 | n_fft = 2048 24 | n_mags = n_fft//2 + 1 25 | n_mels = 80 26 | frame_shift = 0.0125 27 | frame_length = 0.05 28 | hop_length = int(sr*frame_shift) 29 | win_length = int(sr*frame_length) 30 | gl_iter = 50 # Griffin-Lim iteration 31 | max_db = 100 32 | ref_db = 20 33 | power = 1.5 34 | r = 4 # reduction factor. mel/4 35 | g = 0.2 36 | 37 | batch_size = 32 38 | test_batch = 50 # for test 39 | max_step = 200000 40 | lr = 0.001 41 | lr_decay_step = 50000 # actually not decayed per this step 42 | Ce = 128 # for text embedding and encoding 43 | Cx = 256 # for text embedding and encoding 44 | Cy = 256 # for audio encoding 45 | Cs = 512 # for SSRN 46 | drop_rate = 0.05 47 | 48 | max_Tx = 188 49 | max_Ty = 250 50 | 51 | vocab = u'''PE !',-.?abcdefghijklmnopqrstuvwxyz''' 52 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os, sys 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | import glob, re 7 | import utils 8 | import codecs, unicodedata 9 | from config import ConfigArgs as args 10 | 11 | class SpeechDataset(Dataset): 12 | def __init__(self, data_path, metadata, model_name, mem_mode=False, ga_mode=False): 13 | ''' 14 | Args: 15 | data_path (str): path to dataset 16 | meta_path (str): path to metadata csv file 17 | model_name (str): {'Text2Mel', 'SSRN', 'All'} 18 | ''' 19 | self.data_path = data_path 20 | self.model_name = model_name 21 | self.mem_mode = mem_mode 22 | self.ga_mode = ga_mode 23 | self.fpaths, self.texts, self.norms = read_meta(os.path.join(data_path, metadata)) 24 | if self.mem_mode: 25 | self.mels = [torch.tensor(np.load(os.path.join( 26 | self.data_path, args.mel_dir, path))) for path in self.fpaths] 27 | if self.ga_mode: 28 | self.g_att = [torch.tensor(np.load(os.path.join( 29 | self.data_path, args.ga_dir, path))) for path in self.fpaths] 30 | 31 | def __getitem__(self, idx): 32 | text, mel, mag = None, None, None 33 | text = torch.tensor(self.norms[idx], dtype=torch.long) 34 | # Memory mode is faster 35 | if not self.mem_mode: 36 | mel_path = os.path.join(self.data_path, args.mel_dir, self.fpaths[idx]) 37 | mel = torch.tensor(np.load(mel_path)) 38 | else: 39 | mel = self.mels[idx] 40 | 41 | if self.model_name == 'Text2Mel': 42 | if not self.ga_mode: 43 | return (text, mel) 44 | else: 45 | # Guided attention mode 46 | return (text, mel, self.g_att[idx]) 47 | 48 | mag_path = os.path.join(self.data_path, args.mag_dir, self.fpaths[idx]) 49 | mag = torch.tensor(np.load(mag_path)) 50 | return (text, mel, mag) 51 | 52 | def __len__(self): 53 | return len(self.fpaths) 54 | 55 | def load_vocab(): 56 | char2idx = {char: idx for idx, char in enumerate(args.vocab)} 57 | idx2char = {idx: char for idx, char in enumerate(args.vocab)} 58 | return char2idx, idx2char 59 | 60 | def text_normalize(text): 61 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 62 | if unicodedata.category(char) != 'Mn') # Strip accents 63 | text = text.lower() 64 | text = re.sub(u"[^{}]".format(args.vocab), " ", text) 65 | text = re.sub("[ ]+", " ", text) 66 | return text 67 | 68 | def read_meta(path): 69 | ''' 70 | If we use pandas instead of this function, it may not cover quotes. 71 | Args: 72 | path: metadata path 73 | Returns: 74 | fpaths, texts, norms 75 | ''' 76 | char2idx, _ = load_vocab() 77 | lines = codecs.open(path, 'r', 'utf-8').readlines() 78 | fpaths, texts, norms = [], [], [] 79 | for line in lines: 80 | fname, text, norm = line.strip().split('|') 81 | fpath = fname + '.npy' 82 | text = text_normalize(text).strip() + u'E' # ␃: EOS 83 | text = [char2idx[char] for char in text] 84 | norm = text_normalize(norm).strip() + u'E' # ␃: EOS 85 | norm = [char2idx[char] for char in norm] 86 | fpaths.append(fpath) 87 | texts.append(text) 88 | norms.append(norm) 89 | return fpaths, texts, norms 90 | 91 | def collate_fn(data): 92 | """ 93 | Creates mini-batch tensors from the list of tuples (texts, mels, mags). 94 | Args: 95 | data: list of tuple (texts, mels, mags). 96 | - texts: torch tensor of shape (B, Tx). 97 | - mels: torch tensor of shape (B, Ty/4, n_mels). 98 | - mags: torch tensor of shape (B, Ty, n_mags). 99 | Returns: 100 | texts: torch tensor of shape (batch_size, padded_length). 101 | mels: torch tensor of shape (batch_size, padded_length, n_mels). 102 | mels: torch tensor of shape (batch_size, padded_length, n_mags). 103 | """ 104 | # Sort a data list by text length (descending order). 105 | data.sort(key=lambda x: len(x[0]), reverse=True) 106 | texts, mels, mags = zip(*data) 107 | 108 | # Merge (from tuple of 1D tensor to 2D tensor). 109 | text_lengths = [len(text) for text in texts] 110 | mel_lengths = [len(mel) for mel in mels] 111 | mag_lengths = [len(mag) for mag in mags] 112 | # (number of mels, max_len, feature_dims) 113 | text_pads = torch.zeros(len(texts), max(text_lengths), dtype=torch.long) 114 | mel_pads = torch.zeros(len(mels), max(mel_lengths), mels[0].shape[-1]) 115 | mag_pads = torch.zeros(len(mags), max(mag_lengths), mags[0].shape[-1]) 116 | for idx in range(len(mels)): 117 | text_end = text_lengths[idx] 118 | text_pads[idx, :text_end] = texts[idx] 119 | mel_end = mel_lengths[idx] 120 | mel_pads[idx, :mel_end] = mels[idx] 121 | mag_end = mag_lengths[idx] 122 | mag_pads[idx, :mag_end] = mags[idx] 123 | return text_pads, mel_pads, mag_pads 124 | 125 | def t2m_collate_fn(data): 126 | """ 127 | Creates mini-batch tensors from the list of tuples (texts, mels, mags). 128 | Args: 129 | data: list of tuple (texts). 130 | - texts: torch tensor of shape (B, Tx). 131 | - mels: torch tensor of shape (B, Ty/4, n_mels). 132 | Returns: 133 | texts: torch tensor of shape (batch_size, padded_length). 134 | mels: torch tensor of shape (batch_size, padded_length, n_mels). 135 | """ 136 | # Sort a data list by text length (descending order). 137 | data.sort(key=lambda x: len(x[0]), reverse=True) 138 | texts, mels = zip(*data) 139 | 140 | # Merge (from tuple of 1D tensor to 2D tensor). 141 | text_lengths = [len(text) for text in texts] 142 | mel_lengths = [len(mel) for mel in mels] 143 | # (number of mels, max_len, feature_dims) 144 | text_pads = torch.zeros(len(texts), max(text_lengths), dtype=torch.long) 145 | mel_pads = torch.zeros(len(mels), max(mel_lengths), mels[0].shape[-1]) 146 | for idx in range(len(mels)): 147 | text_end = text_lengths[idx] 148 | text_pads[idx, :text_end] = texts[idx] 149 | mel_end = mel_lengths[idx] 150 | mel_pads[idx, :mel_end] = mels[idx] 151 | return text_pads, mel_pads, None 152 | 153 | def t2m_ga_collate_fn(data): 154 | """ 155 | Creates mini-batch tensors from the list of tuples (texts, mels, mags). 156 | Args: 157 | data: list of tuple (texts). 158 | - texts: torch tensor of shape (B, Tx). 159 | - mels: torch tensor of shape (B, Ty/4, n_mels). 160 | - gas: torch tensor of shape (B, max_Tx, max_Ty). 161 | Returns: 162 | texts: torch tensor of shape (B, padded_length). 163 | mels: torch tensor of shape (B, padded_length, n_mels). 164 | gas: torch tensor of shape (B, Tx, Ty/4) 165 | """ 166 | # Sort a data list by text length (descending order). 167 | data.sort(key=lambda x: len(x[0]), reverse=True) 168 | texts, mels, gas = zip(*data) 169 | # Merge (from tuple of 1D tensor to 2D tensor). 170 | text_lengths = [len(text) for text in texts] 171 | mel_lengths = [len(mel) for mel in mels] 172 | # (number of mels, max_len, feature_dims) 173 | text_pads = torch.zeros(len(texts), max(text_lengths), dtype=torch.long) 174 | mel_pads = torch.zeros(len(mels), max(mel_lengths), mels[0].shape[-1]) 175 | ga_pads = torch.zeros(len(mels), max(text_lengths), max(mel_lengths)) 176 | for idx in range(len(mels)): 177 | text_end = text_lengths[idx] 178 | text_pads[idx, :text_end] = texts[idx] 179 | mel_end = mel_lengths[idx] 180 | mel_pads[idx, :mel_end] = mels[idx] 181 | ga_pads[idx] = gas[idx][:max(text_lengths), :max(mel_lengths)] 182 | return text_pads, mel_pads, ga_pads 183 | 184 | class TextDataset(Dataset): 185 | def __init__(self, text_path): 186 | ''' 187 | Args: 188 | text path (str): path to text set 189 | ''' 190 | self.texts = read_text(text_path) 191 | 192 | def __getitem__(self, idx): 193 | text = torch.tensor(self.texts[idx], dtype=torch.long) 194 | return text 195 | 196 | def __len__(self): 197 | return len(self.texts) 198 | 199 | 200 | def read_text(path): 201 | ''' 202 | If we use pandas instead of this function, it may not cover quotes. 203 | Args: 204 | path: metadata path 205 | Returns: 206 | fpaths, texts, norms 207 | ''' 208 | char2idx, _ = load_vocab() 209 | lines = codecs.open(path, 'r', 'utf-8').readlines()[1:] 210 | texts = [] 211 | for line in lines: 212 | text = text_normalize(line.split(' ', 1)[-1]).strip() + u'E' # ␃: EOS 213 | text = [char2idx[char] for char in text] 214 | texts.append(text) 215 | return texts 216 | 217 | def synth_collate_fn(data): 218 | """ 219 | Creates mini-batch tensors from the list of tuples (texts, mels, mags). 220 | Args: 221 | data: list of tuple (texts,). 222 | - texts: torch tensor of shape (B, Tx). 223 | Returns: 224 | texts: torch tensor of shape (batch_size, padded_length). 225 | """ 226 | texts = data 227 | 228 | # Merge (from tuple of 1D tensor to 2D tensor). 229 | text_lengths = [len(text) for text in texts] 230 | # (number of mels, max_len, feature_dims) 231 | text_pads = torch.zeros(len(texts), max(text_lengths), dtype=torch.long) 232 | for idx in range(len(texts)): 233 | text_end = text_lengths[idx] 234 | text_pads[idx, :text_end] = texts[idx] 235 | return text_pads, None, None 236 | -------------------------------------------------------------------------------- /lj_eval_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yangyangii/DeepConvolutionalTTS-pytorch/4807de0dba0398fd6e757d9f81cc3e572f3c7f94/lj_eval_idx.npy -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from config import ConfigArgs as args 2 | import torch 3 | import torch.nn as nn 4 | from network import TextEncoder, AudioEncoder, AudioDecoder, DotProductAttention 5 | from torch.nn.utils import weight_norm as norm 6 | import module as mm 7 | 8 | class Text2Mel(nn.Module): 9 | """ 10 | Text2Mel 11 | Args: 12 | L: (N, Tx) text 13 | S: (N, Ty/r, n_mels) previous audio 14 | Returns: 15 | Y: (N, Ty/r, n_mels) 16 | """ 17 | def __init__(self): 18 | super(Text2Mel, self).__init__() 19 | self.name = 'Text2Mel' 20 | self.embed = nn.Embedding(len(args.vocab), args.Ce, padding_idx=0) 21 | self.TextEnc = TextEncoder() 22 | self.AudioEnc = AudioEncoder() 23 | self.Attention = DotProductAttention() 24 | self.AudioDec = AudioDecoder() 25 | 26 | def forward(self, L, S): 27 | L = self.embed(L).transpose(1,2) # -> (N, Cx, Tx) for conv1d 28 | S = S.transpose(1,2) # (N, n_mels, Ty/r) for conv1d 29 | K, V = self.TextEnc(L) # (N, Cx, Tx) respectively 30 | Q = self.AudioEnc(S) # -> (N, Cx, Ty/r) 31 | R, A = self.Attention(K, V, Q) # -> (N, Cx, Ty/r) 32 | R_ = torch.cat((R, Q), 1) # -> (N, Cx*2, Ty/r) 33 | Y = self.AudioDec(R_) # -> (N, n_mels, Ty/r) 34 | return Y.transpose(1, 2), A # (N, Ty/r, n_mels) 35 | 36 | class SSRN(nn.Module): 37 | """ 38 | SSRN 39 | Args: 40 | Y: (N, Ty/r, n_mels) 41 | Returns: 42 | Z: (N, Ty, n_mags) 43 | """ 44 | def __init__(self): 45 | super(SSRN, self).__init__() 46 | self.name = 'SSRN' 47 | # (N, n_mels, Ty/r) -> (N, Cs, Ty/r) 48 | self.hc_blocks = nn.ModuleList([norm(mm.Conv1d(args.n_mels, args.Cs, 1, activation_fn=torch.relu))]) 49 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cs, args.Cs, 3, dilation=3**i)) 50 | for i in range(2)]) 51 | # (N, Cs, Ty/r*2) -> (N, Cs, Ty/r*2) 52 | self.hc_blocks.extend([norm(mm.ConvTranspose1d(args.Cs, args.Cs, 4, stride=2, padding=1))]) 53 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cs, args.Cs, 3, dilation=3**i)) 54 | for i in range(2)]) 55 | # (N, Cs, Ty/r*2) -> (N, Cs, Ty/r*4==Ty) 56 | self.hc_blocks.extend([norm(mm.ConvTranspose1d(args.Cs, args.Cs, 4, stride=2, padding=1))]) 57 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cs, args.Cs, 3, dilation=3**i)) 58 | for i in range(2)]) 59 | # (N, Cs, Ty) -> (N, Cs*2, Ty) 60 | self.hc_blocks.extend([norm(mm.Conv1d(args.Cs, args.Cs*2, 1))]) 61 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cs*2, args.Cs*2, 3, dilation=1)) 62 | for i in range(2)]) 63 | # (N, Cs*2, Ty) -> (N, n_mags, Ty) 64 | self.hc_blocks.extend([norm(mm.Conv1d(args.Cs*2, args.n_mags, 1))]) 65 | self.hc_blocks.extend([norm(mm.Conv1d(args.n_mags, args.n_mags, 1, activation_fn=torch.relu)) 66 | for i in range(2)]) 67 | self.hc_blocks.extend([norm(mm.Conv1d(args.n_mags, args.n_mags, 1))]) 68 | 69 | def forward(self, Y): 70 | Y = Y.transpose(1, 2) # -> (N, n_mels, Ty/r) 71 | Z = Y 72 | # -> (N, n_mags, Ty) 73 | for i in range(len(self.hc_blocks)): 74 | Z = self.hc_blocks[i](Z) 75 | Z = torch.sigmoid(Z) 76 | return Z.transpose(1, 2) # (N, Ty, n_mags) 77 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | from config import ConfigArgs as args 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | class Conv1d(nn.Conv1d): 7 | """ 8 | Hightway Convolution 1d 9 | Args: 10 | x: (N, C_in, L) 11 | Returns: 12 | y: (N, C_out, L) 13 | """ 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0., 16 | stride=1, padding='same', dilation=1, groups=1, bias=True): 17 | self.activation_fn = activation_fn 18 | self.drop_rate = drop_rate 19 | if padding == 'same': 20 | padding = kernel_size // 2 * dilation 21 | super(Conv1d, self).__init__(in_channels, out_channels, kernel_size, 22 | stride=stride, padding=padding, dilation=dilation, 23 | groups=groups, bias=bias) 24 | self.drop_out = nn.Dropout(self.drop_rate) if self.drop_rate > 0 else None 25 | 26 | def forward(self, x): 27 | y = super(Conv1d, self).forward(x) 28 | y = self.activation_fn(y) if self.activation_fn is not None else y 29 | y = self.drop_out(y) if self.drop_out is not None else y 30 | return y 31 | 32 | class HighwayConv1d(Conv1d): 33 | """ 34 | Hightway Convolution 1d 35 | Args: 36 | x: (N, C_in, T) 37 | Returns: 38 | y: (N, C_out, T) 39 | """ 40 | def __init__(self, in_channels, out_channels, kernel_size, drop_rate=0., 41 | stride=1, padding='same', dilation=1, groups=1, bias=True): 42 | self.drop_rate = drop_rate 43 | super(HighwayConv1d, self).__init__(in_channels, out_channels*2, kernel_size, activation_fn=None, 44 | stride=stride, padding=padding, dilation=dilation, 45 | groups=groups, bias=bias) 46 | self.drop_out = nn.Dropout(self.drop_rate) if drop_rate > 0 else None 47 | 48 | def forward(self, x): 49 | y = super(HighwayConv1d, self).forward(x) # (N, C_out*2, T) 50 | h, y_ = y.chunk(2, dim=1) # half size for axis C_out. (N, C_out, T) respectively 51 | h = torch.sigmoid(h) # Gate 52 | y_ = torch.relu(y_) 53 | y_ = h*y_ + (1-h)*x 54 | y_ = self.drop_out(y_) if self.drop_out is not None else y_ 55 | return y_ 56 | 57 | class CausalConv1d(Conv1d): 58 | """ 59 | Causal convolution 1d 60 | Args: 61 | x: (N, C_in, L) 62 | Returns: 63 | y: (N, C_out, L) 64 | """ 65 | def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0., 66 | stride=1, dilation=1, groups=1, bias=True): 67 | padding = (kernel_size - 1) * dilation 68 | super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size, 69 | activation_fn=activation_fn, drop_rate=drop_rate, 70 | stride=stride, padding=padding, dilation=dilation, 71 | groups=groups, bias=bias) 72 | 73 | def forward(self, x): 74 | y = super(CausalConv1d, self).forward(x) 75 | return y[:, :, :x.size(2)] # (N, C, :-(ksize-1)) slicing 76 | 77 | class CausalHighwayConv1d(CausalConv1d): 78 | """ 79 | Causal convolution 1d 80 | Args: 81 | x: (N, C_in, L) 82 | Returns: 83 | y: (N, C_out, L) 84 | """ 85 | def __init__(self, in_channels, out_channels, kernel_size, drop_rate=0., 86 | stride=1, dilation=1, groups=1, bias=True): 87 | self.drop_rate = drop_rate 88 | super(CausalHighwayConv1d, self).__init__(in_channels, out_channels*2, kernel_size, 89 | activation_fn=None, 90 | stride=stride, dilation=dilation, 91 | groups=groups, bias=bias) 92 | self.drop_out = nn.Dropout(self.drop_rate) if self.drop_rate > 0 else None 93 | 94 | def forward(self, x): 95 | y = super(CausalHighwayConv1d, self).forward(x) 96 | h, y_ = y.chunk(2, dim=1) # half size for axis C_out 97 | h = torch.sigmoid(h) # Gate 98 | y_ = torch.relu(y_) 99 | y_ = h*y_ + (1.0-h)*x 100 | y_ = self.drop_out(y_) if self.drop_out is not None else y_ 101 | return y_ 102 | 103 | class ConvTranspose1d(nn.ConvTranspose1d): 104 | """ 105 | Transposed Convolution 1d 106 | Args: 107 | x: (N, C_in, L) 108 | Returns: 109 | y: (N, C_out, L*stride) naive shape 110 | """ 111 | 112 | def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0., 113 | stride=1, padding='same', dilation=1, groups=1, bias=True): 114 | self.activation_fn = activation_fn 115 | self.drop_rate = drop_rate 116 | if padding == 'same': 117 | padding = kernel_size // 2 * dilation 118 | super(ConvTranspose1d, self).__init__(in_channels, out_channels, kernel_size, 119 | stride=stride, padding=padding, dilation=dilation, 120 | groups=groups, bias=bias) 121 | self.drop_out = nn.Dropout(self.drop_rate) if self.drop_rate > 0 else None 122 | 123 | def forward(self, x): 124 | y = super(ConvTranspose1d, self).forward(x) 125 | y = self.activation_fn(y) if self.activation_fn is not None else y 126 | y = self.drop_out(y) if self.drop_out is not None else y 127 | return y 128 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from config import ConfigArgs as args 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils import weight_norm as norm 5 | import numpy as np 6 | import module as mm 7 | 8 | class TextEncoder(nn.Module): 9 | """ 10 | Text Encoder 11 | T: (N, Cx, Tx) Text embedding (variable length) 12 | Returns: 13 | K: (N, Cx, Tx) Text Encoding for Key 14 | V: (N, Cx, Tx) Text Encoding for Value 15 | """ 16 | def __init__(self): 17 | super(TextEncoder, self).__init__() 18 | self.hc_blocks = nn.ModuleList([norm(mm.Conv1d(args.Ce, args.Cx*2, 1, padding='same', activation_fn=torch.relu))]) # filter up to split into K, V 19 | self.hc_blocks.extend([norm(mm.Conv1d(args.Cx*2, args.Cx*2, 1, padding='same', activation_fn=None))]) 20 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cx*2, args.Cx*2, 3, dilation=3**i, padding='same')) 21 | for _ in range(2) for i in range(4)]) 22 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cx*2, args.Cx*2, 3, dilation=1, padding='same')) 23 | for i in range(2)]) 24 | self.hc_blocks.extend([norm(mm.HighwayConv1d(args.Cx*2, args.Cx*2, 1, dilation=1, padding='same')) 25 | for i in range(2)]) 26 | 27 | def forward(self, L): 28 | y = L 29 | for i in range(len(self.hc_blocks)): 30 | y = self.hc_blocks[i](y) 31 | K, V = y.chunk(2, dim=1) # half size for axis Cx 32 | return K, V 33 | 34 | class AudioEncoder(nn.Module): 35 | """ 36 | Text Encoder 37 | prev_audio: (N, n_mels, Ty/r) Mel-spectrogram (variable length) 38 | Returns: 39 | Q: (N, Cx, Ty/r) Audio Encoding for Query 40 | """ 41 | 42 | def __init__(self): 43 | super(AudioEncoder, self).__init__() 44 | self.hc_blocks = nn.ModuleList([norm(mm.CausalConv1d(args.n_mels, args.Cx, 1, activation_fn=torch.relu))]) 45 | self.hc_blocks.extend([norm(mm.CausalConv1d(args.Cx, args.Cx, 1, activation_fn=torch.relu)) 46 | for _ in range(2)]) 47 | self.hc_blocks.extend([norm(mm.CausalHighwayConv1d(args.Cx, args.Cx, 3, dilation=3**i)) # i is in [[0,1,2,3],[0,1,2,3]] 48 | for _ in range(2) for i in range(4)]) 49 | self.hc_blocks.extend([norm(mm.CausalHighwayConv1d(args.Cx, args.Cx, 3, dilation=3)) 50 | for i in range(2)]) 51 | # self.hc_blocks.extend([mm.CausalConv1d(args.Cy, args.Cx, 1, dilation=1, activation_fn=torch.relu)]) # down #filters to dotproduct K, V 52 | 53 | def forward(self, S): 54 | Q = S 55 | for i in range(len(self.hc_blocks)): 56 | Q = self.hc_blocks[i](Q) 57 | return Q 58 | 59 | class DotProductAttention(nn.Module): 60 | """ 61 | Dot Product Attention 62 | Args: 63 | K: (N, Cx, Tx) 64 | V: (N, Cx, Tx) 65 | Q: (N, Cx, Ty) 66 | Returns: 67 | R: (N, Cx, Ty) 68 | A: (N, Tx, Ty) alignments 69 | """ 70 | 71 | def __init__(self): 72 | super(DotProductAttention, self).__init__() 73 | 74 | def forward(self, K, V, Q): 75 | A = torch.softmax((torch.bmm(K.transpose(1, 2), Q)/np.sqrt(args.Cx)), dim=1) # K.T.dot(Q) -> (N, Tx, Ty) 76 | R = torch.bmm(V, A) # (N, Cx, Ty) 77 | return R, A 78 | 79 | class AudioDecoder(nn.Module): 80 | """ 81 | Dot Product Attention 82 | Args: 83 | R_: (N, Cx*2, Ty) 84 | Returns: 85 | O: (N, n_mels, Ty) 86 | """ 87 | def __init__(self): 88 | super(AudioDecoder, self).__init__() 89 | self.hc_blocks = nn.ModuleList([norm(mm.CausalConv1d(args.Cx*2, args.Cy, 1, activation_fn=torch.relu))]) 90 | self.hc_blocks.extend([norm(mm.CausalHighwayConv1d(args.Cy, args.Cy, 3, dilation=3**i)) 91 | for i in range(4)]) 92 | self.hc_blocks.extend([norm(mm.CausalHighwayConv1d(args.Cy, args.Cy, 3, dilation=1)) 93 | for _ in range(2)]) 94 | self.hc_blocks.extend([norm(mm.CausalConv1d(args.Cy, args.Cy, 1, dilation=1, activation_fn=torch.relu)) 95 | for _ in range(3)]) 96 | self.hc_blocks.extend([norm(mm.CausalConv1d(args.Cy, args.n_mels, 1, dilation=1))]) # down #filters to dotproduct K, V 97 | 98 | def forward(self, R_): 99 | Y = R_ 100 | for i in range(len(self.hc_blocks)): 101 | Y = self.hc_blocks[i](Y) 102 | return torch.sigmoid(Y) 103 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from config import ConfigArgs as args 3 | from utils import load_spectrogram, prepro_guided_attention 4 | import os, sys 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | import codecs 10 | import data 11 | 12 | NUM_JOBS = 8 13 | 14 | def f(f_args): 15 | fpath, text = f_args 16 | mel, mag = load_spectrogram(os.path.join(args.data_path, 'wavs', fpath.replace('npy', 'wav'))) 17 | np.save(os.path.join(args.data_path, args.ga_dir, fpath), prepro_guided_attention(len(text), len(mel), g=args.g)) 18 | np.save(os.path.join(args.data_path, args.mel_dir, fpath), mel) 19 | np.save(os.path.join(args.data_path, args.mag_dir, fpath), mag) 20 | return None 21 | 22 | def prepro_signal(): 23 | print('Preprocessing signal') 24 | # Load data 25 | fpaths, texts, _ = data.read_meta(os.path.join(args.data_path, args.meta)) 26 | 27 | # Creates folders 28 | if not os.path.exists(os.path.join(args.data_path, args.mel_dir)): 29 | os.mkdir(os.path.join(args.data_path, args.mel_dir)) 30 | if not os.path.exists(os.path.join(args.data_path, args.mag_dir)): 31 | os.mkdir(os.path.join(args.data_path, args.mag_dir)) 32 | if not os.path.exists(os.path.join(args.data_path, args.ga_dir)): 33 | os.mkdir(os.path.join(args.data_path, args.ga_dir)) 34 | 35 | # Creates pool 36 | p = Pool(NUM_JOBS) 37 | 38 | total_files = len(fpaths) 39 | with tqdm(total=total_files) as pbar: 40 | for _ in tqdm(p.imap_unordered(f, list(zip(fpaths,texts)))): 41 | pbar.update() 42 | 43 | def prepro_meta(): 44 | ## train(95%)/test(5%) split for metadata 45 | print('Preprocessing meta') 46 | # Parse 47 | transcript = os.path.join(args.data_path, 'metadata.csv') 48 | train_transcript = os.path.join(args.data_path, 'meta-train.csv') 49 | test_transcript = os.path.join(args.data_path, 'meta-eval.csv') 50 | 51 | lines = codecs.open(transcript, 'r', 'utf-8').readlines() 52 | train_f = codecs.open(train_transcript, 'w', 'utf-8') 53 | test_f = codecs.open(test_transcript, 'w', 'utf-8') 54 | 55 | test_idx = np.load('lj_eval_idx.npy') 56 | 57 | for idx, line in enumerate(lines): 58 | if idx in test_idx: 59 | test_f.write(line) 60 | else: 61 | train_f.write(line) 62 | print('# of train set: {}, # of test set: {}'.format(1+idx-len(test_idx), len(test_idx))) 63 | print('Complete') 64 | 65 | if __name__ == '__main__': 66 | is_signal = sys.argv[1] 67 | is_meta = sys.argv[2] 68 | print('Signal: {}, Meta: {}'.format(is_signal, is_meta)) 69 | 70 | if is_signal in ['1', 'True']: 71 | prepro_signal() 72 | if is_meta in ['1', 'True']: 73 | prepro_meta() 74 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | from config import ConfigArgs as args 2 | import os, sys 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm, trange 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from model import Text2Mel, SSRN 15 | from data import TextDataset, synth_collate_fn, load_vocab 16 | import utils 17 | from scipy.io.wavfile import write 18 | 19 | 20 | def synthesize(t2m, ssrn, data_loader, batch_size=100): 21 | ''' 22 | DCTTS Architecture 23 | Text --> Text2Mel --> SSRN --> Wav file 24 | ''' 25 | # Text2Mel 26 | idx2char = load_vocab()[-1] 27 | with torch.no_grad(): 28 | print('='*10, ' Text2Mel ', '='*10) 29 | total_mel_hats = torch.zeros([len(data_loader.dataset), args.max_Ty, args.n_mels]).to(DEVICE) 30 | mags = torch.zeros([len(data_loader.dataset), args.max_Ty*args.r, args.n_mags]).to(DEVICE) 31 | for step, (texts, _, _) in enumerate(data_loader): 32 | texts = texts.to(DEVICE) 33 | prev_mel_hats = torch.zeros([len(texts), args.max_Ty, args.n_mels]).to(DEVICE) 34 | for t in tqdm(range(args.max_Ty-1), unit='B', ncols=70): 35 | mel_hats, A = t2m(texts, prev_mel_hats) # mel: (N, Ty/r, n_mels) 36 | prev_mel_hats[:, t+1, :] = mel_hats[:, t, :] 37 | total_mel_hats[step*batch_size:(step+1)*batch_size, :, :] = prev_mel_hats 38 | 39 | print('='*10, ' Alignment ', '='*10) 40 | alignments = A.cpu().detach().numpy() 41 | visual_texts = texts.cpu().detach().numpy() 42 | for idx in range(len(alignments)): 43 | text = [idx2char[ch] for ch in visual_texts[idx]] 44 | utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{}.png'.format(idx)) 45 | print('='*10, ' SSRN ', '='*10) 46 | # Mel --> Mag 47 | mags[step*batch_size:(step+1)*batch_size:, :, :] = \ 48 | ssrn(total_mel_hats[step*batch_size:(step+1)*batch_size, :, :]) # mag: (N, Ty, n_mags) 49 | mags = mags.cpu().detach().numpy() 50 | print('='*10, ' Vocoder ', '='*10) 51 | for idx in trange(len(mags), unit='B', ncols=70): 52 | wav = utils.spectrogram2wav(mags[idx]) 53 | write(os.path.join(args.sampledir, '{}.wav'.format(idx+1)), args.sr, wav) 54 | return None 55 | 56 | def main(): 57 | testset = TextDataset(args.testset) 58 | test_loader = DataLoader(dataset=testset, batch_size=args.test_batch, drop_last=False, 59 | shuffle=False, collate_fn=synth_collate_fn, pin_memory=True) 60 | 61 | t2m = Text2Mel().to(DEVICE) 62 | ssrn = SSRN().to(DEVICE) 63 | 64 | ckpt = pd.read_csv(os.path.join(args.logdir, t2m.name, 'ckpt.csv'), sep=',', header=None) 65 | ckpt.columns = ['models', 'loss'] 66 | ckpt = ckpt.sort_values(by='loss', ascending=True) 67 | state = torch.load(os.path.join(args.logdir, t2m.name, ckpt.models.loc[0])) 68 | t2m.load_state_dict(state['model']) 69 | args.global_step = state['global_step'] 70 | 71 | ckpt = pd.read_csv(os.path.join(args.logdir, ssrn.name, 'ckpt.csv'), sep=',', header=None) 72 | ckpt.columns = ['models', 'loss'] 73 | ckpt = ckpt.sort_values(by='loss', ascending=True) 74 | state = torch.load(os.path.join(args.logdir, ssrn.name, ckpt.models.loc[0])) 75 | ssrn.load_state_dict(state['model']) 76 | 77 | print('All of models are loaded.') 78 | 79 | t2m.eval() 80 | ssrn.eval() 81 | 82 | if not os.path.exists(os.path.join(args.sampledir, 'A')): 83 | os.makedirs(os.path.join(args.sampledir, 'A')) 84 | synthesize(t2m, ssrn, test_loader, args.test_batch) 85 | 86 | if __name__ == '__main__': 87 | gpu_id = int(sys.argv[1]) 88 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 89 | os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_id) 90 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 91 | main() 92 | -------------------------------------------------------------------------------- /test_sents.txt: -------------------------------------------------------------------------------- 1 | Audio samples from "Towards End-to-End Prosody Transfer for Expressive Speech Synthesis with Tacotron" https://google.github.io/tacotron/publications/end_to_end_prosody_transfer/ 2 | 1. How do bureaucrats wrap presents? With lots of red tape. 3 | 2. Why are libraries so strict? They have to go by the book. 4 | 3. Why are fish so smart? Because they hang out in schools so much. 5 | 4. Heaps of things. Like fairy bread, how the surf is today and why magpies swoop. 6 | 5. The past, the present, and the future walk into a bar. It was tense. 7 | 6. I usually down a cup of java script. Then I put on nature sounds and run a few strenuous searches to improve my speed 8 | 7. I don't have eyes, but I don't need them to know the vibe in here feels good 9 | 8. What time do you go to the dentist? At tooth-hurty! 10 | 9. Sweet dreams are made of these. Friendly Assistants who work hard to please 11 | 10. You are what you eat. So I guess I'm a whole lot of data and a little bit of pizza recipes. 12 | 11. Men say they know many things; But lo! they have taken wings, The arts and sciences, And a thousand appliances; The wind that blows Is all that any body knows. 13 | 12. Do you prefer chocolate or jelly? Which would you like in your belly? You could make a good case, For a cool ice cream base, But I'd argue against vermicelli 14 | 13. Halloween Edition it is! Remember to follow the moves as I say them. 15 | 14. Why are archaeologists so annoyed? They always have a bone to pick. 16 | 15. That one sailed RIGHT over my head. 17 | 16. Wear your heart on your sleeve. It'll terrify people. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import ConfigArgs as args 2 | import os, sys 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | from torch.optim.lr_scheduler import StepLR 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm, trange 12 | from tensorboardX import SummaryWriter 13 | 14 | import numpy as np 15 | import pandas as pd 16 | from collections import deque 17 | from model import Text2Mel, SSRN 18 | from data import SpeechDataset, collate_fn, t2m_collate_fn, t2m_ga_collate_fn, load_vocab 19 | from utils import att2img, spectrogram2wav, plot_att 20 | 21 | def train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=32, ckpt_dir=None, writer=None, mode='1'): 22 | epochs = 0 23 | global_step = args.global_step 24 | l1_criterion = nn.L1Loss().to(DEVICE) # default average 25 | bd_criterion = nn.BCELoss().to(DEVICE) 26 | model_infos = [('None', 10000.)]*5 27 | first_frames = torch.zeros([batch_size, 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels) 28 | idx2char = load_vocab()[-1] 29 | while global_step < args.max_step: 30 | epoch_loss = 0 31 | for step, (texts, mels, extras) in tqdm(enumerate(data_loader), total=len(data_loader), unit='B', ncols=70, leave=False): 32 | optimizer.zero_grad() 33 | if model.name == 'Text2Mel': 34 | if args.ga_mode: 35 | texts, mels, gas = texts.to(DEVICE), mels.to(DEVICE), extras.to(DEVICE) 36 | else: 37 | texts, mels = texts.to(DEVICE), mels.to(DEVICE) 38 | prev_mels = torch.cat((first_frames, mels[:, :-1, :]), 1) 39 | mels_hat, A = model(texts, prev_mels) # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r) 40 | if args.ga_mode: 41 | l1_loss = l1_criterion(mels_hat, mels) 42 | bd_loss = bd_criterion(mels_hat, mels) 43 | att_loss = torch.mean(A*gas) 44 | loss = l1_loss + bd_loss + att_loss 45 | else: 46 | l1_loss = l1_criterion(mels_hat, mels) 47 | bd_loss = bd_criterion(mels_hat, mels) 48 | loss = l1_loss + bd_loss 49 | elif model.name == 'SSRN': 50 | texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), extras.to(DEVICE) 51 | mags_hat = model(mels) # mags_hat: (N, Ty, n_mags) 52 | l1_loss = l1_criterion(mags_hat, mags) 53 | bd_loss = bd_criterion(mags_hat, mags) 54 | loss = l1_loss + bd_loss 55 | loss.backward() 56 | nn.utils.clip_grad_norm_(model.parameters(), 2.0) 57 | scheduler.step() 58 | optimizer.step() 59 | epoch_loss += l1_loss.item() 60 | global_step += 1 61 | if global_step % args.save_term == 0: 62 | model.eval() 63 | val_loss = evaluate(model, valid_loader, l1_criterion, writer, global_step, args.test_batch) 64 | model_infos = save_model(model, model_infos, optimizer, scheduler, val_loss, global_step, ckpt_dir) # save best 5 models 65 | model.train() 66 | if args.log_mode: 67 | # Summary 68 | avg_loss = epoch_loss / (len(data_loader)) 69 | writer.add_scalar('train/loss', avg_loss, global_step) 70 | writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) 71 | if model.name == 'Text2Mel': 72 | alignment = A[0:1].clone().cpu().detach().numpy() 73 | writer.add_image('train/alignments', att2img(alignment), global_step) # (Tx, Ty) 74 | if args.ga_mode: 75 | writer.add_scalar('train/loss_att', att_loss, global_step) 76 | text = texts[0].cpu().detach().numpy() 77 | text = [idx2char[ch] for ch in text] 78 | plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A', 'train')) 79 | mel_hat = mels_hat[0:1].transpose(1,2) 80 | mel = mels[0:1].transpose(1, 2) 81 | writer.add_image('train/mel_hat', mel_hat, global_step) 82 | writer.add_image('train/mel', mel, global_step) 83 | else: 84 | mag_hat = mags_hat[0:1].transpose(1, 2) 85 | mag = mags[0:1].transpose(1, 2) 86 | writer.add_image('train/mag_hat', mag_hat, global_step) 87 | writer.add_image('train/mag', mag, global_step) 88 | # print('Training Loss: {}'.format(avg_loss)) 89 | epochs += 1 90 | print('Training complete') 91 | 92 | def evaluate(model, data_loader, criterion, writer, global_step, batch_size=100): 93 | valid_loss = 0. 94 | A = None 95 | with torch.no_grad(): 96 | for step, (texts, mels, extras) in enumerate(data_loader): 97 | if model.name == 'Text2Mel': 98 | first_frames = torch.zeros([mels.shape[0], 1, args.n_mels]).to(DEVICE) # (N, Ty/r, n_mels) 99 | texts, mels = texts.to(DEVICE), mels.to(DEVICE) 100 | prev_mels = torch.cat((first_frames, mels[:, :-1, :]), 1) 101 | mels_hat, A = model(texts, prev_mels) # mels_hat: (N, Ty/r, n_mels), A: (N, Tx, Ty/r) 102 | loss = criterion(mels_hat, mels) 103 | elif model.name == 'SSRN': 104 | texts, mels, mags = texts.to(DEVICE), mels.to(DEVICE), extras.to(DEVICE) 105 | mags_hat = model(mels) # Predict 106 | loss = criterion(mags_hat, mags) 107 | valid_loss += loss.item() 108 | avg_loss = valid_loss / (len(data_loader)) 109 | writer.add_scalar('eval/loss', avg_loss, global_step) 110 | if model.name == 'Text2Mel': 111 | alignment = A[0:1].clone().cpu().detach().numpy() 112 | writer.add_image('eval/alignments', att2img(alignment), global_step) # (Tx, Ty) 113 | text = texts[0].cpu().detach().numpy() 114 | text = [load_vocab()[-1][ch] for ch in text] 115 | plot_att(alignment[0], text, global_step, path=os.path.join(args.logdir, model.name, 'A')) 116 | mel_hat = mels_hat[0:1].transpose(1,2) 117 | mel = mels[0:1].transpose(1, 2) 118 | writer.add_image('eval/mel_hat', mel_hat, global_step) 119 | writer.add_image('eval/mel', mel, global_step) 120 | else: 121 | mag_hat = mags_hat[0:1].transpose(1, 2) 122 | mag = mags[0:1].transpose(1, 2) 123 | writer.add_image('eval/mag_hat', mag_hat, global_step) 124 | writer.add_image('eval/mag', mag, global_step) 125 | return avg_loss 126 | 127 | def save_model(model, model_infos, optimizer, scheduler, val_loss, global_step, ckpt_dir): 128 | cur_ckpt = 'model-{}k.pth.tar'.format(global_step//1000) 129 | prev_ckpt = 'model-{}k.pth.tar'.format(global_step//1000-(args.save_term//1000)) 130 | state = { 131 | 'global_step': global_step, 132 | 'name': model.name, 133 | 'model': model.state_dict(), 134 | 'loss': val_loss, 135 | 'optimizer': optimizer.state_dict(), 136 | 'scheduler': scheduler.state_dict(), 137 | } 138 | torch.save(state, os.path.join(ckpt_dir, cur_ckpt)) 139 | if prev_ckpt not in dict(model_infos).keys() and os.path.exists(os.path.join(ckpt_dir, prev_ckpt)): 140 | os.remove(os.path.join(ckpt_dir, prev_ckpt)) 141 | if val_loss < model_infos[-1][1]: # save better models 142 | worst_model = os.path.join(ckpt_dir, model_infos[-1][0]) 143 | if os.path.exists(worst_model): 144 | os.remove(worst_model) 145 | model_infos[-1] = (cur_ckpt, float('{:.5f}'.format(val_loss))) 146 | model_infos = sorted(list(model_infos), key=lambda x: x[1]) 147 | pd.DataFrame(model_infos).to_csv(os.path.join(ckpt_dir, 'ckpt.csv'), 148 | sep=',', header=None, index=None) 149 | return model_infos 150 | 151 | def main(network=1): 152 | if network == 1: 153 | model = Text2Mel().to(DEVICE) 154 | elif network == 2: 155 | model = SSRN().to(DEVICE) 156 | print('Model {} is working...'.format(model.name)) 157 | print('{} threads are used...'.format(torch.get_num_threads())) 158 | ckpt_dir = os.path.join(args.logdir, model.name) 159 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 160 | scheduler = StepLR(optimizer, step_size=args.lr_decay_step//10, gamma=0.933) # around 1/2 per decay step 161 | 162 | if not os.path.exists(ckpt_dir): 163 | os.makedirs(os.path.join(ckpt_dir, 'A', 'train')) 164 | else: 165 | print('Already exists. Retrain the model.') 166 | ckpt = pd.read_csv(os.path.join(ckpt_dir, 'ckpt.csv'), sep=',', header=None) 167 | ckpt.columns = ['models', 'loss'] 168 | ckpt = ckpt.sort_values(by='loss', ascending=True) 169 | state = torch.load(os.path.join(ckpt_dir, ckpt.models.loc[0])) 170 | model.load_state_dict(state['model']) 171 | args.global_step = state['global_step'] 172 | optimizer.load_state_dict(state['optimizer']) 173 | scheduler.load_state_dict(state['scheduler']) 174 | 175 | # model = torch.nn.DataParallel(model, device_ids=list(range(args.no_gpu))).to(DEVICE) 176 | if model.name == 'Text2Mel': 177 | if args.ga_mode: 178 | cfn_train, cfn_eval = t2m_ga_collate_fn, t2m_collate_fn 179 | else: 180 | cfn_train, cfn_eval = t2m_collate_fn, t2m_collate_fn 181 | else: 182 | cfn_train, cfn_eval = collate_fn, collate_fn 183 | 184 | dataset = SpeechDataset(args.data_path, args.meta_train, model.name, mem_mode=args.mem_mode, ga_mode=args.ga_mode) 185 | validset = SpeechDataset(args.data_path, args.meta_eval, model.name, mem_mode=args.mem_mode) 186 | data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, 187 | shuffle=True, collate_fn=cfn_train, 188 | drop_last=True, pin_memory=True) 189 | valid_loader = DataLoader(dataset=validset, batch_size=args.test_batch, 190 | shuffle=False, collate_fn=cfn_eval, pin_memory=True) 191 | 192 | writer = SummaryWriter(ckpt_dir) 193 | train(model, data_loader, valid_loader, optimizer, scheduler, 194 | batch_size=args.batch_size, ckpt_dir=ckpt_dir, writer=writer) 195 | return None 196 | 197 | if __name__ == '__main__': 198 | network = int(sys.argv[1]) 199 | gpu_id = int(sys.argv[2]) 200 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 201 | os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_id) 202 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 203 | # Set random seem for reproducibility 204 | seed = 999 205 | random.seed(seed) 206 | np.random.seed(seed) 207 | torch.manual_seed(seed) 208 | main(network=network) 209 | 210 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from config import ConfigArgs as args 2 | import librosa 3 | import numpy as np 4 | import os, sys 5 | from scipy import signal 6 | import copy 7 | import torch 8 | import matplotlib 9 | matplotlib.use('pdf') 10 | import matplotlib.pyplot as plt 11 | 12 | def load_spectrogram(fpath): 13 | wav, sr = librosa.load(fpath, sr=args.sr) 14 | 15 | ## Pre-processing 16 | wav, _ = librosa.effects.trim(wav) 17 | wav = np.append(wav[0], wav[1:] - args.preemph * wav[:-1]) 18 | # STFT 19 | linear = librosa.stft(y=wav, 20 | n_fft=args.n_fft, 21 | hop_length=args.hop_length, 22 | win_length=args.win_length) 23 | 24 | # magnitude spectrogram 25 | mag = np.abs(linear) # (1+n_fft//2, T) 26 | 27 | # mel spectrogram 28 | mel_basis = librosa.filters.mel(args.sr, args.n_fft, args.n_mels) # (n_mels, 1+n_fft//2) 29 | mel = np.dot(mel_basis, mag) # (n_mels, t) 30 | 31 | # to decibel 32 | mel = 20 * np.log10(np.maximum(1e-5, mel)) 33 | mag = 20 * np.log10(np.maximum(1e-5, mag)) 34 | 35 | # normalize 36 | mel = np.clip((mel - args.ref_db + args.max_db) / args.max_db, 1e-8, 1) 37 | mag = np.clip((mag - args.ref_db + args.max_db) / args.max_db, 1e-8, 1) 38 | 39 | # Transpose 40 | mel = mel.T.astype(np.float32) # (T, n_mels) 41 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2) 42 | 43 | mel, mag = padding_reduction(mel, mag) 44 | return mel, mag 45 | 46 | def padding_reduction(mel, mag): 47 | # Padding 48 | t = mel.shape[0] 49 | n_paddings = args.r - (t % args.r) if t % args.r != 0 else 0 # for reduction 50 | mel = np.pad(mel, [[0, n_paddings], [0, 0]], mode="constant") 51 | mag = np.pad(mag, [[0, n_paddings], [0, 0]], mode="constant") 52 | mel = mel[::args.r, :] 53 | return mel, mag 54 | 55 | def spectrogram2wav(mag): 56 | '''# Generate wave file from spectrogram''' 57 | # transpose 58 | mag = mag.T 59 | 60 | # de-normalize 61 | mag = (np.clip(mag, 0, 1) * args.max_db) - args.max_db + args.ref_db 62 | 63 | # to amplitude 64 | mag = np.power(10.0, mag * 0.05) 65 | 66 | # wav reconstruction 67 | wav = griffin_lim(mag**args.power) 68 | 69 | # de-preemphasis 70 | wav = signal.lfilter([1], [1, -args.preemph], wav) 71 | 72 | # trim 73 | wav, _ = librosa.effects.trim(wav) 74 | 75 | return wav.astype(np.float32) 76 | 77 | def griffin_lim(spectrogram): 78 | ''' 79 | Applies Griffin-Lim's raw. 80 | ''' 81 | X_best = copy.deepcopy(spectrogram) 82 | for i in range(args.gl_iter): 83 | X_t = librosa.istft(X_best, args.hop_length, win_length=args.win_length, window="hann") 84 | est = librosa.stft(X_t, args.n_fft, args.hop_length, win_length=args.win_length) 85 | phase = est / np.maximum(1e-8, np.abs(est)) 86 | X_best = spectrogram * phase 87 | X_t = librosa.istft(X_best, args.hop_length, win_length=args.win_length, window="hann") 88 | y = np.real(X_t) 89 | return y 90 | 91 | def att2img(A): 92 | ''' 93 | Args: 94 | A: (1, Tx, Ty) Tensor 95 | ''' 96 | for i in range(A.shape[-1]): 97 | att = A[0, :, i] 98 | local_min, local_max = att.min(), att.max() 99 | A[0, :, i] = (att-local_min)/local_max 100 | return A 101 | 102 | 103 | def plot_att(A, text, global_step, path='.', name=None): 104 | ''' 105 | Args: 106 | A: (Tx, Ty) numpy array 107 | text: (Tx,) list 108 | global_step: scalar 109 | ''' 110 | fig, ax = plt.subplots(figsize=(25, 25)) 111 | im = ax.imshow(A) 112 | fig.colorbar(im, fraction=0.035, pad=0.02) 113 | fig.suptitle('{} Steps'.format(global_step), fontsize=30) 114 | plt.ylabel('Text', fontsize=22) 115 | plt.xlabel('Time', fontsize=22) 116 | plt.yticks(np.arange(len(text)), text) 117 | if name is not None: 118 | plt.savefig(os.path.join(path, name), format='png') 119 | else: 120 | plt.savefig(os.path.join( 121 | path, 'A-{}.png'.format(global_step)), format='png') 122 | plt.close(fig) 123 | 124 | def prepro_guided_attention(N, T, g=0.2): 125 | W = np.zeros([args.max_Tx, args.max_Ty], dtype=np.float32) 126 | for tx in range(args.max_Tx): 127 | for ty in range(args.max_Ty): 128 | if ty <= T: 129 | W[tx, ty] = 1.0 - np.exp(-0.5 * (ty/T - tx/N)**2 / g**2) 130 | else: 131 | W[tx, ty] = 1.0 - np.exp(-0.5 * ((N-1)/N - tx/N)**2 / (g/2)**2) # forcing more at end step 132 | return W 133 | --------------------------------------------------------------------------------