├── datasets ├── __init__.py ├── .gitignore ├── lj_speech.py ├── data_loader.py └── mb_speech.py ├── models ├── __init__.py ├── ssrn.py ├── layers.py └── text2mel.py ├── .gitignore ├── requirements.txt ├── LICENSE ├── logger.py ├── README.md ├── hparams.py ├── utils.py ├── audio.py ├── synthesize.py ├── train-ssrn.py ├── train-text2mel.py └── dl_and_preprop_dataset.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | LJSpeech-1.1/ 2 | MBSpeech-1.0/ 3 | *.tar.gz 4 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .text2mel import Text2Mel 2 | from .ssrn import SSRN 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | .ipynb_checkpoints 5 | *.ipynb 6 | logdir/ 7 | samples 8 | *.npy 9 | *.tar.bz2 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa>=0.5.1 2 | torch>=0.4 3 | tensorboardX>=1.2 4 | tqdm>=4.15.0 5 | numpy>=1.25.0 6 | scipy 7 | pandas 8 | requests 9 | scikit-image 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Erdene-Ochir Tuguldur 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | """Wrapper class for logging into the TensorBoard and comet.ml""" 2 | __author__ = 'Erdene-Ochir Tuguldur' 3 | __all__ = ['Logger'] 4 | 5 | import os 6 | from tensorboardX import SummaryWriter 7 | 8 | from hparams import HParams as hp 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, dataset_name, model_name): 14 | self.model_name = model_name 15 | self.project_name = "%s-%s" % (dataset_name, self.model_name) 16 | self.logdir = os.path.join(hp.logdir, self.project_name) 17 | self.writer = SummaryWriter(log_dir=self.logdir) 18 | 19 | def log_step(self, phase, step, loss_dict, image_dict): 20 | if phase == 'train': 21 | if step % 50 == 0: 22 | # self.writer.add_scalar('lr', get_lr(), step) 23 | # self.writer.add_scalar('%s-step/loss' % phase, loss, step) 24 | for key in sorted(loss_dict): 25 | self.writer.add_scalar('%s-step/%s' % (phase, key), loss_dict[key], step) 26 | 27 | if step % 1000 == 0: 28 | for key in sorted(image_dict): 29 | self.writer.add_image('%s/%s' % (self.model_name, key), image_dict[key], step) 30 | 31 | def log_epoch(self, phase, step, loss_dict): 32 | for key in sorted(loss_dict): 33 | self.writer.add_scalar('%s/%s' % (phase, key), loss_dict[key], step) 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of 2 | [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention](https://arxiv.org/abs/1710.08969) 3 | based partially on the following projects: 4 | * https://github.com/Kyubyong/dc_tts (audio pre processing) 5 | * https://github.com/r9y9/deepvoice3_pytorch (data loader sampler) 6 | 7 | ## Online Text-To-Speech Demo 8 | The following notebooks are executable on [https://colab.research.google.com ](https://colab.research.google.com): 9 | * [Mongolian Male Voice TTS Demo](https://colab.research.google.com/github/tugstugi/pytorch-dc-tts/blob/master/notebooks/MongolianTTS.ipynb) 10 | * [English Female Voice TTS Demo (LJ-Speech)](https://colab.research.google.com/github/tugstugi/pytorch-dc-tts/blob/master/notebooks/EnglishTTS.ipynb) 11 | 12 | For audio samples and pretrained models, visit the above notebook links. 13 | 14 | ## Training/Synthesizing English Text-To-Speech 15 | The English TTS uses the [LJ-Speech](https://keithito.com/LJ-Speech-Dataset/) dataset. 16 | 1. Download the dataset: `python dl_and_preprop_dataset.py --dataset=ljspeech` 17 | 2. Train the Text2Mel model: `python train-text2mel.py --dataset=ljspeech` 18 | 3. Train the SSRN model: `python train-ssrn.py --dataset=ljspeech` 19 | 4. Synthesize sentences: `python synthesize.py --dataset=ljspeech` 20 | * The WAV files are saved in the `samples` folder. 21 | 22 | ## Training/Synthesizing Mongolian Text-To-Speech 23 | The Mongolian text-to-speech uses 5 hours audio from the [Mongolian Bible](https://www.bible.com/mn/versions/1590-2013-ariun-bibli-2013). 24 | 1. Download the dataset: `python dl_and_preprop_dataset.py --dataset=mbspeech` 25 | 2. Train the Text2Mel model: `python train-text2mel.py --dataset=mbspeech` 26 | 3. Train the SSRN model: `python train-ssrn.py --dataset=mbspeech` 27 | 4. Synthesize sentences: `python synthesize.py --dataset=mbspeech` 28 | * The WAV files are saved in the `samples` folder. 29 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | """Hyper parameters.""" 2 | __author__ = 'Erdene-Ochir Tuguldur' 3 | 4 | 5 | class HParams: 6 | """Hyper parameters""" 7 | 8 | disable_progress_bar = False # set True if you don't want the progress bar in the console 9 | 10 | logdir = "logdir" # log dir where the checkpoints and tensorboard files are saved 11 | 12 | # audio.py options, these values are from https://github.com/Kyubyong/dc_tts/blob/master/hyperparams.py 13 | reduction_rate = 4 # melspectrogram reduction rate, don't change because SSRN is using this rate 14 | n_fft = 2048 # fft points (samples) 15 | n_mels = 80 # Number of Mel banks to generate 16 | power = 1.5 # Exponent for amplifying the predicted magnitude 17 | n_iter = 50 # Number of inversion iterations 18 | preemphasis = .97 19 | max_db = 100 20 | ref_db = 20 21 | sr = 22050 # Sampling rate 22 | frame_shift = 0.0125 # seconds 23 | frame_length = 0.05 # seconds 24 | hop_length = int(sr * frame_shift) # samples. =276. 25 | win_length = int(sr * frame_length) # samples. =1102. 26 | max_N = 180 # Maximum number of characters. 27 | max_T = 210 # Maximum number of mel frames. 28 | 29 | e = 128 # embedding dimension 30 | d = 256 # Text2Mel hidden unit dimension 31 | c = 512+128 # SSRN hidden unit dimension 32 | 33 | dropout_rate = 0.05 # dropout 34 | 35 | # Text2Mel network options 36 | text2mel_lr = 0.005 # learning rate 37 | text2mel_max_iteration = 300000 # max train step 38 | text2mel_weight_init = 'none' # 'kaiming', 'xavier' or 'none' 39 | text2mel_normalization = 'layer' # 'layer', 'weight' or 'none' 40 | text2mel_basic_block = 'gated_conv' # 'highway', 'gated_conv' or 'residual' 41 | 42 | # SSRN network options 43 | ssrn_lr = 0.0005 # learning rate 44 | ssrn_max_iteration = 150000 # max train step 45 | ssrn_weight_init = 'kaiming' # 'kaiming', 'xavier' or 'none' 46 | ssrn_normalization = 'weight' # 'layer', 'weight' or 'none' 47 | ssrn_basic_block = 'residual' # 'highway', 'gated_conv' or 'residual' 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utility methods.""" 2 | __author__ = 'Erdene-Ochir Tuguldur' 3 | 4 | import os 5 | import sys 6 | import glob 7 | import torch 8 | import math 9 | import requests 10 | from tqdm import tqdm 11 | from skimage.io import imsave 12 | from skimage import img_as_ubyte 13 | 14 | 15 | def get_last_checkpoint_file_name(logdir): 16 | """Returns the last checkpoint file name in the given log dir path.""" 17 | checkpoints = glob.glob(os.path.join(logdir, '*.pth')) 18 | checkpoints.sort() 19 | if len(checkpoints) == 0: 20 | return None 21 | return checkpoints[-1] 22 | 23 | 24 | def load_checkpoint(checkpoint_file_name, model, optimizer): 25 | """Loads the checkpoint into the given model and optimizer.""" 26 | checkpoint = torch.load(checkpoint_file_name) 27 | model.load_state_dict(checkpoint['state_dict']) 28 | model.float() 29 | if optimizer is not None: 30 | optimizer.load_state_dict(checkpoint['optimizer']) 31 | start_epoch = checkpoint.get('epoch', 0) 32 | global_step = checkpoint.get('global_step', 0) 33 | del checkpoint 34 | print("loaded checkpoint epoch=%d step=%d" % (start_epoch, global_step)) 35 | return start_epoch, global_step 36 | 37 | 38 | def save_checkpoint(logdir, epoch, global_step, model, optimizer): 39 | """Saves the training state into the given log dir path.""" 40 | checkpoint_file_name = os.path.join(logdir, 'step-%03dK.pth' % (global_step // 1000)) 41 | print("saving the checkpoint file '%s'..." % checkpoint_file_name) 42 | checkpoint = { 43 | 'epoch': epoch + 1, 44 | 'global_step': global_step, 45 | 'state_dict': model.state_dict(), 46 | 'optimizer': optimizer.state_dict(), 47 | } 48 | torch.save(checkpoint, checkpoint_file_name) 49 | del checkpoint 50 | 51 | 52 | def download_file(url, file_path): 53 | """Downloads a file from the given URL.""" 54 | print("downloading %s..." % url) 55 | r = requests.get(url, stream=True) 56 | total_size = int(r.headers.get('content-length', 0)) 57 | block_size = 1024 * 1024 58 | wrote = 0 59 | with open(file_path, 'wb') as f: 60 | for data in tqdm(r.iter_content(block_size), total=math.ceil(total_size // block_size), unit='MB'): 61 | wrote = wrote + len(data) 62 | f.write(data) 63 | 64 | if total_size != 0 and wrote != total_size: 65 | print("downloading failed") 66 | sys.exit(1) 67 | 68 | 69 | def save_to_png(file_name, array): 70 | """Save the given numpy array as a PNG file.""" 71 | # from skimage._shared._warnings import expected_warnings 72 | # with expected_warnings(['precision']): 73 | imsave(file_name, img_as_ubyte(array)) 74 | -------------------------------------------------------------------------------- /models/ssrn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara 3 | Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention 4 | https://arxiv.org/abs/1710.08969 5 | 6 | SSRN Network. 7 | """ 8 | __author__ = 'Erdene-Ochir Tuguldur' 9 | __all__ = ['SSRN'] 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from hparams import HParams as hp 15 | from .layers import D, C, HighwayBlock, GatedConvBlock, ResidualBlock 16 | 17 | 18 | def Conv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): 19 | return C(in_channels, out_channels, kernel_size, dilation, causal=False, 20 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity) 21 | 22 | 23 | def DeConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): 24 | return D(in_channels, out_channels, kernel_size, dilation, 25 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity) 26 | 27 | 28 | def BasicBlock(d, k, delta): 29 | if hp.ssrn_basic_block == 'gated_conv': 30 | return GatedConvBlock(d, k, delta, causal=False, 31 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization) 32 | elif hp.ssrn_basic_block == 'highway': 33 | return HighwayBlock(d, k, delta, causal=False, 34 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization) 35 | else: 36 | return ResidualBlock(d, k, delta, causal=False, 37 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, 38 | widening_factor=1) 39 | 40 | 41 | class SSRN(nn.Module): 42 | def __init__(self, c=hp.c, f=hp.n_mels, f_prime=(1 + hp.n_fft // 2)): 43 | """Spectrogram super-resolution network. 44 | Args: 45 | c: SSRN dim 46 | f: Number of mel bins 47 | f_prime: full spectrogram dim 48 | Input: 49 | Y: (B, f, T) predicted melspectrograms 50 | Outputs: 51 | Z_logit: logit of Z 52 | Z: (B, f_prime, 4*T) full spectrograms 53 | """ 54 | super(SSRN, self).__init__() 55 | self.layers = nn.Sequential( 56 | Conv(f, c, 1, 1), 57 | 58 | BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), 59 | 60 | DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), 61 | DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), 62 | 63 | Conv(c, 2 * c, 1, 1), 64 | 65 | BasicBlock(2 * c, 3, 1), BasicBlock(2 * c, 3, 1), 66 | 67 | Conv(2 * c, f_prime, 1, 1), 68 | 69 | # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'), 70 | # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'), 71 | BasicBlock(f_prime, 1, 1), 72 | 73 | Conv(f_prime, f_prime, 1, 1) 74 | ) 75 | 76 | def forward(self, x): 77 | Z_logit = self.layers(x) 78 | Z = F.sigmoid(Z_logit) 79 | return Z_logit, Z -------------------------------------------------------------------------------- /datasets/lj_speech.py: -------------------------------------------------------------------------------- 1 | """Data loader for the LJSpeech dataset. See: https://keithito.com/LJ-Speech-Dataset/""" 2 | import os 3 | import re 4 | import codecs 5 | import unicodedata 6 | import numpy as np 7 | 8 | from torch.utils.data import Dataset 9 | 10 | vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS. 11 | char2idx = {char: idx for idx, char in enumerate(vocab)} 12 | idx2char = {idx: char for idx, char in enumerate(vocab)} 13 | 14 | 15 | def text_normalize(text): 16 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 17 | if unicodedata.category(char) != 'Mn') # Strip accents 18 | 19 | text = text.lower() 20 | text = re.sub("[^{}]".format(vocab), " ", text) 21 | text = re.sub("[ ]+", " ", text) 22 | return text 23 | 24 | 25 | def read_metadata(metadata_file): 26 | fnames, text_lengths, texts = [], [], [] 27 | transcript = os.path.join(metadata_file) 28 | lines = codecs.open(transcript, 'r', 'utf-8').readlines() 29 | for line in lines: 30 | fname, _, text = line.strip().split("|") 31 | 32 | fnames.append(fname) 33 | 34 | text = text_normalize(text) + "E" # E: EOS 35 | text = [char2idx[char] for char in text] 36 | text_lengths.append(len(text)) 37 | texts.append(np.array(text, np.longlong)) 38 | 39 | return fnames, text_lengths, texts 40 | 41 | 42 | def get_test_data(sentences, max_n): 43 | normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS 44 | texts = np.zeros((len(normalized_sentences), max_n + 1), np.longlong) 45 | for i, sent in enumerate(normalized_sentences): 46 | texts[i, :len(sent)] = [char2idx[char] for char in sent] 47 | return texts 48 | 49 | 50 | class LJSpeech(Dataset): 51 | def __init__(self, keys, dir_name='LJSpeech-1.1'): 52 | self.keys = keys 53 | self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name) 54 | self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv')) 55 | 56 | def slice(self, start, end): 57 | self.fnames = self.fnames[start:end] 58 | self.text_lengths = self.text_lengths[start:end] 59 | self.texts = self.texts[start:end] 60 | 61 | def __len__(self): 62 | return len(self.fnames) 63 | 64 | def __getitem__(self, index): 65 | data = {} 66 | if 'texts' in self.keys: 67 | data['texts'] = self.texts[index] 68 | if 'mels' in self.keys: 69 | # (39, 80) 70 | data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index])) 71 | if 'mags' in self.keys: 72 | # (39, 80) 73 | data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index])) 74 | if 'mel_gates' in self.keys: 75 | data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int64) # TODO: because pre processing! 76 | if 'mag_gates' in self.keys: 77 | data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int64) # TODO: because pre processing! 78 | return data 79 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | """These methods are copied from https://github.com/Kyubyong/dc_tts/""" 2 | 3 | import os 4 | import copy 5 | import librosa 6 | import scipy.io.wavfile 7 | import numpy as np 8 | 9 | from tqdm import tqdm 10 | from scipy import signal 11 | from hparams import HParams as hp 12 | 13 | 14 | def spectrogram2wav(mag): 15 | '''# Generate wave file from linear magnitude spectrogram 16 | Args: 17 | mag: A numpy array of (T, 1+n_fft//2) 18 | Returns: 19 | wav: A 1-D numpy array. 20 | ''' 21 | # transpose 22 | mag = mag.T 23 | 24 | # de-noramlize 25 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db 26 | 27 | # to amplitude 28 | mag = np.power(10.0, mag * 0.05) 29 | 30 | # wav reconstruction 31 | wav = griffin_lim(mag ** hp.power) 32 | 33 | # de-preemphasis 34 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav) 35 | 36 | # trim 37 | wav, _ = librosa.effects.trim(wav) 38 | 39 | return wav.astype(np.float32) 40 | 41 | 42 | def griffin_lim(spectrogram): 43 | '''Applies Griffin-Lim's raw.''' 44 | X_best = copy.deepcopy(spectrogram) 45 | for i in range(hp.n_iter): 46 | X_t = invert_spectrogram(X_best) 47 | est = librosa.stft(X_t, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 48 | phase = est / np.maximum(1e-8, np.abs(est)) 49 | X_best = spectrogram * phase 50 | X_t = invert_spectrogram(X_best) 51 | y = np.real(X_t) 52 | 53 | return y 54 | 55 | 56 | def invert_spectrogram(spectrogram): 57 | '''Applies inverse fft. 58 | Args: 59 | spectrogram: [1+n_fft//2, t] 60 | ''' 61 | return librosa.istft(spectrogram, hop_length=hp.hop_length, win_length=hp.win_length, window="hann") 62 | 63 | 64 | def get_spectrograms(fpath): 65 | '''Parse the wave file in `fpath` and 66 | Returns normalized melspectrogram and linear spectrogram. 67 | Args: 68 | fpath: A string. The full path of a sound file. 69 | Returns: 70 | mel: A 2d array of shape (T, n_mels) and dtype of float32. 71 | mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32. 72 | ''' 73 | # Loading sound file 74 | y, sr = librosa.load(fpath, sr=hp.sr) 75 | 76 | # Trimming 77 | y, _ = librosa.effects.trim(y) 78 | 79 | # Preemphasis 80 | y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1]) 81 | 82 | # stft 83 | linear = librosa.stft(y=y, 84 | n_fft=hp.n_fft, 85 | hop_length=hp.hop_length, 86 | win_length=hp.win_length) 87 | 88 | # magnitude spectrogram 89 | mag = np.abs(linear) # (1+n_fft//2, T) 90 | 91 | # mel spectrogram 92 | mel_basis = librosa.filters.mel(sr=hp.sr, n_fft=hp.n_fft, n_mels=hp.n_mels) # (n_mels, 1+n_fft//2) 93 | mel = np.dot(mel_basis, mag) # (n_mels, t) 94 | 95 | # to decibel 96 | mel = 20 * np.log10(np.maximum(1e-5, mel)) 97 | mag = 20 * np.log10(np.maximum(1e-5, mag)) 98 | 99 | # normalize 100 | mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 101 | mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 102 | 103 | # Transpose 104 | mel = mel.T.astype(np.float32) # (T, n_mels) 105 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2) 106 | 107 | return mel, mag 108 | 109 | 110 | def save_to_wav(mag, filename): 111 | """Generate and save an audio file from the given linear spectrogram using Griffin-Lim.""" 112 | wav = spectrogram2wav(mag) 113 | scipy.io.wavfile.write(filename, hp.sr, wav) 114 | 115 | 116 | def preprocess(dataset_path, speech_dataset): 117 | """Preprocess the given dataset.""" 118 | wavs_path = os.path.join(dataset_path, 'wavs') 119 | mels_path = os.path.join(dataset_path, 'mels') 120 | if not os.path.isdir(mels_path): 121 | os.mkdir(mels_path) 122 | mags_path = os.path.join(dataset_path, 'mags') 123 | if not os.path.isdir(mags_path): 124 | os.mkdir(mags_path) 125 | 126 | for fname in tqdm(speech_dataset.fnames): 127 | mel, mag = get_spectrograms(os.path.join(wavs_path, '%s.wav' % fname)) 128 | 129 | t = mel.shape[0] 130 | # Marginal padding for reduction shape sync. 131 | num_paddings = hp.reduction_rate - (t % hp.reduction_rate) if t % hp.reduction_rate != 0 else 0 132 | mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode="constant") 133 | mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant") 134 | # Reduction 135 | mel = mel[::hp.reduction_rate, :] 136 | 137 | np.save(os.path.join(mels_path, '%s.npy' % fname), mel) 138 | np.save(os.path.join(mags_path, '%s.npy' % fname), mag) 139 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Synthetize sentences into speech.""" 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import os 6 | import sys 7 | import argparse 8 | from tqdm import * 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from models import Text2Mel, SSRN 14 | from hparams import HParams as hp 15 | from audio import save_to_wav 16 | from utils import get_last_checkpoint_file_name, load_checkpoint, save_to_png 17 | 18 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') 20 | args = parser.parse_args() 21 | 22 | if args.dataset == 'ljspeech': 23 | from datasets.lj_speech import vocab, get_test_data 24 | 25 | SENTENCES = [ 26 | "The birch canoe slid on the smooth planks.", 27 | "Glue the sheet to the dark blue background.", 28 | "It's easy to tell the depth of a well.", 29 | "These days a chicken leg is a rare dish.", 30 | "Rice is often served in round bowls.", 31 | "The juice of lemons makes fine punch.", 32 | "The box was thrown beside the parked truck.", 33 | "The hogs were fed chopped corn and garbage.", 34 | "Four hours of steady work faced us.", 35 | "Large size in stockings is hard to sell.", 36 | "The boy was there when the sun rose.", 37 | "A rod is used to catch pink salmon.", 38 | "The source of the huge river is the clear spring.", 39 | "Kick the ball straight and follow through.", 40 | "Help the woman get back to her feet.", 41 | "A pot of tea helps to pass the evening.", 42 | "Smoky fires lack flame and heat.", 43 | "The soft cushion broke the man's fall.", 44 | "The salt breeze came across from the sea.", 45 | "The girl at the booth sold fifty bonds." 46 | ] 47 | else: 48 | from datasets.mb_speech import vocab, get_test_data 49 | 50 | SENTENCES = [ 51 | "Нийслэлийн прокурорын газраас төрийн өндөр албан тушаалтнуудад холбогдох зарим эрүүгийн хэргүүдийг шүүхэд шилжүүлэв.", 52 | "Мөнх тэнгэрийн хүчин дор Монгол Улс цэцэглэн хөгжих болтугай.", 53 | "Унасан хүлгээ түрүү магнай, аман хүзүүнд уралдуулж, айрагдуулсан унаач хүүхдүүдэд бэлэг гардууллаа.", 54 | "Албан ёсоор хэлэхэд “Монгол Улсын хэрэг эрхлэх газрын гэгээнтэн” гэж нэрлээд байгаа зүйл огт байхгүй.", 55 | "Сайн чанарын бохирын хоолой зарна.", 56 | "Хараа тэглэх мэс заслын дараа хараа дахин муудах магадлал бага.", 57 | "Ер нь бол хараа тэглэх мэс заслыг гоо сайхны мэс засалтай адилхан гэж зүйрлэж болно.", 58 | "Хашлага даван, зүлэг гэмтээсэн жолоочийн эрхийг хоёр жилээр хасжээ.", 59 | "Монгол хүн бидний сэтгэлийг сорсон орон. Энэ бол миний төрсөн нутаг. Монголын сайхан орон.", 60 | "Постройка крейсера затягивалась из-за проектных неувязок, необходимости." 61 | ] 62 | 63 | torch.set_grad_enabled(False) 64 | 65 | text2mel = Text2Mel(vocab).eval() 66 | last_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-text2mel' % args.dataset)) 67 | # last_checkpoint_file_name = 'logdir/%s-text2mel/step-020K.pth' % args.dataset 68 | if last_checkpoint_file_name: 69 | print("loading text2mel checkpoint '%s'..." % last_checkpoint_file_name) 70 | load_checkpoint(last_checkpoint_file_name, text2mel, None) 71 | else: 72 | print("text2mel not exits") 73 | sys.exit(1) 74 | 75 | ssrn = SSRN().eval() 76 | last_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-ssrn' % args.dataset)) 77 | # last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset 78 | if last_checkpoint_file_name: 79 | print("loading ssrn checkpoint '%s'..." % last_checkpoint_file_name) 80 | load_checkpoint(last_checkpoint_file_name, ssrn, None) 81 | else: 82 | print("ssrn not exits") 83 | sys.exit(1) 84 | 85 | # synthetize by one by one because there is a batch processing bug! 86 | for i in range(len(SENTENCES)): 87 | sentences = [SENTENCES[i]] 88 | 89 | max_N = len(SENTENCES[i]) 90 | L = torch.from_numpy(get_test_data(sentences, max_N)) 91 | zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32)) 92 | Y = zeros 93 | A = None 94 | 95 | for t in tqdm(range(hp.max_T)): 96 | _, Y_t, A = text2mel(L, Y, monotonic_attention=True) 97 | Y = torch.cat((zeros, Y_t), -1) 98 | _, attention = torch.max(A[0, :, -1], 0) 99 | attention = attention.item() 100 | if L[0, attention] == vocab.index('E'): # EOS 101 | break 102 | 103 | _, Z = ssrn(Y) 104 | 105 | Y = Y.cpu().detach().numpy() 106 | A = A.cpu().detach().numpy() 107 | Z = Z.cpu().detach().numpy() 108 | 109 | save_to_png('samples/%d-att.png' % (i + 1), A[0, :, :]) 110 | save_to_png('samples/%d-mel.png' % (i + 1), Y[0, :, :]) 111 | save_to_png('samples/%d-mag.png' % (i + 1), Z[0, :, :]) 112 | save_to_wav(Z[0, :, :].T, 'samples/%d-wav.wav' % (i + 1)) 113 | -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data.dataloader import default_collate, DataLoader 6 | from torch.utils.data.sampler import Sampler 7 | 8 | __all__ = ['Text2MelDataLoader', 'SSRNDataLoader'] 9 | 10 | 11 | class Text2MelDataLoader(DataLoader): 12 | def __init__(self, text2mel_dataset, batch_size, mode='train', num_workers=8): 13 | if mode == 'train': 14 | text2mel_dataset.slice(0, -batch_size) 15 | elif mode == 'valid': 16 | text2mel_dataset.slice(len(text2mel_dataset) - batch_size, -1) 17 | else: 18 | raise ValueError("mode must be either 'train' or 'valid'") 19 | super().__init__(text2mel_dataset, 20 | batch_size=batch_size, 21 | num_workers=num_workers, 22 | collate_fn=collate_fn, 23 | shuffle=True) 24 | 25 | 26 | class SSRNDataLoader(DataLoader): 27 | def __init__(self, ssrn_dataset, batch_size, mode='train', num_workers=8): 28 | if mode == 'train': 29 | ssrn_dataset.slice(0, -batch_size) 30 | super().__init__(ssrn_dataset, 31 | batch_size=batch_size, 32 | num_workers=num_workers, 33 | collate_fn=collate_fn, 34 | sampler=PartiallyRandomizedSimilarTimeLengthSampler(lengths=ssrn_dataset.text_lengths, 35 | data_source=None, 36 | batch_size=batch_size)) 37 | elif mode == 'valid': 38 | ssrn_dataset.slice(len(ssrn_dataset) - batch_size, -1) 39 | super().__init__(ssrn_dataset, 40 | batch_size=batch_size, 41 | num_workers=num_workers, 42 | collate_fn=collate_fn, 43 | shuffle=True) 44 | else: 45 | raise ValueError("mode must be either 'train' or 'valid'") 46 | 47 | 48 | def collate_fn(batch): 49 | keys = batch[0].keys() 50 | max_lengths = {key: 0 for key in keys} 51 | collated_batch = {key: [] for key in keys} 52 | 53 | # find out the max lengths 54 | for row in batch: 55 | for key in keys: 56 | max_lengths[key] = max(max_lengths[key], row[key].shape[0]) 57 | 58 | # pad to the max lengths 59 | for row in batch: 60 | for key in keys: 61 | array = row[key] 62 | dim = len(array.shape) 63 | assert dim == 1 or dim == 2 64 | # TODO: because of pre processing, later we want to have (n_mels, T) 65 | if dim == 1: 66 | padded_array = np.pad(array, (0, max_lengths[key] - array.shape[0]), mode='constant') 67 | else: 68 | padded_array = np.pad(array, ((0, max_lengths[key] - array.shape[0]), (0, 0)), mode='constant') 69 | collated_batch[key].append(padded_array) 70 | 71 | # use the default_collate to convert to tensors 72 | for key in keys: 73 | collated_batch[key] = default_collate(collated_batch[key]) 74 | return collated_batch 75 | 76 | 77 | class PartiallyRandomizedSimilarTimeLengthSampler(Sampler): 78 | """Copied from: https://github.com/r9y9/deepvoice3_pytorch/blob/master/train.py. 79 | Partially randomized sampler 80 | 1. Sort by lengths 81 | 2. Pick a small patch and randomize it 82 | 3. Permutate mini-batches 83 | """ 84 | 85 | def __init__(self, lengths, data_source, batch_size=16, batch_group_size=None, permutate=True): 86 | super().__init__(data_source) 87 | self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths)) 88 | self.batch_size = batch_size 89 | if batch_group_size is None: 90 | batch_group_size = min(batch_size * 32, len(self.lengths)) 91 | if batch_group_size % batch_size != 0: 92 | batch_group_size -= batch_group_size % batch_size 93 | 94 | self.batch_group_size = batch_group_size 95 | assert batch_group_size % batch_size == 0 96 | self.permutate = permutate 97 | 98 | def __iter__(self): 99 | indices = self.sorted_indices.clone() 100 | batch_group_size = self.batch_group_size 101 | s, e = 0, 0 102 | for i in range(len(indices) // batch_group_size): 103 | s = i * batch_group_size 104 | e = s + batch_group_size 105 | random.shuffle(indices[s:e]) 106 | 107 | # Permutate batches 108 | if self.permutate: 109 | perm = np.arange(len(indices[:e]) // self.batch_size) 110 | random.shuffle(perm) 111 | indices[:e] = indices[:e].view(-1, self.batch_size)[perm, :].view(-1) 112 | 113 | # Handle last elements 114 | s += batch_group_size 115 | if s < len(indices): 116 | random.shuffle(indices[s:]) 117 | 118 | return iter(indices) 119 | 120 | def __len__(self): 121 | return len(self.sorted_indices) 122 | -------------------------------------------------------------------------------- /train-ssrn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Train the Text2Mel network. See: https://arxiv.org/abs/1710.08969""" 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import sys 6 | import time 7 | import argparse 8 | from tqdm import * 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | # project imports 14 | from models import SSRN 15 | from hparams import HParams as hp 16 | from logger import Logger 17 | from utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint 18 | from datasets.data_loader import SSRNDataLoader 19 | 20 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') 22 | args = parser.parse_args() 23 | 24 | if args.dataset == 'ljspeech': 25 | from datasets.lj_speech import LJSpeech as SpeechDataset 26 | else: 27 | from datasets.mb_speech import MBSpeech as SpeechDataset 28 | 29 | use_gpu = torch.cuda.is_available() 30 | print('use_gpu', use_gpu) 31 | if use_gpu: 32 | torch.backends.cudnn.benchmark = True 33 | 34 | train_data_loader = SSRNDataLoader(ssrn_dataset=SpeechDataset(['mags', 'mels']), batch_size=24, mode='train') 35 | valid_data_loader = SSRNDataLoader(ssrn_dataset=SpeechDataset(['mags', 'mels']), batch_size=24, mode='valid') 36 | 37 | ssrn = SSRN().cuda() 38 | 39 | optimizer = torch.optim.Adam(ssrn.parameters(), lr=hp.ssrn_lr) 40 | 41 | start_timestamp = int(time.time() * 1000) 42 | start_epoch = 0 43 | global_step = 0 44 | 45 | logger = Logger(args.dataset, 'ssrn') 46 | 47 | # load the last checkpoint if exists 48 | last_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir) 49 | if last_checkpoint_file_name: 50 | print("loading the last checkpoint: %s" % last_checkpoint_file_name) 51 | start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, ssrn, optimizer) 52 | 53 | 54 | def get_lr(): 55 | return optimizer.param_groups[0]['lr'] 56 | 57 | 58 | def lr_decay(step, warmup_steps=1000): 59 | new_lr = hp.ssrn_lr * warmup_steps ** 0.5 * min((step + 1) * warmup_steps ** -1.5, (step + 1) ** -0.5) 60 | optimizer.param_groups[0]['lr'] = new_lr 61 | 62 | 63 | def train(train_epoch, phase='train'): 64 | global global_step 65 | 66 | lr_decay(global_step) 67 | print("epoch %3d with lr=%.02e" % (train_epoch, get_lr())) 68 | 69 | ssrn.train() if phase == 'train' else ssrn.eval() 70 | torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False) 71 | data_loader = train_data_loader if phase == 'train' else valid_data_loader 72 | 73 | it = 0 74 | running_loss = 0.0 75 | running_l1_loss = 0.0 76 | 77 | pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size, disable=hp.disable_progress_bar) 78 | for batch in pbar: 79 | M, S = batch['mags'], batch['mels'] 80 | M = M.permute(0, 2, 1) # TODO: because of pre processing 81 | S = S.permute(0, 2, 1) # TODO: because of pre processing 82 | 83 | M.requires_grad = False 84 | M = M.cuda() 85 | S = S.cuda() 86 | 87 | Z_logit, Z = ssrn(S) 88 | 89 | l1_loss = F.l1_loss(Z, M) 90 | 91 | loss = l1_loss 92 | 93 | if phase == 'train': 94 | lr_decay(global_step) 95 | optimizer.zero_grad() 96 | loss.backward() 97 | optimizer.step() 98 | global_step += 1 99 | 100 | it += 1 101 | 102 | loss = loss.item() 103 | l1_loss = l1_loss.item() 104 | running_loss += loss 105 | running_l1_loss += l1_loss 106 | 107 | if phase == 'train': 108 | # update the progress bar 109 | pbar.set_postfix({ 110 | 'l1': "%.05f" % (running_l1_loss / it) 111 | }) 112 | logger.log_step(phase, global_step, {'loss_l1': l1_loss}, 113 | {'mags-true': M[:1, :, :], 'mags-pred': Z[:1, :, :], 'mels': S[:1, :, :]}) 114 | if global_step % 5000 == 0: 115 | # checkpoint at every 5000th step 116 | save_checkpoint(logger.logdir, train_epoch, global_step, ssrn, optimizer) 117 | 118 | epoch_loss = running_loss / it 119 | epoch_l1_loss = running_l1_loss / it 120 | 121 | logger.log_epoch(phase, global_step, {'loss_l1': epoch_l1_loss}) 122 | 123 | return epoch_loss 124 | 125 | 126 | since = time.time() 127 | epoch = start_epoch 128 | while True: 129 | train_epoch_loss = train(epoch, phase='train') 130 | time_elapsed = time.time() - since 131 | time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60, 132 | time_elapsed % 60) 133 | print("train epoch loss %f, step=%d, %s" % (train_epoch_loss, global_step, time_str)) 134 | 135 | valid_epoch_loss = train(epoch, phase='valid') 136 | print("valid epoch loss %f" % valid_epoch_loss) 137 | 138 | epoch += 1 139 | if global_step >= hp.ssrn_max_iteration: 140 | print("max step %d (current step %d) reached, exiting..." % (hp.ssrn_max_iteration, global_step)) 141 | sys.exit(0) 142 | -------------------------------------------------------------------------------- /datasets/mb_speech.py: -------------------------------------------------------------------------------- 1 | """Data loader for the Mongolian Bible dataset.""" 2 | import os 3 | import codecs 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | 8 | vocab = "PE абвгдеёжзийклмноөпрстуүфхцчшъыьэюя-.,!?" # P: Padding, E: EOS. 9 | char2idx = {char: idx for idx, char in enumerate(vocab)} 10 | idx2char = {idx: char for idx, char in enumerate(vocab)} 11 | 12 | 13 | def text_normalize(text): 14 | text = text.lower() 15 | # text = text.replace(",", "'") 16 | # text = text.replace("!", "?") 17 | for c in "-—:": 18 | text = text.replace(c, "-") 19 | for c in "()\"«»“”'": 20 | text = text.replace(c, ",") 21 | return text 22 | 23 | 24 | def read_metadata(metadata_file): 25 | fnames, text_lengths, texts = [], [], [] 26 | transcript = os.path.join(metadata_file) 27 | lines = codecs.open(transcript, 'r', 'utf-8').readlines() 28 | for line in lines: 29 | fname, _, text = line.strip().split("|") 30 | 31 | fnames.append(fname) 32 | 33 | text = text_normalize(text) + "E" # E: EOS 34 | text = [char2idx[char] for char in text] 35 | text_lengths.append(len(text)) 36 | texts.append(np.array(text, np.longlong)) 37 | 38 | return fnames, text_lengths, texts 39 | 40 | 41 | def get_test_data(sentences, max_n): 42 | normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS 43 | texts = np.zeros((len(normalized_sentences), max_n + 1), np.longlong) 44 | for i, sent in enumerate(normalized_sentences): 45 | texts[i, :len(sent)] = [char2idx[char] for char in sent] 46 | return texts 47 | 48 | 49 | class MBSpeech(Dataset): 50 | def __init__(self, keys, dir_name='MBSpeech-1.0'): 51 | self.keys = keys 52 | self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name) 53 | self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv')) 54 | 55 | def slice(self, start, end): 56 | self.fnames = self.fnames[start:end] 57 | self.text_lengths = self.text_lengths[start:end] 58 | self.texts = self.texts[start:end] 59 | 60 | def __len__(self): 61 | return len(self.fnames) 62 | 63 | def __getitem__(self, index): 64 | data = {} 65 | if 'texts' in self.keys: 66 | data['texts'] = self.texts[index] 67 | if 'mels' in self.keys: 68 | # (39, 80) 69 | data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index])) 70 | if 'mags' in self.keys: 71 | # (39, 80) 72 | data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index])) 73 | if 'mel_gates' in self.keys: 74 | data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int64) # TODO: because pre processing! 75 | if 'mag_gates' in self.keys: 76 | data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int64) # TODO: because pre processing! 77 | return data 78 | 79 | # 80 | # simple method to convert mongolian numbers to text, copied from somewhere 81 | # 82 | 83 | 84 | def number2word(number): 85 | digit_len = len(number) 86 | digit_name = {1: '', 2: 'мянга', 3: 'сая', 4: 'тэрбум', 5: 'их наяд', 6: 'тунамал'} 87 | 88 | if digit_len == 1: 89 | return _last_digit_2_str(number) 90 | if digit_len == 2: 91 | return _2_digits_2_str(number) 92 | if digit_len == 3: 93 | return _3_digits_to_str(number) 94 | if digit_len < 7: 95 | return _3_digits_to_str(number[:-3], False) + ' ' + digit_name[2] + ' ' + _3_digits_to_str(number[-3:]) 96 | 97 | digitgroup = [number[0 if i - 3 < 0 else i - 3:i] for i in reversed(range(len(number), 0, -3))] 98 | count = len(digitgroup) 99 | i = 0 100 | result = '' 101 | while i < count - 1: 102 | result += ' ' + (_3_digits_to_str(digitgroup[i], False) + ' ' + digit_name[count - i]) 103 | i += 1 104 | return result.strip() + ' ' + _3_digits_to_str(digitgroup[-1]) 105 | 106 | 107 | def _1_digit_2_str(digit): 108 | return {'0': '', '1': 'нэгэн', '2': 'хоёр', '3': 'гурван', '4': 'дөрвөн', '5': 'таван', '6': 'зургаан', 109 | '7': 'долоон', '8': 'найман', '9': 'есөн'}[digit] 110 | 111 | 112 | def _last_digit_2_str(digit): 113 | return {'0': 'тэг', '1': 'нэг', '2': 'хоёр', '3': 'гурав', '4': 'дөрөв', '5': 'тав', '6': 'зургаа', '7': 'долоо', 114 | '8': 'найм', '9': 'ес'}[digit] 115 | 116 | 117 | def _2_digits_2_str(digit, is_fina=True): 118 | word2 = {'0': '', '1': 'арван', '2': 'хорин', '3': 'гучин', '4': 'дөчин', '5': 'тавин', '6': 'жаран', '7': 'далан', 119 | '8': 'наян', '9': 'ерэн'} 120 | word2fina = {'10': 'арав', '20': 'хорь', '30': 'гуч', '40': 'дөч', '50': 'тавь', '60': 'жар', '70': 'дал', 121 | '80': 'ная', '90': 'ер'} 122 | if digit[1] == '0': 123 | return word2fina[digit] if is_fina else word2[digit[0]] 124 | digit1 = _last_digit_2_str(digit[1]) if is_fina else _1_digit_2_str(digit[1]) 125 | return (word2[digit[0]] + ' ' + digit1).strip() 126 | 127 | 128 | def _3_digits_to_str(digit, is_fina=True): 129 | digstr = digit.lstrip('0') 130 | if len(digstr) == 0: 131 | return '' 132 | if len(digstr) == 1: 133 | return _1_digit_2_str(digstr) 134 | if len(digstr) == 2: 135 | return _2_digits_2_str(digstr, is_fina) 136 | if digit[-2:] == '00': 137 | return _1_digit_2_str(digit[0]) + ' зуу' if is_fina else _1_digit_2_str(digit[0]) + ' зуун' 138 | else: 139 | return _1_digit_2_str(digit[0]) + ' зуун ' + _2_digits_2_str(digit[-2:], is_fina) 140 | -------------------------------------------------------------------------------- /train-text2mel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Train the Text2Mel network. See: https://arxiv.org/abs/1710.08969""" 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import sys 6 | import time 7 | import argparse 8 | from tqdm import * 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | # project imports 16 | from models import Text2Mel 17 | from hparams import HParams as hp 18 | from logger import Logger 19 | from utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint 20 | from datasets.data_loader import Text2MelDataLoader 21 | 22 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') 24 | args = parser.parse_args() 25 | 26 | if args.dataset == 'ljspeech': 27 | from datasets.lj_speech import vocab, LJSpeech as SpeechDataset 28 | else: 29 | from datasets.mb_speech import vocab, MBSpeech as SpeechDataset 30 | 31 | use_gpu = torch.cuda.is_available() 32 | print('use_gpu', use_gpu) 33 | if use_gpu: 34 | torch.backends.cudnn.benchmark = True 35 | 36 | train_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(['texts', 'mels', 'mel_gates']), batch_size=64, 37 | mode='train') 38 | valid_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(['texts', 'mels', 'mel_gates']), batch_size=64, 39 | mode='valid') 40 | 41 | text2mel = Text2Mel(vocab).cuda() 42 | 43 | optimizer = torch.optim.Adam(text2mel.parameters(), lr=hp.text2mel_lr) 44 | 45 | start_timestamp = int(time.time() * 1000) 46 | start_epoch = 0 47 | global_step = 0 48 | 49 | logger = Logger(args.dataset, 'text2mel') 50 | 51 | # load the last checkpoint if exists 52 | last_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir) 53 | if last_checkpoint_file_name: 54 | print("loading the last checkpoint: %s" % last_checkpoint_file_name) 55 | start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, text2mel, optimizer) 56 | 57 | 58 | def get_lr(): 59 | return optimizer.param_groups[0]['lr'] 60 | 61 | 62 | def lr_decay(step, warmup_steps=4000): 63 | new_lr = hp.text2mel_lr * warmup_steps ** 0.5 * min((step + 1) * warmup_steps ** -1.5, (step + 1) ** -0.5) 64 | optimizer.param_groups[0]['lr'] = new_lr 65 | 66 | 67 | def train(train_epoch, phase='train'): 68 | global global_step 69 | 70 | lr_decay(global_step) 71 | print("epoch %3d with lr=%.02e" % (train_epoch, get_lr())) 72 | 73 | text2mel.train() if phase == 'train' else text2mel.eval() 74 | torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False) 75 | data_loader = train_data_loader if phase == 'train' else valid_data_loader 76 | 77 | it = 0 78 | running_loss = 0.0 79 | running_l1_loss = 0.0 80 | running_att_loss = 0.0 81 | 82 | pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size, disable=hp.disable_progress_bar) 83 | for batch in pbar: 84 | L, S, gates = batch['texts'], batch['mels'], batch['mel_gates'] 85 | S = S.permute(0, 2, 1) # TODO: because of pre processing 86 | 87 | B, N = L.size() # batch size and text count 88 | _, n_mels, T = S.size() # number of melspectrogram bins and time 89 | 90 | assert gates.size(0) == B # TODO: later remove 91 | assert gates.size(1) == T 92 | 93 | S_shifted = torch.cat((S[:, :, 1:], torch.zeros(B, n_mels, 1)), 2) 94 | 95 | S.requires_grad = False 96 | S_shifted.requires_grad = False 97 | gates.requires_grad = False 98 | 99 | def W_nt(_, n, t, g=0.2): 100 | return 1.0 - np.exp(-((n / float(N) - t / float(T)) ** 2) / (2 * g ** 2)) 101 | 102 | W = np.fromfunction(W_nt, (B, N, T), dtype=np.float32) 103 | W = torch.from_numpy(W) 104 | 105 | L = L.cuda() 106 | S = S.cuda() 107 | S_shifted = S_shifted.cuda() 108 | W = W.cuda() 109 | gates = gates.cuda() 110 | 111 | Y_logit, Y, A = text2mel(L, S) 112 | 113 | l1_loss = F.l1_loss(Y, S_shifted) 114 | masks = gates.reshape(B, 1, T).float() 115 | att_loss = (A * W * masks).mean() 116 | 117 | loss = l1_loss + att_loss 118 | 119 | if phase == 'train': 120 | lr_decay(global_step) 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | global_step += 1 125 | 126 | it += 1 127 | 128 | loss, l1_loss, att_loss = loss.item(), l1_loss.item(), att_loss.item() 129 | running_loss += loss 130 | running_l1_loss += l1_loss 131 | running_att_loss += att_loss 132 | 133 | if phase == 'train': 134 | # update the progress bar 135 | pbar.set_postfix({ 136 | 'l1': "%.05f" % (running_l1_loss / it), 137 | 'att': "%.05f" % (running_att_loss / it) 138 | }) 139 | logger.log_step(phase, global_step, {'loss_l1': l1_loss, 'loss_att': att_loss}, 140 | {'mels-true': S[:1, :, :], 'mels-pred': Y[:1, :, :], 'attention': A[:1, :, :]}) 141 | if global_step % 5000 == 0: 142 | # checkpoint at every 5000th step 143 | save_checkpoint(logger.logdir, train_epoch, global_step, text2mel, optimizer) 144 | 145 | epoch_loss = running_loss / it 146 | epoch_l1_loss = running_l1_loss / it 147 | epoch_att_loss = running_att_loss / it 148 | 149 | logger.log_epoch(phase, global_step, {'loss_l1': epoch_l1_loss, 'loss_att': epoch_att_loss}) 150 | 151 | return epoch_loss 152 | 153 | 154 | since = time.time() 155 | epoch = start_epoch 156 | while True: 157 | train_epoch_loss = train(epoch, phase='train') 158 | time_elapsed = time.time() - since 159 | time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60, 160 | time_elapsed % 60) 161 | print("train epoch loss %f, step=%d, %s" % (train_epoch_loss, global_step, time_str)) 162 | 163 | valid_epoch_loss = train(epoch, phase='valid') 164 | print("valid epoch loss %f" % valid_epoch_loss) 165 | 166 | epoch += 1 167 | if global_step >= hp.text2mel_max_iteration: 168 | print("max step %d (current step %d) reached, exiting..." % (hp.text2mel_max_iteration, global_step)) 169 | sys.exit(0) 170 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | __all__ = ['E', 'D', 'C', 'HighwayBlock', 'GatedConvBlock', 'ResidualBlock'] 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from hparams import HParams as hp 8 | 9 | 10 | class LayerNorm(nn.LayerNorm): 11 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 12 | """Layer Norm.""" 13 | super(LayerNorm, self).__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) 14 | 15 | def forward(self, x): 16 | x = x.permute(0, 2, 1) # PyTorch LayerNorm seems to be expect (B, T, C) 17 | y = super(LayerNorm, self).forward(x) 18 | y = y.permute(0, 2, 1) # reverse 19 | return y 20 | 21 | 22 | class D(nn.Module): 23 | def __init__(self, in_channels, out_channels, kernel_size, dilation, weight_init='none', normalization='weight', nonlinearity='linear'): 24 | """1D Deconvolution.""" 25 | super(D, self).__init__() 26 | self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, 27 | stride=2, # paper: stride of deconvolution is always 2 28 | dilation=dilation) 29 | 30 | if normalization == 'weight': 31 | self.deconv = nn.utils.weight_norm(self.deconv) 32 | elif normalization == 'layer': 33 | self.layer_norm = LayerNorm(out_channels) 34 | 35 | self.nonlinearity = nonlinearity 36 | if weight_init == 'kaiming': 37 | nn.init.kaiming_normal_(self.deconv.weight, mode='fan_out', nonlinearity=nonlinearity) 38 | elif weight_init == 'xavier': 39 | nn.init.xavier_uniform_(self.deconv.weight, nn.init.calculate_gain(nonlinearity)) 40 | 41 | def forward(self, x, output_size=None): 42 | y = self.deconv(x, output_size=output_size) 43 | if hasattr(self, 'layer_norm'): 44 | y = self.layer_norm(y) 45 | y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True) 46 | if self.nonlinearity == 'relu': 47 | y = F.relu(y, inplace=True) 48 | return y 49 | 50 | 51 | class C(nn.Module): 52 | def __init__(self, in_channels, out_channels, kernel_size, dilation, causal=False, weight_init='none', normalization='weight', nonlinearity='linear'): 53 | """1D convolution. 54 | The argument 'causal' indicates whether the causal convolution should be used or not. 55 | """ 56 | super(C, self).__init__() 57 | self.causal = causal 58 | if causal: 59 | self.padding = (kernel_size - 1) * dilation 60 | else: 61 | self.padding = (kernel_size - 1) * dilation // 2 62 | 63 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 64 | stride=1, # paper: 'The stride of convolution is always 1.' 65 | padding=self.padding, dilation=dilation) 66 | 67 | if normalization == 'weight': 68 | self.conv = nn.utils.weight_norm(self.conv) 69 | elif normalization == 'layer': 70 | self.layer_norm = LayerNorm(out_channels) 71 | 72 | self.nonlinearity = nonlinearity 73 | if weight_init == 'kaiming': 74 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity=nonlinearity) 75 | elif weight_init == 'xavier': 76 | nn.init.xavier_uniform_(self.conv.weight, nn.init.calculate_gain(nonlinearity)) 77 | 78 | def forward(self, x): 79 | y = self.conv(x) 80 | padding = self.padding 81 | if self.causal and padding > 0: 82 | y = y[:, :, :-padding] 83 | 84 | if hasattr(self, 'layer_norm'): 85 | y = self.layer_norm(y) 86 | y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True) 87 | if self.nonlinearity == 'relu': 88 | y = F.relu(y, inplace=True) 89 | return y 90 | 91 | 92 | class E(nn.Module): 93 | def __init__(self, num_embeddings, embedding_dim): 94 | super(E, self).__init__() 95 | self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0) 96 | 97 | def forward(self, x): 98 | return self.embedding(x) 99 | 100 | 101 | class HighwayBlock(nn.Module): 102 | def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'): 103 | """Highway Network like layer: https://arxiv.org/abs/1505.00387 104 | The input and output shapes remain same. 105 | Args: 106 | d: input channel 107 | k: kernel size 108 | delta: dilation 109 | causal: causal convolution or not 110 | """ 111 | super(HighwayBlock, self).__init__() 112 | self.d = d 113 | self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization) 114 | 115 | def forward(self, x): 116 | L = self.C(x) 117 | H1 = L[:, :self.d, :] 118 | H2 = L[:, self.d:, :] 119 | sigH1 = F.sigmoid(H1) 120 | return sigH1 * H2 + (1 - sigH1) * x 121 | 122 | 123 | class GatedConvBlock(nn.Module): 124 | def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'): 125 | """Gated convolutional layer: https://arxiv.org/abs/1612.08083 126 | The input and output shapes remain same. 127 | Args: 128 | d: input channel 129 | k: kernel size 130 | delta: dilation 131 | causal: causal convolution or not 132 | """ 133 | super(GatedConvBlock, self).__init__() 134 | self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, 135 | weight_init=weight_init, normalization=normalization) 136 | self.glu = nn.GLU(dim=1) 137 | 138 | def forward(self, x): 139 | L = self.C(x) 140 | return self.glu(L) + x 141 | 142 | 143 | class ResidualBlock(nn.Module): 144 | def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight', 145 | widening_factor=2): 146 | """Residual block: https://arxiv.org/abs/1512.03385 147 | The input and output shapes remain same. 148 | Args: 149 | d: input channel 150 | k: kernel size 151 | delta: dilation 152 | causal: causal convolution or not 153 | """ 154 | super(ResidualBlock, self).__init__() 155 | self.C1 = C(in_channels=d, out_channels=widening_factor * d, kernel_size=k, dilation=delta, causal=causal, 156 | weight_init=weight_init, normalization=normalization, nonlinearity='relu') 157 | self.C2 = C(in_channels=widening_factor * d, out_channels=d, kernel_size=k, dilation=delta, causal=causal, 158 | weight_init=weight_init, normalization=normalization, nonlinearity='relu') 159 | 160 | def forward(self, x): 161 | return self.C2(self.C1(x)) + x 162 | -------------------------------------------------------------------------------- /dl_and_preprop_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Download and preprocess datasets. Supported datasets are: 3 | * English female: LJSpeech (https://keithito.com/LJ-Speech-Dataset/) 4 | * Mongolian male: MBSpeech (Mongolian Bible) 5 | """ 6 | __author__ = 'Erdene-Ochir Tuguldur' 7 | 8 | import os 9 | import sys 10 | import csv 11 | import time 12 | import argparse 13 | import fnmatch 14 | import librosa 15 | import pandas as pd 16 | 17 | from hparams import HParams as hp 18 | from zipfile import ZipFile 19 | from audio import preprocess 20 | from utils import download_file 21 | from datasets.mb_speech import MBSpeech 22 | from datasets.lj_speech import LJSpeech 23 | 24 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') 26 | args = parser.parse_args() 27 | 28 | if args.dataset == 'ljspeech': 29 | dataset_file_name = 'LJSpeech-1.1.tar.bz2' 30 | datasets_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets') 31 | dataset_path = os.path.join(datasets_path, 'LJSpeech-1.1') 32 | 33 | if os.path.isdir(dataset_path) and False: 34 | print("LJSpeech dataset folder already exists") 35 | sys.exit(0) 36 | else: 37 | dataset_file_path = os.path.join(datasets_path, dataset_file_name) 38 | if not os.path.isfile(dataset_file_path): 39 | url = "http://data.keithito.com/data/speech/%s" % dataset_file_name 40 | download_file(url, dataset_file_path) 41 | else: 42 | print("'%s' already exists" % dataset_file_name) 43 | 44 | print("extracting '%s'..." % dataset_file_name) 45 | os.system('cd %s; tar xvjf %s' % (datasets_path, dataset_file_name)) 46 | 47 | # pre process 48 | print("pre processing...") 49 | lj_speech = LJSpeech([]) 50 | preprocess(dataset_path, lj_speech) 51 | elif args.dataset == 'mbspeech': 52 | dataset_name = 'MBSpeech-1.0' 53 | datasets_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets') 54 | dataset_path = os.path.join(datasets_path, dataset_name) 55 | 56 | if os.path.isdir(dataset_path) and False: 57 | print("MBSpeech dataset folder already exists") 58 | sys.exit(0) 59 | else: 60 | bible_books = ['01_Genesis', '02_Exodus', '03_Leviticus'] 61 | for bible_book_name in bible_books: 62 | bible_book_file_name = '%s.zip' % bible_book_name 63 | bible_book_file_path = os.path.join(datasets_path, bible_book_file_name) 64 | if not os.path.isfile(bible_book_file_path): 65 | url = "https://s3.us-east-2.amazonaws.com/bible.davarpartners.com/Mongolian/" + bible_book_file_name 66 | download_file(url, bible_book_file_path) 67 | else: 68 | print("'%s' already exists" % bible_book_file_name) 69 | 70 | print("extracting '%s'..." % bible_book_file_name) 71 | zipfile = ZipFile(bible_book_file_path) 72 | zipfile.extractall(datasets_path) 73 | 74 | dataset_csv_file_path = os.path.join(datasets_path, '%s-csv.zip' % dataset_name) 75 | dataset_csv_extracted_path = os.path.join(datasets_path, '%s-csv' % dataset_name) 76 | if not os.path.isfile(dataset_csv_file_path): 77 | url = "https://www.dropbox.com/s/dafueq0w278lbz6/%s-csv.zip?dl=1" % dataset_name 78 | download_file(url, dataset_csv_file_path) 79 | else: 80 | print("'%s' already exists" % dataset_csv_file_path) 81 | 82 | print("extracting '%s'..." % dataset_csv_file_path) 83 | zipfile = ZipFile(dataset_csv_file_path) 84 | zipfile.extractall(datasets_path) 85 | 86 | sample_rate = 44100 # original sample rate 87 | total_duration_s = 0 88 | 89 | if not os.path.isdir(dataset_path): 90 | os.mkdir(dataset_path) 91 | wavs_path = os.path.join(dataset_path, 'wavs') 92 | if not os.path.isdir(wavs_path): 93 | os.mkdir(wavs_path) 94 | 95 | metadata_csv = open(os.path.join(dataset_path, 'metadata.csv'), 'w') 96 | metadata_csv_writer = csv.writer(metadata_csv, delimiter='|') 97 | 98 | 99 | def _normalize(s): 100 | """remove leading '-'""" 101 | s = s.strip() 102 | if s[0] == '—' or s[0] == '-': 103 | s = s[1:].strip() 104 | return s 105 | 106 | 107 | def _get_mp3_file(book_name, chapter): 108 | book_download_path = os.path.join(datasets_path, book_name) 109 | wildcard = "*%02d - DPI.mp3" % chapter 110 | for file_name in os.listdir(book_download_path): 111 | if fnmatch.fnmatch(file_name, wildcard): 112 | return os.path.join(book_download_path, file_name) 113 | return None 114 | 115 | 116 | def _convert_mp3_to_wav(book_name, book_nr): 117 | global total_duration_s 118 | chapter = 1 119 | while True: 120 | try: 121 | i = 0 122 | chapter_csv_file_name = os.path.join(dataset_csv_extracted_path, "%s_%02d.csv" % (book_name, chapter)) 123 | df = pd.read_csv(chapter_csv_file_name, sep="|") 124 | print("processing %s..." % chapter_csv_file_name) 125 | mp3_file = _get_mp3_file(book_name, chapter) 126 | print("processing %s..." % mp3_file) 127 | assert mp3_file is not None 128 | samples, sr = librosa.load(mp3_file, sr=sample_rate, mono=True) 129 | assert sr == sample_rate 130 | 131 | for index, row in df.iterrows(): 132 | start, end, sentence = row['start'], row['end'], row['sentence'] 133 | assert end > start 134 | duration = end - start 135 | duration_s = int(duration / sample_rate) 136 | if duration_s > 10: 137 | continue # only audios shorter than 10s 138 | 139 | total_duration_s += duration_s 140 | i += 1 141 | sentence = _normalize(sentence) 142 | fn = "MB%d%02d-%04d" % (book_nr, chapter, i) 143 | metadata_csv_writer.writerow([fn, sentence, sentence]) # same format as LJSpeech 144 | wav = samples[start:end] 145 | wav = librosa.resample(wav, sample_rate, hp.sr) # use same sample rate as LJSpeech 146 | librosa.output.write_wav(os.path.join(wavs_path, fn + ".wav"), wav, hp.sr) 147 | 148 | chapter += 1 149 | except FileNotFoundError: 150 | break 151 | 152 | 153 | _convert_mp3_to_wav('01_Genesis', 1) 154 | _convert_mp3_to_wav('02_Exodus', 2) 155 | _convert_mp3_to_wav('03_Leviticus', 3) 156 | metadata_csv.close() 157 | print("total audio duration: %ss" % (time.strftime('%H:%M:%S', time.gmtime(total_duration_s)))) 158 | 159 | # pre process 160 | print("pre processing...") 161 | mb_speech = MBSpeech([]) 162 | preprocess(dataset_path, mb_speech) 163 | -------------------------------------------------------------------------------- /models/text2mel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara 3 | Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention 4 | https://arxiv.org/abs/1710.08969 5 | 6 | Text2Mel Network. 7 | """ 8 | __author__ = 'Erdene-Ochir Tuguldur' 9 | __all__ = ['Text2Mel'] 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from hparams import HParams as hp 18 | from .layers import E, C, HighwayBlock, GatedConvBlock, ResidualBlock 19 | 20 | 21 | def Conv(in_channels, out_channels, kernel_size, dilation, causal=False, nonlinearity='linear'): 22 | return C(in_channels, out_channels, kernel_size, dilation, causal=causal, 23 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, nonlinearity=nonlinearity) 24 | 25 | 26 | def BasicBlock(d, k, delta, causal=False): 27 | if hp.text2mel_basic_block == 'gated_conv': 28 | return GatedConvBlock(d, k, delta, causal=causal, 29 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization) 30 | elif hp.text2mel_basic_block == 'highway': 31 | return HighwayBlock(d, k, delta, causal=causal, 32 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization) 33 | else: 34 | return ResidualBlock(d, k, delta, causal=causal, 35 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, 36 | widening_factor=2) 37 | 38 | 39 | def CausalConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): 40 | return Conv(in_channels, out_channels, kernel_size, dilation, causal=True, nonlinearity=nonlinearity) 41 | 42 | 43 | def CausalBasicBlock(d, k, delta): 44 | return BasicBlock(d, k, delta, causal=True) 45 | 46 | 47 | class TextEnc(nn.Module): 48 | 49 | def __init__(self, vocab, e=hp.e, d=hp.d): 50 | """Text encoder network. 51 | Args: 52 | vocab: vocabulary 53 | e: embedding dim 54 | d: Text2Mel dim 55 | Input: 56 | L: (B, N) text inputs 57 | Outputs: 58 | K: (B, d, N) keys 59 | V: (N, d, N) values 60 | """ 61 | super(TextEnc, self).__init__() 62 | self.d = d 63 | self.embedding = E(len(vocab), e) 64 | 65 | self.layers = nn.Sequential( 66 | Conv(e, 2 * d, 1, 1, nonlinearity='relu'), 67 | Conv(2 * d, 2 * d, 1, 1), 68 | 69 | BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27), 70 | BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27), 71 | 72 | BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 1), 73 | 74 | BasicBlock(2 * d, 1, 1), BasicBlock(2 * d, 1, 1) 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.embedding(x) 79 | out = out.permute(0, 2, 1) # change to (B, e, N) 80 | out = self.layers(out) # (B, 2*d, N) 81 | K = out[:, :self.d, :] # (B, d, N) 82 | V = out[:, self.d:, :] # (B, d, N) 83 | return K, V 84 | 85 | 86 | class AudioEnc(nn.Module): 87 | def __init__(self, d=hp.d, f=hp.n_mels): 88 | """Audio encoder network. 89 | Args: 90 | d: Text2Mel dim 91 | f: Number of mel bins 92 | Input: 93 | S: (B, f, T) melspectrograms 94 | Output: 95 | Q: (B, d, T) queries 96 | """ 97 | super(AudioEnc, self).__init__() 98 | self.layers = nn.Sequential( 99 | CausalConv(f, d, 1, 1, nonlinearity='relu'), 100 | CausalConv(d, d, 1, 1, nonlinearity='relu'), 101 | CausalConv(d, d, 1, 1), 102 | 103 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), 104 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), 105 | 106 | CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 3), 107 | ) 108 | 109 | def forward(self, x): 110 | return self.layers(x) 111 | 112 | 113 | class AudioDec(nn.Module): 114 | def __init__(self, d=hp.d, f=hp.n_mels): 115 | """Audio decoder network. 116 | Args: 117 | d: Text2Mel dim 118 | f: Number of mel bins 119 | Input: 120 | R_prime: (B, 2d, T) [V*Attention, Q] paper says: "we found it beneficial in our pilot study." 121 | Output: 122 | Y: (B, f, T) 123 | """ 124 | super(AudioDec, self).__init__() 125 | self.layers = nn.Sequential( 126 | CausalConv(2 * d, d, 1, 1), 127 | 128 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), 129 | 130 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 1), 131 | 132 | # CausalConv(d, d, 1, 1, nonlinearity='relu'), 133 | # CausalConv(d, d, 1, 1, nonlinearity='relu'), 134 | CausalBasicBlock(d, 1, 1), 135 | CausalConv(d, d, 1, 1, nonlinearity='relu'), 136 | 137 | CausalConv(d, f, 1, 1) 138 | ) 139 | 140 | def forward(self, x): 141 | return self.layers(x) 142 | 143 | 144 | class Text2Mel(nn.Module): 145 | def __init__(self, vocab, d=hp.d): 146 | """Text to melspectrogram network. 147 | Args: 148 | vocab: vocabulary 149 | d: Text2Mel dim 150 | Input: 151 | L: (B, N) text inputs 152 | S: (B, f, T) melspectrograms 153 | Outputs: 154 | Y_logit: logit of Y 155 | Y: predicted melspectrograms 156 | A: (B, N, T) attention matrix 157 | """ 158 | super(Text2Mel, self).__init__() 159 | self.d = d 160 | self.text_enc = TextEnc(vocab) 161 | self.audio_enc = AudioEnc() 162 | self.audio_dec = AudioDec() 163 | 164 | def forward(self, L, S, monotonic_attention=False): 165 | K, V = self.text_enc(L) 166 | Q = self.audio_enc(S) 167 | A = torch.bmm(K.permute(0, 2, 1), Q) / np.sqrt(self.d) 168 | 169 | if monotonic_attention: 170 | # TODO: vectorize instead of loops 171 | B, N, T = A.size() 172 | for i in range(B): 173 | prva = -1 # previous attention 174 | for t in range(T): 175 | _, n = torch.max(A[i, :, t], 0) 176 | if not (-1 <= n - prva <= 3): 177 | A[i, :, t] = -2 ** 20 # some small numbers 178 | A[i, min(N - 1, prva + 1), t] = 1 179 | _, prva = torch.max(A[i, :, t], 0) 180 | 181 | A = F.softmax(A, dim=1) 182 | R = torch.bmm(V, A) 183 | R_prime = torch.cat((R, Q), 1) 184 | Y_logit = self.audio_dec(R_prime) 185 | Y = F.sigmoid(Y_logit) 186 | return Y_logit, Y, A 187 | --------------------------------------------------------------------------------