├── Audio ├── __init__.py ├── audio_processing.py ├── hparams.py ├── stft.py └── tools.py ├── README.md ├── data └── ljspeech.py ├── data_utils.py ├── glow.py ├── hparams.py ├── img └── model_test.jpg ├── inference.py ├── layers.py ├── loss_function.py ├── network.py ├── preprocess.py ├── results ├── Generative adversarial network or variational auto-encoder.76000waveglow.wav ├── I am very happy to see you again!76000griffin_lim.wav ├── I am very happy to see you again!76000waveglow.wav ├── Jack is a little goose. He has a lovely hat. He likes wearing it very much. But when he sits, his hat can’t stay on his head.76000waveglow.wav ├── Lift humanity with cognitive artificial intelligence platforms.76000waveglow.wav └── Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.76000waveglow.wav ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py ├── train.py ├── utils.py └── waveglow ├── __init__.py ├── convert_model.py ├── glow.py ├── inference.py └── mel2samp.py /Audio/__init__.py: -------------------------------------------------------------------------------- 1 | import Audio.hparams 2 | import Audio.tools 3 | import Audio.stft 4 | import Audio.audio_processing 5 | -------------------------------------------------------------------------------- /Audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /Audio/hparams.py: -------------------------------------------------------------------------------- 1 | max_wav_value = 32768.0 2 | sampling_rate = 22050 3 | filter_length = 1024 4 | hop_length = 256 5 | win_length = 1024 6 | n_mel_channels = 80 7 | mel_fmin = 0.0 8 | mel_fmax = 8000.0 9 | -------------------------------------------------------------------------------- /Audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | from scipy.signal import get_window 7 | from librosa.util import pad_center, tiny 8 | from librosa.filters import mel as librosa_mel_fn 9 | 10 | from Audio.audio_processing import dynamic_range_compression 11 | from Audio.audio_processing import dynamic_range_decompression 12 | from Audio.audio_processing import window_sumsquare 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 19 | window='hann'): 20 | super(STFT, self).__init__() 21 | self.filter_length = filter_length 22 | self.hop_length = hop_length 23 | self.win_length = win_length 24 | self.window = window 25 | self.forward_transform = None 26 | scale = self.filter_length / self.hop_length 27 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 28 | 29 | cutoff = int((self.filter_length / 2 + 1)) 30 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 31 | np.imag(fourier_basis[:cutoff, :])]) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 36 | 37 | if window is not None: 38 | assert(filter_length >= win_length) 39 | # get window and zero center pad it to filter_length 40 | fft_window = get_window(window, win_length, fftbins=True) 41 | fft_window = pad_center(fft_window, filter_length) 42 | fft_window = torch.from_numpy(fft_window).float() 43 | 44 | # window the bases 45 | forward_basis *= fft_window 46 | inverse_basis *= fft_window 47 | 48 | self.register_buffer('forward_basis', forward_basis.float()) 49 | self.register_buffer('inverse_basis', inverse_basis.float()) 50 | 51 | def transform(self, input_data): 52 | num_batches = input_data.size(0) 53 | num_samples = input_data.size(1) 54 | 55 | self.num_samples = num_samples 56 | 57 | # similar to librosa, reflect-pad the input 58 | input_data = input_data.view(num_batches, 1, num_samples) 59 | input_data = F.pad( 60 | input_data.unsqueeze(1), 61 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 62 | mode='reflect') 63 | input_data = input_data.squeeze(1) 64 | 65 | forward_transform = F.conv1d( 66 | input_data.cuda(), 67 | Variable(self.forward_basis, requires_grad=False).cuda(), 68 | stride=self.hop_length, 69 | padding=0).cpu() 70 | 71 | cutoff = int((self.filter_length / 2) + 1) 72 | real_part = forward_transform[:, :cutoff, :] 73 | imag_part = forward_transform[:, cutoff:, :] 74 | 75 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 76 | phase = torch.autograd.Variable( 77 | torch.atan2(imag_part.data, real_part.data)) 78 | 79 | return magnitude, phase 80 | 81 | def inverse(self, magnitude, phase): 82 | recombine_magnitude_phase = torch.cat( 83 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 84 | 85 | inverse_transform = F.conv_transpose1d( 86 | recombine_magnitude_phase, 87 | Variable(self.inverse_basis, requires_grad=False), 88 | stride=self.hop_length, 89 | padding=0) 90 | 91 | if self.window is not None: 92 | window_sum = window_sumsquare( 93 | self.window, magnitude.size(-1), hop_length=self.hop_length, 94 | win_length=self.win_length, n_fft=self.filter_length, 95 | dtype=np.float32) 96 | # remove modulation effects 97 | approx_nonzero_indices = torch.from_numpy( 98 | np.where(window_sum > tiny(window_sum))[0]) 99 | window_sum = torch.autograd.Variable( 100 | torch.from_numpy(window_sum), requires_grad=False) 101 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 102 | inverse_transform[:, :, 103 | approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 104 | 105 | # scale by hop ratio 106 | inverse_transform *= float(self.filter_length) / self.hop_length 107 | 108 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 109 | inverse_transform = inverse_transform[:, 110 | :, :-int(self.filter_length/2):] 111 | 112 | return inverse_transform 113 | 114 | def forward(self, input_data): 115 | self.magnitude, self.phase = self.transform(input_data) 116 | reconstruction = self.inverse(self.magnitude, self.phase) 117 | return reconstruction 118 | 119 | 120 | class TacotronSTFT(torch.nn.Module): 121 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 122 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 123 | mel_fmax=8000.0): 124 | super(TacotronSTFT, self).__init__() 125 | self.n_mel_channels = n_mel_channels 126 | self.sampling_rate = sampling_rate 127 | self.stft_fn = STFT(filter_length, hop_length, win_length) 128 | mel_basis = librosa_mel_fn( 129 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 130 | mel_basis = torch.from_numpy(mel_basis).float() 131 | self.register_buffer('mel_basis', mel_basis) 132 | 133 | def spectral_normalize(self, magnitudes): 134 | output = dynamic_range_compression(magnitudes) 135 | return output 136 | 137 | def spectral_de_normalize(self, magnitudes): 138 | output = dynamic_range_decompression(magnitudes) 139 | return output 140 | 141 | def mel_spectrogram(self, y): 142 | """Computes mel-spectrograms from a batch of waves 143 | PARAMS 144 | ------ 145 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 146 | 147 | RETURNS 148 | ------- 149 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 150 | """ 151 | assert(torch.min(y.data) >= -1) 152 | assert(torch.max(y.data) <= 1) 153 | 154 | magnitudes, phases = self.stft_fn.transform(y) 155 | magnitudes = magnitudes.data 156 | mel_output = torch.matmul(self.mel_basis, magnitudes) 157 | mel_output = self.spectral_normalize(mel_output) 158 | return mel_output 159 | -------------------------------------------------------------------------------- /Audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import read 4 | from scipy.io.wavfile import write 5 | 6 | import Audio.stft as stft 7 | import Audio.hparams as hparams 8 | from Audio.audio_processing import griffin_lim 9 | 10 | 11 | _stft = stft.TacotronSTFT( 12 | hparams.filter_length, hparams.hop_length, hparams.win_length, 13 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 14 | hparams.mel_fmax) 15 | 16 | 17 | # def _normalize(S): 18 | # return np.clip((S + 100.0) / 100.0, 0, 1) 19 | 20 | 21 | # def _denormalize(S): 22 | # return (np.clip(S, 0, 1) * 100.0) - 100.0 23 | 24 | 25 | def load_wav_to_torch(full_path): 26 | sampling_rate, data = read(full_path) 27 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 28 | 29 | 30 | def get_mel(filename): 31 | audio, sampling_rate = load_wav_to_torch(filename) 32 | if sampling_rate != _stft.sampling_rate: 33 | raise ValueError("{} {} SR doesn't match target {} SR".format( 34 | sampling_rate, _stft.sampling_rate)) 35 | audio_norm = audio / hparams.max_wav_value 36 | audio_norm = audio_norm.unsqueeze(0) 37 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 38 | melspec = _stft.mel_spectrogram(audio_norm) 39 | melspec = torch.squeeze(melspec, 0) 40 | # melspec = torch.from_numpy(_normalize(melspec.numpy())) 41 | 42 | return melspec 43 | 44 | 45 | def inv_mel_spec(mel, out_filename, griffin_iters=60): 46 | mel = torch.stack([mel]) 47 | # mel = torch.stack([torch.from_numpy(_denormalize(mel.numpy()))]) 48 | mel_decompress = _stft.spectral_de_normalize(mel) 49 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 50 | spec_from_mel_scaling = 1000 51 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 52 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 53 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 54 | 55 | audio = griffin_lim(torch.autograd.Variable( 56 | spec_from_mel[:, :, :-1]), _stft.stft_fn, griffin_iters) 57 | 58 | audio = audio.squeeze() 59 | audio = audio.cpu().numpy() 60 | audio_path = out_filename 61 | write(audio_path, hparams.sampling_rate, audio) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron2-Pytorch 2 | Support DataParallel 3 | 4 | ## Training 5 | 1. Put LJSpeech dataset in `data` 6 | 2. Run `python preprocess.py` 7 | 3. modify hyperparameters in `hparams.py` 8 | 4. Run `python train.py` 9 | 10 | ## Inference 11 | 1. Put [Nvidia pretrained waveglow model](https://drive.google.com/file/d/1WsibBTsuRg_SF2Z6L6NFRTT-NjEy1oTx/view?usp=sharing) in `waveglow/pre_trained_model` 12 | 2. Run `python inference.py` -------------------------------------------------------------------------------- /data/ljspeech.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import Audio 4 | 5 | 6 | def build_from_path(in_dir, out_dir): 7 | index = 1 8 | out = list() 9 | 10 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f: 11 | for line in f: 12 | parts = line.strip().split('|') 13 | wav_path = os.path.join(in_dir, 'wavs', '%s.wav' % parts[0]) 14 | text = parts[2] 15 | out.append(_process_utterance(out_dir, index, wav_path, text)) 16 | 17 | if index % 100 == 0: 18 | print("Done %d" % index) 19 | index = index + 1 20 | 21 | return out 22 | 23 | 24 | def _process_utterance(out_dir, index, wav_path, text): 25 | # Compute a mel-scale spectrogram from the wav: 26 | mel_spectrogram = Audio.tools.get_mel(wav_path).numpy().astype(np.float32) 27 | # print(mel_spectrogram) 28 | 29 | # Write the spectrograms to disk: 30 | mel_filename = 'ljspeech-mel-%05d.npy' % index 31 | np.save(os.path.join(out_dir, mel_filename), 32 | mel_spectrogram.T, allow_pickle=False) 33 | 34 | return text 35 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import math 4 | import numpy as np 5 | import os 6 | 7 | from text import text_to_sequence 8 | import hparams 9 | 10 | device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') 11 | 12 | 13 | class Tacotron2Dataset(Dataset): 14 | """ LJSpeech """ 15 | 16 | def __init__(self, dataset_path=hparams.dataset_path): 17 | self.dataset_path = dataset_path 18 | self.text_path = os.path.join(self.dataset_path, "train.txt") 19 | self.text = process_text(self.text_path) 20 | 21 | def __len__(self): 22 | return len(self.text) 23 | 24 | def __getitem__(self, idx): 25 | index = idx + 1 26 | mel_name = os.path.join( 27 | self.dataset_path, "ljspeech-mel-%05d.npy" % index) 28 | mel_target = np.load(mel_name) 29 | 30 | character = self.text[idx][0:len(self.text[idx])-1] 31 | character = text_to_sequence(character, hparams.text_cleaners) 32 | character = np.array(character) 33 | 34 | stop_token = np.array([0. for _ in range(mel_target.shape[0])]) 35 | stop_token[-1] = 1. 36 | 37 | sample = {"text": character, "mel_target": mel_target, 38 | "stop_token": stop_token} 39 | 40 | return sample 41 | 42 | 43 | def process_text(train_text_path): 44 | with open(train_text_path, "r", encoding="utf-8") as f: 45 | txt = [] 46 | for line in f.readlines(): 47 | txt.append(line) 48 | 49 | return txt 50 | 51 | 52 | def _process(batch, cut_list): 53 | texts = [batch[ind]["text"] for ind in cut_list] 54 | mel_targets = [batch[ind]["mel_target"] for ind in cut_list] 55 | stop_tokens = [batch[ind]["stop_token"] for ind in cut_list] 56 | 57 | length_text = np.array([]) 58 | for text in texts: 59 | length_text = np.append(length_text, text.shape[0]) 60 | 61 | length_mel = np.array([]) 62 | for mel in mel_targets: 63 | length_mel = np.append(length_mel, mel.shape[0]) 64 | 65 | texts = pad_normal(texts) 66 | stop_tokens = pad_normal(stop_tokens, PAD=1.) 67 | mel_targets = pad_mel(mel_targets) 68 | 69 | out = {"text": texts, "mel_target": mel_targets, "stop_token": stop_tokens, 70 | "length_mel": length_mel, "length_text": length_text} 71 | 72 | return out 73 | 74 | 75 | def collate_fn(batch): 76 | len_arr = np.array([d["text"].shape[0] for d in batch]) 77 | index_arr = np.argsort(-len_arr) 78 | batchsize = len(batch) 79 | real_batchsize = int(math.sqrt(batchsize)) 80 | 81 | cut_list = list() 82 | for i in range(real_batchsize): 83 | cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize]) 84 | 85 | output = list() 86 | for i in range(real_batchsize): 87 | output.append(_process(batch, cut_list[i])) 88 | 89 | return output 90 | 91 | 92 | def pad_normal(inputs, PAD=0): 93 | 94 | def pad_data(x, length, PAD): 95 | x_padded = np.pad( 96 | x, (0, length - x.shape[0]), mode='constant', constant_values=PAD) 97 | return x_padded 98 | 99 | max_len = max((len(x) for x in inputs)) 100 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 101 | 102 | return padded 103 | 104 | 105 | def pad_mel(inputs): 106 | 107 | def pad(x, max_len): 108 | PAD = 0 109 | if np.shape(x)[0] > max_len: 110 | raise ValueError("not max_len") 111 | 112 | s = np.shape(x)[1] 113 | x_padded = np.pad(x, (0, max_len - np.shape(x) 114 | [0]), mode='constant', constant_values=PAD) 115 | return x_padded[:, :s] 116 | 117 | max_len = max(np.shape(x)[0] for x in inputs) 118 | mel_output = np.stack([pad(x, max_len) for x in inputs]) 119 | 120 | return mel_output 121 | -------------------------------------------------------------------------------- /glow.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import copy 28 | import torch 29 | from torch.autograd import Variable 30 | import torch.nn.functional as F 31 | 32 | 33 | @torch.jit.script 34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 35 | n_channels_int = n_channels[0] 36 | in_act = input_a+input_b 37 | t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :]) 38 | s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :]) 39 | acts = t_act * s_act 40 | return acts 41 | 42 | 43 | class WaveGlowLoss(torch.nn.Module): 44 | def __init__(self, sigma=1.0): 45 | super(WaveGlowLoss, self).__init__() 46 | self.sigma = sigma 47 | 48 | def forward(self, model_output): 49 | z, log_s_list, log_det_W_list = model_output 50 | for i, log_s in enumerate(log_s_list): 51 | if i == 0: 52 | log_s_total = torch.sum(log_s) 53 | log_det_W_total = log_det_W_list[i] 54 | else: 55 | log_s_total = log_s_total + torch.sum(log_s) 56 | log_det_W_total += log_det_W_list[i] 57 | 58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total 59 | return loss/(z.size(0)*z.size(1)*z.size(2)) 60 | 61 | 62 | class Invertible1x1Conv(torch.nn.Module): 63 | """ 64 | The layer outputs both the convolution, and the log determinant 65 | of its weight matrix. If reverse=True it does convolution with 66 | inverse 67 | """ 68 | def __init__(self, c): 69 | super(Invertible1x1Conv, self).__init__() 70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, 71 | bias=False) 72 | 73 | # Sample a random orthonormal matrix to initialize weights 74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0] 75 | 76 | # Ensure determinant is 1.0 not -1.0 77 | if torch.det(W) < 0: 78 | W[:,0] = -1*W[:,0] 79 | W = W.view(c, c, 1) 80 | self.conv.weight.data = W 81 | 82 | def forward(self, z, reverse=False): 83 | # shape 84 | batch_size, group_size, n_of_groups = z.size() 85 | 86 | W = self.conv.weight.squeeze() 87 | 88 | if reverse: 89 | if not hasattr(self, 'W_inverse'): 90 | # Reverse computation 91 | W_inverse = W.inverse() 92 | W_inverse = Variable(W_inverse[..., None]) 93 | if z.type() == 'torch.cuda.HalfTensor': 94 | W_inverse = W_inverse.half() 95 | self.W_inverse = W_inverse 96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) 97 | return z 98 | else: 99 | # Forward computation 100 | log_det_W = batch_size * n_of_groups * torch.logdet(W) 101 | z = self.conv(z) 102 | return z, log_det_W 103 | 104 | 105 | class WN(torch.nn.Module): 106 | """ 107 | This is the WaveNet like layer for the affine coupling. The primary difference 108 | from WaveNet is the convolutions need not be causal. There is also no dilation 109 | size reset. The dilation only doubles on each layer 110 | """ 111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 112 | kernel_size): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | assert(n_channels % 2 == 0) 116 | self.n_layers = n_layers 117 | self.n_channels = n_channels 118 | self.in_layers = torch.nn.ModuleList() 119 | self.res_skip_layers = torch.nn.ModuleList() 120 | self.cond_layers = torch.nn.ModuleList() 121 | 122 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 123 | start = torch.nn.utils.weight_norm(start, name='weight') 124 | self.start = start 125 | 126 | # Initializing last layer to 0 makes the affine coupling layers 127 | # do nothing at first. This helps with training stability 128 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 129 | end.weight.data.zero_() 130 | end.bias.data.zero_() 131 | self.end = end 132 | 133 | for i in range(n_layers): 134 | dilation = 2 ** i 135 | padding = int((kernel_size*dilation - dilation)/2) 136 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 137 | dilation=dilation, padding=padding) 138 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 139 | self.in_layers.append(in_layer) 140 | 141 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 142 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 143 | self.cond_layers.append(cond_layer) 144 | 145 | # last one is not necessary 146 | if i < n_layers - 1: 147 | res_skip_channels = 2*n_channels 148 | else: 149 | res_skip_channels = n_channels 150 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 151 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 152 | self.res_skip_layers.append(res_skip_layer) 153 | 154 | def forward(self, forward_input): 155 | audio, spect = forward_input 156 | audio = self.start(audio) 157 | 158 | for i in range(self.n_layers): 159 | acts = fused_add_tanh_sigmoid_multiply( 160 | self.in_layers[i](audio), 161 | self.cond_layers[i](spect), 162 | torch.IntTensor([self.n_channels])) 163 | 164 | res_skip_acts = self.res_skip_layers[i](acts) 165 | if i < self.n_layers - 1: 166 | audio = res_skip_acts[:,:self.n_channels,:] + audio 167 | skip_acts = res_skip_acts[:,self.n_channels:,:] 168 | else: 169 | skip_acts = res_skip_acts 170 | 171 | if i == 0: 172 | output = skip_acts 173 | else: 174 | output = skip_acts + output 175 | return self.end(output) 176 | 177 | 178 | class WaveGlow(torch.nn.Module): 179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 180 | n_early_size, WN_config): 181 | super(WaveGlow, self).__init__() 182 | 183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 184 | n_mel_channels, 185 | 1024, stride=256) 186 | assert(n_group % 2 == 0) 187 | self.n_flows = n_flows 188 | self.n_group = n_group 189 | self.n_early_every = n_early_every 190 | self.n_early_size = n_early_size 191 | self.WN = torch.nn.ModuleList() 192 | self.convinv = torch.nn.ModuleList() 193 | 194 | n_half = int(n_group/2) 195 | 196 | # Set up layers with the right sizes based on how many dimensions 197 | # have been output already 198 | n_remaining_channels = n_group 199 | for k in range(n_flows): 200 | if k % self.n_early_every == 0 and k > 0: 201 | n_half = n_half - int(self.n_early_size/2) 202 | n_remaining_channels = n_remaining_channels - self.n_early_size 203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 205 | self.n_remaining_channels = n_remaining_channels # Useful during inference 206 | 207 | def forward(self, forward_input): 208 | """ 209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames 210 | forward_input[1] = audio: batch x time 211 | """ 212 | spect, audio = forward_input 213 | 214 | # Upsample spectrogram to size of audio 215 | spect = self.upsample(spect) 216 | assert(spect.size(2) >= audio.size(1)) 217 | if spect.size(2) > audio.size(1): 218 | spect = spect[:, :, :audio.size(1)] 219 | 220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 222 | 223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 224 | output_audio = [] 225 | log_s_list = [] 226 | log_det_W_list = [] 227 | 228 | for k in range(self.n_flows): 229 | if k % self.n_early_every == 0 and k > 0: 230 | output_audio.append(audio[:,:self.n_early_size,:]) 231 | audio = audio[:,self.n_early_size:,:] 232 | 233 | audio, log_det_W = self.convinv[k](audio) 234 | log_det_W_list.append(log_det_W) 235 | 236 | n_half = int(audio.size(1)/2) 237 | audio_0 = audio[:,:n_half,:] 238 | audio_1 = audio[:,n_half:,:] 239 | 240 | output = self.WN[k]((audio_0, spect)) 241 | log_s = output[:, n_half:, :] 242 | b = output[:, :n_half, :] 243 | audio_1 = torch.exp(log_s)*audio_1 + b 244 | log_s_list.append(log_s) 245 | 246 | audio = torch.cat([audio_0, audio_1],1) 247 | 248 | output_audio.append(audio) 249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list 250 | 251 | def infer(self, spect, sigma=1.0): 252 | spect = self.upsample(spect) 253 | # trim conv artifacts. maybe pad spec to kernel multiple 254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 255 | spect = spect[:, :, :-time_cutoff] 256 | 257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 259 | 260 | if spect.type() == 'torch.cuda.HalfTensor': 261 | audio = torch.cuda.HalfTensor(spect.size(0), 262 | self.n_remaining_channels, 263 | spect.size(2)).normal_() 264 | else: 265 | audio = torch.cuda.FloatTensor(spect.size(0), 266 | self.n_remaining_channels, 267 | spect.size(2)).normal_() 268 | 269 | audio = torch.autograd.Variable(sigma*audio) 270 | 271 | for k in reversed(range(self.n_flows)): 272 | n_half = int(audio.size(1)/2) 273 | audio_0 = audio[:,:n_half,:] 274 | audio_1 = audio[:,n_half:,:] 275 | 276 | output = self.WN[k]((audio_0, spect)) 277 | s = output[:, n_half:, :] 278 | b = output[:, :n_half, :] 279 | audio_1 = (audio_1 - b)/torch.exp(s) 280 | audio = torch.cat([audio_0, audio_1],1) 281 | 282 | audio = self.convinv[k](audio, reverse=True) 283 | 284 | if k % self.n_early_every == 0 and k > 0: 285 | if spect.type() == 'torch.cuda.HalfTensor': 286 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 287 | else: 288 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 289 | audio = torch.cat((sigma*z, audio),1) 290 | 291 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 292 | return audio 293 | 294 | @staticmethod 295 | def remove_weightnorm(model): 296 | waveglow = model 297 | for WN in waveglow.WN: 298 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 299 | WN.in_layers = remove(WN.in_layers) 300 | WN.cond_layers = remove(WN.cond_layers) 301 | WN.res_skip_layers = remove(WN.res_skip_layers) 302 | return waveglow 303 | 304 | 305 | def remove(conv_list): 306 | new_conv_list = torch.nn.ModuleList() 307 | for old_conv in conv_list: 308 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 309 | new_conv_list.append(old_conv) 310 | return new_conv_list 311 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | # Audio: 4 | num_mels = 80 5 | 6 | # Text 7 | text_cleaners = ['english_cleaners'] 8 | 9 | # Mel 10 | n_mel_channels = 80 11 | n_frames_per_step = 1 12 | 13 | # PostNet 14 | postnet_embedding_dim = 512 15 | postnet_kernel_size = 5 16 | postnet_n_convolutions = 5 17 | 18 | # PreNet 19 | prenet_dim = 256 20 | 21 | # Encoder 22 | encoder_n_convolutions = 3 23 | encoder_embedding_dim = 512 24 | # encoder_embedding_dim = 1024 25 | encoder_kernel_size = 5 26 | 27 | # Decoder 28 | attention_rnn_dim = 1024 29 | decoder_rnn_dim = 1024 30 | max_decoder_steps = 1000 31 | gate_threshold = 0.5 32 | p_attention_dropout = 0.1 33 | p_decoder_dropout = 0.1 34 | attention_location_kernel_size = 31 35 | attention_location_n_filters = 32 36 | attention_dim = 128 37 | 38 | # Tacotron2 39 | mask_padding = True 40 | n_symbols = len(symbols) 41 | symbols_embedding_dim = 512 42 | # symbols_embedding_dim = 1024 43 | 44 | # Train 45 | batch_size = 32 46 | epochs = 10000 47 | dataset_path = "./dataset" 48 | checkpoint_path = "./model_new" 49 | logger_path = "./logger" 50 | learning_rate = 1e-3 51 | weight_decay = 1e-6 52 | grad_clip_thresh = 1.0 53 | decay_step = [500000, 1000000, 2000000] 54 | save_step = 200 55 | log_step = 5 56 | clear_Time = 20 57 | -------------------------------------------------------------------------------- /img/model_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/img/model_test.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import os 6 | import numpy as np 7 | 8 | import waveglow 9 | import glow 10 | from network import Tacotron2 11 | from text import text_to_sequence 12 | import hparams as hp 13 | import Audio 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def plot_data(data, figsize=(12, 4)): 19 | _, axes = plt.subplots(1, len(data), figsize=figsize) 20 | for i in range(len(data)): 21 | axes[i].imshow(data[i], aspect='auto', 22 | origin='bottom', interpolation='none') 23 | 24 | if not os.path.exists("img"): 25 | os.mkdir("img") 26 | plt.savefig(os.path.join("img", "model_test.jpg")) 27 | 28 | 29 | def get_model(num): 30 | checkpoint_path = "checkpoint_" + str(num) + ".pth.tar" 31 | model = nn.DataParallel(Tacotron2(hp)).to(device) 32 | model.load_state_dict(torch.load(os.path.join( 33 | hp.checkpoint_path, checkpoint_path))['model']) 34 | model.eval() 35 | 36 | return model 37 | 38 | 39 | def synthesis(model, text): 40 | with torch.no_grad(): 41 | sequence = np.array(text_to_sequence( 42 | text, ['english_cleaners']))[None, :] 43 | sequence = torch.autograd.Variable( 44 | torch.from_numpy(sequence)).cuda().long() 45 | 46 | mel_outputs, mel_outputs_postnet, _, alignments = model.module.inference( 47 | sequence) 48 | 49 | return mel_outputs[0].cpu(), mel_outputs_postnet[0].cpu(), mel_outputs_postnet 50 | 51 | 52 | if __name__ == "__main__": 53 | # Test 54 | num = 76000 55 | model = get_model(num) 56 | # checkpoint_path = "tacotron2_statedict.pt" 57 | # model = nn.DataParallel(Tacotron2(hp)).to(device) 58 | # model.load_state_dict(torch.load(os.path.join( 59 | # hp.checkpoint_path, checkpoint_path))['state_dict']) 60 | # model.eval() 61 | text = "I am very happy to see you again!" 62 | mel, mel_postnet, mel_postnet_torch = synthesis(model, text) 63 | if not os.path.exists("results"): 64 | os.mkdir("results") 65 | Audio.tools.inv_mel_spec(mel_postnet, os.path.join( 66 | "results", text + str(num) + "griffin_lim.wav")) 67 | plot_data([mel.numpy(), mel_postnet.numpy()]) 68 | 69 | waveglow_path = os.path.join("waveglow", "pre_trained_model") 70 | waveglow_path = os.path.join(waveglow_path, "waveglow_256channels.pt") 71 | wave_glow = torch.load(waveglow_path)['model'] 72 | wave_glow = wave_glow.remove_weightnorm(wave_glow) 73 | wave_glow.cuda().eval() 74 | for m in wave_glow.modules(): 75 | if 'Conv' in str(type(m)): 76 | setattr(m, 'padding_mode', 'zeros') 77 | waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join( 78 | "results", text + str(num) + "waveglow.wav")) 79 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearNorm(torch.nn.Module): 5 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 6 | super(LinearNorm, self).__init__() 7 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 8 | 9 | torch.nn.init.xavier_uniform_( 10 | self.linear_layer.weight, 11 | gain=torch.nn.init.calculate_gain(w_init_gain)) 12 | 13 | def forward(self, x): 14 | return self.linear_layer(x) 15 | 16 | 17 | class ConvNorm(torch.nn.Module): 18 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 19 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 20 | super(ConvNorm, self).__init__() 21 | if padding is None: 22 | assert(kernel_size % 2 == 1) 23 | padding = int(dilation * (kernel_size - 1) / 2) 24 | 25 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 26 | kernel_size=kernel_size, stride=stride, 27 | padding=padding, dilation=dilation, 28 | bias=bias) 29 | 30 | torch.nn.init.xavier_uniform_( 31 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 32 | 33 | def forward(self, signal): 34 | conv_signal = self.conv(signal) 35 | return conv_signal 36 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Tacotron2Loss(nn.Module): 5 | """Tacotron2 Loss""" 6 | 7 | def __init__(self): 8 | super(Tacotron2Loss, self).__init__() 9 | 10 | def forward(self, model_output, targets): 11 | mel_target, gate_target = targets[0], targets[1] 12 | mel_target.requires_grad = False 13 | gate_target.requires_grad = False 14 | 15 | mel_out, mel_out_postnet, gate_out, _ = model_output 16 | # print(mel_out.size()) 17 | 18 | mel_loss = nn.MSELoss()(mel_out, mel_target) 19 | mel_postnet_loss = nn.MSELoss()(mel_out_postnet, mel_target) 20 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 21 | 22 | return mel_loss, mel_postnet_loss, gate_loss 23 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from math import sqrt 7 | 8 | from layers import ConvNorm, LinearNorm 9 | from utils import to_gpu, get_mask_from_lengths 10 | import hparams 11 | 12 | 13 | class LocationLayer(nn.Module): 14 | def __init__(self, attention_n_filters, attention_kernel_size, 15 | attention_dim): 16 | super(LocationLayer, self).__init__() 17 | padding = int((attention_kernel_size - 1) / 2) 18 | self.location_conv = ConvNorm(2, attention_n_filters, 19 | kernel_size=attention_kernel_size, 20 | padding=padding, bias=False, stride=1, 21 | dilation=1) 22 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 23 | bias=False, w_init_gain='tanh') 24 | 25 | def forward(self, attention_weights_cat): 26 | processed_attention = self.location_conv(attention_weights_cat) 27 | processed_attention = processed_attention.transpose(1, 2) 28 | processed_attention = self.location_dense(processed_attention) 29 | return processed_attention 30 | 31 | 32 | class Attention(nn.Module): 33 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 34 | attention_location_n_filters, attention_location_kernel_size): 35 | super(Attention, self).__init__() 36 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 37 | bias=False, w_init_gain='tanh') 38 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 39 | w_init_gain='tanh') 40 | self.v = LinearNorm(attention_dim, 1, bias=False) 41 | self.location_layer = LocationLayer(attention_location_n_filters, 42 | attention_location_kernel_size, 43 | attention_dim) 44 | self.score_mask_value = -float("inf") 45 | 46 | def get_alignment_energies(self, query, processed_memory, 47 | attention_weights_cat): 48 | """ 49 | PARAMS 50 | ------ 51 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 52 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 53 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 54 | 55 | RETURNS 56 | ------- 57 | alignment (batch, max_time) 58 | """ 59 | 60 | processed_query = self.query_layer(query.unsqueeze(1)) 61 | processed_attention_weights = self.location_layer( 62 | attention_weights_cat) 63 | energies = self.v(torch.tanh( 64 | processed_query + processed_attention_weights + processed_memory)) 65 | 66 | energies = energies.squeeze(-1) 67 | return energies 68 | 69 | def forward(self, attention_hidden_state, memory, processed_memory, 70 | attention_weights_cat, mask): 71 | """ 72 | PARAMS 73 | ------ 74 | attention_hidden_state: attention rnn last output 75 | memory: encoder outputs 76 | processed_memory: processed encoder outputs 77 | attention_weights_cat: previous and cummulative attention weights 78 | mask: binary mask for padded data 79 | """ 80 | 81 | alignment = self.get_alignment_energies( 82 | attention_hidden_state, processed_memory, attention_weights_cat) 83 | 84 | if mask is not None: 85 | alignment.data.masked_fill_(mask, self.score_mask_value) 86 | 87 | attention_weights = F.softmax(alignment, dim=1) 88 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 89 | attention_context = attention_context.squeeze(1) 90 | 91 | return attention_context, attention_weights 92 | 93 | 94 | class Prenet(nn.Module): 95 | def __init__(self, in_dim, sizes): 96 | super(Prenet, self).__init__() 97 | in_sizes = [in_dim] + sizes[:-1] 98 | self.layers = nn.ModuleList( 99 | [LinearNorm(in_size, out_size, bias=False) 100 | for (in_size, out_size) in zip(in_sizes, sizes)]) 101 | 102 | def forward(self, x): 103 | for linear in self.layers: 104 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 105 | return x 106 | 107 | 108 | class Postnet(nn.Module): 109 | """Postnet 110 | - Five 1-d convolution with 512 channels and kernel size 5 111 | """ 112 | 113 | def __init__(self, hparams): 114 | super(Postnet, self).__init__() 115 | self.convolutions = nn.ModuleList() 116 | 117 | self.convolutions.append( 118 | nn.Sequential( 119 | ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, 120 | kernel_size=hparams.postnet_kernel_size, stride=1, 121 | padding=int((hparams.postnet_kernel_size - 1) / 2), 122 | dilation=1, w_init_gain='tanh'), 123 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 124 | ) 125 | 126 | for i in range(1, hparams.postnet_n_convolutions - 1): 127 | self.convolutions.append( 128 | nn.Sequential( 129 | ConvNorm(hparams.postnet_embedding_dim, 130 | hparams.postnet_embedding_dim, 131 | kernel_size=hparams.postnet_kernel_size, stride=1, 132 | padding=int( 133 | (hparams.postnet_kernel_size - 1) / 2), 134 | dilation=1, w_init_gain='tanh'), 135 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 136 | ) 137 | 138 | self.convolutions.append( 139 | nn.Sequential( 140 | ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, 141 | kernel_size=hparams.postnet_kernel_size, stride=1, 142 | padding=int((hparams.postnet_kernel_size - 1) / 2), 143 | dilation=1, w_init_gain='linear'), 144 | nn.BatchNorm1d(hparams.n_mel_channels)) 145 | ) 146 | 147 | def forward(self, x): 148 | for i in range(len(self.convolutions) - 1): 149 | x = F.dropout(torch.tanh( 150 | self.convolutions[i](x)), 0.5, self.training) 151 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 152 | 153 | return x 154 | 155 | 156 | class Encoder(nn.Module): 157 | """Encoder module: 158 | - Three 1-d convolution banks 159 | - Bidirectional LSTM 160 | """ 161 | 162 | def __init__(self, hparams): 163 | super(Encoder, self).__init__() 164 | 165 | convolutions = [] 166 | for _ in range(hparams.encoder_n_convolutions): 167 | conv_layer = nn.Sequential( 168 | ConvNorm(hparams.encoder_embedding_dim, 169 | hparams.encoder_embedding_dim, 170 | kernel_size=hparams.encoder_kernel_size, stride=1, 171 | padding=int((hparams.encoder_kernel_size - 1) / 2), 172 | dilation=1, w_init_gain='relu'), 173 | nn.BatchNorm1d(hparams.encoder_embedding_dim)) 174 | convolutions.append(conv_layer) 175 | self.convolutions = nn.ModuleList(convolutions) 176 | 177 | self.lstm = nn.LSTM(hparams.encoder_embedding_dim, 178 | int(hparams.encoder_embedding_dim / 2), 1, 179 | batch_first=True, bidirectional=True) 180 | 181 | def pad_again(self, x, max_len): 182 | # print(x.size()) 183 | out = F.pad(x, (0, 0, 0, max_len-x.size(1))) 184 | # print(out.size()) 185 | # print(out == x) 186 | return out 187 | 188 | def forward(self, x, input_lengths, max_len): 189 | for conv in self.convolutions: 190 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 191 | 192 | x = x.transpose(1, 2) 193 | 194 | # pytorch tensor are not reversible, hence the conversion 195 | input_lengths = input_lengths.cpu().numpy() 196 | x = nn.utils.rnn.pack_padded_sequence( 197 | x, input_lengths, batch_first=True) 198 | 199 | self.lstm.flatten_parameters() 200 | outputs, _ = self.lstm(x) 201 | 202 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 203 | outputs, batch_first=True) 204 | outputs = self.pad_again(outputs, max_len) 205 | 206 | return outputs 207 | 208 | def inference(self, x): 209 | for conv in self.convolutions: 210 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 211 | 212 | x = x.transpose(1, 2) 213 | 214 | self.lstm.flatten_parameters() 215 | outputs, _ = self.lstm(x) 216 | 217 | return outputs 218 | 219 | 220 | class Decoder(nn.Module): 221 | def __init__(self, hparams): 222 | super(Decoder, self).__init__() 223 | self.n_mel_channels = hparams.n_mel_channels 224 | self.n_frames_per_step = hparams.n_frames_per_step 225 | self.encoder_embedding_dim = hparams.encoder_embedding_dim 226 | self.attention_rnn_dim = hparams.attention_rnn_dim 227 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 228 | self.prenet_dim = hparams.prenet_dim 229 | self.max_decoder_steps = hparams.max_decoder_steps 230 | self.gate_threshold = hparams.gate_threshold 231 | self.p_attention_dropout = hparams.p_attention_dropout 232 | self.p_decoder_dropout = hparams.p_decoder_dropout 233 | 234 | self.prenet = Prenet( 235 | hparams.n_mel_channels * hparams.n_frames_per_step, 236 | [hparams.prenet_dim, hparams.prenet_dim]) 237 | 238 | self.attention_rnn = nn.LSTMCell( 239 | hparams.prenet_dim + hparams.encoder_embedding_dim, 240 | hparams.attention_rnn_dim) 241 | 242 | self.attention_layer = Attention( 243 | hparams.attention_rnn_dim, hparams.encoder_embedding_dim, 244 | hparams.attention_dim, hparams.attention_location_n_filters, 245 | hparams.attention_location_kernel_size) 246 | 247 | self.decoder_rnn = nn.LSTMCell( 248 | hparams.attention_rnn_dim + hparams.encoder_embedding_dim, 249 | hparams.decoder_rnn_dim, 1) 250 | 251 | self.linear_projection = LinearNorm( 252 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 253 | hparams.n_mel_channels * hparams.n_frames_per_step) 254 | 255 | self.gate_layer = LinearNorm( 256 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, 257 | bias=True, w_init_gain='sigmoid') 258 | 259 | def get_go_frame(self, memory): 260 | """ Gets all zeros frames to use as first decoder input 261 | PARAMS 262 | ------ 263 | memory: decoder outputs 264 | 265 | RETURNS 266 | ------- 267 | decoder_input: all zeros frames 268 | """ 269 | 270 | B = memory.size(0) 271 | decoder_input = Variable(memory.data.new( 272 | B, self.n_mel_channels * self.n_frames_per_step).zero_()) 273 | return decoder_input 274 | 275 | def initialize_decoder_states(self, memory, mask): 276 | """ Initializes attention rnn states, decoder rnn states, attention 277 | weights, attention cumulative weights, attention context, stores memory 278 | and stores processed memory 279 | PARAMS 280 | ------ 281 | memory: Encoder outputs 282 | mask: Mask for padded data if training, expects None for inference 283 | """ 284 | 285 | B = memory.size(0) 286 | MAX_TIME = memory.size(1) 287 | 288 | self.attention_hidden = Variable(memory.data.new( 289 | B, self.attention_rnn_dim).zero_()) 290 | self.attention_cell = Variable(memory.data.new( 291 | B, self.attention_rnn_dim).zero_()) 292 | 293 | self.decoder_hidden = Variable(memory.data.new( 294 | B, self.decoder_rnn_dim).zero_()) 295 | self.decoder_cell = Variable(memory.data.new( 296 | B, self.decoder_rnn_dim).zero_()) 297 | 298 | self.attention_weights = Variable(memory.data.new( 299 | B, MAX_TIME).zero_()) 300 | self.attention_weights_cum = Variable(memory.data.new( 301 | B, MAX_TIME).zero_()) 302 | self.attention_context = Variable(memory.data.new( 303 | B, self.encoder_embedding_dim).zero_()) 304 | 305 | self.memory = memory 306 | self.processed_memory = self.attention_layer.memory_layer(memory) 307 | self.mask = mask 308 | 309 | def parse_decoder_inputs(self, decoder_inputs): 310 | """ Prepares decoder inputs, i.e. mel outputs 311 | PARAMS 312 | ------ 313 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 314 | 315 | RETURNS 316 | ------- 317 | inputs: processed decoder inputs 318 | """ 319 | 320 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 321 | decoder_inputs = decoder_inputs.transpose(1, 2) 322 | decoder_inputs = decoder_inputs.view( 323 | decoder_inputs.size(0), 324 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1) 325 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 326 | decoder_inputs = decoder_inputs.transpose(0, 1) 327 | return decoder_inputs 328 | 329 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): 330 | """ Prepares decoder outputs for output 331 | PARAMS 332 | ------ 333 | mel_outputs: 334 | gate_outputs: gate output energies 335 | alignments: 336 | 337 | RETURNS 338 | ------- 339 | mel_outputs: 340 | gate_outpust: gate output energies 341 | alignments: 342 | """ 343 | 344 | # (T_out, B) -> (B, T_out) 345 | alignments = torch.stack(alignments).transpose(0, 1) 346 | # (T_out, B) -> (B, T_out) 347 | gate_outputs = torch.stack(gate_outputs).transpose(0, 1) 348 | gate_outputs = gate_outputs.contiguous() 349 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 350 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 351 | # decouple frames per step 352 | mel_outputs = mel_outputs.view( 353 | mel_outputs.size(0), -1, self.n_mel_channels) 354 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 355 | mel_outputs = mel_outputs.transpose(1, 2) 356 | 357 | return mel_outputs, gate_outputs, alignments 358 | 359 | def decode(self, decoder_input): 360 | """ Decoder step using stored states, attention and memory 361 | PARAMS 362 | ------ 363 | decoder_input: previous mel output 364 | 365 | RETURNS 366 | ------- 367 | mel_output: 368 | gate_output: gate output energies 369 | attention_weights: 370 | """ 371 | 372 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 373 | self.attention_hidden, self.attention_cell = self.attention_rnn( 374 | cell_input, (self.attention_hidden, self.attention_cell)) 375 | self.attention_hidden = F.dropout( 376 | self.attention_hidden, self.p_attention_dropout, self.training) 377 | self.attention_cell = F.dropout( 378 | self.attention_cell, self.p_attention_dropout, self.training) 379 | 380 | attention_weights_cat = torch.cat( 381 | (self.attention_weights.unsqueeze(1), 382 | self.attention_weights_cum.unsqueeze(1)), dim=1) 383 | self.attention_context, self.attention_weights = self.attention_layer( 384 | self.attention_hidden, self.memory, self.processed_memory, 385 | attention_weights_cat, self.mask) 386 | 387 | self.attention_weights_cum += self.attention_weights 388 | decoder_input = torch.cat( 389 | (self.attention_hidden, self.attention_context), -1) 390 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 391 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 392 | self.decoder_hidden = F.dropout( 393 | self.decoder_hidden, self.p_decoder_dropout, self.training) 394 | self.decoder_cell = F.dropout( 395 | self.decoder_cell, self.p_decoder_dropout, self.training) 396 | 397 | decoder_hidden_attention_context = torch.cat( 398 | (self.decoder_hidden, self.attention_context), dim=1) 399 | decoder_output = self.linear_projection( 400 | decoder_hidden_attention_context) 401 | 402 | gate_prediction = self.gate_layer(decoder_hidden_attention_context) 403 | return decoder_output, gate_prediction, self.attention_weights 404 | 405 | def forward(self, memory, decoder_inputs, memory_lengths, max_len): 406 | """ Decoder forward pass for training 407 | PARAMS 408 | ------ 409 | memory: Encoder outputs 410 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs 411 | memory_lengths: Encoder output lengths for attention masking. 412 | 413 | RETURNS 414 | ------- 415 | mel_outputs: mel outputs from the decoder 416 | gate_outputs: gate outputs from the decoder 417 | alignments: sequence of attention weights from the decoder 418 | """ 419 | 420 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 421 | # print(decoder_input.size()) 422 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 423 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 424 | decoder_inputs = self.prenet(decoder_inputs) 425 | 426 | self.initialize_decoder_states( 427 | memory, mask=~get_mask_from_lengths(memory_lengths, max_len)) 428 | 429 | mel_outputs, gate_outputs, alignments = [], [], [] 430 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 431 | decoder_input = decoder_inputs[len(mel_outputs)] 432 | mel_output, gate_output, attention_weights = self.decode( 433 | decoder_input) 434 | mel_outputs += [mel_output.squeeze(1)] 435 | gate_outputs += [gate_output.squeeze()] 436 | alignments += [attention_weights] 437 | 438 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 439 | mel_outputs, gate_outputs, alignments) 440 | 441 | return mel_outputs, gate_outputs, alignments 442 | 443 | def inference(self, memory): 444 | """ Decoder inference 445 | PARAMS 446 | ------ 447 | memory: Encoder outputs 448 | 449 | RETURNS 450 | ------- 451 | mel_outputs: mel outputs from the decoder 452 | gate_outputs: gate outputs from the decoder 453 | alignments: sequence of attention weights from the decoder 454 | """ 455 | 456 | decoder_input = self.get_go_frame(memory) 457 | 458 | self.initialize_decoder_states(memory, mask=None) 459 | 460 | mel_outputs, gate_outputs, alignments = [], [], [] 461 | while True: 462 | decoder_input = self.prenet(decoder_input) 463 | mel_output, gate_output, alignment = self.decode(decoder_input) 464 | 465 | mel_outputs += [mel_output.squeeze(1)] 466 | gate_outputs += [gate_output] 467 | alignments += [alignment] 468 | 469 | if torch.sigmoid(gate_output.data) > self.gate_threshold: 470 | break 471 | elif len(mel_outputs) == self.max_decoder_steps: 472 | print("Warning! Reached max decoder steps") 473 | break 474 | 475 | decoder_input = mel_output 476 | 477 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 478 | mel_outputs, gate_outputs, alignments) 479 | 480 | return mel_outputs, gate_outputs, alignments 481 | 482 | 483 | class Tacotron2(nn.Module): 484 | def __init__(self, hparams): 485 | super(Tacotron2, self).__init__() 486 | self.mask_padding = hparams.mask_padding 487 | self.n_mel_channels = hparams.n_mel_channels 488 | self.n_frames_per_step = hparams.n_frames_per_step 489 | self.embedding = nn.Embedding( 490 | hparams.n_symbols, hparams.symbols_embedding_dim) 491 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 492 | val = sqrt(3.0) * std # uniform bounds for std 493 | self.embedding.weight.data.uniform_(-val, val) 494 | self.encoder = Encoder(hparams) 495 | self.decoder = Decoder(hparams) 496 | self.postnet = Postnet(hparams) 497 | 498 | def parse_batch(self, batch): 499 | text_padded, input_lengths, mel_padded, gate_padded, \ 500 | output_lengths = batch 501 | text_padded = to_gpu(text_padded).long() 502 | input_lengths = to_gpu(input_lengths).long() 503 | max_input_len = torch.max(input_lengths.data).item() 504 | max_output_len = torch.max(output_lengths.data).item() 505 | mel_padded = to_gpu(mel_padded).float() 506 | gate_padded = to_gpu(gate_padded).float() 507 | output_lengths = to_gpu(output_lengths).long() 508 | 509 | return ( 510 | (text_padded, input_lengths, mel_padded, 511 | (max_input_len, max_output_len), output_lengths), 512 | (mel_padded, gate_padded)) 513 | 514 | def parse_output(self, outputs, output_lengths=None, max_len=None): 515 | if self.mask_padding and output_lengths is not None: 516 | mask = ~get_mask_from_lengths(output_lengths, max_len) 517 | mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) 518 | mask = mask.permute(1, 0, 2) 519 | 520 | outputs[0].data.masked_fill_(mask, 0.0) 521 | outputs[1].data.masked_fill_(mask, 0.0) 522 | outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies 523 | 524 | return outputs 525 | 526 | def forward(self, inputs, target_max_len=0): 527 | inputs, input_lengths, targets, (max_input_len, 528 | max_output_len), output_lengths = inputs 529 | input_lengths, output_lengths = input_lengths.data, output_lengths.data 530 | 531 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 532 | # print(embedded_inputs) 533 | 534 | encoder_outputs = self.encoder( 535 | embedded_inputs, input_lengths, max_input_len) 536 | # print(encoder_outputs) 537 | 538 | mel_outputs, gate_outputs, alignments = self.decoder( 539 | encoder_outputs, targets, input_lengths, max_input_len) 540 | 541 | # print(mel_outputs.size()) 542 | # print(gate_outputs.size()) 543 | # print(alignments.size()) 544 | # print(gate_outputs) 545 | 546 | mel_outputs_postnet = self.postnet(mel_outputs) 547 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 548 | 549 | # print(max_len) 550 | 551 | return self.parse_output( 552 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], 553 | output_lengths, 554 | max_output_len) 555 | 556 | def inference(self, inputs): 557 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 558 | encoder_outputs = self.encoder.inference(embedded_inputs) 559 | mel_outputs, gate_outputs, alignments = self.decoder.inference( 560 | encoder_outputs) 561 | 562 | mel_outputs_postnet = self.postnet(mel_outputs) 563 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 564 | 565 | outputs = self.parse_output( 566 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) 567 | 568 | return outputs 569 | 570 | 571 | if __name__ == "__main__": 572 | 573 | model = Tacotron2(hparams) 574 | print(model) 575 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from data import ljspeech 4 | 5 | 6 | def preprocess_ljspeech(filename): 7 | in_dir = filename 8 | out_dir = "dataset" 9 | if not os.path.exists(out_dir): 10 | os.makedirs(out_dir, exist_ok=True) 11 | metadata = ljspeech.build_from_path(in_dir, out_dir) 12 | write_metadata(metadata, out_dir) 13 | 14 | 15 | def write_metadata(metadata, out_dir): 16 | with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: 17 | for m in metadata: 18 | f.write(m + '\n') 19 | 20 | 21 | def main(): 22 | path = os.path.join("data", "LJSpeech-1.1") 23 | preprocess_ljspeech(path) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /results/Generative adversarial network or variational auto-encoder.76000waveglow.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/results/Generative adversarial network or variational auto-encoder.76000waveglow.wav -------------------------------------------------------------------------------- /results/I am very happy to see you again!76000griffin_lim.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/results/I am very happy to see you again!76000griffin_lim.wav -------------------------------------------------------------------------------- /results/I am very happy to see you again!76000waveglow.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/results/I am very happy to see you again!76000waveglow.wav -------------------------------------------------------------------------------- /results/Jack is a little goose. He has a lovely hat. He likes wearing it very much. But when he sits, his hat can’t stay on his head.76000waveglow.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/results/Jack is a little goose. He has a lovely hat. He likes wearing it very much. But when he sits, his hat can’t stay on his head.76000waveglow.wav -------------------------------------------------------------------------------- /results/Lift humanity with cognitive artificial intelligence platforms.76000waveglow.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/results/Lift humanity with cognitive artificial intelligence platforms.76000waveglow.wav -------------------------------------------------------------------------------- /results/Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.76000waveglow.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcmyz/Tacotron2-Pytorch/a2d8b4712499e5696338521b10807a244df92922/results/Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.76000waveglow.wav -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s is not '_' and s is not '~' 75 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | 5 | from network import Tacotron2 6 | from data_utils import DataLoader, collate_fn 7 | from data_utils import Tacotron2Dataset 8 | from loss_function import Tacotron2Loss 9 | import hparams as hp 10 | 11 | from multiprocessing import cpu_count 12 | import numpy as np 13 | import argparse 14 | import os 15 | import time 16 | 17 | 18 | cuda_available = torch.cuda.is_available() 19 | 20 | 21 | def main(args): 22 | # Get device 23 | device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') 24 | 25 | # Define model 26 | model = nn.DataParallel(Tacotron2(hp)).to(device) 27 | # model = Tacotron2(hp).to(device) 28 | print("Model Have Been Defined") 29 | num_param = sum(param.numel() for param in model.parameters()) 30 | print('Number of Tacotron Parameters:', num_param) 31 | 32 | # Get dataset 33 | dataset = Tacotron2Dataset() 34 | 35 | # Optimizer 36 | optimizer = torch.optim.Adam( 37 | model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay) 38 | 39 | # Criterion 40 | criterion = Tacotron2Loss() 41 | 42 | # Load checkpoint if exists 43 | try: 44 | checkpoint = torch.load(os.path.join( 45 | hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) 46 | model.load_state_dict(checkpoint['model']) 47 | optimizer.load_state_dict(checkpoint['optimizer']) 48 | print("\n---Model Restored at Step %d---\n" % args.restore_step) 49 | 50 | except: 51 | print("\n---Start New Training---\n") 52 | if not os.path.exists(hp.checkpoint_path): 53 | os.mkdir(hp.checkpoint_path) 54 | 55 | # Init logger 56 | if not os.path.exists(hp.logger_path): 57 | os.mkdir(hp.logger_path) 58 | 59 | # Define Some Information 60 | Time = np.array([]) 61 | Start = time.clock() 62 | 63 | # Training 64 | model = model.train() 65 | 66 | for epoch in range(hp.epochs): 67 | # Get training loader 68 | training_loader = DataLoader(dataset, 69 | batch_size=hp.batch_size**2, 70 | shuffle=True, 71 | collate_fn=collate_fn, 72 | drop_last=True, 73 | num_workers=cpu_count()) 74 | total_step = hp.epochs * len(training_loader) * hp.batch_size 75 | 76 | for i, batchs in enumerate(training_loader): 77 | for j, data_of_batch in enumerate(batchs): 78 | start_time = time.clock() 79 | 80 | current_step = i * hp.batch_size + j + args.restore_step + \ 81 | epoch * len(training_loader)*hp.batch_size + 1 82 | 83 | # Init 84 | optimizer.zero_grad() 85 | 86 | # Get Data 87 | character = torch.from_numpy( 88 | data_of_batch["text"]).long().to(device) 89 | mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to( 90 | device).contiguous().transpose(1, 2) 91 | stop_target = torch.from_numpy( 92 | data_of_batch["stop_token"]).float().to(device) 93 | input_lengths = torch.from_numpy( 94 | data_of_batch["length_text"]).int().to(device) 95 | output_lengths = torch.from_numpy( 96 | data_of_batch["length_mel"]).int().to(device) 97 | # print(mel_target.size()) 98 | # print(mel_target) 99 | 100 | # Forward 101 | batch = character, input_lengths, mel_target, stop_target, output_lengths 102 | 103 | x, y = model.module.parse_batch(batch) 104 | y_pred = model(x) 105 | 106 | # Cal Loss 107 | mel_loss, mel_postnet_loss, stop_pred_loss = criterion( 108 | y_pred, y) 109 | total_loss = mel_loss + mel_postnet_loss + stop_pred_loss 110 | 111 | # Logger 112 | t_l = total_loss.item() 113 | m_l = mel_loss.item() 114 | m_p_l = mel_postnet_loss.item() 115 | s_l = stop_pred_loss.item() 116 | 117 | with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: 118 | f_total_loss.write(str(t_l)+"\n") 119 | 120 | with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: 121 | f_mel_loss.write(str(m_l)+"\n") 122 | 123 | with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: 124 | f_mel_postnet_loss.write(str(m_p_l)+"\n") 125 | 126 | with open(os.path.join("logger", "stop_pred_loss.txt"), "a") as f_s_loss: 127 | f_s_loss.write(str(s_l)+"\n") 128 | 129 | # Backward 130 | total_loss.backward() 131 | 132 | # Clipping gradients to avoid gradient explosion 133 | nn.utils.clip_grad_norm_(model.parameters(), 1.) 134 | 135 | # Update weights 136 | optimizer.step() 137 | adjust_learning_rate(optimizer, current_step) 138 | 139 | # Print 140 | if current_step % hp.log_step == 0: 141 | Now = time.clock() 142 | 143 | str1 = "Epoch [{}/{}], Step [{}/{}], Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f};".format( 144 | epoch+1, hp.epochs, current_step, total_step, mel_loss.item(), mel_postnet_loss.item()) 145 | str2 = "Stop Predicted Loss: {:.4f}, Total Loss: {:.4f}.".format( 146 | stop_pred_loss.item(), total_loss.item()) 147 | str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( 148 | (Now-Start), (total_step-current_step)*np.mean(Time)) 149 | 150 | print("\n" + str1) 151 | print(str2) 152 | print(str3) 153 | 154 | with open(os.path.join("logger", "logger.txt"), "a") as f_logger: 155 | f_logger.write(str1 + "\n") 156 | f_logger.write(str2 + "\n") 157 | f_logger.write(str3 + "\n") 158 | f_logger.write("\n") 159 | 160 | if current_step % hp.save_step == 0: 161 | torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict( 162 | )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) 163 | print("save model at step %d ..." % current_step) 164 | 165 | end_time = time.clock() 166 | Time = np.append(Time, end_time - start_time) 167 | if len(Time) == hp.clear_Time: 168 | temp_value = np.mean(Time) 169 | Time = np.delete( 170 | Time, [i for i in range(len(Time))], axis=None) 171 | Time = np.append(Time, temp_value) 172 | 173 | 174 | def adjust_learning_rate(optimizer, step): 175 | if step == 500000: 176 | for param_group in optimizer.param_groups: 177 | param_group['lr'] = 0.0005 178 | 179 | elif step == 1000000: 180 | for param_group in optimizer.param_groups: 181 | param_group['lr'] = 0.0003 182 | 183 | elif step == 2000000: 184 | for param_group in optimizer.param_groups: 185 | param_group['lr'] = 0.0001 186 | 187 | return optimizer 188 | 189 | 190 | if __name__ == "__main__": 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('--restore_step', type=int, default=0) 193 | args = parser.parse_args() 194 | 195 | main(args) 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io.wavfile import read 3 | import torch 4 | 5 | 6 | def get_mask_from_lengths(lengths, max_len=None): 7 | if max_len == None: 8 | max_len = torch.max(lengths).item() 9 | 10 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 11 | mask = (ids < lengths.unsqueeze(1)).byte() 12 | 13 | return mask 14 | 15 | 16 | def to_gpu(x): 17 | x = x.contiguous() 18 | 19 | if torch.cuda.is_available(): 20 | x = x.cuda(non_blocking=True) 21 | return torch.autograd.Variable(x) 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | out = get_mask_from_lengths(torch.arange( 27 | 0, 12, out=torch.cuda.LongTensor(12))) 28 | print(~out) 29 | -------------------------------------------------------------------------------- /waveglow/__init__.py: -------------------------------------------------------------------------------- 1 | import waveglow.inference 2 | import waveglow.mel2samp 3 | import waveglow.glow 4 | -------------------------------------------------------------------------------- /waveglow/convert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | 5 | def _check_model_old_version(model): 6 | if hasattr(model.WN[0], 'res_layers'): 7 | return True 8 | else: 9 | return False 10 | 11 | def update_model(old_model): 12 | if not _check_model_old_version(old_model): 13 | return old_model 14 | new_model = copy.deepcopy(old_model) 15 | for idx in range(0, len(new_model.WN)): 16 | wavenet = new_model.WN[idx] 17 | wavenet.res_skip_layers = torch.nn.ModuleList() 18 | n_channels = wavenet.n_channels 19 | n_layers = wavenet.n_layers 20 | for i in range(0, n_layers): 21 | if i < n_layers - 1: 22 | res_skip_channels = 2*n_channels 23 | else: 24 | res_skip_channels = n_channels 25 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 26 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) 27 | if i < n_layers - 1: 28 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) 29 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) 30 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) 31 | else: 32 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) 33 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) 34 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 35 | wavenet.res_skip_layers.append(res_skip_layer) 36 | del wavenet.res_layers 37 | del wavenet.skip_layers 38 | return new_model 39 | 40 | if __name__ == '__main__': 41 | old_model_path = sys.argv[1] 42 | new_model_path = sys.argv[2] 43 | model = torch.load(old_model_path) 44 | model['model'] = update_model(model['model']) 45 | torch.save(model, new_model_path) 46 | 47 | -------------------------------------------------------------------------------- /waveglow/glow.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import copy 28 | import torch 29 | from torch.autograd import Variable 30 | import torch.nn.functional as F 31 | 32 | 33 | @torch.jit.script 34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 35 | n_channels_int = n_channels[0] 36 | in_act = input_a+input_b 37 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 39 | acts = t_act * s_act 40 | return acts 41 | 42 | 43 | class WaveGlowLoss(torch.nn.Module): 44 | def __init__(self, sigma=1.0): 45 | super(WaveGlowLoss, self).__init__() 46 | self.sigma = sigma 47 | 48 | def forward(self, model_output): 49 | z, log_s_list, log_det_W_list = model_output 50 | for i, log_s in enumerate(log_s_list): 51 | if i == 0: 52 | log_s_total = torch.sum(log_s) 53 | log_det_W_total = log_det_W_list[i] 54 | else: 55 | log_s_total = log_s_total + torch.sum(log_s) 56 | log_det_W_total += log_det_W_list[i] 57 | 58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total 59 | return loss/(z.size(0)*z.size(1)*z.size(2)) 60 | 61 | 62 | class Invertible1x1Conv(torch.nn.Module): 63 | """ 64 | The layer outputs both the convolution, and the log determinant 65 | of its weight matrix. If reverse=True it does convolution with 66 | inverse 67 | """ 68 | def __init__(self, c): 69 | super(Invertible1x1Conv, self).__init__() 70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, 71 | bias=False) 72 | 73 | # Sample a random orthonormal matrix to initialize weights 74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0] 75 | 76 | # Ensure determinant is 1.0 not -1.0 77 | if torch.det(W) < 0: 78 | W[:,0] = -1*W[:,0] 79 | W = W.view(c, c, 1) 80 | self.conv.weight.data = W 81 | 82 | def forward(self, z, reverse=False): 83 | # shape 84 | batch_size, group_size, n_of_groups = z.size() 85 | 86 | W = self.conv.weight.squeeze() 87 | 88 | if reverse: 89 | if not hasattr(self, 'W_inverse'): 90 | # Reverse computation 91 | W_inverse = W.float().inverse() 92 | W_inverse = Variable(W_inverse[..., None]) 93 | if z.type() == 'torch.cuda.HalfTensor': 94 | W_inverse = W_inverse.half() 95 | self.W_inverse = W_inverse 96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) 97 | return z 98 | else: 99 | # Forward computation 100 | log_det_W = batch_size * n_of_groups * torch.logdet(W) 101 | z = self.conv(z) 102 | return z, log_det_W 103 | 104 | 105 | class WN(torch.nn.Module): 106 | """ 107 | This is the WaveNet like layer for the affine coupling. The primary difference 108 | from WaveNet is the convolutions need not be causal. There is also no dilation 109 | size reset. The dilation only doubles on each layer 110 | """ 111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 112 | kernel_size): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | assert(n_channels % 2 == 0) 116 | self.n_layers = n_layers 117 | self.n_channels = n_channels 118 | self.in_layers = torch.nn.ModuleList() 119 | self.res_skip_layers = torch.nn.ModuleList() 120 | self.cond_layers = torch.nn.ModuleList() 121 | 122 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 123 | start = torch.nn.utils.weight_norm(start, name='weight') 124 | self.start = start 125 | 126 | # Initializing last layer to 0 makes the affine coupling layers 127 | # do nothing at first. This helps with training stability 128 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 129 | end.weight.data.zero_() 130 | end.bias.data.zero_() 131 | self.end = end 132 | 133 | for i in range(n_layers): 134 | dilation = 2 ** i 135 | padding = int((kernel_size*dilation - dilation)/2) 136 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 137 | dilation=dilation, padding=padding) 138 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 139 | self.in_layers.append(in_layer) 140 | 141 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 142 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 143 | self.cond_layers.append(cond_layer) 144 | 145 | # last one is not necessary 146 | if i < n_layers - 1: 147 | res_skip_channels = 2*n_channels 148 | else: 149 | res_skip_channels = n_channels 150 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 151 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 152 | self.res_skip_layers.append(res_skip_layer) 153 | 154 | def forward(self, forward_input): 155 | audio, spect = forward_input 156 | audio = self.start(audio) 157 | 158 | for i in range(self.n_layers): 159 | acts = fused_add_tanh_sigmoid_multiply( 160 | self.in_layers[i](audio), 161 | self.cond_layers[i](spect), 162 | torch.IntTensor([self.n_channels])) 163 | 164 | res_skip_acts = self.res_skip_layers[i](acts) 165 | if i < self.n_layers - 1: 166 | audio = res_skip_acts[:,:self.n_channels,:] + audio 167 | skip_acts = res_skip_acts[:,self.n_channels:,:] 168 | else: 169 | skip_acts = res_skip_acts 170 | 171 | if i == 0: 172 | output = skip_acts 173 | else: 174 | output = skip_acts + output 175 | return self.end(output) 176 | 177 | 178 | class WaveGlow(torch.nn.Module): 179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 180 | n_early_size, WN_config): 181 | super(WaveGlow, self).__init__() 182 | 183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 184 | n_mel_channels, 185 | 1024, stride=256) 186 | assert(n_group % 2 == 0) 187 | self.n_flows = n_flows 188 | self.n_group = n_group 189 | self.n_early_every = n_early_every 190 | self.n_early_size = n_early_size 191 | self.WN = torch.nn.ModuleList() 192 | self.convinv = torch.nn.ModuleList() 193 | 194 | n_half = int(n_group/2) 195 | 196 | # Set up layers with the right sizes based on how many dimensions 197 | # have been output already 198 | n_remaining_channels = n_group 199 | for k in range(n_flows): 200 | if k % self.n_early_every == 0 and k > 0: 201 | n_half = n_half - int(self.n_early_size/2) 202 | n_remaining_channels = n_remaining_channels - self.n_early_size 203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 205 | self.n_remaining_channels = n_remaining_channels # Useful during inference 206 | 207 | def forward(self, forward_input): 208 | """ 209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames 210 | forward_input[1] = audio: batch x time 211 | """ 212 | spect, audio = forward_input 213 | 214 | # Upsample spectrogram to size of audio 215 | spect = self.upsample(spect) 216 | assert(spect.size(2) >= audio.size(1)) 217 | if spect.size(2) > audio.size(1): 218 | spect = spect[:, :, :audio.size(1)] 219 | 220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 222 | 223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 224 | output_audio = [] 225 | log_s_list = [] 226 | log_det_W_list = [] 227 | 228 | for k in range(self.n_flows): 229 | if k % self.n_early_every == 0 and k > 0: 230 | output_audio.append(audio[:,:self.n_early_size,:]) 231 | audio = audio[:,self.n_early_size:,:] 232 | 233 | audio, log_det_W = self.convinv[k](audio) 234 | log_det_W_list.append(log_det_W) 235 | 236 | n_half = int(audio.size(1)/2) 237 | audio_0 = audio[:,:n_half,:] 238 | audio_1 = audio[:,n_half:,:] 239 | 240 | output = self.WN[k]((audio_0, spect)) 241 | log_s = output[:, n_half:, :] 242 | b = output[:, :n_half, :] 243 | audio_1 = torch.exp(log_s)*audio_1 + b 244 | log_s_list.append(log_s) 245 | 246 | audio = torch.cat([audio_0, audio_1],1) 247 | 248 | output_audio.append(audio) 249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list 250 | 251 | def infer(self, spect, sigma=1.0): 252 | spect = self.upsample(spect) 253 | # trim conv artifacts. maybe pad spec to kernel multiple 254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 255 | spect = spect[:, :, :-time_cutoff] 256 | 257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 259 | 260 | if spect.type() == 'torch.cuda.HalfTensor': 261 | audio = torch.cuda.HalfTensor(spect.size(0), 262 | self.n_remaining_channels, 263 | spect.size(2)).normal_() 264 | else: 265 | audio = torch.cuda.FloatTensor(spect.size(0), 266 | self.n_remaining_channels, 267 | spect.size(2)).normal_() 268 | 269 | audio = torch.autograd.Variable(sigma*audio) 270 | 271 | for k in reversed(range(self.n_flows)): 272 | n_half = int(audio.size(1)/2) 273 | audio_0 = audio[:,:n_half,:] 274 | audio_1 = audio[:,n_half:,:] 275 | 276 | output = self.WN[k]((audio_0, spect)) 277 | s = output[:, n_half:, :] 278 | b = output[:, :n_half, :] 279 | audio_1 = (audio_1 - b)/torch.exp(s) 280 | audio = torch.cat([audio_0, audio_1],1) 281 | 282 | audio = self.convinv[k](audio, reverse=True) 283 | 284 | if k % self.n_early_every == 0 and k > 0: 285 | if spect.type() == 'torch.cuda.HalfTensor': 286 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 287 | else: 288 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 289 | audio = torch.cat((sigma*z, audio),1) 290 | 291 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 292 | return audio 293 | 294 | @staticmethod 295 | def remove_weightnorm(model): 296 | waveglow = model 297 | for WN in waveglow.WN: 298 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 299 | WN.in_layers = remove(WN.in_layers) 300 | WN.cond_layers = remove(WN.cond_layers) 301 | WN.res_skip_layers = remove(WN.res_skip_layers) 302 | return waveglow 303 | 304 | 305 | def remove(conv_list): 306 | new_conv_list = torch.nn.ModuleList() 307 | for old_conv in conv_list: 308 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 309 | new_conv_list.append(old_conv) 310 | return new_conv_list 311 | -------------------------------------------------------------------------------- /waveglow/inference.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | from scipy.io.wavfile import write 29 | import torch 30 | from waveglow.mel2samp import files_to_list, MAX_WAV_VALUE 31 | # from denoiser import Denoiser 32 | 33 | 34 | def inference(mel, waveglow, audio_path, sigma=1.0, sampling_rate=22050): 35 | with torch.no_grad(): 36 | audio = waveglow.infer(mel, sigma=sigma) 37 | audio = audio * MAX_WAV_VALUE 38 | audio = audio.squeeze() 39 | audio = audio.cpu().numpy() 40 | audio = audio.astype('int16') 41 | write(audio_path, sampling_rate, audio) 42 | # print(audio_path) 43 | 44 | 45 | # if __name__ == "__main__": 46 | # # Test 47 | # waveglow = torch.load(waveglow_path)['model'] 48 | # waveglow = waveglow.remove_weightnorm(waveglow) 49 | # waveglow.cuda().eval() 50 | -------------------------------------------------------------------------------- /waveglow/mel2samp.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # *****************************************************************************\ 27 | # from tacotron2.layers import TacotronSTFT 28 | import os 29 | import random 30 | import argparse 31 | import json 32 | import torch 33 | import torch.utils.data 34 | import sys 35 | from scipy.io.wavfile import read 36 | 37 | # We're using the audio processing from TacoTron2 to make sure it matches 38 | sys.path.insert(0, 'tacotron2') 39 | 40 | MAX_WAV_VALUE = 32768.0 41 | 42 | 43 | def files_to_list(filename): 44 | """ 45 | Takes a text file of filenames and makes a list of filenames 46 | """ 47 | with open(filename, encoding='utf-8') as f: 48 | files = f.readlines() 49 | 50 | files = [f.rstrip() for f in files] 51 | return files 52 | 53 | 54 | # def load_wav_to_torch(full_path): 55 | # """ 56 | # Loads wavdata into torch array 57 | # """ 58 | # sampling_rate, data = read(full_path) 59 | # return torch.from_numpy(data).float(), sampling_rate 60 | 61 | 62 | # class Mel2Samp(torch.utils.data.Dataset): 63 | # """ 64 | # This is the main class that calculates the spectrogram and returns the 65 | # spectrogram, audio pair. 66 | # """ 67 | 68 | # def __init__(self, training_files, segment_length, filter_length, 69 | # hop_length, win_length, sampling_rate, mel_fmin, mel_fmax): 70 | # self.audio_files = files_to_list(training_files) 71 | # random.seed(1234) 72 | # random.shuffle(self.audio_files) 73 | # self.stft = TacotronSTFT(filter_length=filter_length, 74 | # hop_length=hop_length, 75 | # win_length=win_length, 76 | # sampling_rate=sampling_rate, 77 | # mel_fmin=mel_fmin, mel_fmax=mel_fmax) 78 | # self.segment_length = segment_length 79 | # self.sampling_rate = sampling_rate 80 | 81 | # def get_mel(self, audio): 82 | # audio_norm = audio / MAX_WAV_VALUE 83 | # audio_norm = audio_norm.unsqueeze(0) 84 | # audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 85 | # melspec = self.stft.mel_spectrogram(audio_norm) 86 | # melspec = torch.squeeze(melspec, 0) 87 | # return melspec 88 | 89 | # def __getitem__(self, index): 90 | # # Read audio 91 | # filename = self.audio_files[index] 92 | # audio, sampling_rate = load_wav_to_torch(filename) 93 | # if sampling_rate != self.sampling_rate: 94 | # raise ValueError("{} SR doesn't match target {} SR".format( 95 | # sampling_rate, self.sampling_rate)) 96 | 97 | # # Take segment 98 | # if audio.size(0) >= self.segment_length: 99 | # max_audio_start = audio.size(0) - self.segment_length 100 | # audio_start = random.randint(0, max_audio_start) 101 | # audio = audio[audio_start:audio_start+self.segment_length] 102 | # else: 103 | # audio = torch.nn.functional.pad( 104 | # audio, (0, self.segment_length - audio.size(0)), 'constant').data 105 | 106 | # mel = self.get_mel(audio) 107 | # audio = audio / MAX_WAV_VALUE 108 | 109 | # return (mel, audio) 110 | 111 | # def __len__(self): 112 | # return len(self.audio_files) 113 | 114 | 115 | # # =================================================================== 116 | # # Takes directory of clean audio and makes directory of spectrograms 117 | # # Useful for making test sets 118 | # # =================================================================== 119 | # if __name__ == "__main__": 120 | # # Get defaults so it can work with no Sacred 121 | # parser = argparse.ArgumentParser() 122 | # parser.add_argument('-f', "--filelist_path", required=True) 123 | # parser.add_argument('-c', '--config', type=str, 124 | # help='JSON file for configuration') 125 | # parser.add_argument('-o', '--output_dir', type=str, 126 | # help='Output directory') 127 | # args = parser.parse_args() 128 | 129 | # with open(args.config) as f: 130 | # data = f.read() 131 | # data_config = json.loads(data)["data_config"] 132 | # mel2samp = Mel2Samp(**data_config) 133 | 134 | # filepaths = files_to_list(args.filelist_path) 135 | 136 | # # Make directory if it doesn't exist 137 | # if not os.path.isdir(args.output_dir): 138 | # os.makedirs(args.output_dir) 139 | # os.chmod(args.output_dir, 0o775) 140 | 141 | # for filepath in filepaths: 142 | # audio, sr = load_wav_to_torch(filepath) 143 | # melspectrogram = mel2samp.get_mel(audio) 144 | # filename = os.path.basename(filepath) 145 | # new_filepath = args.output_dir + '/' + filename + '.pt' 146 | # print(new_filepath) 147 | # torch.save(melspectrogram, new_filepath) 148 | --------------------------------------------------------------------------------