├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── README.md ├── audio_processing.py ├── data_utils.py ├── distributed.py ├── hparams.py ├── inference.ipynb ├── layers.py ├── logger.py ├── loss_function.py ├── loss_scaler.py ├── model.py ├── multiproc.py ├── plotting_utils.py ├── requirements.txt ├── stft.py ├── text └── __init__.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | *.DS_Store 29 | .DS_Store 30 | .idea/* 31 | MANIFEST 32 | 33 | config_v2.json 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | *.ipynb 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "waveglow"] 2 | path = waveglow 3 | url = https://github.com/NVIDIA/waveglow 4 | branch = master 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:nightly-devel-cuda10.0-cudnn7 2 | ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} 3 | 4 | RUN apt-get update -y 5 | 6 | RUN pip install numpy scipy matplotlib librosa==0.6.0 tensorflow tensorboardX inflect==0.2.5 Unidecode==1.0.22 pillow jupyter 7 | 8 | ADD apex /apex/ 9 | WORKDIR /apex/ 10 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron2 for Korean (taKotron2) 2 | 3 | - Code borrow from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2) 4 | - Modify for Korean TTS System (see [text/\_\_init\_\_.py](https://github.com/sooftware/nvidia-tacotron2/blob/master/text/__init__.py)) 5 | - Normalize with NFKD 6 | - [g2pK](https://github.com/Kyubyong/g2pK) 7 | - Add learning rate scheduler (transformer style) 8 | - Add [Wandb](https://wandb.ai/) monitoring 9 | - Add generate mel-spectrogram and alignment monitoring. 10 | 11 | 12 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | import layers 7 | from utils import load_wav_to_torch, load_filepaths_and_text 8 | from text import text_to_sequence 9 | 10 | 11 | class TextMelLoader(torch.utils.data.Dataset): 12 | """ 13 | 1) loads audio,text pairs 14 | 2) normalizes text and converts them to sequences of one-hot vectors 15 | 3) computes mel-spectrograms from audio files. 16 | """ 17 | def __init__(self, audiopaths_and_text, hparams): 18 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 19 | self.max_wav_value = hparams.max_wav_value 20 | self.sampling_rate = hparams.sampling_rate 21 | self.load_mel_from_disk = hparams.load_mel_from_disk 22 | self.stft = layers.TacotronSTFT( 23 | hparams.filter_length, hparams.hop_length, hparams.win_length, 24 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 25 | hparams.mel_fmax) 26 | random.seed(hparams.seed) 27 | random.shuffle(self.audiopaths_and_text) 28 | 29 | def get_mel_text_pair(self, audiopath_and_text): 30 | # separate filename and text 31 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 32 | text = self.get_text(text) 33 | mel = self.get_mel(audiopath) 34 | return (text, mel) 35 | 36 | def get_mel(self, filename): 37 | if not self.load_mel_from_disk: 38 | audio, sampling_rate = load_wav_to_torch(filename) 39 | if sampling_rate != self.stft.sampling_rate: 40 | raise ValueError("{} {} SR doesn't match target {} SR".format( 41 | sampling_rate, self.stft.sampling_rate)) 42 | audio_norm = audio / self.max_wav_value 43 | audio_norm = audio_norm.unsqueeze(0) 44 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 45 | melspec = self.stft.mel_spectrogram(audio_norm) 46 | melspec = torch.squeeze(melspec, 0) 47 | else: 48 | melspec = torch.from_numpy(np.load(filename)) 49 | assert melspec.size(0) == self.stft.n_mel_channels, ( 50 | 'Mel dimension mismatch: given {}, expected {}'.format( 51 | melspec.size(0), self.stft.n_mel_channels)) 52 | 53 | return melspec 54 | 55 | def get_text(self, text): 56 | text_norm = torch.IntTensor(text_to_sequence(text)) 57 | return text_norm 58 | 59 | def __getitem__(self, index): 60 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 61 | 62 | def __len__(self): 63 | return len(self.audiopaths_and_text) 64 | 65 | 66 | class TextMelCollate(): 67 | """ Zero-pads model inputs and targets based on number of frames per setep 68 | """ 69 | def __init__(self, n_frames_per_step): 70 | self.n_frames_per_step = n_frames_per_step 71 | 72 | def __call__(self, batch): 73 | """Collate's training batch from normalized text and mel-spectrogram 74 | PARAMS 75 | ------ 76 | batch: [text_normalized, mel_normalized] 77 | """ 78 | # Right zero-pad all one-hot text sequences to max input length 79 | input_lengths, ids_sorted_decreasing = torch.sort( 80 | torch.LongTensor([len(x[0]) for x in batch]), 81 | dim=0, descending=True) 82 | max_input_len = input_lengths[0] 83 | 84 | text_padded = torch.LongTensor(len(batch), max_input_len) 85 | text_padded.zero_() 86 | for i in range(len(ids_sorted_decreasing)): 87 | text = batch[ids_sorted_decreasing[i]][0] 88 | text_padded[i, :text.size(0)] = text 89 | 90 | # Right zero-pad mel-spec 91 | num_mels = batch[0][1].size(0) 92 | max_target_len = max([x[1].size(1) for x in batch]) 93 | if max_target_len % self.n_frames_per_step != 0: 94 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 95 | assert max_target_len % self.n_frames_per_step == 0 96 | 97 | # include mel padded and gate padded 98 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 99 | mel_padded.zero_() 100 | gate_padded = torch.FloatTensor(len(batch), max_target_len) 101 | gate_padded.zero_() 102 | output_lengths = torch.LongTensor(len(batch)) 103 | for i in range(len(ids_sorted_decreasing)): 104 | mel = batch[ids_sorted_decreasing[i]][1] 105 | mel_padded[i, :, :mel.size(1)] = mel 106 | gate_padded[i, mel.size(1)-1:] = 1 107 | output_lengths[i] = mel.size(1) 108 | 109 | return text_padded, input_lengths, mel_padded, gate_padded, \ 110 | output_lengths 111 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.modules import Module 4 | from torch.autograd import Variable 5 | 6 | def _flatten_dense_tensors(tensors): 7 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 8 | same dense type. 9 | Since inputs are dense, the resulting tensor will be a concatenated 1D 10 | buffer. Element-wise operation on this buffer will be equivalent to 11 | operating individually. 12 | Arguments: 13 | tensors (Iterable[Tensor]): dense tensors to flatten. 14 | Returns: 15 | A contiguous 1D buffer containing input tensors. 16 | """ 17 | if len(tensors) == 1: 18 | return tensors[0].contiguous().view(-1) 19 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 20 | return flat 21 | 22 | def _unflatten_dense_tensors(flat, tensors): 23 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 24 | same dense type, and that flat is given by _flatten_dense_tensors. 25 | Arguments: 26 | flat (Tensor): flattened dense tensors to unflatten. 27 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 28 | unflatten flat. 29 | Returns: 30 | Unflattened dense tensors with sizes same as tensors and values from 31 | flat. 32 | """ 33 | outputs = [] 34 | offset = 0 35 | for tensor in tensors: 36 | numel = tensor.numel() 37 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 38 | offset += numel 39 | return tuple(outputs) 40 | 41 | 42 | ''' 43 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 44 | launcher included with this example. It assumes that your run is using multiprocess with 1 45 | GPU/process, that the model is on the correct device, and that torch.set_device has been 46 | used to set the device. 47 | 48 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 49 | and will be allreduced at the finish of the backward pass. 50 | ''' 51 | class DistributedDataParallel(Module): 52 | 53 | def __init__(self, module): 54 | super(DistributedDataParallel, self).__init__() 55 | #fallback for PyTorch 0.3 56 | if not hasattr(dist, '_backend'): 57 | self.warn_on_half = True 58 | else: 59 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 60 | 61 | self.module = module 62 | 63 | for p in self.module.state_dict().values(): 64 | if not torch.is_tensor(p): 65 | continue 66 | dist.broadcast(p, 0) 67 | 68 | def allreduce_params(): 69 | if(self.needs_reduction): 70 | self.needs_reduction = False 71 | buckets = {} 72 | for param in self.module.parameters(): 73 | if param.requires_grad and param.grad is not None: 74 | tp = type(param.data) 75 | if tp not in buckets: 76 | buckets[tp] = [] 77 | buckets[tp].append(param) 78 | if self.warn_on_half: 79 | if torch.cuda.HalfTensor in buckets: 80 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 81 | " It is recommended to use the NCCL backend in this case. This currently requires" + 82 | "PyTorch built from top of tree master.") 83 | self.warn_on_half = False 84 | 85 | for tp in buckets: 86 | bucket = buckets[tp] 87 | grads = [param.grad.data for param in bucket] 88 | coalesced = _flatten_dense_tensors(grads) 89 | dist.all_reduce(coalesced) 90 | coalesced /= dist.get_world_size() 91 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 92 | buf.copy_(synced) 93 | 94 | for param in list(self.module.parameters()): 95 | def allreduce_hook(*unused): 96 | param._execution_engine.queue_callback(allreduce_params) 97 | if param.requires_grad: 98 | param.register_hook(allreduce_hook) 99 | 100 | def forward(self, *inputs, **kwargs): 101 | self.needs_reduction = True 102 | return self.module(*inputs, **kwargs) 103 | 104 | ''' 105 | def _sync_buffers(self): 106 | buffers = list(self.module._all_buffers()) 107 | if len(buffers) > 0: 108 | # cross-node buffer sync 109 | flat_buffers = _flatten_dense_tensors(buffers) 110 | dist.broadcast(flat_buffers, 0) 111 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 112 | buf.copy_(synced) 113 | def train(self, mode=True): 114 | # Clear NCCL communicator and CUDA event cache of the default group ID, 115 | # These cache will be recreated at the later call. This is currently a 116 | # work-around for a potential NCCL deadlock. 117 | if dist._backend == dist.dist_backend.NCCL: 118 | dist._clear_group_cache() 119 | super(DistributedDataParallel, self).train(mode) 120 | self.module.train(mode) 121 | ''' 122 | ''' 123 | Modifies existing model to do gradient allreduce, but doesn't change class 124 | so you don't need "module" 125 | ''' 126 | def apply_gradient_allreduce(module): 127 | if not hasattr(dist, '_backend'): 128 | module.warn_on_half = True 129 | else: 130 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 131 | 132 | for p in module.state_dict().values(): 133 | if not torch.is_tensor(p): 134 | continue 135 | dist.broadcast(p, 0) 136 | 137 | def allreduce_params(): 138 | if(module.needs_reduction): 139 | module.needs_reduction = False 140 | buckets = {} 141 | for param in module.parameters(): 142 | if param.requires_grad and param.grad is not None: 143 | tp = param.data.dtype 144 | if tp not in buckets: 145 | buckets[tp] = [] 146 | buckets[tp].append(param) 147 | if module.warn_on_half: 148 | if torch.cuda.HalfTensor in buckets: 149 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 150 | " It is recommended to use the NCCL backend in this case. This currently requires" + 151 | "PyTorch built from top of tree master.") 152 | module.warn_on_half = False 153 | 154 | for tp in buckets: 155 | bucket = buckets[tp] 156 | grads = [param.grad.data for param in bucket] 157 | coalesced = _flatten_dense_tensors(grads) 158 | dist.all_reduce(coalesced) 159 | coalesced /= dist.get_world_size() 160 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 161 | buf.copy_(synced) 162 | 163 | for param in list(module.parameters()): 164 | def allreduce_hook(*unused): 165 | Variable._execution_engine.queue_callback(allreduce_params) 166 | if param.requires_grad: 167 | param.register_hook(allreduce_hook) 168 | 169 | def set_needs_reduction(self, input, output): 170 | self.needs_reduction = True 171 | 172 | module.register_forward_hook(set_needs_reduction) 173 | return module 174 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from text import VOCAB_DICT 3 | 4 | 5 | def create_hparams(hparams_string=None, verbose=False): 6 | """Create model hyperparameters. Parse nondefault from given string.""" 7 | 8 | hparams = tf.contrib.training.HParams( 9 | ################################ 10 | # Experiment Parameters # 11 | ################################ 12 | epochs=500, 13 | iters_per_checkpoint=1000, 14 | seed=1234, 15 | dynamic_loss_scaling=True, 16 | fp16_run=False, 17 | distributed_run=False, 18 | dist_backend="nccl", 19 | dist_url="tcp://localhost:54321", 20 | cudnn_enabled=True, 21 | cudnn_benchmark=False, 22 | ignore_layers=['embedding.weight'], 23 | 24 | ################################ 25 | # Data Parameters # 26 | ################################ 27 | load_mel_from_disk=False, 28 | training_files='data/train_filelist.txt', 29 | validation_files='data/valid_filelist.txt', 30 | 31 | ################################ 32 | # Audio Parameters # 33 | ################################ 34 | max_wav_value=32768.0, 35 | sampling_rate=22050, 36 | filter_length=1024, 37 | hop_length=256, 38 | win_length=1024, 39 | n_mel_channels=80, 40 | mel_fmin=0.0, 41 | mel_fmax=8000.0, 42 | 43 | ################################ 44 | # Model Parameters # 45 | ################################ 46 | n_symbols=len(VOCAB_DICT.keys()), 47 | symbols_embedding_dim=512, 48 | 49 | # Encoder parameters 50 | encoder_kernel_size=5, 51 | encoder_n_convolutions=3, 52 | encoder_embedding_dim=512, 53 | 54 | # Decoder parameters 55 | n_frames_per_step=1, # currently only 1 is supported 56 | decoder_rnn_dim=1024, 57 | prenet_dim=256, 58 | max_decoder_steps=1000, 59 | gate_threshold=0.5, 60 | p_attention_dropout=0.1, 61 | p_decoder_dropout=0.1, 62 | 63 | # Attention parameters 64 | attention_rnn_dim=1024, 65 | attention_dim=128, 66 | 67 | # Location Layer parameters 68 | attention_location_n_filters=32, 69 | attention_location_kernel_size=31, 70 | 71 | # Mel-post processing network parameters 72 | postnet_embedding_dim=512, 73 | postnet_kernel_size=5, 74 | postnet_n_convolutions=5, 75 | 76 | ################################ 77 | # Optimization Hyperparameters # 78 | ################################ 79 | use_saved_learning_rate=False, 80 | learning_rate=1e-3, 81 | weight_decay=1e-6, 82 | grad_clip_thresh=1.0, 83 | batch_size=32, 84 | mask_padding=True # set model's padded outputs to padded values 85 | ) 86 | 87 | if hparams_string: 88 | tf.logging.info('Parsing command line hparams: %s', hparams_string) 89 | hparams.parse(hparams_string) 90 | 91 | if verbose: 92 | tf.logging.info('Final parsed hparams: %s', hparams.values()) 93 | 94 | return hparams 95 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from audio_processing import dynamic_range_compression 4 | from audio_processing import dynamic_range_decompression 5 | from stft import STFT 6 | 7 | 8 | class LinearNorm(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 10 | super(LinearNorm, self).__init__() 11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 12 | 13 | torch.nn.init.xavier_uniform_( 14 | self.linear_layer.weight, 15 | gain=torch.nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class ConvNorm(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 23 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 24 | super(ConvNorm, self).__init__() 25 | if padding is None: 26 | assert(kernel_size % 2 == 1) 27 | padding = int(dilation * (kernel_size - 1) / 2) 28 | 29 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, 32 | bias=bias) 33 | 34 | torch.nn.init.xavier_uniform_( 35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 36 | 37 | def forward(self, signal): 38 | conv_signal = self.conv(signal) 39 | return conv_signal 40 | 41 | 42 | class TacotronSTFT(torch.nn.Module): 43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 45 | mel_fmax=8000.0): 46 | super(TacotronSTFT, self).__init__() 47 | self.n_mel_channels = n_mel_channels 48 | self.sampling_rate = sampling_rate 49 | self.stft_fn = STFT(filter_length, hop_length, win_length) 50 | mel_basis = librosa_mel_fn( 51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 52 | mel_basis = torch.from_numpy(mel_basis).float() 53 | self.register_buffer('mel_basis', mel_basis) 54 | 55 | def spectral_normalize(self, magnitudes): 56 | output = dynamic_range_compression(magnitudes) 57 | return output 58 | 59 | def spectral_de_normalize(self, magnitudes): 60 | output = dynamic_range_decompression(magnitudes) 61 | return output 62 | 63 | def mel_spectrogram(self, y): 64 | """Computes mel-spectrograms from a batch of waves 65 | PARAMS 66 | ------ 67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 68 | 69 | RETURNS 70 | ------- 71 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 72 | """ 73 | assert(torch.min(y.data) >= -1) 74 | assert(torch.max(y.data) <= 1) 75 | 76 | magnitudes, phases = self.stft_fn.transform(y) 77 | magnitudes = magnitudes.data 78 | mel_output = torch.matmul(self.mel_basis, magnitudes) 79 | mel_output = self.spectral_normalize(mel_output) 80 | return mel_output 81 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy 5 | from plotting_utils import plot_gate_outputs_to_numpy 6 | 7 | 8 | class Tacotron2Logger(SummaryWriter): 9 | def __init__(self, logdir): 10 | super(Tacotron2Logger, self).__init__(logdir) 11 | 12 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, 13 | iteration): 14 | self.add_scalar("training.loss", reduced_loss, iteration) 15 | self.add_scalar("grad.norm", grad_norm, iteration) 16 | self.add_scalar("learning.rate", learning_rate, iteration) 17 | self.add_scalar("duration", duration, iteration) 18 | 19 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 20 | self.add_scalar("validation.loss", reduced_loss, iteration) 21 | _, mel_outputs, gate_outputs, alignments = y_pred 22 | mel_targets, gate_targets = y 23 | 24 | # plot distribution of parameters 25 | for tag, value in model.named_parameters(): 26 | tag = tag.replace('.', '/') 27 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 28 | 29 | # plot alignment, mel target and predicted, gate target and predicted 30 | idx = random.randint(0, alignments.size(0) - 1) 31 | self.add_image( 32 | "alignment", 33 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 34 | iteration, dataformats='HWC') 35 | self.add_image( 36 | "mel_target", 37 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 38 | iteration, dataformats='HWC') 39 | self.add_image( 40 | "mel_predicted", 41 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 42 | iteration, dataformats='HWC') 43 | self.add_image( 44 | "gate", 45 | plot_gate_outputs_to_numpy( 46 | gate_targets[idx].data.cpu().numpy(), 47 | torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 48 | iteration, dataformats='HWC') 49 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Tacotron2Loss(nn.Module): 5 | def __init__(self): 6 | super(Tacotron2Loss, self).__init__() 7 | 8 | def forward(self, model_output, targets): 9 | mel_target, gate_target = targets[0], targets[1] 10 | mel_target.requires_grad = False 11 | gate_target.requires_grad = False 12 | gate_target = gate_target.view(-1, 1) 13 | 14 | mel_out, mel_out_postnet, gate_out, _ = model_output 15 | gate_out = gate_out.view(-1, 1) 16 | mel_loss = nn.MSELoss()(mel_out, mel_target) + \ 17 | nn.MSELoss()(mel_out_postnet, mel_target) 18 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 19 | return mel_loss + gate_loss 20 | -------------------------------------------------------------------------------- /loss_scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LossScaler: 4 | 5 | def __init__(self, scale=1): 6 | self.cur_scale = scale 7 | 8 | # `params` is a list / generator of torch.Variable 9 | def has_overflow(self, params): 10 | return False 11 | 12 | # `x` is a torch.Tensor 13 | def _has_inf_or_nan(x): 14 | return False 15 | 16 | # `overflow` is boolean indicating whether we overflowed in gradient 17 | def update_scale(self, overflow): 18 | pass 19 | 20 | @property 21 | def loss_scale(self): 22 | return self.cur_scale 23 | 24 | def scale_gradient(self, module, grad_in, grad_out): 25 | return tuple(self.loss_scale * g for g in grad_in) 26 | 27 | def backward(self, loss): 28 | scaled_loss = loss*self.loss_scale 29 | scaled_loss.backward() 30 | 31 | class DynamicLossScaler: 32 | 33 | def __init__(self, 34 | init_scale=2**32, 35 | scale_factor=2., 36 | scale_window=1000): 37 | self.cur_scale = init_scale 38 | self.cur_iter = 0 39 | self.last_overflow_iter = -1 40 | self.scale_factor = scale_factor 41 | self.scale_window = scale_window 42 | 43 | # `params` is a list / generator of torch.Variable 44 | def has_overflow(self, params): 45 | # return False 46 | for p in params: 47 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 48 | return True 49 | 50 | return False 51 | 52 | # `x` is a torch.Tensor 53 | def _has_inf_or_nan(x): 54 | cpu_sum = float(x.float().sum()) 55 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 56 | return True 57 | return False 58 | 59 | # `overflow` is boolean indicating whether we overflowed in gradient 60 | def update_scale(self, overflow): 61 | if overflow: 62 | #self.cur_scale /= self.scale_factor 63 | self.cur_scale = max(self.cur_scale/self.scale_factor, 1) 64 | self.last_overflow_iter = self.cur_iter 65 | else: 66 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 67 | self.cur_scale *= self.scale_factor 68 | # self.cur_scale = 1 69 | self.cur_iter += 1 70 | 71 | @property 72 | def loss_scale(self): 73 | return self.cur_scale 74 | 75 | def scale_gradient(self, module, grad_in, grad_out): 76 | return tuple(self.loss_scale * g for g in grad_in) 77 | 78 | def backward(self, loss): 79 | scaled_loss = loss*self.loss_scale 80 | scaled_loss.backward() 81 | 82 | ############################################################## 83 | # Example usage below here -- assuming it's in a separate file 84 | ############################################################## 85 | if __name__ == "__main__": 86 | import torch 87 | from torch.autograd import Variable 88 | from dynamic_loss_scaler import DynamicLossScaler 89 | 90 | # N is batch size; D_in is input dimension; 91 | # H is hidden dimension; D_out is output dimension. 92 | N, D_in, H, D_out = 64, 1000, 100, 10 93 | 94 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 95 | x = Variable(torch.randn(N, D_in), requires_grad=False) 96 | y = Variable(torch.randn(N, D_out), requires_grad=False) 97 | 98 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 99 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 100 | parameters = [w1, w2] 101 | 102 | learning_rate = 1e-6 103 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 104 | loss_scaler = DynamicLossScaler() 105 | 106 | for t in range(500): 107 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 108 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 109 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 110 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 111 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 112 | 113 | # Run backprop 114 | optimizer.zero_grad() 115 | loss.backward() 116 | 117 | # Check for overflow 118 | has_overflow = DynamicLossScaler.has_overflow(parameters) 119 | 120 | # If no overflow, unscale grad and update as usual 121 | if not has_overflow: 122 | for param in parameters: 123 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 124 | optimizer.step() 125 | # Otherwise, don't do anything -- ie, skip iteration 126 | else: 127 | print('OVERFLOW!') 128 | 129 | # Update loss scale for next iteration 130 | loss_scaler.update_scale(has_overflow) 131 | 132 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from layers import ConvNorm, LinearNorm 7 | from utils import to_gpu, get_mask_from_lengths 8 | 9 | 10 | class LocationLayer(nn.Module): 11 | def __init__(self, attention_n_filters, attention_kernel_size, 12 | attention_dim): 13 | super(LocationLayer, self).__init__() 14 | padding = int((attention_kernel_size - 1) / 2) 15 | self.location_conv = ConvNorm(2, attention_n_filters, 16 | kernel_size=attention_kernel_size, 17 | padding=padding, bias=False, stride=1, 18 | dilation=1) 19 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 20 | bias=False, w_init_gain='tanh') 21 | 22 | def forward(self, attention_weights_cat): 23 | processed_attention = self.location_conv(attention_weights_cat) 24 | processed_attention = processed_attention.transpose(1, 2) 25 | processed_attention = self.location_dense(processed_attention) 26 | return processed_attention 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 31 | attention_location_n_filters, attention_location_kernel_size): 32 | super(Attention, self).__init__() 33 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 34 | bias=False, w_init_gain='tanh') 35 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 36 | w_init_gain='tanh') 37 | self.v = LinearNorm(attention_dim, 1, bias=False) 38 | self.location_layer = LocationLayer(attention_location_n_filters, 39 | attention_location_kernel_size, 40 | attention_dim) 41 | self.score_mask_value = -float("inf") 42 | 43 | def get_alignment_energies(self, query, processed_memory, 44 | attention_weights_cat): 45 | """ 46 | PARAMS 47 | ------ 48 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 49 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 50 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 51 | 52 | RETURNS 53 | ------- 54 | alignment (batch, max_time) 55 | """ 56 | 57 | processed_query = self.query_layer(query.unsqueeze(1)) 58 | processed_attention_weights = self.location_layer(attention_weights_cat) 59 | energies = self.v(torch.tanh( 60 | processed_query + processed_attention_weights + processed_memory)) 61 | 62 | energies = energies.squeeze(-1) 63 | return energies 64 | 65 | def forward(self, attention_hidden_state, memory, processed_memory, 66 | attention_weights_cat, mask): 67 | """ 68 | PARAMS 69 | ------ 70 | attention_hidden_state: attention rnn last output 71 | memory: encoder outputs 72 | processed_memory: processed encoder outputs 73 | attention_weights_cat: previous and cummulative attention weights 74 | mask: binary mask for padded data 75 | """ 76 | alignment = self.get_alignment_energies( 77 | attention_hidden_state, processed_memory, attention_weights_cat) 78 | 79 | if mask is not None: 80 | alignment.data.masked_fill_(mask, self.score_mask_value) 81 | 82 | attention_weights = F.softmax(alignment, dim=1) 83 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 84 | attention_context = attention_context.squeeze(1) 85 | 86 | return attention_context, attention_weights 87 | 88 | 89 | class Prenet(nn.Module): 90 | def __init__(self, in_dim, sizes): 91 | super(Prenet, self).__init__() 92 | in_sizes = [in_dim] + sizes[:-1] 93 | self.layers = nn.ModuleList( 94 | [LinearNorm(in_size, out_size, bias=False) 95 | for (in_size, out_size) in zip(in_sizes, sizes)]) 96 | 97 | def forward(self, x): 98 | for linear in self.layers: 99 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 100 | return x 101 | 102 | 103 | class Postnet(nn.Module): 104 | """Postnet 105 | - Five 1-d convolution with 512 channels and kernel size 5 106 | """ 107 | 108 | def __init__(self, hparams): 109 | super(Postnet, self).__init__() 110 | self.convolutions = nn.ModuleList() 111 | 112 | self.convolutions.append( 113 | nn.Sequential( 114 | ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, 115 | kernel_size=hparams.postnet_kernel_size, stride=1, 116 | padding=int((hparams.postnet_kernel_size - 1) / 2), 117 | dilation=1, w_init_gain='tanh'), 118 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 119 | ) 120 | 121 | for i in range(1, hparams.postnet_n_convolutions - 1): 122 | self.convolutions.append( 123 | nn.Sequential( 124 | ConvNorm(hparams.postnet_embedding_dim, 125 | hparams.postnet_embedding_dim, 126 | kernel_size=hparams.postnet_kernel_size, stride=1, 127 | padding=int((hparams.postnet_kernel_size - 1) / 2), 128 | dilation=1, w_init_gain='tanh'), 129 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 130 | ) 131 | 132 | self.convolutions.append( 133 | nn.Sequential( 134 | ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, 135 | kernel_size=hparams.postnet_kernel_size, stride=1, 136 | padding=int((hparams.postnet_kernel_size - 1) / 2), 137 | dilation=1, w_init_gain='linear'), 138 | nn.BatchNorm1d(hparams.n_mel_channels)) 139 | ) 140 | 141 | def forward(self, x): 142 | for i in range(len(self.convolutions) - 1): 143 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 144 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 145 | 146 | return x 147 | 148 | 149 | class Encoder(nn.Module): 150 | """Encoder module: 151 | - Three 1-d convolution banks 152 | - Bidirectional LSTM 153 | """ 154 | def __init__(self, hparams): 155 | super(Encoder, self).__init__() 156 | 157 | convolutions = [] 158 | for _ in range(hparams.encoder_n_convolutions): 159 | conv_layer = nn.Sequential( 160 | ConvNorm(hparams.encoder_embedding_dim, 161 | hparams.encoder_embedding_dim, 162 | kernel_size=hparams.encoder_kernel_size, stride=1, 163 | padding=int((hparams.encoder_kernel_size - 1) / 2), 164 | dilation=1, w_init_gain='relu'), 165 | nn.BatchNorm1d(hparams.encoder_embedding_dim)) 166 | convolutions.append(conv_layer) 167 | self.convolutions = nn.ModuleList(convolutions) 168 | 169 | self.lstm = nn.LSTM(hparams.encoder_embedding_dim, 170 | int(hparams.encoder_embedding_dim / 2), 1, 171 | batch_first=True, bidirectional=True) 172 | 173 | def forward(self, x, input_lengths): 174 | for conv in self.convolutions: 175 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 176 | 177 | x = x.transpose(1, 2) 178 | 179 | # pytorch tensor are not reversible, hence the conversion 180 | input_lengths = input_lengths.cpu().numpy() 181 | x = nn.utils.rnn.pack_padded_sequence( 182 | x, input_lengths, batch_first=True) 183 | 184 | self.lstm.flatten_parameters() 185 | outputs, _ = self.lstm(x) 186 | 187 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 188 | outputs, batch_first=True) 189 | 190 | return outputs 191 | 192 | def inference(self, x): 193 | for conv in self.convolutions: 194 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 195 | 196 | x = x.transpose(1, 2) 197 | 198 | self.lstm.flatten_parameters() 199 | outputs, _ = self.lstm(x) 200 | 201 | return outputs 202 | 203 | 204 | class Decoder(nn.Module): 205 | def __init__(self, hparams): 206 | super(Decoder, self).__init__() 207 | self.n_mel_channels = hparams.n_mel_channels 208 | self.n_frames_per_step = hparams.n_frames_per_step 209 | self.encoder_embedding_dim = hparams.encoder_embedding_dim 210 | self.attention_rnn_dim = hparams.attention_rnn_dim 211 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 212 | self.prenet_dim = hparams.prenet_dim 213 | self.max_decoder_steps = hparams.max_decoder_steps 214 | self.gate_threshold = hparams.gate_threshold 215 | self.p_attention_dropout = hparams.p_attention_dropout 216 | self.p_decoder_dropout = hparams.p_decoder_dropout 217 | 218 | self.prenet = Prenet( 219 | hparams.n_mel_channels * hparams.n_frames_per_step, 220 | [hparams.prenet_dim, hparams.prenet_dim]) 221 | 222 | self.attention_rnn = nn.LSTMCell( 223 | hparams.prenet_dim + hparams.encoder_embedding_dim, 224 | hparams.attention_rnn_dim) 225 | 226 | self.attention_layer = Attention( 227 | hparams.attention_rnn_dim, hparams.encoder_embedding_dim, 228 | hparams.attention_dim, hparams.attention_location_n_filters, 229 | hparams.attention_location_kernel_size) 230 | 231 | self.decoder_rnn = nn.LSTMCell( 232 | hparams.attention_rnn_dim + hparams.encoder_embedding_dim, 233 | hparams.decoder_rnn_dim, 1) 234 | 235 | self.linear_projection = LinearNorm( 236 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 237 | hparams.n_mel_channels * hparams.n_frames_per_step) 238 | 239 | self.gate_layer = LinearNorm( 240 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, 241 | bias=True, w_init_gain='sigmoid') 242 | 243 | def get_go_frame(self, memory): 244 | """ Gets all zeros frames to use as first decoder input 245 | PARAMS 246 | ------ 247 | memory: decoder outputs 248 | 249 | RETURNS 250 | ------- 251 | decoder_input: all zeros frames 252 | """ 253 | B = memory.size(0) 254 | decoder_input = Variable(memory.data.new( 255 | B, self.n_mel_channels * self.n_frames_per_step).zero_()) 256 | return decoder_input 257 | 258 | def initialize_decoder_states(self, memory, mask): 259 | """ Initializes attention rnn states, decoder rnn states, attention 260 | weights, attention cumulative weights, attention context, stores memory 261 | and stores processed memory 262 | PARAMS 263 | ------ 264 | memory: Encoder outputs 265 | mask: Mask for padded data if training, expects None for inference 266 | """ 267 | B = memory.size(0) 268 | MAX_TIME = memory.size(1) 269 | 270 | self.attention_hidden = Variable(memory.data.new( 271 | B, self.attention_rnn_dim).zero_()) 272 | self.attention_cell = Variable(memory.data.new( 273 | B, self.attention_rnn_dim).zero_()) 274 | 275 | self.decoder_hidden = Variable(memory.data.new( 276 | B, self.decoder_rnn_dim).zero_()) 277 | self.decoder_cell = Variable(memory.data.new( 278 | B, self.decoder_rnn_dim).zero_()) 279 | 280 | self.attention_weights = Variable(memory.data.new( 281 | B, MAX_TIME).zero_()) 282 | self.attention_weights_cum = Variable(memory.data.new( 283 | B, MAX_TIME).zero_()) 284 | self.attention_context = Variable(memory.data.new( 285 | B, self.encoder_embedding_dim).zero_()) 286 | 287 | self.memory = memory 288 | self.processed_memory = self.attention_layer.memory_layer(memory) 289 | self.mask = mask 290 | 291 | def parse_decoder_inputs(self, decoder_inputs): 292 | """ Prepares decoder inputs, i.e. mel outputs 293 | PARAMS 294 | ------ 295 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 296 | 297 | RETURNS 298 | ------- 299 | inputs: processed decoder inputs 300 | 301 | """ 302 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 303 | decoder_inputs = decoder_inputs.transpose(1, 2) 304 | decoder_inputs = decoder_inputs.view( 305 | decoder_inputs.size(0), 306 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1) 307 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 308 | decoder_inputs = decoder_inputs.transpose(0, 1) 309 | return decoder_inputs 310 | 311 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): 312 | """ Prepares decoder outputs for output 313 | PARAMS 314 | ------ 315 | mel_outputs: 316 | gate_outputs: gate output energies 317 | alignments: 318 | 319 | RETURNS 320 | ------- 321 | mel_outputs: 322 | gate_outpust: gate output energies 323 | alignments: 324 | """ 325 | # (T_out, B) -> (B, T_out) 326 | alignments = torch.stack(alignments).transpose(0, 1) 327 | # (T_out, B) -> (B, T_out) 328 | gate_outputs = torch.stack(gate_outputs).transpose(0, 1) 329 | gate_outputs = gate_outputs.contiguous() 330 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 331 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 332 | # decouple frames per step 333 | mel_outputs = mel_outputs.view( 334 | mel_outputs.size(0), -1, self.n_mel_channels) 335 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 336 | mel_outputs = mel_outputs.transpose(1, 2) 337 | 338 | return mel_outputs, gate_outputs, alignments 339 | 340 | def decode(self, decoder_input): 341 | """ Decoder step using stored states, attention and memory 342 | PARAMS 343 | ------ 344 | decoder_input: previous mel output 345 | 346 | RETURNS 347 | ------- 348 | mel_output: 349 | gate_output: gate output energies 350 | attention_weights: 351 | """ 352 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 353 | self.attention_hidden, self.attention_cell = self.attention_rnn( 354 | cell_input, (self.attention_hidden, self.attention_cell)) 355 | self.attention_hidden = F.dropout( 356 | self.attention_hidden, self.p_attention_dropout, self.training) 357 | 358 | attention_weights_cat = torch.cat( 359 | (self.attention_weights.unsqueeze(1), 360 | self.attention_weights_cum.unsqueeze(1)), dim=1) 361 | self.attention_context, self.attention_weights = self.attention_layer( 362 | self.attention_hidden, self.memory, self.processed_memory, 363 | attention_weights_cat, self.mask) 364 | 365 | self.attention_weights_cum += self.attention_weights 366 | decoder_input = torch.cat( 367 | (self.attention_hidden, self.attention_context), -1) 368 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 369 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 370 | self.decoder_hidden = F.dropout( 371 | self.decoder_hidden, self.p_decoder_dropout, self.training) 372 | 373 | decoder_hidden_attention_context = torch.cat( 374 | (self.decoder_hidden, self.attention_context), dim=1) 375 | decoder_output = self.linear_projection( 376 | decoder_hidden_attention_context) 377 | 378 | gate_prediction = self.gate_layer(decoder_hidden_attention_context) 379 | return decoder_output, gate_prediction, self.attention_weights 380 | 381 | def forward(self, memory, decoder_inputs, memory_lengths): 382 | """ Decoder forward pass for training 383 | PARAMS 384 | ------ 385 | memory: Encoder outputs 386 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs 387 | memory_lengths: Encoder output lengths for attention masking. 388 | 389 | RETURNS 390 | ------- 391 | mel_outputs: mel outputs from the decoder 392 | gate_outputs: gate outputs from the decoder 393 | alignments: sequence of attention weights from the decoder 394 | """ 395 | 396 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 397 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 398 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 399 | decoder_inputs = self.prenet(decoder_inputs) 400 | 401 | self.initialize_decoder_states( 402 | memory, mask=~get_mask_from_lengths(memory_lengths)) 403 | 404 | mel_outputs, gate_outputs, alignments = [], [], [] 405 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 406 | decoder_input = decoder_inputs[len(mel_outputs)] 407 | mel_output, gate_output, attention_weights = self.decode( 408 | decoder_input) 409 | mel_outputs += [mel_output.squeeze(1)] 410 | gate_outputs += [gate_output.squeeze(1)] 411 | alignments += [attention_weights] 412 | 413 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 414 | mel_outputs, gate_outputs, alignments) 415 | 416 | return mel_outputs, gate_outputs, alignments 417 | 418 | def inference(self, memory): 419 | """ Decoder inference 420 | PARAMS 421 | ------ 422 | memory: Encoder outputs 423 | 424 | RETURNS 425 | ------- 426 | mel_outputs: mel outputs from the decoder 427 | gate_outputs: gate outputs from the decoder 428 | alignments: sequence of attention weights from the decoder 429 | """ 430 | decoder_input = self.get_go_frame(memory) 431 | 432 | self.initialize_decoder_states(memory, mask=None) 433 | 434 | mel_outputs, gate_outputs, alignments = [], [], [] 435 | while True: 436 | decoder_input = self.prenet(decoder_input) 437 | mel_output, gate_output, alignment = self.decode(decoder_input) 438 | 439 | mel_outputs += [mel_output.squeeze(1)] 440 | gate_outputs += [gate_output] 441 | alignments += [alignment] 442 | 443 | if torch.sigmoid(gate_output.data) > self.gate_threshold: 444 | break 445 | elif len(mel_outputs) == self.max_decoder_steps: 446 | print("Warning! Reached max decoder steps") 447 | break 448 | 449 | decoder_input = mel_output 450 | 451 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 452 | mel_outputs, gate_outputs, alignments) 453 | 454 | return mel_outputs, gate_outputs, alignments 455 | 456 | 457 | class Tacotron2(nn.Module): 458 | def __init__(self, hparams): 459 | super(Tacotron2, self).__init__() 460 | self.mask_padding = hparams.mask_padding 461 | self.fp16_run = hparams.fp16_run 462 | self.n_mel_channels = hparams.n_mel_channels 463 | self.n_frames_per_step = hparams.n_frames_per_step 464 | self.embedding = nn.Embedding( 465 | hparams.n_symbols, hparams.symbols_embedding_dim) 466 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 467 | val = sqrt(3.0) * std # uniform bounds for std 468 | self.embedding.weight.data.uniform_(-val, val) 469 | self.encoder = Encoder(hparams) 470 | self.decoder = Decoder(hparams) 471 | self.postnet = Postnet(hparams) 472 | 473 | def parse_batch(self, batch): 474 | text_padded, input_lengths, mel_padded, gate_padded, \ 475 | output_lengths = batch 476 | text_padded = to_gpu(text_padded).long() 477 | input_lengths = to_gpu(input_lengths).long() 478 | max_len = torch.max(input_lengths.data).item() 479 | mel_padded = to_gpu(mel_padded).float() 480 | gate_padded = to_gpu(gate_padded).float() 481 | output_lengths = to_gpu(output_lengths).long() 482 | 483 | return ( 484 | (text_padded, input_lengths, mel_padded, max_len, output_lengths), 485 | (mel_padded, gate_padded)) 486 | 487 | def parse_output(self, outputs, output_lengths=None): 488 | if self.mask_padding and output_lengths is not None: 489 | mask = ~get_mask_from_lengths(output_lengths) 490 | mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) 491 | mask = mask.permute(1, 0, 2) 492 | 493 | outputs[0].data.masked_fill_(mask, 0.0) 494 | outputs[1].data.masked_fill_(mask, 0.0) 495 | outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies 496 | 497 | return outputs 498 | 499 | def forward(self, inputs): 500 | text_inputs, text_lengths, mels, max_len, output_lengths = inputs 501 | text_lengths, output_lengths = text_lengths.data, output_lengths.data 502 | 503 | embedded_inputs = self.embedding(text_inputs).transpose(1, 2) 504 | 505 | encoder_outputs = self.encoder(embedded_inputs, text_lengths) 506 | 507 | mel_outputs, gate_outputs, alignments = self.decoder( 508 | encoder_outputs, mels, memory_lengths=text_lengths) 509 | 510 | mel_outputs_postnet = self.postnet(mel_outputs) 511 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 512 | 513 | return self.parse_output([mel_outputs, mel_outputs_postnet, gate_outputs, alignments], output_lengths) 514 | 515 | def inference(self, inputs): 516 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 517 | encoder_outputs = self.encoder.inference(embedded_inputs) 518 | mel_outputs, gate_outputs, alignments = self.decoder.inference( 519 | encoder_outputs) 520 | 521 | mel_outputs_postnet = self.postnet(mel_outputs) 522 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 523 | 524 | outputs = self.parse_output([mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) 525 | 526 | return outputs 527 | -------------------------------------------------------------------------------- /multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | 46 | 47 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 48 | fig, ax = plt.subplots(figsize=(12, 3)) 49 | ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, 50 | color='green', marker='+', s=1, label='target') 51 | ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, 52 | color='red', marker='.', s=1, label='predicted') 53 | 54 | plt.xlabel("Frames (Green target, Red predicted)") 55 | plt.ylabel("Gate State") 56 | plt.tight_layout() 57 | 58 | fig.canvas.draw() 59 | data = save_figure_to_numpy(fig) 60 | plt.close() 61 | return data 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.1.0 2 | tensorflow==1.15.2 3 | numpy==1.13.3 4 | inflect==0.2.5 5 | librosa==0.6.0 6 | scipy==1.0.0 7 | Unidecode==1.0.22 8 | pillow 9 | -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann'): 46 | super(STFT, self).__init__() 47 | self.filter_length = filter_length 48 | self.hop_length = hop_length 49 | self.win_length = win_length 50 | self.window = window 51 | self.forward_transform = None 52 | scale = self.filter_length / self.hop_length 53 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 54 | 55 | cutoff = int((self.filter_length / 2 + 1)) 56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 57 | np.imag(fourier_basis[:cutoff, :])]) 58 | 59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 60 | inverse_basis = torch.FloatTensor( 61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 62 | 63 | if window is not None: 64 | assert(filter_length >= win_length) 65 | # get window and zero center pad it to filter_length 66 | fft_window = get_window(window, win_length, fftbins=True) 67 | fft_window = pad_center(fft_window, filter_length) 68 | fft_window = torch.from_numpy(fft_window).float() 69 | 70 | # window the bases 71 | forward_basis *= fft_window 72 | inverse_basis *= fft_window 73 | 74 | self.register_buffer('forward_basis', forward_basis.float()) 75 | self.register_buffer('inverse_basis', inverse_basis.float()) 76 | 77 | def transform(self, input_data): 78 | num_batches = input_data.size(0) 79 | num_samples = input_data.size(1) 80 | 81 | self.num_samples = num_samples 82 | 83 | # similar to librosa, reflect-pad the input 84 | input_data = input_data.view(num_batches, 1, num_samples) 85 | input_data = F.pad( 86 | input_data.unsqueeze(1), 87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 88 | mode='reflect') 89 | input_data = input_data.squeeze(1) 90 | 91 | forward_transform = F.conv1d( 92 | input_data, 93 | Variable(self.forward_basis, requires_grad=False), 94 | stride=self.hop_length, 95 | padding=0) 96 | 97 | cutoff = int((self.filter_length / 2) + 1) 98 | real_part = forward_transform[:, :cutoff, :] 99 | imag_part = forward_transform[:, cutoff:, :] 100 | 101 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 102 | phase = torch.autograd.Variable( 103 | torch.atan2(imag_part.data, real_part.data)) 104 | 105 | return magnitude, phase 106 | 107 | def inverse(self, magnitude, phase): 108 | recombine_magnitude_phase = torch.cat( 109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 110 | 111 | inverse_transform = F.conv_transpose1d( 112 | recombine_magnitude_phase, 113 | Variable(self.inverse_basis, requires_grad=False), 114 | stride=self.hop_length, 115 | padding=0) 116 | 117 | if self.window is not None: 118 | window_sum = window_sumsquare( 119 | self.window, magnitude.size(-1), hop_length=self.hop_length, 120 | win_length=self.win_length, n_fft=self.filter_length, dtype=np.float32) 121 | # remove modulation effects 122 | approx_nonzero_indices = torch.from_numpy( 123 | np.where(window_sum > tiny(window_sum))[0]) 124 | window_sum = torch.autograd.Variable( 125 | torch.from_numpy(window_sum), requires_grad=False) 126 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 127 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 128 | 129 | # scale by hop ratio 130 | inverse_transform *= float(self.filter_length) / self.hop_length 131 | 132 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 133 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 134 | 135 | return inverse_transform 136 | 137 | def forward(self, input_data): 138 | self.magnitude, self.phase = self.transform(input_data) 139 | reconstruction = self.inverse(self.magnitude, self.phase) 140 | return reconstruction 141 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | from g2pk import G2p 4 | 5 | CHOSUNGS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) 6 | JOONGSUNGS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) 7 | JONGSUNGS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) 8 | SPECIALS = " ?!" 9 | 10 | ALL_VOCABS = "".join([ 11 | CHOSUNGS, 12 | JOONGSUNGS, 13 | JONGSUNGS, 14 | SPECIALS 15 | ]) 16 | VOCAB_DICT = { 17 | "_": 0, 18 | "~": 1, 19 | } 20 | 21 | for idx, v in enumerate(ALL_VOCABS): 22 | VOCAB_DICT[v] = idx + 2 23 | 24 | g2p = G2p() 25 | 26 | 27 | def normalize(text): 28 | text = unicodedata.normalize('NFKD', text) 29 | text = text.upper() 30 | regex = unicodedata.normalize('NFKD', r"[^ \u11A8-\u11FF\u1100-\u115E\u1161-\u11A7?!]") 31 | text = re.sub(regex, '', text) 32 | text = re.sub(' +', ' ', text) 33 | text = text.strip() 34 | return text 35 | 36 | 37 | def tokenize(text, encoding: bool = True): 38 | tokens = list() 39 | 40 | for t in text: 41 | if encoding: 42 | tokens.append(VOCAB_DICT[t]) 43 | else: 44 | tokens.append(t) 45 | 46 | if encoding: 47 | tokens.append(VOCAB_DICT['~']) 48 | else: 49 | tokens.append('~') 50 | 51 | return tokens 52 | 53 | 54 | def text_to_sequence(text): 55 | text = g2p(text) 56 | text = normalize(text) 57 | tokens = tokenize(text, encoding=True) 58 | return tokens 59 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import math 5 | import wandb 6 | import matplotlib.pyplot as plt 7 | from numpy import finfo 8 | 9 | import torch 10 | from distributed import apply_gradient_allreduce 11 | import torch.distributed as dist 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.utils.data import DataLoader 14 | 15 | from model import Tacotron2 16 | from data_utils import TextMelLoader, TextMelCollate 17 | from loss_function import Tacotron2Loss 18 | from logger import Tacotron2Logger 19 | from hparams import create_hparams 20 | from text import text_to_sequence 21 | 22 | wandb.init(project='taKotron2') 23 | 24 | 25 | def reduce_tensor(tensor, n_gpus): 26 | rt = tensor.clone() 27 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 28 | rt /= n_gpus 29 | return rt 30 | 31 | 32 | def init_distributed(hparams, n_gpus, rank, group_name): 33 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 34 | print("Initializing Distributed") 35 | 36 | # Set cuda device so everything is done on the right GPU. 37 | torch.cuda.set_device(rank % torch.cuda.device_count()) 38 | 39 | # Initialize distributed communication 40 | dist.init_process_group( 41 | backend=hparams.dist_backend, init_method=hparams.dist_url, 42 | world_size=n_gpus, rank=rank, group_name=group_name) 43 | 44 | print("Done initializing distributed") 45 | 46 | 47 | def get_lr(optimizer): 48 | for param_group in optimizer.param_groups: 49 | return param_group['lr'] 50 | 51 | 52 | def prepare_dataloaders(hparams): 53 | # Get data, data loaders and collate function ready 54 | trainset = TextMelLoader(hparams.training_files, hparams) 55 | valset = TextMelLoader(hparams.validation_files, hparams) 56 | collate_fn = TextMelCollate(hparams.n_frames_per_step) 57 | 58 | if hparams.distributed_run: 59 | train_sampler = DistributedSampler(trainset) 60 | shuffle = False 61 | else: 62 | train_sampler = None 63 | shuffle = True 64 | 65 | train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, 66 | sampler=train_sampler, 67 | batch_size=hparams.batch_size, pin_memory=False, 68 | drop_last=True, collate_fn=collate_fn) 69 | return train_loader, valset, collate_fn 70 | 71 | 72 | def prepare_directories_and_logger(output_directory, log_directory, rank): 73 | if rank == 0: 74 | if not os.path.isdir(output_directory): 75 | os.makedirs(output_directory) 76 | os.chmod(output_directory, 0o775) 77 | logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) 78 | else: 79 | logger = None 80 | return logger 81 | 82 | 83 | def load_model(hparams): 84 | model = Tacotron2(hparams).cuda() 85 | if hparams.fp16_run: 86 | model.decoder.attention_layer.score_mask_value = finfo('float16').min 87 | 88 | if hparams.distributed_run: 89 | model = apply_gradient_allreduce(model) 90 | 91 | return model 92 | 93 | 94 | def warm_start_model(checkpoint_path, model, ignore_layers): 95 | assert os.path.isfile(checkpoint_path) 96 | print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) 97 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 98 | model_dict = checkpoint_dict['state_dict'] 99 | if len(ignore_layers) > 0: 100 | model_dict = {k: v for k, v in model_dict.items() 101 | if k not in ignore_layers} 102 | dummy_dict = model.state_dict() 103 | dummy_dict.update(model_dict) 104 | model_dict = dummy_dict 105 | model.load_state_dict(model_dict) 106 | return model 107 | 108 | 109 | def load_checkpoint(checkpoint_path, model, optimizer): 110 | assert os.path.isfile(checkpoint_path) 111 | print("Loading checkpoint '{}'".format(checkpoint_path)) 112 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 113 | model.load_state_dict(checkpoint_dict['state_dict']) 114 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 115 | learning_rate = checkpoint_dict['learning_rate'] 116 | iteration = checkpoint_dict['iteration'] 117 | print("Loaded checkpoint '{}' from iteration {}" .format( 118 | checkpoint_path, iteration)) 119 | return model, optimizer, learning_rate, iteration 120 | 121 | 122 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 123 | print("Saving model and optimizer state at iteration {} to {}".format( 124 | iteration, filepath)) 125 | torch.save({'iteration': iteration, 126 | 'state_dict': model.state_dict(), 127 | 'optimizer': optimizer.state_dict(), 128 | 'learning_rate': learning_rate}, filepath) 129 | 130 | 131 | def validate(model, criterion, valset, iteration, batch_size, n_gpus, 132 | collate_fn, logger, distributed_run, rank): 133 | """Handles all the validation scoring and printing""" 134 | model.eval() 135 | with torch.no_grad(): 136 | val_sampler = DistributedSampler(valset) if distributed_run else None 137 | val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, 138 | shuffle=False, batch_size=batch_size, 139 | pin_memory=False, collate_fn=collate_fn) 140 | 141 | val_loss = 0.0 142 | for i, batch in enumerate(val_loader): 143 | x, y = model.parse_batch(batch) 144 | y_pred = model(x) 145 | loss = criterion(y_pred, y) 146 | if distributed_run: 147 | reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() 148 | else: 149 | reduced_val_loss = loss.item() 150 | val_loss += reduced_val_loss 151 | val_loss = val_loss / (i + 1) 152 | wandb.log({"val_loss": val_loss}) 153 | 154 | model.train() 155 | if rank == 0: 156 | print("Validation loss {}: {:9f} ".format(iteration, val_loss)) 157 | logger.log_validation(val_loss, model, y, y_pred, iteration) 158 | 159 | 160 | def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, 161 | rank, group_name, hparams, image_directory, test_text): 162 | """Training and validation logging results to tensorboard and stdout 163 | 164 | Params 165 | ------ 166 | output_directory (string): directory to save checkpoints 167 | log_directory (string) directory to save tensorboard logs 168 | checkpoint_path(string): checkpoint path 169 | n_gpus (int): number of gpus 170 | rank (int): rank of current gpu 171 | hparams (object): comma separated list of "name=value" pairs. 172 | """ 173 | if hparams.distributed_run: 174 | init_distributed(hparams, n_gpus, rank, group_name) 175 | 176 | torch.manual_seed(hparams.seed) 177 | torch.cuda.manual_seed(hparams.seed) 178 | 179 | model = load_model(hparams) 180 | learning_rate = hparams.learning_rate 181 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-03, weight_decay=hparams.weight_decay) 182 | 183 | if hparams.fp16_run: 184 | from apex import amp 185 | model, optimizer = amp.initialize( 186 | model, optimizer, opt_level='O2') 187 | 188 | if hparams.distributed_run: 189 | model = apply_gradient_allreduce(model) 190 | 191 | criterion = Tacotron2Loss() 192 | 193 | logger = prepare_directories_and_logger( 194 | output_directory, log_directory, rank) 195 | 196 | train_loader, valset, collate_fn = prepare_dataloaders(hparams) 197 | 198 | # Load checkpoint if one exists 199 | iteration = 0 200 | epoch_offset = 0 201 | if checkpoint_path is not None: 202 | if warm_start: 203 | model = warm_start_model( 204 | checkpoint_path, model, hparams.ignore_layers) 205 | else: 206 | model, optimizer, _learning_rate, iteration = load_checkpoint( 207 | checkpoint_path, model, optimizer) 208 | if hparams.use_saved_learning_rate: 209 | learning_rate = _learning_rate 210 | iteration += 1 # next iteration is iteration + 1 211 | epoch_offset = max(0, int(iteration / len(train_loader))) 212 | 213 | model.train() 214 | is_overflow = False 215 | # ================ MAIN TRAINNIG LOOP! =================== 216 | for epoch in range(epoch_offset, hparams.epochs): 217 | print("Epoch: {}".format(epoch)) 218 | for i, batch in enumerate(train_loader): 219 | start = time.perf_counter() 220 | for param_group in optimizer.param_groups: 221 | param_group['lr'] = learning_rate 222 | 223 | model.zero_grad() 224 | x, y = model.parse_batch(batch) 225 | y_pred = model(x) 226 | 227 | loss = criterion(y_pred, y) 228 | 229 | # Log generated melspectrogram & alignments 230 | if i == 0: 231 | model.eval() 232 | 233 | if not os.path.exists(os.path.join(image_directory, f"epoch_{epoch}")): 234 | os.mkdir(os.path.join(image_directory, f"epoch_{epoch}")) 235 | 236 | sequence = text_to_sequence(test_text) 237 | sequence = torch.LongTensor(sequence).unsqueeze(0).cuda() 238 | 239 | mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence) 240 | 241 | mel_outputs = mel_outputs.float().data.cpu().numpy()[0] 242 | mel_outputs_postnet = mel_outputs_postnet.float().data.cpu().numpy()[0] 243 | alignments = alignments.float().data.cpu().numpy()[0].T 244 | 245 | plt.imshow(mel_outputs, aspect='auto', origin='lower', interpolation='none') 246 | plt.savefig(os.path.join(image_directory, f"epoch_{epoch}", 'mel_outputs.png'), figsize=(16, 4)) 247 | 248 | plt.imshow(mel_outputs_postnet, aspect='auto', origin='lower', interpolation='none') 249 | plt.savefig(os.path.join(image_directory, f"epoch_{epoch}", 'mel_outputs_postnet.png'), figsize=(16, 4)) 250 | 251 | plt.imshow(alignments, aspect='auto', origin='lower', interpolation='none') 252 | plt.savefig(os.path.join(image_directory, f"epoch_{epoch}", 'alignments.png'), figsize=(16, 4)) 253 | 254 | wandb.log({ 255 | "Mel": [ 256 | wandb.Image(os.path.join(image_directory, f"epoch_{epoch}", 'mel_outputs.png')) 257 | ], 258 | "Mel-PostNet": [ 259 | wandb.Image(os.path.join(image_directory, f"epoch_{epoch}", 'mel_outputs_postnet.png')) 260 | ], 261 | "Alignments": [ 262 | wandb.Image(os.path.join(image_directory, f"epoch_{epoch}", 'alignments.png')) 263 | ], 264 | }) 265 | model.train() 266 | 267 | if hparams.distributed_run: 268 | reduced_loss = reduce_tensor(loss.data, n_gpus).item() 269 | else: 270 | reduced_loss = loss.item() 271 | if hparams.fp16_run: 272 | with amp.scale_loss(loss, optimizer) as scaled_loss: 273 | scaled_loss.backward() 274 | else: 275 | loss.backward() 276 | 277 | if hparams.fp16_run: 278 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), hparams.grad_clip_thresh) 279 | is_overflow = math.isnan(grad_norm) 280 | else: 281 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) 282 | 283 | optimizer.step() 284 | 285 | if not is_overflow and rank == 0: 286 | duration = time.perf_counter() - start 287 | print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( 288 | iteration, reduced_loss, grad_norm, duration)) 289 | wandb.log({ 290 | "train_loss": reduced_loss, 291 | "grad_norm": grad_norm, 292 | "lr": get_lr(optimizer), 293 | }) 294 | logger.log_training(reduced_loss, grad_norm, learning_rate, duration, iteration) 295 | 296 | if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): 297 | validate(model, criterion, valset, iteration, 298 | hparams.batch_size, n_gpus, collate_fn, logger, 299 | hparams.distributed_run, rank) 300 | if rank == 0: 301 | checkpoint_path = os.path.join( 302 | output_directory, "checkpoint_{}".format(iteration)) 303 | save_checkpoint(model, optimizer, learning_rate, iteration, 304 | checkpoint_path) 305 | 306 | iteration += 1 307 | 308 | 309 | if __name__ == '__main__': 310 | parser = argparse.ArgumentParser() 311 | parser.add_argument('-o', '--output_directory', type=str, 312 | help='directory to save checkpoints') 313 | parser.add_argument('-l', '--log_directory', type=str, 314 | help='directory to save tensorboard logs') 315 | parser.add_argument('-c', '--checkpoint_path', type=str, default=None, 316 | required=False, help='checkpoint path') 317 | parser.add_argument('--image_directory', type=str, default='images') 318 | parser.add_argument('--warm_start', action='store_true', 319 | help='load model weights only, ignore specified layers') 320 | parser.add_argument('--n_gpus', type=int, default=1, 321 | required=False, help='number of gpus') 322 | parser.add_argument('--rank', type=int, default=0, 323 | required=False, help='rank of current gpu') 324 | parser.add_argument('--group_name', type=str, default='group_name', 325 | required=False, help='Distributed group name') 326 | parser.add_argument('--hparams', type=str, 327 | required=False, help='comma separated name=value pairs') 328 | parser.add_argument('--test_text', type=str, 329 | required=False, default='알겠습니다') 330 | 331 | args = parser.parse_args() 332 | hparams = create_hparams(args.hparams) 333 | 334 | torch.backends.cudnn.enabled = hparams.cudnn_enabled 335 | torch.backends.cudnn.benchmark = hparams.cudnn_benchmark 336 | 337 | print("FP16 Run:", hparams.fp16_run) 338 | print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling) 339 | print("Distributed Run:", hparams.distributed_run) 340 | print("cuDNN Enabled:", hparams.cudnn_enabled) 341 | print("cuDNN Benchmark:", hparams.cudnn_benchmark) 342 | 343 | train(args.output_directory, args.log_directory, args.checkpoint_path, 344 | args.warm_start, args.n_gpus, args.rank, args.group_name, hparams, args.image_directory, 345 | test_text=args.test_text) 346 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io.wavfile import read 3 | import torch 4 | 5 | 6 | def get_mask_from_lengths(lengths): 7 | max_len = torch.max(lengths).item() 8 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 9 | mask = (ids < lengths.unsqueeze(1)).bool() 10 | return mask 11 | 12 | 13 | def load_wav_to_torch(full_path): 14 | sampling_rate, data = read(full_path) 15 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 16 | 17 | 18 | def load_filepaths_and_text(filename, split="|"): 19 | with open(filename, encoding='utf-8') as f: 20 | filepaths_and_text = [line.strip().split(split) for line in f] 21 | return filepaths_and_text 22 | 23 | 24 | def to_gpu(x): 25 | x = x.contiguous() 26 | 27 | if torch.cuda.is_available(): 28 | x = x.cuda(non_blocking=True) 29 | return torch.autograd.Variable(x) 30 | --------------------------------------------------------------------------------