├── README.md ├── audio.py ├── hparams.py ├── hyperparams.py ├── module.py ├── network.py ├── prepare_data.py ├── preprocess.py ├── synthesis.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py ├── train_postnet.py ├── train_transformer.py ├── utils.py └── visualize_loss ├── loss_verification.py ├── visualize_all.py └── visualize_single.py /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-TTS 2 | Follow https://github.com/soobinseo/Transformer-TTS 3 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import math 4 | import numpy as np 5 | import scipy 6 | import hparams 7 | 8 | 9 | # Find end point need trans spec 10 | # inv spec needn't trans spec 11 | 12 | 13 | def load_wav(path): 14 | return librosa.core.load(path, sr=hparams.sample_rate)[0] 15 | 16 | 17 | def save_wav(wav, path): 18 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 19 | scipy.io.wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) 20 | 21 | 22 | def preemphasis(x): 23 | return scipy.signal.lfilter([1, -hparams.preemphasis], [1], x) 24 | 25 | 26 | def inv_preemphasis(x): 27 | return scipy.signal.lfilter([1], [1, -hparams.preemphasis], x) 28 | 29 | 30 | def spectrogram(y): 31 | D = _stft(preemphasis(y)) 32 | S = _amp_to_db(np.abs(D)) - hparams.ref_level_db 33 | return _normalize(S) 34 | 35 | 36 | def inv_spectrogram(spectrogram): 37 | '''Converts spectrogram to waveform using librosa''' 38 | S = _db_to_amp(_denormalize(spectrogram) + 39 | hparams.ref_level_db) # Convert back to linear 40 | # Reconstruct phase 41 | return inv_preemphasis(_griffin_lim(S ** hparams.power)) 42 | 43 | 44 | def melspectrogram(y): 45 | D = _stft(preemphasis(y)) 46 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db 47 | 48 | return _normalize(S) 49 | 50 | 51 | def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8): 52 | window_length = int(hparams.sample_rate * min_silence_sec) 53 | hop_length = int(window_length / 4) 54 | threshold = _db_to_amp(threshold_db) 55 | for x in range(hop_length, len(wav) - window_length, hop_length): 56 | if np.max(wav[x:x+window_length]) < threshold: 57 | return x + hop_length 58 | return len(wav) 59 | 60 | 61 | def _griffin_lim(S): 62 | '''librosa implementation of Griffin-Lim 63 | Based on https://github.com/librosa/librosa/issues/434 64 | ''' 65 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 66 | S_complex = np.abs(S).astype(np.complex) 67 | y = _istft(S_complex * angles) 68 | for i in range(hparams.griffin_lim_iters): 69 | angles = np.exp(1j * np.angle(_stft(y))) 70 | y = _istft(S_complex * angles) 71 | return y 72 | 73 | 74 | def _stft(y): 75 | n_fft, hop_length, win_length = _stft_parameters() 76 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 77 | 78 | 79 | def _istft(y): 80 | _, hop_length, win_length = _stft_parameters() 81 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 82 | 83 | 84 | def _stft_parameters(): 85 | n_fft = (hparams.num_freq - 1) * 2 86 | 87 | # hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 88 | # win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate) 89 | 90 | hop_length = hparams.hop_length 91 | win_length = hparams.win_length 92 | 93 | # print(hop_length, win_length) 94 | 95 | return n_fft, hop_length, win_length 96 | 97 | 98 | # Conversions: 99 | _mel_basis = None 100 | 101 | 102 | def _linear_to_mel(spectrogram): 103 | global _mel_basis 104 | if _mel_basis is None: 105 | _mel_basis = _build_mel_basis() 106 | return np.dot(_mel_basis, spectrogram) 107 | 108 | 109 | def _build_mel_basis(): 110 | n_fft = (hparams.num_freq - 1) * 2 111 | return librosa.filters.mel(hparams.sample_rate, n_fft, n_mels=hparams.num_mels) 112 | 113 | 114 | def _amp_to_db(x): 115 | return 20 * np.log10(np.maximum(1e-5, x)) 116 | 117 | 118 | def _db_to_amp(x): 119 | return np.power(10.0, x * 0.05) 120 | 121 | 122 | def _normalize(S): 123 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) 124 | 125 | 126 | def _denormalize(S): 127 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 128 | 129 | 130 | # def get_hop_size(): 131 | # hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 132 | # return hop_size 133 | 134 | 135 | # def get_win_size(): 136 | # win_size = int(hparams.frame_length_ms / 1000 * hparams.sample_rate) 137 | # return win_size 138 | 139 | 140 | _inv_mel_basis = None 141 | 142 | 143 | def _mel_to_linear(mel_spectrogram): 144 | global _inv_mel_basis 145 | if _inv_mel_basis is None: 146 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis()) 147 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 148 | 149 | 150 | def inv_mel_spectrogram(mel_spectrogram): 151 | '''Converts mel spectrogram to waveform using librosa''' 152 | if hparams.signal_normalization: 153 | D = _denormalize(mel_spectrogram) 154 | else: 155 | D = mel_spectrogram 156 | 157 | # Convert back to linear 158 | S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db)) 159 | 160 | return inv_preemphasis(_griffin_lim(S ** hparams.power)) 161 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | # Audio: 2 | num_mels = 80 3 | num_freq = 1025 4 | sample_rate = 22050 5 | 6 | frame_length_ms = 50 7 | frame_shift_ms = 12.5 8 | 9 | hop_length = 256 10 | win_length = 1024 11 | 12 | preemphasis = 0.97 13 | min_level_db = -100 14 | ref_level_db = 20 15 | griffin_lim_iters = 60 16 | power = 1.5 17 | signal_normalization = True 18 | use_lws = False 19 | 20 | # num_mels = 80 21 | outputs_per_step = 1 22 | hidden_size = 256 23 | embedding_size = 512 24 | epochs = 10000 25 | lr = 0.001 26 | save_step = 2000 27 | batch_size = 16 28 | cleaners = ['english_cleaners'] 29 | data_path = './dataset' 30 | checkpoint_path = './model_new' 31 | logger_path = "./logger" 32 | log_step = 10 33 | clear_Time = 20 34 | -------------------------------------------------------------------------------- /hyperparams.py: -------------------------------------------------------------------------------- 1 | # Audio 2 | num_mels = 80 3 | # num_freq = 1024 4 | n_fft = 2048 5 | sr = 22050 6 | # frame_length_ms = 50. 7 | # frame_shift_ms = 12.5 8 | preemphasis = 0.97 9 | frame_shift = 0.0125 # seconds 10 | frame_length = 0.05 # seconds 11 | hop_length = int(sr*frame_shift) # samples. 12 | win_length = int(sr*frame_length) # samples. 13 | n_mels = 80 # Number of Mel banks to generate 14 | power = 1.2 # Exponent for amplifying the predicted magnitude 15 | min_level_db = -100 16 | ref_level_db = 20 17 | hidden_size = 256 18 | embedding_size = 512 19 | max_db = 100 20 | ref_db = 20 21 | 22 | n_iter = 60 23 | # power = 1.5 24 | outputs_per_step = 1 25 | 26 | epochs = 10000 27 | lr = 0.001 28 | save_step = 2000 29 | image_step = 500 30 | batch_size = 8 31 | 32 | cleaners = 'english_cleaners' 33 | 34 | data_path = './LJSpeech-1.1' 35 | checkpoint_path = './checkpoint' 36 | sample_path = './samples' 37 | log_step = 5 38 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as t 3 | import torch.nn.functional as F 4 | import math 5 | import hyperparams as hp 6 | from text.symbols import symbols 7 | import numpy as np 8 | import copy 9 | from collections import OrderedDict 10 | 11 | def clones(module, N): 12 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 13 | 14 | class Linear(nn.Module): 15 | """ 16 | Linear Module 17 | """ 18 | def __init__(self, in_dim, out_dim, bias=True, w_init='linear'): 19 | """ 20 | :param in_dim: dimension of input 21 | :param out_dim: dimension of output 22 | :param bias: boolean. if True, bias is included. 23 | :param w_init: str. weight inits with xavier initialization. 24 | """ 25 | super(Linear, self).__init__() 26 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) 27 | 28 | nn.init.xavier_uniform_( 29 | self.linear_layer.weight, 30 | gain=nn.init.calculate_gain(w_init)) 31 | 32 | def forward(self, x): 33 | return self.linear_layer(x) 34 | 35 | 36 | class Conv(nn.Module): 37 | """ 38 | Convolution Module 39 | """ 40 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 41 | padding=0, dilation=1, bias=True, w_init='linear'): 42 | """ 43 | :param in_channels: dimension of input 44 | :param out_channels: dimension of output 45 | :param kernel_size: size of kernel 46 | :param stride: size of stride 47 | :param padding: size of padding 48 | :param dilation: dilation rate 49 | :param bias: boolean. if True, bias is included. 50 | :param w_init: str. weight inits with xavier initialization. 51 | """ 52 | super(Conv, self).__init__() 53 | 54 | self.conv = nn.Conv1d(in_channels, out_channels, 55 | kernel_size=kernel_size, stride=stride, 56 | padding=padding, dilation=dilation, 57 | bias=bias) 58 | 59 | nn.init.xavier_uniform_( 60 | self.conv.weight, gain=nn.init.calculate_gain(w_init)) 61 | 62 | def forward(self, x): 63 | x = self.conv(x) 64 | return x 65 | 66 | 67 | class EncoderPrenet(nn.Module): 68 | """ 69 | Pre-network for Encoder consists of convolution networks. 70 | """ 71 | def __init__(self, embedding_size, num_hidden): 72 | super(EncoderPrenet, self).__init__() 73 | self.embedding_size = embedding_size 74 | self.embed = nn.Embedding(len(symbols), embedding_size, padding_idx=0) 75 | 76 | self.conv1 = Conv(in_channels=embedding_size, 77 | out_channels=num_hidden, 78 | kernel_size=5, 79 | padding=int(np.floor(5 / 2)), 80 | w_init='relu') 81 | self.conv2 = Conv(in_channels=num_hidden, 82 | out_channels=num_hidden, 83 | kernel_size=5, 84 | padding=int(np.floor(5 / 2)), 85 | w_init='relu') 86 | 87 | self.conv3 = Conv(in_channels=num_hidden, 88 | out_channels=num_hidden, 89 | kernel_size=5, 90 | padding=int(np.floor(5 / 2)), 91 | w_init='relu') 92 | 93 | self.batch_norm1 = nn.BatchNorm1d(num_hidden) 94 | self.batch_norm2 = nn.BatchNorm1d(num_hidden) 95 | self.batch_norm3 = nn.BatchNorm1d(num_hidden) 96 | 97 | self.dropout1 = nn.Dropout(p=0.2) 98 | self.dropout2 = nn.Dropout(p=0.2) 99 | self.dropout3 = nn.Dropout(p=0.2) 100 | self.projection = Linear(num_hidden, num_hidden) 101 | 102 | def forward(self, input_): 103 | input_ = self.embed(input_) 104 | input_ = input_.transpose(1, 2) 105 | input_ = self.dropout1(self.batch_norm1(t.relu(self.conv1(input_)))) 106 | input_ = self.dropout2(self.batch_norm2(t.relu(self.conv2(input_)))) 107 | input_ = self.dropout3(self.batch_norm3(t.relu(self.conv3(input_)))) 108 | input_ = input_.transpose(1, 2) 109 | input_ = self.projection(input_) 110 | 111 | return input_ 112 | 113 | 114 | class FFN(nn.Module): 115 | """ 116 | Positionwise Feed-Forward Network 117 | """ 118 | 119 | def __init__(self, num_hidden): 120 | """ 121 | :param num_hidden: dimension of hidden 122 | """ 123 | super(FFN, self).__init__() 124 | self.w_1 = Conv(num_hidden, num_hidden * 4, kernel_size=1, w_init='relu') 125 | self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=1) 126 | self.dropout = nn.Dropout(p=0.1) 127 | self.layer_norm = nn.LayerNorm(num_hidden) 128 | 129 | def forward(self, input_): 130 | # FFN Network 131 | x = input_.transpose(1, 2) 132 | x = self.w_2(t.relu(self.w_1(x))) 133 | x = x.transpose(1, 2) 134 | 135 | # residual connection 136 | x = x + input_ 137 | 138 | # dropout 139 | x = self.dropout(x) 140 | 141 | # layer normalization 142 | x = self.layer_norm(x) 143 | 144 | return x 145 | 146 | 147 | class PostConvNet(nn.Module): 148 | """ 149 | Post Convolutional Network (mel --> mel) 150 | """ 151 | def __init__(self, num_hidden): 152 | """ 153 | 154 | :param num_hidden: dimension of hidden 155 | """ 156 | super(PostConvNet, self).__init__() 157 | self.conv1 = Conv(in_channels=hp.num_mels * hp.outputs_per_step, 158 | out_channels=num_hidden, 159 | kernel_size=5, 160 | padding=4, 161 | w_init='tanh') 162 | self.conv_list = clones(Conv(in_channels=num_hidden, 163 | out_channels=num_hidden, 164 | kernel_size=5, 165 | padding=4, 166 | w_init='tanh'), 3) 167 | self.conv2 = Conv(in_channels=num_hidden, 168 | out_channels=hp.num_mels * hp.outputs_per_step, 169 | kernel_size=5, 170 | padding=4) 171 | 172 | self.batch_norm_list = clones(nn.BatchNorm1d(num_hidden), 3) 173 | self.pre_batchnorm = nn.BatchNorm1d(num_hidden) 174 | 175 | self.dropout1 = nn.Dropout(p=0.1) 176 | self.dropout_list = nn.ModuleList([nn.Dropout(p=0.1) for _ in range(3)]) 177 | 178 | def forward(self, input_, mask=None): 179 | # Causal Convolution (for auto-regressive) 180 | input_ = self.dropout1(t.tanh(self.pre_batchnorm(self.conv1(input_)[:, :, :-4]))) 181 | for batch_norm, conv, dropout in zip(self.batch_norm_list, self.conv_list, self.dropout_list): 182 | input_ = dropout(t.tanh(batch_norm(conv(input_)[:, :, :-4]))) 183 | input_ = self.conv2(input_)[:, :, :-4] 184 | return input_ 185 | 186 | 187 | class MultiheadAttention(nn.Module): 188 | """ 189 | Multihead attention mechanism (dot attention) 190 | """ 191 | def __init__(self, num_hidden_k): 192 | """ 193 | :param num_hidden_k: dimension of hidden 194 | """ 195 | super(MultiheadAttention, self).__init__() 196 | 197 | self.num_hidden_k = num_hidden_k 198 | self.attn_dropout = nn.Dropout(p=0.1) 199 | 200 | def forward(self, key, value, query, mask=None, query_mask=None): 201 | # Get attention score 202 | attn = t.bmm(query, key.transpose(1, 2)) 203 | attn = attn / math.sqrt(self.num_hidden_k) 204 | 205 | # Masking to ignore padding (key side) 206 | if mask is not None: 207 | attn = attn.masked_fill(mask, -2 ** 32 + 1) 208 | attn = t.softmax(attn, dim=-1) 209 | else: 210 | attn = t.softmax(attn, dim=-1) 211 | 212 | # Masking to ignore padding (query side) 213 | if query_mask is not None: 214 | attn = attn * query_mask 215 | 216 | # Dropout 217 | attn = self.attn_dropout(attn) 218 | 219 | # Get Context Vector 220 | result = t.bmm(attn, value) 221 | 222 | return result, attn 223 | 224 | 225 | class Attention(nn.Module): 226 | """ 227 | Attention Network 228 | """ 229 | def __init__(self, num_hidden, h=4): 230 | """ 231 | :param num_hidden: dimension of hidden 232 | :param h: num of heads 233 | """ 234 | super(Attention, self).__init__() 235 | 236 | self.num_hidden = num_hidden 237 | self.num_hidden_per_attn = num_hidden // h 238 | self.h = h 239 | 240 | self.key = Linear(num_hidden, num_hidden, bias=False) 241 | self.value = Linear(num_hidden, num_hidden, bias=False) 242 | self.query = Linear(num_hidden, num_hidden, bias=False) 243 | 244 | self.multihead = MultiheadAttention(self.num_hidden_per_attn) 245 | 246 | self.residual_dropout = nn.Dropout(p=0.1) 247 | 248 | self.final_linear = Linear(num_hidden * 2, num_hidden) 249 | 250 | self.layer_norm_1 = nn.LayerNorm(num_hidden) 251 | 252 | def forward(self, memory, decoder_input, mask=None, query_mask=None): 253 | 254 | batch_size = memory.size(0) 255 | seq_k = memory.size(1) 256 | seq_q = decoder_input.size(1) 257 | 258 | # Repeat masks h times 259 | if query_mask is not None: 260 | query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k) 261 | query_mask = query_mask.repeat(self.h, 1, 1) 262 | if mask is not None: 263 | mask = mask.repeat(self.h, 1, 1) 264 | 265 | # Make multihead 266 | key = self.key(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn) 267 | value = self.value(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn) 268 | query = self.query(decoder_input).view(batch_size, seq_q, self.h, self.num_hidden_per_attn) 269 | 270 | key = key.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn) 271 | value = value.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn) 272 | query = query.permute(2, 0, 1, 3).contiguous().view(-1, seq_q, self.num_hidden_per_attn) 273 | 274 | # Get context vector 275 | result, attns = self.multihead(key, value, query, mask=mask, query_mask=query_mask) 276 | 277 | # Concatenate all multihead context vector 278 | result = result.view(self.h, batch_size, seq_q, self.num_hidden_per_attn) 279 | result = result.permute(1, 2, 0, 3).contiguous().view(batch_size, seq_q, -1) 280 | 281 | # Concatenate context vector with input (most important) 282 | result = t.cat([decoder_input, result], dim=-1) 283 | 284 | # Final linear 285 | result = self.final_linear(result) 286 | 287 | # Residual dropout & connection 288 | result = self.residual_dropout(result) 289 | result = result + decoder_input 290 | 291 | # Layer normalization 292 | result = self.layer_norm_1(result) 293 | 294 | return result, attns 295 | 296 | 297 | class Prenet(nn.Module): 298 | """ 299 | Prenet before passing through the network 300 | """ 301 | def __init__(self, input_size, hidden_size, output_size, p=0.5): 302 | """ 303 | :param input_size: dimension of input 304 | :param hidden_size: dimension of hidden unit 305 | :param output_size: dimension of output 306 | """ 307 | super(Prenet, self).__init__() 308 | self.input_size = input_size 309 | self.output_size = output_size 310 | self.hidden_size = hidden_size 311 | self.layer = nn.Sequential(OrderedDict([ 312 | ('fc1', Linear(self.input_size, self.hidden_size)), 313 | ('relu1', nn.ReLU()), 314 | ('dropout1', nn.Dropout(p)), 315 | ('fc2', Linear(self.hidden_size, self.output_size)), 316 | ('relu2', nn.ReLU()), 317 | ('dropout2', nn.Dropout(p)), 318 | ])) 319 | 320 | def forward(self, input_): 321 | 322 | out = self.layer(input_) 323 | 324 | return out 325 | 326 | class CBHG(nn.Module): 327 | """ 328 | CBHG Module 329 | """ 330 | def __init__(self, hidden_size, K=16, projection_size = 256, num_gru_layers=2, max_pool_kernel_size=2, is_post=False): 331 | """ 332 | :param hidden_size: dimension of hidden unit 333 | :param K: # of convolution banks 334 | :param projection_size: dimension of projection unit 335 | :param num_gru_layers: # of layers of GRUcell 336 | :param max_pool_kernel_size: max pooling kernel size 337 | :param is_post: whether post processing or not 338 | """ 339 | super(CBHG, self).__init__() 340 | self.hidden_size = hidden_size 341 | self.projection_size = projection_size 342 | self.convbank_list = nn.ModuleList() 343 | self.convbank_list.append(nn.Conv1d(in_channels=projection_size, 344 | out_channels=hidden_size, 345 | kernel_size=1, 346 | padding=int(np.floor(1/2)))) 347 | 348 | for i in range(2, K+1): 349 | self.convbank_list.append(nn.Conv1d(in_channels=hidden_size, 350 | out_channels=hidden_size, 351 | kernel_size=i, 352 | padding=int(np.floor(i/2)))) 353 | 354 | self.batchnorm_list = nn.ModuleList() 355 | for i in range(1, K+1): 356 | self.batchnorm_list.append(nn.BatchNorm1d(hidden_size)) 357 | 358 | convbank_outdim = hidden_size * K 359 | 360 | self.conv_projection_1 = nn.Conv1d(in_channels=convbank_outdim, 361 | out_channels=hidden_size, 362 | kernel_size=3, 363 | padding=int(np.floor(3 / 2))) 364 | self.conv_projection_2 = nn.Conv1d(in_channels=hidden_size, 365 | out_channels=projection_size, 366 | kernel_size=3, 367 | padding=int(np.floor(3 / 2))) 368 | self.batchnorm_proj_1 = nn.BatchNorm1d(hidden_size) 369 | 370 | self.batchnorm_proj_2 = nn.BatchNorm1d(projection_size) 371 | 372 | 373 | self.max_pool = nn.MaxPool1d(max_pool_kernel_size, stride=1, padding=1) 374 | self.highway = Highwaynet(self.projection_size) 375 | self.gru = nn.GRU(self.projection_size, self.hidden_size // 2, num_layers=num_gru_layers, 376 | batch_first=True, 377 | bidirectional=True) 378 | 379 | 380 | def _conv_fit_dim(self, x, kernel_size=3): 381 | if kernel_size % 2 == 0: 382 | return x[:,:,:-1] 383 | else: 384 | return x 385 | 386 | def forward(self, input_): 387 | 388 | input_ = input_.contiguous() 389 | batch_size = input_.size(0) 390 | total_length = input_.size(-1) 391 | 392 | convbank_list = list() 393 | convbank_input = input_ 394 | 395 | # Convolution bank filters 396 | for k, (conv, batchnorm) in enumerate(zip(self.convbank_list, self.batchnorm_list)): 397 | convbank_input = t.relu(batchnorm(self._conv_fit_dim(conv(convbank_input), k+1).contiguous())) 398 | convbank_list.append(convbank_input) 399 | 400 | # Concatenate all features 401 | conv_cat = t.cat(convbank_list, dim=1) 402 | 403 | # Max pooling 404 | conv_cat = self.max_pool(conv_cat)[:,:,:-1] 405 | 406 | # Projection 407 | conv_projection = t.relu(self.batchnorm_proj_1(self._conv_fit_dim(self.conv_projection_1(conv_cat)))) 408 | conv_projection = self.batchnorm_proj_2(self._conv_fit_dim(self.conv_projection_2(conv_projection))) + input_ 409 | 410 | # Highway networks 411 | highway = self.highway.forward(conv_projection.transpose(1,2)) 412 | 413 | 414 | # Bidirectional GRU 415 | 416 | self.gru.flatten_parameters() 417 | out, _ = self.gru(highway) 418 | 419 | return out 420 | 421 | 422 | class Highwaynet(nn.Module): 423 | """ 424 | Highway network 425 | """ 426 | def __init__(self, num_units, num_layers=4): 427 | """ 428 | :param num_units: dimension of hidden unit 429 | :param num_layers: # of highway layers 430 | """ 431 | super(Highwaynet, self).__init__() 432 | self.num_units = num_units 433 | self.num_layers = num_layers 434 | self.gates = nn.ModuleList() 435 | self.linears = nn.ModuleList() 436 | for _ in range(self.num_layers): 437 | self.linears.append(Linear(num_units, num_units)) 438 | self.gates.append(Linear(num_units, num_units)) 439 | 440 | def forward(self, input_): 441 | 442 | out = input_ 443 | 444 | # highway gated function 445 | for fc1, fc2 in zip(self.linears, self.gates): 446 | 447 | h = t.relu(fc1.forward(out)) 448 | t_ = t.sigmoid(fc2.forward(out)) 449 | 450 | c = 1. - t_ 451 | out = h * t_ + out * c 452 | 453 | return out -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from module import * 2 | from utils import get_positional_table, get_sinusoid_encoding_table 3 | import hyperparams as hp 4 | import copy 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder Network 9 | """ 10 | def __init__(self, embedding_size, num_hidden): 11 | """ 12 | :param embedding_size: dimension of embedding 13 | :param num_hidden: dimension of hidden 14 | """ 15 | super(Encoder, self).__init__() 16 | self.alpha = nn.Parameter(t.ones(1)) 17 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0), 18 | freeze=True) 19 | self.pos_dropout = nn.Dropout(p=0.1) 20 | self.encoder_prenet = EncoderPrenet(embedding_size, num_hidden) 21 | self.layers = clones(Attention(num_hidden), 3) 22 | self.ffns = clones(FFN(num_hidden), 3) 23 | 24 | def forward(self, x, pos): 25 | 26 | # Get character mask 27 | if self.training: 28 | c_mask = pos.ne(0).type(t.float) 29 | mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1) 30 | 31 | else: 32 | c_mask, mask = None, None 33 | 34 | # Encoder pre-network 35 | x = self.encoder_prenet(x) 36 | 37 | # Get positional embedding, apply alpha and add 38 | pos = self.pos_emb(pos) 39 | x = pos * self.alpha + x 40 | 41 | # Positional dropout 42 | x = self.pos_dropout(x) 43 | 44 | # Attention encoder-encoder 45 | attns = list() 46 | for layer, ffn in zip(self.layers, self.ffns): 47 | x, attn = layer(x, x, mask=mask, query_mask=c_mask) 48 | x = ffn(x) 49 | attns.append(attn) 50 | 51 | return x, c_mask, attns 52 | 53 | 54 | class MelDecoder(nn.Module): 55 | """ 56 | Decoder Network 57 | """ 58 | def __init__(self, num_hidden): 59 | """ 60 | :param num_hidden: dimension of hidden 61 | """ 62 | super(MelDecoder, self).__init__() 63 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0), 64 | freeze=True) 65 | self.pos_dropout = nn.Dropout(p=0.1) 66 | self.alpha = nn.Parameter(t.ones(1)) 67 | self.decoder_prenet = Prenet(hp.num_mels, num_hidden * 2, num_hidden, p=0.2) 68 | self.norm = Linear(num_hidden, num_hidden) 69 | 70 | self.selfattn_layers = clones(Attention(num_hidden), 3) 71 | self.dotattn_layers = clones(Attention(num_hidden), 3) 72 | self.ffns = clones(FFN(num_hidden), 3) 73 | self.mel_linear = Linear(num_hidden, hp.num_mels * hp.outputs_per_step) 74 | self.stop_linear = Linear(num_hidden, 1, w_init='sigmoid') 75 | 76 | self.postconvnet = PostConvNet(num_hidden) 77 | 78 | def forward(self, memory, decoder_input, c_mask, pos): 79 | batch_size = memory.size(0) 80 | decoder_len = decoder_input.size(1) 81 | 82 | # get decoder mask with triangular matrix 83 | if self.training: 84 | m_mask = pos.ne(0).type(t.float) 85 | mask = m_mask.eq(0).unsqueeze(1).repeat(1, decoder_len, 1) 86 | mask = mask + t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte() 87 | mask = mask.gt(0) 88 | zero_mask = c_mask.eq(0).unsqueeze(-1).repeat(1, 1, decoder_len) 89 | zero_mask = zero_mask.transpose(1, 2) 90 | else: 91 | mask = t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte() 92 | mask = mask.gt(0) 93 | m_mask, zero_mask = None, None 94 | 95 | # Decoder pre-network 96 | decoder_input = self.decoder_prenet(decoder_input) 97 | 98 | # Centered position 99 | decoder_input = self.norm(decoder_input) 100 | 101 | # Get positional embedding, apply alpha and add 102 | pos = self.pos_emb(pos) 103 | decoder_input = pos * self.alpha + decoder_input 104 | 105 | # Positional dropout 106 | decoder_input = self.pos_dropout(decoder_input) 107 | 108 | # Attention decoder-decoder, encoder-decoder 109 | attn_dot_list = list() 110 | attn_dec_list = list() 111 | 112 | for selfattn, dotattn, ffn in zip(self.selfattn_layers, self.dotattn_layers, self.ffns): 113 | decoder_input, attn_dec = selfattn(decoder_input, decoder_input, mask=mask, query_mask=m_mask) 114 | decoder_input, attn_dot = dotattn(memory, decoder_input, mask=zero_mask, query_mask=m_mask) 115 | decoder_input = ffn(decoder_input) 116 | attn_dot_list.append(attn_dot) 117 | attn_dec_list.append(attn_dec) 118 | 119 | # Mel linear projection 120 | mel_out = self.mel_linear(decoder_input) 121 | 122 | # Post Mel Network 123 | postnet_input = mel_out.transpose(1, 2) 124 | out = self.postconvnet(postnet_input) 125 | out = postnet_input + out 126 | out = out.transpose(1, 2) 127 | 128 | # Stop tokens 129 | stop_tokens = self.stop_linear(decoder_input) 130 | 131 | return mel_out, out, attn_dot_list, stop_tokens, attn_dec_list 132 | 133 | 134 | class Model(nn.Module): 135 | """ 136 | Transformer Network 137 | """ 138 | def __init__(self): 139 | super(Model, self).__init__() 140 | self.encoder = Encoder(hp.embedding_size, hp.hidden_size) 141 | self.decoder = MelDecoder(hp.hidden_size) 142 | 143 | def forward(self, characters, mel_input, pos_text, pos_mel): 144 | memory, c_mask, attns_enc = self.encoder.forward(characters, pos=pos_text) 145 | mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder.forward(memory, mel_input, c_mask, 146 | pos=pos_mel) 147 | 148 | return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec 149 | 150 | 151 | class ModelPostNet(nn.Module): 152 | """ 153 | CBHG Network (mel --> linear) 154 | """ 155 | def __init__(self): 156 | super(ModelPostNet, self).__init__() 157 | self.pre_projection = Conv(hp.n_mels, hp.hidden_size) 158 | self.cbhg = CBHG(hp.hidden_size) 159 | self.post_projection = Conv(hp.hidden_size, (hp.n_fft // 2) + 1) 160 | 161 | def forward(self, mel): 162 | mel = mel.transpose(1, 2) 163 | mel = self.pre_projection(mel) 164 | mel = self.cbhg(mel).transpose(1, 2) 165 | mag_pred = self.post_projection(mel).transpose(1, 2) 166 | 167 | return mag_pred -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | from utils import get_spectrograms 6 | import hyperparams as hp 7 | import librosa 8 | 9 | 10 | class PrepareDataset(Dataset): 11 | """LJSpeech dataset.""" 12 | 13 | def __init__(self, csv_file, root_dir): 14 | """ 15 | Args: 16 | csv_file (string): Path to the csv file with annotations. 17 | root_dir (string): Directory with all the wavs. 18 | 19 | """ 20 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None) 21 | self.root_dir = root_dir 22 | 23 | def load_wav(self, filename): 24 | return librosa.load(filename, sr=hp.sample_rate) 25 | 26 | def __len__(self): 27 | return len(self.landmarks_frame) 28 | 29 | def __getitem__(self, idx): 30 | wav_name = os.path.join( 31 | self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav' 32 | mel, _ = get_spectrograms(wav_name) 33 | 34 | np.save(wav_name[:-4] + '.pt', mel) 35 | # np.save(wav_name[:-4] + '.mag', mag) 36 | 37 | sample = {'mel': mel} 38 | 39 | return sample 40 | 41 | 42 | if __name__ == '__main__': 43 | dataset = PrepareDataset(os.path.join( 44 | hp.data_path, 'metadata.csv'), os.path.join(hp.data_path, 'wavs')) 45 | dataloader = DataLoader(dataset, batch_size=1, 46 | drop_last=False, num_workers=6) 47 | from tqdm import tqdm 48 | pbar = tqdm(dataloader) 49 | for d in pbar: 50 | pass 51 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import hyperparams as hp 2 | import pandas as pd 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | import librosa 6 | import numpy as np 7 | from text import text_to_sequence 8 | import collections 9 | from scipy import signal 10 | import torch as t 11 | import math 12 | 13 | 14 | class LJDatasets(Dataset): 15 | """LJSpeech dataset.""" 16 | 17 | def __init__(self, csv_file, root_dir): 18 | """ 19 | Args: 20 | csv_file (string): Path to the csv file with annotations. 21 | root_dir (string): Directory with all the wavs. 22 | 23 | """ 24 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None) 25 | self.root_dir = root_dir 26 | 27 | def load_wav(self, filename): 28 | return librosa.load(filename, sr=hp.sample_rate) 29 | 30 | def __len__(self): 31 | return len(self.landmarks_frame) 32 | 33 | def __getitem__(self, idx): 34 | wav_name = os.path.join( 35 | self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav' 36 | text = self.landmarks_frame.ix[idx, 1] 37 | 38 | text = np.asarray(text_to_sequence( 39 | text, [hp.cleaners]), dtype=np.int32) 40 | mel = np.load(wav_name[:-4] + '.pt.npy') 41 | mel_input = np.concatenate( 42 | [np.zeros([1, hp.num_mels], np.float32), mel[:-1, :]], axis=0) 43 | text_length = len(text) 44 | pos_text = np.arange(1, text_length + 1) 45 | pos_mel = np.arange(1, mel.shape[0] + 1) 46 | 47 | sample = {'text': text, 'mel': mel, 'text_length': text_length, 48 | 'mel_input': mel_input, 'pos_mel': pos_mel, 'pos_text': pos_text} 49 | 50 | return sample 51 | 52 | 53 | class PostDatasets(Dataset): 54 | """LJSpeech dataset.""" 55 | 56 | def __init__(self, csv_file, root_dir): 57 | """ 58 | Args: 59 | csv_file (string): Path to the csv file with annotations. 60 | root_dir (string): Directory with all the wavs. 61 | 62 | """ 63 | self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None) 64 | self.root_dir = root_dir 65 | 66 | def __len__(self): 67 | return len(self.landmarks_frame) 68 | 69 | def __getitem__(self, idx): 70 | wav_name = os.path.join( 71 | self.root_dir, self.landmarks_frame.ix[idx, 0]) + '.wav' 72 | mel = np.load(wav_name[:-4] + '.pt.npy') 73 | mag = np.load(wav_name[:-4] + '.mag.npy') 74 | sample = {'mel': mel, 'mag': mag} 75 | 76 | return sample 77 | 78 | 79 | def collate_fn_transformer(batch): 80 | 81 | # Puts each data field into a tensor with outer dimension batch size 82 | if isinstance(batch[0], collections.Mapping): 83 | 84 | text = [d['text'] for d in batch] 85 | mel = [d['mel'] for d in batch] 86 | mel_input = [d['mel_input'] for d in batch] 87 | text_length = [d['text_length'] for d in batch] 88 | pos_mel = [d['pos_mel'] for d in batch] 89 | pos_text = [d['pos_text'] for d in batch] 90 | 91 | text = [i for i, _ in sorted( 92 | zip(text, text_length), key=lambda x: x[1], reverse=True)] 93 | # print(text) 94 | # for t in text: 95 | # print(len(t)) 96 | mel = [i for i, _ in sorted( 97 | zip(mel, text_length), key=lambda x: x[1], reverse=True)] 98 | mel_input = [i for i, _ in sorted( 99 | zip(mel_input, text_length), key=lambda x: x[1], reverse=True)] 100 | pos_text = [i for i, _ in sorted( 101 | zip(pos_text, text_length), key=lambda x: x[1], reverse=True)] 102 | pos_mel = [i for i, _ in sorted( 103 | zip(pos_mel, text_length), key=lambda x: x[1], reverse=True)] 104 | text_length = sorted(text_length, reverse=True) 105 | # PAD sequences with largest length of the batch 106 | text = _prepare_data(text).astype(np.int32) 107 | mel = _pad_mel(mel) 108 | mel_input = _pad_mel(mel_input) 109 | pos_mel = _prepare_data(pos_mel).astype(np.int32) 110 | pos_text = _prepare_data(pos_text).astype(np.int32) 111 | 112 | return t.LongTensor(text), t.FloatTensor(mel), t.FloatTensor(mel_input), t.LongTensor(pos_text), t.LongTensor(pos_mel), t.LongTensor(text_length) 113 | 114 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 115 | .format(type(batch[0])))) 116 | 117 | 118 | def collate_fn_postnet(batch): 119 | 120 | # Puts each data field into a tensor with outer dimension batch size 121 | if isinstance(batch[0], collections.Mapping): 122 | 123 | mel = [d['mel'] for d in batch] 124 | mag = [d['mag'] for d in batch] 125 | 126 | # PAD sequences with largest length of the batch 127 | mel = _pad_mel(mel) 128 | mag = _pad_mel(mag) 129 | 130 | return t.FloatTensor(mel), t.FloatTensor(mag) 131 | 132 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 133 | .format(type(batch[0])))) 134 | 135 | 136 | def _pad_data(x, length): 137 | _pad = 0 138 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) 139 | 140 | 141 | def _prepare_data(inputs): 142 | max_len = max((len(x) for x in inputs)) 143 | return np.stack([_pad_data(x, max_len) for x in inputs]) 144 | 145 | 146 | def _pad_per_step(inputs): 147 | timesteps = inputs.shape[-1] 148 | return np.pad(inputs, [[0, 0], [0, 0], [0, hp.outputs_per_step - (timesteps % hp.outputs_per_step)]], mode='constant', constant_values=0.0) 149 | 150 | 151 | def get_param_size(model): 152 | params = 0 153 | for p in model.parameters(): 154 | tmp = 1 155 | for x in p.size(): 156 | tmp *= x 157 | params += tmp 158 | return params 159 | 160 | 161 | def get_dataset(): 162 | return LJDatasets(os.path.join(hp.data_path, 'metadata.csv'), os.path.join(hp.data_path, 'wavs')) 163 | 164 | 165 | def get_post_dataset(): 166 | return PostDatasets(os.path.join(hp.data_path, 'metadata.csv'), os.path.join(hp.data_path, 'wavs')) 167 | 168 | 169 | def _pad_mel(inputs): 170 | _pad = 0 171 | 172 | def _pad_one(x, max_len): 173 | mel_len = x.shape[0] 174 | return np.pad(x, [[0, max_len - mel_len], [0, 0]], mode='constant', constant_values=_pad) 175 | max_len = max((x.shape[0] for x in inputs)) 176 | return np.stack([_pad_one(x, max_len) for x in inputs]) 177 | -------------------------------------------------------------------------------- /synthesis.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from utils import spectrogram2wav 3 | from scipy.io.wavfile import write 4 | import hyperparams as hp 5 | from text import text_to_sequence 6 | import numpy as np 7 | from network import ModelPostNet, Model 8 | from collections import OrderedDict 9 | # from tqdm import tqdm 10 | # import argparse 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import audio 14 | import os 15 | 16 | 17 | def plot_data(data, figsize=(12, 4)): 18 | _, axes = plt.subplots(1, len(data), figsize=figsize) 19 | for i in range(len(data)): 20 | axes[i].imshow(data[i], aspect='auto', 21 | origin='bottom', interpolation='none') 22 | 23 | if not os.path.exists("img"): 24 | os.mkdir("img") 25 | plt.savefig(os.path.join("img", "model_test.jpg")) 26 | 27 | 28 | def load_checkpoint(step, model_name="transformer"): 29 | state_dict = t.load( 30 | './checkpoint/checkpoint_%s_%d.pth.tar' % (model_name, step)) 31 | new_state_dict = OrderedDict() 32 | for k, value in state_dict['model'].items(): 33 | key = k[7:] 34 | new_state_dict[key] = value 35 | 36 | return new_state_dict 37 | 38 | 39 | def synthesis(text, num): 40 | m = Model() 41 | # m_post = ModelPostNet() 42 | 43 | m.load_state_dict(load_checkpoint(num, "transformer")) 44 | # m_post.load_state_dict(load_checkpoint(args.restore_step2, "postnet")) 45 | 46 | text = np.asarray(text_to_sequence(text, [hp.cleaners])) 47 | text = t.LongTensor(text).unsqueeze(0) 48 | text = text.cuda() 49 | mel_input = t.zeros([1, 1, 80]).cuda() 50 | pos_text = t.arange(1, text.size(1)+1).unsqueeze(0) 51 | pos_text = pos_text.cuda() 52 | 53 | m = m.cuda() 54 | # m_post = m_post.cuda() 55 | m.train(False) 56 | # m_post.train(False) 57 | 58 | # pbar = tqdm(range(args.max_len)) 59 | with t.no_grad(): 60 | for _ in range(1000): 61 | pos_mel = t.arange(1, mel_input.size(1)+1).unsqueeze(0).cuda() 62 | mel_pred, postnet_pred, attn, stop_token, _, attn_dec = m.forward( 63 | text, mel_input, pos_text, pos_mel) 64 | mel_input = t.cat([mel_input, postnet_pred[:, -1:, :]], dim=1) 65 | 66 | # mag_pred = m_post.forward(postnet_pred) 67 | 68 | # wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy()) 69 | mel_postnet = postnet_pred[0].cpu().numpy().T 70 | plot_data([mel_postnet for _ in range(2)]) 71 | wav = audio.inv_mel_spectrogram(mel_postnet) 72 | wav = wav[0:audio.find_endpoint(wav)] 73 | audio.save_wav(wav, "result.wav") 74 | 75 | 76 | if __name__ == '__main__': 77 | # Test 78 | synthesis("I am very happy to see you again.", 160000) 79 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import re 4 | from text import cleaners 5 | from text.symbols import symbols 6 | 7 | 8 | 9 | 10 | # Mappings from symbol to numeric ID and vice versa: 11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 13 | 14 | # Regular expression matching text enclosed in curly braces: 15 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 16 | 17 | 18 | def text_to_sequence(text, cleaner_names): 19 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 20 | 21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 23 | 24 | Args: 25 | text: string to convert to a sequence 26 | cleaner_names: names of the cleaner functions to run the text through 27 | 28 | Returns: 29 | List of integers corresponding to the symbols in the text 30 | ''' 31 | sequence = [] 32 | 33 | # Check for curly braces and treat their contents as ARPAbet: 34 | while len(text): 35 | m = _curly_re.match(text) 36 | if not m: 37 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 38 | break 39 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 40 | sequence += _arpabet_to_sequence(m.group(2)) 41 | text = m.group(3) 42 | 43 | # Append EOS token 44 | sequence.append(_symbol_to_id['~']) 45 | return sequence 46 | 47 | 48 | def sequence_to_text(sequence): 49 | '''Converts a sequence of IDs back to a string''' 50 | result = '' 51 | for symbol_id in sequence: 52 | if symbol_id in _id_to_symbol: 53 | s = _id_to_symbol[symbol_id] 54 | # Enclose ARPAbet back in curly braces: 55 | if len(s) > 1 and s[0] == '@': 56 | s = '{%s}' % s[1:] 57 | result += s 58 | return result.replace('}{', ' ') 59 | 60 | 61 | def _clean_text(text, cleaner_names): 62 | for name in cleaner_names: 63 | cleaner = getattr(cleaners, name) 64 | if not cleaner: 65 | raise Exception('Unknown cleaner: %s' % name) 66 | text = cleaner(text) 67 | return text 68 | 69 | 70 | def _symbols_to_sequence(symbols): 71 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 72 | 73 | 74 | def _arpabet_to_sequence(text): 75 | return _symbols_to_sequence(['@' + s for s in text.split()]) 76 | 77 | 78 | def _should_keep_symbol(s): 79 | return s in _symbol_to_id and s is not '_' and s is not '~' -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | 4 | ''' 5 | Cleaners are transformations that run over the input text at both training and eval time. 6 | 7 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 8 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 9 | 1. "english_cleaners" for English text 10 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 11 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 12 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 13 | the symbols in symbols.py to match your data). 14 | ''' 15 | 16 | import re 17 | from unidecode import unidecode 18 | from .numbers import normalize_numbers 19 | 20 | 21 | # Regular expression matching whitespace: 22 | _whitespace_re = re.compile(r'\s+') 23 | 24 | # List of (regular expression, replacement) pairs for abbreviations: 25 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 26 | ('mrs', 'misess'), 27 | ('mr', 'mister'), 28 | ('dr', 'doctor'), 29 | ('st', 'saint'), 30 | ('co', 'company'), 31 | ('jr', 'junior'), 32 | ('maj', 'major'), 33 | ('gen', 'general'), 34 | ('drs', 'doctors'), 35 | ('rev', 'reverend'), 36 | ('lt', 'lieutenant'), 37 | ('hon', 'honorable'), 38 | ('sgt', 'sergeant'), 39 | ('capt', 'captain'), 40 | ('esq', 'esquire'), 41 | ('ltd', 'limited'), 42 | ('col', 'colonel'), 43 | ('ft', 'fort'), 44 | ]] 45 | 46 | 47 | def expand_abbreviations(text): 48 | for regex, replacement in _abbreviations: 49 | text = re.sub(regex, replacement, text) 50 | return text 51 | 52 | 53 | def expand_numbers(text): 54 | return normalize_numbers(text) 55 | 56 | 57 | def lowercase(text): 58 | return text.lower() 59 | 60 | 61 | def collapse_whitespace(text): 62 | return re.sub(_whitespace_re, ' ', text) 63 | 64 | 65 | def convert_to_ascii(text): 66 | return unidecode(text) 67 | 68 | 69 | def basic_cleaners(text): 70 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 71 | text = lowercase(text) 72 | text = collapse_whitespace(text) 73 | return text 74 | 75 | 76 | def transliteration_cleaners(text): 77 | '''Pipeline for non-English text that transliterates to ASCII.''' 78 | text = convert_to_ascii(text) 79 | text = lowercase(text) 80 | text = collapse_whitespace(text) 81 | return text 82 | 83 | 84 | def english_cleaners(text): 85 | '''Pipeline for English text, including number and abbreviation expansion.''' 86 | text = convert_to_ascii(text) 87 | text = lowercase(text) 88 | text = expand_numbers(text) 89 | text = expand_abbreviations(text) 90 | text = collapse_whitespace(text) 91 | return text 92 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | 4 | import re 5 | 6 | 7 | valid_symbols = [ 8 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 9 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 10 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 11 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 12 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 13 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 14 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 15 | ] 16 | 17 | _valid_symbol_set = set(valid_symbols) 18 | 19 | 20 | class CMUDict: 21 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 22 | def __init__(self, file_or_path, keep_ambiguous=True): 23 | if isinstance(file_or_path, str): 24 | with open(file_or_path, encoding='latin-1') as f: 25 | entries = _parse_cmudict(f) 26 | else: 27 | entries = _parse_cmudict(file_or_path) 28 | if not keep_ambiguous: 29 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 30 | self._entries = entries 31 | 32 | 33 | def __len__(self): 34 | return len(self._entries) 35 | 36 | 37 | def lookup(self, word): 38 | '''Returns list of ARPAbet pronunciations of the given word.''' 39 | return self._entries.get(word.upper()) 40 | 41 | 42 | 43 | _alt_re = re.compile(r'\([0-9]+\)') 44 | 45 | 46 | def _parse_cmudict(file): 47 | cmudict = {} 48 | for line in file: 49 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 50 | parts = line.split(' ') 51 | word = re.sub(_alt_re, '', parts[0]) 52 | pronunciation = _get_pronunciation(parts[1]) 53 | if pronunciation: 54 | if word in cmudict: 55 | cmudict[word].append(pronunciation) 56 | else: 57 | cmudict[word] = [pronunciation] 58 | return cmudict 59 | 60 | 61 | def _get_pronunciation(s): 62 | parts = s.strip().split(' ') 63 | for part in parts: 64 | if part not in _valid_symbol_set: 65 | return None 66 | return ' '.join(parts) 67 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | 4 | ''' 5 | Defines the set of symbols used in text input to the model. 6 | 7 | The default is a set of ASCII characters that works well for English or text that has been run 8 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. 9 | ''' 10 | from text import cmudict 11 | 12 | _pad = '_' 13 | _eos = '~' 14 | _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 18 | 19 | # Export all symbols: 20 | symbols = [_pad, _eos] + list(_characters) + _arpabet 21 | 22 | 23 | if __name__ == '__main__': 24 | print(symbols) -------------------------------------------------------------------------------- /train_postnet.py: -------------------------------------------------------------------------------- 1 | from preprocess import get_post_dataset, DataLoader, collate_fn_postnet 2 | from network import * 3 | from tensorboardX import SummaryWriter 4 | import torchvision.utils as vutils 5 | import os 6 | from tqdm import tqdm 7 | 8 | def adjust_learning_rate(optimizer, step_num, warmup_step=4000): 9 | lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = lr 12 | 13 | 14 | def main(): 15 | 16 | dataset = get_post_dataset() 17 | global_step = 0 18 | 19 | m = nn.DataParallel(ModelPostNet().cuda()) 20 | 21 | m.train() 22 | optimizer = t.optim.Adam(m.parameters(), lr=hp.lr) 23 | 24 | writer = SummaryWriter() 25 | 26 | for epoch in range(hp.epochs): 27 | 28 | dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_postnet, drop_last=True, num_workers=8) 29 | pbar = tqdm(dataloader) 30 | for i, data in enumerate(pbar): 31 | pbar.set_description("Processing at epoch %d"%epoch) 32 | global_step += 1 33 | if global_step < 400000: 34 | adjust_learning_rate(optimizer, global_step) 35 | 36 | mel, mag = data 37 | 38 | mel = mel.cuda() 39 | mag = mag.cuda() 40 | 41 | mag_pred = m.forward(mel) 42 | 43 | loss = nn.L1Loss()(mag_pred, mag) 44 | 45 | writer.add_scalars('training_loss',{ 46 | 'loss':loss, 47 | 48 | }, global_step) 49 | 50 | optimizer.zero_grad() 51 | # Calculate gradients 52 | loss.backward() 53 | 54 | nn.utils.clip_grad_norm_(m.parameters(), 1.) 55 | 56 | # Update weights 57 | optimizer.step() 58 | 59 | if global_step % hp.save_step == 0: 60 | t.save({'model':m.state_dict(), 61 | 'optimizer':optimizer.state_dict()}, 62 | os.path.join(hp.checkpoint_path,'checkpoint_postnet_%d.pth.tar' % global_step)) 63 | 64 | 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | main() -------------------------------------------------------------------------------- /train_transformer.py: -------------------------------------------------------------------------------- 1 | from preprocess import get_dataset, DataLoader, collate_fn_transformer 2 | from network import * 3 | # from tensorboardX import SummaryWriter 4 | # import torchvision.utils as vutils 5 | import os 6 | # from tqdm import tqdm 7 | import time 8 | 9 | 10 | def adjust_learning_rate(optimizer, step_num, warmup_step=4000): 11 | lr = hp.lr * warmup_step**0.5 * \ 12 | min(step_num * warmup_step**-1.5, step_num**-0.5) 13 | for param_group in optimizer.param_groups: 14 | param_group['lr'] = lr 15 | 16 | 17 | def main(): 18 | if not os.path.exists("logger"): 19 | os.mkdir("logger") 20 | 21 | dataset = get_dataset() 22 | global_step = 0 23 | 24 | m = nn.DataParallel(Model().cuda()) 25 | num_param = sum(param.numel() for param in m.parameters()) 26 | print('Number of Transformer-TTS Parameters:', num_param) 27 | 28 | m.train() 29 | optimizer = t.optim.Adam(m.parameters(), lr=hp.lr) 30 | 31 | pos_weight = t.FloatTensor([5.]).cuda() 32 | # writer = SummaryWriter() 33 | 34 | for epoch in range(hp.epochs): 35 | 36 | dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, 37 | collate_fn=collate_fn_transformer, drop_last=True, num_workers=16) 38 | # pbar = tqdm(dataloader) 39 | for i, data in enumerate(dataloader): 40 | # pbar.set_description("Processing at epoch %d"%epoch) 41 | global_step += 1 42 | if global_step < 400000: 43 | adjust_learning_rate(optimizer, global_step) 44 | 45 | character, mel, mel_input, pos_text, pos_mel, _ = data 46 | 47 | stop_tokens = t.abs(pos_mel.ne(0).type(t.float) - 1) 48 | 49 | character = character.cuda() 50 | mel = mel.cuda() 51 | mel_input = mel_input.cuda() 52 | pos_text = pos_text.cuda() 53 | pos_mel = pos_mel.cuda() 54 | # print(mel) 55 | 56 | mel_pred, postnet_pred, attn_probs, stop_preds, attns_enc, attns_dec = m.forward( 57 | character, mel_input, pos_text, pos_mel) 58 | 59 | mel_loss = nn.L1Loss()(mel_pred, mel) 60 | post_mel_loss = nn.L1Loss()(postnet_pred, mel) 61 | 62 | loss = mel_loss + post_mel_loss 63 | 64 | t_l = loss.item() 65 | m_l = mel_loss.item() 66 | m_p_l = post_mel_loss.item() 67 | # s_l = stop_pred_loss.item() 68 | 69 | with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: 70 | f_total_loss.write(str(t_l)+"\n") 71 | 72 | with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: 73 | f_mel_loss.write(str(m_l)+"\n") 74 | 75 | with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: 76 | f_mel_postnet_loss.write(str(m_p_l)+"\n") 77 | 78 | # with open(os.path.join("logger", "stop_pred_loss.txt"), "a") as f_s_loss: 79 | # f_s_loss.write(str(s_l)+"\n") 80 | 81 | # Print 82 | if global_step % hp.log_step == 0: 83 | # Now = time.clock() 84 | 85 | str1 = "Epoch [{}/{}], Step [{}], Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f};".format( 86 | epoch+1, hp.epochs, global_step, mel_loss.item(), post_mel_loss.item()) 87 | str2 = "Total Loss: {:.4f}.".format(loss.item()) 88 | current_learning_rate = 0 89 | for param_group in optimizer.param_groups: 90 | current_learning_rate = param_group['lr'] 91 | str3 = "Current Learning Rate is {:.6f}.".format( 92 | current_learning_rate) 93 | # str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( 94 | # (Now-Start), (total_step-current_step)*np.mean(Time)) 95 | 96 | print("\n" + str1) 97 | print(str2) 98 | print(str3) 99 | # print(str4) 100 | 101 | with open(os.path.join("logger", "logger.txt"), "a") as f_logger: 102 | f_logger.write(str1 + "\n") 103 | f_logger.write(str2 + "\n") 104 | f_logger.write(str3 + "\n") 105 | # f_logger.write(str4 + "\n") 106 | f_logger.write("\n") 107 | 108 | # writer.add_scalars('training_loss',{ 109 | # 'mel_loss':mel_loss, 110 | # 'post_mel_loss':post_mel_loss, 111 | 112 | # }, global_step) 113 | 114 | # writer.add_scalars('alphas',{ 115 | # 'encoder_alpha':m.module.encoder.alpha.data, 116 | # 'decoder_alpha':m.module.decoder.alpha.data, 117 | # }, global_step) 118 | 119 | # if global_step % hp.image_step == 1: 120 | 121 | # for i, prob in enumerate(attn_probs): 122 | 123 | # num_h = prob.size(0) 124 | # for j in range(4): 125 | 126 | # x = vutils.make_grid(prob[j*16] * 255) 127 | # writer.add_image('Attention_%d_0'%global_step, x, i*4+j) 128 | 129 | # for i, prob in enumerate(attns_enc): 130 | # num_h = prob.size(0) 131 | 132 | # for j in range(4): 133 | 134 | # x = vutils.make_grid(prob[j*16] * 255) 135 | # writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j) 136 | 137 | # for i, prob in enumerate(attns_dec): 138 | 139 | # num_h = prob.size(0) 140 | # for j in range(4): 141 | 142 | # x = vutils.make_grid(prob[j*16] * 255) 143 | # writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j) 144 | 145 | optimizer.zero_grad() 146 | # Calculate gradients 147 | loss.backward() 148 | 149 | nn.utils.clip_grad_norm_(m.parameters(), 1.) 150 | 151 | # Update weights 152 | optimizer.step() 153 | 154 | if global_step % hp.save_step == 0: 155 | t.save({'model': m.state_dict(), 156 | 'optimizer': optimizer.state_dict()}, 157 | os.path.join(hp.checkpoint_path, 'checkpoint_transformer_%d.pth.tar' % global_step)) 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import os, copy 4 | from scipy import signal 5 | import hyperparams as hp 6 | import torch as t 7 | 8 | def get_spectrograms(fpath): 9 | '''Parse the wave file in `fpath` and 10 | Returns normalized melspectrogram and linear spectrogram. 11 | Args: 12 | fpath: A string. The full path of a sound file. 13 | Returns: 14 | mel: A 2d array of shape (T, n_mels) and dtype of float32. 15 | mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32. 16 | ''' 17 | # Loading sound file 18 | y, sr = librosa.load(fpath, sr=hp.sr) 19 | 20 | # Trimming 21 | y, _ = librosa.effects.trim(y) 22 | 23 | # Preemphasis 24 | y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1]) 25 | 26 | # stft 27 | linear = librosa.stft(y=y, 28 | n_fft=hp.n_fft, 29 | hop_length=hp.hop_length, 30 | win_length=hp.win_length) 31 | 32 | # magnitude spectrogram 33 | mag = np.abs(linear) # (1+n_fft//2, T) 34 | 35 | # mel spectrogram 36 | mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) # (n_mels, 1+n_fft//2) 37 | mel = np.dot(mel_basis, mag) # (n_mels, t) 38 | 39 | # to decibel 40 | mel = 20 * np.log10(np.maximum(1e-5, mel)) 41 | mag = 20 * np.log10(np.maximum(1e-5, mag)) 42 | 43 | # normalize 44 | mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 45 | mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 46 | 47 | # Transpose 48 | mel = mel.T.astype(np.float32) # (T, n_mels) 49 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2) 50 | 51 | return mel, mag 52 | 53 | def spectrogram2wav(mag): 54 | '''# Generate wave file from linear magnitude spectrogram 55 | Args: 56 | mag: A numpy array of (T, 1+n_fft//2) 57 | Returns: 58 | wav: A 1-D numpy array. 59 | ''' 60 | # transpose 61 | mag = mag.T 62 | 63 | # de-noramlize 64 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db 65 | 66 | # to amplitude 67 | mag = np.power(10.0, mag * 0.05) 68 | 69 | # wav reconstruction 70 | wav = griffin_lim(mag**hp.power) 71 | 72 | # de-preemphasis 73 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav) 74 | 75 | # trim 76 | wav, _ = librosa.effects.trim(wav) 77 | 78 | return wav.astype(np.float32) 79 | 80 | def griffin_lim(spectrogram): 81 | '''Applies Griffin-Lim's raw.''' 82 | X_best = copy.deepcopy(spectrogram) 83 | for i in range(hp.n_iter): 84 | X_t = invert_spectrogram(X_best) 85 | est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length) 86 | phase = est / np.maximum(1e-8, np.abs(est)) 87 | X_best = spectrogram * phase 88 | X_t = invert_spectrogram(X_best) 89 | y = np.real(X_t) 90 | 91 | return y 92 | 93 | def invert_spectrogram(spectrogram): 94 | '''Applies inverse fft. 95 | Args: 96 | spectrogram: [1+n_fft//2, t] 97 | ''' 98 | return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann") 99 | 100 | def get_positional_table(d_pos_vec, n_position=1024): 101 | position_enc = np.array([ 102 | [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)] 103 | if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) 104 | 105 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i 106 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 107 | return t.from_numpy(position_enc).type(t.FloatTensor) 108 | 109 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 110 | ''' Sinusoid position encoding table ''' 111 | 112 | def cal_angle(position, hid_idx): 113 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 114 | 115 | def get_posi_angle_vec(position): 116 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 117 | 118 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 119 | 120 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 121 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 122 | 123 | if padding_idx is not None: 124 | # zero vector for padding dimension 125 | sinusoid_table[padding_idx] = 0. 126 | 127 | return t.FloatTensor(sinusoid_table) 128 | 129 | def guided_attention(N, T, g=0.2): 130 | '''Guided attention. Refer to page 3 on the paper.''' 131 | W = np.zeros((N, T), dtype=np.float32) 132 | for n_pos in range(W.shape[0]): 133 | for t_pos in range(W.shape[1]): 134 | W[n_pos, t_pos] = 1 - np.exp(-(t_pos / float(T) - n_pos / float(N)) ** 2 / (2 * g * g)) 135 | return W 136 | -------------------------------------------------------------------------------- /visualize_loss/loss_verification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | loss_arr = np.array(list()) 4 | with open("total_loss.txt", "r") as f_loss: 5 | cnt = 0 6 | for loss in f_loss.readlines(): 7 | cnt += 1 8 | # print(loss) 9 | loss_arr = np.append(loss_arr, float(loss)) 10 | print(cnt) 11 | -------------------------------------------------------------------------------- /visualize_loss/visualize_all.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def cut_arr(arr): 6 | for index in range(np.shape(arr)[0]): 7 | if arr[index] >= 2.0: 8 | arr[index] = 2.0 9 | 10 | return arr 11 | 12 | 13 | def visualize(total_loss_file_name, mel_loss_file_name, mel_postnet_loss_file_name): 14 | plt.figure() 15 | 16 | total_loss_arr = np.array(list()) 17 | with open(total_loss_file_name, "r") as f_total_loss: 18 | for index, loss in enumerate(f_total_loss.readlines()): 19 | if float(loss) > 0.6: 20 | print(index) 21 | total_loss_arr = np.append(total_loss_arr, float(loss)) 22 | 23 | x = np.array([i for i in range(np.shape(total_loss_arr)[0])]) 24 | y = cut_arr(total_loss_arr) 25 | 26 | plt.plot(x, y, color="y", lw=0.7, label="total loss") 27 | 28 | mel_loss_arr = np.array(list()) 29 | with open(mel_loss_file_name, "r") as f_mel_loss: 30 | for loss in f_mel_loss.readlines(): 31 | mel_loss_arr = np.append(mel_loss_arr, float(loss)) 32 | 33 | x = np.array([i for i in range(np.shape(mel_loss_arr)[0])]) 34 | y = cut_arr(mel_loss_arr) 35 | 36 | plt.plot(x, y, color="r", lw=0.7, label="mel loss") 37 | 38 | # gate_loss_arr = np.array(list()) 39 | # with open(mel_postnet_loss_file_name, "r") as f_gate_loss: 40 | # for loss in f_gate_loss.readlines(): 41 | # gate_loss_arr = np.append(gate_loss_arr, float(loss)) 42 | 43 | # x = np.array([i for i in range(np.shape(gate_loss_arr)[0])]) 44 | # y = cut_arr(gate_loss_arr) 45 | 46 | # plt.plot(x, y, color="b", lw=0.7, label="mel postnet loss") 47 | 48 | plt.legend() 49 | plt.xlabel("sequence number") 50 | plt.ylabel("loss item") 51 | plt.title("loss") 52 | plt.savefig("loss.jpg") 53 | 54 | 55 | if __name__ == "__main__": 56 | # Test 57 | visualize("total_loss.txt", "mel_loss.txt", "mel_postnet_loss.txt") 58 | -------------------------------------------------------------------------------- /visualize_loss/visualize_single.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def visualize(file_name, start, end=-1): 6 | plt.figure() 7 | 8 | loss_arr = np.array(list()) 9 | 10 | with open(file_name, "r") as f_loss: 11 | for loss in f_loss.readlines(): 12 | loss_arr = np.append(loss_arr, float(loss)) 13 | 14 | loss_arr = loss_arr[start:end] 15 | 16 | x = np.array([i for i in range(np.shape(loss_arr)[0])]) 17 | y = loss_arr 18 | 19 | plt.plot(x, y, color="y", lw=0.7) 20 | plt.xlabel("sequence number") 21 | plt.ylabel("loss item") 22 | plt.title("loss") 23 | plt.savefig("loss_one.jpg") 24 | 25 | 26 | if __name__ == "__main__": 27 | # Test 28 | visualize("total_loss.txt", 1000, 18000) 29 | --------------------------------------------------------------------------------