├── requirements.txt ├── text ├── symbols.py ├── LICENSE ├── cmudict.py ├── numbers.py ├── __init__.py └── cleaners.py ├── utils ├── util.py ├── plot.py ├── logger.py ├── audio.py └── dataset.py ├── LICENSE ├── model ├── layers.py └── model.py ├── hparams.py ├── mkgta.py ├── inference.py ├── inference.ipynb ├── README.md └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pillow 4 | inflect 5 | librosa 6 | Unidecode 7 | matplotlib 8 | tensorboardX 9 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from hparams import hparams as hps 4 | 5 | def mode(obj, model = False): 6 | if model and hps.is_cuda: 7 | obj = obj.cuda() 8 | elif hps.is_cuda: 9 | obj = obj.cuda(non_blocking = hps.pin_mem) 10 | return obj 11 | 12 | def to_arr(var): 13 | return var.cpu().detach().numpy().astype(np.float32) 14 | 15 | def get_mask_from_lengths(lengths, pad = False): 16 | max_len = torch.max(lengths).item() 17 | if pad and max_len%hps.n_frames_per_step != 0: 18 | max_len += hps.n_frames_per_step - max_len%hps.n_frames_per_step 19 | assert max_len%hps.n_frames_per_step == 0 20 | ids = torch.arange(0, max_len, out = torch.LongTensor(max_len)) 21 | ids = mode(ids) 22 | mask = (ids < lengths.unsqueeze(1)) 23 | return mask 24 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 BogiHsu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import numpy as np 4 | import matplotlib.pylab as plt 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data.transpose(2, 0, 1) 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearNorm(torch.nn.Module): 5 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 6 | super(LinearNorm, self).__init__() 7 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 8 | 9 | torch.nn.init.xavier_uniform_( 10 | self.linear_layer.weight, 11 | gain=torch.nn.init.calculate_gain(w_init_gain)) 12 | 13 | def forward(self, x): 14 | return self.linear_layer(x) 15 | 16 | 17 | class ConvNorm(torch.nn.Module): 18 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 19 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 20 | super(ConvNorm, self).__init__() 21 | if padding is None: 22 | assert(kernel_size % 2 == 1) 23 | padding = int(dilation * (kernel_size - 1) / 2) 24 | 25 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 26 | kernel_size=kernel_size, stride=stride, 27 | padding=padding, dilation=dilation, 28 | bias=bias) 29 | 30 | torch.nn.init.xavier_uniform_( 31 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 32 | 33 | def forward(self, signal): 34 | conv_signal = self.conv(signal) 35 | return conv_signal 36 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | 4 | class hparams: 5 | seed = 0 6 | 7 | ################################ 8 | # Data Parameters # 9 | ################################ 10 | text_cleaners=['english_cleaners'] 11 | 12 | ################################ 13 | # Audio # 14 | ################################ 15 | num_mels = 80 16 | num_freq = 513 17 | sample_rate = 22050 18 | frame_shift = 256 19 | frame_length = 1024 20 | fmin = 0 21 | fmax = 8000 22 | power = 1.5 23 | gl_iters = 30 24 | 25 | ################################ 26 | # Train # 27 | ################################ 28 | is_cuda = True 29 | pin_mem = True 30 | n_workers = 4 31 | prep = True 32 | pth = 'lj-22k.pkl' 33 | lr = 2e-3 34 | betas = (0.9, 0.999) 35 | eps = 1e-6 36 | sch = True 37 | sch_step = 4000 38 | max_iter = 200e3 39 | batch_size = 16 40 | iters_per_log = 10 41 | iters_per_sample = 500 42 | iters_per_ckpt = 10000 43 | weight_decay = 1e-6 44 | grad_clip_thresh = 1.0 45 | eg_text = 'OMAK is a thinking process which considers things always positively.' 46 | 47 | ################################ 48 | # Model Parameters # 49 | ################################ 50 | n_symbols = len(symbols) 51 | symbols_embedding_dim = 512 52 | 53 | # Encoder parameters 54 | encoder_kernel_size = 5 55 | encoder_n_convolutions = 3 56 | encoder_embedding_dim = 512 57 | 58 | # Decoder parameters 59 | n_frames_per_step = 3 60 | decoder_rnn_dim = 1024 61 | prenet_dim = 256 62 | max_decoder_ratio = 10 63 | gate_threshold = 0.5 64 | p_attention_dropout = 0.1 65 | p_decoder_dropout = 0.1 66 | 67 | # Attention parameters 68 | attention_rnn_dim = 1024 69 | attention_dim = 128 70 | 71 | # Location Layer parameters 72 | attention_location_n_filters = 32 73 | attention_location_kernel_size = 31 74 | 75 | # Mel-post processing network parameters 76 | postnet_embedding_dim = 512 77 | postnet_kernel_size = 5 78 | postnet_n_convolutions = 5 79 | 80 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s != '_' and s != '~' 75 | -------------------------------------------------------------------------------- /mkgta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pylab as plt 6 | from text import text_to_sequence 7 | from model.model import Tacotron2 8 | from hparams import hparams as hps 9 | from utils.util import mode, to_var, to_arr 10 | from utils.audio import load_wav, save_wav, melspectrogram 11 | 12 | 13 | def files_to_list(fdir = 'data'): 14 | f_list = [] 15 | with open(os.path.join(fdir, 'metadata.csv'), encoding = 'utf-8') as f: 16 | for line in f: 17 | parts = line.strip().split('|') 18 | wav_path = os.path.join(fdir, 'wavs', '%s.wav' % parts[0]) 19 | f_list.append([wav_path, parts[1]]) 20 | return f_list 21 | 22 | 23 | def load_model(ckpt_pth): 24 | ckpt_dict = torch.load(ckpt_pth) 25 | model = Tacotron2() 26 | model.load_state_dict(ckpt_dict['model']) 27 | model = mode(model, True).eval() 28 | model.decoder.train() 29 | model.postnet.train() 30 | return model 31 | 32 | 33 | def infer(wav_path, text, model): 34 | sequence = text_to_sequence(text, hps.text_cleaners) 35 | sequence = to_var(torch.IntTensor(sequence)[None, :]).long() 36 | mel = melspectrogram(load_wav(wav_path)) 37 | mel_in = to_var(torch.Tensor([mel])) 38 | r = mel_in.shape[2]%hps.n_frames_per_step 39 | if r != 0: 40 | mel_in = mel_in[:, :, :-r] 41 | sequence = torch.cat([sequence, sequence], 0) 42 | mel_in = torch.cat([mel_in, mel_in], 0) 43 | _, mel_outputs_postnet, _, _ = model.teacher_infer(sequence, mel_in) 44 | ret = mel 45 | if r != 0: 46 | ret[:, :-r] = to_arr(mel_outputs_postnet[0]) 47 | else: 48 | ret = to_arr(mel_outputs_postnet[0]) 49 | return ret 50 | 51 | 52 | def save_mel(res, pth, name): 53 | out = os.path.join(pth, name) 54 | np.save(out, res) 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('-c', '--ckpt_pth', type = str, default = '', 60 | required = True, help = 'path to load checkpoints') 61 | parser.add_argument('-n', '--npy_pth', type = str, default = 'dump', 62 | help = 'path to save mels') 63 | 64 | args = parser.parse_args() 65 | 66 | torch.backends.cudnn.enabled = True 67 | torch.backends.cudnn.benchmark = False 68 | model = load_model(args.ckpt_pth) 69 | flist = files_to_list() 70 | for x in flist: 71 | ret = infer(x[0], x[1], model) 72 | name = x[0].split('/')[-1].split('.wav')[0] 73 | if args.npy_pth != '': 74 | save_mel(ret, args.npy_pth, name) 75 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.util import to_arr 3 | from hparams import hparams as hps 4 | from tensorboardX import SummaryWriter 5 | from utils.audio import inv_melspectrogram 6 | from utils.plot import plot_alignment_to_numpy, plot_spectrogram_to_numpy 7 | 8 | 9 | class Tacotron2Logger(SummaryWriter): 10 | def __init__(self, logdir): 11 | super(Tacotron2Logger, self).__init__(logdir, flush_secs = 5) 12 | 13 | def log_training(self, items, grad_norm, learning_rate, iteration): 14 | self.add_scalar('loss.mel', items[0], iteration) 15 | self.add_scalar('loss.gate', items[1], iteration) 16 | self.add_scalar('grad.norm', grad_norm, iteration) 17 | self.add_scalar('learning.rate', learning_rate, iteration) 18 | 19 | def sample_train(self, outputs, iteration): 20 | mel_outputs = to_arr(outputs[0][0]) 21 | mel_outputs_postnet = to_arr(outputs[1][0]) 22 | alignments = to_arr(outputs[3][0]).T 23 | 24 | # plot alignment, mel and postnet output 25 | self.add_image( 26 | 'train.align', 27 | plot_alignment_to_numpy(alignments), 28 | iteration) 29 | self.add_image( 30 | 'train.mel', 31 | plot_spectrogram_to_numpy(mel_outputs), 32 | iteration) 33 | self.add_image( 34 | 'train.mel_post', 35 | plot_spectrogram_to_numpy(mel_outputs_postnet), 36 | iteration) 37 | 38 | def sample_infer(self, outputs, iteration): 39 | mel_outputs = to_arr(outputs[0][0]) 40 | mel_outputs_postnet = to_arr(outputs[1][0]) 41 | alignments = to_arr(outputs[2][0]).T 42 | 43 | # plot alignment, mel and postnet output 44 | self.add_image( 45 | 'infer.align', 46 | plot_alignment_to_numpy(alignments), 47 | iteration) 48 | self.add_image( 49 | 'infer.mel', 50 | plot_spectrogram_to_numpy(mel_outputs), 51 | iteration) 52 | self.add_image( 53 | 'infer.mel_post', 54 | plot_spectrogram_to_numpy(mel_outputs_postnet), 55 | iteration) 56 | 57 | # save audio 58 | wav = inv_melspectrogram(mel_outputs) 59 | wav_postnet = inv_melspectrogram(mel_outputs_postnet) 60 | self.add_audio('infer.wav', wav, iteration, hps.sample_rate) 61 | self.add_audio('infer.wav_post', wav_postnet, iteration, hps.sample_rate) 62 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import matplotlib.pylab as plt 5 | from text import text_to_sequence 6 | from model.model import Tacotron2 7 | from hparams import hparams as hps 8 | from utils.util import mode, to_arr 9 | from utils.audio import save_wav, inv_melspectrogram 10 | 11 | 12 | def load_model(ckpt_pth): 13 | ckpt_dict = torch.load(ckpt_pth) 14 | model = Tacotron2() 15 | model.load_state_dict(ckpt_dict['model']) 16 | model = mode(model, True).eval() 17 | return model 18 | 19 | 20 | def infer(text, model): 21 | sequence = text_to_sequence(text, hps.text_cleaners) 22 | sequence = mode(torch.IntTensor(sequence)[None, :]).long() 23 | mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence) 24 | return (mel_outputs, mel_outputs_postnet, alignments) 25 | 26 | 27 | def plot_data(data, figsize = (16, 4)): 28 | fig, axes = plt.subplots(1, len(data), figsize = figsize) 29 | for i in range(len(data)): 30 | axes[i].imshow(data[i], aspect = 'auto', origin = 'bottom') 31 | 32 | 33 | def plot(output, pth): 34 | mel_outputs, mel_outputs_postnet, alignments = output 35 | plot_data((to_arr(mel_outputs[0]), 36 | to_arr(mel_outputs_postnet[0]), 37 | to_arr(alignments[0]).T)) 38 | plt.savefig(pth+'.png') 39 | 40 | 41 | def audio(output, pth): 42 | mel_outputs, mel_outputs_postnet, _ = output 43 | wav_postnet = inv_melspectrogram(to_arr(mel_outputs_postnet[0])) 44 | save_wav(wav_postnet, pth+'.wav') 45 | 46 | 47 | def save_mel(output, pth): 48 | mel_outputs, mel_outputs_postnet, _ = output 49 | np.save(pth+'.npy', to_arr(mel_outputs_postnet)) 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('-c', '--ckpt_pth', type = str, default = '', 55 | required = True, help = 'path to load checkpoints') 56 | parser.add_argument('-i', '--img_pth', type = str, default = '', 57 | help = 'path to save images') 58 | parser.add_argument('-w', '--wav_pth', type = str, default = '', 59 | help = 'path to save wavs') 60 | parser.add_argument('-n', '--npy_pth', type = str, default = '', 61 | help = 'path to save mels') 62 | parser.add_argument('-t', '--text', type = str, default = 'Tacotron is awesome.', 63 | help = 'text to synthesize') 64 | 65 | args = parser.parse_args() 66 | 67 | torch.backends.cudnn.enabled = True 68 | torch.backends.cudnn.benchmark = False 69 | model = load_model(args.ckpt_pth) 70 | output = infer(args.text, model) 71 | if args.img_pth != '': 72 | plot(output, args.img_pth) 73 | if args.wav_pth != '': 74 | audio(output, args.wav_pth) 75 | if args.npy_pth != '': 76 | save_mel(output, args.npy_pth) 77 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import librosa 3 | import numpy as np 4 | from scipy.io import wavfile 5 | from librosa.util import normalize 6 | from hparams import hparams as hps 7 | MAX_WAV_VALUE = 32768.0 8 | _mel_basis = None 9 | 10 | 11 | def load_wav(path): 12 | sr, wav = wavfile.read(path) 13 | assert sr == hps.sample_rate 14 | return normalize(wav/MAX_WAV_VALUE)*0.95 15 | 16 | 17 | def save_wav(wav, path): 18 | wav *= MAX_WAV_VALUE 19 | wavfile.write(path, hps.sample_rate, wav.astype(np.int16)) 20 | 21 | 22 | def spectrogram(y): 23 | D = _stft(y) 24 | S = _amp_to_db(np.abs(D)) 25 | return S 26 | 27 | 28 | def inv_spectrogram(S): 29 | S = _db_to_amp(S) 30 | return _griffin_lim(S ** hps.power) 31 | 32 | 33 | def melspectrogram(y): 34 | D = _stft(y) 35 | S = _amp_to_db(_linear_to_mel(np.abs(D))) 36 | return S 37 | 38 | 39 | def inv_melspectrogram(mel): 40 | mel = _db_to_amp(mel) 41 | S = _mel_to_linear(mel) 42 | return _griffin_lim(S**hps.power) 43 | 44 | 45 | def _griffin_lim(S): 46 | '''librosa implementation of Griffin-Lim 47 | Based on https://github.com/librosa/librosa/issues/434 48 | ''' 49 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 50 | S_complex = np.abs(S).astype(np.complex) 51 | y = _istft(S_complex * angles) 52 | for i in range(hps.gl_iters): 53 | angles = np.exp(1j * np.angle(_stft(y))) 54 | y = _istft(S_complex * angles) 55 | return np.clip(y, a_max = 1, a_min = -1) 56 | 57 | 58 | # Conversions: 59 | def _stft(y): 60 | n_fft, hop_length, win_length = _stft_parameters() 61 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, 62 | win_length=win_length, pad_mode='reflect') 63 | 64 | 65 | def _istft(y): 66 | _, hop_length, win_length = _stft_parameters() 67 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 68 | 69 | 70 | def _stft_parameters(): 71 | return (hps.num_freq - 1) * 2, hps.frame_shift, hps.frame_length 72 | 73 | 74 | def _linear_to_mel(spectrogram): 75 | global _mel_basis 76 | if _mel_basis is None: 77 | _mel_basis = _build_mel_basis() 78 | return np.dot(_mel_basis, spectrogram) 79 | 80 | 81 | def _mel_to_linear(spectrogram): 82 | global _mel_basis 83 | if _mel_basis is None: 84 | _mel_basis = _build_mel_basis() 85 | inv_mel_basis = np.linalg.pinv(_mel_basis) 86 | inverse = np.dot(inv_mel_basis, spectrogram) 87 | inverse = np.maximum(1e-10, inverse) 88 | return inverse 89 | 90 | 91 | def _build_mel_basis(): 92 | n_fft = (hps.num_freq - 1) * 2 93 | return librosa.filters.mel(hps.sample_rate, n_fft, n_mels=hps.num_mels, fmin = hps.fmin, fmax = hps.fmax) 94 | 95 | 96 | def _amp_to_db(x): 97 | return np.log(np.maximum(1e-5, x)) 98 | 99 | 100 | def _db_to_amp(x): 101 | return np.exp(x) 102 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","source":["# Text-to-Speech with Tacotron2 and HifiGAN\n","\n","This is an English female voice TTS demo using open source projects [BogiHsu/Tacotron2-PyTorch](https://github.com/BogiHsu/Tacotron2-PyTorch) and [jik876/hifi-gan](https://github.com/jik876/hifi-gan).\n","\n","Please enable GPU acceleration in Colab before you start running the code."],"metadata":{"id":"3qb1435p6OEz"}},{"cell_type":"markdown","metadata":{"id":"3rof5nB25HD4"},"source":["## Set up environment"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"OBB0-ki_5HD7"},"outputs":[],"source":["!git clone https://github.com/BogiHsu/Tacotron2-PyTorch.git\n","!pip install -r Tacotron2-PyTorch/requirements.txt\n","!git clone https://github.com/jik876/hifi-gan.git\n","!mkdir -p mel_files"]},{"cell_type":"markdown","metadata":{"id":"whZkDydI5HD9"},"source":["## Download Tacotron2 pretrained model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"EbbOs-sT5HD9"},"outputs":[],"source":["!wget https://github.com/BogiHsu/Tacotron2-PyTorch/releases/download/lj-200k-b512/ckpt_200000 -O Tacotron2-PyTorch/ckpt_200000"]},{"cell_type":"markdown","metadata":{"id":"OTcGBEfG5HD-"},"source":["## Download HifiGAN pretrained model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gH6vR6045HD-"},"outputs":[],"source":["import gdown\n","dl = {'hifi-gan/config.json': 'https://drive.google.com/u/1/uc?id=1aDh576AEYA5eTjhx7sew1qcCM_Y526jc&export=download',\n"," 'hifi-gan/generator_v1': 'https://drive.google.com/u/1/uc?id=14NENd4equCBLyyCSke114Mv6YR_j_uFs&export=download'}\n","for k in dl:\n"," gdown.download(dl[k], k)"]},{"cell_type":"markdown","metadata":{"id":"DrOB1Ku15HD_"},"source":["## Tacotron2 inference"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"z44MMC_y5HD_"},"outputs":[],"source":["%cd Tacotron2-PyTorch\n","!python3 inference.py --ckpt_pth=ckpt_200000 --text='Tacotron is awesome.' --npy_pth=../mel_files/1\n","%cd ../"]},{"cell_type":"markdown","metadata":{"id":"9cHVnw4e5HEA"},"source":["## HifiGAN inference"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KmMdKHKL5HEA"},"outputs":[],"source":["%cd hifi-gan\n","!python3 inference_e2e.py --checkpoint_file generator_v1 --input_mels_dir ../mel_files --output_dir ../wav_files\n","%cd ../"]},{"cell_type":"markdown","metadata":{"id":"M2ElzCby5HEB"},"source":["## Play synthesized wav file"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9SGCUtQC5HEB"},"outputs":[],"source":["import IPython\n","IPython.display.Audio('wav_files/1_generated_e2e.wav')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ws_nYQk35HEB"},"outputs":[],"source":[""]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.10"},"colab":{"name":"inference.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron2-PyTorch 2 | Yet another PyTorch implementation of [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf). The project is highly based on [these](#References). I made some modification to improve speed and performance of both training and inference. 3 | 4 | ## TODO 5 | - [x] Add Colab demo. 6 | - [x] Update README. 7 | - [x] Upload pretrained models. 8 | - [x] Compatible with [WaveGlow](https://github.com/NVIDIA/waveglow) and [Hifi-GAN](https://github.com/jik876/hifi-gan). 9 | 10 | ## Requirements 11 | - Python >= 3.5.2 12 | - torch >= 1.0.0 13 | - numpy 14 | - scipy 15 | - pillow 16 | - inflect 17 | - librosa 18 | - Unidecode 19 | - matplotlib 20 | - tensorboardX 21 | 22 | ## Preprocessing 23 | Currently only support [LJ Speech](https://keithito.com/LJ-Speech-Dataset/). You can modify `hparams.py` for different sampling rates. `prep` decides whether to preprocess all utterances before training or online preprocess. `pth` sepecifies the path to store preprocessed data. 24 | 25 | ## Training 26 | 1. For training Tacotron2, run the following command. 27 | ```bash 28 | python3 train.py \ 29 | --data_dir=