├── Configs └── config.yml ├── Data ├── train_list.txt └── val_list.txt ├── LICENSE ├── README.md ├── layers.py ├── meldataset.py ├── models.py ├── optimizers.py ├── text_utils.py ├── train.py ├── trainer.py ├── utils.py └── word_index_dict.txt /Configs/config.yml: -------------------------------------------------------------------------------- 1 | log_dir: "Checkpoint" 2 | save_freq: 10 3 | device: "cuda" 4 | epochs: 200 5 | batch_size: 64 6 | pretrained_model: "" 7 | train_data: "Data/train_list.txt" 8 | val_data: "Data/val_list.txt" 9 | 10 | preprocess_parasm: 11 | sr: 24000 12 | spect_params: 13 | n_fft: 2048 14 | win_length: 1200 15 | hop_length: 300 16 | mel_params: 17 | n_mels: 80 18 | 19 | model_params: 20 | input_dim: 80 21 | hidden_dim: 256 22 | n_token: 80 23 | token_embedding_dim: 256 24 | 25 | optimizer_params: 26 | lr: 0.0005 27 | -------------------------------------------------------------------------------- /Data/val_list.txt: -------------------------------------------------------------------------------- 1 | LJSpeech-1.1/wavs/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies|0 2 | LJSpeech-1.1/wavs/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery.|0 3 | LJSpeech-1.1/wavs/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President.|0 4 | LJSpeech-1.1/wavs/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here.|0 5 | LJSpeech-1.1/wavs/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work.|0 6 | LJSpeech-1.1/wavs/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound,|0 7 | LJSpeech-1.1/wavs/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window.|0 8 | LJSpeech-1.1/wavs/LJ026-0068.wav|Energy enters the plant, to a small extent,|0 9 | LJSpeech-1.1/wavs/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized|0 10 | LJSpeech-1.1/wavs/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to|0 11 | LJSpeech-1.1/wavs/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup.|0 12 | LJSpeech-1.1/wavs/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution.|0 13 | LJSpeech-1.1/wavs/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes.|0 14 | LJSpeech-1.1/wavs/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art.|0 15 | LJSpeech-1.1/wavs/LJ012-0235.wav|While they were in a state of insensibility the murder was committed.|0 16 | LJSpeech-1.1/wavs/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County,|0 17 | LJSpeech-1.1/wavs/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning.|0 18 | LJSpeech-1.1/wavs/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy.|0 19 | LJSpeech-1.1/wavs/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace,|0 20 | LJSpeech-1.1/wavs/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits|0 21 | LJSpeech-1.1/wavs/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail.|0 22 | LJSpeech-1.1/wavs/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal.|0 23 | LJSpeech-1.1/wavs/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator.|0 24 | LJSpeech-1.1/wavs/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely.|0 25 | LJSpeech-1.1/wavs/LJ008-0294.wav|nearly indefinitely deferred.|0 26 | LJSpeech-1.1/wavs/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands.|0 27 | LJSpeech-1.1/wavs/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job.|0 28 | LJSpeech-1.1/wavs/LJ016-0138.wav|at a distance from the prison.|0 29 | LJSpeech-1.1/wavs/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology.|0 30 | LJSpeech-1.1/wavs/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came.|0 31 | LJSpeech-1.1/wavs/LJ033-0047.wav|I noticed when I went out that the light was on, end quote,|0 32 | LJSpeech-1.1/wavs/LJ040-0027.wav|He was never satisfied with anything.|0 33 | LJSpeech-1.1/wavs/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly.|0 34 | LJSpeech-1.1/wavs/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity.|0 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Aaron (Yinghao) Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AuxiliaryASR 2 | This repo contains the training code for Phoneme-level ASR for Voice Conversion (VC) and TTS (Text-Mel Alignment) used in [StarGANv2-VC](https://github.com/yl4579/StarGANv2-VC) and [StyleTTS](https://github.com/yl4579/StyleTTS). 3 | 4 | ## Pre-requisites 5 | 1. Python >= 3.7 6 | 2. Clone this repository: 7 | ```bash 8 | git clone https://github.com/yl4579/AuxiliaryASR.git 9 | cd AuxiliaryASR 10 | ``` 11 | 3. Install python requirements: 12 | ```bash 13 | pip install SoundFile torchaudio torch jiwer pyyaml click matplotlib g2p_en librosa 14 | ``` 15 | 4. Prepare your own dataset and put the `train_list.txt` and `val_list.txt` in the `Data` folder (see Training section for more details). 16 | 17 | ## Training 18 | ```bash 19 | python train.py --config_path ./Configs/config.yml 20 | ``` 21 | Please specify the training and validation data in `config.yml` file. The data list format needs to be `filename.wav|label|speaker_number`, see [train_list.txt](https://github.com/yl4579/AuxiliaryASR/blob/main/Data/train_list.txt) as an example (a subset for LJSpeech). Note that `speaker_number` can just be `0` for ASR, but it is useful to set a meaningful number for TTS training (if you need to use this repo for StyleTTS). 22 | 23 | Checkpoints and Tensorboard logs will be saved at `log_dir`. To speed up training, you may want to make `batch_size` as large as your GPU RAM can take. However, please note that `batch_size = 64` will take around 10G GPU RAM. 24 | 25 | ### Languages 26 | This repo is set up for English with the [g2p_en](https://github.com/Kyubyong/g2p) package, but you can train it with other languages. If you would like to train for datasets in different languages, you will need to modify the [meldataset.py](https://github.com/yl4579/AuxiliaryASR/blob/main/meldataset.py#L86-L93) file (L86-93) with your own phonemizer. You also need to change the vocabulary file ([word_index_dict.txt](https://github.com/yl4579/AuxiliaryASR/blob/main/word_index_dict.txt)) and change `n_token` in `config.yml` to reflect the number of tokens. A recommended phonemizer for other languages is [phonemizer](https://github.com/bootphon/phonemizer). 27 | 28 | ## References 29 | - [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2) 30 | - [kan-bayashi/ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) 31 | 32 | ## Acknowledgement 33 | The author would like to thank [@tosaka-m](https://github.com/tosaka-m) for his great repository and valuable discussions. 34 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from typing import Optional, Any 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | import torchaudio 8 | import torchaudio.functional as audio_F 9 | 10 | import random 11 | random.seed(0) 12 | 13 | 14 | def _get_activation_fn(activ): 15 | if activ == 'relu': 16 | return nn.ReLU() 17 | elif activ == 'lrelu': 18 | return nn.LeakyReLU(0.2) 19 | elif activ == 'swish': 20 | return lambda x: x*torch.sigmoid(x) 21 | else: 22 | raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) 23 | 24 | class LinearNorm(torch.nn.Module): 25 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 26 | super(LinearNorm, self).__init__() 27 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 28 | 29 | torch.nn.init.xavier_uniform_( 30 | self.linear_layer.weight, 31 | gain=torch.nn.init.calculate_gain(w_init_gain)) 32 | 33 | def forward(self, x): 34 | return self.linear_layer(x) 35 | 36 | 37 | class ConvNorm(torch.nn.Module): 38 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 39 | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): 40 | super(ConvNorm, self).__init__() 41 | if padding is None: 42 | assert(kernel_size % 2 == 1) 43 | padding = int(dilation * (kernel_size - 1) / 2) 44 | 45 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 46 | kernel_size=kernel_size, stride=stride, 47 | padding=padding, dilation=dilation, 48 | bias=bias) 49 | 50 | torch.nn.init.xavier_uniform_( 51 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 52 | 53 | def forward(self, signal): 54 | conv_signal = self.conv(signal) 55 | return conv_signal 56 | 57 | class CausualConv(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): 59 | super(CausualConv, self).__init__() 60 | if padding is None: 61 | assert(kernel_size % 2 == 1) 62 | padding = int(dilation * (kernel_size - 1) / 2) * 2 63 | else: 64 | self.padding = padding * 2 65 | self.conv = nn.Conv1d(in_channels, out_channels, 66 | kernel_size=kernel_size, stride=stride, 67 | padding=self.padding, 68 | dilation=dilation, 69 | bias=bias) 70 | 71 | torch.nn.init.xavier_uniform_( 72 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 73 | 74 | def forward(self, x): 75 | x = self.conv(x) 76 | x = x[:, :, :-self.padding] 77 | return x 78 | 79 | class CausualBlock(nn.Module): 80 | def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): 81 | super(CausualBlock, self).__init__() 82 | self.blocks = nn.ModuleList([ 83 | self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) 84 | for i in range(n_conv)]) 85 | 86 | def forward(self, x): 87 | for block in self.blocks: 88 | res = x 89 | x = block(x) 90 | x += res 91 | return x 92 | 93 | def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): 94 | layers = [ 95 | CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), 96 | _get_activation_fn(activ), 97 | nn.BatchNorm1d(hidden_dim), 98 | nn.Dropout(p=dropout_p), 99 | CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), 100 | _get_activation_fn(activ), 101 | nn.Dropout(p=dropout_p) 102 | ] 103 | return nn.Sequential(*layers) 104 | 105 | class ConvBlock(nn.Module): 106 | def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): 107 | super().__init__() 108 | self._n_groups = 8 109 | self.blocks = nn.ModuleList([ 110 | self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) 111 | for i in range(n_conv)]) 112 | 113 | 114 | def forward(self, x): 115 | for block in self.blocks: 116 | res = x 117 | x = block(x) 118 | x += res 119 | return x 120 | 121 | def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): 122 | layers = [ 123 | ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), 124 | _get_activation_fn(activ), 125 | nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), 126 | nn.Dropout(p=dropout_p), 127 | ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), 128 | _get_activation_fn(activ), 129 | nn.Dropout(p=dropout_p) 130 | ] 131 | return nn.Sequential(*layers) 132 | 133 | class LocationLayer(nn.Module): 134 | def __init__(self, attention_n_filters, attention_kernel_size, 135 | attention_dim): 136 | super(LocationLayer, self).__init__() 137 | padding = int((attention_kernel_size - 1) / 2) 138 | self.location_conv = ConvNorm(2, attention_n_filters, 139 | kernel_size=attention_kernel_size, 140 | padding=padding, bias=False, stride=1, 141 | dilation=1) 142 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 143 | bias=False, w_init_gain='tanh') 144 | 145 | def forward(self, attention_weights_cat): 146 | processed_attention = self.location_conv(attention_weights_cat) 147 | processed_attention = processed_attention.transpose(1, 2) 148 | processed_attention = self.location_dense(processed_attention) 149 | return processed_attention 150 | 151 | 152 | class Attention(nn.Module): 153 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 154 | attention_location_n_filters, attention_location_kernel_size): 155 | super(Attention, self).__init__() 156 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 157 | bias=False, w_init_gain='tanh') 158 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 159 | w_init_gain='tanh') 160 | self.v = LinearNorm(attention_dim, 1, bias=False) 161 | self.location_layer = LocationLayer(attention_location_n_filters, 162 | attention_location_kernel_size, 163 | attention_dim) 164 | self.score_mask_value = -float("inf") 165 | 166 | def get_alignment_energies(self, query, processed_memory, 167 | attention_weights_cat): 168 | """ 169 | PARAMS 170 | ------ 171 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 172 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 173 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 174 | RETURNS 175 | ------- 176 | alignment (batch, max_time) 177 | """ 178 | 179 | processed_query = self.query_layer(query.unsqueeze(1)) 180 | processed_attention_weights = self.location_layer(attention_weights_cat) 181 | energies = self.v(torch.tanh( 182 | processed_query + processed_attention_weights + processed_memory)) 183 | 184 | energies = energies.squeeze(-1) 185 | return energies 186 | 187 | def forward(self, attention_hidden_state, memory, processed_memory, 188 | attention_weights_cat, mask): 189 | """ 190 | PARAMS 191 | ------ 192 | attention_hidden_state: attention rnn last output 193 | memory: encoder outputs 194 | processed_memory: processed encoder outputs 195 | attention_weights_cat: previous and cummulative attention weights 196 | mask: binary mask for padded data 197 | """ 198 | alignment = self.get_alignment_energies( 199 | attention_hidden_state, processed_memory, attention_weights_cat) 200 | 201 | if mask is not None: 202 | alignment.data.masked_fill_(mask, self.score_mask_value) 203 | 204 | attention_weights = F.softmax(alignment, dim=1) 205 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 206 | attention_context = attention_context.squeeze(1) 207 | 208 | return attention_context, attention_weights 209 | 210 | 211 | class ForwardAttentionV2(nn.Module): 212 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 213 | attention_location_n_filters, attention_location_kernel_size): 214 | super(ForwardAttentionV2, self).__init__() 215 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 216 | bias=False, w_init_gain='tanh') 217 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 218 | w_init_gain='tanh') 219 | self.v = LinearNorm(attention_dim, 1, bias=False) 220 | self.location_layer = LocationLayer(attention_location_n_filters, 221 | attention_location_kernel_size, 222 | attention_dim) 223 | self.score_mask_value = -float(1e20) 224 | 225 | def get_alignment_energies(self, query, processed_memory, 226 | attention_weights_cat): 227 | """ 228 | PARAMS 229 | ------ 230 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 231 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 232 | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) 233 | RETURNS 234 | ------- 235 | alignment (batch, max_time) 236 | """ 237 | 238 | processed_query = self.query_layer(query.unsqueeze(1)) 239 | processed_attention_weights = self.location_layer(attention_weights_cat) 240 | energies = self.v(torch.tanh( 241 | processed_query + processed_attention_weights + processed_memory)) 242 | 243 | energies = energies.squeeze(-1) 244 | return energies 245 | 246 | def forward(self, attention_hidden_state, memory, processed_memory, 247 | attention_weights_cat, mask, log_alpha): 248 | """ 249 | PARAMS 250 | ------ 251 | attention_hidden_state: attention rnn last output 252 | memory: encoder outputs 253 | processed_memory: processed encoder outputs 254 | attention_weights_cat: previous and cummulative attention weights 255 | mask: binary mask for padded data 256 | """ 257 | log_energy = self.get_alignment_energies( 258 | attention_hidden_state, processed_memory, attention_weights_cat) 259 | 260 | #log_energy = 261 | 262 | if mask is not None: 263 | log_energy.data.masked_fill_(mask, self.score_mask_value) 264 | 265 | #attention_weights = F.softmax(alignment, dim=1) 266 | 267 | #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] 268 | #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] 269 | 270 | #log_total_score = log_alpha + content_score 271 | 272 | #previous_attention_weights = attention_weights_cat[:,0,:] 273 | 274 | log_alpha_shift_padded = [] 275 | max_time = log_energy.size(1) 276 | for sft in range(2): 277 | shifted = log_alpha[:,:max_time-sft] 278 | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) 279 | log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) 280 | 281 | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) 282 | 283 | log_alpha_new = biased + log_energy 284 | 285 | attention_weights = F.softmax(log_alpha_new, dim=1) 286 | 287 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 288 | attention_context = attention_context.squeeze(1) 289 | 290 | return attention_context, attention_weights, log_alpha_new 291 | 292 | 293 | class PhaseShuffle2d(nn.Module): 294 | def __init__(self, n=2): 295 | super(PhaseShuffle2d, self).__init__() 296 | self.n = n 297 | self.random = random.Random(1) 298 | 299 | def forward(self, x, move=None): 300 | # x.size = (B, C, M, L) 301 | if move is None: 302 | move = self.random.randint(-self.n, self.n) 303 | 304 | if move == 0: 305 | return x 306 | else: 307 | left = x[:, :, :, :move] 308 | right = x[:, :, :, move:] 309 | shuffled = torch.cat([right, left], dim=3) 310 | return shuffled 311 | 312 | class PhaseShuffle1d(nn.Module): 313 | def __init__(self, n=2): 314 | super(PhaseShuffle1d, self).__init__() 315 | self.n = n 316 | self.random = random.Random(1) 317 | 318 | def forward(self, x, move=None): 319 | # x.size = (B, C, M, L) 320 | if move is None: 321 | move = self.random.randint(-self.n, self.n) 322 | 323 | if move == 0: 324 | return x 325 | else: 326 | left = x[:, :, :move] 327 | right = x[:, :, move:] 328 | shuffled = torch.cat([right, left], dim=2) 329 | 330 | return shuffled 331 | 332 | class MFCC(nn.Module): 333 | def __init__(self, n_mfcc=40, n_mels=80): 334 | super(MFCC, self).__init__() 335 | self.n_mfcc = n_mfcc 336 | self.n_mels = n_mels 337 | self.norm = 'ortho' 338 | dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) 339 | self.register_buffer('dct_mat', dct_mat) 340 | 341 | def forward(self, mel_specgram): 342 | if len(mel_specgram.shape) == 2: 343 | mel_specgram = mel_specgram.unsqueeze(0) 344 | unsqueezed = True 345 | else: 346 | unsqueezed = False 347 | # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) 348 | # -> (channel, time, n_mfcc).tranpose(...) 349 | mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) 350 | 351 | # unpack batch 352 | if unsqueezed: 353 | mfcc = mfcc.squeeze(0) 354 | return mfcc 355 | -------------------------------------------------------------------------------- /meldataset.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | import os 4 | import os.path as osp 5 | import time 6 | import random 7 | import numpy as np 8 | import random 9 | import soundfile as sf 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | import torchaudio 15 | from torch.utils.data import DataLoader 16 | 17 | from g2p_en import G2p 18 | 19 | import logging 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.DEBUG) 22 | from text_utils import TextCleaner 23 | np.random.seed(1) 24 | random.seed(1) 25 | DEFAULT_DICT_PATH = osp.join(osp.dirname(__file__), 'word_index_dict.txt') 26 | SPECT_PARAMS = { 27 | "n_fft": 2048, 28 | "win_length": 1200, 29 | "hop_length": 300 30 | } 31 | MEL_PARAMS = { 32 | "n_mels": 80, 33 | "n_fft": 2048, 34 | "win_length": 1200, 35 | "hop_length": 300 36 | } 37 | 38 | class MelDataset(torch.utils.data.Dataset): 39 | def __init__(self, 40 | data_list, 41 | dict_path=DEFAULT_DICT_PATH, 42 | sr=24000 43 | ): 44 | 45 | spect_params = SPECT_PARAMS 46 | mel_params = MEL_PARAMS 47 | 48 | _data_list = [l[:-1].split('|') for l in data_list] 49 | self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] 50 | self.text_cleaner = TextCleaner(dict_path) 51 | self.sr = sr 52 | 53 | self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) 54 | self.mean, self.std = -4, 4 55 | 56 | self.g2p = G2p() 57 | 58 | def __len__(self): 59 | return len(self.data_list) 60 | 61 | def __getitem__(self, idx): 62 | data = self.data_list[idx] 63 | wave, text_tensor, speaker_id = self._load_tensor(data) 64 | wave_tensor = torch.from_numpy(wave).float() 65 | mel_tensor = self.to_melspec(wave_tensor) 66 | 67 | if (text_tensor.size(0)+1) >= (mel_tensor.size(1) // 3): 68 | mel_tensor = F.interpolate( 69 | mel_tensor.unsqueeze(0), size=(text_tensor.size(0)+1)*3, align_corners=False, 70 | mode='linear').squeeze(0) 71 | 72 | acoustic_feature = (torch.log(1e-5 + mel_tensor) - self.mean)/self.std 73 | 74 | length_feature = acoustic_feature.size(1) 75 | acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] 76 | 77 | return wave_tensor, acoustic_feature, text_tensor, data[0] 78 | 79 | def _load_tensor(self, data): 80 | wave_path, text, speaker_id = data 81 | speaker_id = int(speaker_id) 82 | wave, sr = sf.read(wave_path) 83 | 84 | # phonemize the text 85 | ps = self.g2p(text.replace('-', ' ')) 86 | if "'" in ps: 87 | ps.remove("'") 88 | text = self.text_cleaner(ps) 89 | blank_index = self.text_cleaner.word_index_dictionary[" "] 90 | text.insert(0, blank_index) # add a blank at the beginning (silence) 91 | text.append(blank_index) # add a blank at the end (silence) 92 | 93 | text = torch.LongTensor(text) 94 | 95 | return wave, text, speaker_id 96 | 97 | 98 | 99 | 100 | class Collater(object): 101 | """ 102 | Args: 103 | return_wave (bool): if true, will return the wave data along with spectrogram. 104 | """ 105 | 106 | def __init__(self, return_wave=False): 107 | self.text_pad_index = 0 108 | self.return_wave = return_wave 109 | 110 | def __call__(self, batch): 111 | batch_size = len(batch) 112 | 113 | # sort by mel length 114 | lengths = [b[1].shape[1] for b in batch] 115 | batch_indexes = np.argsort(lengths)[::-1] 116 | batch = [batch[bid] for bid in batch_indexes] 117 | 118 | nmels = batch[0][1].size(0) 119 | max_mel_length = max([b[1].shape[1] for b in batch]) 120 | max_text_length = max([b[2].shape[0] for b in batch]) 121 | 122 | mels = torch.zeros((batch_size, nmels, max_mel_length)).float() 123 | texts = torch.zeros((batch_size, max_text_length)).long() 124 | input_lengths = torch.zeros(batch_size).long() 125 | output_lengths = torch.zeros(batch_size).long() 126 | paths = ['' for _ in range(batch_size)] 127 | for bid, (_, mel, text, path) in enumerate(batch): 128 | mel_size = mel.size(1) 129 | text_size = text.size(0) 130 | mels[bid, :, :mel_size] = mel 131 | texts[bid, :text_size] = text 132 | input_lengths[bid] = text_size 133 | output_lengths[bid] = mel_size 134 | paths[bid] = path 135 | assert(text_size < (mel_size//2)) 136 | 137 | if self.return_wave: 138 | waves = [b[0] for b in batch] 139 | return texts, input_lengths, mels, output_lengths, paths, waves 140 | 141 | return texts, input_lengths, mels, output_lengths 142 | 143 | 144 | 145 | def build_dataloader(path_list, 146 | validation=False, 147 | batch_size=4, 148 | num_workers=1, 149 | device='cpu', 150 | collate_config={}, 151 | dataset_config={}): 152 | 153 | dataset = MelDataset(path_list, **dataset_config) 154 | collate_fn = Collater(**collate_config) 155 | data_loader = DataLoader(dataset, 156 | batch_size=batch_size, 157 | shuffle=(not validation), 158 | num_workers=num_workers, 159 | drop_last=(not validation), 160 | collate_fn=collate_fn, 161 | pin_memory=(device != 'cpu')) 162 | 163 | return data_loader 164 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import TransformerEncoder 5 | import torch.nn.functional as F 6 | from layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock 7 | 8 | def build_model(model_params={}, model_type='asr'): 9 | model = ASRCNN(**model_params) 10 | return model 11 | 12 | 13 | class ASRCNN(nn.Module): 14 | def __init__(self, 15 | input_dim=80, 16 | hidden_dim=256, 17 | n_token=35, 18 | n_layers=6, 19 | token_embedding_dim=256, 20 | 21 | ): 22 | super().__init__() 23 | self.n_token = n_token 24 | self.n_down = 1 25 | self.to_mfcc = MFCC() 26 | self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2) 27 | self.cnns = nn.Sequential( 28 | *[nn.Sequential( 29 | ConvBlock(hidden_dim), 30 | nn.GroupNorm(num_groups=1, num_channels=hidden_dim) 31 | ) for n in range(n_layers)]) 32 | self.projection = ConvNorm(hidden_dim, hidden_dim // 2) 33 | self.ctc_linear = nn.Sequential( 34 | LinearNorm(hidden_dim//2, hidden_dim), 35 | nn.ReLU(), 36 | LinearNorm(hidden_dim, n_token)) 37 | self.asr_s2s = ASRS2S( 38 | embedding_dim=token_embedding_dim, 39 | hidden_dim=hidden_dim//2, 40 | n_token=n_token) 41 | 42 | def forward(self, x, src_key_padding_mask=None, text_input=None): 43 | x = self.to_mfcc(x) 44 | x = self.init_cnn(x) 45 | x = self.cnns(x) 46 | 47 | x = self.projection(x) 48 | x = x.transpose(1, 2) 49 | ctc_logit = self.ctc_linear(x) 50 | if text_input is not None: 51 | _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input) 52 | return ctc_logit, s2s_logit, s2s_attn 53 | else: 54 | return ctc_logit 55 | 56 | def get_feature(self, x): 57 | x = self.to_mfcc(x) 58 | x = self.init_cnn(x) 59 | x = self.cnns(x) 60 | x = self.instance_norm(x) 61 | x = self.projection(x) 62 | return x 63 | 64 | def length_to_mask(self, lengths): 65 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 66 | mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device) 67 | return mask 68 | 69 | def get_future_mask(self, out_length, unmask_future_steps=0): 70 | """ 71 | Args: 72 | out_length (int): returned mask shape is (out_length, out_length). 73 | unmask_futre_steps (int): unmasking future step size. 74 | Return: 75 | mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False 76 | """ 77 | index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1) 78 | mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps) 79 | return mask 80 | 81 | class ASRS2S(nn.Module): 82 | def __init__(self, 83 | embedding_dim=256, 84 | hidden_dim=512, 85 | n_location_filters=32, 86 | location_kernel_size=63, 87 | n_token=40): 88 | super(ASRS2S, self).__init__() 89 | self.embedding = nn.Embedding(n_token, embedding_dim) 90 | val_range = math.sqrt(6 / hidden_dim) 91 | self.embedding.weight.data.uniform_(-val_range, val_range) 92 | 93 | self.decoder_rnn_dim = hidden_dim 94 | self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) 95 | self.attention_layer = Attention( 96 | self.decoder_rnn_dim, 97 | hidden_dim, 98 | hidden_dim, 99 | n_location_filters, 100 | location_kernel_size 101 | ) 102 | self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim) 103 | self.project_to_hidden = nn.Sequential( 104 | LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), 105 | nn.Tanh()) 106 | self.sos = 1 107 | self.eos = 2 108 | 109 | def initialize_decoder_states(self, memory, mask): 110 | """ 111 | moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) 112 | """ 113 | B, L, H = memory.shape 114 | self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) 115 | self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) 116 | self.attention_weights = torch.zeros((B, L)).type_as(memory) 117 | self.attention_weights_cum = torch.zeros((B, L)).type_as(memory) 118 | self.attention_context = torch.zeros((B, H)).type_as(memory) 119 | self.memory = memory 120 | self.processed_memory = self.attention_layer.memory_layer(memory) 121 | self.mask = mask 122 | self.unk_index = 3 123 | self.random_mask = 0.1 124 | 125 | def forward(self, memory, memory_mask, text_input): 126 | """ 127 | moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) 128 | moemory_mask.shape = (B, L, ) 129 | texts_input.shape = (B, T) 130 | """ 131 | self.initialize_decoder_states(memory, memory_mask) 132 | # text random mask 133 | random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device) 134 | _text_input = text_input.clone() 135 | _text_input.masked_fill_(random_mask, self.unk_index) 136 | decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel] 137 | start_embedding = self.embedding( 138 | torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device)) 139 | decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0) 140 | 141 | hidden_outputs, logit_outputs, alignments = [], [], [] 142 | while len(hidden_outputs) < decoder_inputs.size(0): 143 | 144 | decoder_input = decoder_inputs[len(hidden_outputs)] 145 | hidden, logit, attention_weights = self.decode(decoder_input) 146 | hidden_outputs += [hidden] 147 | logit_outputs += [logit] 148 | alignments += [attention_weights] 149 | 150 | hidden_outputs, logit_outputs, alignments = \ 151 | self.parse_decoder_outputs( 152 | hidden_outputs, logit_outputs, alignments) 153 | 154 | return hidden_outputs, logit_outputs, alignments 155 | 156 | 157 | def decode(self, decoder_input): 158 | 159 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 160 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 161 | cell_input, 162 | (self.decoder_hidden, self.decoder_cell)) 163 | 164 | attention_weights_cat = torch.cat( 165 | (self.attention_weights.unsqueeze(1), 166 | self.attention_weights_cum.unsqueeze(1)),dim=1) 167 | 168 | self.attention_context, self.attention_weights = self.attention_layer( 169 | self.decoder_hidden, 170 | self.memory, 171 | self.processed_memory, 172 | attention_weights_cat, 173 | self.mask) 174 | 175 | self.attention_weights_cum += self.attention_weights 176 | 177 | hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1) 178 | hidden = self.project_to_hidden(hidden_and_context) 179 | 180 | # dropout to increasing g 181 | logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training)) 182 | 183 | return hidden, logit, self.attention_weights 184 | 185 | def parse_decoder_outputs(self, hidden, logit, alignments): 186 | 187 | # -> [B, T_out + 1, max_time] 188 | alignments = torch.stack(alignments).transpose(0,1) 189 | # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols] 190 | logit = torch.stack(logit).transpose(0, 1).contiguous() 191 | hidden = torch.stack(hidden).transpose(0, 1).contiguous() 192 | 193 | return hidden, logit, alignments 194 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os, sys 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.optim import Optimizer 8 | from functools import reduce 9 | from torch.optim import AdamW 10 | 11 | class MultiOptimizer: 12 | def __init__(self, optimizers={}, schedulers={}): 13 | self.optimizers = optimizers 14 | self.schedulers = schedulers 15 | self.keys = list(optimizers.keys()) 16 | self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()]) 17 | 18 | def state_dict(self): 19 | state_dicts = [(key, self.optimizers[key].state_dict())\ 20 | for key in self.keys] 21 | return state_dicts 22 | 23 | def load_state_dict(self, state_dict): 24 | for key, val in state_dict: 25 | try: 26 | self.optimizers[key].load_state_dict(val) 27 | except: 28 | print("Unloaded %s" % key) 29 | 30 | 31 | def step(self, key=None): 32 | if key is not None: 33 | self.optimizers[key].step() 34 | else: 35 | _ = [self.optimizers[key].step() for key in self.keys] 36 | 37 | def zero_grad(self, key=None): 38 | if key is not None: 39 | self.optimizers[key].zero_grad() 40 | else: 41 | _ = [self.optimizers[key].zero_grad() for key in self.keys] 42 | 43 | def scheduler(self, *args, key=None): 44 | if key is not None: 45 | self.schedulers[key].step(*args) 46 | else: 47 | _ = [self.schedulers[key].step(*args) for key in self.keys] 48 | 49 | 50 | def build_optimizer(parameters): 51 | optimizer, scheduler = _define_optimizer(parameters) 52 | return optimizer, scheduler 53 | 54 | def _define_optimizer(params): 55 | optimizer_params = params['optimizer_params'] 56 | sch_params = params['scheduler_params'] 57 | optimizer = AdamW( 58 | params['params'], 59 | lr=optimizer_params.get('lr', 1e-4), 60 | weight_decay=optimizer_params.get('weight_decay', 5e-4), 61 | betas=(0.9, 0.98), 62 | eps=1e-9) 63 | scheduler = _define_scheduler(optimizer, sch_params) 64 | return optimizer, scheduler 65 | 66 | def _define_scheduler(optimizer, params): 67 | print(params) 68 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 69 | optimizer, 70 | max_lr=params.get('max_lr', 5e-4), 71 | epochs=params.get('epochs', 200), 72 | steps_per_epoch=params.get('steps_per_epoch', 1000), 73 | pct_start=params.get('pct_start', 0.0), 74 | final_div_factor=5) 75 | 76 | return scheduler 77 | 78 | def build_multi_optimizer(parameters_dict, scheduler_params): 79 | optim = dict([(key, AdamW(params, lr=1e-4, weight_decay=1e-6, betas=(0.9, 0.98), eps=1e-9)) 80 | for key, params in parameters_dict.items()]) 81 | 82 | schedulers = dict([(key, _define_scheduler(opt, scheduler_params)) \ 83 | for key, opt in optim.items()]) 84 | 85 | multi_optim = MultiOptimizer(optim, schedulers) 86 | return multi_optim 87 | -------------------------------------------------------------------------------- /text_utils.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import os.path as osp 4 | import pandas as pd 5 | 6 | DEFAULT_DICT_PATH = osp.join('word_index_dict.txt') 7 | class TextCleaner: 8 | def __init__(self, word_index_dict_path=DEFAULT_DICT_PATH): 9 | self.word_index_dictionary = self.load_dictionary(word_index_dict_path) 10 | 11 | def __call__(self, text): 12 | indexes = [] 13 | for char in text: 14 | try: 15 | indexes.append(self.word_index_dictionary[char]) 16 | except KeyError: 17 | print(char) 18 | return indexes 19 | 20 | def load_dictionary(self, path): 21 | csv = pd.read_csv(path, header=None).values 22 | word_index_dict = {word: index for word, index in csv} 23 | return word_index_dict 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from meldataset import build_dataloader 2 | from optimizers import build_optimizer 3 | from utils import * 4 | from models import build_model 5 | from trainer import Trainer 6 | 7 | import os 8 | import os.path as osp 9 | import re 10 | import sys 11 | import yaml 12 | import shutil 13 | import numpy as np 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | import click 17 | 18 | import logging 19 | from logging import StreamHandler 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.DEBUG) 22 | handler = StreamHandler() 23 | handler.setLevel(logging.DEBUG) 24 | logger.addHandler(handler) 25 | 26 | torch.backends.cudnn.benchmark = True 27 | 28 | @click.command() 29 | @click.option('-p', '--config_path', default='./Configs/config.yml', type=str) 30 | def main(config_path): 31 | config = yaml.safe_load(open(config_path)) 32 | log_dir = config['log_dir'] 33 | if not osp.exists(log_dir): os.mkdir(log_dir) 34 | shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) 35 | 36 | writer = SummaryWriter(log_dir + "/tensorboard") 37 | 38 | # write logs 39 | file_handler = logging.FileHandler(osp.join(log_dir, 'train.log')) 40 | file_handler.setLevel(logging.DEBUG) 41 | file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s')) 42 | logger.addHandler(file_handler) 43 | 44 | batch_size = config.get('batch_size', 10) 45 | device = config.get('device', 'cpu') 46 | epochs = config.get('epochs', 1000) 47 | save_freq = config.get('save_freq', 20) 48 | train_path = config.get('train_data', None) 49 | val_path = config.get('val_data', None) 50 | 51 | train_list, val_list = get_data_path_list(train_path, val_path) 52 | train_dataloader = build_dataloader(train_list, 53 | batch_size=batch_size, 54 | num_workers=8, 55 | dataset_config=config.get('dataset_params', {}), 56 | device=device) 57 | 58 | val_dataloader = build_dataloader(val_list, 59 | batch_size=batch_size, 60 | validation=True, 61 | num_workers=2, 62 | device=device, 63 | dataset_config=config.get('dataset_params', {})) 64 | 65 | model = build_model(model_params=config['model_params'] or {}) 66 | 67 | scheduler_params = { 68 | "max_lr": float(config['optimizer_params'].get('lr', 5e-4)), 69 | "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), 70 | "epochs": epochs, 71 | "steps_per_epoch": len(train_dataloader), 72 | } 73 | 74 | model.to(device) 75 | optimizer, scheduler = build_optimizer( 76 | {"params": model.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params}) 77 | 78 | blank_index = train_dataloader.dataset.text_cleaner.word_index_dictionary[" "] # get blank index 79 | 80 | criterion = build_criterion(critic_params={ 81 | 'ctc': {'blank': blank_index}, 82 | }) 83 | 84 | trainer = Trainer(model=model, 85 | criterion=criterion, 86 | optimizer=optimizer, 87 | scheduler=scheduler, 88 | device=device, 89 | train_dataloader=train_dataloader, 90 | val_dataloader=val_dataloader, 91 | logger=logger) 92 | 93 | if config.get('pretrained_model', '') != '': 94 | trainer.load_checkpoint(config['pretrained_model'], 95 | load_only_params=config.get('load_only_params', True)) 96 | 97 | for epoch in range(1, epochs+1): 98 | train_results = trainer._train_epoch() 99 | eval_results = trainer._eval_epoch() 100 | results = train_results.copy() 101 | results.update(eval_results) 102 | logger.info('--- epoch %d ---' % epoch) 103 | for key, value in results.items(): 104 | if isinstance(value, float): 105 | logger.info('%-15s: %.4f' % (key, value)) 106 | writer.add_scalar(key, value, epoch) 107 | else: 108 | for v in value: 109 | writer.add_figure('eval_attn', plot_image(v), epoch) 110 | if (epoch % save_freq) == 0: 111 | trainer.save_checkpoint(osp.join(log_dir, 'epoch_%05d.pth' % epoch)) 112 | 113 | return 0 114 | 115 | if __name__=="__main__": 116 | main() -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import os.path as osp 5 | import sys 6 | import time 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | from utils import calc_wer 16 | 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.DEBUG) 20 | 21 | from utils import * 22 | 23 | class Trainer(object): 24 | def __init__(self, 25 | model=None, 26 | criterion=None, 27 | optimizer=None, 28 | scheduler=None, 29 | config={}, 30 | device=torch.device("cpu"), 31 | logger=logger, 32 | train_dataloader=None, 33 | val_dataloader=None, 34 | initial_steps=0, 35 | initial_epochs=0): 36 | 37 | self.steps = initial_steps 38 | self.epochs = initial_epochs 39 | self.model = model 40 | self.criterion = criterion 41 | self.optimizer = optimizer 42 | self.scheduler = scheduler 43 | self.train_dataloader = train_dataloader 44 | self.val_dataloader = val_dataloader 45 | self.config = config 46 | self.device = device 47 | self.finish_train = False 48 | self.logger = logger 49 | self.fp16_run = False 50 | 51 | def save_checkpoint(self, checkpoint_path): 52 | """Save checkpoint. 53 | Args: 54 | checkpoint_path (str): Checkpoint path to be saved. 55 | """ 56 | state_dict = { 57 | "optimizer": self.optimizer.state_dict(), 58 | "scheduler": self.scheduler.state_dict(), 59 | "steps": self.steps, 60 | "epochs": self.epochs, 61 | } 62 | state_dict["model"] = self.model.state_dict() 63 | 64 | if not os.path.exists(os.path.dirname(checkpoint_path)): 65 | os.makedirs(os.path.dirname(checkpoint_path)) 66 | torch.save(state_dict, checkpoint_path) 67 | 68 | def load_checkpoint(self, checkpoint_path, load_only_params=False): 69 | """Load checkpoint. 70 | 71 | Args: 72 | checkpoint_path (str): Checkpoint path to be loaded. 73 | load_only_params (bool): Whether to load only model parameters. 74 | 75 | """ 76 | state_dict = torch.load(checkpoint_path, map_location="cpu") 77 | self._load(state_dict["model"], self.model) 78 | 79 | if not load_only_params: 80 | self.steps = state_dict["steps"] 81 | self.epochs = state_dict["epochs"] 82 | self.optimizer.load_state_dict(state_dict["optimizer"]) 83 | 84 | # overwrite schedular argument parameters 85 | state_dict["scheduler"].update(**self.config.get("scheduler_params", {})) 86 | self.scheduler.load_state_dict(state_dict["scheduler"]) 87 | 88 | def _load(self, states, model, force_load=True): 89 | model_states = model.state_dict() 90 | for key, val in states.items(): 91 | try: 92 | if key not in model_states: 93 | continue 94 | if isinstance(val, nn.Parameter): 95 | val = val.data 96 | 97 | if val.shape != model_states[key].shape: 98 | self.logger.info("%s does not have same shape" % key) 99 | print(val.shape, model_states[key].shape) 100 | if not force_load: 101 | continue 102 | 103 | min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape)) 104 | slices = [slice(0, min_index) for min_index in min_shape] 105 | model_states[key][slices].copy_(val[slices]) 106 | else: 107 | model_states[key].copy_(val) 108 | except: 109 | self.logger.info("not exist :%s" % key) 110 | print("not exist ", key) 111 | 112 | @staticmethod 113 | def get_gradient_norm(model): 114 | total_norm = 0 115 | for p in model.parameters(): 116 | param_norm = p.grad.data.norm(2) 117 | total_norm += param_norm.item() ** 2 118 | 119 | total_norm = np.sqrt(total_norm) 120 | return total_norm 121 | 122 | @staticmethod 123 | def length_to_mask(lengths): 124 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 125 | mask = torch.gt(mask+1, lengths.unsqueeze(1)) 126 | return mask 127 | 128 | def _get_lr(self): 129 | for param_group in self.optimizer.param_groups: 130 | lr = param_group['lr'] 131 | break 132 | return lr 133 | 134 | @staticmethod 135 | def get_image(arrs): 136 | pil_images = [] 137 | height = 0 138 | width = 0 139 | for arr in arrs: 140 | uint_arr = (((arr - arr.min()) / (arr.max() - arr.min())) * 255).astype(np.uint8) 141 | pil_image = Image.fromarray(uint_arr) 142 | pil_images.append(pil_image) 143 | height += uint_arr.shape[0] 144 | width = max(width, uint_arr.shape[1]) 145 | 146 | palette = Image.new('L', (width, height)) 147 | curr_heigth = 0 148 | for pil_image in pil_images: 149 | palette.paste(pil_image, (0, curr_heigth)) 150 | curr_heigth += pil_image.size[1] 151 | 152 | return palette 153 | 154 | def run(self, batch): 155 | self.optimizer.zero_grad() 156 | batch = [b.to(self.device) for b in batch] 157 | text_input, text_input_length, mel_input, mel_input_length = batch 158 | mel_input_length = mel_input_length // (2 ** self.model.n_down) 159 | future_mask = self.model.get_future_mask( 160 | mel_input.size(2)//(2**self.model.n_down), unmask_future_steps=0).to(self.device) 161 | mel_mask = self.model.length_to_mask(mel_input_length) 162 | text_mask = self.model.length_to_mask(text_input_length) 163 | ppgs, s2s_pred, s2s_attn = self.model( 164 | mel_input, src_key_padding_mask=mel_mask, text_input=text_input) 165 | 166 | loss_ctc = self.criterion['ctc'](ppgs.log_softmax(dim=2).transpose(0, 1), 167 | text_input, mel_input_length, text_input_length) 168 | 169 | loss_s2s = 0 170 | for _s2s_pred, _text_input, _text_length in zip(s2s_pred, text_input, text_input_length): 171 | loss_s2s += self.criterion['ce'](_s2s_pred[:_text_length], _text_input[:_text_length]) 172 | loss_s2s /= text_input.size(0) 173 | 174 | loss = loss_ctc + loss_s2s 175 | loss.backward() 176 | torch.nn.utils.clip_grad_value_(self.model.parameters(), 5) 177 | self.optimizer.step() 178 | self.scheduler.step() 179 | return {'loss': loss.item(), 180 | 'ctc': loss_ctc.item(), 181 | 's2s': loss_s2s.item()} 182 | 183 | def _train_epoch(self): 184 | train_losses = defaultdict(list) 185 | self.model.train() 186 | for train_steps_per_epoch, batch in enumerate(tqdm(self.train_dataloader, desc="[train]"), 1): 187 | losses = self.run(batch) 188 | for key, value in losses.items(): 189 | train_losses["train/%s" % key].append(value) 190 | 191 | train_losses = {key: np.mean(value) for key, value in train_losses.items()} 192 | train_losses['train/learning_rate'] = self._get_lr() 193 | return train_losses 194 | 195 | @torch.no_grad() 196 | def _eval_epoch(self): 197 | self.model.eval() 198 | eval_losses = defaultdict(list) 199 | eval_images = defaultdict(list) 200 | for eval_steps_per_epoch, batch in enumerate(tqdm(self.val_dataloader, desc="[eval]"), 1): 201 | batch = [b.to(self.device) for b in batch] 202 | text_input, text_input_length, mel_input, mel_input_length = batch 203 | mel_input_length = mel_input_length // (2 ** self.model.n_down) 204 | future_mask = self.model.get_future_mask( 205 | mel_input.size(2)//(2**self.model.n_down), unmask_future_steps=0).to(self.device) 206 | mel_mask = self.model.length_to_mask(mel_input_length) 207 | text_mask = self.model.length_to_mask(text_input_length) 208 | ppgs, s2s_pred, s2s_attn = self.model( 209 | mel_input, src_key_padding_mask=mel_mask, text_input=text_input) 210 | loss_ctc = self.criterion['ctc'](ppgs.log_softmax(dim=2).transpose(0, 1), 211 | text_input, mel_input_length, text_input_length) 212 | loss_s2s = 0 213 | for _s2s_pred, _text_input, _text_length in zip(s2s_pred, text_input, text_input_length): 214 | loss_s2s += self.criterion['ce'](_s2s_pred[:_text_length], _text_input[:_text_length]) 215 | loss_s2s /= text_input.size(0) 216 | loss = loss_ctc + loss_s2s 217 | 218 | eval_losses["eval/ctc"].append(loss_ctc.item()) 219 | eval_losses["eval/s2s"].append(loss_s2s.item()) 220 | eval_losses["eval/loss"].append(loss.item()) 221 | 222 | _, amax_ppgs = torch.max(ppgs, dim=2) 223 | wers = [calc_wer(target[:text_length], 224 | pred[:mel_length], 225 | ignore_indexes=list(range(5))) \ 226 | for target, pred, text_length, mel_length in zip( 227 | text_input.cpu(), amax_ppgs.cpu(), text_input_length.cpu(), mel_input_length.cpu())] 228 | eval_losses["eval/wer"].extend(wers) 229 | 230 | _, amax_s2s = torch.max(s2s_pred, dim=2) 231 | acc = [torch.eq(target[:length], pred[:length]).float().mean().item() \ 232 | for target, pred, length in zip(text_input.cpu(), amax_s2s.cpu(), text_input_length.cpu())] 233 | eval_losses["eval/acc"].extend(acc) 234 | 235 | if eval_steps_per_epoch <= 2: 236 | eval_images["eval/image"].append( 237 | self.get_image([s2s_attn[0].cpu().numpy()])) 238 | 239 | eval_losses = {key: np.mean(value) for key, value in eval_losses.items()} 240 | eval_losses.update(eval_images) 241 | return eval_losses -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | from collections import defaultdict 6 | 7 | import matplotlib 8 | import numpy as np 9 | import soundfile as sf 10 | import torch 11 | from torch import nn 12 | import jiwer 13 | 14 | import matplotlib.pylab as plt 15 | 16 | def calc_wer(target, pred, ignore_indexes=[0]): 17 | target_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(target))))) 18 | pred_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(pred))))) 19 | target_str = ' '.join(target_chars) 20 | pred_str = ' '.join(pred_chars) 21 | error = jiwer.wer(target_str, pred_str) 22 | return error 23 | 24 | def drop_duplicated(chars): 25 | ret_chars = [chars[0]] 26 | for prev, curr in zip(chars[:-1], chars[1:]): 27 | if prev != curr: 28 | ret_chars.append(curr) 29 | return ret_chars 30 | 31 | def build_criterion(critic_params={}): 32 | criterion = { 33 | "ce": nn.CrossEntropyLoss(ignore_index=-1), 34 | "ctc": torch.nn.CTCLoss(**critic_params.get('ctc', {})), 35 | } 36 | return criterion 37 | 38 | def get_data_path_list(train_path=None, val_path=None): 39 | if train_path is None: 40 | train_path = "Data/train_list.txt" 41 | if val_path is None: 42 | val_path = "Data/val_list.txt" 43 | 44 | with open(train_path, 'r') as f: 45 | train_list = f.readlines() 46 | with open(val_path, 'r') as f: 47 | val_list = f.readlines() 48 | 49 | return train_list, val_list 50 | 51 | 52 | def plot_image(image): 53 | fig, ax = plt.subplots(figsize=(10, 2)) 54 | im = ax.imshow(image, aspect="auto", origin="lower", 55 | interpolation='none') 56 | 57 | fig.canvas.draw() 58 | plt.close() 59 | 60 | return fig -------------------------------------------------------------------------------- /word_index_dict.txt: -------------------------------------------------------------------------------- 1 | "",0 2 | "",1 3 | "",2 4 | "",3 5 | " ",4 6 | ",",5 7 | ".",6 8 | ":",7 9 | "!",8 10 | "?",9 11 | "AA0",10 12 | "AA1",11 13 | "AA2",12 14 | "AE0",13 15 | "AE1",14 16 | "AE2",15 17 | "AH0",16 18 | "AH1",17 19 | "AH2",18 20 | "AO0",19 21 | "AO1",20 22 | "AO2",21 23 | "AW0",22 24 | "AW1",23 25 | "AW2",24 26 | "AY0",25 27 | "AY1",26 28 | "AY2",27 29 | "B",28 30 | "CH",29 31 | "D",30 32 | "DH",31 33 | "EH0",32 34 | "EH1",33 35 | "EH2",34 36 | "ER0",35 37 | "ER1",36 38 | "ER2",37 39 | "EY0",38 40 | "EY1",39 41 | "EY2",40 42 | "F",41 43 | "G",42 44 | "HH",43 45 | "IH0",44 46 | "IH1",45 47 | "IH2",46 48 | "IY0",47 49 | "IY1",48 50 | "IY2",49 51 | "JH",50 52 | "K",51 53 | "L",52 54 | "M",53 55 | "N",54 56 | "NG",55 57 | "OW0",56 58 | "OW1",57 59 | "OW2",58 60 | "OY0",59 61 | "OY1",60 62 | "OY2",61 63 | "P",62 64 | "R",63 65 | "S",64 66 | "SH",65 67 | "T",66 68 | "TH",67 69 | "UH0",68 70 | "UH1",69 71 | "UH2",70 72 | "UW",71 73 | "UW0",72 74 | "UW1",73 75 | "UW2",74 76 | "V",75 77 | "W",76 78 | "Y",77 79 | "Z",78 80 | "ZH",79 81 | --------------------------------------------------------------------------------