├── predict.py ├── .gitignore ├── input ├── LJSpeech-1.1 │ ├── .~lock.metadata.csv# │ └── README ├── char_to_idx.pickle └── idx_to_char.pickle ├── Transformer_tts_model ├── model.png └── TransformerTTSModel.py ├── config.py ├── sampling.py ├── LICENSE ├── README.md ├── utils.py ├── engine.py ├── dataloader.py ├── train.py └── Wavenet └── WaveNet.py /predict.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.wav 3 | *.bin 4 | __pycache__ 5 | 6 | LICENSE 7 | README.md 8 | 9 | -------------------------------------------------------------------------------- /input/LJSpeech-1.1/.~lock.metadata.csv#: -------------------------------------------------------------------------------- 1 | ,shivam,shivam,02.12.2020 18:34,file:///home/shivam/.config/libreoffice/4; -------------------------------------------------------------------------------- /input/char_to_idx.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Text-To-Speech/HEAD/input/char_to_idx.pickle -------------------------------------------------------------------------------- /input/idx_to_char.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Text-To-Speech/HEAD/input/idx_to_char.pickle -------------------------------------------------------------------------------- /Transformer_tts_model/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Text-To-Speech/HEAD/Transformer_tts_model/model.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | sample_rate=16000 #Provided in the paper 2 | n_mels = 80 3 | frame_rate = 80 #Provided in the paper 4 | frame_length = 0.05 5 | hop_length = int(sample_rate/frame_rate) 6 | win_length = int(sample_rate*frame_length) 7 | 8 | scaling_factor = 4 9 | min_db_level = -100 10 | 11 | bce_weights = 7 12 | 13 | embed_dims = 512 14 | hidden_dims = 256 15 | heads = 4 16 | forward_expansion = 4 17 | num_layers = 4 18 | dropout = 0.15 19 | max_len = 1024 20 | pad_idx = 0 21 | 22 | Metadata = 'input/LJSpeech-1.1/metadata.csv' 23 | Audio_file_path = 'input/LJSpeech-1.1/wavs/' 24 | 25 | Model_Path = 'model/model.bin' 26 | checkpoint = 'model/checkpoint.bin' 27 | 28 | Batch_Size = 2 29 | Epochs = 40 30 | LR = 3e-4 31 | warmup_steps = 0.2 -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def greedy_decoding(): 4 | pass 5 | 6 | 7 | def sampling_decoding(): 8 | pass 9 | 10 | def mixture_of_log_sampling(y, log_scale_min=-7.0, clamp_log_scale=False): 11 | nr_mix = y.shape[1] // 3 12 | 13 | y = y.transpose(1, 2) 14 | logit_probs = y[:, :, :nr_mix] 15 | 16 | temp = logit_probs.data.new(logit_probs.shape).uniform_(1e-5, 1.0 - 1e-5) 17 | temp = logit_probs.data - torch.log(- torch.log(temp)) 18 | _, argmax = temp.max(dim=-1) 19 | 20 | one_hot = to_one_hot(argmax, nr_mix) 21 | 22 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 23 | log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1) 24 | if clamp_log_scale: 25 | log_scales = torch.clamp(log_scales, min=log_scale_min) 26 | 27 | u = means.data.new(means.shape).uniform_(1e-5, 1.0 - 1e-5) 28 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 29 | 30 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 31 | 32 | return x 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shivam Raj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Text To Speech 2 | 3 | A text-to-speech (TTS) system converts normal language text into speech; other systems render symbolic linguistic representations like phonetic transcriptions into speech. Now with recent development in deep learning, it's possible to convert text into a human-understandable voice. For this, the text is fed into an Encoder-Decoder type Neural Network to output a Mel-Spectrogram. This Mel-Spectrogram can now be used to generate audio using the ["Griffin-Lim Algorithm"](https://paperswithcode.com/method/griffin-lim-algorithm). But due to its disadvantage that it is not able to produce human-like speech quality, another neural net named [WaveNet](https://deepmind.com/blog/article/wavenet-generative-model-raw-audio) is employed, which is fed by Mel-Spectrogram to produce audio that even a human is not able to differentiate apart. 4 | 5 | ## Model Architecture 6 | 7 | ### 1. Transformer TTS 8 | 9 | 10 | * An Encoder-Decoder transformer architecture for parallel training instead for Seq2Seq training incase of [Tacotron-2](https://github.com/NVIDIA/tacotron2). 11 | * Text are sent as input and the model outputs a Mel-Spectrogram. 12 | * Multi-headed attention is employed, with causal masking only on the decoder side. 13 | * Paper : [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895). 14 | 15 | ### 2. Wavenet 16 | * 17 | 18 | * Output of the Transformer tts (Mel-Spectrogram) is fed into the Wavenet to generate audio samples. 19 | * Unlike Seq2Seq models wavenet also allows parallel training. 20 | * Paper : [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499). 21 | 22 | 23 | 24 | 25 | ## Dataset Information 26 | The model was trained on a subset of WMT-2014 English-German Dataset. Preprocessing was carried out before training the model.
27 | Dataset : https://keithito.com/LJ-Speech-Dataset/ 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def mulaw_encode(audio_samples): 4 | # Outputs values ranging from 0-255 5 | audio_converted_samples = [] 6 | mu = 255 7 | for audio_sample in audio_samples: 8 | audio_sample = np.sign(audio_sample) * (np.log1p(mu * np.abs(audio_sample)) / np.log1p(mu)) 9 | audio_sample = ((audio_sample + 1) / 2 * mu + 0.5) 10 | audio_converted_samples.append(audio_sample.astype(int)) 11 | audio_converted_samples = np.array(audio_converted_samples) 12 | return audio_converted_samples 13 | 14 | 15 | def mulaw_decode(audio_samples): 16 | audio_converted_samples = [] 17 | mu = 255 18 | for audio_sample in audio_samples: 19 | audio_sample = audio_sample.astype(np.float32) 20 | audio_sample = 2*(audio_sample/mu) - 1 21 | audio_sample = np.sign(audio_sample)*(1.0 / mu)*((1.0 + mu)**(np.abs(audio_sample)) - 1.0) 22 | audio_converted_samples.append(audio_sample) 23 | audio_converted_samples = np.array(audio_converted_samples) 24 | return audio_converted_samples 25 | 26 | 27 | def normalize_(mel): 28 | #Normalizing data between -4 and 4 29 | #Converges even more faster 30 | mel = np.clip( 31 | (config.scaling_factor)*((mel - config.min_db_level)/-config.min_db_level) - config.scaling_factor, 32 | -config.scaling_factor, config.scaling_factor 33 | ) 34 | return mel 35 | 36 | 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | import librosa 42 | import numpy as np 43 | a = np.random.uniform(-1, 1, (2, 10)) 44 | a = mulaw_encode(a) 45 | print(a) 46 | # from sklearn.preprocessing import StandardScaler 47 | # audio, sr = librosa.load("../Downloads/WhatsApp Ptt 2021-07-05 at 3.18.15 PM.ogg", sr=16000) 48 | # audio_samples = audio[None, :] 49 | # print(min(audio_samples[0]), max(audio_samples[0])) 50 | # audio_samples = mulaw_encode(audio_samples) 51 | # print(min(audio_samples[0]), max(audio_samples[0])) 52 | # audio_samples = mulaw_decode(audio_samples) 53 | # print(min(audio_samples[0]), max(audio_samples[0])) -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import config 2 | 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | 7 | def loss_fn( 8 | target_mel_spect, 9 | target_end_logits, 10 | pred_mel_spect_post, 11 | pred_mel_spect, 12 | pred_end_logits 13 | ): 14 | mel_loss = nn.L1Loss()(pred_mel_spect, target_mel_spect) + nn.L1Loss()(pred_mel_spect_post, target_mel_spect) 15 | bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(config.bce_weights))(pred_end_logits.squeeze(-1), target_end_logits) 16 | return mel_loss + bce_loss 17 | 18 | 19 | def train_fn(model, dataloader, optimizer, scheduler, device): 20 | running_loss = 0 21 | model.train() 22 | for num, data in tqdm(enumerate(dataloader), total=len(dataloader)): 23 | for p in model.parameters(): 24 | p.grad=None 25 | 26 | end_logits = data['end_logits'].to(device) 27 | mel_spect = data['mel_spect'].to(device) 28 | text_idx = data['text_idx'].to(device) 29 | text = data['original_text'] 30 | mel_mask = data['mel_mask'].to(device) 31 | mel_spect_post_pred, mel_spect_pred, end_logits_pred = model(text_idx, mel_spect[:, :-1], mel_mask[:, :-1]) 32 | loss = loss_fn( 33 | mel_spect[:, 1:], 34 | end_logits[:, :-1], 35 | mel_spect_post_pred, 36 | mel_spect_pred, 37 | end_logits_pred 38 | ) 39 | 40 | running_loss += loss.item() 41 | loss.backward() 42 | optimizer.step() 43 | scheduler.step() 44 | 45 | epoch_loss = running_loss/len(dataloader) 46 | return epoch_loss 47 | 48 | 49 | def eval_fn(model, dataloder, device): 50 | running_loss = 0 51 | model.eval() 52 | with torch.no_grad(): 53 | for num, data in tqdm(enumerate(dataloder), total=len(dataloder)): 54 | end_logits = data['end_logits'].to(device) 55 | mel_spect = data['mel_spect'].to(device) 56 | text_idx = data['text_idx'].to(device) 57 | text = data['original_text'] 58 | mel_mask = data['mel_mask'].to(device) 59 | mel_spect_post_pred, mel_spect_pred, end_logits_pred = model(text_idx, mel_spect[:, :-1], mel_mask[:, :-1]) 60 | loss = loss_fn( 61 | mel_spect[:, 1:], 62 | end_logits[:, :-1], 63 | mel_spect_post_pred, 64 | mel_spect_pred, 65 | end_logits_pred 66 | ) 67 | 68 | running_loss += loss.item() 69 | epoch_loss = running_loss/len(dataloder) 70 | return epoch_loss 71 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import config 2 | 3 | import os 4 | import pickle 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import librosa 9 | from librosa.feature import melspectrogram 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | 13 | class TransformerLoader(torch.utils.data.Dataset): 14 | def __init__(self, files_name, text_data, mel_transforms=None, normalize=False): 15 | self.files_name = files_name 16 | self.text_data = text_data 17 | self.transforms = mel_transforms 18 | self.normalize = normalize 19 | self.char_to_idx = pickle.load(open('input/char_to_idx.pickle', 'rb')) 20 | 21 | def __len__(self): 22 | return len(self.text_data) 23 | 24 | 25 | def data_preprocess_(self, text): 26 | char_idx = [self.char_to_idx[char] for char in text if char in self.char_to_idx] 27 | return char_idx 28 | 29 | 30 | def normalize_(self, mel): 31 | #Normalizing data between -4 and 4 32 | #Converges even more faster 33 | mel = np.clip( 34 | (config.scaling_factor)*((mel - config.min_db_level)/-config.min_db_level) - config.scaling_factor, 35 | -config.scaling_factor, config.scaling_factor 36 | ) 37 | return mel 38 | 39 | 40 | def __getitem__(self, idx): 41 | file_name = self.files_name[idx] 42 | text = self.text_data[idx] 43 | text_idx = self.data_preprocess_(text) 44 | audio_path = os.path.join(config.Audio_file_path + file_name + '.wav') 45 | 46 | audio_file, _ = librosa.load( 47 | audio_path, 48 | sr=config.sample_rate 49 | ) 50 | 51 | audio_file, _ = librosa.effects.trim(audio_file) 52 | 53 | mel_spect = melspectrogram( 54 | audio_file, 55 | sr=config.sample_rate, 56 | n_mels=config.n_mels, 57 | hop_length=config.hop_length, 58 | win_length=config.win_length 59 | ) 60 | 61 | pre_mel_spect = np.zeros((1, config.n_mels)) 62 | mel_spect = librosa.power_to_db(mel_spect).T 63 | mel_spect = np.concatenate((pre_mel_spect, mel_spect), axis=0) 64 | 65 | if self.normalize: 66 | mel_spect = self.normalize_(mel_spect) 67 | 68 | mel_spect = torch.tensor(mel_spect, dtype=torch.float) 69 | mel_mask = [1]*mel_spect.shape[0] 70 | 71 | end_logits = [0]*(len(mel_spect) - 1) 72 | end_logits += [1] 73 | 74 | if self.transforms: 75 | for transform in self.transforms: 76 | if np.random.randint(0, 11) == 10: 77 | mel_spect = transform(mel_spect).squeeze(0) 78 | 79 | return { 80 | 'original_text' : text, 81 | 'mel_spect' : mel_spect, 82 | 'mel_mask' : torch.tensor(mel_mask, dtype=torch.long), 83 | 'text_idx' : torch.tensor(text_idx, dtype=torch.long), 84 | 'end_logits' : torch.tensor(end_logits, dtype=torch.float), 85 | } 86 | 87 | 88 | class MyCollate: 89 | def __init__(self, pad_idx, spect_pad): 90 | self.pad_idx = pad_idx 91 | self.spect_pad =spect_pad 92 | 93 | def __call__(self, batch): 94 | text_idx = [item['text_idx'] for item in batch] 95 | padded_text_idx = pad_sequence( 96 | text_idx, 97 | batch_first=True, 98 | padding_value=self.pad_idx 99 | ) 100 | end_logits = [item['end_logits'] for item in batch] 101 | padded_end_logits = pad_sequence( 102 | end_logits, 103 | batch_first=True, 104 | padding_value=0 105 | ) 106 | original_text = [item['original_text'] for item in batch] 107 | mel_mask = [item['mel_mask'] for item in batch] 108 | padded_mel_mask = pad_sequence( 109 | mel_mask, 110 | batch_first=True, 111 | padding_value=0 112 | ) 113 | mel_spects = [item['mel_spect'] for item in batch] 114 | 115 | batch_size, max_len = padded_mel_mask.shape 116 | 117 | padded_mel_spect = torch.zeros(batch_size, max_len, mel_spects[0].shape[-1]) 118 | 119 | for num,mel_spect in enumerate(mel_spects): 120 | padded_mel_spect[num, :mel_spect.shape[0]] = mel_spect 121 | 122 | return { 123 | 'original_text' : original_text, 124 | 'mel_spect' : padded_mel_spect, 125 | 'mel_mask' : padded_mel_mask, 126 | 'text_idx' : padded_text_idx, 127 | 'end_logits' : padded_end_logits 128 | } 129 | 130 | 131 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import config 5 | import dataloader 6 | import engine 7 | 8 | import sys 9 | 10 | import os 11 | import gc 12 | import transformers 13 | import torch 14 | import torch.nn as nn 15 | import torchaudio 16 | from tqdm import tqdm 17 | from sklearn.model_selection import train_test_split 18 | from Transformer_tts_model.TransformerTTSModel import TransformerTTS 19 | 20 | def preprocess(data): 21 | text_data, audio_files_name = [], [] 22 | for d in data: 23 | audio_name, text = d.split('|')[:2] 24 | text_data.append(text.lower()) 25 | audio_files_name.append(audio_name) 26 | return text_data, audio_files_name 27 | 28 | def train(): 29 | data = open(config.Metadata).read().strip().split('\n')[:10] 30 | text_data, audio_file_name = preprocess(data) 31 | del data 32 | gc.collect() 33 | transforms = [ 34 | torchaudio.transforms.FrequencyMasking(freq_mask_param=15), 35 | torchaudio.transforms.TimeMasking(time_mask_param=35) 36 | ] 37 | 38 | train_text_data, val_text_data, train_audio_file_name, val_audio_file_name = train_test_split( 39 | text_data, 40 | audio_file_name, 41 | test_size=0.2 42 | ) 43 | 44 | train_data = dataloader.TransformerLoader( 45 | files_name=train_audio_file_name, 46 | text_data=train_text_data, 47 | mel_transforms=transforms, 48 | normalize=True 49 | ) 50 | 51 | val_data = dataloader.TransformerLoader( 52 | files_name=val_audio_file_name, 53 | text_data=val_text_data, 54 | normalize=True 55 | ) 56 | 57 | pad_idx = 0 58 | 59 | 60 | train_loader = torch.utils.data.DataLoader( 61 | train_data, 62 | batch_size=config.Batch_Size, 63 | num_workers=1, 64 | pin_memory=True, 65 | collate_fn=dataloader.MyCollate( 66 | pad_idx=pad_idx, 67 | spect_pad=-config.scaling_factor 68 | ) 69 | ) 70 | 71 | val_loader = torch.utils.data.DataLoader( 72 | val_data, 73 | batch_size=config.Batch_Size, 74 | num_workers=1, 75 | pin_memory=True, 76 | collate_fn=dataloader.MyCollate( 77 | pad_idx=pad_idx, 78 | spect_pad=-config.scaling_factor 79 | ) 80 | ) 81 | 82 | vocab_size = len(train_data.char_to_idx) + 1 83 | 84 | model = TransformerTTS( 85 | vocab_size=vocab_size, 86 | embed_dims=config.embed_dims, 87 | hidden_dims=config.hidden_dims, 88 | heads=config.heads, 89 | forward_expansion=config.forward_expansion, 90 | num_layers=config.num_layers, 91 | dropout=config.dropout, 92 | mel_dims=config.n_mels, 93 | max_len=config.max_len, 94 | pad_idx=config.pad_idx 95 | ) 96 | # device = torch.device('cuda' if torch.cuda.is_available else 'cpu') 97 | # torch.backends.cudnn.benchmark = True 98 | device = torch.device('cpu') 99 | model = model.to(device) 100 | 101 | optimizer = transformers.AdamW(model.parameters(), lr=config.LR) 102 | 103 | num_training_steps = config.Epochs*len(train_data)//config.Batch_Size 104 | 105 | scheduler = transformers.get_cosine_schedule_with_warmup( 106 | optimizer, 107 | num_warmup_steps=config.warmup_steps*num_training_steps, 108 | num_training_steps=num_training_steps 109 | ) 110 | 111 | epoch_start = 0 112 | 113 | if os.path.exists(config.checkpoint): 114 | checkpoint = torch.load(config.checkpoint) 115 | model.load_state_dict(checkpoint['model_state_dict']) 116 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 117 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 118 | epoch_start = checkpoint['epoch'] 119 | print(f'---------[INFO] Restarting Training from Epoch {epoch_start} -----------\n') 120 | 121 | 122 | 123 | best_loss = 1e10 124 | best_model = model.state_dict() 125 | print('--------- [INFO] STARTING TRAINING ---------\n') 126 | for epoch in range(epoch_start, config.Epochs): 127 | train_loss = engine.train_fn(model, train_loader, optimizer, scheduler, device) 128 | val_loss = engine.eval_fn(model, val_loader, device) 129 | print(f'EPOCH -> {epoch+1}/{config.Epochs} | TRAIN LOSS = {train_loss} | VAL LOSS = {val_loss} | LR = {scheduler.get_lr()[0]} \n') 130 | 131 | torch.save({ 132 | 'epoch' : epoch, 133 | 'model_state_dict' : model.state_dict(), 134 | 'optimizer_state_dict' : optimizer.state_dict(), 135 | 'scheduler_state_dict' : scheduler.state_dict(), 136 | 'loss': val_loss, 137 | }, config.checkpoint) 138 | 139 | if best_loss > val_loss: 140 | best_loss = val_loss 141 | best_model = model.state_dict() 142 | torch.save(best_model, config.Model_Path) 143 | 144 | 145 | if __name__ == "__main__": 146 | train() -------------------------------------------------------------------------------- /input/LJSpeech-1.1/README: -------------------------------------------------------------------------------- 1 | ----------------------------------------------------------------------------- 2 | The LJ Speech Dataset 3 | 4 | Version 1.0 5 | July 5, 2017 6 | https://keithito.com/LJ-Speech-Dataset 7 | ----------------------------------------------------------------------------- 8 | 9 | 10 | OVERVIEW 11 | 12 | This is a public domain speech dataset consisting of 13,100 short audio clips 13 | of a single speaker reading passages from 7 non-fiction books. A transcription 14 | is provided for each clip. Clips vary in length from 1 to 10 seconds and have 15 | a total length of approximately 24 hours. 16 | 17 | The texts were published between 1884 and 1964, and are in the public domain. 18 | The audio was recorded in 2016-17 by the LibriVox project and is also in the 19 | public domain. 20 | 21 | 22 | 23 | FILE FORMAT 24 | 25 | Metadata is provided in metadata.csv. This file consists of one record per 26 | line, delimited by the pipe character (0x7c). The fields are: 27 | 28 | 1. ID: this is the name of the corresponding .wav file 29 | 2. Transcription: words spoken by the reader (UTF-8) 30 | 3. Normalized Transcription: transcription with numbers, ordinals, and 31 | monetary units expanded into full words (UTF-8). 32 | 33 | Each audio file is a single-channel 16-bit PCM WAV with a sample rate of 34 | 22050 Hz. 35 | 36 | 37 | 38 | STATISTICS 39 | 40 | Total Clips 13,100 41 | Total Words 225,715 42 | Total Characters 1,308,674 43 | Total Duration 23:55:17 44 | Mean Clip Duration 6.57 sec 45 | Min Clip Duration 1.11 sec 46 | Max Clip Duration 10.10 sec 47 | Mean Words per Clip 17.23 48 | Distinct Words 13,821 49 | 50 | 51 | 52 | MISCELLANEOUS 53 | 54 | The audio clips range in length from approximately 1 second to 10 seconds. 55 | They were segmented automatically based on silences in the recording. Clip 56 | boundaries generally align with sentence or clause boundaries, but not always. 57 | 58 | The text was matched to the audio manually, and a QA pass was done to ensure 59 | that the text accurately matched the words spoken in the audio. 60 | 61 | The original LibriVox recordings were distributed as 128 kbps MP3 files. As a 62 | result, they may contain artifacts introduced by the MP3 encoding. 63 | 64 | The following abbreviations appear in the text. They may be expanded as 65 | follows: 66 | 67 | Abbreviation Expansion 68 | -------------------------- 69 | Mr. Mister 70 | Mrs. Misess (*) 71 | Dr. Doctor 72 | No. Number 73 | St. Saint 74 | Co. Company 75 | Jr. Junior 76 | Maj. Major 77 | Gen. General 78 | Drs. Doctors 79 | Rev. Reverend 80 | Lt. Lieutenant 81 | Hon. Honorable 82 | Sgt. Sergeant 83 | Capt. Captain 84 | Esq. Esquire 85 | Ltd. Limited 86 | Col. Colonel 87 | Ft. Fort 88 | 89 | * there's no standard expansion of "Mrs." 90 | 91 | 92 | 19 of the transcriptions contain non-ASCII characters (for example, LJ016-0257 93 | contains "raison d'être"). 94 | 95 | For more information or to report errors, please email kito@kito.us. 96 | 97 | 98 | 99 | LICENSE 100 | 101 | This dataset is in the public domain in the USA (and likely other countries as 102 | well). There are no restrictions on its use. For more information, please see: 103 | https://librivox.org/pages/public-domain. 104 | 105 | 106 | CHANGELOG 107 | 108 | * 1.0 (July 8, 2017): 109 | Initial release 110 | 111 | * 1.1 (Feb 19, 2018): 112 | Version 1.0 included 30 .wav files with no corresponding annotations in 113 | metadata.csv. These have been removed in version 1.1. Thanks to Rafael Valle 114 | for spotting this. 115 | 116 | 117 | CREDITS 118 | 119 | This dataset consists of excerpts from the following works: 120 | 121 | * Morris, William, et al. Arts and Crafts Essays. 1893. 122 | * Griffiths, Arthur. The Chronicles of Newgate, Vol. 2. 1884. 123 | * Roosevelt, Franklin D. The Fireside Chats of Franklin Delano Roosevelt. 124 | 1933-42. 125 | * Harland, Marion. Marion Harland's Cookery for Beginners. 1893. 126 | * Rolt-Wheeler, Francis. The Science - History of the Universe, Vol. 5: 127 | Biology. 1910. 128 | * Banks, Edgar J. The Seven Wonders of the Ancient World. 1916. 129 | * President's Commission on the Assassination of President Kennedy. Report 130 | of the President's Commission on the Assassination of President Kennedy. 131 | 1964. 132 | 133 | Recordings by Linda Johnson. Alignment and annotation by Keith Ito. All text, 134 | audio, and annotations are in the public domain. 135 | 136 | There's no requirement to cite this work, but if you'd like to do so, you can 137 | link to: https://keithito.com/LJ-Speech-Dataset 138 | 139 | or use the following: 140 | @misc{ljspeech17, 141 | author = {Keith Ito}, 142 | title = {The LJ Speech Dataset}, 143 | howpublished = {\url{https://keithito.com/LJ-Speech-Dataset/}}, 144 | year = 2017 145 | } 146 | -------------------------------------------------------------------------------- /Wavenet/WaveNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | 8 | def receptive_feild_size(layers_per_stack, kernel_size): 9 | dilations = [2**i for i in range(layers_per_stack)] 10 | receptive_feild_size = (kernel_size-1)*sum(dilations) + 1 11 | return receptive_feild_size 12 | 13 | 14 | def padding_calc(kernel_size, dilation): 15 | padding = (kernel_size - 1)*dilation 16 | return padding 17 | 18 | def mixture_of_log_sampling(y, log_scale_min=-7.0, clamp_log_scale=False): 19 | nr_mix = y.shape[1] // 3 20 | 21 | y = y.transpose(1, 2) 22 | logit_probs = y[:, :, :nr_mix] 23 | 24 | temp = logit_probs.data.new(logit_probs.shape).uniform_(1e-5, 1.0 - 1e-5) 25 | temp = logit_probs.data - torch.log(- torch.log(temp)) 26 | _, argmax = temp.max(dim=-1) 27 | 28 | one_hot = torch.zeros(logit_probs.shape) 29 | one_hot[0, -1, argmax] = 1 30 | 31 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 32 | log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1) 33 | if clamp_log_scale: 34 | log_scales = torch.clamp(log_scales, min=log_scale_min) 35 | 36 | u = means.data.new(means.shape).uniform_(1e-5, 1.0 - 1e-5) 37 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 38 | 39 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 40 | 41 | return x 42 | 43 | 44 | class Conv_layers(nn.Module): 45 | def __init__( 46 | self, 47 | in_channels 48 | ): 49 | super(Conv_layers, self).__init__() 50 | self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=1) 51 | 52 | def forward(self, x): 53 | return nn.ReLU()(self.conv(x)) 54 | 55 | 56 | class UpScalingNet(nn.Module): 57 | def __init__( 58 | self, 59 | hopsize, 60 | factor, 61 | cin_channels, 62 | ): 63 | super(UpScalingNet, self).__init__() 64 | self.hopsize = hopsize 65 | self.factor = factor 66 | self.repetitions = int(math.log(hopsize)/math.log(self.factor)) 67 | self.forward_layers = nn.Sequential(*[ 68 | Conv_layers(cin_channels) 69 | for i in range(self.repetitions) 70 | ]) 71 | 72 | def forward(self, x): 73 | for layer in self.forward_layers: 74 | x = nn.functional.interpolate(x.float(), scale_factor=self.factor) 75 | x = layer(x) 76 | return x 77 | 78 | 79 | class WaveNet(nn.Module): 80 | def __init__( 81 | self, 82 | layers_per_stack=20, 83 | stack=2, 84 | residual_channels=512, 85 | gate_channels=512, 86 | filter_channels=512, 87 | skip_out_channels=512, 88 | l_in_channels=-1, 89 | hopsize=512, 90 | out_channels=256, 91 | scaler_input=True, 92 | if_quantized=False, 93 | include_bias=True, 94 | loss_ = "MOL" 95 | ): 96 | super(WaveNet, self).__init__() 97 | self.layers_per_stack = layers_per_stack 98 | self.stack = stack 99 | self.residual_channels = residual_channels 100 | self.gate_channels = gate_channels 101 | self.skip_out_channels = skip_out_channels 102 | self.hopsize = hopsize 103 | self.scaler_input = scaler_input 104 | self.padding_list = [] 105 | self.out_channels = out_channels 106 | 107 | self.g_gate = nn.ModuleList() 108 | self.g_filter = nn.ModuleList() 109 | 110 | self.l_filter = nn.ModuleList() 111 | self.l_gate = nn.ModuleList() 112 | 113 | self.skip = nn.ModuleList() 114 | self.residual = nn.ModuleList() 115 | 116 | if self.scaler_input: 117 | self.conv1 = nn.Conv1d(1, residual_channels, kernel_size=1, bias=include_bias) 118 | else: 119 | self.conv1 = nn.Conv1d(out_channels, residual_channels, kernel_size=1, bias=include_bias) 120 | 121 | for i in range(stack): 122 | dilation = 1 123 | for layer in range(layers_per_stack): 124 | padding = padding_calc(2, dilation) 125 | self.padding_list.append(padding) 126 | self.g_filter.append( 127 | nn.Conv1d( 128 | residual_channels, 129 | filter_channels, 130 | kernel_size=2, 131 | padding=padding, 132 | dilation=dilation, 133 | bias=include_bias 134 | ) 135 | ) 136 | 137 | self.g_gate.append( 138 | nn.Conv1d( 139 | residual_channels, 140 | gate_channels, 141 | kernel_size=2, 142 | padding=padding, 143 | dilation=dilation, 144 | bias=include_bias 145 | ) 146 | ) 147 | 148 | self.l_gate.append( 149 | nn.Conv1d( 150 | l_in_channels, 151 | gate_channels, 152 | kernel_size=1, 153 | bias=include_bias 154 | ) 155 | ) 156 | 157 | self.l_filter.append( 158 | nn.Conv1d( 159 | l_in_channels, 160 | filter_channels, 161 | kernel_size=1, 162 | bias=include_bias 163 | ) 164 | ) 165 | 166 | self.residual.append( 167 | nn.Conv1d( 168 | gate_channels, 169 | residual_channels, 170 | kernel_size=1, 171 | bias=include_bias 172 | ) 173 | ) 174 | 175 | self.skip.append( 176 | nn.Conv1d( 177 | gate_channels, 178 | skip_out_channels, 179 | kernel_size=1, 180 | bias=include_bias 181 | ) 182 | ) 183 | 184 | dilation *= 2 185 | 186 | self.receptive_field = receptive_feild_size(self.layers_per_stack, kernel_size=2) 187 | 188 | self.upscaling_net = UpScalingNet(hopsize, factor=4, cin_channels=l_in_channels) 189 | 190 | self.last_conv = nn.Sequential( 191 | nn.ReLU(), 192 | nn.Conv1d(skip_out_channels, skip_out_channels, kernel_size=1), 193 | nn.ReLU(), 194 | nn.Conv1d(skip_out_channels, out_channels, kernel_size=1) 195 | ) 196 | 197 | def forward(self, x, local_feature, is_training=True): 198 | B, _, T = x.shape 199 | 200 | if is_training: 201 | local_feature = self.upscaling_net(local_feature) 202 | 203 | skip = 0 204 | 205 | x = self.conv1(x) 206 | 207 | for i in range(self.layers_per_stack*self.stack): 208 | residuals = x 209 | g_f = self.g_filter[i](x)[:, :, :x.shape[-1]] 210 | g_g = self.g_gate[i](x)[:, :, :x.shape[-1]] 211 | 212 | 213 | l_f = self.l_filter[i](local_feature)[:, :, :x.shape[-1]] 214 | l_g = self.l_gate[i](local_feature)[:, :, :x.shape[-1]] 215 | 216 | 217 | f, g = g_f + l_f, g_g + l_g 218 | 219 | 220 | x = torch.tanh(f)*torch.sigmoid(g) 221 | 222 | 223 | s = self.skip[i](x) 224 | 225 | 226 | skip += s 227 | 228 | 229 | x = self.residual[i](x) 230 | 231 | x = (x + residuals)*math.sqrt(0.5) 232 | 233 | x = self.last_conv(skip) 234 | 235 | return x 236 | 237 | 238 | def waveform_generation(self, g, mel_spect, device="cpu", time_scale=100, decoding="greedy"): 239 | 240 | mel_spect = self.upscaling_net(mel_spect) 241 | time_scale = max(time_scale, mel_spect.shape[-1]) 242 | 243 | if self.scaler_input: 244 | global_ = torch.zeros(1, 1, 1, dtype=torch.float) 245 | 246 | else: 247 | global_ = torch.zeros(1, self.out_channels, 1, dtype=torch.float) 248 | 249 | 250 | output = [] 251 | output.append(global_) 252 | 253 | for t in range(time_scale): 254 | global_features = torch.cat(output[max(t-self.receptive_field, 0): t+1], dim=-1).to(device) 255 | local_features = mel_spect[:, :, max(t-self.receptive_field, 0): t+1] 256 | 257 | out = self.forward(global_features, local_features, is_training=False) 258 | 259 | if decoding == "greedy": 260 | out = out.permute(0, 2, 1) 261 | out = out.argmax(dim=-1)[:, -1] 262 | 263 | if self.scaler_input: 264 | output.append(out) 265 | else: 266 | x = torch.zeros(1, self.out_channels, 1) 267 | x[:, out[0], :] = 1 268 | out = x 269 | output.append(out.detach()) 270 | 271 | elif decoding == "greedy_autoregress": 272 | out = out.permute(0, 2, 1) 273 | out = torch.softmax(out, dim=-1)[:, -1].view(-1) 274 | out = np.random.choice(np.arange(256), p=out.cpu().detach().numpy()) 275 | if self.scaler_input: 276 | output.append(torch.tensor(out.detach())) 277 | else: 278 | x = torch.zeros(1, self.out_channels, 1) 279 | x[:, out, :] = 1 280 | out = x 281 | output.append(out.detach()) 282 | 283 | elif (decoding == "MOL") and (self.scaler_input): 284 | out = mixture_of_log_sampling(out)[:, -1].unsqueeze(0).unsqueeze(1) 285 | output.append(out.detach()) 286 | 287 | 288 | return output 289 | 290 | 291 | 292 | if __name__ == "__main__": 293 | local = torch.randn(1, 20, 114, dtype=torch.float) 294 | # global_ = torch.randint(0, 255, (1, 1, 29120), dtype=torch.float) 295 | global_ = torch.randn(1, 1, 29120, dtype=torch.float) 296 | device = torch.device("cpu") 297 | local = local.to(device) 298 | global_ = global_.to(device) 299 | 300 | wavenet = WaveNet( 301 | layers_per_stack=10, 302 | stack=3, 303 | residual_channels=32, 304 | gate_channels=32, 305 | filter_channels=32, 306 | skip_out_channels=256, 307 | l_in_channels=20, 308 | hopsize=256, 309 | scaler_input=False, 310 | out_channels=30, 311 | include_bias=True, 312 | loss_ = "MOL" 313 | ) 314 | 315 | wavenet = wavenet.to(device) 316 | 317 | start = time.time() 318 | out = wavenet.waveform_generation(global_, local, device=device, decoding="greedy") 319 | print(f"TIME TAKEN = {time.time() - start}") 320 | print(out.shape) -------------------------------------------------------------------------------- /Transformer_tts_model/TransformerTTSModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SelfAttention(nn.Module): 5 | def __init__(self, embed_dims, heads): 6 | super(SelfAttention, self).__init__() 7 | self.heads = heads 8 | self.embed_dims = embed_dims 9 | self.depth = embed_dims//heads 10 | 11 | self.query = nn.Linear(self.depth, self.depth) 12 | self.key = nn.Linear(self.depth, self.depth) 13 | self.value = nn.Linear(self.depth, self.depth) 14 | 15 | self.fc_out = nn.Linear(self.depth*self.heads*2, self.embed_dims) 16 | 17 | def forward(self, query, key, value, mask): 18 | batch, q_len, k_len, v_len = query.shape[0], query.shape[1], key.shape[1], value.shape[1] 19 | 20 | query = query.reshape(batch, q_len, self.heads, self.depth) 21 | key = key.reshape(batch, k_len, self.heads, self.depth) 22 | value = value.reshape(batch, v_len, self.heads, self.depth) 23 | 24 | query = self.query(query) 25 | key = self.key(key) 26 | value = self.value(value) 27 | 28 | energy = torch.einsum('bqhd, bkhd -> bhqk', [query, key]) 29 | 30 | if mask is not None: 31 | energy.masked_fill(mask==0, float("-1e20")) 32 | 33 | energy = torch.softmax((energy/((self.depth**1/2))), dim=-1) 34 | 35 | out = torch.einsum('bhqv, bvhd -> bqhd', [energy, value]) 36 | 37 | out = out.reshape(batch, q_len, self.heads*self.depth) 38 | query = query.reshape(batch, q_len, self.heads*self.depth) 39 | 40 | out = torch.cat([query, out], dim=-1) 41 | out = self.fc_out(out) 42 | 43 | return out 44 | 45 | 46 | class TransformerBlock(nn.Module): 47 | def __init__(self, hidden_dims, heads, dropout, forward_expansion): 48 | super(TransformerBlock, self).__init__() 49 | self.hidden_dims = hidden_dims 50 | self.heads = heads 51 | self.multihead_attention = SelfAttention(hidden_dims, heads) 52 | self.feed_forward = nn.Sequential( 53 | nn.Conv1d(hidden_dims, hidden_dims*forward_expansion, kernel_size=1), 54 | nn.GELU(), 55 | nn.Conv1d(hidden_dims*forward_expansion, hidden_dims, kernel_size=1) 56 | ) 57 | self.dropout = nn.Dropout(dropout) 58 | self.layer_norm1 = nn.LayerNorm(hidden_dims) 59 | self.layer_norm2 = nn.LayerNorm(hidden_dims) 60 | 61 | def forward(self, query, key, value, mask): 62 | attention_out = self.multihead_attention(query, key, value, mask) 63 | add = self.dropout(self.layer_norm1(attention_out + query)) 64 | ffn_in = add.transpose(1, 2) 65 | ffn_out = self.feed_forward(ffn_in) 66 | ffn_out = ffn_out.transpose(1, 2) 67 | out = self.dropout(self.layer_norm2(ffn_out + add)) 68 | return out 69 | 70 | 71 | class EncoderPreNet(nn.Module): 72 | def __init__(self, embed_dims, hidden_dims, dropout): 73 | super(EncoderPreNet, self).__init__() 74 | self.conv1 = nn.Conv1d( 75 | embed_dims, 76 | hidden_dims, 77 | kernel_size=5, 78 | padding=2 79 | ) 80 | 81 | self.conv2 = nn.Conv1d( 82 | hidden_dims, 83 | hidden_dims, 84 | kernel_size=5, 85 | padding=2 86 | ) 87 | 88 | self.conv3 = nn.Conv1d( 89 | hidden_dims, 90 | hidden_dims, 91 | kernel_size=5, 92 | padding=2 93 | ) 94 | 95 | self.batch_norm1 = nn.BatchNorm1d(hidden_dims) 96 | self.batch_norm2 = nn.BatchNorm1d(hidden_dims) 97 | self.batch_norm3 = nn.BatchNorm1d(hidden_dims) 98 | 99 | self.dropout1 = nn.Dropout(dropout) 100 | self.dropout2 = nn.Dropout(dropout) 101 | self.dropout3 = nn.Dropout(dropout) 102 | 103 | self.fc_out = nn.Linear(hidden_dims, hidden_dims) 104 | 105 | def forward(self, x): 106 | x = x.transpose(1, 2) 107 | x = self.dropout1(torch.relu(self.batch_norm1(self.conv1(x)))) 108 | x = self.dropout2(torch.relu(self.batch_norm2(self.conv2(x)))) 109 | x = self.dropout3(torch.relu(self.batch_norm3(self.conv3(x)))) 110 | x = x.transpose(1, 2) 111 | x = self.fc_out(x) 112 | return x 113 | 114 | class Encoder(nn.Module): 115 | def __init__( 116 | self, 117 | vocab_size, 118 | embed_dims, 119 | hidden_dims, 120 | max_len, 121 | heads, 122 | forward_expansion, 123 | num_layers, 124 | dropout 125 | ): 126 | super(Encoder, self).__init__() 127 | self.token_embed = nn.Embedding(vocab_size, embed_dims) 128 | self.positional_embed = nn.Parameter(torch.zeros(1, max_len, hidden_dims)) 129 | self.prenet = EncoderPreNet(embed_dims, hidden_dims, dropout) 130 | self.dropout = nn.Dropout(dropout) 131 | self.attention_layers = nn.Sequential( 132 | *[ 133 | TransformerBlock( 134 | hidden_dims, 135 | heads, 136 | dropout, 137 | forward_expansion 138 | ) 139 | for _ in range(num_layers) 140 | ] 141 | ) 142 | 143 | def forward(self, x, mask=None): 144 | seq_len = x.shape[1] 145 | token_embed = self.token_embed(x) 146 | positional_embed = self.positional_embed[:, :seq_len, :] 147 | x = self.prenet(token_embed) 148 | x += positional_embed 149 | x = self.dropout(x) 150 | for layer in self.attention_layers: 151 | x = layer(x, x, x, mask) 152 | return x 153 | 154 | 155 | class DecoderPreNet(nn.Module): 156 | def __init__(self, mel_dims, hidden_dims, dropout): 157 | super(DecoderPreNet, self).__init__() 158 | self.fc_out = nn.Sequential( 159 | nn.Linear(mel_dims, hidden_dims), 160 | nn.ReLU(), 161 | nn.Dropout(dropout), 162 | nn.Linear(hidden_dims, hidden_dims), 163 | nn.ReLU(), 164 | nn.Dropout(dropout) 165 | ) 166 | 167 | def forward(self, x): 168 | 169 | return self.fc_out(x) 170 | 171 | 172 | class PostNet(nn.Module): 173 | def __init__(self, mel_dims, hidden_dims, dropout): 174 | #causal padding -> padding = (kernel_size - 1) x dilation 175 | #kernel_size = 5 -> padding = 4 176 | #Exclude the last padding_size output as we want only left padded output 177 | super(PostNet, self).__init__() 178 | self.conv1 = nn.Conv1d(mel_dims, hidden_dims, kernel_size=5, padding=4) 179 | self.batch_norm1 = nn.BatchNorm1d(hidden_dims) 180 | self.dropout1 = nn.Dropout(dropout) 181 | self.conv_list = nn.Sequential( 182 | *[ 183 | nn.Conv1d(hidden_dims, hidden_dims, kernel_size=5, padding=4) 184 | for _ in range(3) 185 | ] 186 | ) 187 | 188 | self.batch_norm_list = nn.Sequential( 189 | *[ 190 | nn.BatchNorm1d(hidden_dims) 191 | for _ in range(3) 192 | ] 193 | ) 194 | 195 | self.dropout_list = nn.Sequential( 196 | *[ 197 | nn.Dropout(dropout) 198 | for _ in range(3) 199 | ] 200 | ) 201 | 202 | self.conv5 = nn.Conv1d(hidden_dims, mel_dims, kernel_size=5, padding=4) 203 | 204 | def forward(self, x): 205 | x = x.transpose(1, 2) 206 | x = self.dropout1(torch.tanh(self.batch_norm1(self.conv1(x)[:, :, :-4]))) 207 | for dropout, batchnorm, conv in zip(self.dropout_list, self.batch_norm_list, self.conv_list): 208 | x = dropout(torch.tanh(batchnorm(conv(x)[:, :, :-4]))) 209 | out = self.conv5(x)[:, :, :-4] 210 | out = out.transpose(1, 2) 211 | return out 212 | 213 | 214 | class DecoderBlock(nn.Module): 215 | def __init__( 216 | self, 217 | embed_dims, 218 | heads, 219 | forward_expansion, 220 | dropout 221 | ): 222 | super(DecoderBlock, self).__init__() 223 | self.causal_masked_attention = SelfAttention(embed_dims, heads) 224 | self.attention_layer = TransformerBlock( 225 | embed_dims, 226 | heads, 227 | dropout, 228 | forward_expansion 229 | ) 230 | self.dropout = nn.Dropout(dropout) 231 | self.layer_norm = nn.LayerNorm(embed_dims) 232 | 233 | def forward(self, query, key, value, src_mask, causal_mask): 234 | causal_masked_attention = self.causal_masked_attention(query, query, query, causal_mask) 235 | query = self.dropout(self.layer_norm(causal_masked_attention + query)) 236 | out = self.attention_layer(query, key, value, src_mask) 237 | return out 238 | 239 | 240 | class Decoder(nn.Module): 241 | def __init__( 242 | self, 243 | mel_dims, 244 | hidden_dims, 245 | heads, 246 | max_len, 247 | num_layers, 248 | forward_expansion, 249 | dropout 250 | ): 251 | super(Decoder, self).__init__() 252 | self.positional_embed = nn.Parameter(torch.zeros(1, max_len, hidden_dims)) 253 | self.prenet = DecoderPreNet(mel_dims, hidden_dims, dropout) 254 | self.attention_layers = nn.Sequential( 255 | *[ 256 | DecoderBlock( 257 | hidden_dims, 258 | heads, 259 | forward_expansion, 260 | dropout 261 | ) 262 | for _ in range(num_layers) 263 | ] 264 | ) 265 | self.mel_linear = nn.Linear(hidden_dims, mel_dims) 266 | self.stop_linear = nn.Linear(hidden_dims, 1) 267 | self.postnet = PostNet(mel_dims, hidden_dims, dropout) 268 | self.dropout = nn.Dropout(dropout) 269 | 270 | def forward(self, mel, encoder_output, src_mask, casual_mask): 271 | seq_len = mel.shape[1] 272 | prenet_out = self.prenet(mel) 273 | x = self.dropout(prenet_out + self.positional_embed[:, :seq_len, :]) 274 | 275 | for layer in self.attention_layers: 276 | x = layer(x, encoder_output, encoder_output, src_mask, casual_mask) 277 | 278 | stop_linear = self.stop_linear(x) 279 | 280 | mel_linear = self.mel_linear(x) 281 | 282 | postnet = self.postnet(mel_linear) 283 | 284 | out = postnet + mel_linear 285 | 286 | return out, mel_linear, stop_linear 287 | 288 | 289 | class TransformerTTS(nn.Module): 290 | def __init__( 291 | self, 292 | vocab_size, 293 | embed_dims, 294 | hidden_dims, 295 | heads, 296 | forward_expansion, 297 | num_layers, 298 | dropout, 299 | mel_dims, 300 | max_len, 301 | pad_idx 302 | ): 303 | super(TransformerTTS, self).__init__() 304 | self.encoder = Encoder( 305 | vocab_size, 306 | embed_dims, 307 | hidden_dims, 308 | max_len, 309 | heads, 310 | forward_expansion, 311 | num_layers, 312 | dropout 313 | ) 314 | 315 | self.decoder = Decoder( 316 | mel_dims, 317 | hidden_dims, 318 | heads, 319 | max_len, 320 | num_layers, 321 | forward_expansion, 322 | dropout 323 | ) 324 | 325 | self.pad_idx = pad_idx 326 | 327 | def target_mask(self, mel, mel_mask): 328 | seq_len = mel.shape[1] 329 | pad_mask = (mel_mask != self.pad_idx).unsqueeze(1).unsqueeze(3) 330 | causal_mask = torch.tril(torch.ones((1, seq_len, seq_len))).unsqueeze(1) 331 | return pad_mask, causal_mask 332 | 333 | def input_mask(self, x): 334 | mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2) 335 | return mask 336 | 337 | def forward(self, text_idx, mel, mel_mask): 338 | input_pad_mask = self.input_mask(text_idx) 339 | target_pad_mask, causal_mask = self.target_mask(mel, mel_mask) 340 | encoder_out = self.encoder(text_idx, input_pad_mask) 341 | mel_postout, mel_linear, stop_linear = self.decoder(mel, encoder_out, target_pad_mask, causal_mask) 342 | return mel_postout, mel_linear, stop_linear 343 | 344 | 345 | if __name__ == "__main__": 346 | a = torch.randint(0, 30, (4, 60)) 347 | mel = torch.randn(4, 128, 80) 348 | mask = torch.ones((4, 128)) 349 | model = TransformerTTS( 350 | vocab_size=30, 351 | embed_dims=512, 352 | hidden_dims=256, 353 | heads=4, 354 | forward_expansion=4, 355 | num_layers=6, 356 | dropout=0.1, 357 | mel_dims=80, 358 | max_len=512, 359 | pad_idx=0 360 | ) 361 | x, y, z = model(a, mel, mask) 362 | print(x.shape, y.shape, z.shape) --------------------------------------------------------------------------------