├── .gitignore
├── LICENSE
├── README.md
├── hyperparams.py
├── module.py
├── network.py
├── png
├── alphas.png
├── attention.gif
├── attention
│ ├── attention_0_0.png
│ ├── attention_0_1.png
│ ├── attention_0_2.png
│ ├── attention_0_3.png
│ ├── attention_1_0.png
│ ├── attention_1_1.png
│ ├── attention_1_2.png
│ ├── attention_1_3.png
│ ├── attention_2_0.png
│ ├── attention_2_1.png
│ ├── attention_2_2.png
│ └── attention_2_3.png
├── attention_decoder.gif
├── attention_decoder
│ ├── attention_dec_0_0.png
│ ├── attention_dec_0_1.png
│ ├── attention_dec_0_2.png
│ ├── attention_dec_0_3.png
│ ├── attention_dec_1_0.png
│ ├── attention_dec_1_1.png
│ ├── attention_dec_1_2.png
│ ├── attention_dec_1_3.png
│ ├── attention_dec_2_0.png
│ ├── attention_dec_2_1.png
│ ├── attention_dec_2_2.png
│ └── attention_dec_2_3.png
├── attention_encoder.gif
├── attention_encoder
│ ├── attention_enc_0_0.png
│ ├── attention_enc_0_1.png
│ ├── attention_enc_0_2.png
│ ├── attention_enc_0_3.png
│ ├── attention_enc_1_0.png
│ ├── attention_enc_1_1.png
│ ├── attention_enc_1_2.png
│ ├── attention_enc_1_3.png
│ ├── attention_enc_2_0.png
│ ├── attention_enc_2_1.png
│ ├── attention_enc_2_2.png
│ └── attention_enc_2_3.png
├── mel_original.png
├── mel_pred.png
├── model.png
└── training_loss.png
├── prepare_data.py
├── preprocess.py
├── requirements.txt
├── samples
└── test.wav
├── synthesis.py
├── text
├── __init__.py
├── cleaners.py
├── cmudict.py
├── numbers.py
└── symbols.py
├── train_postnet.py
├── train_transformer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.pyc
3 | checkpoint*
4 | assets/
5 | .ipynb_checkpoints/
6 | .idea/
7 | data/
8 | ._.DS_Store
9 | runs/
10 | test.ipynb
11 | synthesis.ipynb
12 | log*
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 soobin seo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer-TTS
2 | * A Pytorch Implementation of [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895)
3 | * This model can be trained about 3 to 4 times faster than the well known seq2seq model like tacotron, and the quality of synthesized speech is almost the same. It was confirmed through experiment that it took about 0.5 second per step.
4 | * I did not use the wavenet vocoder but learned the post network using CBHG model of tacotron and converted the spectrogram into raw wave using griffin-lim algorithm.
5 |
6 |
7 |
8 | ## Requirements
9 | * Install python 3
10 | * Install pytorch == 0.4.0
11 | * Install requirements:
12 | ```
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## Data
17 | * I used LJSpeech dataset which consists of pairs of text script and wav files. The complete dataset (13,100 pairs) can be downloaded [here](https://keithito.com/LJ-Speech-Dataset/). I referred https://github.com/keithito/tacotron and https://github.com/Kyubyong/dc_tts for the preprocessing code.
18 |
19 | ## Pretrained Model
20 | * you can download pretrained model [here](https://drive.google.com/drive/folders/1r1tdgsdtipLossqD9ZfDmxSZb8nMO8Nf) (160K for AR model / 100K for Postnet)
21 | * Locate the pretrained model at checkpoint/ directory.
22 |
23 | ## Attention plots
24 | * A diagonal alignment appeared after about 15k steps. The attention plots below are at 160k steps. Plots represent the multihead attention of all layers. In this experiment, h=4 is used for three attention layers. Therefore, 12 attention plots were drawn for each of the encoder, decoder and encoder-decoder. With the exception of the decoder, only a few multiheads showed diagonal alignment.
25 |
26 | ### Self Attention encoder
27 |
28 |
29 | ### Self Attention decoder
30 |
31 |
32 | ### Attention encoder-decoder
33 |
34 |
35 | ## Learning curves & Alphas
36 | * I used Noam style warmup and decay as same as [Tacotron](https://github.com/Kyubyong/tacotron)
37 |
38 |
39 |
40 | * The alpha value for the scaled position encoding is different from the thesis. In the paper, the alpha value of the encoder is increased to 4, whereas in the present experiment, it slightly increased at the beginning and then decreased continuously. The decoder alpha has steadily decreased since the beginning.
41 |
42 |
43 |
44 | ## Experimental notes
45 | 1. **The learning rate is an important parameter for training.** With initial learning rate of 0.001 and exponentially decaying doesn't work.
46 | 2. **The gradient clipping is also an important parameter for training.** I clipped the gradient with norm value 1.
47 | 3. With the stop token loss, the model did not training.
48 | 4. **It was very important to concatenate the input and context vectors in the Attention mechanism.**
49 |
50 | ## Generated Samples
51 | * You can check some generated samples below. All samples are step at 160k, so I think the model is not converged yet. This model seems to be lower performance in long sentences.
52 |
53 | * [sample1](https://soundcloud.com/ksrbpbmcxrzu/160k-0)
54 | * [sample2](https://soundcloud.com/ksrbpbmcxrzu/160k_sample_1)
55 | * [sample3](https://soundcloud.com/ksrbpbmcxrzu/160k_sample_2)
56 |
57 | * The first plot is the predicted mel spectrogram, and the second is the ground truth.
58 |
59 |
60 |
61 | ## File description
62 | * `hyperparams.py` includes all hyper parameters that are needed.
63 | * `prepare_data.py` preprocess wav files to mel, linear spectrogram and save them for faster training time. Preprocessing codes for text is in text/ directory.
64 | * `preprocess.py` includes all preprocessing codes when you loads data.
65 | * `module.py` contains all methods, including attention, prenet, postnet and so on.
66 | * `network.py` contains networks including encoder, decoder and post-processing network.
67 | * `train_transformer.py` is for training autoregressive attention network. (text --> mel)
68 | * `train_postnet.py` is for training post network. (mel --> linear)
69 | * `synthesis.py` is for generating TTS sample.
70 |
71 | ## Training the network
72 | * STEP 1. Download and extract LJSpeech data at any directory you want.
73 | * STEP 2. Adjust hyperparameters in `hyperparams.py`, especially 'data_path' which is a directory that you extract files, and the others if necessary.
74 | * STEP 3. Run `prepare_data.py`.
75 | * STEP 4. Run `train_transformer.py`.
76 | * STEP 5. Run `train_postnet.py`.
77 |
78 | ## Generate TTS wav file
79 | * STEP 1. Run `synthesis.py`. Make sure the restore step.
80 |
81 | ## Reference
82 | * Keith ito: https://github.com/keithito/tacotron
83 | * Kyubyong Park: https://github.com/Kyubyong/dc_tts
84 | * jadore801120: https://github.com/jadore801120/attention-is-all-you-need-pytorch/
85 |
86 | ## Comments
87 | * Any comments for the codes are always welcome.
88 |
89 |
--------------------------------------------------------------------------------
/hyperparams.py:
--------------------------------------------------------------------------------
1 | # Audio
2 | num_mels = 80
3 | # num_freq = 1024
4 | n_fft = 2048
5 | sr = 22050
6 | # frame_length_ms = 50.
7 | # frame_shift_ms = 12.5
8 | preemphasis = 0.97
9 | frame_shift = 0.0125 # seconds
10 | frame_length = 0.05 # seconds
11 | hop_length = int(sr*frame_shift) # samples.
12 | win_length = int(sr*frame_length) # samples.
13 | n_mels = 80 # Number of Mel banks to generate
14 | power = 1.2 # Exponent for amplifying the predicted magnitude
15 | min_level_db = -100
16 | ref_level_db = 20
17 | hidden_size = 256
18 | embedding_size = 512
19 | max_db = 100
20 | ref_db = 20
21 |
22 | n_iter = 60
23 | # power = 1.5
24 | outputs_per_step = 1
25 |
26 | epochs = 10000
27 | lr = 0.001
28 | save_step = 2000
29 | image_step = 500
30 | batch_size = 32
31 |
32 | cleaners='english_cleaners'
33 |
34 | data_path = './data/LJSpeech-1.1'
35 | checkpoint_path = './checkpoint'
36 | sample_path = './samples'
--------------------------------------------------------------------------------
/module.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch as t
3 | import torch.nn.functional as F
4 | import math
5 | import hyperparams as hp
6 | from text.symbols import symbols
7 | import numpy as np
8 | import copy
9 | from collections import OrderedDict
10 |
11 | def clones(module, N):
12 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
13 |
14 | class Linear(nn.Module):
15 | """
16 | Linear Module
17 | """
18 | def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
19 | """
20 | :param in_dim: dimension of input
21 | :param out_dim: dimension of output
22 | :param bias: boolean. if True, bias is included.
23 | :param w_init: str. weight inits with xavier initialization.
24 | """
25 | super(Linear, self).__init__()
26 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
27 |
28 | nn.init.xavier_uniform_(
29 | self.linear_layer.weight,
30 | gain=nn.init.calculate_gain(w_init))
31 |
32 | def forward(self, x):
33 | return self.linear_layer(x)
34 |
35 |
36 | class Conv(nn.Module):
37 | """
38 | Convolution Module
39 | """
40 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
41 | padding=0, dilation=1, bias=True, w_init='linear'):
42 | """
43 | :param in_channels: dimension of input
44 | :param out_channels: dimension of output
45 | :param kernel_size: size of kernel
46 | :param stride: size of stride
47 | :param padding: size of padding
48 | :param dilation: dilation rate
49 | :param bias: boolean. if True, bias is included.
50 | :param w_init: str. weight inits with xavier initialization.
51 | """
52 | super(Conv, self).__init__()
53 |
54 | self.conv = nn.Conv1d(in_channels, out_channels,
55 | kernel_size=kernel_size, stride=stride,
56 | padding=padding, dilation=dilation,
57 | bias=bias)
58 |
59 | nn.init.xavier_uniform_(
60 | self.conv.weight, gain=nn.init.calculate_gain(w_init))
61 |
62 | def forward(self, x):
63 | x = self.conv(x)
64 | return x
65 |
66 |
67 | class EncoderPrenet(nn.Module):
68 | """
69 | Pre-network for Encoder consists of convolution networks.
70 | """
71 | def __init__(self, embedding_size, num_hidden):
72 | super(EncoderPrenet, self).__init__()
73 | self.embedding_size = embedding_size
74 | self.embed = nn.Embedding(len(symbols), embedding_size, padding_idx=0)
75 |
76 | self.conv1 = Conv(in_channels=embedding_size,
77 | out_channels=num_hidden,
78 | kernel_size=5,
79 | padding=int(np.floor(5 / 2)),
80 | w_init='relu')
81 | self.conv2 = Conv(in_channels=num_hidden,
82 | out_channels=num_hidden,
83 | kernel_size=5,
84 | padding=int(np.floor(5 / 2)),
85 | w_init='relu')
86 |
87 | self.conv3 = Conv(in_channels=num_hidden,
88 | out_channels=num_hidden,
89 | kernel_size=5,
90 | padding=int(np.floor(5 / 2)),
91 | w_init='relu')
92 |
93 | self.batch_norm1 = nn.BatchNorm1d(num_hidden)
94 | self.batch_norm2 = nn.BatchNorm1d(num_hidden)
95 | self.batch_norm3 = nn.BatchNorm1d(num_hidden)
96 |
97 | self.dropout1 = nn.Dropout(p=0.2)
98 | self.dropout2 = nn.Dropout(p=0.2)
99 | self.dropout3 = nn.Dropout(p=0.2)
100 | self.projection = Linear(num_hidden, num_hidden)
101 |
102 | def forward(self, input_):
103 | input_ = self.embed(input_)
104 | input_ = input_.transpose(1, 2)
105 | input_ = self.dropout1(t.relu(self.batch_norm1(self.conv1(input_))))
106 | input_ = self.dropout2(t.relu(self.batch_norm2(self.conv2(input_))))
107 | input_ = self.dropout3(t.relu(self.batch_norm3(self.conv3(input_))))
108 | input_ = input_.transpose(1, 2)
109 | input_ = self.projection(input_)
110 |
111 | return input_
112 |
113 |
114 | class FFN(nn.Module):
115 | """
116 | Positionwise Feed-Forward Network
117 | """
118 |
119 | def __init__(self, num_hidden):
120 | """
121 | :param num_hidden: dimension of hidden
122 | """
123 | super(FFN, self).__init__()
124 | self.w_1 = Conv(num_hidden, num_hidden * 4, kernel_size=1, w_init='relu')
125 | self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=1)
126 | self.dropout = nn.Dropout(p=0.1)
127 | self.layer_norm = nn.LayerNorm(num_hidden)
128 |
129 | def forward(self, input_):
130 | # FFN Network
131 | x = input_.transpose(1, 2)
132 | x = self.w_2(t.relu(self.w_1(x)))
133 | x = x.transpose(1, 2)
134 |
135 |
136 | # residual connection
137 | x = x + input_
138 |
139 | # dropout
140 | # x = self.dropout(x)
141 |
142 | # layer normalization
143 | x = self.layer_norm(x)
144 |
145 | return x
146 |
147 |
148 | class PostConvNet(nn.Module):
149 | """
150 | Post Convolutional Network (mel --> mel)
151 | """
152 | def __init__(self, num_hidden):
153 | """
154 |
155 | :param num_hidden: dimension of hidden
156 | """
157 | super(PostConvNet, self).__init__()
158 | self.conv1 = Conv(in_channels=hp.num_mels * hp.outputs_per_step,
159 | out_channels=num_hidden,
160 | kernel_size=5,
161 | padding=4,
162 | w_init='tanh')
163 | self.conv_list = clones(Conv(in_channels=num_hidden,
164 | out_channels=num_hidden,
165 | kernel_size=5,
166 | padding=4,
167 | w_init='tanh'), 3)
168 | self.conv2 = Conv(in_channels=num_hidden,
169 | out_channels=hp.num_mels * hp.outputs_per_step,
170 | kernel_size=5,
171 | padding=4)
172 |
173 | self.batch_norm_list = clones(nn.BatchNorm1d(num_hidden), 3)
174 | self.pre_batchnorm = nn.BatchNorm1d(num_hidden)
175 |
176 | self.dropout1 = nn.Dropout(p=0.1)
177 | self.dropout_list = nn.ModuleList([nn.Dropout(p=0.1) for _ in range(3)])
178 |
179 | def forward(self, input_, mask=None):
180 | # Causal Convolution (for auto-regressive)
181 | input_ = self.dropout1(t.tanh(self.pre_batchnorm(self.conv1(input_)[:, :, :-4])))
182 | for batch_norm, conv, dropout in zip(self.batch_norm_list, self.conv_list, self.dropout_list):
183 | input_ = dropout(t.tanh(batch_norm(conv(input_)[:, :, :-4])))
184 | input_ = self.conv2(input_)[:, :, :-4]
185 | return input_
186 |
187 |
188 | class MultiheadAttention(nn.Module):
189 | """
190 | Multihead attention mechanism (dot attention)
191 | """
192 | def __init__(self, num_hidden_k):
193 | """
194 | :param num_hidden_k: dimension of hidden
195 | """
196 | super(MultiheadAttention, self).__init__()
197 |
198 | self.num_hidden_k = num_hidden_k
199 | self.attn_dropout = nn.Dropout(p=0.1)
200 |
201 | def forward(self, key, value, query, mask=None, query_mask=None):
202 | # Get attention score
203 | attn = t.bmm(query, key.transpose(1, 2))
204 | attn = attn / math.sqrt(self.num_hidden_k)
205 |
206 | # Masking to ignore padding (key side)
207 | if mask is not None:
208 | attn = attn.masked_fill(mask, -2 ** 32 + 1)
209 | attn = t.softmax(attn, dim=-1)
210 | else:
211 | attn = t.softmax(attn, dim=-1)
212 |
213 | # Masking to ignore padding (query side)
214 | if query_mask is not None:
215 | attn = attn * query_mask
216 |
217 | # Dropout
218 | # attn = self.attn_dropout(attn)
219 |
220 | # Get Context Vector
221 | result = t.bmm(attn, value)
222 |
223 | return result, attn
224 |
225 |
226 | class Attention(nn.Module):
227 | """
228 | Attention Network
229 | """
230 | def __init__(self, num_hidden, h=4):
231 | """
232 | :param num_hidden: dimension of hidden
233 | :param h: num of heads
234 | """
235 | super(Attention, self).__init__()
236 |
237 | self.num_hidden = num_hidden
238 | self.num_hidden_per_attn = num_hidden // h
239 | self.h = h
240 |
241 | self.key = Linear(num_hidden, num_hidden, bias=False)
242 | self.value = Linear(num_hidden, num_hidden, bias=False)
243 | self.query = Linear(num_hidden, num_hidden, bias=False)
244 |
245 | self.multihead = MultiheadAttention(self.num_hidden_per_attn)
246 |
247 | self.residual_dropout = nn.Dropout(p=0.1)
248 |
249 | self.final_linear = Linear(num_hidden * 2, num_hidden)
250 |
251 | self.layer_norm_1 = nn.LayerNorm(num_hidden)
252 |
253 | def forward(self, memory, decoder_input, mask=None, query_mask=None):
254 |
255 | batch_size = memory.size(0)
256 | seq_k = memory.size(1)
257 | seq_q = decoder_input.size(1)
258 |
259 | # Repeat masks h times
260 | if query_mask is not None:
261 | query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k)
262 | query_mask = query_mask.repeat(self.h, 1, 1)
263 | if mask is not None:
264 | mask = mask.repeat(self.h, 1, 1)
265 |
266 | # Make multihead
267 | key = self.key(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn)
268 | value = self.value(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn)
269 | query = self.query(decoder_input).view(batch_size, seq_q, self.h, self.num_hidden_per_attn)
270 |
271 | key = key.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn)
272 | value = value.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn)
273 | query = query.permute(2, 0, 1, 3).contiguous().view(-1, seq_q, self.num_hidden_per_attn)
274 |
275 | # Get context vector
276 | result, attns = self.multihead(key, value, query, mask=mask, query_mask=query_mask)
277 |
278 | # Concatenate all multihead context vector
279 | result = result.view(self.h, batch_size, seq_q, self.num_hidden_per_attn)
280 | result = result.permute(1, 2, 0, 3).contiguous().view(batch_size, seq_q, -1)
281 |
282 | # Concatenate context vector with input (most important)
283 | result = t.cat([decoder_input, result], dim=-1)
284 |
285 | # Final linear
286 | result = self.final_linear(result)
287 |
288 | # Residual dropout & connection
289 | result = result + decoder_input
290 |
291 | # result = self.residual_dropout(result)
292 |
293 | # Layer normalization
294 | result = self.layer_norm_1(result)
295 |
296 | return result, attns
297 |
298 |
299 | class Prenet(nn.Module):
300 | """
301 | Prenet before passing through the network
302 | """
303 | def __init__(self, input_size, hidden_size, output_size, p=0.5):
304 | """
305 | :param input_size: dimension of input
306 | :param hidden_size: dimension of hidden unit
307 | :param output_size: dimension of output
308 | """
309 | super(Prenet, self).__init__()
310 | self.input_size = input_size
311 | self.output_size = output_size
312 | self.hidden_size = hidden_size
313 | self.layer = nn.Sequential(OrderedDict([
314 | ('fc1', Linear(self.input_size, self.hidden_size)),
315 | ('relu1', nn.ReLU()),
316 | ('dropout1', nn.Dropout(p)),
317 | ('fc2', Linear(self.hidden_size, self.output_size)),
318 | ('relu2', nn.ReLU()),
319 | ('dropout2', nn.Dropout(p)),
320 | ]))
321 |
322 | def forward(self, input_):
323 |
324 | out = self.layer(input_)
325 |
326 | return out
327 |
328 | class CBHG(nn.Module):
329 | """
330 | CBHG Module
331 | """
332 | def __init__(self, hidden_size, K=16, projection_size = 256, num_gru_layers=2, max_pool_kernel_size=2, is_post=False):
333 | """
334 | :param hidden_size: dimension of hidden unit
335 | :param K: # of convolution banks
336 | :param projection_size: dimension of projection unit
337 | :param num_gru_layers: # of layers of GRUcell
338 | :param max_pool_kernel_size: max pooling kernel size
339 | :param is_post: whether post processing or not
340 | """
341 | super(CBHG, self).__init__()
342 | self.hidden_size = hidden_size
343 | self.projection_size = projection_size
344 | self.convbank_list = nn.ModuleList()
345 | self.convbank_list.append(nn.Conv1d(in_channels=projection_size,
346 | out_channels=hidden_size,
347 | kernel_size=1,
348 | padding=int(np.floor(1/2))))
349 |
350 | for i in range(2, K+1):
351 | self.convbank_list.append(nn.Conv1d(in_channels=hidden_size,
352 | out_channels=hidden_size,
353 | kernel_size=i,
354 | padding=int(np.floor(i/2))))
355 |
356 | self.batchnorm_list = nn.ModuleList()
357 | for i in range(1, K+1):
358 | self.batchnorm_list.append(nn.BatchNorm1d(hidden_size))
359 |
360 | convbank_outdim = hidden_size * K
361 |
362 | self.conv_projection_1 = nn.Conv1d(in_channels=convbank_outdim,
363 | out_channels=hidden_size,
364 | kernel_size=3,
365 | padding=int(np.floor(3 / 2)))
366 | self.conv_projection_2 = nn.Conv1d(in_channels=hidden_size,
367 | out_channels=projection_size,
368 | kernel_size=3,
369 | padding=int(np.floor(3 / 2)))
370 | self.batchnorm_proj_1 = nn.BatchNorm1d(hidden_size)
371 |
372 | self.batchnorm_proj_2 = nn.BatchNorm1d(projection_size)
373 |
374 |
375 | self.max_pool = nn.MaxPool1d(max_pool_kernel_size, stride=1, padding=1)
376 | self.highway = Highwaynet(self.projection_size)
377 | self.gru = nn.GRU(self.projection_size, self.hidden_size // 2, num_layers=num_gru_layers,
378 | batch_first=True,
379 | bidirectional=True)
380 |
381 |
382 | def _conv_fit_dim(self, x, kernel_size=3):
383 | if kernel_size % 2 == 0:
384 | return x[:,:,:-1]
385 | else:
386 | return x
387 |
388 | def forward(self, input_):
389 |
390 | input_ = input_.contiguous()
391 | batch_size = input_.size(0)
392 | total_length = input_.size(-1)
393 |
394 | convbank_list = list()
395 | convbank_input = input_
396 |
397 | # Convolution bank filters
398 | for k, (conv, batchnorm) in enumerate(zip(self.convbank_list, self.batchnorm_list)):
399 | convbank_input = t.relu(batchnorm(self._conv_fit_dim(conv(convbank_input), k+1).contiguous()))
400 | convbank_list.append(convbank_input)
401 |
402 | # Concatenate all features
403 | conv_cat = t.cat(convbank_list, dim=1)
404 |
405 | # Max pooling
406 | conv_cat = self.max_pool(conv_cat)[:,:,:-1]
407 |
408 | # Projection
409 | conv_projection = t.relu(self.batchnorm_proj_1(self._conv_fit_dim(self.conv_projection_1(conv_cat))))
410 | conv_projection = self.batchnorm_proj_2(self._conv_fit_dim(self.conv_projection_2(conv_projection))) + input_
411 |
412 | # Highway networks
413 | highway = self.highway.forward(conv_projection.transpose(1,2))
414 |
415 |
416 | # Bidirectional GRU
417 |
418 | self.gru.flatten_parameters()
419 | out, _ = self.gru(highway)
420 |
421 | return out
422 |
423 |
424 | class Highwaynet(nn.Module):
425 | """
426 | Highway network
427 | """
428 | def __init__(self, num_units, num_layers=4):
429 | """
430 | :param num_units: dimension of hidden unit
431 | :param num_layers: # of highway layers
432 | """
433 | super(Highwaynet, self).__init__()
434 | self.num_units = num_units
435 | self.num_layers = num_layers
436 | self.gates = nn.ModuleList()
437 | self.linears = nn.ModuleList()
438 | for _ in range(self.num_layers):
439 | self.linears.append(Linear(num_units, num_units))
440 | self.gates.append(Linear(num_units, num_units))
441 |
442 | def forward(self, input_):
443 |
444 | out = input_
445 |
446 | # highway gated function
447 | for fc1, fc2 in zip(self.linears, self.gates):
448 |
449 | h = t.relu(fc1.forward(out))
450 | t_ = t.sigmoid(fc2.forward(out))
451 |
452 | c = 1. - t_
453 | out = h * t_ + out * c
454 |
455 | return out
456 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | from module import *
2 | from utils import get_positional_table, get_sinusoid_encoding_table
3 | import hyperparams as hp
4 | import copy
5 |
6 | class Encoder(nn.Module):
7 | """
8 | Encoder Network
9 | """
10 | def __init__(self, embedding_size, num_hidden):
11 | """
12 | :param embedding_size: dimension of embedding
13 | :param num_hidden: dimension of hidden
14 | """
15 | super(Encoder, self).__init__()
16 | self.alpha = nn.Parameter(t.ones(1))
17 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0),
18 | freeze=True)
19 | self.pos_dropout = nn.Dropout(p=0.1)
20 | self.encoder_prenet = EncoderPrenet(embedding_size, num_hidden)
21 | self.layers = clones(Attention(num_hidden), 3)
22 | self.ffns = clones(FFN(num_hidden), 3)
23 |
24 | def forward(self, x, pos):
25 |
26 | # Get character mask
27 | if self.training:
28 | c_mask = pos.ne(0).type(t.float)
29 | mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1)
30 |
31 | else:
32 | c_mask, mask = None, None
33 |
34 | # Encoder pre-network
35 | x = self.encoder_prenet(x)
36 |
37 | # Get positional embedding, apply alpha and add
38 | pos = self.pos_emb(pos)
39 | x = pos * self.alpha + x
40 |
41 | # Positional dropout
42 | x = self.pos_dropout(x)
43 |
44 | # Attention encoder-encoder
45 | attns = list()
46 | for layer, ffn in zip(self.layers, self.ffns):
47 | x, attn = layer(x, x, mask=mask, query_mask=c_mask)
48 | x = ffn(x)
49 | attns.append(attn)
50 |
51 | return x, c_mask, attns
52 |
53 |
54 | class MelDecoder(nn.Module):
55 | """
56 | Decoder Network
57 | """
58 | def __init__(self, num_hidden):
59 | """
60 | :param num_hidden: dimension of hidden
61 | """
62 | super(MelDecoder, self).__init__()
63 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0),
64 | freeze=True)
65 | self.pos_dropout = nn.Dropout(p=0.1)
66 | self.alpha = nn.Parameter(t.ones(1))
67 | self.decoder_prenet = Prenet(hp.num_mels, num_hidden * 2, num_hidden, p=0.2)
68 | self.norm = Linear(num_hidden, num_hidden)
69 |
70 | self.selfattn_layers = clones(Attention(num_hidden), 3)
71 | self.dotattn_layers = clones(Attention(num_hidden), 3)
72 | self.ffns = clones(FFN(num_hidden), 3)
73 | self.mel_linear = Linear(num_hidden, hp.num_mels * hp.outputs_per_step)
74 | self.stop_linear = Linear(num_hidden, 1, w_init='sigmoid')
75 |
76 | self.postconvnet = PostConvNet(num_hidden)
77 |
78 | def forward(self, memory, decoder_input, c_mask, pos):
79 | batch_size = memory.size(0)
80 | decoder_len = decoder_input.size(1)
81 |
82 | # get decoder mask with triangular matrix
83 | if self.training:
84 | m_mask = pos.ne(0).type(t.float)
85 | mask = m_mask.eq(0).unsqueeze(1).repeat(1, decoder_len, 1)
86 | if next(self.parameters()).is_cuda:
87 | mask = mask + t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte()
88 | else:
89 | mask = mask + t.triu(t.ones(decoder_len, decoder_len), diagonal=1).repeat(batch_size, 1, 1).byte()
90 | mask = mask.gt(0)
91 | zero_mask = c_mask.eq(0).unsqueeze(-1).repeat(1, 1, decoder_len)
92 | zero_mask = zero_mask.transpose(1, 2)
93 | else:
94 | if next(self.parameters()).is_cuda:
95 | mask = t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte()
96 | else:
97 | mask = t.triu(t.ones(decoder_len, decoder_len), diagonal=1).repeat(batch_size, 1, 1).byte()
98 | mask = mask.gt(0)
99 | m_mask, zero_mask = None, None
100 |
101 | # Decoder pre-network
102 | decoder_input = self.decoder_prenet(decoder_input)
103 |
104 | # Centered position
105 | decoder_input = self.norm(decoder_input)
106 |
107 | # Get positional embedding, apply alpha and add
108 | pos = self.pos_emb(pos)
109 | decoder_input = pos * self.alpha + decoder_input
110 |
111 | # Positional dropout
112 | decoder_input = self.pos_dropout(decoder_input)
113 |
114 | # Attention decoder-decoder, encoder-decoder
115 | attn_dot_list = list()
116 | attn_dec_list = list()
117 |
118 | for selfattn, dotattn, ffn in zip(self.selfattn_layers, self.dotattn_layers, self.ffns):
119 | decoder_input, attn_dec = selfattn(decoder_input, decoder_input, mask=mask, query_mask=m_mask)
120 | decoder_input, attn_dot = dotattn(memory, decoder_input, mask=zero_mask, query_mask=m_mask)
121 | decoder_input = ffn(decoder_input)
122 | attn_dot_list.append(attn_dot)
123 | attn_dec_list.append(attn_dec)
124 |
125 | # Mel linear projection
126 | mel_out = self.mel_linear(decoder_input)
127 |
128 | # Post Mel Network
129 | postnet_input = mel_out.transpose(1, 2)
130 | out = self.postconvnet(postnet_input)
131 | out = postnet_input + out
132 | out = out.transpose(1, 2)
133 |
134 | # Stop tokens
135 | stop_tokens = self.stop_linear(decoder_input)
136 |
137 | return mel_out, out, attn_dot_list, stop_tokens, attn_dec_list
138 |
139 |
140 | class Model(nn.Module):
141 | """
142 | Transformer Network
143 | """
144 | def __init__(self):
145 | super(Model, self).__init__()
146 | self.encoder = Encoder(hp.embedding_size, hp.hidden_size)
147 | self.decoder = MelDecoder(hp.hidden_size)
148 |
149 | def forward(self, characters, mel_input, pos_text, pos_mel):
150 | memory, c_mask, attns_enc = self.encoder.forward(characters, pos=pos_text)
151 | mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder.forward(memory, mel_input, c_mask,
152 | pos=pos_mel)
153 |
154 | return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
155 |
156 |
157 | class ModelPostNet(nn.Module):
158 | """
159 | CBHG Network (mel --> linear)
160 | """
161 | def __init__(self):
162 | super(ModelPostNet, self).__init__()
163 | self.pre_projection = Conv(hp.n_mels, hp.hidden_size)
164 | self.cbhg = CBHG(hp.hidden_size)
165 | self.post_projection = Conv(hp.hidden_size, (hp.n_fft // 2) + 1)
166 |
167 | def forward(self, mel):
168 | mel = mel.transpose(1, 2)
169 | mel = self.pre_projection(mel)
170 | mel = self.cbhg(mel).transpose(1, 2)
171 | mag_pred = self.post_projection(mel).transpose(1, 2)
172 |
173 | return mag_pred
--------------------------------------------------------------------------------
/png/alphas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/alphas.png
--------------------------------------------------------------------------------
/png/attention.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention.gif
--------------------------------------------------------------------------------
/png/attention/attention_0_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_0.png
--------------------------------------------------------------------------------
/png/attention/attention_0_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_1.png
--------------------------------------------------------------------------------
/png/attention/attention_0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_2.png
--------------------------------------------------------------------------------
/png/attention/attention_0_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_0_3.png
--------------------------------------------------------------------------------
/png/attention/attention_1_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_0.png
--------------------------------------------------------------------------------
/png/attention/attention_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_1.png
--------------------------------------------------------------------------------
/png/attention/attention_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_2.png
--------------------------------------------------------------------------------
/png/attention/attention_1_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_1_3.png
--------------------------------------------------------------------------------
/png/attention/attention_2_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_0.png
--------------------------------------------------------------------------------
/png/attention/attention_2_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_1.png
--------------------------------------------------------------------------------
/png/attention/attention_2_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_2.png
--------------------------------------------------------------------------------
/png/attention/attention_2_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention/attention_2_3.png
--------------------------------------------------------------------------------
/png/attention_decoder.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder.gif
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_0_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_0.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_0_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_1.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_2.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_0_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_0_3.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_1_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_0.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_1.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_2.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_1_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_1_3.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_2_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_0.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_2_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_1.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_2_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_2.png
--------------------------------------------------------------------------------
/png/attention_decoder/attention_dec_2_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_decoder/attention_dec_2_3.png
--------------------------------------------------------------------------------
/png/attention_encoder.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder.gif
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_0_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_0.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_0_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_1.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_2.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_0_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_0_3.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_1_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_0.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_1.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_2.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_1_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_1_3.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_2_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_0.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_2_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_1.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_2_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_2.png
--------------------------------------------------------------------------------
/png/attention_encoder/attention_enc_2_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/attention_encoder/attention_enc_2_3.png
--------------------------------------------------------------------------------
/png/mel_original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/mel_original.png
--------------------------------------------------------------------------------
/png/mel_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/mel_pred.png
--------------------------------------------------------------------------------
/png/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/model.png
--------------------------------------------------------------------------------
/png/training_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/png/training_loss.png
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from torch.utils.data import Dataset, DataLoader
4 | import os
5 | from utils import get_spectrograms
6 | import hyperparams as hp
7 | import librosa
8 |
9 | class PrepareDataset(Dataset):
10 | """LJSpeech dataset."""
11 |
12 | def __init__(self, csv_file, root_dir):
13 | """
14 | Args:
15 | csv_file (string): Path to the csv file with annotations.
16 | root_dir (string): Directory with all the wavs.
17 |
18 | """
19 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None)
20 | self.root_dir = root_dir
21 |
22 | def load_wav(self, filename):
23 | return librosa.load(filename, sr=hp.sample_rate)
24 |
25 | def __len__(self):
26 | return len(self.landmarks_frame)
27 |
28 | def __getitem__(self, idx):
29 | wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav'
30 | mel, mag = get_spectrograms(wav_name)
31 |
32 | np.save(wav_name[:-4] + '.pt', mel)
33 | np.save(wav_name[:-4] + '.mag', mag)
34 |
35 | sample = {'mel':mel, 'mag': mag}
36 |
37 | return sample
38 |
39 | if __name__ == '__main__':
40 | dataset = PrepareDataset(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs'))
41 | dataloader = DataLoader(dataset, batch_size=1, drop_last=False, num_workers=8)
42 | from tqdm import tqdm
43 | pbar = tqdm(dataloader)
44 | for d in pbar:
45 | pass
46 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import hyperparams as hp
2 | import pandas as pd
3 | from torch.utils.data import Dataset, DataLoader
4 | import os
5 | import librosa
6 | import numpy as np
7 | from text import text_to_sequence
8 | import collections
9 | from scipy import signal
10 | import torch as t
11 | import math
12 |
13 |
14 | class LJDatasets(Dataset):
15 | """LJSpeech dataset."""
16 |
17 | def __init__(self, csv_file, root_dir):
18 | """
19 | Args:
20 | csv_file (string): Path to the csv file with annotations.
21 | root_dir (string): Directory with all the wavs.
22 |
23 | """
24 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None)
25 | self.root_dir = root_dir
26 |
27 | def load_wav(self, filename):
28 | return librosa.load(filename, sr=hp.sample_rate)
29 |
30 | def __len__(self):
31 | return len(self.landmarks_frame)
32 |
33 | def __getitem__(self, idx):
34 | wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav'
35 | text = self.landmarks_frame.ix[idx, 1]
36 |
37 | text = np.asarray(text_to_sequence(text, [hp.cleaners]), dtype=np.int32)
38 | mel = np.load(wav_name[:-4] + '.pt.npy')
39 | mel_input = np.concatenate([np.zeros([1,hp.num_mels], np.float32), mel[:-1,:]], axis=0)
40 | text_length = len(text)
41 | pos_text = np.arange(1, text_length + 1)
42 | pos_mel = np.arange(1, mel.shape[0] + 1)
43 |
44 | sample = {'text': text, 'mel': mel, 'text_length':text_length, 'mel_input':mel_input, 'pos_mel':pos_mel, 'pos_text':pos_text}
45 |
46 | return sample
47 |
48 | class PostDatasets(Dataset):
49 | """LJSpeech dataset."""
50 |
51 | def __init__(self, csv_file, root_dir):
52 | """
53 | Args:
54 | csv_file (string): Path to the csv file with annotations.
55 | root_dir (string): Directory with all the wavs.
56 |
57 | """
58 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None)
59 | self.root_dir = root_dir
60 |
61 | def __len__(self):
62 | return len(self.landmarks_frame)
63 |
64 | def __getitem__(self, idx):
65 | wav_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav'
66 | mel = np.load(wav_name[:-4] + '.pt.npy')
67 | mag = np.load(wav_name[:-4] + '.mag.npy')
68 | sample = {'mel':mel, 'mag':mag}
69 |
70 | return sample
71 |
72 | def collate_fn_transformer(batch):
73 |
74 | # Puts each data field into a tensor with outer dimension batch size
75 | if isinstance(batch[0], collections.Mapping):
76 |
77 | text = [d['text'] for d in batch]
78 | mel = [d['mel'] for d in batch]
79 | mel_input = [d['mel_input'] for d in batch]
80 | text_length = [d['text_length'] for d in batch]
81 | pos_mel = [d['pos_mel'] for d in batch]
82 | pos_text= [d['pos_text'] for d in batch]
83 |
84 | text = [i for i,_ in sorted(zip(text, text_length), key=lambda x: x[1], reverse=True)]
85 | mel = [i for i, _ in sorted(zip(mel, text_length), key=lambda x: x[1], reverse=True)]
86 | mel_input = [i for i, _ in sorted(zip(mel_input, text_length), key=lambda x: x[1], reverse=True)]
87 | pos_text = [i for i, _ in sorted(zip(pos_text, text_length), key=lambda x: x[1], reverse=True)]
88 | pos_mel = [i for i, _ in sorted(zip(pos_mel, text_length), key=lambda x: x[1], reverse=True)]
89 | text_length = sorted(text_length, reverse=True)
90 | # PAD sequences with largest length of the batch
91 | text = _prepare_data(text).astype(np.int32)
92 | mel = _pad_mel(mel)
93 | mel_input = _pad_mel(mel_input)
94 | pos_mel = _prepare_data(pos_mel).astype(np.int32)
95 | pos_text = _prepare_data(pos_text).astype(np.int32)
96 |
97 |
98 | return t.LongTensor(text), t.FloatTensor(mel), t.FloatTensor(mel_input), t.LongTensor(pos_text), t.LongTensor(pos_mel), t.LongTensor(text_length)
99 |
100 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
101 | .format(type(batch[0]))))
102 |
103 | def collate_fn_postnet(batch):
104 |
105 | # Puts each data field into a tensor with outer dimension batch size
106 | if isinstance(batch[0], collections.Mapping):
107 |
108 | mel = [d['mel'] for d in batch]
109 | mag = [d['mag'] for d in batch]
110 |
111 | # PAD sequences with largest length of the batch
112 | mel = _pad_mel(mel)
113 | mag = _pad_mel(mag)
114 |
115 | return t.FloatTensor(mel), t.FloatTensor(mag)
116 |
117 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
118 | .format(type(batch[0]))))
119 |
120 | def _pad_data(x, length):
121 | _pad = 0
122 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
123 |
124 | def _prepare_data(inputs):
125 | max_len = max((len(x) for x in inputs))
126 | return np.stack([_pad_data(x, max_len) for x in inputs])
127 |
128 | def _pad_per_step(inputs):
129 | timesteps = inputs.shape[-1]
130 | return np.pad(inputs, [[0,0],[0,0],[0, hp.outputs_per_step - (timesteps % hp.outputs_per_step)]], mode='constant', constant_values=0.0)
131 |
132 | def get_param_size(model):
133 | params = 0
134 | for p in model.parameters():
135 | tmp = 1
136 | for x in p.size():
137 | tmp *= x
138 | params += tmp
139 | return params
140 |
141 | def get_dataset():
142 | return LJDatasets(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs'))
143 |
144 | def get_post_dataset():
145 | return PostDatasets(os.path.join(hp.data_path,'metadata.csv'), os.path.join(hp.data_path,'wavs'))
146 |
147 | def _pad_mel(inputs):
148 | _pad = 0
149 | def _pad_one(x, max_len):
150 | mel_len = x.shape[0]
151 | return np.pad(x, [[0,max_len - mel_len],[0,0]], mode='constant', constant_values=_pad)
152 | max_len = max((x.shape[0] for x in inputs))
153 | return np.stack([_pad_one(x, max_len) for x in inputs])
154 |
155 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | falcon==1.2.0
2 | inflect==0.2.5
3 | librosa==0.7.1
4 | scipy==1.10.0
5 | Unidecode==0.4.21
6 | pandas
7 | numpy
8 | tensorboardX
9 | tqdm
10 |
--------------------------------------------------------------------------------
/samples/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soobinseo/Transformer-TTS/87bd9c9afdf98320b3168e73e2b7db14a7fc4b7a/samples/test.wav
--------------------------------------------------------------------------------
/synthesis.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from utils import spectrogram2wav
3 | from scipy.io.wavfile import write
4 | import hyperparams as hp
5 | from text import text_to_sequence
6 | import numpy as np
7 | from network import ModelPostNet, Model
8 | from collections import OrderedDict
9 | from tqdm import tqdm
10 | import argparse
11 |
12 | def load_checkpoint(step, model_name="transformer"):
13 | state_dict = t.load('./checkpoint/checkpoint_%s_%d.pth.tar'% (model_name, step))
14 | new_state_dict = OrderedDict()
15 | for k, value in state_dict['model'].items():
16 | key = k[7:]
17 | new_state_dict[key] = value
18 |
19 | return new_state_dict
20 |
21 | def synthesis(text, args):
22 | m = Model()
23 | m_post = ModelPostNet()
24 |
25 | m.load_state_dict(load_checkpoint(args.restore_step1, "transformer"))
26 | m_post.load_state_dict(load_checkpoint(args.restore_step2, "postnet"))
27 |
28 | text = np.asarray(text_to_sequence(text, [hp.cleaners]))
29 | text = t.LongTensor(text).unsqueeze(0)
30 | text = text.cuda()
31 | mel_input = t.zeros([1,1, 80]).cuda()
32 | pos_text = t.arange(1, text.size(1)+1).unsqueeze(0)
33 | pos_text = pos_text.cuda()
34 |
35 | m=m.cuda()
36 | m_post = m_post.cuda()
37 | m.train(False)
38 | m_post.train(False)
39 |
40 | pbar = tqdm(range(args.max_len))
41 | with t.no_grad():
42 | for i in pbar:
43 | pos_mel = t.arange(1,mel_input.size(1)+1).unsqueeze(0).cuda()
44 | mel_pred, postnet_pred, attn, stop_token, _, attn_dec = m.forward(text, mel_input, pos_text, pos_mel)
45 | mel_input = t.cat([mel_input, mel_pred[:,-1:,:]], dim=1)
46 |
47 | mag_pred = m_post.forward(postnet_pred)
48 |
49 | wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy())
50 | write(hp.sample_path + "/test.wav", hp.sr, wav)
51 |
52 | if __name__ == '__main__':
53 |
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument('--restore_step1', type=int, help='Global step to restore checkpoint', default=172000)
56 | parser.add_argument('--restore_step2', type=int, help='Global step to restore checkpoint', default=100000)
57 | parser.add_argument('--max_len', type=int, help='Global step to restore checkpoint', default=400)
58 |
59 | args = parser.parse_args()
60 | synthesis("Transformer model is so fast!",args)
61 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | import re
4 | from text import cleaners
5 | from text.symbols import symbols
6 |
7 |
8 |
9 |
10 | # Mappings from symbol to numeric ID and vice versa:
11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
13 |
14 | # Regular expression matching text enclosed in curly braces:
15 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
16 |
17 |
18 | def text_to_sequence(text, cleaner_names):
19 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
20 |
21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
23 |
24 | Args:
25 | text: string to convert to a sequence
26 | cleaner_names: names of the cleaner functions to run the text through
27 |
28 | Returns:
29 | List of integers corresponding to the symbols in the text
30 | '''
31 | sequence = []
32 |
33 | # Check for curly braces and treat their contents as ARPAbet:
34 | while len(text):
35 | m = _curly_re.match(text)
36 | if not m:
37 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
38 | break
39 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
40 | sequence += _arpabet_to_sequence(m.group(2))
41 | text = m.group(3)
42 |
43 | # Append EOS token
44 | sequence.append(_symbol_to_id['~'])
45 | return sequence
46 |
47 |
48 | def sequence_to_text(sequence):
49 | '''Converts a sequence of IDs back to a string'''
50 | result = ''
51 | for symbol_id in sequence:
52 | if symbol_id in _id_to_symbol:
53 | s = _id_to_symbol[symbol_id]
54 | # Enclose ARPAbet back in curly braces:
55 | if len(s) > 1 and s[0] == '@':
56 | s = '{%s}' % s[1:]
57 | result += s
58 | return result.replace('}{', ' ')
59 |
60 |
61 | def _clean_text(text, cleaner_names):
62 | for name in cleaner_names:
63 | cleaner = getattr(cleaners, name)
64 | if not cleaner:
65 | raise Exception('Unknown cleaner: %s' % name)
66 | text = cleaner(text)
67 | return text
68 |
69 |
70 | def _symbols_to_sequence(symbols):
71 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
72 |
73 |
74 | def _arpabet_to_sequence(text):
75 | return _symbols_to_sequence(['@' + s for s in text.split()])
76 |
77 |
78 | def _should_keep_symbol(s):
79 | return s in _symbol_to_id and s is not '_' and s is not '~'
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 |
4 | '''
5 | Cleaners are transformations that run over the input text at both training and eval time.
6 |
7 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
8 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
9 | 1. "english_cleaners" for English text
10 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
11 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
12 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
13 | the symbols in symbols.py to match your data).
14 | '''
15 |
16 | import re
17 | from unidecode import unidecode
18 | from .numbers import normalize_numbers
19 |
20 |
21 | # Regular expression matching whitespace:
22 | _whitespace_re = re.compile(r'\s+')
23 |
24 | # List of (regular expression, replacement) pairs for abbreviations:
25 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
26 | ('mrs', 'misess'),
27 | ('mr', 'mister'),
28 | ('dr', 'doctor'),
29 | ('st', 'saint'),
30 | ('co', 'company'),
31 | ('jr', 'junior'),
32 | ('maj', 'major'),
33 | ('gen', 'general'),
34 | ('drs', 'doctors'),
35 | ('rev', 'reverend'),
36 | ('lt', 'lieutenant'),
37 | ('hon', 'honorable'),
38 | ('sgt', 'sergeant'),
39 | ('capt', 'captain'),
40 | ('esq', 'esquire'),
41 | ('ltd', 'limited'),
42 | ('col', 'colonel'),
43 | ('ft', 'fort'),
44 | ]]
45 |
46 |
47 | def expand_abbreviations(text):
48 | for regex, replacement in _abbreviations:
49 | text = re.sub(regex, replacement, text)
50 | return text
51 |
52 |
53 | def expand_numbers(text):
54 | return normalize_numbers(text)
55 |
56 |
57 | def lowercase(text):
58 | return text.lower()
59 |
60 |
61 | def collapse_whitespace(text):
62 | return re.sub(_whitespace_re, ' ', text)
63 |
64 |
65 | def convert_to_ascii(text):
66 | return unidecode(text)
67 |
68 |
69 | def basic_cleaners(text):
70 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
71 | text = lowercase(text)
72 | text = collapse_whitespace(text)
73 | return text
74 |
75 |
76 | def transliteration_cleaners(text):
77 | '''Pipeline for non-English text that transliterates to ASCII.'''
78 | text = convert_to_ascii(text)
79 | text = lowercase(text)
80 | text = collapse_whitespace(text)
81 | return text
82 |
83 |
84 | def english_cleaners(text):
85 | '''Pipeline for English text, including number and abbreviation expansion.'''
86 | text = convert_to_ascii(text)
87 | text = lowercase(text)
88 | text = expand_numbers(text)
89 | text = expand_abbreviations(text)
90 | text = collapse_whitespace(text)
91 | return text
92 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 |
4 | import re
5 |
6 |
7 | valid_symbols = [
8 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
9 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
10 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
11 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
12 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
13 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
14 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
15 | ]
16 |
17 | _valid_symbol_set = set(valid_symbols)
18 |
19 |
20 | class CMUDict:
21 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
22 | def __init__(self, file_or_path, keep_ambiguous=True):
23 | if isinstance(file_or_path, str):
24 | with open(file_or_path, encoding='latin-1') as f:
25 | entries = _parse_cmudict(f)
26 | else:
27 | entries = _parse_cmudict(file_or_path)
28 | if not keep_ambiguous:
29 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
30 | self._entries = entries
31 |
32 |
33 | def __len__(self):
34 | return len(self._entries)
35 |
36 |
37 | def lookup(self, word):
38 | '''Returns list of ARPAbet pronunciations of the given word.'''
39 | return self._entries.get(word.upper())
40 |
41 |
42 |
43 | _alt_re = re.compile(r'\([0-9]+\)')
44 |
45 |
46 | def _parse_cmudict(file):
47 | cmudict = {}
48 | for line in file:
49 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
50 | parts = line.split(' ')
51 | word = re.sub(_alt_re, '', parts[0])
52 | pronunciation = _get_pronunciation(parts[1])
53 | if pronunciation:
54 | if word in cmudict:
55 | cmudict[word].append(pronunciation)
56 | else:
57 | cmudict[word] = [pronunciation]
58 | return cmudict
59 |
60 |
61 | def _get_pronunciation(s):
62 | parts = s.strip().split(' ')
63 | for part in parts:
64 | if part not in _valid_symbol_set:
65 | return None
66 | return ' '.join(parts)
67 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
13 | _number_re = re.compile(r'[0-9]+')
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(',', '')
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace('.', ' point ')
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split('.')
27 | if len(parts) > 2:
28 | return match + ' dollars' # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
33 | cent_unit = 'cent' if cents == 1 else 'cents'
34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
37 | return '%s %s' % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = 'cent' if cents == 1 else 'cents'
40 | return '%s %s' % (cents, cent_unit)
41 | else:
42 | return 'zero dollars'
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return 'two thousand'
54 | elif num > 2000 and num < 2010:
55 | return 'two thousand ' + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + ' hundred'
58 | else:
59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
60 | else:
61 | return _inflect.number_to_words(num, andword='')
62 |
63 |
64 | def normalize_numbers(text):
65 | text = re.sub(_comma_number_re, _remove_commas, text)
66 | text = re.sub(_pounds_re, r'\1 pounds', text)
67 | text = re.sub(_dollars_re, _expand_dollars, text)
68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69 | text = re.sub(_ordinal_re, _expand_ordinal, text)
70 | text = re.sub(_number_re, _expand_number, text)
71 | return text
72 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 |
4 | '''
5 | Defines the set of symbols used in text input to the model.
6 |
7 | The default is a set of ASCII characters that works well for English or text that has been run
8 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
9 | '''
10 | from text import cmudict
11 |
12 | _pad = '_'
13 | _eos = '~'
14 | _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
15 |
16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
17 | _arpabet = ['@' + s for s in cmudict.valid_symbols]
18 |
19 | # Export all symbols:
20 | symbols = [_pad, _eos] + list(_characters) + _arpabet
21 |
22 |
23 | if __name__ == '__main__':
24 | print(symbols)
--------------------------------------------------------------------------------
/train_postnet.py:
--------------------------------------------------------------------------------
1 | from preprocess import get_post_dataset, DataLoader, collate_fn_postnet
2 | from network import *
3 | from tensorboardX import SummaryWriter
4 | import torchvision.utils as vutils
5 | import os
6 | from tqdm import tqdm
7 |
8 | def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
9 | lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = lr
12 |
13 |
14 | def main():
15 |
16 | dataset = get_post_dataset()
17 | global_step = 0
18 |
19 | m = nn.DataParallel(ModelPostNet().cuda())
20 |
21 | m.train()
22 | optimizer = t.optim.Adam(m.parameters(), lr=hp.lr)
23 |
24 | writer = SummaryWriter()
25 |
26 | for epoch in range(hp.epochs):
27 |
28 | dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_postnet, drop_last=True, num_workers=8)
29 | pbar = tqdm(dataloader)
30 | for i, data in enumerate(pbar):
31 | pbar.set_description("Processing at epoch %d"%epoch)
32 | global_step += 1
33 | if global_step < 400000:
34 | adjust_learning_rate(optimizer, global_step)
35 |
36 | mel, mag = data
37 |
38 | mel = mel.cuda()
39 | mag = mag.cuda()
40 |
41 | mag_pred = m.forward(mel)
42 |
43 | loss = nn.L1Loss()(mag_pred, mag)
44 |
45 | writer.add_scalars('training_loss',{
46 | 'loss':loss,
47 |
48 | }, global_step)
49 |
50 | optimizer.zero_grad()
51 | # Calculate gradients
52 | loss.backward()
53 |
54 | nn.utils.clip_grad_norm_(m.parameters(), 1.)
55 |
56 | # Update weights
57 | optimizer.step()
58 |
59 | if global_step % hp.save_step == 0:
60 | t.save({'model':m.state_dict(),
61 | 'optimizer':optimizer.state_dict()},
62 | os.path.join(hp.checkpoint_path,'checkpoint_postnet_%d.pth.tar' % global_step))
63 |
64 |
65 |
66 |
67 |
68 | if __name__ == '__main__':
69 | main()
--------------------------------------------------------------------------------
/train_transformer.py:
--------------------------------------------------------------------------------
1 | from preprocess import get_dataset, DataLoader, collate_fn_transformer
2 | from network import *
3 | from tensorboardX import SummaryWriter
4 | import torchvision.utils as vutils
5 | import os
6 | from tqdm import tqdm
7 |
8 | def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
9 | lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = lr
12 |
13 |
14 | def main():
15 |
16 | dataset = get_dataset()
17 | global_step = 0
18 |
19 | m = nn.DataParallel(Model().cuda())
20 |
21 | m.train()
22 | optimizer = t.optim.Adam(m.parameters(), lr=hp.lr)
23 |
24 | pos_weight = t.FloatTensor([5.]).cuda()
25 | writer = SummaryWriter()
26 |
27 | for epoch in range(hp.epochs):
28 |
29 | dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_transformer, drop_last=True, num_workers=16)
30 | pbar = tqdm(dataloader)
31 | for i, data in enumerate(pbar):
32 | pbar.set_description("Processing at epoch %d"%epoch)
33 | global_step += 1
34 | if global_step < 400000:
35 | adjust_learning_rate(optimizer, global_step)
36 |
37 | character, mel, mel_input, pos_text, pos_mel, _ = data
38 |
39 | stop_tokens = t.abs(pos_mel.ne(0).type(t.float) - 1)
40 |
41 | character = character.cuda()
42 | mel = mel.cuda()
43 | mel_input = mel_input.cuda()
44 | pos_text = pos_text.cuda()
45 | pos_mel = pos_mel.cuda()
46 |
47 | mel_pred, postnet_pred, attn_probs, stop_preds, attns_enc, attns_dec = m.forward(character, mel_input, pos_text, pos_mel)
48 |
49 | mel_loss = nn.L1Loss()(mel_pred, mel)
50 | post_mel_loss = nn.L1Loss()(postnet_pred, mel)
51 |
52 | loss = mel_loss + post_mel_loss
53 |
54 | writer.add_scalars('training_loss',{
55 | 'mel_loss':mel_loss,
56 | 'post_mel_loss':post_mel_loss,
57 |
58 | }, global_step)
59 |
60 | writer.add_scalars('alphas',{
61 | 'encoder_alpha':m.module.encoder.alpha.data,
62 | 'decoder_alpha':m.module.decoder.alpha.data,
63 | }, global_step)
64 |
65 |
66 | if global_step % hp.image_step == 1:
67 |
68 | for i, prob in enumerate(attn_probs):
69 |
70 | num_h = prob.size(0)
71 | for j in range(4):
72 |
73 | x = vutils.make_grid(prob[j*16] * 255)
74 | writer.add_image('Attention_%d_0'%global_step, x, i*4+j)
75 |
76 | for i, prob in enumerate(attns_enc):
77 | num_h = prob.size(0)
78 |
79 | for j in range(4):
80 |
81 | x = vutils.make_grid(prob[j*16] * 255)
82 | writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j)
83 |
84 | for i, prob in enumerate(attns_dec):
85 |
86 | num_h = prob.size(0)
87 | for j in range(4):
88 |
89 | x = vutils.make_grid(prob[j*16] * 255)
90 | writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j)
91 |
92 | optimizer.zero_grad()
93 | # Calculate gradients
94 | loss.backward()
95 |
96 | nn.utils.clip_grad_norm_(m.parameters(), 1.)
97 |
98 | # Update weights
99 | optimizer.step()
100 |
101 | if global_step % hp.save_step == 0:
102 | t.save({'model':m.state_dict(),
103 | 'optimizer':optimizer.state_dict()},
104 | os.path.join(hp.checkpoint_path,'checkpoint_transformer_%d.pth.tar' % global_step))
105 |
106 |
107 |
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import librosa
3 | import os, copy
4 | from scipy import signal
5 | import hyperparams as hp
6 | import torch as t
7 |
8 | def get_spectrograms(fpath):
9 | '''Parse the wave file in `fpath` and
10 | Returns normalized melspectrogram and linear spectrogram.
11 | Args:
12 | fpath: A string. The full path of a sound file.
13 | Returns:
14 | mel: A 2d array of shape (T, n_mels) and dtype of float32.
15 | mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32.
16 | '''
17 | # Loading sound file
18 | y, sr = librosa.load(fpath, sr=hp.sr)
19 |
20 | # Trimming
21 | y, _ = librosa.effects.trim(y)
22 |
23 | # Preemphasis
24 | y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])
25 |
26 | # stft
27 | linear = librosa.stft(y=y,
28 | n_fft=hp.n_fft,
29 | hop_length=hp.hop_length,
30 | win_length=hp.win_length)
31 |
32 | # magnitude spectrogram
33 | mag = np.abs(linear) # (1+n_fft//2, T)
34 |
35 | # mel spectrogram
36 | mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) # (n_mels, 1+n_fft//2)
37 | mel = np.dot(mel_basis, mag) # (n_mels, t)
38 |
39 | # to decibel
40 | mel = 20 * np.log10(np.maximum(1e-5, mel))
41 | mag = 20 * np.log10(np.maximum(1e-5, mag))
42 |
43 | # normalize
44 | mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
45 | mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
46 |
47 | # Transpose
48 | mel = mel.T.astype(np.float32) # (T, n_mels)
49 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2)
50 |
51 | return mel, mag
52 |
53 | def spectrogram2wav(mag):
54 | '''# Generate wave file from linear magnitude spectrogram
55 | Args:
56 | mag: A numpy array of (T, 1+n_fft//2)
57 | Returns:
58 | wav: A 1-D numpy array.
59 | '''
60 | # transpose
61 | mag = mag.T
62 |
63 | # de-noramlize
64 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db
65 |
66 | # to amplitude
67 | mag = np.power(10.0, mag * 0.05)
68 |
69 | # wav reconstruction
70 | wav = griffin_lim(mag**hp.power)
71 |
72 | # de-preemphasis
73 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav)
74 |
75 | # trim
76 | wav, _ = librosa.effects.trim(wav)
77 |
78 | return wav.astype(np.float32)
79 |
80 | def griffin_lim(spectrogram):
81 | '''Applies Griffin-Lim's raw.'''
82 | X_best = copy.deepcopy(spectrogram)
83 | for i in range(hp.n_iter):
84 | X_t = invert_spectrogram(X_best)
85 | est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length)
86 | phase = est / np.maximum(1e-8, np.abs(est))
87 | X_best = spectrogram * phase
88 | X_t = invert_spectrogram(X_best)
89 | y = np.real(X_t)
90 |
91 | return y
92 |
93 | def invert_spectrogram(spectrogram):
94 | '''Applies inverse fft.
95 | Args:
96 | spectrogram: [1+n_fft//2, t]
97 | '''
98 | return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann")
99 |
100 | def get_positional_table(d_pos_vec, n_position=1024):
101 | position_enc = np.array([
102 | [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)]
103 | if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
104 |
105 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
106 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
107 | return t.from_numpy(position_enc).type(t.FloatTensor)
108 |
109 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
110 | ''' Sinusoid position encoding table '''
111 |
112 | def cal_angle(position, hid_idx):
113 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
114 |
115 | def get_posi_angle_vec(position):
116 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
117 |
118 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
119 |
120 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
121 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
122 |
123 | if padding_idx is not None:
124 | # zero vector for padding dimension
125 | sinusoid_table[padding_idx] = 0.
126 |
127 | return t.FloatTensor(sinusoid_table)
128 |
129 | def guided_attention(N, T, g=0.2):
130 | '''Guided attention. Refer to page 3 on the paper.'''
131 | W = np.zeros((N, T), dtype=np.float32)
132 | for n_pos in range(W.shape[0]):
133 | for t_pos in range(W.shape[1]):
134 | W[n_pos, t_pos] = 1 - np.exp(-(t_pos / float(T) - n_pos / float(N)) ** 2 / (2 * g * g))
135 | return W
136 |
--------------------------------------------------------------------------------