├── .gitignore ├── LICENSE ├── README.md ├── hyperparams.py ├── module.py ├── network.py ├── png ├── alphas.png ├── attention.gif ├── attention │ ├── attention_0_0.png │ ├── attention_0_1.png │ ├── attention_0_2.png │ ├── attention_0_3.png │ ├── attention_1_0.png │ ├── attention_1_1.png │ ├── attention_1_2.png │ ├── attention_1_3.png │ ├── attention_2_0.png │ ├── attention_2_1.png │ ├── attention_2_2.png │ └── attention_2_3.png ├── attention_decoder.gif ├── attention_decoder │ ├── attention_dec_0_0.png │ ├── attention_dec_0_1.png │ ├── attention_dec_0_2.png │ ├── attention_dec_0_3.png │ ├── attention_dec_1_0.png │ ├── attention_dec_1_1.png │ ├── attention_dec_1_2.png │ ├── attention_dec_1_3.png │ ├── attention_dec_2_0.png │ ├── attention_dec_2_1.png │ ├── attention_dec_2_2.png │ └── attention_dec_2_3.png ├── attention_encoder.gif ├── attention_encoder │ ├── attention_enc_0_0.png │ ├── attention_enc_0_1.png │ ├── attention_enc_0_2.png │ ├── attention_enc_0_3.png │ ├── attention_enc_1_0.png │ ├── attention_enc_1_1.png │ ├── attention_enc_1_2.png │ ├── attention_enc_1_3.png │ ├── attention_enc_2_0.png │ ├── attention_enc_2_1.png │ ├── attention_enc_2_2.png │ └── attention_enc_2_3.png ├── mel_original.png ├── mel_pred.png ├── model.png └── training_loss.png ├── prepare_data.py ├── preprocess.py ├── requirements.txt ├── samples └── test.wav ├── synthesis.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py ├── train_postnet.py ├── train_transformer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | checkpoint* 4 | assets/ 5 | .ipynb_checkpoints/ 6 | .idea/ 7 | data/ 8 | ._.DS_Store 9 | runs/ 10 | test.ipynb 11 | synthesis.ipynb 12 | log* 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 soobin seo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-TTS 2 | * A Pytorch Implementation of [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) 3 | * This model can be trained about 3 to 4 times faster than the well known seq2seq model like tacotron, and the quality of synthesized speech is almost the same. It was confirmed through experiment that it took about 0.5 second per step. 4 | * I did not use the wavenet vocoder but learned the post network using CBHG model of tacotron and converted the spectrogram into raw wave using griffin-lim algorithm. 5 | 6 | 7 | 8 | ## Requirements 9 | * Install python 3 10 | * Install pytorch == 0.4.0 11 | * Install requirements: 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Data 17 | * I used LJSpeech dataset which consists of pairs of text script and wav files. The complete dataset (13,100 pairs) can be downloaded [here](https://keithito.com/LJ-Speech-Dataset/). I referred https://github.com/keithito/tacotron and https://github.com/Kyubyong/dc_tts for the preprocessing code. 18 | 19 | ## Pretrained Model 20 | * you can download pretrained model [here](https://drive.google.com/drive/folders/1r1tdgsdtipLossqD9ZfDmxSZb8nMO8Nf) (160K for AR model / 100K for Postnet) 21 | * Locate the pretrained model at checkpoint/ directory. 22 | 23 | ## Attention plots 24 | * A diagonal alignment appeared after about 15k steps. The attention plots below are at 160k steps. Plots represent the multihead attention of all layers. In this experiment, h=4 is used for three attention layers. Therefore, 12 attention plots were drawn for each of the encoder, decoder and encoder-decoder. With the exception of the decoder, only a few multiheads showed diagonal alignment. 25 | 26 | ### Self Attention encoder 27 | 28 | 29 | ### Self Attention decoder 30 | 31 | 32 | ### Attention encoder-decoder 33 | 34 | 35 | ## Learning curves & Alphas 36 | * I used Noam style warmup and decay as same as [Tacotron](https://github.com/Kyubyong/tacotron) 37 | 38 | 39 | 40 | * The alpha value for the scaled position encoding is different from the thesis. In the paper, the alpha value of the encoder is increased to 4, whereas in the present experiment, it slightly increased at the beginning and then decreased continuously. The decoder alpha has steadily decreased since the beginning. 41 | 42 | 43 | 44 | ## Experimental notes 45 | 1. **The learning rate is an important parameter for training.** With initial learning rate of 0.001 and exponentially decaying doesn't work. 46 | 2. **The gradient clipping is also an important parameter for training.** I clipped the gradient with norm value 1. 47 | 3. With the stop token loss, the model did not training. 48 | 4. **It was very important to concatenate the input and context vectors in the Attention mechanism.** 49 | 50 | ## Generated Samples 51 | * You can check some generated samples below. All samples are step at 160k, so I think the model is not converged yet. This model seems to be lower performance in long sentences. 52 | 53 | * [sample1](https://soundcloud.com/ksrbpbmcxrzu/160k-0) 54 | * [sample2](https://soundcloud.com/ksrbpbmcxrzu/160k_sample_1) 55 | * [sample3](https://soundcloud.com/ksrbpbmcxrzu/160k_sample_2) 56 | 57 | * The first plot is the predicted mel spectrogram, and the second is the ground truth. 58 | 59 | 60 | 61 | ## File description 62 | * `hyperparams.py` includes all hyper parameters that are needed. 63 | * `prepare_data.py` preprocess wav files to mel, linear spectrogram and save them for faster training time. Preprocessing codes for text is in text/ directory. 64 | * `preprocess.py` includes all preprocessing codes when you loads data. 65 | * `module.py` contains all methods, including attention, prenet, postnet and so on. 66 | * `network.py` contains networks including encoder, decoder and post-processing network. 67 | * `train_transformer.py` is for training autoregressive attention network. (text --> mel) 68 | * `train_postnet.py` is for training post network. (mel --> linear) 69 | * `synthesis.py` is for generating TTS sample. 70 | 71 | ## Training the network 72 | * STEP 1. Download and extract LJSpeech data at any directory you want. 73 | * STEP 2. Adjust hyperparameters in `hyperparams.py`, especially 'data_path' which is a directory that you extract files, and the others if necessary. 74 | * STEP 3. Run `prepare_data.py`. 75 | * STEP 4. Run `train_transformer.py`. 76 | * STEP 5. Run `train_postnet.py`. 77 | 78 | ## Generate TTS wav file 79 | * STEP 1. Run `synthesis.py`. Make sure the restore step. 80 | 81 | ## Reference 82 | * Keith ito: https://github.com/keithito/tacotron 83 | * Kyubyong Park: https://github.com/Kyubyong/dc_tts 84 | * jadore801120: https://github.com/jadore801120/attention-is-all-you-need-pytorch/ 85 | 86 | ## Comments 87 | * Any comments for the codes are always welcome. 88 | 89 | -------------------------------------------------------------------------------- /hyperparams.py: -------------------------------------------------------------------------------- 1 | # Audio 2 | num_mels = 80 3 | # num_freq = 1024 4 | n_fft = 2048 5 | sr = 22050 6 | # frame_length_ms = 50. 7 | # frame_shift_ms = 12.5 8 | preemphasis = 0.97 9 | frame_shift = 0.0125 # seconds 10 | frame_length = 0.05 # seconds 11 | hop_length = int(sr*frame_shift) # samples. 12 | win_length = int(sr*frame_length) # samples. 13 | n_mels = 80 # Number of Mel banks to generate 14 | power = 1.2 # Exponent for amplifying the predicted magnitude 15 | min_level_db = -100 16 | ref_level_db = 20 17 | hidden_size = 256 18 | embedding_size = 512 19 | max_db = 100 20 | ref_db = 20 21 | 22 | n_iter = 60 23 | # power = 1.5 24 | outputs_per_step = 1 25 | 26 | epochs = 10000 27 | lr = 0.001 28 | save_step = 2000 29 | image_step = 500 30 | batch_size = 32 31 | 32 | cleaners='english_cleaners' 33 | 34 | data_path = './data/LJSpeech-1.1' 35 | checkpoint_path = './checkpoint' 36 | sample_path = './samples' -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as t 3 | import torch.nn.functional as F 4 | import math 5 | import hyperparams as hp 6 | from text.symbols import symbols 7 | import numpy as np 8 | import copy 9 | from collections import OrderedDict 10 | 11 | def clones(module, N): 12 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 13 | 14 | class Linear(nn.Module): 15 | """ 16 | Linear Module 17 | """ 18 | def __init__(self, in_dim, out_dim, bias=True, w_init='linear'): 19 | """ 20 | :param in_dim: dimension of input 21 | :param out_dim: dimension of output 22 | :param bias: boolean. if True, bias is included. 23 | :param w_init: str. weight inits with xavier initialization. 24 | """ 25 | super(Linear, self).__init__() 26 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) 27 | 28 | nn.init.xavier_uniform_( 29 | self.linear_layer.weight, 30 | gain=nn.init.calculate_gain(w_init)) 31 | 32 | def forward(self, x): 33 | return self.linear_layer(x) 34 | 35 | 36 | class Conv(nn.Module): 37 | """ 38 | Convolution Module 39 | """ 40 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 41 | padding=0, dilation=1, bias=True, w_init='linear'): 42 | """ 43 | :param in_channels: dimension of input 44 | :param out_channels: dimension of output 45 | :param kernel_size: size of kernel 46 | :param stride: size of stride 47 | :param padding: size of padding 48 | :param dilation: dilation rate 49 | :param bias: boolean. if True, bias is included. 50 | :param w_init: str. weight inits with xavier initialization. 51 | """ 52 | super(Conv, self).__init__() 53 | 54 | self.conv = nn.Conv1d(in_channels, out_channels, 55 | kernel_size=kernel_size, stride=stride, 56 | padding=padding, dilation=dilation, 57 | bias=bias) 58 | 59 | nn.init.xavier_uniform_( 60 | self.conv.weight, gain=nn.init.calculate_gain(w_init)) 61 | 62 | def forward(self, x): 63 | x = self.conv(x) 64 | return x 65 | 66 | 67 | class EncoderPrenet(nn.Module): 68 | """ 69 | Pre-network for Encoder consists of convolution networks. 70 | """ 71 | def __init__(self, embedding_size, num_hidden): 72 | super(EncoderPrenet, self).__init__() 73 | self.embedding_size = embedding_size 74 | self.embed = nn.Embedding(len(symbols), embedding_size, padding_idx=0) 75 | 76 | self.conv1 = Conv(in_channels=embedding_size, 77 | out_channels=num_hidden, 78 | kernel_size=5, 79 | padding=int(np.floor(5 / 2)), 80 | w_init='relu') 81 | self.conv2 = Conv(in_channels=num_hidden, 82 | out_channels=num_hidden, 83 | kernel_size=5, 84 | padding=int(np.floor(5 / 2)), 85 | w_init='relu') 86 | 87 | self.conv3 = Conv(in_channels=num_hidden, 88 | out_channels=num_hidden, 89 | kernel_size=5, 90 | padding=int(np.floor(5 / 2)), 91 | w_init='relu') 92 | 93 | self.batch_norm1 = nn.BatchNorm1d(num_hidden) 94 | self.batch_norm2 = nn.BatchNorm1d(num_hidden) 95 | self.batch_norm3 = nn.BatchNorm1d(num_hidden) 96 | 97 | self.dropout1 = nn.Dropout(p=0.2) 98 | self.dropout2 = nn.Dropout(p=0.2) 99 | self.dropout3 = nn.Dropout(p=0.2) 100 | self.projection = Linear(num_hidden, num_hidden) 101 | 102 | def forward(self, input_): 103 | input_ = self.embed(input_) 104 | input_ = input_.transpose(1, 2) 105 | input_ = self.dropout1(t.relu(self.batch_norm1(self.conv1(input_)))) 106 | input_ = self.dropout2(t.relu(self.batch_norm2(self.conv2(input_)))) 107 | input_ = self.dropout3(t.relu(self.batch_norm3(self.conv3(input_)))) 108 | input_ = input_.transpose(1, 2) 109 | input_ = self.projection(input_) 110 | 111 | return input_ 112 | 113 | 114 | class FFN(nn.Module): 115 | """ 116 | Positionwise Feed-Forward Network 117 | """ 118 | 119 | def __init__(self, num_hidden): 120 | """ 121 | :param num_hidden: dimension of hidden 122 | """ 123 | super(FFN, self).__init__() 124 | self.w_1 = Conv(num_hidden, num_hidden * 4, kernel_size=1, w_init='relu') 125 | self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=1) 126 | self.dropout = nn.Dropout(p=0.1) 127 | self.layer_norm = nn.LayerNorm(num_hidden) 128 | 129 | def forward(self, input_): 130 | # FFN Network 131 | x = input_.transpose(1, 2) 132 | x = self.w_2(t.relu(self.w_1(x))) 133 | x = x.transpose(1, 2) 134 | 135 | 136 | # residual connection 137 | x = x + input_ 138 | 139 | # dropout 140 | # x = self.dropout(x) 141 | 142 | # layer normalization 143 | x = self.layer_norm(x) 144 | 145 | return x 146 | 147 | 148 | class PostConvNet(nn.Module): 149 | """ 150 | Post Convolutional Network (mel --> mel) 151 | """ 152 | def __init__(self, num_hidden): 153 | """ 154 | 155 | :param num_hidden: dimension of hidden 156 | """ 157 | super(PostConvNet, self).__init__() 158 | self.conv1 = Conv(in_channels=hp.num_mels * hp.outputs_per_step, 159 | out_channels=num_hidden, 160 | kernel_size=5, 161 | padding=4, 162 | w_init='tanh') 163 | self.conv_list = clones(Conv(in_channels=num_hidden, 164 | out_channels=num_hidden, 165 | kernel_size=5, 166 | padding=4, 167 | w_init='tanh'), 3) 168 | self.conv2 = Conv(in_channels=num_hidden, 169 | out_channels=hp.num_mels * hp.outputs_per_step, 170 | kernel_size=5, 171 | padding=4) 172 | 173 | self.batch_norm_list = clones(nn.BatchNorm1d(num_hidden), 3) 174 | self.pre_batchnorm = nn.BatchNorm1d(num_hidden) 175 | 176 | self.dropout1 = nn.Dropout(p=0.1) 177 | self.dropout_list = nn.ModuleList([nn.Dropout(p=0.1) for _ in range(3)]) 178 | 179 | def forward(self, input_, mask=None): 180 | # Causal Convolution (for auto-regressive) 181 | input_ = self.dropout1(t.tanh(self.pre_batchnorm(self.conv1(input_)[:, :, :-4]))) 182 | for batch_norm, conv, dropout in zip(self.batch_norm_list, self.conv_list, self.dropout_list): 183 | input_ = dropout(t.tanh(batch_norm(conv(input_)[:, :, :-4]))) 184 | input_ = self.conv2(input_)[:, :, :-4] 185 | return input_ 186 | 187 | 188 | class MultiheadAttention(nn.Module): 189 | """ 190 | Multihead attention mechanism (dot attention) 191 | """ 192 | def __init__(self, num_hidden_k): 193 | """ 194 | :param num_hidden_k: dimension of hidden 195 | """ 196 | super(MultiheadAttention, self).__init__() 197 | 198 | self.num_hidden_k = num_hidden_k 199 | self.attn_dropout = nn.Dropout(p=0.1) 200 | 201 | def forward(self, key, value, query, mask=None, query_mask=None): 202 | # Get attention score 203 | attn = t.bmm(query, key.transpose(1, 2)) 204 | attn = attn / math.sqrt(self.num_hidden_k) 205 | 206 | # Masking to ignore padding (key side) 207 | if mask is not None: 208 | attn = attn.masked_fill(mask, -2 ** 32 + 1) 209 | attn = t.softmax(attn, dim=-1) 210 | else: 211 | attn = t.softmax(attn, dim=-1) 212 | 213 | # Masking to ignore padding (query side) 214 | if query_mask is not None: 215 | attn = attn * query_mask 216 | 217 | # Dropout 218 | # attn = self.attn_dropout(attn) 219 | 220 | # Get Context Vector 221 | result = t.bmm(attn, value) 222 | 223 | return result, attn 224 | 225 | 226 | class Attention(nn.Module): 227 | """ 228 | Attention Network 229 | """ 230 | def __init__(self, num_hidden, h=4): 231 | """ 232 | :param num_hidden: dimension of hidden 233 | :param h: num of heads 234 | """ 235 | super(Attention, self).__init__() 236 | 237 | self.num_hidden = num_hidden 238 | self.num_hidden_per_attn = num_hidden // h 239 | self.h = h 240 | 241 | self.key = Linear(num_hidden, num_hidden, bias=False) 242 | self.value = Linear(num_hidden, num_hidden, bias=False) 243 | self.query = Linear(num_hidden, num_hidden, bias=False) 244 | 245 | self.multihead = MultiheadAttention(self.num_hidden_per_attn) 246 | 247 | self.residual_dropout = nn.Dropout(p=0.1) 248 | 249 | self.final_linear = Linear(num_hidden * 2, num_hidden) 250 | 251 | self.layer_norm_1 = nn.LayerNorm(num_hidden) 252 | 253 | def forward(self, memory, decoder_input, mask=None, query_mask=None): 254 | 255 | batch_size = memory.size(0) 256 | seq_k = memory.size(1) 257 | seq_q = decoder_input.size(1) 258 | 259 | # Repeat masks h times 260 | if query_mask is not None: 261 | query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k) 262 | query_mask = query_mask.repeat(self.h, 1, 1) 263 | if mask is not None: 264 | mask = mask.repeat(self.h, 1, 1) 265 | 266 | # Make multihead 267 | key = self.key(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn) 268 | value = self.value(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn) 269 | query = self.query(decoder_input).view(batch_size, seq_q, self.h, self.num_hidden_per_attn) 270 | 271 | key = key.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn) 272 | value = value.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn) 273 | query = query.permute(2, 0, 1, 3).contiguous().view(-1, seq_q, self.num_hidden_per_attn) 274 | 275 | # Get context vector 276 | result, attns = self.multihead(key, value, query, mask=mask, query_mask=query_mask) 277 | 278 | # Concatenate all multihead context vector 279 | result = result.view(self.h, batch_size, seq_q, self.num_hidden_per_attn) 280 | result = result.permute(1, 2, 0, 3).contiguous().view(batch_size, seq_q, -1) 281 | 282 | # Concatenate context vector with input (most important) 283 | result = t.cat([decoder_input, result], dim=-1) 284 | 285 | # Final linear 286 | result = self.final_linear(result) 287 | 288 | # Residual dropout & connection 289 | result = result + decoder_input 290 | 291 | # result = self.residual_dropout(result) 292 | 293 | # Layer normalization 294 | result = self.layer_norm_1(result) 295 | 296 | return result, attns 297 | 298 | 299 | class Prenet(nn.Module): 300 | """ 301 | Prenet before passing through the network 302 | """ 303 | def __init__(self, input_size, hidden_size, output_size, p=0.5): 304 | """ 305 | :param input_size: dimension of input 306 | :param hidden_size: dimension of hidden unit 307 | :param output_size: dimension of output 308 | """ 309 | super(Prenet, self).__init__() 310 | self.input_size = input_size 311 | self.output_size = output_size 312 | self.hidden_size = hidden_size 313 | self.layer = nn.Sequential(OrderedDict([ 314 | ('fc1', Linear(self.input_size, self.hidden_size)), 315 | ('relu1', nn.ReLU()), 316 | ('dropout1', nn.Dropout(p)), 317 | ('fc2', Linear(self.hidden_size, self.output_size)), 318 | ('relu2', nn.ReLU()), 319 | ('dropout2', nn.Dropout(p)), 320 | ])) 321 | 322 | def forward(self, input_): 323 | 324 | out = self.layer(input_) 325 | 326 | return out 327 | 328 | class CBHG(nn.Module): 329 | """ 330 | CBHG Module 331 | """ 332 | def __init__(self, hidden_size, K=16, projection_size = 256, num_gru_layers=2, max_pool_kernel_size=2, is_post=False): 333 | """ 334 | :param hidden_size: dimension of hidden unit 335 | :param K: # of convolution banks 336 | :param projection_size: dimension of projection unit 337 | :param num_gru_layers: # of layers of GRUcell 338 | :param max_pool_kernel_size: max pooling kernel size 339 | :param is_post: whether post processing or not 340 | """ 341 | super(CBHG, self).__init__() 342 | self.hidden_size = hidden_size 343 | self.projection_size = projection_size 344 | self.convbank_list = nn.ModuleList() 345 | self.convbank_list.append(nn.Conv1d(in_channels=projection_size, 346 | out_channels=hidden_size, 347 | kernel_size=1, 348 | padding=int(np.floor(1/2)))) 349 | 350 | for i in range(2, K+1): 351 | self.convbank_list.append(nn.Conv1d(in_channels=hidden_size, 352 | out_channels=hidden_size, 353 | kernel_size=i, 354 | padding=int(np.floor(i/2)))) 355 | 356 | self.batchnorm_list = nn.ModuleList() 357 | for i in range(1, K+1): 358 | self.batchnorm_list.append(nn.BatchNorm1d(hidden_size)) 359 | 360 | convbank_outdim = hidden_size * K 361 | 362 | self.conv_projection_1 = nn.Conv1d(in_channels=convbank_outdim, 363 | out_channels=hidden_size, 364 | kernel_size=3, 365 | padding=int(np.floor(3 / 2))) 366 | self.conv_projection_2 = nn.Conv1d(in_channels=hidden_size, 367 | out_channels=projection_size, 368 | kernel_size=3, 369 | padding=int(np.floor(3 / 2))) 370 | self.batchnorm_proj_1 = nn.BatchNorm1d(hidden_size) 371 | 372 | self.batchnorm_proj_2 = nn.BatchNorm1d(projection_size) 373 | 374 | 375 | self.max_pool = nn.MaxPool1d(max_pool_kernel_size, stride=1, padding=1) 376 | self.highway = Highwaynet(self.projection_size) 377 | self.gru = nn.GRU(self.projection_size, self.hidden_size // 2, num_layers=num_gru_layers, 378 | batch_first=True, 379 | bidirectional=True) 380 | 381 | 382 | def _conv_fit_dim(self, x, kernel_size=3): 383 | if kernel_size % 2 == 0: 384 | return x[:,:,:-1] 385 | else: 386 | return x 387 | 388 | def forward(self, input_): 389 | 390 | input_ = input_.contiguous() 391 | batch_size = input_.size(0) 392 | total_length = input_.size(-1) 393 | 394 | convbank_list = list() 395 | convbank_input = input_ 396 | 397 | # Convolution bank filters 398 | for k, (conv, batchnorm) in enumerate(zip(self.convbank_list, self.batchnorm_list)): 399 | convbank_input = t.relu(batchnorm(self._conv_fit_dim(conv(convbank_input), k+1).contiguous())) 400 | convbank_list.append(convbank_input) 401 | 402 | # Concatenate all features 403 | conv_cat = t.cat(convbank_list, dim=1) 404 | 405 | # Max pooling 406 | conv_cat = self.max_pool(conv_cat)[:,:,:-1] 407 | 408 | # Projection 409 | conv_projection = t.relu(self.batchnorm_proj_1(self._conv_fit_dim(self.conv_projection_1(conv_cat)))) 410 | conv_projection = self.batchnorm_proj_2(self._conv_fit_dim(self.conv_projection_2(conv_projection))) + input_ 411 | 412 | # Highway networks 413 | highway = self.highway.forward(conv_projection.transpose(1,2)) 414 | 415 | 416 | # Bidirectional GRU 417 | 418 | self.gru.flatten_parameters() 419 | out, _ = self.gru(highway) 420 | 421 | return out 422 | 423 | 424 | class Highwaynet(nn.Module): 425 | """ 426 | Highway network 427 | """ 428 | def __init__(self, num_units, num_layers=4): 429 | """ 430 | :param num_units: dimension of hidden unit 431 | :param num_layers: # of highway layers 432 | """ 433 | super(Highwaynet, self).__init__() 434 | self.num_units = num_units 435 | self.num_layers = num_layers 436 | self.gates = nn.ModuleList() 437 | self.linears = nn.ModuleList() 438 | for _ in range(self.num_layers): 439 | self.linears.append(Linear(num_units, num_units)) 440 | self.gates.append(Linear(num_units, num_units)) 441 | 442 | def forward(self, input_): 443 | 444 | out = input_ 445 | 446 | # highway gated function 447 | for fc1, fc2 in zip(self.linears, self.gates): 448 | 449 | h = t.relu(fc1.forward(out)) 450 | t_ = t.sigmoid(fc2.forward(out)) 451 | 452 | c = 1. - t_ 453 | out = h * t_ + out * c 454 | 455 | return out 456 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from module import * 2 | from utils import get_positional_table, get_sinusoid_encoding_table 3 | import hyperparams as hp 4 | import copy 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder Network 9 | """ 10 | def __init__(self, embedding_size, num_hidden): 11 | """ 12 | :param embedding_size: dimension of embedding 13 | :param num_hidden: dimension of hidden 14 | """ 15 | super(Encoder, self).__init__() 16 | self.alpha = nn.Parameter(t.ones(1)) 17 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0), 18 | freeze=True) 19 | self.pos_dropout = nn.Dropout(p=0.1) 20 | self.encoder_prenet = EncoderPrenet(embedding_size, num_hidden) 21 | self.layers = clones(Attention(num_hidden), 3) 22 | self.ffns = clones(FFN(num_hidden), 3) 23 | 24 | def forward(self, x, pos): 25 | 26 | # Get character mask 27 | if self.training: 28 | c_mask = pos.ne(0).type(t.float) 29 | mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1) 30 | 31 | else: 32 | c_mask, mask = None, None 33 | 34 | # Encoder pre-network 35 | x = self.encoder_prenet(x) 36 | 37 | # Get positional embedding, apply alpha and add 38 | pos = self.pos_emb(pos) 39 | x = pos * self.alpha + x 40 | 41 | # Positional dropout 42 | x = self.pos_dropout(x) 43 | 44 | # Attention encoder-encoder 45 | attns = list() 46 | for layer, ffn in zip(self.layers, self.ffns): 47 | x, attn = layer(x, x, mask=mask, query_mask=c_mask) 48 | x = ffn(x) 49 | attns.append(attn) 50 | 51 | return x, c_mask, attns 52 | 53 | 54 | class MelDecoder(nn.Module): 55 | """ 56 | Decoder Network 57 | """ 58 | def __init__(self, num_hidden): 59 | """ 60 | :param num_hidden: dimension of hidden 61 | """ 62 | super(MelDecoder, self).__init__() 63 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0), 64 | freeze=True) 65 | self.pos_dropout = nn.Dropout(p=0.1) 66 | self.alpha = nn.Parameter(t.ones(1)) 67 | self.decoder_prenet = Prenet(hp.num_mels, num_hidden * 2, num_hidden, p=0.2) 68 | self.norm = Linear(num_hidden, num_hidden) 69 | 70 | self.selfattn_layers = clones(Attention(num_hidden), 3) 71 | self.dotattn_layers = clones(Attention(num_hidden), 3) 72 | self.ffns = clones(FFN(num_hidden), 3) 73 | self.mel_linear = Linear(num_hidden, hp.num_mels * hp.outputs_per_step) 74 | self.stop_linear = Linear(num_hidden, 1, w_init='sigmoid') 75 | 76 | self.postconvnet = PostConvNet(num_hidden) 77 | 78 | def forward(self, memory, decoder_input, c_mask, pos): 79 | batch_size = memory.size(0) 80 | decoder_len = decoder_input.size(1) 81 | 82 | # get decoder mask with triangular matrix 83 | if self.training: 84 | m_mask = pos.ne(0).type(t.float) 85 | mask = m_mask.eq(0).unsqueeze(1).repeat(1, decoder_len, 1) 86 | if next(self.parameters()).is_cuda: 87 | mask = mask + t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte() 88 | else: 89 | mask = mask + t.triu(t.ones(decoder_len, decoder_len), diagonal=1).repeat(batch_size, 1, 1).byte() 90 | mask = mask.gt(0) 91 | zero_mask = c_mask.eq(0).unsqueeze(-1).repeat(1, 1, decoder_len) 92 | zero_mask = zero_mask.transpose(1, 2) 93 | else: 94 | if next(self.parameters()).is_cuda: 95 | mask = t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte() 96 | else: 97 | mask = t.triu(t.ones(decoder_len, decoder_len), diagonal=1).repeat(batch_size, 1, 1).byte() 98 | mask = mask.gt(0) 99 | m_mask, zero_mask = None, None 100 | 101 | # Decoder pre-network 102 | decoder_input = self.decoder_prenet(decoder_input) 103 | 104 | # Centered position 105 | decoder_input = self.norm(decoder_input) 106 | 107 | # Get positional embedding, apply alpha and add 108 | pos = self.pos_emb(pos) 109 | decoder_input = pos * self.alpha + decoder_input 110 | 111 | # Positional dropout 112 | decoder_input = self.pos_dropout(decoder_input) 113 | 114 | # Attention decoder-decoder, encoder-decoder 115 | attn_dot_list = list() 116 | attn_dec_list = list() 117 | 118 | for selfattn, dotattn, ffn in zip(self.selfattn_layers, self.dotattn_layers, self.ffns): 119 | decoder_input, attn_dec = selfattn(decoder_input, decoder_input, mask=mask, query_mask=m_mask) 120 | decoder_input, attn_dot = dotattn(memory, decoder_input, mask=zero_mask, query_mask=m_mask) 121 | decoder_input = ffn(decoder_input) 122 | attn_dot_list.append(attn_dot) 123 | attn_dec_list.append(attn_dec) 124 | 125 | # Mel linear projection 126 | mel_out = self.mel_linear(decoder_input) 127 | 128 | # Post Mel Network 129 | postnet_input = mel_out.transpose(1, 2) 130 | out = self.postconvnet(postnet_input) 131 | out = postnet_input + out 132 | out = out.transpose(1, 2) 133 | 134 | # Stop tokens 135 | stop_tokens = self.stop_linear(decoder_input) 136 | 137 | return mel_out, out, attn_dot_list, stop_tokens, attn_dec_list 138 | 139 | 140 | class Model(nn.Module): 141 | """ 142 | Transformer Network 143 | """ 144 | def __init__(self): 145 | super(Model, self).__init__() 146 | self.encoder = Encoder(hp.embedding_size, hp.hidden_size) 147 | self.decoder = MelDecoder(hp.hidden_size) 148 | 149 | def forward(self, characters, mel_input, pos_text, pos_mel): 150 | memory, c_mask, attns_enc = self.encoder.forward(characters, pos=pos_text) 151 | mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder.forward(memory, mel_input, c_mask, 152 | pos=pos_mel) 153 | 154 | return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec 155 | 156 | 157 | class ModelPostNet(nn.Module): 158 | """ 159 | CBHG Network (mel --> linear) 160 | """ 161 | def __init__(self): 162 | super(ModelPostNet, self).__init__() 163 | self.pre_projection = Conv(hp.n_mels, hp.hidden_size) 164 | self.cbhg = CBHG(hp.hidden_size) 165 | self.post_projection = Conv(hp.hidden_size, (hp.n_fft // 2) + 1) 166 | 167 | def forward(self, mel): 168 | mel = mel.transpose(1, 2) 169 | mel = self.pre_projection(mel) 170 | mel = self.cbhg(mel).transpose(1, 2) 171 | mag_pred = self.post_projection(mel).transpose(1, 2) 172 | 173 | return mag_pred -------------------------------------------------------------------------------- /png/alphas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/alphas.png -------------------------------------------------------------------------------- /png/attention.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention.gif -------------------------------------------------------------------------------- /png/attention/attention_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_0.png -------------------------------------------------------------------------------- /png/attention/attention_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_1.png -------------------------------------------------------------------------------- /png/attention/attention_0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_2.png -------------------------------------------------------------------------------- /png/attention/attention_0_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_3.png -------------------------------------------------------------------------------- /png/attention/attention_1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_0.png -------------------------------------------------------------------------------- /png/attention/attention_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_1.png -------------------------------------------------------------------------------- /png/attention/attention_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_2.png -------------------------------------------------------------------------------- /png/attention/attention_1_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_3.png -------------------------------------------------------------------------------- /png/attention/attention_2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_0.png -------------------------------------------------------------------------------- /png/attention/attention_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_1.png -------------------------------------------------------------------------------- /png/attention/attention_2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_2.png -------------------------------------------------------------------------------- /png/attention/attention_2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_3.png -------------------------------------------------------------------------------- /png/attention_decoder.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder.gif -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_0.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_1.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_2.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_0_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_3.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_0.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_1.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_2.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_1_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_3.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_0.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_1.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_2.png -------------------------------------------------------------------------------- /png/attention_decoder/attention_dec_2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_3.png -------------------------------------------------------------------------------- /png/attention_encoder.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder.gif -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_0.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_1.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_2.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_0_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_3.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_0.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_1.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_2.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_1_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_3.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_0.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_1.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_2.png -------------------------------------------------------------------------------- /png/attention_encoder/attention_enc_2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_3.png -------------------------------------------------------------------------------- /png/mel_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/mel_original.png -------------------------------------------------------------------------------- /png/mel_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/mel_pred.png -------------------------------------------------------------------------------- /png/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/model.png -------------------------------------------------------------------------------- /png/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/training_loss.png -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | from utils import get_spectrograms 6 | import hyperparams as hp 7 | import librosa 8 | 9 | class PrepareDataset(Dataset): 10 | """LJSpeech dataset.""" 11 | 12 | def __init__(self, csv_file, root_dir): 13 | """ 14 | Args: 15 | csv_file (string): Path to the csv file with annotations. 16 | root_dir (string): Directory with all the wavs. 17 | 18 | """ 19 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None) 20 | self.root_dir = root_dir 21 | 22 | def load_wav(self, filename): 23 | return librosa.load(filename, sr=hp.sample_rate) 24 | 25 | def __len__(self): 26 | return len(self.landmarks_frame) 27 | 28 | def __getitem__(self, idx): 29 | wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav' 30 | mel, mag = get_spectrograms(wav_name) 31 | 32 | np.save(wav_name[:-4] + '.pt', mel) 33 | np.save(wav_name[:-4] + '.mag', mag) 34 | 35 | sample = {'mel':mel, 'mag': mag} 36 | 37 | return sample 38 | 39 | if __name__ == '__main__': 40 | dataset = PrepareDataset(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs')) 41 | dataloader = DataLoader(dataset, batch_size=1, drop_last=False, num_workers=8) 42 | from tqdm import tqdm 43 | pbar = tqdm(dataloader) 44 | for d in pbar: 45 | pass 46 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import hyperparams as hp 2 | import pandas as pd 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | import librosa 6 | import numpy as np 7 | from text import text_to_sequence 8 | import collections 9 | from scipy import signal 10 | import torch as t 11 | import math 12 | 13 | 14 | class LJDatasets(Dataset): 15 | """LJSpeech dataset.""" 16 | 17 | def __init__(self, csv_file, root_dir): 18 | """ 19 | Args: 20 | csv_file (string): Path to the csv file with annotations. 21 | root_dir (string): Directory with all the wavs. 22 | 23 | """ 24 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None) 25 | self.root_dir = root_dir 26 | 27 | def load_wav(self, filename): 28 | return librosa.load(filename, sr=hp.sample_rate) 29 | 30 | def __len__(self): 31 | return len(self.landmarks_frame) 32 | 33 | def __getitem__(self, idx): 34 | wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav' 35 | text = self.landmarks_frame.ix[idx, 1] 36 | 37 | text = np.asarray(text_to_sequence(text, [hp.cleaners]), dtype=np.int32) 38 | mel = np.load(wav_name[:-4] + '.pt.npy') 39 | mel_input = np.concatenate([np.zeros([1,hp.num_mels], np.float32), mel[:-1,:]], axis=0) 40 | text_length = len(text) 41 | pos_text = np.arange(1, text_length + 1) 42 | pos_mel = np.arange(1, mel.shape[0] + 1) 43 | 44 | sample = {'text': text, 'mel': mel, 'text_length':text_length, 'mel_input':mel_input, 'pos_mel':pos_mel, 'pos_text':pos_text} 45 | 46 | return sample 47 | 48 | class PostDatasets(Dataset): 49 | """LJSpeech dataset.""" 50 | 51 | def __init__(self, csv_file, root_dir): 52 | """ 53 | Args: 54 | csv_file (string): Path to the csv file with annotations. 55 | root_dir (string): Directory with all the wavs. 56 | 57 | """ 58 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None) 59 | self.root_dir = root_dir 60 | 61 | def __len__(self): 62 | return len(self.landmarks_frame) 63 | 64 | def __getitem__(self, idx): 65 | wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav' 66 | mel = np.load(wav_name[:-4] + '.pt.npy') 67 | mag = np.load(wav_name[:-4] + '.mag.npy') 68 | sample = {'mel':mel, 'mag':mag} 69 | 70 | return sample 71 | 72 | def collate_fn_transformer(batch): 73 | 74 | # Puts each data field into a tensor with outer dimension batch size 75 | if isinstance(batch[0], collections.Mapping): 76 | 77 | text = [d['text'] for d in batch] 78 | mel = [d['mel'] for d in batch] 79 | mel_input = [d['mel_input'] for d in batch] 80 | text_length = [d['text_length'] for d in batch] 81 | pos_mel = [d['pos_mel'] for d in batch] 82 | pos_text= [d['pos_text'] for d in batch] 83 | 84 | text = [i for i,_ in sorted(zip(text, text_length), key=lambda x: x[1], reverse=True)] 85 | mel = [i for i, _ in sorted(zip(mel, text_length), key=lambda x: x[1], reverse=True)] 86 | mel_input = [i for i, _ in sorted(zip(mel_input, text_length), key=lambda x: x[1], reverse=True)] 87 | pos_text = [i for i, _ in sorted(zip(pos_text, text_length), key=lambda x: x[1], reverse=True)] 88 | pos_mel = [i for i, _ in sorted(zip(pos_mel, text_length), key=lambda x: x[1], reverse=True)] 89 | text_length = sorted(text_length, reverse=True) 90 | # PAD sequences with largest length of the batch 91 | text = _prepare_data(text).astype(np.int32) 92 | mel = _pad_mel(mel) 93 | mel_input = _pad_mel(mel_input) 94 | pos_mel = _prepare_data(pos_mel).astype(np.int32) 95 | pos_text = _prepare_data(pos_text).astype(np.int32) 96 | 97 | 98 | return t.LongTensor(text), t.FloatTensor(mel), t.FloatTensor(mel_input), t.LongTensor(pos_text), t.LongTensor(pos_mel), t.LongTensor(text_length) 99 | 100 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 101 | .format(type(batch[0])))) 102 | 103 | def collate_fn_postnet(batch): 104 | 105 | # Puts each data field into a tensor with outer dimension batch size 106 | if isinstance(batch[0], collections.Mapping): 107 | 108 | mel = [d['mel'] for d in batch] 109 | mag = [d['mag'] for d in batch] 110 | 111 | # PAD sequences with largest length of the batch 112 | mel = _pad_mel(mel) 113 | mag = _pad_mel(mag) 114 | 115 | return t.FloatTensor(mel), t.FloatTensor(mag) 116 | 117 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 118 | .format(type(batch[0])))) 119 | 120 | def _pad_data(x, length): 121 | _pad = 0 122 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) 123 | 124 | def _prepare_data(inputs): 125 | max_len = max((len(x) for x in inputs)) 126 | return np.stack([_pad_data(x, max_len) for x in inputs]) 127 | 128 | def _pad_per_step(inputs): 129 | timesteps = inputs.shape[-1] 130 | return np.pad(inputs, [[0,0],[0,0],[0, hp.outputs_per_step - (timesteps % hp.outputs_per_step)]], mode='constant', constant_values=0.0) 131 | 132 | def get_param_size(model): 133 | params = 0 134 | for p in model.parameters(): 135 | tmp = 1 136 | for x in p.size(): 137 | tmp *= x 138 | params += tmp 139 | return params 140 | 141 | def get_dataset(): 142 | return LJDatasets(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs')) 143 | 144 | def get_post_dataset(): 145 | return PostDatasets(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs')) 146 | 147 | def _pad_mel(inputs): 148 | _pad = 0 149 | def _pad_one(x, max_len): 150 | mel_len = x.shape[0] 151 | return np.pad(x, [[0,max_len - mel_len],[0,0]], mode='constant', constant_values=_pad) 152 | max_len = max((x.shape[0] for x in inputs)) 153 | return np.stack([_pad_one(x, max_len) for x in inputs]) 154 | 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | falcon==1.2.0 2 | inflect==0.2.5 3 | librosa==0.7.1 4 | scipy==1.10.0 5 | Unidecode==0.4.21 6 | pandas 7 | numpy 8 | tensorboardX 9 | tqdm 10 | -------------------------------------------------------------------------------- /samples/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/samples/test.wav -------------------------------------------------------------------------------- /synthesis.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from utils import spectrogram2wav 3 | from scipy.io.wavfile import write 4 | import hyperparams as hp 5 | from text import text_to_sequence 6 | import numpy as np 7 | from network import ModelPostNet, Model 8 | from collections import OrderedDict 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | def load_checkpoint(step, model_name="transformer"): 13 | state_dict = t.load('./checkpoint/checkpoint_%s_%d.pth.tar'% (model_name, step)) 14 | new_state_dict = OrderedDict() 15 | for k, value in state_dict['model'].items(): 16 | key = k[7:] 17 | new_state_dict[key] = value 18 | 19 | return new_state_dict 20 | 21 | def synthesis(text, args): 22 | m = Model() 23 | m_post = ModelPostNet() 24 | 25 | m.load_state_dict(load_checkpoint(args.restore_step1, "transformer")) 26 | m_post.load_state_dict(load_checkpoint(args.restore_step2, "postnet")) 27 | 28 | text = np.asarray(text_to_sequence(text, [hp.cleaners])) 29 | text = t.LongTensor(text).unsqueeze(0) 30 | text = text.cuda() 31 | mel_input = t.zeros([1,1, 80]).cuda() 32 | pos_text = t.arange(1, text.size(1)+1).unsqueeze(0) 33 | pos_text = pos_text.cuda() 34 | 35 | m=m.cuda() 36 | m_post = m_post.cuda() 37 | m.train(False) 38 | m_post.train(False) 39 | 40 | pbar = tqdm(range(args.max_len)) 41 | with t.no_grad(): 42 | for i in pbar: 43 | pos_mel = t.arange(1,mel_input.size(1)+1).unsqueeze(0).cuda() 44 | mel_pred, postnet_pred, attn, stop_token, _, attn_dec = m.forward(text, mel_input, pos_text, pos_mel) 45 | mel_input = t.cat([mel_input, mel_pred[:,-1:,:]], dim=1) 46 | 47 | mag_pred = m_post.forward(postnet_pred) 48 | 49 | wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy()) 50 | write(hp.sample_path + "/test.wav", hp.sr, wav) 51 | 52 | if __name__ == '__main__': 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('--restore_step1', type=int, help='Global step to restore checkpoint', default=172000) 56 | parser.add_argument('--restore_step2', type=int, help='Global step to restore checkpoint', default=100000) 57 | parser.add_argument('--max_len', type=int, help='Global step to restore checkpoint', default=400) 58 | 59 | args = parser.parse_args() 60 | synthesis("Transformer model is so fast!",args) 61 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import re 4 | from text import cleaners 5 | from text.symbols import symbols 6 | 7 | 8 | 9 | 10 | # Mappings from symbol to numeric ID and vice versa: 11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 13 | 14 | # Regular expression matching text enclosed in curly braces: 15 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 16 | 17 | 18 | def text_to_sequence(text, cleaner_names): 19 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 20 | 21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 23 | 24 | Args: 25 | text: string to convert to a sequence 26 | cleaner_names: names of the cleaner functions to run the text through 27 | 28 | Returns: 29 | List of integers corresponding to the symbols in the text 30 | ''' 31 | sequence = [] 32 | 33 | # Check for curly braces and treat their contents as ARPAbet: 34 | while len(text): 35 | m = _curly_re.match(text) 36 | if not m: 37 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 38 | break 39 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 40 | sequence += _arpabet_to_sequence(m.group(2)) 41 | text = m.group(3) 42 | 43 | # Append EOS token 44 | sequence.append(_symbol_to_id['~']) 45 | return sequence 46 | 47 | 48 | def sequence_to_text(sequence): 49 | '''Converts a sequence of IDs back to a string''' 50 | result = '' 51 | for symbol_id in sequence: 52 | if symbol_id in _id_to_symbol: 53 | s = _id_to_symbol[symbol_id] 54 | # Enclose ARPAbet back in curly braces: 55 | if len(s) > 1 and s[0] == '@': 56 | s = '{%s}' % s[1:] 57 | result += s 58 | return result.replace('}{', ' ') 59 | 60 | 61 | def _clean_text(text, cleaner_names): 62 | for name in cleaner_names: 63 | cleaner = getattr(cleaners, name) 64 | if not cleaner: 65 | raise Exception('Unknown cleaner: %s' % name) 66 | text = cleaner(text) 67 | return text 68 | 69 | 70 | def _symbols_to_sequence(symbols): 71 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 72 | 73 | 74 | def _arpabet_to_sequence(text): 75 | return _symbols_to_sequence(['@' + s for s in text.split()]) 76 | 77 | 78 | def _should_keep_symbol(s): 79 | return s in _symbol_to_id and s is not '_' and s is not '~' -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | 4 | ''' 5 | Cleaners are transformations that run over the input text at both training and eval time. 6 | 7 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 8 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 9 | 1. "english_cleaners" for English text 10 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 11 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 12 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 13 | the symbols in symbols.py to match your data). 14 | ''' 15 | 16 | import re 17 | from unidecode import unidecode 18 | from .numbers import normalize_numbers 19 | 20 | 21 | # Regular expression matching whitespace: 22 | _whitespace_re = re.compile(r'\s+') 23 | 24 | # List of (regular expression, replacement) pairs for abbreviations: 25 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 26 | ('mrs', 'misess'), 27 | ('mr', 'mister'), 28 | ('dr', 'doctor'), 29 | ('st', 'saint'), 30 | ('co', 'company'), 31 | ('jr', 'junior'), 32 | ('maj', 'major'), 33 | ('gen', 'general'), 34 | ('drs', 'doctors'), 35 | ('rev', 'reverend'), 36 | ('lt', 'lieutenant'), 37 | ('hon', 'honorable'), 38 | ('sgt', 'sergeant'), 39 | ('capt', 'captain'), 40 | ('esq', 'esquire'), 41 | ('ltd', 'limited'), 42 | ('col', 'colonel'), 43 | ('ft', 'fort'), 44 | ]] 45 | 46 | 47 | def expand_abbreviations(text): 48 | for regex, replacement in _abbreviations: 49 | text = re.sub(regex, replacement, text) 50 | return text 51 | 52 | 53 | def expand_numbers(text): 54 | return normalize_numbers(text) 55 | 56 | 57 | def lowercase(text): 58 | return text.lower() 59 | 60 | 61 | def collapse_whitespace(text): 62 | return re.sub(_whitespace_re, ' ', text) 63 | 64 | 65 | def convert_to_ascii(text): 66 | return unidecode(text) 67 | 68 | 69 | def basic_cleaners(text): 70 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 71 | text = lowercase(text) 72 | text = collapse_whitespace(text) 73 | return text 74 | 75 | 76 | def transliteration_cleaners(text): 77 | '''Pipeline for non-English text that transliterates to ASCII.''' 78 | text = convert_to_ascii(text) 79 | text = lowercase(text) 80 | text = collapse_whitespace(text) 81 | return text 82 | 83 | 84 | def english_cleaners(text): 85 | '''Pipeline for English text, including number and abbreviation expansion.''' 86 | text = convert_to_ascii(text) 87 | text = lowercase(text) 88 | text = expand_numbers(text) 89 | text = expand_abbreviations(text) 90 | text = collapse_whitespace(text) 91 | return text 92 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | 4 | import re 5 | 6 | 7 | valid_symbols = [ 8 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 9 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 10 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 11 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 12 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 13 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 14 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 15 | ] 16 | 17 | _valid_symbol_set = set(valid_symbols) 18 | 19 | 20 | class CMUDict: 21 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 22 | def __init__(self, file_or_path, keep_ambiguous=True): 23 | if isinstance(file_or_path, str): 24 | with open(file_or_path, encoding='latin-1') as f: 25 | entries = _parse_cmudict(f) 26 | else: 27 | entries = _parse_cmudict(file_or_path) 28 | if not keep_ambiguous: 29 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 30 | self._entries = entries 31 | 32 | 33 | def __len__(self): 34 | return len(self._entries) 35 | 36 | 37 | def lookup(self, word): 38 | '''Returns list of ARPAbet pronunciations of the given word.''' 39 | return self._entries.get(word.upper()) 40 | 41 | 42 | 43 | _alt_re = re.compile(r'\([0-9]+\)') 44 | 45 | 46 | def _parse_cmudict(file): 47 | cmudict = {} 48 | for line in file: 49 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 50 | parts = line.split(' ') 51 | word = re.sub(_alt_re, '', parts[0]) 52 | pronunciation = _get_pronunciation(parts[1]) 53 | if pronunciation: 54 | if word in cmudict: 55 | cmudict[word].append(pronunciation) 56 | else: 57 | cmudict[word] = [pronunciation] 58 | return cmudict 59 | 60 | 61 | def _get_pronunciation(s): 62 | parts = s.strip().split(' ') 63 | for part in parts: 64 | if part not in _valid_symbol_set: 65 | return None 66 | return ' '.join(parts) 67 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | 4 | ''' 5 | Defines the set of symbols used in text input to the model. 6 | 7 | The default is a set of ASCII characters that works well for English or text that has been run 8 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. 9 | ''' 10 | from text import cmudict 11 | 12 | _pad = '_' 13 | _eos = '~' 14 | _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 18 | 19 | # Export all symbols: 20 | symbols = [_pad, _eos] + list(_characters) + _arpabet 21 | 22 | 23 | if __name__ == '__main__': 24 | print(symbols) -------------------------------------------------------------------------------- /train_postnet.py: -------------------------------------------------------------------------------- 1 | from preprocess import get_post_dataset, DataLoader, collate_fn_postnet 2 | from network import * 3 | from tensorboardX import SummaryWriter 4 | import torchvision.utils as vutils 5 | import os 6 | from tqdm import tqdm 7 | 8 | def adjust_learning_rate(optimizer, step_num, warmup_step=4000): 9 | lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = lr 12 | 13 | 14 | def main(): 15 | 16 | dataset = get_post_dataset() 17 | global_step = 0 18 | 19 | m = nn.DataParallel(ModelPostNet().cuda()) 20 | 21 | m.train() 22 | optimizer = t.optim.Adam(m.parameters(), lr=hp.lr) 23 | 24 | writer = SummaryWriter() 25 | 26 | for epoch in range(hp.epochs): 27 | 28 | dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_postnet, drop_last=True, num_workers=8) 29 | pbar = tqdm(dataloader) 30 | for i, data in enumerate(pbar): 31 | pbar.set_description("Processing at epoch %d"%epoch) 32 | global_step += 1 33 | if global_step < 400000: 34 | adjust_learning_rate(optimizer, global_step) 35 | 36 | mel, mag = data 37 | 38 | mel = mel.cuda() 39 | mag = mag.cuda() 40 | 41 | mag_pred = m.forward(mel) 42 | 43 | loss = nn.L1Loss()(mag_pred, mag) 44 | 45 | writer.add_scalars('training_loss',{ 46 | 'loss':loss, 47 | 48 | }, global_step) 49 | 50 | optimizer.zero_grad() 51 | # Calculate gradients 52 | loss.backward() 53 | 54 | nn.utils.clip_grad_norm_(m.parameters(), 1.) 55 | 56 | # Update weights 57 | optimizer.step() 58 | 59 | if global_step % hp.save_step == 0: 60 | t.save({'model':m.state_dict(), 61 | 'optimizer':optimizer.state_dict()}, 62 | os.path.join(hp.checkpoint_path,'checkpoint_postnet_%d.pth.tar' % global_step)) 63 | 64 | 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | main() -------------------------------------------------------------------------------- /train_transformer.py: -------------------------------------------------------------------------------- 1 | from preprocess import get_dataset, DataLoader, collate_fn_transformer 2 | from network import * 3 | from tensorboardX import SummaryWriter 4 | import torchvision.utils as vutils 5 | import os 6 | from tqdm import tqdm 7 | 8 | def adjust_learning_rate(optimizer, step_num, warmup_step=4000): 9 | lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = lr 12 | 13 | 14 | def main(): 15 | 16 | dataset = get_dataset() 17 | global_step = 0 18 | 19 | m = nn.DataParallel(Model().cuda()) 20 | 21 | m.train() 22 | optimizer = t.optim.Adam(m.parameters(), lr=hp.lr) 23 | 24 | pos_weight = t.FloatTensor([5.]).cuda() 25 | writer = SummaryWriter() 26 | 27 | for epoch in range(hp.epochs): 28 | 29 | dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_transformer, drop_last=True, num_workers=16) 30 | pbar = tqdm(dataloader) 31 | for i, data in enumerate(pbar): 32 | pbar.set_description("Processing at epoch %d"%epoch) 33 | global_step += 1 34 | if global_step < 400000: 35 | adjust_learning_rate(optimizer, global_step) 36 | 37 | character, mel, mel_input, pos_text, pos_mel, _ = data 38 | 39 | stop_tokens = t.abs(pos_mel.ne(0).type(t.float) - 1) 40 | 41 | character = character.cuda() 42 | mel = mel.cuda() 43 | mel_input = mel_input.cuda() 44 | pos_text = pos_text.cuda() 45 | pos_mel = pos_mel.cuda() 46 | 47 | mel_pred, postnet_pred, attn_probs, stop_preds, attns_enc, attns_dec = m.forward(character, mel_input, pos_text, pos_mel) 48 | 49 | mel_loss = nn.L1Loss()(mel_pred, mel) 50 | post_mel_loss = nn.L1Loss()(postnet_pred, mel) 51 | 52 | loss = mel_loss + post_mel_loss 53 | 54 | writer.add_scalars('training_loss',{ 55 | 'mel_loss':mel_loss, 56 | 'post_mel_loss':post_mel_loss, 57 | 58 | }, global_step) 59 | 60 | writer.add_scalars('alphas',{ 61 | 'encoder_alpha':m.module.encoder.alpha.data, 62 | 'decoder_alpha':m.module.decoder.alpha.data, 63 | }, global_step) 64 | 65 | 66 | if global_step % hp.image_step == 1: 67 | 68 | for i, prob in enumerate(attn_probs): 69 | 70 | num_h = prob.size(0) 71 | for j in range(4): 72 | 73 | x = vutils.make_grid(prob[j*16] * 255) 74 | writer.add_image('Attention_%d_0'%global_step, x, i*4+j) 75 | 76 | for i, prob in enumerate(attns_enc): 77 | num_h = prob.size(0) 78 | 79 | for j in range(4): 80 | 81 | x = vutils.make_grid(prob[j*16] * 255) 82 | writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j) 83 | 84 | for i, prob in enumerate(attns_dec): 85 | 86 | num_h = prob.size(0) 87 | for j in range(4): 88 | 89 | x = vutils.make_grid(prob[j*16] * 255) 90 | writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j) 91 | 92 | optimizer.zero_grad() 93 | # Calculate gradients 94 | loss.backward() 95 | 96 | nn.utils.clip_grad_norm_(m.parameters(), 1.) 97 | 98 | # Update weights 99 | optimizer.step() 100 | 101 | if global_step % hp.save_step == 0: 102 | t.save({'model':m.state_dict(), 103 | 'optimizer':optimizer.state_dict()}, 104 | os.path.join(hp.checkpoint_path,'checkpoint_transformer_%d.pth.tar' % global_step)) 105 | 106 | 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import os, copy 4 | from scipy import signal 5 | import hyperparams as hp 6 | import torch as t 7 | 8 | def get_spectrograms(fpath): 9 | '''Parse the wave file in `fpath` and 10 | Returns normalized melspectrogram and linear spectrogram. 11 | Args: 12 | fpath: A string. The full path of a sound file. 13 | Returns: 14 | mel: A 2d array of shape (T, n_mels) and dtype of float32. 15 | mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32. 16 | ''' 17 | # Loading sound file 18 | y, sr = librosa.load(fpath, sr=hp.sr) 19 | 20 | # Trimming 21 | y, _ = librosa.effects.trim(y) 22 | 23 | # Preemphasis 24 | y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1]) 25 | 26 | # stft 27 | linear = librosa.stft(y=y, 28 | n_fft=hp.n_fft, 29 | hop_length=hp.hop_length, 30 | win_length=hp.win_length) 31 | 32 | # magnitude spectrogram 33 | mag = np.abs(linear) # (1+n_fft//2, T) 34 | 35 | # mel spectrogram 36 | mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) # (n_mels, 1+n_fft//2) 37 | mel = np.dot(mel_basis, mag) # (n_mels, t) 38 | 39 | # to decibel 40 | mel = 20 * np.log10(np.maximum(1e-5, mel)) 41 | mag = 20 * np.log10(np.maximum(1e-5, mag)) 42 | 43 | # normalize 44 | mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 45 | mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 46 | 47 | # Transpose 48 | mel = mel.T.astype(np.float32) # (T, n_mels) 49 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2) 50 | 51 | return mel, mag 52 | 53 | def spectrogram2wav(mag): 54 | '''# Generate wave file from linear magnitude spectrogram 55 | Args: 56 | mag: A numpy array of (T, 1+n_fft//2) 57 | Returns: 58 | wav: A 1-D numpy array. 59 | ''' 60 | # transpose 61 | mag = mag.T 62 | 63 | # de-noramlize 64 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db 65 | 66 | # to amplitude 67 | mag = np.power(10.0, mag * 0.05) 68 | 69 | # wav reconstruction 70 | wav = griffin_lim(mag**hp.power) 71 | 72 | # de-preemphasis 73 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav) 74 | 75 | # trim 76 | wav, _ = librosa.effects.trim(wav) 77 | 78 | return wav.astype(np.float32) 79 | 80 | def griffin_lim(spectrogram): 81 | '''Applies Griffin-Lim's raw.''' 82 | X_best = copy.deepcopy(spectrogram) 83 | for i in range(hp.n_iter): 84 | X_t = invert_spectrogram(X_best) 85 | est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length) 86 | phase = est / np.maximum(1e-8, np.abs(est)) 87 | X_best = spectrogram * phase 88 | X_t = invert_spectrogram(X_best) 89 | y = np.real(X_t) 90 | 91 | return y 92 | 93 | def invert_spectrogram(spectrogram): 94 | '''Applies inverse fft. 95 | Args: 96 | spectrogram: [1+n_fft//2, t] 97 | ''' 98 | return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann") 99 | 100 | def get_positional_table(d_pos_vec, n_position=1024): 101 | position_enc = np.array([ 102 | [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)] 103 | if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) 104 | 105 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i 106 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 107 | return t.from_numpy(position_enc).type(t.FloatTensor) 108 | 109 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 110 | ''' Sinusoid position encoding table ''' 111 | 112 | def cal_angle(position, hid_idx): 113 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 114 | 115 | def get_posi_angle_vec(position): 116 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 117 | 118 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 119 | 120 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 121 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 122 | 123 | if padding_idx is not None: 124 | # zero vector for padding dimension 125 | sinusoid_table[padding_idx] = 0. 126 | 127 | return t.FloatTensor(sinusoid_table) 128 | 129 | def guided_attention(N, T, g=0.2): 130 | '''Guided attention. Refer to page 3 on the paper.''' 131 | W = np.zeros((N, T), dtype=np.float32) 132 | for n_pos in range(W.shape[0]): 133 | for t_pos in range(W.shape[1]): 134 | W[n_pos, t_pos] = 1 - np.exp(-(t_pos / float(T) - n_pos / float(N)) ** 2 / (2 * g * g)) 135 | return W 136 | --------------------------------------------------------------------------------