├── requirements.txt ├── text ├── symbols.py ├── LICENSE ├── cmudict.py ├── numbers.py ├── __init__.py └── cleaners.py ├── utils ├── util.py ├── plot.py ├── logger.py ├── audio.py └── dataset.py ├── LICENSE ├── model ├── layers.py └── model.py ├── hparams.py ├── mkgta.py ├── inference.py ├── inference.ipynb ├── README.md └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pillow 4 | inflect 5 | librosa 6 | Unidecode 7 | matplotlib 8 | tensorboardX 9 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from hparams import hparams as hps 4 | 5 | def mode(obj, model = False): 6 | if model and hps.is_cuda: 7 | obj = obj.cuda() 8 | elif hps.is_cuda: 9 | obj = obj.cuda(non_blocking = hps.pin_mem) 10 | return obj 11 | 12 | def to_arr(var): 13 | return var.cpu().detach().numpy().astype(np.float32) 14 | 15 | def get_mask_from_lengths(lengths, pad = False): 16 | max_len = torch.max(lengths).item() 17 | if pad and max_len%hps.n_frames_per_step != 0: 18 | max_len += hps.n_frames_per_step - max_len%hps.n_frames_per_step 19 | assert max_len%hps.n_frames_per_step == 0 20 | ids = torch.arange(0, max_len, out = torch.LongTensor(max_len)) 21 | ids = mode(ids) 22 | mask = (ids < lengths.unsqueeze(1)) 23 | return mask 24 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 BogiHsu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import numpy as np 4 | import matplotlib.pylab as plt 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data.transpose(2, 0, 1) 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearNorm(torch.nn.Module): 5 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 6 | super(LinearNorm, self).__init__() 7 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 8 | 9 | torch.nn.init.xavier_uniform_( 10 | self.linear_layer.weight, 11 | gain=torch.nn.init.calculate_gain(w_init_gain)) 12 | 13 | def forward(self, x): 14 | return self.linear_layer(x) 15 | 16 | 17 | class ConvNorm(torch.nn.Module): 18 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 19 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 20 | super(ConvNorm, self).__init__() 21 | if padding is None: 22 | assert(kernel_size % 2 == 1) 23 | padding = int(dilation * (kernel_size - 1) / 2) 24 | 25 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 26 | kernel_size=kernel_size, stride=stride, 27 | padding=padding, dilation=dilation, 28 | bias=bias) 29 | 30 | torch.nn.init.xavier_uniform_( 31 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 32 | 33 | def forward(self, signal): 34 | conv_signal = self.conv(signal) 35 | return conv_signal 36 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | 4 | class hparams: 5 | seed = 0 6 | 7 | ################################ 8 | # Data Parameters # 9 | ################################ 10 | text_cleaners=['english_cleaners'] 11 | 12 | ################################ 13 | # Audio # 14 | ################################ 15 | num_mels = 80 16 | num_freq = 513 17 | sample_rate = 22050 18 | frame_shift = 256 19 | frame_length = 1024 20 | fmin = 0 21 | fmax = 8000 22 | power = 1.5 23 | gl_iters = 30 24 | 25 | ################################ 26 | # Train # 27 | ################################ 28 | is_cuda = True 29 | pin_mem = True 30 | n_workers = 4 31 | prep = True 32 | pth = 'lj-22k.pkl' 33 | lr = 2e-3 34 | betas = (0.9, 0.999) 35 | eps = 1e-6 36 | sch = True 37 | sch_step = 4000 38 | max_iter = 200e3 39 | batch_size = 16 40 | iters_per_log = 10 41 | iters_per_sample = 500 42 | iters_per_ckpt = 10000 43 | weight_decay = 1e-6 44 | grad_clip_thresh = 1.0 45 | eg_text = 'OMAK is a thinking process which considers things always positively.' 46 | 47 | ################################ 48 | # Model Parameters # 49 | ################################ 50 | n_symbols = len(symbols) 51 | symbols_embedding_dim = 512 52 | 53 | # Encoder parameters 54 | encoder_kernel_size = 5 55 | encoder_n_convolutions = 3 56 | encoder_embedding_dim = 512 57 | 58 | # Decoder parameters 59 | n_frames_per_step = 3 60 | decoder_rnn_dim = 1024 61 | prenet_dim = 256 62 | max_decoder_ratio = 10 63 | gate_threshold = 0.5 64 | p_attention_dropout = 0.1 65 | p_decoder_dropout = 0.1 66 | 67 | # Attention parameters 68 | attention_rnn_dim = 1024 69 | attention_dim = 128 70 | 71 | # Location Layer parameters 72 | attention_location_n_filters = 32 73 | attention_location_kernel_size = 31 74 | 75 | # Mel-post processing network parameters 76 | postnet_embedding_dim = 512 77 | postnet_kernel_size = 5 78 | postnet_n_convolutions = 5 79 | 80 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s != '_' and s != '~' 75 | -------------------------------------------------------------------------------- /mkgta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pylab as plt 6 | from text import text_to_sequence 7 | from model.model import Tacotron2 8 | from hparams import hparams as hps 9 | from utils.util import mode, to_var, to_arr 10 | from utils.audio import load_wav, save_wav, melspectrogram 11 | 12 | 13 | def files_to_list(fdir = 'data'): 14 | f_list = [] 15 | with open(os.path.join(fdir, 'metadata.csv'), encoding = 'utf-8') as f: 16 | for line in f: 17 | parts = line.strip().split('|') 18 | wav_path = os.path.join(fdir, 'wavs', '%s.wav' % parts[0]) 19 | f_list.append([wav_path, parts[1]]) 20 | return f_list 21 | 22 | 23 | def load_model(ckpt_pth): 24 | ckpt_dict = torch.load(ckpt_pth) 25 | model = Tacotron2() 26 | model.load_state_dict(ckpt_dict['model']) 27 | model = mode(model, True).eval() 28 | model.decoder.train() 29 | model.postnet.train() 30 | return model 31 | 32 | 33 | def infer(wav_path, text, model): 34 | sequence = text_to_sequence(text, hps.text_cleaners) 35 | sequence = to_var(torch.IntTensor(sequence)[None, :]).long() 36 | mel = melspectrogram(load_wav(wav_path)) 37 | mel_in = to_var(torch.Tensor([mel])) 38 | r = mel_in.shape[2]%hps.n_frames_per_step 39 | if r != 0: 40 | mel_in = mel_in[:, :, :-r] 41 | sequence = torch.cat([sequence, sequence], 0) 42 | mel_in = torch.cat([mel_in, mel_in], 0) 43 | _, mel_outputs_postnet, _, _ = model.teacher_infer(sequence, mel_in) 44 | ret = mel 45 | if r != 0: 46 | ret[:, :-r] = to_arr(mel_outputs_postnet[0]) 47 | else: 48 | ret = to_arr(mel_outputs_postnet[0]) 49 | return ret 50 | 51 | 52 | def save_mel(res, pth, name): 53 | out = os.path.join(pth, name) 54 | np.save(out, res) 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('-c', '--ckpt_pth', type = str, default = '', 60 | required = True, help = 'path to load checkpoints') 61 | parser.add_argument('-n', '--npy_pth', type = str, default = 'dump', 62 | help = 'path to save mels') 63 | 64 | args = parser.parse_args() 65 | 66 | torch.backends.cudnn.enabled = True 67 | torch.backends.cudnn.benchmark = False 68 | model = load_model(args.ckpt_pth) 69 | flist = files_to_list() 70 | for x in flist: 71 | ret = infer(x[0], x[1], model) 72 | name = x[0].split('/')[-1].split('.wav')[0] 73 | if args.npy_pth != '': 74 | save_mel(ret, args.npy_pth, name) 75 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.util import to_arr 3 | from hparams import hparams as hps 4 | from tensorboardX import SummaryWriter 5 | from utils.audio import inv_melspectrogram 6 | from utils.plot import plot_alignment_to_numpy, plot_spectrogram_to_numpy 7 | 8 | 9 | class Tacotron2Logger(SummaryWriter): 10 | def __init__(self, logdir): 11 | super(Tacotron2Logger, self).__init__(logdir, flush_secs = 5) 12 | 13 | def log_training(self, items, grad_norm, learning_rate, iteration): 14 | self.add_scalar('loss.mel', items[0], iteration) 15 | self.add_scalar('loss.gate', items[1], iteration) 16 | self.add_scalar('grad.norm', grad_norm, iteration) 17 | self.add_scalar('learning.rate', learning_rate, iteration) 18 | 19 | def sample_train(self, outputs, iteration): 20 | mel_outputs = to_arr(outputs[0][0]) 21 | mel_outputs_postnet = to_arr(outputs[1][0]) 22 | alignments = to_arr(outputs[3][0]).T 23 | 24 | # plot alignment, mel and postnet output 25 | self.add_image( 26 | 'train.align', 27 | plot_alignment_to_numpy(alignments), 28 | iteration) 29 | self.add_image( 30 | 'train.mel', 31 | plot_spectrogram_to_numpy(mel_outputs), 32 | iteration) 33 | self.add_image( 34 | 'train.mel_post', 35 | plot_spectrogram_to_numpy(mel_outputs_postnet), 36 | iteration) 37 | 38 | def sample_infer(self, outputs, iteration): 39 | mel_outputs = to_arr(outputs[0][0]) 40 | mel_outputs_postnet = to_arr(outputs[1][0]) 41 | alignments = to_arr(outputs[2][0]).T 42 | 43 | # plot alignment, mel and postnet output 44 | self.add_image( 45 | 'infer.align', 46 | plot_alignment_to_numpy(alignments), 47 | iteration) 48 | self.add_image( 49 | 'infer.mel', 50 | plot_spectrogram_to_numpy(mel_outputs), 51 | iteration) 52 | self.add_image( 53 | 'infer.mel_post', 54 | plot_spectrogram_to_numpy(mel_outputs_postnet), 55 | iteration) 56 | 57 | # save audio 58 | wav = inv_melspectrogram(mel_outputs) 59 | wav_postnet = inv_melspectrogram(mel_outputs_postnet) 60 | self.add_audio('infer.wav', wav, iteration, hps.sample_rate) 61 | self.add_audio('infer.wav_post', wav_postnet, iteration, hps.sample_rate) 62 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import matplotlib.pylab as plt 5 | from text import text_to_sequence 6 | from model.model import Tacotron2 7 | from hparams import hparams as hps 8 | from utils.util import mode, to_arr 9 | from utils.audio import save_wav, inv_melspectrogram 10 | 11 | 12 | def load_model(ckpt_pth): 13 | ckpt_dict = torch.load(ckpt_pth) 14 | model = Tacotron2() 15 | model.load_state_dict(ckpt_dict['model']) 16 | model = mode(model, True).eval() 17 | return model 18 | 19 | 20 | def infer(text, model): 21 | sequence = text_to_sequence(text, hps.text_cleaners) 22 | sequence = mode(torch.IntTensor(sequence)[None, :]).long() 23 | mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence) 24 | return (mel_outputs, mel_outputs_postnet, alignments) 25 | 26 | 27 | def plot_data(data, figsize = (16, 4)): 28 | fig, axes = plt.subplots(1, len(data), figsize = figsize) 29 | for i in range(len(data)): 30 | axes[i].imshow(data[i], aspect = 'auto', origin = 'bottom') 31 | 32 | 33 | def plot(output, pth): 34 | mel_outputs, mel_outputs_postnet, alignments = output 35 | plot_data((to_arr(mel_outputs[0]), 36 | to_arr(mel_outputs_postnet[0]), 37 | to_arr(alignments[0]).T)) 38 | plt.savefig(pth+'.png') 39 | 40 | 41 | def audio(output, pth): 42 | mel_outputs, mel_outputs_postnet, _ = output 43 | wav_postnet = inv_melspectrogram(to_arr(mel_outputs_postnet[0])) 44 | save_wav(wav_postnet, pth+'.wav') 45 | 46 | 47 | def save_mel(output, pth): 48 | mel_outputs, mel_outputs_postnet, _ = output 49 | np.save(pth+'.npy', to_arr(mel_outputs_postnet)) 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('-c', '--ckpt_pth', type = str, default = '', 55 | required = True, help = 'path to load checkpoints') 56 | parser.add_argument('-i', '--img_pth', type = str, default = '', 57 | help = 'path to save images') 58 | parser.add_argument('-w', '--wav_pth', type = str, default = '', 59 | help = 'path to save wavs') 60 | parser.add_argument('-n', '--npy_pth', type = str, default = '', 61 | help = 'path to save mels') 62 | parser.add_argument('-t', '--text', type = str, default = 'Tacotron is awesome.', 63 | help = 'text to synthesize') 64 | 65 | args = parser.parse_args() 66 | 67 | torch.backends.cudnn.enabled = True 68 | torch.backends.cudnn.benchmark = False 69 | model = load_model(args.ckpt_pth) 70 | output = infer(args.text, model) 71 | if args.img_pth != '': 72 | plot(output, args.img_pth) 73 | if args.wav_pth != '': 74 | audio(output, args.wav_pth) 75 | if args.npy_pth != '': 76 | save_mel(output, args.npy_pth) 77 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import librosa 3 | import numpy as np 4 | from scipy.io import wavfile 5 | from librosa.util import normalize 6 | from hparams import hparams as hps 7 | MAX_WAV_VALUE = 32768.0 8 | _mel_basis = None 9 | 10 | 11 | def load_wav(path): 12 | sr, wav = wavfile.read(path) 13 | assert sr == hps.sample_rate 14 | return normalize(wav/MAX_WAV_VALUE)*0.95 15 | 16 | 17 | def save_wav(wav, path): 18 | wav *= MAX_WAV_VALUE 19 | wavfile.write(path, hps.sample_rate, wav.astype(np.int16)) 20 | 21 | 22 | def spectrogram(y): 23 | D = _stft(y) 24 | S = _amp_to_db(np.abs(D)) 25 | return S 26 | 27 | 28 | def inv_spectrogram(S): 29 | S = _db_to_amp(S) 30 | return _griffin_lim(S ** hps.power) 31 | 32 | 33 | def melspectrogram(y): 34 | D = _stft(y) 35 | S = _amp_to_db(_linear_to_mel(np.abs(D))) 36 | return S 37 | 38 | 39 | def inv_melspectrogram(mel): 40 | mel = _db_to_amp(mel) 41 | S = _mel_to_linear(mel) 42 | return _griffin_lim(S**hps.power) 43 | 44 | 45 | def _griffin_lim(S): 46 | '''librosa implementation of Griffin-Lim 47 | Based on https://github.com/librosa/librosa/issues/434 48 | ''' 49 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 50 | S_complex = np.abs(S).astype(np.complex) 51 | y = _istft(S_complex * angles) 52 | for i in range(hps.gl_iters): 53 | angles = np.exp(1j * np.angle(_stft(y))) 54 | y = _istft(S_complex * angles) 55 | return np.clip(y, a_max = 1, a_min = -1) 56 | 57 | 58 | # Conversions: 59 | def _stft(y): 60 | n_fft, hop_length, win_length = _stft_parameters() 61 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, 62 | win_length=win_length, pad_mode='reflect') 63 | 64 | 65 | def _istft(y): 66 | _, hop_length, win_length = _stft_parameters() 67 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 68 | 69 | 70 | def _stft_parameters(): 71 | return (hps.num_freq - 1) * 2, hps.frame_shift, hps.frame_length 72 | 73 | 74 | def _linear_to_mel(spectrogram): 75 | global _mel_basis 76 | if _mel_basis is None: 77 | _mel_basis = _build_mel_basis() 78 | return np.dot(_mel_basis, spectrogram) 79 | 80 | 81 | def _mel_to_linear(spectrogram): 82 | global _mel_basis 83 | if _mel_basis is None: 84 | _mel_basis = _build_mel_basis() 85 | inv_mel_basis = np.linalg.pinv(_mel_basis) 86 | inverse = np.dot(inv_mel_basis, spectrogram) 87 | inverse = np.maximum(1e-10, inverse) 88 | return inverse 89 | 90 | 91 | def _build_mel_basis(): 92 | n_fft = (hps.num_freq - 1) * 2 93 | return librosa.filters.mel(hps.sample_rate, n_fft, n_mels=hps.num_mels, fmin = hps.fmin, fmax = hps.fmax) 94 | 95 | 96 | def _amp_to_db(x): 97 | return np.log(np.maximum(1e-5, x)) 98 | 99 | 100 | def _db_to_amp(x): 101 | return np.exp(x) 102 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","source":["# Text-to-Speech with Tacotron2 and HifiGAN\n","\n","This is an English female voice TTS demo using open source projects [BogiHsu/Tacotron2-PyTorch](https://github.com/BogiHsu/Tacotron2-PyTorch) and [jik876/hifi-gan](https://github.com/jik876/hifi-gan).\n","\n","Please enable GPU acceleration in Colab before you start running the code."],"metadata":{"id":"3qb1435p6OEz"}},{"cell_type":"markdown","metadata":{"id":"3rof5nB25HD4"},"source":["## Set up environment"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"OBB0-ki_5HD7"},"outputs":[],"source":["!git clone https://github.com/BogiHsu/Tacotron2-PyTorch.git\n","!pip install -r Tacotron2-PyTorch/requirements.txt\n","!git clone https://github.com/jik876/hifi-gan.git\n","!mkdir -p mel_files"]},{"cell_type":"markdown","metadata":{"id":"whZkDydI5HD9"},"source":["## Download Tacotron2 pretrained model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"EbbOs-sT5HD9"},"outputs":[],"source":["!wget https://github.com/BogiHsu/Tacotron2-PyTorch/releases/download/lj-200k-b512/ckpt_200000 -O Tacotron2-PyTorch/ckpt_200000"]},{"cell_type":"markdown","metadata":{"id":"OTcGBEfG5HD-"},"source":["## Download HifiGAN pretrained model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gH6vR6045HD-"},"outputs":[],"source":["import gdown\n","dl = {'hifi-gan/config.json': 'https://drive.google.com/u/1/uc?id=1aDh576AEYA5eTjhx7sew1qcCM_Y526jc&export=download',\n"," 'hifi-gan/generator_v1': 'https://drive.google.com/u/1/uc?id=14NENd4equCBLyyCSke114Mv6YR_j_uFs&export=download'}\n","for k in dl:\n"," gdown.download(dl[k], k)"]},{"cell_type":"markdown","metadata":{"id":"DrOB1Ku15HD_"},"source":["## Tacotron2 inference"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"z44MMC_y5HD_"},"outputs":[],"source":["%cd Tacotron2-PyTorch\n","!python3 inference.py --ckpt_pth=ckpt_200000 --text='Tacotron is awesome.' --npy_pth=../mel_files/1\n","%cd ../"]},{"cell_type":"markdown","metadata":{"id":"9cHVnw4e5HEA"},"source":["## HifiGAN inference"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KmMdKHKL5HEA"},"outputs":[],"source":["%cd hifi-gan\n","!python3 inference_e2e.py --checkpoint_file generator_v1 --input_mels_dir ../mel_files --output_dir ../wav_files\n","%cd ../"]},{"cell_type":"markdown","metadata":{"id":"M2ElzCby5HEB"},"source":["## Play synthesized wav file"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9SGCUtQC5HEB"},"outputs":[],"source":["import IPython\n","IPython.display.Audio('wav_files/1_generated_e2e.wav')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ws_nYQk35HEB"},"outputs":[],"source":[""]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.10"},"colab":{"name":"inference.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron2-PyTorch 2 | Yet another PyTorch implementation of [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf). The project is highly based on [these](#References). I made some modification to improve speed and performance of both training and inference. 3 | 4 | ## TODO 5 | - [x] Add Colab demo. 6 | - [x] Update README. 7 | - [x] Upload pretrained models. 8 | - [x] Compatible with [WaveGlow](https://github.com/NVIDIA/waveglow) and [Hifi-GAN](https://github.com/jik876/hifi-gan). 9 | 10 | ## Requirements 11 | - Python >= 3.5.2 12 | - torch >= 1.0.0 13 | - numpy 14 | - scipy 15 | - pillow 16 | - inflect 17 | - librosa 18 | - Unidecode 19 | - matplotlib 20 | - tensorboardX 21 | 22 | ## Preprocessing 23 | Currently only support [LJ Speech](https://keithito.com/LJ-Speech-Dataset/). You can modify `hparams.py` for different sampling rates. `prep` decides whether to preprocess all utterances before training or online preprocess. `pth` sepecifies the path to store preprocessed data. 24 | 25 | ## Training 26 | 1. For training Tacotron2, run the following command. 27 | ```bash 28 | python3 train.py \ 29 | --data_dir= \ 30 | --ckpt_dir= 31 | ``` 32 | 33 | 2. If you have multiple GPUs, try [distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility). 34 | ```bash 35 | python -m torch.distributed.launch --nproc_per_node train.py \ 36 | --data_dir= \ 37 | --ckpt_dir= 38 | ``` 39 | Note that the training batch size will become times larger. 40 | 41 | 3. For training using a pretrained model, run the following command. 42 | ```bash 43 | python3 train.py \ 44 | --data_dir= \ 45 | --ckpt_dir= \ 46 | --ckpt_pth= 47 | ``` 48 | 49 | 4. For using Tensorboard (optional), run the following command. 50 | ```bash 51 | python3 train.py \ 52 | --data_dir= \ 53 | --ckpt_dir= \ 54 | --log_dir= 55 | ``` 56 | You can find alinment images and synthesized audio clips during training. The text to synthesize can be set in `hparams.py`. 57 | 58 | ## Inference 59 | - For synthesizing wav files, run the following command. 60 | 61 | ```bash 62 | python3 inference.py \ 63 | --ckpt_pth= \ 64 | --img_pth= \ 65 | --npy_pth= \ 66 | --wav_pth= \ 67 | --text= 68 | ``` 69 | 70 | ## Pretrained Model 71 | You can download pretrained models from [Realeases](https://github.com/BogiHsu/Tacotron2-PyTorch/releases). The hyperparameter for training is also in the directory. All the models were trained using 8 GPUs. 72 | 73 | ## Vocoder 74 | A vocoder is not implemented. But the model is compatible with [WaveGlow](https://github.com/NVIDIA/waveglow) and [Hifi-GAN](https://github.com/jik876/hifi-gan). Check the Colab demo for more information. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/BogiHsu/Tacotron2-PyTorch/blob/master/inference.ipynb) 75 | 76 | ## References 77 | This project is highly based on the works below. 78 | - [Tacotron2 by NVIDIA](https://github.com/NVIDIA/tacotron2) 79 | - [Tacotron by r9y9](https://github.com/r9y9/tacotron_pytorch) 80 | - [Tacotron by keithito](https://github.com/keithito/tacotron) 81 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from text import text_to_sequence 6 | from hparams import hparams as hps 7 | from torch.utils.data import Dataset 8 | from utils.audio import load_wav, melspectrogram 9 | 10 | 11 | def files_to_list(fdir): 12 | f_list = [] 13 | with open(os.path.join(fdir, 'metadata.csv'), encoding = 'utf-8') as f: 14 | for line in f: 15 | parts = line.strip().split('|') 16 | wav_path = os.path.join(fdir, 'wavs', '%s.wav' % parts[0]) 17 | if hps.prep: 18 | f_list.append(get_mel_text_pair(parts[1], wav_path)) 19 | else: 20 | f_list.append([parts[1], wav_path]) 21 | if hps.prep and hps.pth is not None: 22 | with open(hps.pth, 'wb') as w: 23 | pickle.dump(f_list, w) 24 | return f_list 25 | 26 | 27 | class ljdataset(Dataset): 28 | def __init__(self, fdir): 29 | if hps.prep and hps.pth is not None and os.path.isfile(hps.pth): 30 | with open(hps.pth, 'rb') as r: 31 | self.f_list = pickle.load(r) 32 | else: 33 | self.f_list = files_to_list(fdir) 34 | 35 | def __getitem__(self, index): 36 | text, mel = self.f_list[index] if hps.prep \ 37 | else get_mel_text_pair(*self.f_list[index]) 38 | return text, mel 39 | 40 | def __len__(self): 41 | return len(self.f_list) 42 | 43 | 44 | def get_mel_text_pair(text, wav_path): 45 | text = get_text(text) 46 | mel = get_mel(wav_path) 47 | return (text, mel) 48 | 49 | def get_text(text): 50 | return torch.IntTensor(text_to_sequence(text, hps.text_cleaners)) 51 | 52 | def get_mel(wav_path): 53 | wav = load_wav(wav_path) 54 | return torch.Tensor(melspectrogram(wav).astype(np.float32)) 55 | 56 | 57 | class ljcollate(): 58 | def __init__(self, n_frames_per_step): 59 | self.n_frames_per_step = n_frames_per_step 60 | 61 | def __call__(self, batch): 62 | # Right zero-pad all one-hot text sequences to max input length 63 | input_lengths, ids_sorted_decreasing = torch.sort( 64 | torch.LongTensor([len(x[0]) for x in batch]), 65 | dim=0, descending=True) 66 | max_input_len = input_lengths[0] 67 | 68 | text_padded = torch.LongTensor(len(batch), max_input_len) 69 | text_padded.zero_() 70 | for i in range(len(ids_sorted_decreasing)): 71 | text = batch[ids_sorted_decreasing[i]][0] 72 | text_padded[i, :text.size(0)] = text 73 | 74 | # Right zero-pad mel-spec 75 | num_mels = batch[0][1].size(0) 76 | max_target_len = max([x[1].size(1) for x in batch]) 77 | if max_target_len % self.n_frames_per_step != 0: 78 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 79 | assert max_target_len % self.n_frames_per_step == 0 80 | 81 | # include mel padded and gate padded 82 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 83 | mel_padded.zero_() 84 | gate_padded = torch.FloatTensor(len(batch), max_target_len) 85 | gate_padded.zero_() 86 | output_lengths = torch.LongTensor(len(batch)) 87 | for i in range(len(ids_sorted_decreasing)): 88 | mel = batch[ids_sorted_decreasing[i]][1] 89 | mel_padded[i, :, :mel.size(1)] = mel 90 | gate_padded[i, mel.size(1)-1:] = 1 91 | output_lengths[i] = mel.size(1) 92 | 93 | return text_padded, input_lengths, mel_padded, gate_padded, output_lengths 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from inference import infer 7 | from utils.util import mode 8 | from hparams import hparams as hps 9 | from utils.logger import Tacotron2Logger 10 | from utils.dataset import ljdataset, ljcollate 11 | from model.model import Tacotron2, Tacotron2Loss 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | np.random.seed(hps.seed) 14 | torch.manual_seed(hps.seed) 15 | torch.cuda.manual_seed(hps.seed) 16 | 17 | 18 | def prepare_dataloaders(fdir, n_gpu): 19 | trainset = ljdataset(fdir) 20 | collate_fn = ljcollate(hps.n_frames_per_step) 21 | sampler = DistributedSampler(trainset) if n_gpu > 1 else None 22 | train_loader = DataLoader(trainset, num_workers = hps.n_workers, shuffle = n_gpu == 1, 23 | batch_size = hps.batch_size, pin_memory = hps.pin_mem, 24 | drop_last = True, collate_fn = collate_fn, sampler = sampler) 25 | return train_loader 26 | 27 | 28 | def load_checkpoint(ckpt_pth, model, optimizer, device, n_gpu): 29 | ckpt_dict = torch.load(ckpt_pth, map_location = device) 30 | (model.module if n_gpu > 1 else model).load_state_dict(ckpt_dict['model']) 31 | optimizer.load_state_dict(ckpt_dict['optimizer']) 32 | iteration = ckpt_dict['iteration'] 33 | return model, optimizer, iteration 34 | 35 | 36 | def save_checkpoint(model, optimizer, iteration, ckpt_pth, n_gpu): 37 | torch.save({'model': (model.module if n_gpu > 1 else model).state_dict(), 38 | 'optimizer': optimizer.state_dict(), 39 | 'iteration': iteration}, ckpt_pth) 40 | 41 | 42 | def train(args): 43 | # setup env 44 | rank = local_rank = 0 45 | n_gpu = 1 46 | if 'WORLD_SIZE' in os.environ: 47 | os.environ['OMP_NUM_THREADS'] = str(hps.n_workers) 48 | rank = int(os.environ['RANK']) 49 | local_rank = int(os.environ['LOCAL_RANK']) 50 | n_gpu = int(os.environ['WORLD_SIZE']) 51 | torch.distributed.init_process_group( 52 | backend = 'nccl', rank = local_rank, world_size = n_gpu) 53 | torch.cuda.set_device(local_rank) 54 | device = torch.device('cuda:{:d}'.format(local_rank)) 55 | 56 | # build model 57 | model = Tacotron2() 58 | mode(model, True) 59 | if n_gpu > 1: 60 | model = torch.nn.parallel.DistributedDataParallel( 61 | model, device_ids = [local_rank]) 62 | optimizer = torch.optim.Adam(model.parameters(), lr = hps.lr, 63 | betas = hps.betas, eps = hps.eps, 64 | weight_decay = hps.weight_decay) 65 | criterion = Tacotron2Loss() 66 | 67 | # load checkpoint 68 | iteration = 1 69 | if args.ckpt_pth != '': 70 | model, optimizer, iteration = load_checkpoint(args.ckpt_pth, model, optimizer, device, n_gpu) 71 | iteration += 1 72 | 73 | # get scheduler 74 | if hps.sch: 75 | lr_lambda = lambda step: hps.sch_step**0.5*min((step+1)*hps.sch_step**-1.5, (step+1)**-0.5) 76 | if args.ckpt_pth != '': 77 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch = iteration) 78 | else: 79 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 80 | 81 | # make dataset 82 | train_loader = prepare_dataloaders(args.data_dir, n_gpu) 83 | 84 | if rank == 0: 85 | # get logger ready 86 | if args.log_dir != '': 87 | if not os.path.isdir(args.log_dir): 88 | os.makedirs(args.log_dir) 89 | os.chmod(args.log_dir, 0o775) 90 | logger = Tacotron2Logger(args.log_dir) 91 | 92 | # get ckpt_dir ready 93 | if args.ckpt_dir != '' and not os.path.isdir(args.ckpt_dir): 94 | os.makedirs(args.ckpt_dir) 95 | os.chmod(args.ckpt_dir, 0o775) 96 | 97 | model.train() 98 | # ================ MAIN TRAINNIG LOOP! =================== 99 | epoch = 0 100 | while iteration <= hps.max_iter: 101 | if n_gpu > 1: 102 | train_loader.sampler.set_epoch(epoch) 103 | for batch in train_loader: 104 | if iteration > hps.max_iter: 105 | break 106 | start = time.perf_counter() 107 | x, y = (model.module if n_gpu > 1 else model).parse_batch(batch) 108 | y_pred = model(x) 109 | 110 | # loss 111 | loss, items = criterion(y_pred, y) 112 | 113 | # zero grad 114 | model.zero_grad() 115 | 116 | # backward, grad_norm, and update 117 | loss.backward() 118 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hps.grad_clip_thresh) 119 | optimizer.step() 120 | if hps.sch: 121 | scheduler.step() 122 | 123 | dur = time.perf_counter()-start 124 | if rank == 0: 125 | # info 126 | print('Iter: {} Mel Loss: {:.2e} Gate Loss: {:.2e} Grad Norm: {:.2e} {:.1f}s/it'.format( 127 | iteration, items[0], items[1], grad_norm, dur)) 128 | 129 | # log 130 | if args.log_dir != '' and (iteration % hps.iters_per_log == 0): 131 | learning_rate = optimizer.param_groups[0]['lr'] 132 | logger.log_training(items, grad_norm, learning_rate, iteration) 133 | 134 | # sample 135 | if args.log_dir != '' and (iteration % hps.iters_per_sample == 0): 136 | model.eval() 137 | output = infer(hps.eg_text, model.module if n_gpu > 1 else model) 138 | model.train() 139 | logger.sample_train(y_pred, iteration) 140 | logger.sample_infer(output, iteration) 141 | 142 | # save ckpt 143 | if args.ckpt_dir != '' and (iteration % hps.iters_per_ckpt == 0): 144 | ckpt_pth = os.path.join(args.ckpt_dir, 'ckpt_{}'.format(iteration)) 145 | save_checkpoint(model, optimizer, iteration, ckpt_pth, n_gpu) 146 | 147 | iteration += 1 148 | epoch += 1 149 | 150 | if rank == 0 and args.log_dir != '': 151 | logger.close() 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | # path 157 | parser.add_argument('-d', '--data_dir', type = str, default = 'data', 158 | help = 'directory to load data') 159 | parser.add_argument('-l', '--log_dir', type = str, default = 'log', 160 | help = 'directory to save tensorboard logs') 161 | parser.add_argument('-cd', '--ckpt_dir', type = str, default = 'ckpt', 162 | help = 'directory to save checkpoints') 163 | parser.add_argument('-cp', '--ckpt_pth', type = str, default = '', 164 | help = 'path to load checkpoints') 165 | 166 | args = parser.parse_args() 167 | 168 | torch.backends.cudnn.enabled = True 169 | torch.backends.cudnn.benchmark = False 170 | train(args) 171 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from math import sqrt 4 | from hparams import hparams as hps 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from model.layers import ConvNorm, LinearNorm 8 | from utils.util import mode, get_mask_from_lengths 9 | 10 | 11 | class Tacotron2Loss(nn.Module): 12 | def __init__(self): 13 | super(Tacotron2Loss, self).__init__() 14 | self.loss = nn.MSELoss(reduction = 'none') 15 | 16 | def forward(self, model_outputs, targets): 17 | mel_out, mel_out_postnet, gate_out, _ = model_outputs 18 | gate_out = gate_out.view(-1, 1) 19 | 20 | mel_target, gate_target, output_lengths = targets 21 | mel_target.requires_grad = False 22 | gate_target.requires_grad = False 23 | output_lengths.requires_grad = False 24 | slice = torch.arange(0, gate_target.size(1), hps.n_frames_per_step) 25 | gate_target = gate_target[:, slice].view(-1, 1) 26 | mel_mask = ~get_mask_from_lengths(output_lengths.data, True) 27 | 28 | mel_loss = self.loss(mel_out, mel_target) + \ 29 | self.loss(mel_out_postnet, mel_target) 30 | mel_loss = mel_loss.sum(1).masked_fill_(mel_mask, 0.)/mel_loss.size(1) 31 | mel_loss = mel_loss.sum()/output_lengths.sum() 32 | 33 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 34 | return mel_loss+gate_loss, (mel_loss.item(), gate_loss.item()) 35 | 36 | 37 | class LocationLayer(nn.Module): 38 | def __init__(self, attention_n_filters, attention_kernel_size, 39 | attention_dim): 40 | super(LocationLayer, self).__init__() 41 | padding = int((attention_kernel_size - 1) / 2) 42 | self.location_conv = ConvNorm(2, attention_n_filters, 43 | kernel_size=attention_kernel_size, 44 | padding=padding, bias=False, stride=1, 45 | dilation=1) 46 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 47 | bias=False, w_init_gain='tanh') 48 | 49 | def forward(self, attention_weights_cat): 50 | processed_attention = self.location_conv(attention_weights_cat) 51 | processed_attention = processed_attention.transpose(1, 2) 52 | processed_attention = self.location_dense(processed_attention) 53 | return processed_attention 54 | 55 | 56 | class Attention(nn.Module): 57 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 58 | attention_location_n_filters, attention_location_kernel_size): 59 | super(Attention, self).__init__() 60 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 61 | bias=False, w_init_gain='tanh') 62 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 63 | w_init_gain='tanh') 64 | self.v = LinearNorm(attention_dim, 1, bias=False) 65 | self.location_layer = LocationLayer(attention_location_n_filters, 66 | attention_location_kernel_size, 67 | attention_dim) 68 | self.score_mask_value = -float('inf') 69 | 70 | def get_alignment_energies(self, query, processed_memory, 71 | attention_weights_cat): 72 | ''' 73 | PARAMS 74 | ------ 75 | query: decoder output (batch, num_mels * n_frames_per_step) 76 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 77 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 78 | 79 | RETURNS 80 | ------- 81 | alignment (batch, max_time) 82 | ''' 83 | 84 | processed_query = self.query_layer(query.unsqueeze(1)) 85 | processed_attention_weights = self.location_layer(attention_weights_cat) 86 | energies = self.v(torch.tanh( 87 | processed_query + processed_attention_weights + processed_memory)) 88 | 89 | energies = energies.squeeze(-1) 90 | return energies 91 | 92 | def forward(self, attention_hidden_state, memory, processed_memory, 93 | attention_weights_cat, mask): 94 | ''' 95 | PARAMS 96 | ------ 97 | attention_hidden_state: attention rnn last output 98 | memory: encoder outputs 99 | processed_memory: processed encoder outputs 100 | attention_weights_cat: previous and cummulative attention weights 101 | mask: binary mask for padded data 102 | ''' 103 | alignment = self.get_alignment_energies( 104 | attention_hidden_state, processed_memory, attention_weights_cat) 105 | 106 | if mask is not None: 107 | alignment.data.masked_fill_(mask, self.score_mask_value) 108 | 109 | attention_weights = F.softmax(alignment, dim=1) 110 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 111 | attention_context = attention_context.squeeze(1) 112 | return attention_context, attention_weights 113 | 114 | 115 | class Prenet(nn.Module): 116 | def __init__(self, in_dim, sizes): 117 | super(Prenet, self).__init__() 118 | in_sizes = [in_dim] + sizes[:-1] 119 | self.layers = nn.ModuleList( 120 | [LinearNorm(in_size, out_size, bias=False) 121 | for (in_size, out_size) in zip(in_sizes, sizes)]) 122 | 123 | def forward(self, x): 124 | for linear in self.layers: 125 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 126 | return x 127 | 128 | 129 | class Postnet(nn.Module): 130 | '''Postnet 131 | - Five 1-d convolution with 512 channels and kernel size 5 132 | ''' 133 | 134 | def __init__(self): 135 | super(Postnet, self).__init__() 136 | self.convolutions = nn.ModuleList() 137 | 138 | self.convolutions.append( 139 | nn.Sequential( 140 | ConvNorm(hps.num_mels, hps.postnet_embedding_dim, 141 | kernel_size=hps.postnet_kernel_size, stride=1, 142 | padding=int((hps.postnet_kernel_size - 1) / 2), 143 | dilation=1, w_init_gain='tanh'), 144 | nn.BatchNorm1d(hps.postnet_embedding_dim)) 145 | ) 146 | 147 | for i in range(1, hps.postnet_n_convolutions - 1): 148 | self.convolutions.append( 149 | nn.Sequential( 150 | ConvNorm(hps.postnet_embedding_dim, 151 | hps.postnet_embedding_dim, 152 | kernel_size=hps.postnet_kernel_size, stride=1, 153 | padding=int((hps.postnet_kernel_size - 1) / 2), 154 | dilation=1, w_init_gain='tanh'), 155 | nn.BatchNorm1d(hps.postnet_embedding_dim)) 156 | ) 157 | 158 | self.convolutions.append( 159 | nn.Sequential( 160 | ConvNorm(hps.postnet_embedding_dim, hps.num_mels, 161 | kernel_size=hps.postnet_kernel_size, stride=1, 162 | padding=int((hps.postnet_kernel_size - 1) / 2), 163 | dilation=1, w_init_gain='linear'), 164 | nn.BatchNorm1d(hps.num_mels)) 165 | ) 166 | 167 | def forward(self, x): 168 | for i in range(len(self.convolutions) - 1): 169 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 170 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 171 | return x 172 | 173 | 174 | class Encoder(nn.Module): 175 | '''Encoder module: 176 | - Three 1-d convolution banks 177 | - Bidirectional LSTM 178 | ''' 179 | def __init__(self): 180 | super(Encoder, self).__init__() 181 | 182 | convolutions = [] 183 | for i in range(hps.encoder_n_convolutions): 184 | conv_layer = nn.Sequential( 185 | ConvNorm(hps.symbols_embedding_dim if i == 0 \ 186 | else hps.encoder_embedding_dim, 187 | hps.encoder_embedding_dim, 188 | kernel_size=hps.encoder_kernel_size, stride=1, 189 | padding=int((hps.encoder_kernel_size - 1) / 2), 190 | dilation=1, w_init_gain='relu'), 191 | nn.BatchNorm1d(hps.encoder_embedding_dim)) 192 | convolutions.append(conv_layer) 193 | self.convolutions = nn.ModuleList(convolutions) 194 | 195 | self.lstm = nn.LSTM(hps.encoder_embedding_dim, 196 | int(hps.encoder_embedding_dim / 2), 1, 197 | batch_first=True, bidirectional=True) 198 | 199 | def forward(self, x, input_lengths): 200 | for conv in self.convolutions: 201 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 202 | 203 | x = x.transpose(1, 2) 204 | 205 | # pytorch tensor are not reversible, hence the conversion 206 | input_lengths = input_lengths.cpu().numpy() 207 | x = nn.utils.rnn.pack_padded_sequence( 208 | x, input_lengths, batch_first=True) 209 | 210 | self.lstm.flatten_parameters() 211 | outputs, _ = self.lstm(x) 212 | 213 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 214 | outputs, batch_first=True) 215 | return outputs 216 | 217 | def inference(self, x): 218 | for conv in self.convolutions: 219 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 220 | 221 | x = x.transpose(1, 2) 222 | 223 | self.lstm.flatten_parameters() 224 | outputs, _ = self.lstm(x) 225 | return outputs 226 | 227 | 228 | class Decoder(nn.Module): 229 | def __init__(self): 230 | super(Decoder, self).__init__() 231 | self.prenet = Prenet( 232 | hps.num_mels * hps.n_frames_per_step, 233 | [hps.prenet_dim, hps.prenet_dim]) 234 | 235 | self.attention_rnn = nn.LSTMCell( 236 | hps.prenet_dim + hps.encoder_embedding_dim, 237 | hps.attention_rnn_dim) 238 | 239 | self.attention_layer = Attention( 240 | hps.attention_rnn_dim, hps.encoder_embedding_dim, 241 | hps.attention_dim, hps.attention_location_n_filters, 242 | hps.attention_location_kernel_size) 243 | 244 | self.decoder_rnn = nn.LSTMCell( 245 | hps.attention_rnn_dim + hps.encoder_embedding_dim, 246 | hps.decoder_rnn_dim, 1) 247 | 248 | self.linear_projection = LinearNorm( 249 | hps.decoder_rnn_dim + hps.encoder_embedding_dim, 250 | hps.num_mels * hps.n_frames_per_step) 251 | 252 | self.gate_layer = LinearNorm( 253 | hps.decoder_rnn_dim + hps.encoder_embedding_dim, 1, 254 | bias=True, w_init_gain='sigmoid') 255 | 256 | def get_go_frame(self, memory): 257 | ''' Gets all zeros frames to use as first decoder input 258 | PARAMS 259 | ------ 260 | memory: decoder outputs 261 | 262 | RETURNS 263 | ------- 264 | decoder_input: all zeros frames 265 | ''' 266 | B = memory.size(0) 267 | decoder_input = Variable(memory.data.new( 268 | B, hps.num_mels * hps.n_frames_per_step).zero_()) 269 | return decoder_input 270 | 271 | def initialize_decoder_states(self, memory, mask): 272 | ''' Initializes attention rnn states, decoder rnn states, attention 273 | weights, attention cumulative weights, attention context, stores memory 274 | and stores processed memory 275 | PARAMS 276 | ------ 277 | memory: Encoder outputs 278 | mask: Mask for padded data if training, expects None for inference 279 | ''' 280 | B = memory.size(0) 281 | MAX_TIME = memory.size(1) 282 | 283 | self.attention_hidden = Variable(memory.data.new( 284 | B, hps.attention_rnn_dim).zero_()) 285 | self.attention_cell = Variable(memory.data.new( 286 | B, hps.attention_rnn_dim).zero_()) 287 | 288 | self.decoder_hidden = Variable(memory.data.new( 289 | B, hps.decoder_rnn_dim).zero_()) 290 | self.decoder_cell = Variable(memory.data.new( 291 | B, hps.decoder_rnn_dim).zero_()) 292 | 293 | self.attention_weights = Variable(memory.data.new( 294 | B, MAX_TIME).zero_()) 295 | self.attention_weights_cum = Variable(memory.data.new( 296 | B, MAX_TIME).zero_()) 297 | self.attention_context = Variable(memory.data.new( 298 | B, hps.encoder_embedding_dim).zero_()) 299 | 300 | self.memory = memory 301 | self.processed_memory = self.attention_layer.memory_layer(memory) 302 | self.mask = mask 303 | 304 | def parse_decoder_inputs(self, decoder_inputs): 305 | ''' Prepares decoder inputs, i.e. mel outputs 306 | PARAMS 307 | ------ 308 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 309 | 310 | RETURNS 311 | ------- 312 | inputs: processed decoder inputs 313 | 314 | ''' 315 | # (B, num_mels, T_out) -> (B, T_out, num_mels) 316 | decoder_inputs = decoder_inputs.transpose(1, 2).contiguous() 317 | decoder_inputs = decoder_inputs.view( 318 | decoder_inputs.size(0), 319 | int(decoder_inputs.size(1)/hps.n_frames_per_step), -1) 320 | # (B, T_out, num_mels) -> (T_out, B, num_mels) 321 | decoder_inputs = decoder_inputs.transpose(0, 1) 322 | return decoder_inputs 323 | 324 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): 325 | ''' Prepares decoder outputs for output 326 | PARAMS 327 | ------ 328 | mel_outputs: 329 | gate_outputs: gate output energies 330 | alignments: 331 | 332 | RETURNS 333 | ------- 334 | mel_outputs: 335 | gate_outpust: gate output energies 336 | alignments: 337 | ''' 338 | # (T_out, B) -> (B, T_out) 339 | alignments = torch.stack(alignments).transpose(0, 1) 340 | # (T_out, B) -> (B, T_out) 341 | gate_outputs = torch.stack(gate_outputs).transpose(0, 1) 342 | gate_outputs = gate_outputs.contiguous() 343 | # (T_out, B, num_mels) -> (B, T_out, num_mels) 344 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 345 | # decouple frames per step 346 | mel_outputs = mel_outputs.view( 347 | mel_outputs.size(0), -1, hps.num_mels) 348 | # (B, T_out, num_mels) -> (B, num_mels, T_out) 349 | mel_outputs = mel_outputs.transpose(1, 2) 350 | return mel_outputs, gate_outputs, alignments 351 | 352 | def decode(self, decoder_input): 353 | ''' Decoder step using stored states, attention and memory 354 | PARAMS 355 | ------ 356 | decoder_input: previous mel output 357 | 358 | RETURNS 359 | ------- 360 | mel_output: 361 | gate_output: gate output energies 362 | attention_weights: 363 | ''' 364 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 365 | self.attention_hidden, self.attention_cell = self.attention_rnn( 366 | cell_input, (self.attention_hidden, self.attention_cell)) 367 | self.attention_hidden = F.dropout( 368 | self.attention_hidden, hps.p_attention_dropout, self.training) 369 | 370 | attention_weights_cat = torch.cat( 371 | (self.attention_weights.unsqueeze(1), 372 | self.attention_weights_cum.unsqueeze(1)), dim=1) 373 | self.attention_context, self.attention_weights = self.attention_layer( 374 | self.attention_hidden, self.memory, self.processed_memory, 375 | attention_weights_cat, self.mask) 376 | 377 | self.attention_weights_cum += self.attention_weights 378 | decoder_input = torch.cat( 379 | (self.attention_hidden, self.attention_context), -1) 380 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 381 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 382 | self.decoder_hidden = F.dropout( 383 | self.decoder_hidden, hps.p_decoder_dropout, self.training) 384 | 385 | decoder_hidden_attention_context = torch.cat( 386 | (self.decoder_hidden, self.attention_context), dim=1) 387 | decoder_output = self.linear_projection( 388 | decoder_hidden_attention_context) 389 | 390 | gate_prediction = self.gate_layer(decoder_hidden_attention_context) 391 | return decoder_output, gate_prediction, self.attention_weights 392 | 393 | def forward(self, memory, decoder_inputs, memory_lengths): 394 | ''' Decoder forward pass for training 395 | PARAMS 396 | ------ 397 | memory: Encoder outputs 398 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs 399 | memory_lengths: Encoder output lengths for attention masking. 400 | 401 | RETURNS 402 | ------- 403 | mel_outputs: mel outputs from the decoder 404 | gate_outputs: gate outputs from the decoder 405 | alignments: sequence of attention weights from the decoder 406 | ''' 407 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 408 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 409 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 410 | decoder_inputs = self.prenet(decoder_inputs) 411 | 412 | self.initialize_decoder_states( 413 | memory, mask=~get_mask_from_lengths(memory_lengths)) 414 | 415 | mel_outputs, gate_outputs, alignments = [], [], [] 416 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 417 | decoder_input = decoder_inputs[len(mel_outputs)] 418 | mel_output, gate_output, attention_weights = self.decode( 419 | decoder_input) 420 | mel_outputs += [mel_output.squeeze(1)] 421 | gate_outputs += [gate_output.squeeze()] 422 | alignments += [attention_weights] 423 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 424 | mel_outputs, gate_outputs, alignments) 425 | return mel_outputs, gate_outputs, alignments 426 | 427 | def inference(self, memory): 428 | ''' Decoder inference 429 | PARAMS 430 | ------ 431 | memory: Encoder outputs 432 | 433 | RETURNS 434 | ------- 435 | mel_outputs: mel outputs from the decoder 436 | gate_outputs: gate outputs from the decoder 437 | alignments: sequence of attention weights from the decoder 438 | ''' 439 | decoder_input = self.get_go_frame(memory) 440 | 441 | self.initialize_decoder_states(memory, mask=None) 442 | 443 | mel_outputs, gate_outputs, alignments = [], [], [] 444 | while True: 445 | decoder_input = self.prenet(decoder_input) 446 | mel_output, gate_output, alignment = self.decode(decoder_input) 447 | 448 | mel_outputs += [mel_output.squeeze(1)] 449 | gate_outputs += [gate_output] 450 | alignments += [alignment] 451 | 452 | if torch.sigmoid(gate_output.data) > hps.gate_threshold: 453 | print('Terminated by gate.') 454 | break 455 | elif hps.n_frames_per_step*len(mel_outputs)/alignment.shape[1] \ 456 | >= hps.max_decoder_ratio: 457 | print('Warning: Reached max decoder steps.') 458 | break 459 | 460 | decoder_input = mel_output 461 | 462 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 463 | mel_outputs, gate_outputs, alignments) 464 | return mel_outputs, gate_outputs, alignments 465 | 466 | 467 | class Tacotron2(nn.Module): 468 | def __init__(self): 469 | super(Tacotron2, self).__init__() 470 | self.embedding = nn.Embedding( 471 | hps.n_symbols, hps.symbols_embedding_dim) 472 | std = sqrt(2.0/(hps.n_symbols+hps.symbols_embedding_dim)) 473 | val = sqrt(3.0)*std 474 | self.embedding.weight.data.uniform_(-val, val) 475 | self.encoder = Encoder() 476 | self.decoder = Decoder() 477 | self.postnet = Postnet() 478 | 479 | def parse_batch(self, batch): 480 | text_padded, input_lengths, mel_padded, gate_padded, output_lengths = batch 481 | text_padded = mode(text_padded).long() 482 | input_lengths = mode(input_lengths).long() 483 | max_len = torch.max(input_lengths.data).item() 484 | mel_padded = mode(mel_padded).float() 485 | gate_padded = mode(gate_padded).float() 486 | output_lengths = mode(output_lengths).long() 487 | return ( 488 | (text_padded, input_lengths, mel_padded, max_len, output_lengths), 489 | (mel_padded, gate_padded, output_lengths)) 490 | 491 | def parse_output(self, outputs, output_lengths=None): 492 | if output_lengths is not None: 493 | mask = ~get_mask_from_lengths(output_lengths, True) # (B, T) 494 | mask = mask.expand(hps.num_mels, mask.size(0), mask.size(1)) # (80, B, T) 495 | mask = mask.permute(1, 0, 2) # (B, 80, T) 496 | 497 | outputs[0].data.masked_fill_(mask, 0.0) # (B, 80, T) 498 | outputs[1].data.masked_fill_(mask, 0.0) # (B, 80, T) 499 | slice = torch.arange(0, mask.size(2), hps.n_frames_per_step) 500 | outputs[2].data.masked_fill_(mask[:, 0, slice], 1e3) # gate energies (B, T//n_frames_per_step) 501 | return outputs 502 | 503 | def forward(self, inputs): 504 | text_inputs, text_lengths, mels, max_len, output_lengths = inputs 505 | text_lengths, output_lengths = text_lengths.data, output_lengths.data 506 | 507 | embedded_inputs = self.embedding(text_inputs).transpose(1, 2) 508 | 509 | encoder_outputs = self.encoder(embedded_inputs, text_lengths) 510 | 511 | mel_outputs, gate_outputs, alignments = self.decoder( 512 | encoder_outputs, mels, memory_lengths=text_lengths) 513 | 514 | mel_outputs_postnet = self.postnet(mel_outputs) 515 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 516 | return self.parse_output( 517 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], 518 | output_lengths) 519 | 520 | def inference(self, inputs): 521 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 522 | encoder_outputs = self.encoder.inference(embedded_inputs) 523 | mel_outputs, gate_outputs, alignments = self.decoder.inference( 524 | encoder_outputs) 525 | 526 | mel_outputs_postnet = self.postnet(mel_outputs) 527 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 528 | 529 | outputs = self.parse_output( 530 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) 531 | return outputs 532 | 533 | def teacher_infer(self, inputs, mels): 534 | il, _ = torch.sort(torch.LongTensor([len(x) for x in inputs]), 535 | dim = 0, descending = True) 536 | text_lengths = mode(il) 537 | 538 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 539 | 540 | encoder_outputs = self.encoder(embedded_inputs, text_lengths) 541 | 542 | mel_outputs, gate_outputs, alignments = self.decoder( 543 | encoder_outputs, mels, memory_lengths=text_lengths) 544 | 545 | mel_outputs_postnet = self.postnet(mel_outputs) 546 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 547 | return self.parse_output( 548 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) 549 | --------------------------------------------------------------------------------