├── predict.py
├── .gitignore
├── input
├── LJSpeech-1.1
│ ├── .~lock.metadata.csv#
│ └── README
├── char_to_idx.pickle
└── idx_to_char.pickle
├── Transformer_tts_model
├── model.png
└── TransformerTTSModel.py
├── config.py
├── sampling.py
├── LICENSE
├── README.md
├── utils.py
├── engine.py
├── dataloader.py
├── train.py
└── Wavenet
└── WaveNet.py
/predict.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.csv
2 | *.wav
3 | *.bin
4 | __pycache__
5 |
6 | LICENSE
7 | README.md
8 |
9 |
--------------------------------------------------------------------------------
/input/LJSpeech-1.1/.~lock.metadata.csv#:
--------------------------------------------------------------------------------
1 | ,shivam,shivam,02.12.2020 18:34,file:///home/shivam/.config/libreoffice/4;
--------------------------------------------------------------------------------
/input/char_to_idx.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Text-To-Speech/HEAD/input/char_to_idx.pickle
--------------------------------------------------------------------------------
/input/idx_to_char.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Text-To-Speech/HEAD/input/idx_to_char.pickle
--------------------------------------------------------------------------------
/Transformer_tts_model/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Text-To-Speech/HEAD/Transformer_tts_model/model.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | sample_rate=16000 #Provided in the paper
2 | n_mels = 80
3 | frame_rate = 80 #Provided in the paper
4 | frame_length = 0.05
5 | hop_length = int(sample_rate/frame_rate)
6 | win_length = int(sample_rate*frame_length)
7 |
8 | scaling_factor = 4
9 | min_db_level = -100
10 |
11 | bce_weights = 7
12 |
13 | embed_dims = 512
14 | hidden_dims = 256
15 | heads = 4
16 | forward_expansion = 4
17 | num_layers = 4
18 | dropout = 0.15
19 | max_len = 1024
20 | pad_idx = 0
21 |
22 | Metadata = 'input/LJSpeech-1.1/metadata.csv'
23 | Audio_file_path = 'input/LJSpeech-1.1/wavs/'
24 |
25 | Model_Path = 'model/model.bin'
26 | checkpoint = 'model/checkpoint.bin'
27 |
28 | Batch_Size = 2
29 | Epochs = 40
30 | LR = 3e-4
31 | warmup_steps = 0.2
--------------------------------------------------------------------------------
/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def greedy_decoding():
4 | pass
5 |
6 |
7 | def sampling_decoding():
8 | pass
9 |
10 | def mixture_of_log_sampling(y, log_scale_min=-7.0, clamp_log_scale=False):
11 | nr_mix = y.shape[1] // 3
12 |
13 | y = y.transpose(1, 2)
14 | logit_probs = y[:, :, :nr_mix]
15 |
16 | temp = logit_probs.data.new(logit_probs.shape).uniform_(1e-5, 1.0 - 1e-5)
17 | temp = logit_probs.data - torch.log(- torch.log(temp))
18 | _, argmax = temp.max(dim=-1)
19 |
20 | one_hot = to_one_hot(argmax, nr_mix)
21 |
22 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
23 | log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1)
24 | if clamp_log_scale:
25 | log_scales = torch.clamp(log_scales, min=log_scale_min)
26 |
27 | u = means.data.new(means.shape).uniform_(1e-5, 1.0 - 1e-5)
28 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
29 |
30 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
31 |
32 | return x
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Shivam Raj
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer Text To Speech
2 |
3 | A text-to-speech (TTS) system converts normal language text into speech; other systems render symbolic linguistic representations like phonetic transcriptions into speech. Now with recent development in deep learning, it's possible to convert text into a human-understandable voice. For this, the text is fed into an Encoder-Decoder type Neural Network to output a Mel-Spectrogram. This Mel-Spectrogram can now be used to generate audio using the ["Griffin-Lim Algorithm"](https://paperswithcode.com/method/griffin-lim-algorithm). But due to its disadvantage that it is not able to produce human-like speech quality, another neural net named [WaveNet](https://deepmind.com/blog/article/wavenet-generative-model-raw-audio) is employed, which is fed by Mel-Spectrogram to produce audio that even a human is not able to differentiate apart.
4 |
5 | ## Model Architecture
6 |
7 | ### 1. Transformer TTS
8 |
9 |
10 | * An Encoder-Decoder transformer architecture for parallel training instead for Seq2Seq training incase of [Tacotron-2](https://github.com/NVIDIA/tacotron2).
11 | * Text are sent as input and the model outputs a Mel-Spectrogram.
12 | * Multi-headed attention is employed, with causal masking only on the decoder side.
13 | * Paper : [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895).
14 |
15 | ### 2. Wavenet
16 |
*
17 |
18 | * Output of the Transformer tts (Mel-Spectrogram) is fed into the Wavenet to generate audio samples.
19 | * Unlike Seq2Seq models wavenet also allows parallel training.
20 | * Paper : [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499).
21 |
22 |
23 |
24 |
25 | ## Dataset Information
26 | The model was trained on a subset of WMT-2014 English-German Dataset. Preprocessing was carried out before training the model.
27 | Dataset : https://keithito.com/LJ-Speech-Dataset/
28 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def mulaw_encode(audio_samples):
4 | # Outputs values ranging from 0-255
5 | audio_converted_samples = []
6 | mu = 255
7 | for audio_sample in audio_samples:
8 | audio_sample = np.sign(audio_sample) * (np.log1p(mu * np.abs(audio_sample)) / np.log1p(mu))
9 | audio_sample = ((audio_sample + 1) / 2 * mu + 0.5)
10 | audio_converted_samples.append(audio_sample.astype(int))
11 | audio_converted_samples = np.array(audio_converted_samples)
12 | return audio_converted_samples
13 |
14 |
15 | def mulaw_decode(audio_samples):
16 | audio_converted_samples = []
17 | mu = 255
18 | for audio_sample in audio_samples:
19 | audio_sample = audio_sample.astype(np.float32)
20 | audio_sample = 2*(audio_sample/mu) - 1
21 | audio_sample = np.sign(audio_sample)*(1.0 / mu)*((1.0 + mu)**(np.abs(audio_sample)) - 1.0)
22 | audio_converted_samples.append(audio_sample)
23 | audio_converted_samples = np.array(audio_converted_samples)
24 | return audio_converted_samples
25 |
26 |
27 | def normalize_(mel):
28 | #Normalizing data between -4 and 4
29 | #Converges even more faster
30 | mel = np.clip(
31 | (config.scaling_factor)*((mel - config.min_db_level)/-config.min_db_level) - config.scaling_factor,
32 | -config.scaling_factor, config.scaling_factor
33 | )
34 | return mel
35 |
36 |
37 |
38 |
39 |
40 | if __name__ == "__main__":
41 | import librosa
42 | import numpy as np
43 | a = np.random.uniform(-1, 1, (2, 10))
44 | a = mulaw_encode(a)
45 | print(a)
46 | # from sklearn.preprocessing import StandardScaler
47 | # audio, sr = librosa.load("../Downloads/WhatsApp Ptt 2021-07-05 at 3.18.15 PM.ogg", sr=16000)
48 | # audio_samples = audio[None, :]
49 | # print(min(audio_samples[0]), max(audio_samples[0]))
50 | # audio_samples = mulaw_encode(audio_samples)
51 | # print(min(audio_samples[0]), max(audio_samples[0]))
52 | # audio_samples = mulaw_decode(audio_samples)
53 | # print(min(audio_samples[0]), max(audio_samples[0]))
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import config
2 |
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 |
7 | def loss_fn(
8 | target_mel_spect,
9 | target_end_logits,
10 | pred_mel_spect_post,
11 | pred_mel_spect,
12 | pred_end_logits
13 | ):
14 | mel_loss = nn.L1Loss()(pred_mel_spect, target_mel_spect) + nn.L1Loss()(pred_mel_spect_post, target_mel_spect)
15 | bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(config.bce_weights))(pred_end_logits.squeeze(-1), target_end_logits)
16 | return mel_loss + bce_loss
17 |
18 |
19 | def train_fn(model, dataloader, optimizer, scheduler, device):
20 | running_loss = 0
21 | model.train()
22 | for num, data in tqdm(enumerate(dataloader), total=len(dataloader)):
23 | for p in model.parameters():
24 | p.grad=None
25 |
26 | end_logits = data['end_logits'].to(device)
27 | mel_spect = data['mel_spect'].to(device)
28 | text_idx = data['text_idx'].to(device)
29 | text = data['original_text']
30 | mel_mask = data['mel_mask'].to(device)
31 | mel_spect_post_pred, mel_spect_pred, end_logits_pred = model(text_idx, mel_spect[:, :-1], mel_mask[:, :-1])
32 | loss = loss_fn(
33 | mel_spect[:, 1:],
34 | end_logits[:, :-1],
35 | mel_spect_post_pred,
36 | mel_spect_pred,
37 | end_logits_pred
38 | )
39 |
40 | running_loss += loss.item()
41 | loss.backward()
42 | optimizer.step()
43 | scheduler.step()
44 |
45 | epoch_loss = running_loss/len(dataloader)
46 | return epoch_loss
47 |
48 |
49 | def eval_fn(model, dataloder, device):
50 | running_loss = 0
51 | model.eval()
52 | with torch.no_grad():
53 | for num, data in tqdm(enumerate(dataloder), total=len(dataloder)):
54 | end_logits = data['end_logits'].to(device)
55 | mel_spect = data['mel_spect'].to(device)
56 | text_idx = data['text_idx'].to(device)
57 | text = data['original_text']
58 | mel_mask = data['mel_mask'].to(device)
59 | mel_spect_post_pred, mel_spect_pred, end_logits_pred = model(text_idx, mel_spect[:, :-1], mel_mask[:, :-1])
60 | loss = loss_fn(
61 | mel_spect[:, 1:],
62 | end_logits[:, :-1],
63 | mel_spect_post_pred,
64 | mel_spect_pred,
65 | end_logits_pred
66 | )
67 |
68 | running_loss += loss.item()
69 | epoch_loss = running_loss/len(dataloder)
70 | return epoch_loss
71 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import config
2 |
3 | import os
4 | import pickle
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | import librosa
9 | from librosa.feature import melspectrogram
10 | from torch.nn.utils.rnn import pad_sequence
11 |
12 |
13 | class TransformerLoader(torch.utils.data.Dataset):
14 | def __init__(self, files_name, text_data, mel_transforms=None, normalize=False):
15 | self.files_name = files_name
16 | self.text_data = text_data
17 | self.transforms = mel_transforms
18 | self.normalize = normalize
19 | self.char_to_idx = pickle.load(open('input/char_to_idx.pickle', 'rb'))
20 |
21 | def __len__(self):
22 | return len(self.text_data)
23 |
24 |
25 | def data_preprocess_(self, text):
26 | char_idx = [self.char_to_idx[char] for char in text if char in self.char_to_idx]
27 | return char_idx
28 |
29 |
30 | def normalize_(self, mel):
31 | #Normalizing data between -4 and 4
32 | #Converges even more faster
33 | mel = np.clip(
34 | (config.scaling_factor)*((mel - config.min_db_level)/-config.min_db_level) - config.scaling_factor,
35 | -config.scaling_factor, config.scaling_factor
36 | )
37 | return mel
38 |
39 |
40 | def __getitem__(self, idx):
41 | file_name = self.files_name[idx]
42 | text = self.text_data[idx]
43 | text_idx = self.data_preprocess_(text)
44 | audio_path = os.path.join(config.Audio_file_path + file_name + '.wav')
45 |
46 | audio_file, _ = librosa.load(
47 | audio_path,
48 | sr=config.sample_rate
49 | )
50 |
51 | audio_file, _ = librosa.effects.trim(audio_file)
52 |
53 | mel_spect = melspectrogram(
54 | audio_file,
55 | sr=config.sample_rate,
56 | n_mels=config.n_mels,
57 | hop_length=config.hop_length,
58 | win_length=config.win_length
59 | )
60 |
61 | pre_mel_spect = np.zeros((1, config.n_mels))
62 | mel_spect = librosa.power_to_db(mel_spect).T
63 | mel_spect = np.concatenate((pre_mel_spect, mel_spect), axis=0)
64 |
65 | if self.normalize:
66 | mel_spect = self.normalize_(mel_spect)
67 |
68 | mel_spect = torch.tensor(mel_spect, dtype=torch.float)
69 | mel_mask = [1]*mel_spect.shape[0]
70 |
71 | end_logits = [0]*(len(mel_spect) - 1)
72 | end_logits += [1]
73 |
74 | if self.transforms:
75 | for transform in self.transforms:
76 | if np.random.randint(0, 11) == 10:
77 | mel_spect = transform(mel_spect).squeeze(0)
78 |
79 | return {
80 | 'original_text' : text,
81 | 'mel_spect' : mel_spect,
82 | 'mel_mask' : torch.tensor(mel_mask, dtype=torch.long),
83 | 'text_idx' : torch.tensor(text_idx, dtype=torch.long),
84 | 'end_logits' : torch.tensor(end_logits, dtype=torch.float),
85 | }
86 |
87 |
88 | class MyCollate:
89 | def __init__(self, pad_idx, spect_pad):
90 | self.pad_idx = pad_idx
91 | self.spect_pad =spect_pad
92 |
93 | def __call__(self, batch):
94 | text_idx = [item['text_idx'] for item in batch]
95 | padded_text_idx = pad_sequence(
96 | text_idx,
97 | batch_first=True,
98 | padding_value=self.pad_idx
99 | )
100 | end_logits = [item['end_logits'] for item in batch]
101 | padded_end_logits = pad_sequence(
102 | end_logits,
103 | batch_first=True,
104 | padding_value=0
105 | )
106 | original_text = [item['original_text'] for item in batch]
107 | mel_mask = [item['mel_mask'] for item in batch]
108 | padded_mel_mask = pad_sequence(
109 | mel_mask,
110 | batch_first=True,
111 | padding_value=0
112 | )
113 | mel_spects = [item['mel_spect'] for item in batch]
114 |
115 | batch_size, max_len = padded_mel_mask.shape
116 |
117 | padded_mel_spect = torch.zeros(batch_size, max_len, mel_spects[0].shape[-1])
118 |
119 | for num,mel_spect in enumerate(mel_spects):
120 | padded_mel_spect[num, :mel_spect.shape[0]] = mel_spect
121 |
122 | return {
123 | 'original_text' : original_text,
124 | 'mel_spect' : padded_mel_spect,
125 | 'mel_mask' : padded_mel_mask,
126 | 'text_idx' : padded_text_idx,
127 | 'end_logits' : padded_end_logits
128 | }
129 |
130 |
131 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings('ignore')
3 |
4 | import config
5 | import dataloader
6 | import engine
7 |
8 | import sys
9 |
10 | import os
11 | import gc
12 | import transformers
13 | import torch
14 | import torch.nn as nn
15 | import torchaudio
16 | from tqdm import tqdm
17 | from sklearn.model_selection import train_test_split
18 | from Transformer_tts_model.TransformerTTSModel import TransformerTTS
19 |
20 | def preprocess(data):
21 | text_data, audio_files_name = [], []
22 | for d in data:
23 | audio_name, text = d.split('|')[:2]
24 | text_data.append(text.lower())
25 | audio_files_name.append(audio_name)
26 | return text_data, audio_files_name
27 |
28 | def train():
29 | data = open(config.Metadata).read().strip().split('\n')[:10]
30 | text_data, audio_file_name = preprocess(data)
31 | del data
32 | gc.collect()
33 | transforms = [
34 | torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
35 | torchaudio.transforms.TimeMasking(time_mask_param=35)
36 | ]
37 |
38 | train_text_data, val_text_data, train_audio_file_name, val_audio_file_name = train_test_split(
39 | text_data,
40 | audio_file_name,
41 | test_size=0.2
42 | )
43 |
44 | train_data = dataloader.TransformerLoader(
45 | files_name=train_audio_file_name,
46 | text_data=train_text_data,
47 | mel_transforms=transforms,
48 | normalize=True
49 | )
50 |
51 | val_data = dataloader.TransformerLoader(
52 | files_name=val_audio_file_name,
53 | text_data=val_text_data,
54 | normalize=True
55 | )
56 |
57 | pad_idx = 0
58 |
59 |
60 | train_loader = torch.utils.data.DataLoader(
61 | train_data,
62 | batch_size=config.Batch_Size,
63 | num_workers=1,
64 | pin_memory=True,
65 | collate_fn=dataloader.MyCollate(
66 | pad_idx=pad_idx,
67 | spect_pad=-config.scaling_factor
68 | )
69 | )
70 |
71 | val_loader = torch.utils.data.DataLoader(
72 | val_data,
73 | batch_size=config.Batch_Size,
74 | num_workers=1,
75 | pin_memory=True,
76 | collate_fn=dataloader.MyCollate(
77 | pad_idx=pad_idx,
78 | spect_pad=-config.scaling_factor
79 | )
80 | )
81 |
82 | vocab_size = len(train_data.char_to_idx) + 1
83 |
84 | model = TransformerTTS(
85 | vocab_size=vocab_size,
86 | embed_dims=config.embed_dims,
87 | hidden_dims=config.hidden_dims,
88 | heads=config.heads,
89 | forward_expansion=config.forward_expansion,
90 | num_layers=config.num_layers,
91 | dropout=config.dropout,
92 | mel_dims=config.n_mels,
93 | max_len=config.max_len,
94 | pad_idx=config.pad_idx
95 | )
96 | # device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
97 | # torch.backends.cudnn.benchmark = True
98 | device = torch.device('cpu')
99 | model = model.to(device)
100 |
101 | optimizer = transformers.AdamW(model.parameters(), lr=config.LR)
102 |
103 | num_training_steps = config.Epochs*len(train_data)//config.Batch_Size
104 |
105 | scheduler = transformers.get_cosine_schedule_with_warmup(
106 | optimizer,
107 | num_warmup_steps=config.warmup_steps*num_training_steps,
108 | num_training_steps=num_training_steps
109 | )
110 |
111 | epoch_start = 0
112 |
113 | if os.path.exists(config.checkpoint):
114 | checkpoint = torch.load(config.checkpoint)
115 | model.load_state_dict(checkpoint['model_state_dict'])
116 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
117 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
118 | epoch_start = checkpoint['epoch']
119 | print(f'---------[INFO] Restarting Training from Epoch {epoch_start} -----------\n')
120 |
121 |
122 |
123 | best_loss = 1e10
124 | best_model = model.state_dict()
125 | print('--------- [INFO] STARTING TRAINING ---------\n')
126 | for epoch in range(epoch_start, config.Epochs):
127 | train_loss = engine.train_fn(model, train_loader, optimizer, scheduler, device)
128 | val_loss = engine.eval_fn(model, val_loader, device)
129 | print(f'EPOCH -> {epoch+1}/{config.Epochs} | TRAIN LOSS = {train_loss} | VAL LOSS = {val_loss} | LR = {scheduler.get_lr()[0]} \n')
130 |
131 | torch.save({
132 | 'epoch' : epoch,
133 | 'model_state_dict' : model.state_dict(),
134 | 'optimizer_state_dict' : optimizer.state_dict(),
135 | 'scheduler_state_dict' : scheduler.state_dict(),
136 | 'loss': val_loss,
137 | }, config.checkpoint)
138 |
139 | if best_loss > val_loss:
140 | best_loss = val_loss
141 | best_model = model.state_dict()
142 | torch.save(best_model, config.Model_Path)
143 |
144 |
145 | if __name__ == "__main__":
146 | train()
--------------------------------------------------------------------------------
/input/LJSpeech-1.1/README:
--------------------------------------------------------------------------------
1 | -----------------------------------------------------------------------------
2 | The LJ Speech Dataset
3 |
4 | Version 1.0
5 | July 5, 2017
6 | https://keithito.com/LJ-Speech-Dataset
7 | -----------------------------------------------------------------------------
8 |
9 |
10 | OVERVIEW
11 |
12 | This is a public domain speech dataset consisting of 13,100 short audio clips
13 | of a single speaker reading passages from 7 non-fiction books. A transcription
14 | is provided for each clip. Clips vary in length from 1 to 10 seconds and have
15 | a total length of approximately 24 hours.
16 |
17 | The texts were published between 1884 and 1964, and are in the public domain.
18 | The audio was recorded in 2016-17 by the LibriVox project and is also in the
19 | public domain.
20 |
21 |
22 |
23 | FILE FORMAT
24 |
25 | Metadata is provided in metadata.csv. This file consists of one record per
26 | line, delimited by the pipe character (0x7c). The fields are:
27 |
28 | 1. ID: this is the name of the corresponding .wav file
29 | 2. Transcription: words spoken by the reader (UTF-8)
30 | 3. Normalized Transcription: transcription with numbers, ordinals, and
31 | monetary units expanded into full words (UTF-8).
32 |
33 | Each audio file is a single-channel 16-bit PCM WAV with a sample rate of
34 | 22050 Hz.
35 |
36 |
37 |
38 | STATISTICS
39 |
40 | Total Clips 13,100
41 | Total Words 225,715
42 | Total Characters 1,308,674
43 | Total Duration 23:55:17
44 | Mean Clip Duration 6.57 sec
45 | Min Clip Duration 1.11 sec
46 | Max Clip Duration 10.10 sec
47 | Mean Words per Clip 17.23
48 | Distinct Words 13,821
49 |
50 |
51 |
52 | MISCELLANEOUS
53 |
54 | The audio clips range in length from approximately 1 second to 10 seconds.
55 | They were segmented automatically based on silences in the recording. Clip
56 | boundaries generally align with sentence or clause boundaries, but not always.
57 |
58 | The text was matched to the audio manually, and a QA pass was done to ensure
59 | that the text accurately matched the words spoken in the audio.
60 |
61 | The original LibriVox recordings were distributed as 128 kbps MP3 files. As a
62 | result, they may contain artifacts introduced by the MP3 encoding.
63 |
64 | The following abbreviations appear in the text. They may be expanded as
65 | follows:
66 |
67 | Abbreviation Expansion
68 | --------------------------
69 | Mr. Mister
70 | Mrs. Misess (*)
71 | Dr. Doctor
72 | No. Number
73 | St. Saint
74 | Co. Company
75 | Jr. Junior
76 | Maj. Major
77 | Gen. General
78 | Drs. Doctors
79 | Rev. Reverend
80 | Lt. Lieutenant
81 | Hon. Honorable
82 | Sgt. Sergeant
83 | Capt. Captain
84 | Esq. Esquire
85 | Ltd. Limited
86 | Col. Colonel
87 | Ft. Fort
88 |
89 | * there's no standard expansion of "Mrs."
90 |
91 |
92 | 19 of the transcriptions contain non-ASCII characters (for example, LJ016-0257
93 | contains "raison d'être").
94 |
95 | For more information or to report errors, please email kito@kito.us.
96 |
97 |
98 |
99 | LICENSE
100 |
101 | This dataset is in the public domain in the USA (and likely other countries as
102 | well). There are no restrictions on its use. For more information, please see:
103 | https://librivox.org/pages/public-domain.
104 |
105 |
106 | CHANGELOG
107 |
108 | * 1.0 (July 8, 2017):
109 | Initial release
110 |
111 | * 1.1 (Feb 19, 2018):
112 | Version 1.0 included 30 .wav files with no corresponding annotations in
113 | metadata.csv. These have been removed in version 1.1. Thanks to Rafael Valle
114 | for spotting this.
115 |
116 |
117 | CREDITS
118 |
119 | This dataset consists of excerpts from the following works:
120 |
121 | * Morris, William, et al. Arts and Crafts Essays. 1893.
122 | * Griffiths, Arthur. The Chronicles of Newgate, Vol. 2. 1884.
123 | * Roosevelt, Franklin D. The Fireside Chats of Franklin Delano Roosevelt.
124 | 1933-42.
125 | * Harland, Marion. Marion Harland's Cookery for Beginners. 1893.
126 | * Rolt-Wheeler, Francis. The Science - History of the Universe, Vol. 5:
127 | Biology. 1910.
128 | * Banks, Edgar J. The Seven Wonders of the Ancient World. 1916.
129 | * President's Commission on the Assassination of President Kennedy. Report
130 | of the President's Commission on the Assassination of President Kennedy.
131 | 1964.
132 |
133 | Recordings by Linda Johnson. Alignment and annotation by Keith Ito. All text,
134 | audio, and annotations are in the public domain.
135 |
136 | There's no requirement to cite this work, but if you'd like to do so, you can
137 | link to: https://keithito.com/LJ-Speech-Dataset
138 |
139 | or use the following:
140 | @misc{ljspeech17,
141 | author = {Keith Ito},
142 | title = {The LJ Speech Dataset},
143 | howpublished = {\url{https://keithito.com/LJ-Speech-Dataset/}},
144 | year = 2017
145 | }
146 |
--------------------------------------------------------------------------------
/Wavenet/WaveNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 |
8 | def receptive_feild_size(layers_per_stack, kernel_size):
9 | dilations = [2**i for i in range(layers_per_stack)]
10 | receptive_feild_size = (kernel_size-1)*sum(dilations) + 1
11 | return receptive_feild_size
12 |
13 |
14 | def padding_calc(kernel_size, dilation):
15 | padding = (kernel_size - 1)*dilation
16 | return padding
17 |
18 | def mixture_of_log_sampling(y, log_scale_min=-7.0, clamp_log_scale=False):
19 | nr_mix = y.shape[1] // 3
20 |
21 | y = y.transpose(1, 2)
22 | logit_probs = y[:, :, :nr_mix]
23 |
24 | temp = logit_probs.data.new(logit_probs.shape).uniform_(1e-5, 1.0 - 1e-5)
25 | temp = logit_probs.data - torch.log(- torch.log(temp))
26 | _, argmax = temp.max(dim=-1)
27 |
28 | one_hot = torch.zeros(logit_probs.shape)
29 | one_hot[0, -1, argmax] = 1
30 |
31 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
32 | log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1)
33 | if clamp_log_scale:
34 | log_scales = torch.clamp(log_scales, min=log_scale_min)
35 |
36 | u = means.data.new(means.shape).uniform_(1e-5, 1.0 - 1e-5)
37 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
38 |
39 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
40 |
41 | return x
42 |
43 |
44 | class Conv_layers(nn.Module):
45 | def __init__(
46 | self,
47 | in_channels
48 | ):
49 | super(Conv_layers, self).__init__()
50 | self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=1)
51 |
52 | def forward(self, x):
53 | return nn.ReLU()(self.conv(x))
54 |
55 |
56 | class UpScalingNet(nn.Module):
57 | def __init__(
58 | self,
59 | hopsize,
60 | factor,
61 | cin_channels,
62 | ):
63 | super(UpScalingNet, self).__init__()
64 | self.hopsize = hopsize
65 | self.factor = factor
66 | self.repetitions = int(math.log(hopsize)/math.log(self.factor))
67 | self.forward_layers = nn.Sequential(*[
68 | Conv_layers(cin_channels)
69 | for i in range(self.repetitions)
70 | ])
71 |
72 | def forward(self, x):
73 | for layer in self.forward_layers:
74 | x = nn.functional.interpolate(x.float(), scale_factor=self.factor)
75 | x = layer(x)
76 | return x
77 |
78 |
79 | class WaveNet(nn.Module):
80 | def __init__(
81 | self,
82 | layers_per_stack=20,
83 | stack=2,
84 | residual_channels=512,
85 | gate_channels=512,
86 | filter_channels=512,
87 | skip_out_channels=512,
88 | l_in_channels=-1,
89 | hopsize=512,
90 | out_channels=256,
91 | scaler_input=True,
92 | if_quantized=False,
93 | include_bias=True,
94 | loss_ = "MOL"
95 | ):
96 | super(WaveNet, self).__init__()
97 | self.layers_per_stack = layers_per_stack
98 | self.stack = stack
99 | self.residual_channels = residual_channels
100 | self.gate_channels = gate_channels
101 | self.skip_out_channels = skip_out_channels
102 | self.hopsize = hopsize
103 | self.scaler_input = scaler_input
104 | self.padding_list = []
105 | self.out_channels = out_channels
106 |
107 | self.g_gate = nn.ModuleList()
108 | self.g_filter = nn.ModuleList()
109 |
110 | self.l_filter = nn.ModuleList()
111 | self.l_gate = nn.ModuleList()
112 |
113 | self.skip = nn.ModuleList()
114 | self.residual = nn.ModuleList()
115 |
116 | if self.scaler_input:
117 | self.conv1 = nn.Conv1d(1, residual_channels, kernel_size=1, bias=include_bias)
118 | else:
119 | self.conv1 = nn.Conv1d(out_channels, residual_channels, kernel_size=1, bias=include_bias)
120 |
121 | for i in range(stack):
122 | dilation = 1
123 | for layer in range(layers_per_stack):
124 | padding = padding_calc(2, dilation)
125 | self.padding_list.append(padding)
126 | self.g_filter.append(
127 | nn.Conv1d(
128 | residual_channels,
129 | filter_channels,
130 | kernel_size=2,
131 | padding=padding,
132 | dilation=dilation,
133 | bias=include_bias
134 | )
135 | )
136 |
137 | self.g_gate.append(
138 | nn.Conv1d(
139 | residual_channels,
140 | gate_channels,
141 | kernel_size=2,
142 | padding=padding,
143 | dilation=dilation,
144 | bias=include_bias
145 | )
146 | )
147 |
148 | self.l_gate.append(
149 | nn.Conv1d(
150 | l_in_channels,
151 | gate_channels,
152 | kernel_size=1,
153 | bias=include_bias
154 | )
155 | )
156 |
157 | self.l_filter.append(
158 | nn.Conv1d(
159 | l_in_channels,
160 | filter_channels,
161 | kernel_size=1,
162 | bias=include_bias
163 | )
164 | )
165 |
166 | self.residual.append(
167 | nn.Conv1d(
168 | gate_channels,
169 | residual_channels,
170 | kernel_size=1,
171 | bias=include_bias
172 | )
173 | )
174 |
175 | self.skip.append(
176 | nn.Conv1d(
177 | gate_channels,
178 | skip_out_channels,
179 | kernel_size=1,
180 | bias=include_bias
181 | )
182 | )
183 |
184 | dilation *= 2
185 |
186 | self.receptive_field = receptive_feild_size(self.layers_per_stack, kernel_size=2)
187 |
188 | self.upscaling_net = UpScalingNet(hopsize, factor=4, cin_channels=l_in_channels)
189 |
190 | self.last_conv = nn.Sequential(
191 | nn.ReLU(),
192 | nn.Conv1d(skip_out_channels, skip_out_channels, kernel_size=1),
193 | nn.ReLU(),
194 | nn.Conv1d(skip_out_channels, out_channels, kernel_size=1)
195 | )
196 |
197 | def forward(self, x, local_feature, is_training=True):
198 | B, _, T = x.shape
199 |
200 | if is_training:
201 | local_feature = self.upscaling_net(local_feature)
202 |
203 | skip = 0
204 |
205 | x = self.conv1(x)
206 |
207 | for i in range(self.layers_per_stack*self.stack):
208 | residuals = x
209 | g_f = self.g_filter[i](x)[:, :, :x.shape[-1]]
210 | g_g = self.g_gate[i](x)[:, :, :x.shape[-1]]
211 |
212 |
213 | l_f = self.l_filter[i](local_feature)[:, :, :x.shape[-1]]
214 | l_g = self.l_gate[i](local_feature)[:, :, :x.shape[-1]]
215 |
216 |
217 | f, g = g_f + l_f, g_g + l_g
218 |
219 |
220 | x = torch.tanh(f)*torch.sigmoid(g)
221 |
222 |
223 | s = self.skip[i](x)
224 |
225 |
226 | skip += s
227 |
228 |
229 | x = self.residual[i](x)
230 |
231 | x = (x + residuals)*math.sqrt(0.5)
232 |
233 | x = self.last_conv(skip)
234 |
235 | return x
236 |
237 |
238 | def waveform_generation(self, g, mel_spect, device="cpu", time_scale=100, decoding="greedy"):
239 |
240 | mel_spect = self.upscaling_net(mel_spect)
241 | time_scale = max(time_scale, mel_spect.shape[-1])
242 |
243 | if self.scaler_input:
244 | global_ = torch.zeros(1, 1, 1, dtype=torch.float)
245 |
246 | else:
247 | global_ = torch.zeros(1, self.out_channels, 1, dtype=torch.float)
248 |
249 |
250 | output = []
251 | output.append(global_)
252 |
253 | for t in range(time_scale):
254 | global_features = torch.cat(output[max(t-self.receptive_field, 0): t+1], dim=-1).to(device)
255 | local_features = mel_spect[:, :, max(t-self.receptive_field, 0): t+1]
256 |
257 | out = self.forward(global_features, local_features, is_training=False)
258 |
259 | if decoding == "greedy":
260 | out = out.permute(0, 2, 1)
261 | out = out.argmax(dim=-1)[:, -1]
262 |
263 | if self.scaler_input:
264 | output.append(out)
265 | else:
266 | x = torch.zeros(1, self.out_channels, 1)
267 | x[:, out[0], :] = 1
268 | out = x
269 | output.append(out.detach())
270 |
271 | elif decoding == "greedy_autoregress":
272 | out = out.permute(0, 2, 1)
273 | out = torch.softmax(out, dim=-1)[:, -1].view(-1)
274 | out = np.random.choice(np.arange(256), p=out.cpu().detach().numpy())
275 | if self.scaler_input:
276 | output.append(torch.tensor(out.detach()))
277 | else:
278 | x = torch.zeros(1, self.out_channels, 1)
279 | x[:, out, :] = 1
280 | out = x
281 | output.append(out.detach())
282 |
283 | elif (decoding == "MOL") and (self.scaler_input):
284 | out = mixture_of_log_sampling(out)[:, -1].unsqueeze(0).unsqueeze(1)
285 | output.append(out.detach())
286 |
287 |
288 | return output
289 |
290 |
291 |
292 | if __name__ == "__main__":
293 | local = torch.randn(1, 20, 114, dtype=torch.float)
294 | # global_ = torch.randint(0, 255, (1, 1, 29120), dtype=torch.float)
295 | global_ = torch.randn(1, 1, 29120, dtype=torch.float)
296 | device = torch.device("cpu")
297 | local = local.to(device)
298 | global_ = global_.to(device)
299 |
300 | wavenet = WaveNet(
301 | layers_per_stack=10,
302 | stack=3,
303 | residual_channels=32,
304 | gate_channels=32,
305 | filter_channels=32,
306 | skip_out_channels=256,
307 | l_in_channels=20,
308 | hopsize=256,
309 | scaler_input=False,
310 | out_channels=30,
311 | include_bias=True,
312 | loss_ = "MOL"
313 | )
314 |
315 | wavenet = wavenet.to(device)
316 |
317 | start = time.time()
318 | out = wavenet.waveform_generation(global_, local, device=device, decoding="greedy")
319 | print(f"TIME TAKEN = {time.time() - start}")
320 | print(out.shape)
--------------------------------------------------------------------------------
/Transformer_tts_model/TransformerTTSModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class SelfAttention(nn.Module):
5 | def __init__(self, embed_dims, heads):
6 | super(SelfAttention, self).__init__()
7 | self.heads = heads
8 | self.embed_dims = embed_dims
9 | self.depth = embed_dims//heads
10 |
11 | self.query = nn.Linear(self.depth, self.depth)
12 | self.key = nn.Linear(self.depth, self.depth)
13 | self.value = nn.Linear(self.depth, self.depth)
14 |
15 | self.fc_out = nn.Linear(self.depth*self.heads*2, self.embed_dims)
16 |
17 | def forward(self, query, key, value, mask):
18 | batch, q_len, k_len, v_len = query.shape[0], query.shape[1], key.shape[1], value.shape[1]
19 |
20 | query = query.reshape(batch, q_len, self.heads, self.depth)
21 | key = key.reshape(batch, k_len, self.heads, self.depth)
22 | value = value.reshape(batch, v_len, self.heads, self.depth)
23 |
24 | query = self.query(query)
25 | key = self.key(key)
26 | value = self.value(value)
27 |
28 | energy = torch.einsum('bqhd, bkhd -> bhqk', [query, key])
29 |
30 | if mask is not None:
31 | energy.masked_fill(mask==0, float("-1e20"))
32 |
33 | energy = torch.softmax((energy/((self.depth**1/2))), dim=-1)
34 |
35 | out = torch.einsum('bhqv, bvhd -> bqhd', [energy, value])
36 |
37 | out = out.reshape(batch, q_len, self.heads*self.depth)
38 | query = query.reshape(batch, q_len, self.heads*self.depth)
39 |
40 | out = torch.cat([query, out], dim=-1)
41 | out = self.fc_out(out)
42 |
43 | return out
44 |
45 |
46 | class TransformerBlock(nn.Module):
47 | def __init__(self, hidden_dims, heads, dropout, forward_expansion):
48 | super(TransformerBlock, self).__init__()
49 | self.hidden_dims = hidden_dims
50 | self.heads = heads
51 | self.multihead_attention = SelfAttention(hidden_dims, heads)
52 | self.feed_forward = nn.Sequential(
53 | nn.Conv1d(hidden_dims, hidden_dims*forward_expansion, kernel_size=1),
54 | nn.GELU(),
55 | nn.Conv1d(hidden_dims*forward_expansion, hidden_dims, kernel_size=1)
56 | )
57 | self.dropout = nn.Dropout(dropout)
58 | self.layer_norm1 = nn.LayerNorm(hidden_dims)
59 | self.layer_norm2 = nn.LayerNorm(hidden_dims)
60 |
61 | def forward(self, query, key, value, mask):
62 | attention_out = self.multihead_attention(query, key, value, mask)
63 | add = self.dropout(self.layer_norm1(attention_out + query))
64 | ffn_in = add.transpose(1, 2)
65 | ffn_out = self.feed_forward(ffn_in)
66 | ffn_out = ffn_out.transpose(1, 2)
67 | out = self.dropout(self.layer_norm2(ffn_out + add))
68 | return out
69 |
70 |
71 | class EncoderPreNet(nn.Module):
72 | def __init__(self, embed_dims, hidden_dims, dropout):
73 | super(EncoderPreNet, self).__init__()
74 | self.conv1 = nn.Conv1d(
75 | embed_dims,
76 | hidden_dims,
77 | kernel_size=5,
78 | padding=2
79 | )
80 |
81 | self.conv2 = nn.Conv1d(
82 | hidden_dims,
83 | hidden_dims,
84 | kernel_size=5,
85 | padding=2
86 | )
87 |
88 | self.conv3 = nn.Conv1d(
89 | hidden_dims,
90 | hidden_dims,
91 | kernel_size=5,
92 | padding=2
93 | )
94 |
95 | self.batch_norm1 = nn.BatchNorm1d(hidden_dims)
96 | self.batch_norm2 = nn.BatchNorm1d(hidden_dims)
97 | self.batch_norm3 = nn.BatchNorm1d(hidden_dims)
98 |
99 | self.dropout1 = nn.Dropout(dropout)
100 | self.dropout2 = nn.Dropout(dropout)
101 | self.dropout3 = nn.Dropout(dropout)
102 |
103 | self.fc_out = nn.Linear(hidden_dims, hidden_dims)
104 |
105 | def forward(self, x):
106 | x = x.transpose(1, 2)
107 | x = self.dropout1(torch.relu(self.batch_norm1(self.conv1(x))))
108 | x = self.dropout2(torch.relu(self.batch_norm2(self.conv2(x))))
109 | x = self.dropout3(torch.relu(self.batch_norm3(self.conv3(x))))
110 | x = x.transpose(1, 2)
111 | x = self.fc_out(x)
112 | return x
113 |
114 | class Encoder(nn.Module):
115 | def __init__(
116 | self,
117 | vocab_size,
118 | embed_dims,
119 | hidden_dims,
120 | max_len,
121 | heads,
122 | forward_expansion,
123 | num_layers,
124 | dropout
125 | ):
126 | super(Encoder, self).__init__()
127 | self.token_embed = nn.Embedding(vocab_size, embed_dims)
128 | self.positional_embed = nn.Parameter(torch.zeros(1, max_len, hidden_dims))
129 | self.prenet = EncoderPreNet(embed_dims, hidden_dims, dropout)
130 | self.dropout = nn.Dropout(dropout)
131 | self.attention_layers = nn.Sequential(
132 | *[
133 | TransformerBlock(
134 | hidden_dims,
135 | heads,
136 | dropout,
137 | forward_expansion
138 | )
139 | for _ in range(num_layers)
140 | ]
141 | )
142 |
143 | def forward(self, x, mask=None):
144 | seq_len = x.shape[1]
145 | token_embed = self.token_embed(x)
146 | positional_embed = self.positional_embed[:, :seq_len, :]
147 | x = self.prenet(token_embed)
148 | x += positional_embed
149 | x = self.dropout(x)
150 | for layer in self.attention_layers:
151 | x = layer(x, x, x, mask)
152 | return x
153 |
154 |
155 | class DecoderPreNet(nn.Module):
156 | def __init__(self, mel_dims, hidden_dims, dropout):
157 | super(DecoderPreNet, self).__init__()
158 | self.fc_out = nn.Sequential(
159 | nn.Linear(mel_dims, hidden_dims),
160 | nn.ReLU(),
161 | nn.Dropout(dropout),
162 | nn.Linear(hidden_dims, hidden_dims),
163 | nn.ReLU(),
164 | nn.Dropout(dropout)
165 | )
166 |
167 | def forward(self, x):
168 |
169 | return self.fc_out(x)
170 |
171 |
172 | class PostNet(nn.Module):
173 | def __init__(self, mel_dims, hidden_dims, dropout):
174 | #causal padding -> padding = (kernel_size - 1) x dilation
175 | #kernel_size = 5 -> padding = 4
176 | #Exclude the last padding_size output as we want only left padded output
177 | super(PostNet, self).__init__()
178 | self.conv1 = nn.Conv1d(mel_dims, hidden_dims, kernel_size=5, padding=4)
179 | self.batch_norm1 = nn.BatchNorm1d(hidden_dims)
180 | self.dropout1 = nn.Dropout(dropout)
181 | self.conv_list = nn.Sequential(
182 | *[
183 | nn.Conv1d(hidden_dims, hidden_dims, kernel_size=5, padding=4)
184 | for _ in range(3)
185 | ]
186 | )
187 |
188 | self.batch_norm_list = nn.Sequential(
189 | *[
190 | nn.BatchNorm1d(hidden_dims)
191 | for _ in range(3)
192 | ]
193 | )
194 |
195 | self.dropout_list = nn.Sequential(
196 | *[
197 | nn.Dropout(dropout)
198 | for _ in range(3)
199 | ]
200 | )
201 |
202 | self.conv5 = nn.Conv1d(hidden_dims, mel_dims, kernel_size=5, padding=4)
203 |
204 | def forward(self, x):
205 | x = x.transpose(1, 2)
206 | x = self.dropout1(torch.tanh(self.batch_norm1(self.conv1(x)[:, :, :-4])))
207 | for dropout, batchnorm, conv in zip(self.dropout_list, self.batch_norm_list, self.conv_list):
208 | x = dropout(torch.tanh(batchnorm(conv(x)[:, :, :-4])))
209 | out = self.conv5(x)[:, :, :-4]
210 | out = out.transpose(1, 2)
211 | return out
212 |
213 |
214 | class DecoderBlock(nn.Module):
215 | def __init__(
216 | self,
217 | embed_dims,
218 | heads,
219 | forward_expansion,
220 | dropout
221 | ):
222 | super(DecoderBlock, self).__init__()
223 | self.causal_masked_attention = SelfAttention(embed_dims, heads)
224 | self.attention_layer = TransformerBlock(
225 | embed_dims,
226 | heads,
227 | dropout,
228 | forward_expansion
229 | )
230 | self.dropout = nn.Dropout(dropout)
231 | self.layer_norm = nn.LayerNorm(embed_dims)
232 |
233 | def forward(self, query, key, value, src_mask, causal_mask):
234 | causal_masked_attention = self.causal_masked_attention(query, query, query, causal_mask)
235 | query = self.dropout(self.layer_norm(causal_masked_attention + query))
236 | out = self.attention_layer(query, key, value, src_mask)
237 | return out
238 |
239 |
240 | class Decoder(nn.Module):
241 | def __init__(
242 | self,
243 | mel_dims,
244 | hidden_dims,
245 | heads,
246 | max_len,
247 | num_layers,
248 | forward_expansion,
249 | dropout
250 | ):
251 | super(Decoder, self).__init__()
252 | self.positional_embed = nn.Parameter(torch.zeros(1, max_len, hidden_dims))
253 | self.prenet = DecoderPreNet(mel_dims, hidden_dims, dropout)
254 | self.attention_layers = nn.Sequential(
255 | *[
256 | DecoderBlock(
257 | hidden_dims,
258 | heads,
259 | forward_expansion,
260 | dropout
261 | )
262 | for _ in range(num_layers)
263 | ]
264 | )
265 | self.mel_linear = nn.Linear(hidden_dims, mel_dims)
266 | self.stop_linear = nn.Linear(hidden_dims, 1)
267 | self.postnet = PostNet(mel_dims, hidden_dims, dropout)
268 | self.dropout = nn.Dropout(dropout)
269 |
270 | def forward(self, mel, encoder_output, src_mask, casual_mask):
271 | seq_len = mel.shape[1]
272 | prenet_out = self.prenet(mel)
273 | x = self.dropout(prenet_out + self.positional_embed[:, :seq_len, :])
274 |
275 | for layer in self.attention_layers:
276 | x = layer(x, encoder_output, encoder_output, src_mask, casual_mask)
277 |
278 | stop_linear = self.stop_linear(x)
279 |
280 | mel_linear = self.mel_linear(x)
281 |
282 | postnet = self.postnet(mel_linear)
283 |
284 | out = postnet + mel_linear
285 |
286 | return out, mel_linear, stop_linear
287 |
288 |
289 | class TransformerTTS(nn.Module):
290 | def __init__(
291 | self,
292 | vocab_size,
293 | embed_dims,
294 | hidden_dims,
295 | heads,
296 | forward_expansion,
297 | num_layers,
298 | dropout,
299 | mel_dims,
300 | max_len,
301 | pad_idx
302 | ):
303 | super(TransformerTTS, self).__init__()
304 | self.encoder = Encoder(
305 | vocab_size,
306 | embed_dims,
307 | hidden_dims,
308 | max_len,
309 | heads,
310 | forward_expansion,
311 | num_layers,
312 | dropout
313 | )
314 |
315 | self.decoder = Decoder(
316 | mel_dims,
317 | hidden_dims,
318 | heads,
319 | max_len,
320 | num_layers,
321 | forward_expansion,
322 | dropout
323 | )
324 |
325 | self.pad_idx = pad_idx
326 |
327 | def target_mask(self, mel, mel_mask):
328 | seq_len = mel.shape[1]
329 | pad_mask = (mel_mask != self.pad_idx).unsqueeze(1).unsqueeze(3)
330 | causal_mask = torch.tril(torch.ones((1, seq_len, seq_len))).unsqueeze(1)
331 | return pad_mask, causal_mask
332 |
333 | def input_mask(self, x):
334 | mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2)
335 | return mask
336 |
337 | def forward(self, text_idx, mel, mel_mask):
338 | input_pad_mask = self.input_mask(text_idx)
339 | target_pad_mask, causal_mask = self.target_mask(mel, mel_mask)
340 | encoder_out = self.encoder(text_idx, input_pad_mask)
341 | mel_postout, mel_linear, stop_linear = self.decoder(mel, encoder_out, target_pad_mask, causal_mask)
342 | return mel_postout, mel_linear, stop_linear
343 |
344 |
345 | if __name__ == "__main__":
346 | a = torch.randint(0, 30, (4, 60))
347 | mel = torch.randn(4, 128, 80)
348 | mask = torch.ones((4, 128))
349 | model = TransformerTTS(
350 | vocab_size=30,
351 | embed_dims=512,
352 | hidden_dims=256,
353 | heads=4,
354 | forward_expansion=4,
355 | num_layers=6,
356 | dropout=0.1,
357 | mel_dims=80,
358 | max_len=512,
359 | pad_idx=0
360 | )
361 | x, y, z = model(a, mel, mask)
362 | print(x.shape, y.shape, z.shape)
--------------------------------------------------------------------------------