├── .gitignore ├── LICENSE ├── README.md ├── audios └── LJ001-0007.wav ├── config.py ├── datasets ├── __init__.py ├── audio │ ├── __init__.py │ ├── augment.py │ ├── stft.py │ └── util.py ├── text │ ├── __init__.py │ ├── cleaners.py │ ├── numbers.py │ └── sequence.py └── text2mel.py ├── demo.py ├── helpers └── logger.py ├── models ├── __init__.py ├── layers.py ├── losses.py ├── optimizers.py └── tacotron.py ├── requirements.txt ├── synthesize.py ├── tests ├── test_librosa.py ├── test_number.py ├── test_pinyin.py ├── test_sequence.py └── test_stft.py ├── train.py └── utils ├── __init__.py ├── common.py ├── hparam.py └── plot.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Projects 2 | .vscode/ 3 | data/ 4 | logdir/ 5 | references/ 6 | wandb/ 7 | 000-* 8 | *.out 9 | 10 | # Git 11 | .git/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Atomicoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron-2 2 | 3 | Tacotron-2 模型的 PyTorch 实现,提出 Tacotron-2 的论文 [Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf)。(持续完善ING) 4 | 5 | ## 目录结构 6 | 7 | ``` 8 | . 9 | |--- audios/ 10 | |--- datasets/ # 数据集相关 11 | |--- audio/ 12 | |--- text/ 13 | |--- helpers/ # 辅助类 14 | |--- models/ # 模型相关 15 | |--- layers.py 16 | |--- losses.py 17 | |--- optimizers.py 18 | |--- tacotron.py 19 | |--- tests/ # 测试代码 20 | |--- utils/ # 一些通用方法 21 | |--- .gitignore 22 | |--- LICENSE 23 | |--- README.md # 说明文档(本文档) 24 | |--- requirements.txt # 依赖文件 25 | |--- train.py # 训练脚本 26 | |--- synthesize.py # 合成脚本 27 | ``` 28 | 29 | ## 数据集 30 | 31 | - [BZNSYP Dataset](https://www.data-baker.com/open_source.html) 32 | 33 | ## 快速开始 34 | 35 | **步骤(1)**:克隆仓库 36 | 37 | ```shell 38 | $ git clone https://github.com/atomicoo/Tacotron2-PyTorch.git 39 | ``` 40 | 41 | **步骤(2)**:安装依赖 42 | 43 | ```shell 44 | $ conda create -n Tacotron2 python=3.7.9 45 | $ conda activate Tacotron2 46 | $ pip install -r requirements.txt 47 | ``` 48 | 49 | **步骤(3)**:合成语音 50 | 51 | ```shell 52 | $ python synthesize.py 53 | ``` 54 | 55 | ## 如何训练 56 | 57 | -------------------------------------------------------------------------------- /audios/LJ001-0007.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomicoo/Tacotron2-PyTorch/71eadf2abafa201648dc16cbfb742ab0032077c9/audios/LJ001-0007.wav -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | 4 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | device = torch.device('cpu') 6 | 7 | meta_file = osp.join('data', 'BZNSYP', 'ProsodyLabeling', '000001-010000.txt') 8 | wave_folder = osp.join('data', 'BZNSYP', 'Wave') 9 | 10 | num_train = 9900 11 | num_valid = 100 12 | 13 | ################################ 14 | # Experiment Parameters # 15 | ################################ 16 | experiment = 'exp1' 17 | logdir = osp.join('.', 'logdir') 18 | epochs = 500 19 | iters_per_checkpoint = 1000 20 | seed = 2021 21 | dynamic_loss_scaling = True 22 | fp16_run = False 23 | distributed_run = False 24 | 25 | ################################ 26 | # Data Parameters # 27 | ################################ 28 | load_mel_from_disk = False 29 | train_files = osp.join('data', 'filelists', 'bznsyp_audio_text_train_filelist.txt') 30 | valid_files = osp.join('data', 'filelists', 'bznsyp_audio_text_valid_filelist.txt') 31 | 32 | ################################ 33 | # Text Parameters # 34 | ################################ 35 | use_phonemes = False 36 | graphemes_or_phonemes = list("abcdefghijklmnopqrstuvwxyz12345") 37 | specials = list([]) # list(["", ""]) 38 | punctuations = list([".", ",", "?", "!", " ", "-"]) 39 | 40 | ################################ 41 | # Audio Parameters # 42 | ################################ 43 | max_wav_value = 32768.0 44 | sampling_rate = 48000 # 22050 45 | filter_length = 1024 46 | hop_length = 256 47 | win_length = 1024 48 | n_mel_channels = 80 49 | mel_fmin = 0.0 50 | mel_fmax = 8000.0 51 | 52 | ################################ 53 | # Model Parameters # 54 | ################################ 55 | n_symbols = len(graphemes_or_phonemes+specials+punctuations) 56 | symbols_embedding_dim = 512 57 | 58 | # Encoder parameters 59 | encoder_kernel_size = 5 60 | encoder_n_convolutions = 3 61 | encoder_embedding_dim = 512 62 | 63 | # Decoder parameters 64 | n_frames_per_step = 1 # currently only 1 is supported 65 | decoder_rnn_dim = 1024 66 | prenet_dim = 256 67 | max_decoder_steps = 1000 68 | gate_threshold = 0.5 69 | p_attention_dropout = 0.1 70 | p_decoder_dropout = 0.1 71 | 72 | # Attention parameters 73 | attention_rnn_dim = 1024 74 | attention_dim = 128 75 | 76 | # Location Layer parameters 77 | attention_location_n_filters = 32 78 | attention_location_kernel_size = 31 79 | 80 | # Mel-post processing network parameters 81 | postnet_embedding_dim = 512 82 | postnet_kernel_size = 5 83 | postnet_n_convolutions = 5 84 | 85 | ################################ 86 | # Optimization Hyperparameters # 87 | ################################ 88 | learning_rate = 1e-3 89 | weight_decay = 1e-6 90 | batch_size = 64 91 | mask_padding = True # set model's padded outputs to padded values 92 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .text2mel import Text2MelDataset, Text2MelDataLoader 2 | -------------------------------------------------------------------------------- /datasets/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomicoo/Tacotron2-PyTorch/71eadf2abafa201648dc16cbfb742ab0032077c9/datasets/audio/__init__.py -------------------------------------------------------------------------------- /datasets/audio/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from random import sample, randrange 3 | 4 | 5 | def add_random_noise(specs, std_dev): 6 | """Add noise from Normal(0, std_dev) 7 | 8 | :param specs: 9 | :param std_dev: 10 | :return: 11 | """ 12 | if not std_dev: return specs 13 | return specs + std_dev * torch.randn(specs.shape).to(specs.device) 14 | 15 | 16 | def degrade_some(model, specs, texts, tlens, ratio, repeat=1): 17 | """Replace some spectrograms in batch by their generated equivalent 18 | 19 | Ideally, run this after adding random noise 20 | so that the generated spectrograms are slightly degenerated. 21 | 22 | :param ratio: How many percent of spectrograms in batch to degrade (0,1) 23 | :param repeat: How many times to degrade 24 | :return: 25 | """ 26 | if not ratio: return specs 27 | if not repeat: return specs 28 | 29 | idx = sample(range(len(specs)), int(ratio * len(specs))) 30 | 31 | with torch.no_grad(): 32 | s = specs 33 | for i in range(repeat): 34 | s, *_ = model((texts, tlens, specs, True)) 35 | 36 | specs[idx] = s[idx] 37 | 38 | return specs 39 | 40 | 41 | def replace_frames_with_random(specs, ratio, distrib=torch.rand): 42 | """ 43 | 44 | Each spectrogram gets different frames degraded. 45 | To use normal noise, set distrib=lambda shape: mean + std_dev * torch.randn(x) 46 | 47 | :param specs: 48 | :param ratio: between 0,1 - how many percent of frames to degrade 49 | :param distrib: default torch.rand -> [0, 1 uniform] 50 | :return: 51 | """ 52 | if not ratio: return specs 53 | 54 | t = specs.shape[1] 55 | num_frames = int(t * ratio) 56 | idx = [sample(range(t), num_frames) for i in range(len(specs))] # different for each spec. 57 | 58 | for s, _ in enumerate(specs): 59 | rnd_frames = distrib((num_frames, specs.shape[-1])).to(specs.device) 60 | specs[s, idx[s]] = rnd_frames 61 | 62 | return specs 63 | 64 | 65 | def frame_dropout(specs, ratio): 66 | """Replace random frames with zeros 67 | 68 | :param specs: 69 | :param ratio: 70 | :return: 71 | """ 72 | return replace_frames_with_random(specs, ratio, distrib=lambda shape: torch.zeros(shape)) 73 | 74 | 75 | def random_patches(specs1, specs2, width, slen): 76 | """Create random patches from spectrograms 77 | 78 | :param specs: (batch, time, channels) 79 | :param width: int 80 | :param slen: list of int 81 | :return: patches (batch, width, channels) 82 | """ 83 | 84 | idx = [randrange(l - width) for l in slen] 85 | patches1, patches2 = [s[i:i+width] for s, i in zip(specs1, idx)], [s[i:i+width] for s, i in zip(specs2, idx)] 86 | return torch.stack(patches1), torch.stack(patches2) 87 | -------------------------------------------------------------------------------- /datasets/audio/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import numpy as np 34 | import torch 35 | import torch.nn as nn 36 | import torch.nn.functional as F 37 | from torch.autograd import Variable 38 | from scipy.signal import get_window 39 | from librosa.util import pad_center, tiny 40 | from .util import window_sumsquare, dynamic_range_compression, dynamic_range_decompression 41 | from librosa.filters import mel as librosa_mel_fn 42 | 43 | 44 | class STFT(nn.Module): 45 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 46 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 47 | window='hann'): 48 | super(STFT, self).__init__() 49 | self.filter_length = filter_length 50 | self.hop_length = hop_length 51 | self.win_length = win_length 52 | self.window = window 53 | self.forward_transform = None 54 | scale = self.filter_length / self.hop_length 55 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 56 | 57 | cutoff = int((self.filter_length / 2 + 1)) 58 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 59 | np.imag(fourier_basis[:cutoff, :])]) 60 | 61 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 62 | inverse_basis = torch.FloatTensor( 63 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 64 | 65 | if window is not None: 66 | assert(filter_length >= win_length) 67 | # get window and zero center pad it to filter_length 68 | fft_window = get_window(window, win_length, fftbins=True) 69 | fft_window = pad_center(fft_window, filter_length) 70 | fft_window = torch.from_numpy(fft_window).float() 71 | 72 | # window the bases 73 | forward_basis *= fft_window 74 | inverse_basis *= fft_window 75 | 76 | self.register_buffer('forward_basis', forward_basis.float()) 77 | self.register_buffer('inverse_basis', inverse_basis.float()) 78 | 79 | def transform(self, input_data): 80 | num_batches = input_data.size(0) 81 | num_samples = input_data.size(1) 82 | 83 | self.num_samples = num_samples 84 | 85 | # similar to librosa, reflect-pad the input 86 | input_data = input_data.view(num_batches, 1, num_samples) 87 | input_data = F.pad( 88 | input_data.unsqueeze(1), 89 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 90 | mode='reflect') 91 | input_data = input_data.squeeze(1) 92 | 93 | # https://github.com/NVIDIA/tacotron2/issues/125 94 | forward_transform = F.conv1d( 95 | input_data.cuda(), 96 | Variable(self.forward_basis, requires_grad=False).cuda(), 97 | stride=self.hop_length, 98 | padding=0).cpu() 99 | 100 | cutoff = int((self.filter_length / 2) + 1) 101 | real_part = forward_transform[:, :cutoff, :] 102 | imag_part = forward_transform[:, cutoff:, :] 103 | 104 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 105 | phase = Variable(torch.atan2(imag_part.data, real_part.data)) 106 | 107 | return magnitude, phase 108 | 109 | def inverse(self, magnitude, phase): 110 | recombine_magnitude_phase = torch.cat( 111 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 112 | 113 | inverse_transform = F.conv_transpose1d( 114 | recombine_magnitude_phase, 115 | Variable(self.inverse_basis, requires_grad=False), 116 | stride=self.hop_length, 117 | padding=0) 118 | 119 | if self.window is not None: 120 | window_sum = window_sumsquare( 121 | self.window, magnitude.size(-1), hop_length=self.hop_length, 122 | win_length=self.win_length, n_fft=self.filter_length, 123 | dtype=np.float32) 124 | # remove modulation effects 125 | approx_nonzero_indices = torch.from_numpy( 126 | np.where(window_sum > tiny(window_sum))[0]) 127 | window_sum = Variable( 128 | torch.from_numpy(window_sum), requires_grad=False) 129 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 130 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 131 | 132 | # scale by hop ratio 133 | inverse_transform *= float(self.filter_length) / self.hop_length 134 | 135 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 136 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 137 | 138 | return inverse_transform 139 | 140 | def forward(self, input_data): 141 | self.magnitude, self.phase = self.transform(input_data) 142 | reconstruction = self.inverse(self.magnitude, self.phase) 143 | return reconstruction 144 | 145 | 146 | class TacotronSTFT(nn.Module): 147 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 148 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 149 | mel_fmax=None): 150 | super(TacotronSTFT, self).__init__() 151 | self.n_mel_channels = n_mel_channels 152 | self.sampling_rate = sampling_rate 153 | self.stft_fn = STFT(filter_length, hop_length, win_length) 154 | mel_basis = librosa_mel_fn( 155 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 156 | mel_basis = torch.from_numpy(mel_basis).float() 157 | self.register_buffer('mel_basis', mel_basis) 158 | 159 | def spectral_normalize(self, magnitudes): 160 | output = dynamic_range_compression(magnitudes) 161 | return output 162 | 163 | def spectral_de_normalize(self, magnitudes): 164 | output = dynamic_range_decompression(magnitudes) 165 | return output 166 | 167 | def mel_spectrogram(self, y): 168 | """Computes mel-spectrograms from a batch of waves 169 | PARAMS 170 | ------ 171 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 172 | 173 | RETURNS 174 | ------- 175 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 176 | """ 177 | assert(torch.min(y.data) >= -1) 178 | assert(torch.max(y.data) <= 1) 179 | 180 | magnitudes, phases = self.stft_fn.transform(y) 181 | magnitudes = magnitudes.data 182 | mel_output = torch.matmul(self.mel_basis, magnitudes) 183 | mel_output = self.spectral_normalize(mel_output) 184 | return mel_output 185 | -------------------------------------------------------------------------------- /datasets/audio/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /datasets/text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomicoo/Tacotron2-PyTorch/71eadf2abafa201648dc16cbfb742ab0032077c9/datasets/text/__init__.py -------------------------------------------------------------------------------- /datasets/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from .numbers import normalize_numbers 4 | 5 | # Regular expression matching whitespace: 6 | _whitespace_re = re.compile(r'\s+') 7 | 8 | 9 | def expand_numbers(text): 10 | return normalize_numbers(text) 11 | 12 | 13 | def collapse_whitespace(text): 14 | return re.sub(_whitespace_re, ' ', text) 15 | 16 | 17 | def chinese_cleaners(text): 18 | text = expand_numbers(text) 19 | text = collapse_whitespace(text) 20 | return text 21 | -------------------------------------------------------------------------------- /datasets/text/numbers.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | 4 | _fraction_re = re.compile(r'([0-9]+\/[0-9]+)') 5 | _decimal_re = re.compile(r'([0-9]+\.[0-9]+)') 6 | _number_re = re.compile(r'[0-9]+') 7 | 8 | 9 | def _expand_fraction(m): 10 | numerator, denominator = m.group(1).split('/') 11 | return '{}分之{}'.format(denominator, numerator) 12 | 13 | def _expand_decimal(m): 14 | return m.group(1).replace('.', '点') 15 | 16 | 17 | def _expand_number(m, big=False, simp=True, o=False, twoalt=True): 18 | num = int(m.group(0)) 19 | """ 20 | Converts numbers to Chinese representations. 21 | `big` : use financial characters. 22 | `simp` : use simplified characters instead of traditional characters. 23 | `o` : use 〇 for zero. 24 | `twoalt`: use 两/兩 for two when appropriate. 25 | Note that `o` and `twoalt` is ignored when `big` is used, 26 | and `twoalt` is ignored when `o` is used for formal representations. 27 | """ 28 | # check num first 29 | nd = str(num) 30 | if abs(float(nd)) >= 1e48: 31 | raise ValueError('number out of range') 32 | elif 'e' in nd: 33 | raise ValueError('scientific notation is not supported') 34 | c_symbol = '正负点' if simp else '正負點' 35 | if o: # formal 36 | twoalt = False 37 | if big: 38 | c_basic = '零壹贰叁肆伍陆柒捌玖' if simp else '零壹貳參肆伍陸柒捌玖' 39 | c_unit1 = '拾佰仟' 40 | c_twoalt = '贰' if simp else '貳' 41 | else: 42 | c_basic = '〇一二三四五六七八九' if o else '零一二三四五六七八九' 43 | c_unit1 = '十百千' 44 | if twoalt: 45 | c_twoalt = '两' if simp else '兩' 46 | else: 47 | c_twoalt = '二' 48 | c_unit2 = '万亿兆京垓秭穰沟涧正载' if simp else '萬億兆京垓秭穰溝澗正載' 49 | revuniq = lambda l: ''.join(k for k, g in itertools.groupby(reversed(l))) 50 | nd = str(num) 51 | result = [] 52 | if nd[0] == '+': 53 | result.append(c_symbol[0]) 54 | elif nd[0] == '-': 55 | result.append(c_symbol[1]) 56 | if '.' in nd: 57 | integer, remainder = nd.lstrip('+-').split('.') 58 | else: 59 | integer, remainder = nd.lstrip('+-'), None 60 | if int(integer): 61 | splitted = [integer[max(i - 4, 0):i] 62 | for i in range(len(integer), 0, -4)] 63 | intresult = [] 64 | for nu, unit in enumerate(splitted): 65 | # special cases 66 | if int(unit) == 0: # 0000 67 | intresult.append(c_basic[0]) 68 | continue 69 | elif nu > 0 and int(unit) == 2: # 0002 70 | intresult.append(c_twoalt + c_unit2[nu - 1]) 71 | continue 72 | ulist = [] 73 | unit = unit.zfill(4) 74 | for nc, ch in enumerate(reversed(unit)): 75 | if ch == '0': 76 | if ulist: # ???0 77 | ulist.append(c_basic[0]) 78 | elif nc == 0: 79 | ulist.append(c_basic[int(ch)]) 80 | elif nc == 1 and ch == '1' and unit[1] == '0': 81 | # special case for tens 82 | # edit the 'elif' if you don't like 83 | # 十四, 三千零十四, 三千三百一十四 84 | ulist.append(c_unit1[0]) 85 | elif nc > 1 and ch == '2': 86 | ulist.append(c_twoalt + c_unit1[nc - 1]) 87 | else: 88 | ulist.append(c_basic[int(ch)] + c_unit1[nc - 1]) 89 | ustr = revuniq(ulist) 90 | if nu == 0: 91 | intresult.append(ustr) 92 | else: 93 | intresult.append(ustr + c_unit2[nu - 1]) 94 | result.append(revuniq(intresult).strip(c_basic[0])) 95 | else: 96 | result.append(c_basic[0]) 97 | if remainder: 98 | result.append(c_symbol[2]) 99 | result.append(''.join(c_basic[int(ch)] for ch in remainder)) 100 | return ''.join(result) 101 | 102 | 103 | def normalize_numbers(text): 104 | text = re.sub(_fraction_re, _expand_fraction, text) 105 | # print(text) 106 | text = re.sub(_decimal_re, _expand_decimal, text) 107 | # print(text) 108 | text = re.sub(_number_re, _expand_number, text) 109 | return text 110 | -------------------------------------------------------------------------------- /datasets/text/sequence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Sequence(nn.Module): 6 | def __init__(self, graphemes_or_phonemes=[], use_phonemes=True, 7 | specials=[], punctuations=[]): 8 | super(Sequence, self).__init__() 9 | self.phonemize = use_phonemes 10 | self.specials = specials 11 | self.graphemes_or_phonemes = graphemes_or_phonemes 12 | self.punctuations = punctuations 13 | self.units = self.specials + graphemes_or_phonemes + self.punctuations 14 | 15 | self.txt2idx = {txt: idx for idx, txt in enumerate(self.units)} 16 | self.idx2txt = {idx: txt for idx, txt in enumerate(self.units)} 17 | 18 | def text_to_sequence(self, text): 19 | # text = chinese_cleaners(text) 20 | sequence = torch.IntTensor([self.txt2idx[ch] for ch in text]) 21 | return sequence 22 | 23 | def sequence_to_text(self, sequence): 24 | text = [self.idx2txt[idx] for idx in sequence] 25 | return text 26 | 27 | -------------------------------------------------------------------------------- /datasets/text2mel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import librosa 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from datasets.text.sequence import Sequence 10 | from datasets.audio.stft import TacotronSTFT 11 | 12 | 13 | def load_filepaths_and_text(filename, split="|"): 14 | with open(filename, encoding='utf-8') as f: 15 | filepaths_and_text = [line.strip().split(split) for line in f] 16 | return filepaths_and_text 17 | 18 | 19 | def load_wav_to_torch(full_path, sampling_rate=None): 20 | y, sr = librosa.core.load(full_path, sampling_rate) 21 | yt, _ = librosa.effects.trim(y) 22 | return torch.FloatTensor(yt.astype(np.float32)), sr 23 | 24 | 25 | class Text2MelDataset(Dataset): 26 | """ 27 | 1) loads audio,text pairs 28 | 2) normalizes text and converts them to sequences of one-hot vectors 29 | 3) computes mel-spectrograms from audio files. 30 | """ 31 | 32 | def __init__(self, filepaths_and_text, hparams): 33 | # self.max_dataset_size = hparams.max_dataset_size 34 | self.filepaths_and_text = load_filepaths_and_text(filepaths_and_text) 35 | self.seq = Sequence(graphemes_or_phonemes=hparams.graphemes_or_phonemes, 36 | use_phonemes=hparams.use_phonemes, 37 | specials=hparams.specials, 38 | punctuations=hparams.punctuations) 39 | self.sampling_rate = hparams.sampling_rate 40 | # self.max_wav_value = hparams.max_wav_value 41 | self.load_mel_from_disk = hparams.load_mel_from_disk 42 | self.stft = TacotronSTFT(filter_length=hparams.filter_length, 43 | hop_length=hparams.hop_length, 44 | win_length=hparams.win_length, 45 | n_mel_channels=hparams.n_mel_channels, 46 | sampling_rate=hparams.sampling_rate, 47 | mel_fmin=hparams.mel_fmin, 48 | mel_fmax=hparams.mel_fmax) 49 | random.seed(2021) 50 | random.shuffle(self.filepaths_and_text) 51 | 52 | def get_spec_text_pair(self, filepath_and_text): 53 | # separate filename and text 54 | filepath, text = filepath_and_text 55 | text = self.get_text(text) 56 | spec = self.get_spec(filepath) 57 | return (text, spec) 58 | 59 | def get_spec(self, filename): 60 | if self.load_mel_from_disk: 61 | melspec = torch.from_numpy(np.load(filename)) 62 | else: 63 | audio, sampling_rate = load_wav_to_torch(filename) 64 | if sampling_rate != self.stft.sampling_rate: 65 | raise ValueError("{} SR doesn't match target {} SR".format( 66 | sampling_rate, self.stft.sampling_rate)) 67 | # audio_norm = audio / self.max_wav_value 68 | audio_norm = audio.unsqueeze(0) 69 | audio_norm = Variable(audio_norm, requires_grad=False) 70 | melspec = self.stft.mel_spectrogram(audio_norm) 71 | melspec = melspec.squeeze(0) 72 | 73 | return melspec 74 | 75 | def get_text(self, text): 76 | text = self.seq.text_to_sequence(text) 77 | return text 78 | 79 | def __getitem__(self, index): 80 | return self.get_spec_text_pair(self.filepaths_and_text[index]) 81 | 82 | def __len__(self): 83 | return len(self.filepaths_and_text) 84 | 85 | 86 | class Text2MelDataLoader(DataLoader): 87 | def __init__(self, text2mel_dataset, hparams, \ 88 | shuffle=True, num_workers=0 if sys.platform.startswith('win') else 8, **kwargs): 89 | collate_fn = Text2MelCollate(n_frames_per_step=hparams.n_frames_per_step) 90 | super(Text2MelDataLoader, self).__init__( 91 | dataset=text2mel_dataset, batch_size=hparams.batch_size, 92 | shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, **kwargs) 93 | 94 | 95 | class Text2MelCollate: 96 | """ Zero-pads model inputs and targets based on number of frames per step 97 | """ 98 | 99 | def __init__(self, n_frames_per_step): 100 | self.n_frames_per_step = n_frames_per_step 101 | 102 | def __call__(self, batch): 103 | """Collate's training batch from normalized text and mel-spectrogram 104 | PARAMS 105 | ------ 106 | batch: [text_normalized, mel_normalized] 107 | """ 108 | batch_size = len(batch) 109 | # Right zero-pad all one-hot text sequences to max input length 110 | text_lengths, ids_sorted = \ 111 | torch.LongTensor([len(x[0]) for x in batch]).sort(dim=0, descending=True) 112 | max_text_len = text_lengths[0] 113 | 114 | text_padded = torch.LongTensor(batch_size, max_text_len) 115 | text_padded.zero_() 116 | for i in range(len(ids_sorted)): 117 | text = batch[ids_sorted[i]][0] 118 | text_padded[i, :text.size(0)] = text 119 | 120 | # Right zero-pad melspec 121 | n_mel_channels = batch[0][1].size(0) 122 | max_spec_len = max([x[1].size(1) for x in batch]) 123 | if max_spec_len % self.n_frames_per_step != 0: 124 | max_spec_len += self.n_frames_per_step - max_spec_len % self.n_frames_per_step 125 | assert max_spec_len % self.n_frames_per_step == 0 126 | 127 | # include mel padded and gate padded 128 | spec_padded = torch.FloatTensor(batch_size, n_mel_channels, max_spec_len) 129 | spec_padded.zero_() 130 | gate_padded = torch.FloatTensor(batch_size, max_spec_len) 131 | gate_padded.zero_() 132 | spec_lengths = torch.LongTensor(batch_size) 133 | for i in range(len(ids_sorted)): 134 | spec = batch[ids_sorted[i]][1] 135 | spec_padded[i, :, :spec.size(1)] = spec 136 | gate_padded[i, spec.size(1) - 1:] = 1 137 | spec_lengths[i] = spec.size(1) 138 | 139 | return text_padded, text_lengths, spec_padded, gate_padded, spec_lengths 140 | 141 | 142 | if __name__ == '__main__': 143 | import config 144 | 145 | collate_fn = Text2MelCollate(config.n_frames_per_step) 146 | 147 | train_dataset = Text2MelDataset(config.train_files, config) 148 | print('len(train_dataset): ' + str(len(train_dataset))) 149 | 150 | valid_dataset = Text2MelDataset(config.valid_files, config) 151 | print('len(valid_dataset): ' + str(len(valid_dataset))) 152 | 153 | text, spec = valid_dataset[0] 154 | print('type(spec): ' + str(type(spec))) 155 | 156 | text_lengths = [] 157 | spec_lengths = [] 158 | 159 | for data in valid_dataset: 160 | text, spec = data 161 | text = valid_dataset.seq.sequence_to_text(text.numpy().tolist()) 162 | text = ''.join(text) 163 | spec = spec.numpy() 164 | 165 | print('text: ' + str(text)) 166 | print('spec.size: ' + str(spec.size)) 167 | text_lengths.append(len(text)) 168 | spec_lengths.append(spec.size) 169 | # print('np.mean(spec): ' + str(np.mean(spec))) 170 | # print('np.max(spec): ' + str(np.max(spec))) 171 | # print('np.min(spec): ' + str(np.min(spec))) 172 | 173 | print('np.mean(text_lengths): ' + str(np.mean(text_lengths))) 174 | print('np.mean(spec_lengths): ' + str(np.mean(spec_lengths))) 175 | 176 | train_loader = Text2MelDataLoader(train_dataset, config, shuffle=True) 177 | print('len(train_loader): ' + str(len(train_loader))) 178 | 179 | valid_loader = Text2MelDataLoader(valid_dataset, config, shuffle=False) 180 | print('len(valid_loader): ' + str(len(valid_loader))) 181 | 182 | batch = iter(valid_loader).next() 183 | print('type(spec): ' + str(type(batch))) 184 | print('batch[0].size(): ' + str(batch[0].size())) 185 | print('batch[1].size(): ' + str(batch[1].size())) 186 | print('batch[2].size(): ' + str(batch[2].size())) 187 | print('batch[3].size(): ' + str(batch[3].size())) 188 | print('batch[4].size(): ' + str(batch[4].size())) 189 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomicoo/Tacotron2-PyTorch/71eadf2abafa201648dc16cbfb742ab0032077c9/demo.py -------------------------------------------------------------------------------- /helpers/logger.py: -------------------------------------------------------------------------------- 1 | """Wrapper class for logging into the TensorBoard, comet.ml and wandb""" 2 | __author__ = 'Atomicoo' 3 | __all__ = ['Logger'] 4 | 5 | import os 6 | from tensorboardX import SummaryWriter 7 | try: 8 | import wandb 9 | except ImportError: 10 | wandb = None 11 | 12 | 13 | class Logger(object): 14 | 15 | def __init__(self, logdir, experiment, model_name, wandb_info=None): 16 | self.model_name = model_name 17 | self.project_name = "%s-%s" % (experiment, self.model_name) 18 | self.logdir = os.path.join(logdir, self.project_name) 19 | self.writer = SummaryWriter(log_dir=self.logdir) 20 | self.wandb = None if wandb_info is None else wandb 21 | if self.wandb and self.wandb.run is None: 22 | self.wandb.init(**wandb_info) 23 | 24 | def log_model(self, model): 25 | self.writer.add_graph(model) 26 | if self.wandb is not None: 27 | self.wandb.watch(model) 28 | 29 | def log_step(self, phase, step, scalar_dict, figure_dict=None): 30 | if phase == 'train': 31 | if step % 2 == 0: 32 | # self.writer.add_scalar('lr', get_lr(), step) 33 | # self.writer.add_scalar('%s-step/loss' % phase, loss, step) 34 | for key in sorted(scalar_dict): 35 | self.writer.add_scalar(f"{phase}-step/{key}", scalar_dict[key], step) 36 | if self.wandb is not None: 37 | self.wandb.log(scalar_dict) 38 | 39 | if step % 10 == 0: 40 | for key in sorted(figure_dict): 41 | self.writer.add_figure(f"{self.model_name}/{key}", figure_dict[key], step) 42 | if self.wandb is not None: 43 | self.wandb.log({k: self.wandb.Image(v) for k,v in figure_dict.items()}) 44 | 45 | def log_epoch(self, phase, epoch, scalar_dict): 46 | for key in sorted(scalar_dict): 47 | self.writer.add_scalar(f"{phase}/{key}", scalar_dict[key], epoch) 48 | if self.wandb is not None: 49 | self.wandb.log(scalar_dict) 50 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .tacotron import Tacotron2 2 | # from .losses import Tacotron2Loss 3 | # from .optims import Tacotron2Optimizer 4 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LinearNorm(nn.Module): 6 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 7 | super(LinearNorm, self).__init__() 8 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) 9 | 10 | nn.init.xavier_uniform_( 11 | self.linear_layer.weight, 12 | gain=nn.init.calculate_gain(w_init_gain)) 13 | 14 | def forward(self, x): 15 | return self.linear_layer(x) 16 | 17 | 18 | class ConvNorm(nn.Module): 19 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 20 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 21 | super(ConvNorm, self).__init__() 22 | if padding is None: 23 | assert (kernel_size % 2 == 1) 24 | padding = int(dilation * (kernel_size - 1) / 2) 25 | 26 | self.conv = nn.Conv1d(in_channels, out_channels, 27 | kernel_size=kernel_size, stride=stride, 28 | padding=padding, dilation=dilation, 29 | bias=bias) 30 | 31 | nn.init.xavier_uniform_( 32 | self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) 33 | 34 | def forward(self, signal): 35 | conv_signal = self.conv(signal) 36 | return conv_signal 37 | 38 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class AverageMeter: 5 | """Keeps track of most recent, average, sum, and count of a metric.""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class Tacotron2Loss(nn.Module): 24 | def __init__(self): 25 | super(Tacotron2Loss, self).__init__() 26 | 27 | def forward(self, model_output, targets): 28 | spec_target, gate_target = targets[0], targets[1] 29 | spec_target.requires_grad_(False) 30 | gate_target.requires_grad_(False) 31 | gate_target = gate_target.view(-1, 1) 32 | 33 | spec_out, spec_out_postnet, gate_out, _ = model_output 34 | gate_out = gate_out.view(-1, 1) 35 | spec_loss = nn.MSELoss()(spec_out, spec_target) + \ 36 | nn.MSELoss()(spec_out_postnet, spec_target) 37 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 38 | return spec_loss + gate_loss 39 | -------------------------------------------------------------------------------- /models/optimizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Tacotron2Optimizer(object): 5 | """A simple wrapper class for learning rate scheduling""" 6 | 7 | def __init__(self, optimizer, max_lr=1e-3, min_lr=1e-5, warmup_steps=45000, k=0.0001): 8 | self.optimizer = optimizer 9 | self.max_lr = max_lr 10 | self.min_lr = min_lr 11 | self.warmup_steps = warmup_steps 12 | self.k = k 13 | self.step_num = 0 14 | self.lr = self.max_lr 15 | 16 | def zero_grad(self): 17 | self.optimizer.zero_grad() 18 | 19 | def step(self): 20 | self._update_lr() 21 | self.optimizer.step() 22 | 23 | def _update_lr(self): 24 | self.step_num += 1 25 | if self.step_num > self.warmup_steps: 26 | self.lr = self.max_lr * np.exp(-1.0 * self.k * (self.step_num - self.warmup_steps)) 27 | self.lr = max(self.lr, self.min_lr) 28 | for param_group in self.optimizer.param_groups: 29 | param_group['lr'] = self.lr 30 | -------------------------------------------------------------------------------- /models/tacotron.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | 8 | from .layers import ConvNorm, LinearNorm 9 | 10 | 11 | def get_mask_from_lengths(lengths): 12 | max_len = torch.max(lengths).item() 13 | ids = torch.arange(0, max_len, out=torch.LongTensor(max_len).to(lengths.device)) 14 | mask = (ids < lengths.unsqueeze(1)).bool() 15 | return mask 16 | 17 | 18 | class LocationLayer(nn.Module): 19 | def __init__(self, attention_n_filters, attention_kernel_size, 20 | attention_dim): 21 | super(LocationLayer, self).__init__() 22 | padding = int((attention_kernel_size - 1) / 2) 23 | self.location_conv = ConvNorm(2, attention_n_filters, 24 | kernel_size=attention_kernel_size, 25 | padding=padding, bias=False, stride=1, 26 | dilation=1) 27 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 28 | bias=False, w_init_gain='tanh') 29 | 30 | def forward(self, attention_weights_cat): 31 | processed_attention = self.location_conv(attention_weights_cat) 32 | processed_attention = processed_attention.transpose(1, 2) 33 | processed_attention = self.location_dense(processed_attention) 34 | return processed_attention 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 39 | attention_location_n_filters, attention_location_kernel_size): 40 | super(Attention, self).__init__() 41 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 42 | bias=False, w_init_gain='tanh') 43 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 44 | w_init_gain='tanh') 45 | self.v = LinearNorm(attention_dim, 1, bias=False) 46 | self.location_layer = LocationLayer(attention_location_n_filters, 47 | attention_location_kernel_size, 48 | attention_dim) 49 | self.score_mask_value = -float("inf") 50 | 51 | def get_alignment_energies(self, query, processed_memory, 52 | attention_weights_cat): 53 | """ 54 | PARAMS 55 | ------ 56 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 57 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 58 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 59 | RETURNS 60 | ------- 61 | alignment (batch, max_time) 62 | """ 63 | 64 | processed_query = self.query_layer(query.unsqueeze(1)) 65 | processed_attention_weights = self.location_layer(attention_weights_cat) 66 | energies = self.v(torch.tanh( 67 | processed_query + processed_attention_weights + processed_memory)) 68 | 69 | energies = energies.squeeze(-1) 70 | return energies 71 | 72 | def forward(self, attention_hidden_state, memory, processed_memory, 73 | attention_weights_cat, mask): 74 | """ 75 | PARAMS 76 | ------ 77 | attention_hidden_state: attention rnn last output 78 | memory: encoder outputs 79 | processed_memory: processed encoder outputs 80 | attention_weights_cat: previous and cummulative attention weights 81 | mask: binary mask for padded data 82 | """ 83 | alignment = self.get_alignment_energies( 84 | attention_hidden_state, processed_memory, attention_weights_cat) 85 | 86 | if mask is not None: 87 | alignment.data.masked_fill_(mask, self.score_mask_value) 88 | 89 | attention_weights = F.softmax(alignment, dim=1) 90 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 91 | attention_context = attention_context.squeeze(1) 92 | 93 | return attention_context, attention_weights 94 | 95 | 96 | class Prenet(nn.Module): 97 | def __init__(self, in_dim, sizes): 98 | super(Prenet, self).__init__() 99 | in_sizes = [in_dim] + sizes[:-1] 100 | self.layers = nn.ModuleList( 101 | [LinearNorm(in_size, out_size, bias=False) 102 | for (in_size, out_size) in zip(in_sizes, sizes)]) 103 | 104 | def forward(self, x): 105 | for linear in self.layers: 106 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 107 | return x 108 | 109 | 110 | class Postnet(nn.Module): 111 | """Postnet 112 | - Five 1-d convolution with 512 channels and kernel size 5 113 | """ 114 | 115 | def __init__(self, hparams): 116 | super(Postnet, self).__init__() 117 | self.convolutions = nn.ModuleList() 118 | 119 | self.convolutions.append( 120 | nn.Sequential( 121 | ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, 122 | kernel_size=hparams.postnet_kernel_size, stride=1, 123 | padding=int((hparams.postnet_kernel_size - 1) / 2), 124 | dilation=1, w_init_gain='tanh'), 125 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 126 | ) 127 | 128 | for i in range(1, hparams.postnet_n_convolutions - 1): 129 | self.convolutions.append( 130 | nn.Sequential( 131 | ConvNorm(hparams.postnet_embedding_dim, 132 | hparams.postnet_embedding_dim, 133 | kernel_size=hparams.postnet_kernel_size, stride=1, 134 | padding=int((hparams.postnet_kernel_size - 1) / 2), 135 | dilation=1, w_init_gain='tanh'), 136 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 137 | ) 138 | 139 | self.convolutions.append( 140 | nn.Sequential( 141 | ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, 142 | kernel_size=hparams.postnet_kernel_size, stride=1, 143 | padding=int((hparams.postnet_kernel_size - 1) / 2), 144 | dilation=1, w_init_gain='linear'), 145 | nn.BatchNorm1d(hparams.n_mel_channels)) 146 | ) 147 | 148 | def forward(self, x): 149 | for i in range(len(self.convolutions) - 1): 150 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 151 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 152 | 153 | return x 154 | 155 | 156 | class Encoder(nn.Module): 157 | """Encoder module: 158 | - Three 1-d convolution banks 159 | - Bidirectional LSTM 160 | """ 161 | 162 | def __init__(self, hparams): 163 | super(Encoder, self).__init__() 164 | 165 | convolutions = [] 166 | for _ in range(hparams.encoder_n_convolutions): 167 | conv_layer = nn.Sequential( 168 | ConvNorm(hparams.encoder_embedding_dim, 169 | hparams.encoder_embedding_dim, 170 | kernel_size=hparams.encoder_kernel_size, stride=1, 171 | padding=int((hparams.encoder_kernel_size - 1) / 2), 172 | dilation=1, w_init_gain='relu'), 173 | nn.BatchNorm1d(hparams.encoder_embedding_dim)) 174 | convolutions.append(conv_layer) 175 | self.convolutions = nn.ModuleList(convolutions) 176 | 177 | self.lstm = nn.LSTM(hparams.encoder_embedding_dim, 178 | int(hparams.encoder_embedding_dim / 2), 1, 179 | batch_first=True, bidirectional=True) 180 | 181 | def forward(self, x, input_lengths): 182 | for conv in self.convolutions: 183 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 184 | 185 | x = x.transpose(1, 2) 186 | 187 | # pytorch tensor are not reversible, hence the conversion 188 | input_lengths = input_lengths.cpu().numpy() 189 | x = nn.utils.rnn.pack_padded_sequence( 190 | x, input_lengths, batch_first=True) 191 | 192 | self.lstm.flatten_parameters() 193 | outputs, _ = self.lstm(x) 194 | 195 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 196 | outputs, batch_first=True) 197 | 198 | return outputs 199 | 200 | def inference(self, x): 201 | for conv in self.convolutions: 202 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 203 | 204 | x = x.transpose(1, 2) 205 | 206 | self.lstm.flatten_parameters() 207 | outputs, _ = self.lstm(x) 208 | 209 | return outputs 210 | 211 | 212 | class Decoder(nn.Module): 213 | def __init__(self, hparams): 214 | super(Decoder, self).__init__() 215 | self.n_mel_channels = hparams.n_mel_channels 216 | self.n_frames_per_step = hparams.n_frames_per_step 217 | self.encoder_embedding_dim = hparams.encoder_embedding_dim 218 | self.attention_rnn_dim = hparams.attention_rnn_dim 219 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 220 | self.prenet_dim = hparams.prenet_dim 221 | self.max_decoder_steps = hparams.max_decoder_steps 222 | self.gate_threshold = hparams.gate_threshold 223 | self.p_attention_dropout = hparams.p_attention_dropout 224 | self.p_decoder_dropout = hparams.p_decoder_dropout 225 | 226 | self.prenet = Prenet( 227 | hparams.n_mel_channels * hparams.n_frames_per_step, 228 | [hparams.prenet_dim, hparams.prenet_dim]) 229 | 230 | self.attention_rnn = nn.LSTMCell( 231 | hparams.prenet_dim + hparams.encoder_embedding_dim, 232 | hparams.attention_rnn_dim) 233 | 234 | self.attention_layer = Attention( 235 | hparams.attention_rnn_dim, hparams.encoder_embedding_dim, 236 | hparams.attention_dim, hparams.attention_location_n_filters, 237 | hparams.attention_location_kernel_size) 238 | 239 | self.decoder_rnn = nn.LSTMCell( 240 | hparams.attention_rnn_dim + hparams.encoder_embedding_dim, 241 | hparams.decoder_rnn_dim, 1) 242 | 243 | self.linear_projection = LinearNorm( 244 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 245 | hparams.n_mel_channels * hparams.n_frames_per_step) 246 | 247 | self.gate_layer = LinearNorm( 248 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, 249 | bias=True, w_init_gain='sigmoid') 250 | 251 | def get_go_frame(self, memory): 252 | """ Gets all zeros frames to use as first decoder input 253 | PARAMS 254 | ------ 255 | memory: decoder outputs 256 | RETURNS 257 | ------- 258 | decoder_input: all zeros frames 259 | """ 260 | B = memory.size(0) 261 | decoder_input = Variable(memory.data.new( 262 | B, self.n_mel_channels * self.n_frames_per_step).zero_()) 263 | return decoder_input 264 | 265 | def initialize_decoder_states(self, memory, mask): 266 | """ Initializes attention rnn states, decoder rnn states, attention 267 | weights, attention cumulative weights, attention context, stores memory 268 | and stores processed memory 269 | PARAMS 270 | ------ 271 | memory: Encoder outputs 272 | mask: Mask for padded data if training, expects None for inference 273 | """ 274 | B = memory.size(0) 275 | MAX_TIME = memory.size(1) 276 | 277 | self.attention_hidden = Variable(memory.data.new( 278 | B, self.attention_rnn_dim).zero_()) 279 | self.attention_cell = Variable(memory.data.new( 280 | B, self.attention_rnn_dim).zero_()) 281 | 282 | self.decoder_hidden = Variable(memory.data.new( 283 | B, self.decoder_rnn_dim).zero_()) 284 | self.decoder_cell = Variable(memory.data.new( 285 | B, self.decoder_rnn_dim).zero_()) 286 | 287 | self.attention_weights = Variable(memory.data.new( 288 | B, MAX_TIME).zero_()) 289 | self.attention_weights_cum = Variable(memory.data.new( 290 | B, MAX_TIME).zero_()) 291 | self.attention_context = Variable(memory.data.new( 292 | B, self.encoder_embedding_dim).zero_()) 293 | 294 | self.memory = memory 295 | self.processed_memory = self.attention_layer.memory_layer(memory) 296 | self.mask = mask 297 | 298 | def parse_decoder_inputs(self, decoder_inputs): 299 | """ Prepares decoder inputs, i.e. mel outputs 300 | PARAMS 301 | ------ 302 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 303 | RETURNS 304 | ------- 305 | inputs: processed decoder inputs 306 | """ 307 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 308 | decoder_inputs = decoder_inputs.transpose(1, 2) 309 | decoder_inputs = decoder_inputs.view( 310 | decoder_inputs.size(0), 311 | int(decoder_inputs.size(1) / self.n_frames_per_step), -1) 312 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 313 | decoder_inputs = decoder_inputs.transpose(0, 1) 314 | return decoder_inputs 315 | 316 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): 317 | """ Prepares decoder outputs for output 318 | PARAMS 319 | ------ 320 | mel_outputs: 321 | gate_outputs: gate output energies 322 | alignments: 323 | RETURNS 324 | ------- 325 | mel_outputs: 326 | gate_outpust: gate output energies 327 | alignments: 328 | """ 329 | # (T_out, B) -> (B, T_out) 330 | alignments = torch.stack(alignments).transpose(0, 1) 331 | # (T_out, B) -> (B, T_out) 332 | gate_outputs = torch.stack(gate_outputs).transpose(0, 1) 333 | gate_outputs = gate_outputs.contiguous() 334 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 335 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 336 | # decouple frames per step 337 | mel_outputs = mel_outputs.view( 338 | mel_outputs.size(0), -1, self.n_mel_channels) 339 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 340 | mel_outputs = mel_outputs.transpose(1, 2) 341 | 342 | return mel_outputs, gate_outputs, alignments 343 | 344 | def decode(self, decoder_input): 345 | """ Decoder step using stored states, attention and memory 346 | PARAMS 347 | ------ 348 | decoder_input: previous mel output 349 | RETURNS 350 | ------- 351 | mel_output: 352 | gate_output: gate output energies 353 | attention_weights: 354 | """ 355 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 356 | self.attention_hidden, self.attention_cell = self.attention_rnn( 357 | cell_input, (self.attention_hidden, self.attention_cell)) 358 | self.attention_hidden = F.dropout( 359 | self.attention_hidden, self.p_attention_dropout, self.training) 360 | 361 | attention_weights_cat = torch.cat( 362 | (self.attention_weights.unsqueeze(1), 363 | self.attention_weights_cum.unsqueeze(1)), dim=1) 364 | self.attention_context, self.attention_weights = self.attention_layer( 365 | self.attention_hidden, self.memory, self.processed_memory, 366 | attention_weights_cat, self.mask) 367 | 368 | self.attention_weights_cum += self.attention_weights 369 | decoder_input = torch.cat( 370 | (self.attention_hidden, self.attention_context), -1) 371 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 372 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 373 | self.decoder_hidden = F.dropout( 374 | self.decoder_hidden, self.p_decoder_dropout, self.training) 375 | 376 | decoder_hidden_attention_context = torch.cat( 377 | (self.decoder_hidden, self.attention_context), dim=1) 378 | decoder_output = self.linear_projection( 379 | decoder_hidden_attention_context) 380 | 381 | gate_prediction = self.gate_layer(decoder_hidden_attention_context) 382 | return decoder_output, gate_prediction, self.attention_weights 383 | 384 | def forward(self, memory, decoder_inputs, memory_lengths): 385 | """ Decoder forward pass for training 386 | PARAMS 387 | ------ 388 | memory: Encoder outputs 389 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs 390 | memory_lengths: Encoder output lengths for attention masking. 391 | RETURNS 392 | ------- 393 | mel_outputs: mel outputs from the decoder 394 | gate_outputs: gate outputs from the decoder 395 | alignments: sequence of attention weights from the decoder 396 | """ 397 | 398 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 399 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 400 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 401 | decoder_inputs = self.prenet(decoder_inputs) 402 | 403 | self.initialize_decoder_states( 404 | memory, mask = ~get_mask_from_lengths(memory_lengths)) 405 | 406 | mel_outputs, gate_outputs, alignments = [], [], [] 407 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 408 | decoder_input = decoder_inputs[len(mel_outputs)] 409 | mel_output, gate_output, attention_weights = self.decode( 410 | decoder_input) 411 | mel_outputs += [mel_output.squeeze(1)] 412 | gate_outputs += [gate_output.squeeze()] 413 | alignments += [attention_weights] 414 | 415 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 416 | mel_outputs, gate_outputs, alignments) 417 | 418 | return mel_outputs, gate_outputs, alignments 419 | 420 | def inference(self, memory): 421 | """ Decoder inference 422 | PARAMS 423 | ------ 424 | memory: Encoder outputs 425 | RETURNS 426 | ------- 427 | mel_outputs: mel outputs from the decoder 428 | gate_outputs: gate outputs from the decoder 429 | alignments: sequence of attention weights from the decoder 430 | """ 431 | decoder_input = self.get_go_frame(memory) 432 | 433 | self.initialize_decoder_states(memory, mask=None) 434 | 435 | mel_outputs, gate_outputs, alignments = [], [], [] 436 | while True: 437 | decoder_input = self.prenet(decoder_input) 438 | mel_output, gate_output, alignment = self.decode(decoder_input) 439 | 440 | mel_outputs += [mel_output.squeeze(1)] 441 | gate_outputs += [gate_output] 442 | alignments += [alignment] 443 | 444 | if torch.sigmoid(gate_output.data) > self.gate_threshold: 445 | break 446 | elif len(mel_outputs) == self.max_decoder_steps: 447 | print("Warning! Reached max decoder steps") 448 | break 449 | 450 | decoder_input = mel_output 451 | 452 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 453 | mel_outputs, gate_outputs, alignments) 454 | 455 | return mel_outputs, gate_outputs, alignments 456 | 457 | 458 | class Tacotron2(nn.Module): 459 | def __init__(self, hparams): 460 | super(Tacotron2, self).__init__() 461 | self.device = hparams.device 462 | self.mask_padding = hparams.mask_padding 463 | self.fp16_run = hparams.fp16_run 464 | self.n_mel_channels = hparams.n_mel_channels 465 | self.n_frames_per_step = hparams.n_frames_per_step 466 | self.embedding = nn.Embedding( 467 | hparams.n_symbols, hparams.symbols_embedding_dim) 468 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 469 | val = sqrt(3.0) * std # uniform bounds for std 470 | self.embedding.weight.data.uniform_(-val, val) 471 | self.encoder = Encoder(hparams) 472 | self.decoder = Decoder(hparams) 473 | self.postnet = Postnet(hparams) 474 | 475 | def parse_batch(self, batch): 476 | text_padded, input_lengths, mel_padded, gate_padded, \ 477 | output_lengths = batch 478 | text_padded = text_padded.long().to(self.device) 479 | input_lengths = input_lengths.long().to(self.device) 480 | max_len = torch.max(input_lengths.data).item() 481 | mel_padded = mel_padded.float().to(self.device) 482 | gate_padded = gate_padded.float().to(self.device) 483 | output_lengths = output_lengths.long().to(self.device) 484 | 485 | return ( 486 | (text_padded, input_lengths, mel_padded, max_len, output_lengths), 487 | (mel_padded, gate_padded)) 488 | 489 | def parse_output(self, outputs, output_lengths=None): 490 | if self.mask_padding and output_lengths is not None: 491 | mask = ~get_mask_from_lengths(output_lengths) 492 | mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) 493 | mask = mask.permute(1, 0, 2) 494 | 495 | outputs[0].data.masked_fill_(mask, 0.0) 496 | outputs[1].data.masked_fill_(mask, 0.0) 497 | outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies 498 | 499 | return outputs 500 | 501 | def forward(self, inputs): 502 | text_inputs, text_lengths, mels, max_len, output_lengths = inputs 503 | text_lengths, output_lengths = text_lengths.data, output_lengths.data 504 | 505 | embedded_inputs = self.embedding(text_inputs).transpose(1, 2) 506 | 507 | encoder_outputs = self.encoder(embedded_inputs, text_lengths) 508 | 509 | mel_outputs, gate_outputs, alignments = self.decoder( 510 | encoder_outputs, mels, memory_lengths=text_lengths) 511 | 512 | mel_outputs_postnet = self.postnet(mel_outputs) 513 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 514 | 515 | return self.parse_output( 516 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], 517 | output_lengths) 518 | 519 | def inference(self, inputs): 520 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 521 | encoder_outputs = self.encoder.inference(embedded_inputs) 522 | mel_outputs, gate_outputs, alignments = self.decoder.inference( 523 | encoder_outputs) 524 | 525 | mel_outputs_postnet = self.postnet(mel_outputs) 526 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 527 | 528 | outputs = self.parse_output( 529 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) 530 | 531 | return outputs 532 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.7.2 2 | torch==1.5.0 3 | tensorboardX==2.1 4 | soundfile==0.10.3.post1 5 | pinyin==0.4.0 6 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomicoo/Tacotron2-PyTorch/71eadf2abafa201648dc16cbfb742ab0032077c9/synthesize.py -------------------------------------------------------------------------------- /tests/test_librosa.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import soundfile 4 | 5 | fullpath = '../audios/LJ001-0007.wav' 6 | sampling_rate = 22050 7 | 8 | y, sr = librosa.core.load(fullpath, sampling_rate) 9 | print(y.shape) 10 | 11 | print('np.mean(y): ' + str(np.mean(y))) 12 | print('np.max(y): ' + str(np.max(y))) 13 | print('np.min(y): ' + str(np.min(y))) 14 | 15 | soundfile.write('test.wav', y, sampling_rate) 16 | 17 | y, sr = librosa.core.load('test.wav') 18 | # Trim the beginning and ending silence 19 | yt, index = librosa.effects.trim(y) 20 | print(index) 21 | # Print the durations 22 | print(librosa.core.get_duration(y), librosa.core.get_duration(yt)) 23 | print(len(y), len(yt)) 24 | soundfile.write('test2.wav', y, sampling_rate) 25 | -------------------------------------------------------------------------------- /tests/test_number.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from datasets.text.numbers import normalize_numbers 4 | 5 | if __name__ == '__main__': 6 | num = '12345.123' 7 | ch = normalize_numbers(num) 8 | print(ch) 9 | 10 | num = '12/123' 11 | ch = normalize_numbers(num) 12 | print(ch) 13 | -------------------------------------------------------------------------------- /tests/test_pinyin.py: -------------------------------------------------------------------------------- 1 | import pinyin 2 | 3 | text = "必须树立公共交通优先发展的理念" 4 | text = pinyin.get(text, format="numerical", delimiter=" ") 5 | print(text) 6 | -------------------------------------------------------------------------------- /tests/test_sequence.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | sys.path.append('..') 4 | from datasets.text.sequence import Sequence 5 | 6 | seq = Sequence(graphemes_or_phonemes=list('abcdefghijklmnopqrstuvwxyz12345 ')) 7 | sequence = seq.text_to_sequence('bi4 xu1 shu4 li4 gong1 gong4 jiao1 tong1 you1 xian1 fa1 zhan3 de5 li3 nian4') 8 | print(sequence) 9 | -------------------------------------------------------------------------------- /tests/test_stft.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import librosa 4 | sys.path.append('..') 5 | from datasets.audio.stft import TacotronSTFT 6 | from utils.plot import plot_spectrogram 7 | 8 | fullpath = '../audios/LJ001-0007.wav' 9 | 10 | filter_length = 1024 11 | hop_length = 256 12 | win_length = 1024 13 | n_mel_channels = 80 14 | sampling_rate = 22050 15 | mel_fmin = 0.0 # 80.0 16 | mel_fmax = 8000.0 # 7600.0 17 | 18 | stft = TacotronSTFT(filter_length=filter_length, 19 | hop_length=hop_length, 20 | win_length=win_length, 21 | n_mel_channels=n_mel_channels, 22 | sampling_rate=sampling_rate, 23 | mel_fmin=mel_fmin, 24 | mel_fmax=mel_fmax) 25 | 26 | wav, sr = librosa.load(fullpath, sr=None) 27 | 28 | assert sr == sampling_rate 29 | 30 | wav = torch.from_numpy(wav).unsqueeze(0) 31 | mel = stft.mel_spectrogram(wav).squeeze(0).t() 32 | 33 | print(mel.size()) 34 | plot_spectrogram(pred_spectrogram=mel, save_img=True, path='test.png') 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | # from torch import nn 6 | import torch.optim as optim 7 | from tqdm import tqdm 8 | 9 | import config 10 | from datasets import Text2MelDataset, Text2MelDataLoader 11 | from models import Tacotron2 12 | from models.losses import Tacotron2Loss, AverageMeter 13 | from models.optimizers import Tacotron2Optimizer 14 | from helpers.logger import Logger 15 | from utils.common import save_checkpoint, load_checkpoint 16 | 17 | 18 | def run(args): 19 | torch.manual_seed(args.seed) 20 | np.random.seed(args.seed) 21 | logdir = args.logdir 22 | checkpoint = args.checkpoint 23 | start_epoch = 0 24 | best_loss = float('inf') 25 | epochs_since_improvement = 0 26 | 27 | # Initialize / load checkpoint 28 | if checkpoint is None: 29 | # model 30 | model = Tacotron2(config) 31 | # optimizer 32 | optimizer = Tacotron2Optimizer( 33 | optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2, betas=(0.9, 0.999), eps=1e-6)) 34 | 35 | else: 36 | start_epoch, epochs_since_improvement, model, optimizer, best_loss = load_checkpoint(logdir, checkpoint) 37 | 38 | logger = Logger(config.logdir, config.experiment, 'tacotron2') 39 | 40 | # Move to GPU, if available 41 | model = model.to(config.device) 42 | 43 | criterion = Tacotron2Loss() 44 | 45 | # Custom dataloaders 46 | train_dataset = Text2MelDataset(config.train_files, config) 47 | train_loader = Text2MelDataLoader(train_dataset, config, shuffle=True, 48 | num_workers=args.num_workers, pin_memory=True) 49 | valid_dataset = Text2MelDataset(config.valid_files, config) 50 | valid_loader = Text2MelDataLoader(valid_dataset, config, shuffle=False, 51 | num_workers=args.num_workers, pin_memory=True) 52 | 53 | # Epochs 54 | for epoch in range(start_epoch, args.epochs): 55 | # One epoch's training 56 | train_loss = train(train_loader=train_loader, 57 | model=model, 58 | optimizer=optimizer, 59 | criterion=criterion, 60 | epoch=epoch, 61 | logger=logger) 62 | 63 | lr = optimizer.lr 64 | print('\nLearning rate: {}'.format(lr)) 65 | step_num = optimizer.step_num 66 | print('Step num: {}\n'.format(step_num)) 67 | 68 | scalar_dict = { 'train_epoch_loss': train_loss, 69 | 'learning_rate': lr } 70 | logger.log_epoch('train', epoch, scalar_dict=scalar_dict) 71 | 72 | # One epoch's validation 73 | valid_loss = valid(valid_loader=valid_loader, 74 | model=model, 75 | criterion=criterion, 76 | logger=logger) 77 | 78 | # Check if there was an improvement 79 | is_best = valid_loss < best_loss 80 | best_loss = min(valid_loss, best_loss) 81 | if not is_best: 82 | epochs_since_improvement += 1 83 | print("\nEpochs since last improvement: {}\n".format(epochs_since_improvement)) 84 | else: 85 | epochs_since_improvement = 0 86 | 87 | scalar_dict = { 'valid_epoch_loss': valid_loss } 88 | logger.log_epoch('valid', epoch, scalar_dict=scalar_dict) 89 | 90 | # Save checkpoint 91 | if epoch % args.save_freq == 0: 92 | save_checkpoint(logdir, epoch, epochs_since_improvement, model, optimizer, best_loss, is_best) 93 | 94 | # # alignments 95 | # img_align = test(model, optimizer.step_num, valid_loss) 96 | # writer.add_image('model/alignment', img_align, epoch, dataformats='HWC') 97 | 98 | 99 | def train(train_loader, model, optimizer, criterion, epoch, logger): 100 | model.train() # train mode (dropout and batchnorm is used) 101 | 102 | losses = AverageMeter() 103 | 104 | # Batches 105 | for i, batch in enumerate(train_loader): 106 | model.zero_grad() 107 | x, y = model.parse_batch(batch) 108 | 109 | # Forward prop. 110 | y_pred = model(x) 111 | 112 | loss = criterion(y_pred, y) 113 | 114 | # Back prop. 115 | optimizer.zero_grad() 116 | loss.backward() 117 | 118 | # Update weights 119 | optimizer.step() 120 | 121 | # Keep track of metrics 122 | losses.update(loss.item()) 123 | 124 | # Print status 125 | if i % args.print_freq == 0: 126 | print('Epoch: [{0}][{1}/{2}]\t' 127 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader), loss=losses)) 128 | 129 | scalar_dict = { 'train_step_loss': loss.item() } 130 | logger.log_step('train', optimizer.step_num, scalar_dict) 131 | 132 | return losses.avg 133 | 134 | 135 | def valid(valid_loader, model, criterion, logger): 136 | model.eval() 137 | 138 | losses = AverageMeter() 139 | 140 | # Batches 141 | for batch in tqdm(valid_loader): 142 | model.zero_grad() 143 | x, y = model.parse_batch(batch) 144 | 145 | # Forward prop. 146 | y_pred = model(x) 147 | 148 | loss = criterion(y_pred, y) 149 | 150 | # Keep track of metrics 151 | losses.update(loss.item()) 152 | 153 | # Print status 154 | print('\nValid Loss {loss.val:.4f} ({loss.avg:.4f})\n'.format(loss=losses)) 155 | 156 | return losses.avg 157 | 158 | 159 | def parse_args(): 160 | parser = argparse.ArgumentParser(description='Tacotron2') 161 | parser.add_argument('--epochs', default=10000, type=int) 162 | parser.add_argument('--max_grad_norm', default=1, type=float, help='Gradient norm threshold to clip') 163 | # minibatch 164 | parser.add_argument('--batch_size', default=4, type=int) 165 | parser.add_argument('--num_workers', default=0, type=int, help='Number of workers to generate minibatch') 166 | # logging 167 | parser.add_argument('--logdir', default='logdir', type=str, help='Logging directory') 168 | parser.add_argument('--print_freq', default=1, type=int, help='Frequency of printing training information') 169 | parser.add_argument('--save_freq', default=1, type=int, help='Frequency of saving model checkpoint') 170 | # optimizer 171 | parser.add_argument('--lr', default=1e-3, type=float, help='Init learning rate') 172 | parser.add_argument('--l2', default=1e-6, type=float, help='weight decay (L2)') 173 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 174 | # others 175 | parser.add_argument('--seed', type=int, default=2021, help='random seed') 176 | args = parser.parse_args() 177 | return args 178 | 179 | 180 | def main(): 181 | global args 182 | args = parse_args() 183 | run(args) 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomicoo/Tacotron2-PyTorch/71eadf2abafa201648dc16cbfb742ab0032077c9/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import math 5 | import time 6 | import requests 7 | import zipfile, tarfile, gzip 8 | import torch 9 | from glob import glob 10 | from tqdm import tqdm 11 | 12 | 13 | def download_file(url, filepath): 14 | """Downloads a file from the given URL.""" 15 | print("Downloading %s..." % url) 16 | r = requests.get(url, stream=True) 17 | total_size = int(r.headers.get('content-length', 0)) 18 | block_size = 1024 * 1024 19 | wrote = 0 20 | with open(filepath, 'wb') as f: 21 | for data in tqdm(r.iter_content(block_size), total=math.ceil(total_size // block_size), unit='MB'): 22 | wrote = wrote + len(data) 23 | f.write(data) 24 | 25 | if total_size != 0 and wrote != total_size: 26 | print("Downloading failed") 27 | sys.exit(1) 28 | 29 | def extract_gzfile(filepath, dstdir='data'): 30 | os.makedirs(dstdir, exist_ok=True) 31 | filename = osp.basename(filepath) 32 | print('Extracting {}...'.format(filename)) 33 | gz = gzip.GzipFile(filepath, 'r') 34 | filename = filename.replace('.gz', '') 35 | open(osp.join(dstdir, filename), 'w+').write(gz.read()) 36 | gz.close() 37 | 38 | def extract_zipfile(filepath, dstdir='data'): 39 | os.makedirs(dstdir, exist_ok=True) 40 | filename = osp.basename(filepath) 41 | print('Extracting {}...'.format(filename)) 42 | zip = zipfile.ZipFile(filepath, 'r') 43 | zip.extractall(dstdir) 44 | zip.close() 45 | 46 | def extract_tarfile(filepath, dstdir='data'): 47 | os.makedirs(dstdir, exist_ok=True) 48 | filename = osp.basename(filepath) 49 | print('Extracting {}...'.format(filename)) 50 | tar = tarfile.TarFile(filepath, 'r') 51 | tar.extractall(dstdir) 52 | tar.close() 53 | 54 | def get_last_checkpoint(dstdir): 55 | """Returns the last checkpoint file name in the given dstdir path.""" 56 | checkpoints = glob(osp.join(dstdir, '*.pth')) 57 | checkpoints.sort() 58 | if len(checkpoints) == 0: 59 | return None 60 | return checkpoints[-1] 61 | 62 | def save_checkpoint(logdir, epoch, epochs_since_improvement, model, optimizer, loss, is_best): 63 | state_dict = { 64 | 'epoch': epoch, 65 | 'epochs_since_improvement': epochs_since_improvement, 66 | 'loss': loss, 67 | 'model': model, 68 | 'optimizer': optimizer 69 | } 70 | checkpoint_file_name = 'final_checkpoint.pth' 71 | torch.save(state_dict, osp.join(logdir, checkpoint_file_name)) 72 | print(f"Saved the checkpoint (epoch={epoch:04d}) to '{checkpoint_file_name}'") 73 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 74 | if is_best: 75 | torch.save(state_dict, osp.join(logdir, 'best_checkpoint.pth')) 76 | print(f"Saved the checkpoint (epoch={epoch:04d}) to 'best_checkpoint.pth'") 77 | 78 | def load_checkpoint(logdir, checkpoint_file_name=None): 79 | """Loads the checkpoint into the given model and optimizer.""" 80 | checkpoint_file_name = checkpoint_file_name \ 81 | if checkpoint_file_name is None else 'final_checkpoint.pth' 82 | checkpoint = torch.load(osp.join(logdir, checkpoint_file_name)) 83 | epoch = checkpoint['epoch'] + 1 84 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 85 | loss = checkpoint['loss'] 86 | model = checkpoint['model'] 87 | optimizer = checkpoint['optimizer'] 88 | print(f"Loaded the checkpoint (epoch={epoch:04d}) from '{checkpoint_file_name}'") 89 | return epoch, epochs_since_improvement, model, optimizer, loss 90 | 91 | -------------------------------------------------------------------------------- /utils/hparam.py: -------------------------------------------------------------------------------- 1 | class HParams: 2 | def __init__(self): 3 | self.n_mel_channels = None 4 | self.dynamic_loss_scaling = True 5 | self.fp16_run = False 6 | self.distributed_run = False 7 | 8 | ################################ 9 | # Data Parameters # 10 | ################################ 11 | self.load_mel_from_disk = False 12 | 13 | ################################ 14 | # Audio Parameters # 15 | ################################ 16 | self.max_wav_value = 32768.0 17 | self.sampling_rate = 22050 18 | self.filter_length = 1024 19 | self.hop_length = 256 20 | self.win_length = 1024 21 | self.n_mel_channels = 80 22 | self.mel_fmin = 0.0 23 | self.mel_fmax = 8000.0 24 | 25 | ################################ 26 | # Model Parameters # 27 | ################################ 28 | self.n_symbols = 35 29 | self.symbols_embedding_dim = 512 30 | 31 | # Encoder parameters 32 | self.encoder_kernel_size = 5 33 | self.encoder_n_convolutions = 3 34 | self.encoder_embedding_dim = 512 35 | 36 | # Decoder parameters 37 | self.n_frames_per_step = 1 # currently only 1 is supported 38 | self.decoder_rnn_dim = 1024 39 | self.prenet_dim = 256 40 | self.max_decoder_steps = 1000 41 | self.gate_threshold = 0.5 42 | self.p_attention_dropout = 0.1 43 | self.p_decoder_dropout = 0.1 44 | 45 | # Attention parameters 46 | self.attention_rnn_dim = 1024 47 | self.attention_dim = 128 48 | 49 | # Location Layer parameters 50 | self.attention_location_n_filters = 32 51 | self.attention_location_kernel_size = 31 52 | 53 | # Mel-post processing network parameters 54 | self.postnet_embedding_dim = 512 55 | self.postnet_kernel_size = 5 56 | self.postnet_n_convolutions = 5 57 | 58 | ################################ 59 | # Optimization Hyperparameters # 60 | ################################ 61 | self.learning_rate = 1e-3 62 | self.weight_decay = 1e-6 63 | self.batch_size = 64 64 | self.mask_padding = True # set model's padded outputs to padded values 65 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | import numpy as np 6 | 7 | 8 | def split_title_line(title_text, max_words=5): 9 | """ 10 | A function that splits any string based on specific character 11 | (returning it with the string), with maximum number of words on it 12 | """ 13 | seq = title_text.split() 14 | return '\n'.join([' '.join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)]) 15 | 16 | def plot_alignment(alignment, title=None, split_title=False, max_len=None, 17 | save_img=False, path=None): 18 | if max_len is not None: 19 | alignment = alignment[:, :max_len] 20 | 21 | fig = plt.figure(figsize=(8, 6)) 22 | ax = fig.add_subplot(111) 23 | 24 | im = ax.imshow( 25 | alignment, 26 | aspect='auto', 27 | origin='lower', 28 | interpolation='none') 29 | fig.colorbar(im, ax=ax) 30 | xlabel = 'Encoder timestep' 31 | 32 | if split_title: 33 | title = split_title_line(title) 34 | 35 | plt.xlabel(xlabel) 36 | plt.title(title) 37 | plt.ylabel('Decoder timestep') 38 | plt.tight_layout() 39 | if save_img: 40 | assert path is not None, "The 'path' must be not None when 'save_img' is True." 41 | plt.savefig(path, format='png') 42 | plt.close() 43 | 44 | return fig 45 | 46 | 47 | def plot_spectrogram(pred_spectrogram, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, 48 | save_img=False, path=None): 49 | if max_len is not None: 50 | target_spectrogram = target_spectrogram[:max_len] 51 | pred_spectrogram = pred_spectrogram[:max_len] 52 | 53 | if split_title: 54 | title = split_title_line(title) 55 | 56 | fig = plt.figure(figsize=(10, 8)) 57 | # Set common labels 58 | fig.text(0.5, 0.18, title, horizontalalignment='center', fontsize=16) 59 | 60 | #target spectrogram subplot 61 | if target_spectrogram is not None: 62 | ax1 = fig.add_subplot(311) 63 | ax2 = fig.add_subplot(312) 64 | 65 | if auto_aspect: 66 | im = ax1.imshow(np.rot90(target_spectrogram), aspect='auto', interpolation='none') 67 | else: 68 | im = ax1.imshow(np.rot90(target_spectrogram), interpolation='none') 69 | ax1.set_title('Target Mel-Spectrogram') 70 | fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax1) 71 | ax2.set_title('Predicted Mel-Spectrogram') 72 | else: 73 | ax2 = fig.add_subplot(211) 74 | 75 | if auto_aspect: 76 | im = ax2.imshow(np.rot90(pred_spectrogram), aspect='auto', interpolation='none') 77 | else: 78 | im = ax2.imshow(np.rot90(pred_spectrogram), interpolation='none') 79 | fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax2) 80 | 81 | plt.tight_layout() 82 | if save_img: 83 | assert path is not None, "The 'path' must be not None when 'save_img' is True." 84 | plt.savefig(path, format='png') 85 | plt.close() 86 | 87 | return fig 88 | --------------------------------------------------------------------------------