├── LICENSE ├── README.md ├── ckpt ├── checkpoint_step500000.pth ├── step500000_alignment.png ├── step500000_predicted.wav ├── step500000_predicted_spectrogram.png └── step500000_target_spectrogram.png ├── config.py ├── data ├── LJSpeech-1.1 │ └── metadata.csv ├── meta │ └── meta_text.txt └── test_transcripts.txt ├── dataloader.py ├── model ├── attention.py ├── loss.py └── tacotron.py ├── preprocess.py ├── requirements.txt ├── result ├── 500000 │ ├── 1.wav │ ├── 1_alignment.png │ ├── 1_all.png │ ├── 1_spectrogram.png │ ├── 2.wav │ ├── 2_alignment.png │ ├── 2_all.png │ ├── 2_spectrogram.png │ ├── 3.wav │ ├── 3_alignment.png │ ├── 3_all.png │ ├── 3_spectrogram.png │ ├── 4.wav │ ├── 4_alignment.png │ ├── 4_all.png │ ├── 4_spectrogram.png │ ├── 5.wav │ ├── 5_alignment.png │ ├── 5_all.png │ ├── 5_spectrogram.png │ ├── 6.wav │ ├── 6_alignment.png │ ├── 6_all.png │ ├── 6_spectrogram.png │ ├── 7.wav │ ├── 7_alignment.png │ ├── 7_all.png │ └── 7_spectrogram.png └── checkpoint_step500000_all.png ├── test.py ├── train.py └── utils ├── audio.py ├── data.py ├── plot.py └── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ting-Wei, Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron 2 | A Pytorch implementation of Google's [Tacotron](https://arxiv.org/pdf/1703.10135.pdf) speech synthesis network. 3 | 4 | This implementation also includes the **Location-Sensitive Attention** and the **Stop Token** features from [Tacotron 2](https://arxiv.org/pdf/1712.05884.pdf). 5 | 6 | Furthermore, the model is trained on the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), with trained model provided. 7 | 8 | 9 | 10 | Audio samples can be found in the [result](result/500000) directory. 11 | 12 | ## Introduction 13 | This implementation is based on [r9y9/tacotron_pytorch](https://github.com/r9y9/tacotron_pytorch), the main differences are: 14 | * Adds **Location-Sensitive Attention** and the **Stop Token** from the [Tacotron 2](https://arxiv.org/pdf/1712.05884.pdf) paper. 15 | This can greatly reduce the amount of time and data required to train a model. 16 | * Remove all TensorFlow dependencies that [r9y9](https://github.com/r9y9/tacotron_pytorch) uses, now it **runs on PyTorch and PyTorch only**. 17 | * Adds a [loss](model/loss.py) module, and use L2 (MSE) loss instead of L1 loss. 18 | * Adds a [data loader](dataloader.py) module. 19 | * Incorporate the LJ Speech data preprocessing script from [keithito](https://github.com/keithito/tacotron). 20 | * Code factoring and optimization for easier debug and extend in the furture. 21 | 22 | Furthermore, some differences from the original [Tacotron](https://arxiv.org/pdf/1703.10135.pdf) paper are: 23 | * Predict r=5 non-overlapping consecutive out-put frames at each decoder step instead of r=2. 24 | * Feed all r frames to the next decoder input step instead of just the last frame of r frames. 25 | * Scale the loss on predicted linear spectrograms so that lower frequencies that corresponds to human speech (0 to 3000 Hz) weighs more. 26 | * Did not use a loss mask in sequence-to-sequence learning, this forces the model to learn when to stop synthesis. 27 | * Disable bias for the 1-Dimensional convolution unit in the CBHG modulehas. 28 | These implementation details helps the model's convergence. 29 | 30 | Audio quality isn't as good as Google's demo yet, but hopefully it will improve eventually. Pull requests are welcome! 31 | 32 | 33 | ## Quick Start 34 | 35 | ### Setup 36 | * Clone this repo: `git clone git@github.com:andi611/Tacotron-Pytorch.git` 37 | * CD into this repo: `cd Tacotron-Pytorch` 38 | 39 | ### Installing dependencies 40 | 41 | 1. Install Python 3. 42 | 43 | 2. Install the latest version of **[Pytorch](https://pytorch.org/get-started/locally/)** according to your platform. For better 44 | performance, install with GPU support (CUDA) if viable. This code works with Pytorch 0.4 and later. 45 | 46 | 3. Install [requirements](requirements.txt): 47 | ``` 48 | pip3 install -r requirements.txt 49 | ``` 50 | *Warning: you need to install torch depending on your platform. Here list the Pytorch version used when built this project was built.* 51 | 52 | 53 | ### Training 54 | 55 | 1. **Download the LJ Speech dataset.** 56 | * [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) 57 | 58 | You can use other datasets if you convert them to the right format. See [TRAINING_DATA.md](https://github.com/keithito/tacotron/blob/master/TRAINING_DATA.md) for more info. 59 | 60 | 2. **Unpack the dataset into `~/Tacotron-Pytorch/data`** 61 | 62 | After unpacking, your tree should look like this for LJ Speech: 63 | ``` 64 | |- Tacotron-Pytorch 65 | |- data 66 | |- LJSpeech-1.1 67 | |- metadata.csv 68 | |- wavs 69 | ``` 70 | 71 | 3. **Preprocess the LJ Speech dataset and make model-ready meta files using [preprocess.py](preprocess.py):** 72 | ``` 73 | python3 preprocess.py --mode make 74 | ``` 75 | 76 | After preprocessing, your tree will look like this: 77 | ``` 78 | |- Tacotron-Pytorch 79 | |- data 80 | |- LJSpeech-1.1 (The downloaded dataset) 81 | |- metadata.csv 82 | |- wavs 83 | |- meta (generate by preprocessing) 84 | |- meta_text.txt 85 | |- meta_mel_xxxxx.npy ... 86 | |- meta_spec_xxxxx.npy ... 87 | |- test_transcripts.txt (provided) 88 | ``` 89 | 90 | 4. **Train a model using [train.py](train.py)** 91 | ``` 92 | python3 train.py --ckpt_dir ckpt/ --log_dir log/ 93 | ``` 94 | 95 | Restore training from a previous checkpoint: 96 | ``` 97 | python3 train.py --ckpt_dir ckpt/ --log_dir log/ --model_name 500000 98 | ``` 99 | 100 | Tunable hyperparameters are found in [config.py](config.py). 101 | 102 | You can adjust these parameters and setting by editing the file, the default hyperparameters are recommended for LJ Speech. 103 | 104 | 5. **Monitor with Tensorboard** (OPTIONAL) 105 | ``` 106 | tensorboard --logdir 'path to log_dir' 107 | ``` 108 | 109 | The trainer dumps audio and alignments every 2000 steps by default. You can find these in `tacotron/ckpt/`. 110 | 111 | 112 | ### Testing: Using a pre-trained model and [test.py](test.py) 113 | * **Run the testing environment with interactive mode**: 114 | ``` 115 | python3 test.py --interactive --plot --model_name 500000 116 | ``` 117 | * **Run the testing algorithm on a set of transcripts** (Results can be found in the [result/500000](result/500000) directory) : 118 | ``` 119 | python3 test.py --plot --model_name 500000 --test_file_path ./data/test_transcripts.txt 120 | ``` 121 | 122 | 123 | ## Acknowledgement 124 | Credits to Ryuichi Yamamoto for a wonderful Pytorch [implementation](https://github.com/r9y9/tacotron_pytorch) of Tacotron, which this work is mainly based on. This work is also inspired by [NVIDIA's](https://github.com/NVIDIA/tacotron2) Tacotron 2 PyTorch implementation. 125 | 126 | ## TODO 127 | * Add more configurable hparams 128 | 129 | -------------------------------------------------------------------------------- /ckpt/checkpoint_step500000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/ckpt/checkpoint_step500000.pth -------------------------------------------------------------------------------- /ckpt/step500000_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/ckpt/step500000_alignment.png -------------------------------------------------------------------------------- /ckpt/step500000_predicted.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/ckpt/step500000_predicted.wav -------------------------------------------------------------------------------- /ckpt/step500000_predicted_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/ckpt/step500000_predicted_spectrogram.png -------------------------------------------------------------------------------- /ckpt/step500000_target_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/ckpt/step500000_target_spectrogram.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ config.py ] 4 | # Synopsis [ configurations ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), NTUEE, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import argparse 14 | from multiprocessing import cpu_count 15 | 16 | 17 | ######################## 18 | # MODEL CONFIGURATIONS # 19 | ######################## 20 | class configurations(object): 21 | 22 | def __init__(self): 23 | self.get_audio_config() 24 | self.get_model_config() 25 | self.get_loss_config() 26 | self.get_dataloader_config() 27 | self.get_training_config() 28 | self.get_testing_config() 29 | 30 | def get_audio_config(self): 31 | self.num_mels = 80 32 | self.num_freq = 1025 33 | self.sample_rate = 22050 34 | self.frame_length_ms = 50 35 | self.frame_shift_ms = 12.5 36 | self.preemphasis = 0.97 37 | self.min_level_db = -100 38 | self.ref_level_db = 20 39 | self.hop_length = 250 40 | 41 | def get_model_config(self): 42 | self.embedding_dim = 256 43 | self.outputs_per_step = 5 44 | self.padding_idx = None 45 | self.attention = 'LocationSensitive' # or 'Bahdanau' 46 | self.use_mask = False 47 | 48 | def get_loss_config(self): 49 | self.prior_freq = 3000 50 | self.prior_weight = 0.5 51 | self.gate_coefficient = 0.1 52 | 53 | def get_dataloader_config(self): 54 | self.pin_memory = True 55 | self.num_workers = cpu_count() # or just set 2 56 | 57 | def get_training_config(self): 58 | self.batch_size = 8 59 | self.adam_beta1 = 0.9 60 | self.adam_beta2 = 0.999 61 | self.initial_learning_rate = 0.002 62 | self.decay_learning_rate = True 63 | self.max_epochs = 1000 64 | self.max_steps = 500000 65 | self.weight_decay = 0.0 66 | self.clip_thresh = 1.0 67 | self.checkpoint_interval = 2000 68 | 69 | def get_testing_config(self): 70 | self.max_iters = 200 71 | self.max_decoder_steps = 500 72 | self.griffin_lim_iters = 60 73 | self.power = 1.5 # Power to raise magnitudes to prior to Griffin-Lim 74 | 75 | config = configurations() 76 | 77 | 78 | ########################### 79 | # TRAINING CONFIGURATIONS # 80 | ########################### 81 | def get_training_args(): 82 | parser = argparse.ArgumentParser(description='training arguments') 83 | 84 | parser.add_argument('--ckpt_dir', type=str, default='./ckpt', help='Directory where to save model checkpoints') 85 | parser.add_argument('--model_name', type=str, default=None, help='Restore model from checkpoint path if name is given') 86 | parser.add_argument('--data_root', type=str, default='./data/meta', help='Directory that contains preprocessed model-ready features') 87 | parser.add_argument('--meta_text', type=str, default='meta_text.txt', help='Model-ready training transcripts') 88 | parser.add_argument('--log_dir', type=str, default=None, help='Directory for log summary writer to write in') 89 | parser.add_argument('--log_comment', type=str, default=None, help='Comment to add to the directory for log summary writer') 90 | 91 | args = parser.parse_args() 92 | return args 93 | 94 | 95 | ############################# 96 | # PREPROCESS CONFIGURATIONS # 97 | ############################# 98 | def get_preprocess_args(): 99 | parser = argparse.ArgumentParser(description='preprocess arguments') 100 | 101 | parser.add_argument('--mode', choices=['make', 'analyze', 'all'], default='all', help='what to preprocess') 102 | parser.add_argument('--num_workers', type=int, default=cpu_count(), help='multi-thread processing') 103 | parser.add_argument('--file_suffix', type=str, default='wav', help='audio filename extension') 104 | 105 | meta_path = parser.add_argument_group('meta_path') 106 | meta_path.add_argument('--meta_dir', type=str, default='./data/meta/', help='path to the model-ready training acoustic features') 107 | meta_path.add_argument('--meta_text', type=str, default='meta_text.txt', help='name of the model-ready training transcripts') 108 | 109 | input_path = parser.add_argument_group('input_path') 110 | input_path.add_argument('--text_input_path', type=str, default='./data/LJSpeech-1.1/metadata.csv', help='path to the original training text data') 111 | input_path.add_argument('--audio_input_dir', type=str, default='./data/LJSpeech-1.1/wavs/', help='path to the original training audio data') 112 | 113 | args = parser.parse_args() 114 | return args 115 | 116 | 117 | ####################### 118 | # TEST CONFIGURATIONS # 119 | ####################### 120 | def get_test_args(): 121 | parser = argparse.ArgumentParser(description='testing arguments') 122 | 123 | parser.add_argument('--plot', action='store_true', help='whether to plot') 124 | parser.add_argument('--interactive', action='store_true', help='whether to test in an interactive mode') 125 | 126 | path_parser = parser.add_argument_group('path') 127 | path_parser.add_argument('--result_dir', type=str, default='./result/', help='path to output test results') 128 | path_parser.add_argument('--ckpt_dir', type=str, default='./ckpt/', help='path to the directory where model checkpoints are saved') 129 | path_parser.add_argument('--checkpoint_name', type=str, default='checkpoint_step', help='model name prefix for checkpoint files') 130 | path_parser.add_argument('--model_name', type=str, default='500000', help='model step name for checkpoint files') 131 | path_parser.add_argument('--test_file_path', type=str, default='./data/test_transcripts.txt', help='path to the input test transcripts') 132 | 133 | args = parser.parse_args() 134 | return args 135 | 136 | -------------------------------------------------------------------------------- /data/test_transcripts.txt: -------------------------------------------------------------------------------- 1 | Generative adversarial network or variational auto-encoder. 2 | He has read the whole thing. 3 | He reads books. 4 | Thisss isrealy awhsome. 5 | This is your personal assistant, Google Home. 6 | The quick brown fox jumps over the lazy dog. 7 | Does the quick brown fox jump over the lazy dog? -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ dataloader.py ] 4 | # Synopsis [ data loader for the Tacotron model ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import os 14 | import numpy as np 15 | #----------------# 16 | import torch 17 | from torch.utils import data 18 | from torch.autograd import Variable 19 | #---------------------------------# 20 | from config import config 21 | from utils.text import text_to_sequence 22 | #-------------------------------------# 23 | from nnmnkwii.datasets import FileSourceDataset, FileDataSource 24 | 25 | 26 | #################### 27 | # TEXT DATA SOURCE # 28 | #################### 29 | class TextDataSource(FileDataSource): 30 | def __init__(self, data_root, meta_text): 31 | self.data_root = data_root 32 | self.meta_text = meta_text 33 | #self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] 34 | 35 | def collect_files(self): 36 | meta = os.path.join(self.data_root, self.meta_text) 37 | with open(meta, 'r', encoding='utf-8') as f: 38 | lines = f.readlines() 39 | lines = list(map(lambda l: l.split("|")[-1][:-1], lines)) 40 | return lines 41 | 42 | def collect_features(self, text): 43 | return np.asarray(text_to_sequence(text), dtype=np.int32) 44 | 45 | 46 | ################### 47 | # NPY DATA SOURCE # 48 | ################### 49 | class _NPYDataSource(FileDataSource): 50 | def __init__(self, col, data_root, meta_text): 51 | self.col = col 52 | self.data_root = data_root 53 | self.meta_text = meta_text 54 | 55 | def collect_files(self): 56 | meta = os.path.join(self.data_root, self.meta_text) 57 | with open(meta, 'r', encoding='utf-8') as f: 58 | lines = f.readlines() 59 | lines = list(map(lambda l: l.split("|")[self.col], lines)) 60 | paths = list(map(lambda f: os.path.join(self.data_root, f), lines)) 61 | return paths 62 | 63 | def collect_features(self, path): 64 | return np.load(path) 65 | 66 | 67 | ######################## 68 | # MEL SPEC DATA SOURCE # 69 | ######################## 70 | class MelSpecDataSource(_NPYDataSource): 71 | def __init__(self, data_root, meta_text): 72 | super(MelSpecDataSource, self).__init__(1, data_root, meta_text) 73 | 74 | 75 | ########################### 76 | # LINEAR SPEC DATA SOURCE # 77 | ########################### 78 | class LinearSpecDataSource(_NPYDataSource): 79 | def __init__(self, data_root, meta_text): 80 | super(LinearSpecDataSource, self).__init__(0, data_root, meta_text) 81 | 82 | 83 | ####################### 84 | # PYTORCH DATA SOURCE # 85 | ####################### 86 | class PyTorchDatasetWrapper(object): 87 | def __init__(self, X, Mel, Y): 88 | self.X = X 89 | self.Mel = Mel 90 | self.Y = Y 91 | 92 | def __getitem__(self, idx): 93 | return self.X[idx], self.Mel[idx], self.Y[idx] 94 | 95 | def __len__(self): 96 | return len(self.X) 97 | 98 | 99 | ############## 100 | # COLLATE FN # 101 | ############## 102 | """ 103 | Create batch 104 | """ 105 | def collate_fn(batch): 106 | def _pad(seq, max_len): 107 | return np.pad(seq, (0, max_len - len(seq)), mode='constant', constant_values=0) 108 | 109 | def _pad_2d(x, max_len): 110 | return np.pad(x, [(0, max_len - len(x)), (0, 0)], mode="constant", constant_values=0) 111 | 112 | r = config.outputs_per_step 113 | input_lengths = [len(x[0]) for x in batch] 114 | 115 | max_input_len = np.max(input_lengths) 116 | max_target_len = np.max([len(x[1]) for x in batch]) + 1 # Add single zeros frame at least, so plus 1 117 | 118 | if max_target_len % r != 0: 119 | max_target_len += r - max_target_len % r 120 | assert max_target_len % r == 0 121 | 122 | input_lengths = torch.LongTensor(input_lengths) 123 | sorted_lengths, indices = torch.sort(input_lengths.view(-1), dim=0, descending=True) 124 | sorted_lengths = sorted_lengths.long().numpy() 125 | 126 | x_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.int) 127 | x_batch = torch.LongTensor(x_batch) 128 | 129 | mel_batch = np.array([_pad_2d(x[1], max_target_len) for x in batch], dtype=np.float32) 130 | mel_batch = torch.FloatTensor(mel_batch) 131 | 132 | y_batch = np.array([_pad_2d(x[2], max_target_len) for x in batch], dtype=np.float32) 133 | y_batch = torch.FloatTensor(y_batch) 134 | 135 | gate_batch = torch.FloatTensor(len(batch), max_target_len).zero_() 136 | for i, x in enumerate(batch): gate_batch[i, len(x[1])-1:] = 1 137 | 138 | x_batch, mel_batch, y_batch, gate_batch, = Variable(x_batch[indices]), Variable(mel_batch[indices]), Variable(y_batch[indices]), Variable(gate_batch[indices]) 139 | return x_batch, mel_batch, y_batch, gate_batch, sorted_lengths 140 | 141 | 142 | ############### 143 | # DATA LOADER # 144 | ############### 145 | """ 146 | Create dataloader 147 | """ 148 | def Dataloader(data_root, meta_text): 149 | 150 | # Input dataset definitions 151 | X = FileSourceDataset(TextDataSource(data_root, meta_text)) 152 | Mel = FileSourceDataset(MelSpecDataSource(data_root, meta_text)) 153 | Y = FileSourceDataset(LinearSpecDataSource(data_root, meta_text)) 154 | 155 | # Dataset and Dataloader setup 156 | dataset = PyTorchDatasetWrapper(X, Mel, Y) 157 | data_loader = data.DataLoader(dataset, 158 | batch_size=config.batch_size, 159 | num_workers=config.num_workers, 160 | shuffle=True, 161 | collate_fn=collate_fn, 162 | pin_memory=config.pin_memory) 163 | return data_loader 164 | 165 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ attention.py ] 4 | # Synopsis [ Sequence to sequence attention module for Tacotron ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import torch 14 | from torch.autograd import Variable 15 | from torch import nn 16 | from torch.nn import functional as F 17 | 18 | 19 | ###################### 20 | # BAHDANAU ATTENTION # 21 | ###################### 22 | class BahdanauAttention(nn.Module): 23 | def __init__(self, dim): 24 | super(BahdanauAttention, self).__init__() 25 | self.query_layer = nn.Linear(dim, dim, bias=False) 26 | self.tanh = nn.Tanh() 27 | self.v = nn.Linear(dim, 1, bias=False) 28 | 29 | """ 30 | Args: 31 | query: (batch, 1, dim) or (batch, dim) 32 | processed_memory: (batch, max_time, dim) 33 | """ 34 | def forward(self, query, processed_memory): 35 | if query.dim() == 2: 36 | query = query.unsqueeze(1) # insert time-axis for broadcasting 37 | 38 | processed_query = self.query_layer(query) # (batch, 1, dim) 39 | alignment = self.v(self.tanh(processed_query + processed_memory)) # (batch, max_time, 1) 40 | alignment = alignment.squeeze(-1) # (batch, max_time) 41 | 42 | return alignment 43 | 44 | 45 | ################ 46 | # LINEAR LAYER # 47 | ################ 48 | class LinearNorm(nn.Module): 49 | 50 | def __init__(self, 51 | in_dim, 52 | out_dim, 53 | bias=True, 54 | w_init_gain='linear'): 55 | 56 | super(LinearNorm, self).__init__() 57 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) 58 | 59 | nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain)) 60 | 61 | def forward(self, x): 62 | return self.linear_layer(x) 63 | 64 | 65 | ##################### 66 | # CONVOLUTION LAYER # 67 | ##################### 68 | class ConvNorm(nn.Module): 69 | 70 | def __init__(self, 71 | in_channels, 72 | out_channels, 73 | kernel_size=1, 74 | stride=1, 75 | padding=None, 76 | dilation=1, 77 | bias=True, 78 | w_init_gain='linear'): 79 | 80 | super(ConvNorm, self).__init__() 81 | if padding is None: 82 | assert(kernel_size % 2 == 1) 83 | padding = int(dilation * (kernel_size - 1) / 2) 84 | 85 | self.conv = nn.Conv1d(in_channels, out_channels, 86 | kernel_size=kernel_size, stride=stride, 87 | padding=padding, dilation=dilation, 88 | bias=bias) 89 | 90 | nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) 91 | 92 | def forward(self, signal): 93 | conv_signal = self.conv(signal) 94 | return conv_signal 95 | 96 | 97 | ################## 98 | # LOCATION LAYER # 99 | ################## 100 | class LocationLayer(nn.Module): 101 | 102 | def __init__(self, 103 | attention_n_filters, 104 | attention_kernel_size, 105 | attention_dim): 106 | 107 | super(LocationLayer, self).__init__() 108 | padding = int((attention_kernel_size - 1) / 2) 109 | self.location_conv = ConvNorm(2, attention_n_filters, 110 | kernel_size=attention_kernel_size, 111 | padding=padding, 112 | bias=False, 113 | stride=1, 114 | dilation=1) 115 | self.location_dense = LinearNorm(attention_n_filters, 116 | attention_dim, 117 | bias=False, 118 | w_init_gain='tanh') 119 | 120 | def forward(self, attention_weights_cat): 121 | processed_attention = self.location_conv(attention_weights_cat) 122 | processed_attention = processed_attention.transpose(1, 2) 123 | processed_attention = self.location_dense(processed_attention) 124 | return processed_attention 125 | 126 | 127 | ################################ 128 | # LOCATION SENSITIVE ATTENTION # 129 | ################################ 130 | class LocationSensitiveAttention(nn.Module): 131 | 132 | def __init__(self, 133 | dim, 134 | attention_location_n_filters=32, 135 | attention_location_kernel_size=31): 136 | 137 | super(LocationSensitiveAttention, self).__init__() 138 | self.query_layer = LinearNorm(dim, dim, bias=False, w_init_gain='tanh') 139 | self.location_layer = LocationLayer(attention_location_n_filters, 140 | attention_location_kernel_size, 141 | dim) 142 | self.tanh = nn.Tanh() 143 | self.v = LinearNorm(dim, 1, bias=False) 144 | 145 | """ 146 | Args: 147 | query: (batch, 1, dim) or (batch, dim) 148 | processed_memory: (batch, max_time, dim) 149 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 150 | """ 151 | def forward(self, 152 | query, 153 | processed_memory, 154 | attention_weights_cat): 155 | 156 | if query.dim() == 2: 157 | query = query.unsqueeze(1) # insert time-axis for broadcasting 158 | 159 | processed_query = self.query_layer(query) # (batch, 1, dim) 160 | processed_attention_weights = self.location_layer(attention_weights_cat) 161 | alignment = self.v(self.tanh(processed_query + processed_attention_weights + processed_memory)) # (batch, max_time, 1) 162 | alignment = alignment.squeeze(-1) # (batch, max_time) 163 | 164 | return alignment 165 | 166 | 167 | ################# 168 | # ATTENTION RNN # 169 | ################# 170 | class AttentionRNN(nn.Module): 171 | 172 | def __init__(self, 173 | rnn_cell, 174 | attention_mechanism, 175 | attention, 176 | score_mask_value=-float("inf")): 177 | 178 | super(AttentionRNN, self).__init__() 179 | 180 | self.rnn_cell = rnn_cell 181 | self.attention_mechanism = attention_mechanism 182 | self.attention = attention 183 | if self.attention == 'Bahdanau': 184 | self.memory_layer = nn.Linear(256, 256, bias=False) 185 | elif self.attention == 'LocationSensitive': 186 | self.memory_layer = LinearNorm(256, 256, bias=False, w_init_gain='tanh') 187 | self.score_mask_value = score_mask_value 188 | 189 | def forward(self, 190 | query, 191 | attention, 192 | cell_state, 193 | memory, 194 | attention_weights_cat=None, 195 | processed_memory=None, 196 | mask=None, 197 | memory_lengths=None): 198 | 199 | if self.attention == 'LocationSensitive' and attention_weights_cat is None: 200 | raise RuntimeError('Missing input: attention_weights_cat') 201 | if processed_memory is None: 202 | processed_memory = memory 203 | if memory_lengths is not None and mask is None: 204 | mask = get_mask_from_lengths(memory, memory_lengths) 205 | 206 | cell_input = torch.cat((query, attention), -1) # Concat input query and previous attention context 207 | cell_output = self.rnn_cell(cell_input, cell_state) # Feed it to RNN 208 | 209 | if self.attention == 'Bahdanau': 210 | alignment = self.attention_mechanism(cell_output, processed_memory) # Alignment: (batch, max_time) 211 | elif self.attention == 'LocationSensitive': 212 | alignment = self.attention_mechanism(cell_output, processed_memory, attention_weights_cat) 213 | 214 | if mask is not None: 215 | mask = mask.view(query.size(0), -1) 216 | alignment.data.masked_fill_(mask, self.score_mask_value) 217 | 218 | alignment = F.softmax(alignment, dim=1) # Normalize attention weight 219 | attention = torch.bmm(alignment.unsqueeze(1), memory) # Attention context vector: (batch, 1, dim) 220 | attention = attention.squeeze(1) # (batch, dim) 221 | 222 | return cell_output, attention, alignment 223 | 224 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ loss.py ] 4 | # Synopsis [ Loss for the Tacotron model ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import torch 14 | from torch import nn 15 | from config import config 16 | 17 | 18 | ################# 19 | # TACOTRON LOSS # 20 | ################# 21 | class TacotronLoss(nn.Module): 22 | def __init__(self): 23 | super(TacotronLoss, self).__init__() 24 | 25 | self.sample_rate = config.sample_rate 26 | self.linear_dim = config.num_freq 27 | self.prior_freq = config.prior_freq 28 | self.prior_weight = config.prior_weight 29 | self.gate_coefficient = config.gate_coefficient 30 | 31 | self.criterion = nn.MSELoss() 32 | self.criterion_gate = nn.BCEWithLogitsLoss() 33 | 34 | def forward(self, model_output, targets): 35 | mel_outputs, mel = model_output[0], targets[0] 36 | linear_outputs, linear = model_output[1], targets[1] 37 | gate_outputs, gate = model_output[2], targets[2] 38 | 39 | mel.requires_grad = False 40 | linear.requires_grad = False 41 | gate.requires_grad = False 42 | 43 | mel_loss = self.criterion(mel_outputs, mel) 44 | n_priority_freq = int(self.prior_freq / (self.sample_rate * 0.5) * self.linear_dim) 45 | linear_loss = (1 - self.prior_weight) * self.criterion(linear_outputs, linear) + self.prior_weight * self.criterion(linear_outputs[:, :, :n_priority_freq], linear[:, :, :n_priority_freq]) 46 | gate_loss = self.gate_coefficient * self.criterion_gate(gate_outputs, gate) 47 | 48 | loss = mel_loss + linear_loss + gate_loss 49 | losses = [loss, mel_loss, linear_loss, gate_loss] 50 | return losses 51 | 52 | 53 | ######################### 54 | # GET MASK FROM LENGTHS # 55 | ######################### 56 | """ 57 | Get mask tensor from list of length 58 | 59 | Args: 60 | memory: (batch, max_time, dim) 61 | memory_lengths: array like 62 | """ 63 | def get_rnn_mask_from_lengths(memory, memory_lengths): 64 | mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_() 65 | for idx, l in enumerate(memory_lengths): 66 | mask[idx][:l] = 1 67 | return ~mask 68 | 69 | -------------------------------------------------------------------------------- /model/tacotron.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ tacotron.py ] 4 | # Synopsis [ Tacotron model in Pytorch ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import torch 14 | from torch import nn 15 | from torch.autograd import Variable 16 | from model.attention import BahdanauAttention, LocationSensitiveAttention 17 | from model.attention import AttentionRNN 18 | from model.loss import get_rnn_mask_from_lengths 19 | 20 | 21 | ########## 22 | # PRENET # 23 | ########## 24 | class Prenet(nn.Module): 25 | 26 | def __init__(self, in_dim, sizes=[256, 128]): 27 | 28 | super(Prenet, self).__init__() 29 | in_sizes = [in_dim] + sizes[:-1] 30 | self.layers = nn.ModuleList([nn.Linear(in_size, out_size) for (in_size, out_size) in zip(in_sizes, sizes)]) 31 | self.relu = nn.ReLU() 32 | self.dropout = nn.Dropout(0.5) 33 | 34 | def forward(self, inputs): 35 | for linear in self.layers: 36 | inputs = self.dropout(self.relu(linear(inputs))) 37 | return inputs 38 | 39 | 40 | ##################### 41 | # BATCH NORM CONV1D # 42 | ##################### 43 | class BatchNormConv1d(nn.Module): 44 | 45 | def __init__(self, in_dim, out_dim, kernel_size, stride, padding, activation=None): 46 | 47 | super(BatchNormConv1d, self).__init__() 48 | self.conv1d = nn.Conv1d(in_dim, 49 | out_dim, 50 | kernel_size=kernel_size, 51 | stride=stride, 52 | padding=padding, 53 | bias=False) 54 | 55 | self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3) # Following tensorflow's default parameters 56 | self.activation = activation 57 | 58 | def forward(self, x): 59 | x = self.conv1d(x) 60 | if self.activation is not None: 61 | x = self.activation(x) 62 | return self.bn(x) 63 | 64 | 65 | ########### 66 | # HIGHWAY # 67 | ########### 68 | class Highway(nn.Module): 69 | 70 | def __init__(self, in_size, out_size): 71 | 72 | super(Highway, self).__init__() 73 | self.H = nn.Linear(in_size, out_size) 74 | self.H.bias.data.zero_() 75 | self.T = nn.Linear(in_size, out_size) 76 | self.T.bias.data.fill_(-1) 77 | self.relu = nn.ReLU() 78 | self.sigmoid = nn.Sigmoid() 79 | 80 | def forward(self, inputs): 81 | H = self.relu(self.H(inputs)) 82 | T = self.sigmoid(self.T(inputs)) 83 | return H * T + inputs * (1.0 - T) 84 | 85 | 86 | ############### 87 | # CBHG MODULE # 88 | ############### 89 | """ 90 | CBHG module: a recurrent neural network composed of: 91 | - 1-d convolution banks 92 | - Highway networks + residual connections 93 | - Bidirectional gated recurrent units 94 | """ 95 | class CBHG(nn.Module): 96 | 97 | def __init__(self, in_dim, K=16, projections=[128, 128]): 98 | 99 | super(CBHG, self).__init__() 100 | self.in_dim = in_dim 101 | self.relu = nn.ReLU() 102 | self.conv1d_banks = nn.ModuleList( 103 | [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, 104 | padding=k // 2, activation=self.relu) 105 | for k in range(1, K + 1)]) 106 | self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) 107 | 108 | in_sizes = [K * in_dim] + projections[:-1] 109 | activations = [self.relu] * (len(projections) - 1) + [None] 110 | self.conv1d_projections = nn.ModuleList( 111 | [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, 112 | padding=1, activation=ac) 113 | for (in_size, out_size, ac) in zip( 114 | in_sizes, projections, activations)]) 115 | 116 | self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) 117 | self.highways = nn.ModuleList( 118 | [Highway(in_dim, in_dim) for _ in range(4)]) 119 | 120 | self.gru = nn.GRU( 121 | in_dim, in_dim, 1, batch_first=True, bidirectional=True) 122 | 123 | def forward(self, inputs, input_lengths=None): 124 | 125 | x = inputs # (B, T_in, in_dim) 126 | 127 | if x.size(-1) == self.in_dim: # Needed to perform conv1d on time-axis: (B, in_dim, T_in) 128 | x = x.transpose(1, 2) 129 | 130 | T = x.size(-1) 131 | 132 | x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) # (B, in_dim*K, T_in) -> Concat conv1d bank outputs 133 | assert x.size(1) == self.in_dim * len(self.conv1d_banks) 134 | x = self.max_pool1d(x)[:, :, :T] 135 | 136 | for conv1d in self.conv1d_projections: 137 | x = conv1d(x) 138 | 139 | x = x.transpose(1, 2) # (B, T_in, in_dim) -> Back to the original shape 140 | 141 | if x.size(-1) != self.in_dim: 142 | x = self.pre_highway(x) 143 | 144 | 145 | x += inputs # Residual connection 146 | for highway in self.highways: 147 | x = highway(x) 148 | 149 | if input_lengths is not None: 150 | x = nn.utils.rnn.pack_padded_sequence( 151 | x, input_lengths, batch_first=True) 152 | 153 | outputs, _ = self.gru(x) # (B, T_in, in_dim*2) 154 | 155 | if input_lengths is not None: 156 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 157 | outputs, batch_first=True) 158 | 159 | return outputs 160 | 161 | 162 | ########### 163 | # ENCODER # 164 | ########### 165 | class Encoder(nn.Module): 166 | 167 | def __init__(self, in_dim): 168 | 169 | super(Encoder, self).__init__() 170 | self.prenet = Prenet(in_dim, sizes=[256, 128]) 171 | self.cbhg = CBHG(128, K=16, projections=[128, 128]) 172 | 173 | def forward(self, inputs, input_lengths=None): 174 | inputs = self.prenet(inputs) 175 | return self.cbhg(inputs, input_lengths) 176 | 177 | 178 | ########### 179 | # DECODER # 180 | ########### 181 | class Decoder(nn.Module): 182 | 183 | def __init__(self, in_dim, r, attention): 184 | 185 | super(Decoder, self).__init__() 186 | self.in_dim = in_dim 187 | self.r = r 188 | self.prenet = Prenet(in_dim * r, sizes=[256, 128]) 189 | # (prenet_out + attention context) -> output 190 | 191 | self.attention = attention 192 | if self.attention == 'Bahdanau': 193 | self.attention_rnn = AttentionRNNh(nn.GRUCell(256 + 128, 256), 194 | BahdanauAttention(256), 195 | attention='Bahdanau') 196 | elif self.attention == 'LocationSensitive': 197 | self.attention_rnn = AttentionRNN(nn.GRUCell(256 + 128, 256), 198 | LocationSensitiveAttention(256), 199 | attention='LocationSensitive') 200 | else: raise NotImplementedError 201 | 202 | self.project_to_decoder_in = nn.Linear(512, 256) 203 | self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) 204 | 205 | self.proj_to_mel = nn.Linear(256, in_dim * self.r) 206 | self.proj_to_gate = nn.Linear(256, 1 * self.r) 207 | self.sigmoid = nn.Sigmoid() 208 | 209 | self.max_decoder_steps = 800 210 | 211 | 212 | def initialize_decoder_states(self, encoder_outputs, processed_memory, memory_lengths): 213 | B = encoder_outputs.size(0) 214 | MAX_TIME = encoder_outputs.size(1) 215 | 216 | self.initial_input = Variable(encoder_outputs.data.new(B, self.in_dim * self.r).zero_()) # go frames 217 | 218 | self.attention_rnn_hidden = Variable(encoder_outputs.data.new(B, 256).zero_()) # Init decoder states 219 | self.decoder_rnn_hiddens = [Variable(encoder_outputs.data.new(B, 256).zero_()) for _ in range(len(self.decoder_rnns))] 220 | self.current_attention = Variable(encoder_outputs.data.new(B, 256).zero_()) 221 | 222 | if self.attention == 'LocationSensitive': 223 | self.attention_weights = Variable(encoder_outputs.data.new(B, MAX_TIME).zero_()) 224 | self.attention_weights_cum = Variable(encoder_outputs.data.new(B, MAX_TIME).zero_()) 225 | 226 | if memory_lengths is not None: 227 | self.mask = get_rnn_mask_from_lengths(processed_memory, memory_lengths) 228 | else: 229 | self.mask = None 230 | 231 | 232 | """ 233 | Decoder forward step. 234 | 235 | If decoder inputs are not given (e.g., at testing time), as noted in Tacotron paper, greedy decoding is adapted. 236 | 237 | Args: 238 | encoder_outputs: Encoder outputs. (B, T_encoder, dim) 239 | inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time), decoder outputs are used as decoder inputs. 240 | memory_lengths: Encoder output (memory) lengths. If not None, used for attention masking. 241 | """ 242 | def forward(self, encoder_outputs, inputs=None, memory_lengths=None): 243 | 244 | greedy = inputs is None # Run greedy decoding if inputs is None 245 | 246 | if inputs is not None: 247 | if inputs.size(-1) == self.in_dim: 248 | inputs = inputs.view(encoder_outputs.size(0), inputs.size(1) // self.r, -1) # Grouping multiple frames if necessary 249 | assert inputs.size(-1) == self.in_dim * self.r 250 | T_decoder = inputs.size(1) 251 | inputs = inputs.transpose(0, 1) # Time first (T_decoder, B, in_dim) 252 | 253 | 254 | processed_memory = self.attention_rnn.memory_layer(encoder_outputs) 255 | self.initialize_decoder_states(encoder_outputs, processed_memory, memory_lengths) 256 | 257 | t = 0 258 | gates = [] 259 | outputs = [] 260 | alignments = [] 261 | current_input = self.initial_input 262 | 263 | while True: 264 | if t > 0: current_input = outputs[-1] if greedy else inputs[t - 1] 265 | 266 | current_input = self.prenet(current_input) # Prenet 267 | 268 | if self.attention == 'Bahdanau': 269 | self.attention_rnn_hidden, self.current_attention, alignment = self.attention_rnn(current_input, 270 | self.current_attention, 271 | self.attention_rnn_hidden, 272 | encoder_outputs, 273 | processed_memory=processed_memory, 274 | mask=self.mask) 275 | elif self.attention == 'LocationSensitive': 276 | self.attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) 277 | 278 | self.attention_rnn_hidden, self.current_attention, alignment = self.attention_rnn(current_input, 279 | self.current_attention, 280 | self.attention_rnn_hidden, 281 | encoder_outputs, 282 | self.attention_weights_cat, 283 | processed_memory=processed_memory, 284 | mask=self.mask) 285 | self.attention_weights_cum += self.attention_weights 286 | 287 | decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.current_attention), -1)) # Concat RNN output and attention context vector 288 | 289 | # Pass through the decoder RNNs 290 | for idx in range(len(self.decoder_rnns)): 291 | self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx]) 292 | decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input # Residual connectinon 293 | 294 | output = decoder_input 295 | gate = self.sigmoid(self.proj_to_gate(output)).squeeze() 296 | output = self.proj_to_mel(output) 297 | 298 | outputs += [output] 299 | alignments += [alignment] 300 | gates += [gate] 301 | 302 | t += 1 303 | 304 | # testing 305 | if greedy: 306 | if t > 1 and is_end_of_gates(gate): 307 | print('Terminated by gate!') 308 | break 309 | elif t > 1 and is_end_of_frames(output): 310 | print('Terminated by silent frames!') 311 | break 312 | elif t > self.max_decoder_steps: 313 | print('Warning! doesn\'t seems to be converged') 314 | break 315 | # training 316 | else: 317 | if t >= T_decoder: 318 | break 319 | 320 | assert greedy or len(outputs) == T_decoder 321 | 322 | # Back to batch first: (T_out, B) -> (B, T_out) 323 | alignments = torch.stack(alignments).transpose(0, 1) 324 | outputs = torch.stack(outputs).transpose(0, 1).contiguous() 325 | gates = torch.stack(gates).transpose(0, 1).contiguous() 326 | 327 | return outputs, alignments, gates 328 | 329 | 330 | def is_end_of_gates(gate, thd=0.5): 331 | return (gate.data >= thd).all() 332 | 333 | 334 | def is_end_of_frames(output, eps=0.2): 335 | return (output.data <= eps).all() 336 | 337 | 338 | ############ 339 | # TACOTRON # 340 | ############ 341 | class Tacotron(nn.Module): 342 | 343 | def __init__(self, n_vocab, embedding_dim=256, mel_dim=80, linear_dim=1025, 344 | r=5, padding_idx=None, attention='Bahdanau', use_mask=False): 345 | 346 | super(Tacotron, self).__init__() 347 | self.mel_dim = mel_dim 348 | self.linear_dim = linear_dim 349 | self.use_mask = use_mask 350 | 351 | self.embedding = nn.Embedding(n_vocab, embedding_dim, padding_idx=padding_idx) 352 | self.embedding.weight.data.normal_(0, 0.3) # Trying smaller std 353 | 354 | self.encoder = Encoder(embedding_dim) 355 | self.decoder = Decoder(mel_dim, r, attention) 356 | 357 | self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) 358 | self.last_linear = nn.Linear(mel_dim * 2, linear_dim) 359 | 360 | def forward(self, inputs, targets=None, input_lengths=None): 361 | B = inputs.size(0) 362 | 363 | inputs = self.embedding(inputs) 364 | 365 | encoder_outputs = self.encoder(inputs, input_lengths) # (B, T', in_dim) 366 | 367 | if self.use_mask: memory_lengths = input_lengths 368 | else: memory_lengths = None 369 | 370 | mel_outputs, alignments, gate_outputs = self.decoder(encoder_outputs, targets, memory_lengths=memory_lengths) # (B, T', mel_dim*r) 371 | 372 | # Post net processing below 373 | 374 | mel_outputs = mel_outputs.view(B, -1, self.mel_dim) # Reshape: (B, T, mel_dim) 375 | gate_outputs = gate_outputs.view(B, -1) # Reshape: (B, T) 376 | 377 | linear_outputs = self.postnet(mel_outputs) 378 | linear_outputs = self.last_linear(linear_outputs) 379 | 380 | return mel_outputs, linear_outputs, gate_outputs, alignments 381 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ preprocess.py ] 4 | # Synopsis [ preprocess text transcripts and audio speech for the Tacotron model ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import os 14 | import glob 15 | import librosa 16 | import argparse 17 | from utils import data 18 | from tqdm import tqdm 19 | from config import config, get_preprocess_args 20 | 21 | 22 | ############# 23 | # MAKE META # 24 | ############# 25 | def make_meta(text_input_path, audio_input_dir, meta_dir, meta_text, file_suffix, num_workers, frame_shift_ms): 26 | os.makedirs(meta_dir, exist_ok=True) 27 | metadata = data.build_from_path(text_input_path, audio_input_dir, meta_dir, file_suffix, num_workers, tqdm=tqdm) 28 | data.write_meta_data(metadata, meta_dir, meta_text, frame_shift_ms) 29 | 30 | 31 | #################### 32 | # DATASET ANALYSIS # 33 | #################### 34 | def dataset_analysis(wav_dir, file_suffix): 35 | 36 | audios = sorted(glob.glob(os.path.join(wav_dir, '*.' + file_suffix))) 37 | print('Training data count: ', len(audios)) 38 | 39 | duration = 0.0 40 | max_d = 0 41 | min_d = 60 42 | for audio in tqdm(audios): 43 | y, sr = librosa.load(audio) 44 | d = librosa.get_duration(y=y, sr=sr) 45 | if d > max_d: max_d = d 46 | if d < min_d: min_d = d 47 | duration += d 48 | 49 | print('Sample rate: ', sr) 50 | print('Speech total length (hr): ', duration / 60**2) 51 | print('Max duration (seconds): ', max_d) 52 | print('Min duration (seconds): ', min_d) 53 | print('Average duration (seconds): ', duration / len(audios)) 54 | 55 | 56 | ######## 57 | # MAIN # 58 | ######## 59 | def main(): 60 | 61 | args = get_preprocess_args() 62 | 63 | if args.mode == 'all' or args.mode == 'make' or args.mode == 'analyze': 64 | 65 | #---preprocess text and data to be model ready---# 66 | if args.mode == 'all' or args.mode == 'make': 67 | make_meta(args.text_input_path, args.audio_input_dir, args.meta_dir, args.meta_text, args.file_suffix, args.num_workers, config.frame_shift_ms) 68 | 69 | #---dataset analyze---# 70 | if args.mode == 'all' or args.mode == 'analyze': 71 | dataset_analysis(args.audio_input_dir, args.file_suffix) 72 | 73 | else: 74 | raise RuntimeError('Invalid mode!') 75 | 76 | 77 | 78 | if __name__ == '__main__': 79 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Note: you need to install torch and tensorflow / tensorflow-gpu depending on your platform. Here we list the Pytorch and tensorflow version that we use when we built this project. 2 | librosa==0.6.2 3 | matplotlib==2.2.3 4 | nnmnkwii==0.0.17 5 | numpy==1.15.4 6 | pycuda==2018.1.1 7 | scipy==1.2.0 8 | tensorboardX==1.5 9 | torch==0.4.1 10 | tqdm==4.29.0 -------------------------------------------------------------------------------- /result/500000/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/1.wav -------------------------------------------------------------------------------- /result/500000/1_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/1_alignment.png -------------------------------------------------------------------------------- /result/500000/1_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/1_all.png -------------------------------------------------------------------------------- /result/500000/1_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/1_spectrogram.png -------------------------------------------------------------------------------- /result/500000/2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/2.wav -------------------------------------------------------------------------------- /result/500000/2_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/2_alignment.png -------------------------------------------------------------------------------- /result/500000/2_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/2_all.png -------------------------------------------------------------------------------- /result/500000/2_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/2_spectrogram.png -------------------------------------------------------------------------------- /result/500000/3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/3.wav -------------------------------------------------------------------------------- /result/500000/3_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/3_alignment.png -------------------------------------------------------------------------------- /result/500000/3_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/3_all.png -------------------------------------------------------------------------------- /result/500000/3_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/3_spectrogram.png -------------------------------------------------------------------------------- /result/500000/4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/4.wav -------------------------------------------------------------------------------- /result/500000/4_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/4_alignment.png -------------------------------------------------------------------------------- /result/500000/4_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/4_all.png -------------------------------------------------------------------------------- /result/500000/4_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/4_spectrogram.png -------------------------------------------------------------------------------- /result/500000/5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/5.wav -------------------------------------------------------------------------------- /result/500000/5_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/5_alignment.png -------------------------------------------------------------------------------- /result/500000/5_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/5_all.png -------------------------------------------------------------------------------- /result/500000/5_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/5_spectrogram.png -------------------------------------------------------------------------------- /result/500000/6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/6.wav -------------------------------------------------------------------------------- /result/500000/6_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/6_alignment.png -------------------------------------------------------------------------------- /result/500000/6_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/6_all.png -------------------------------------------------------------------------------- /result/500000/6_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/6_spectrogram.png -------------------------------------------------------------------------------- /result/500000/7.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/7.wav -------------------------------------------------------------------------------- /result/500000/7_alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/7_alignment.png -------------------------------------------------------------------------------- /result/500000/7_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/7_all.png -------------------------------------------------------------------------------- /result/500000/7_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/500000/7_spectrogram.png -------------------------------------------------------------------------------- /result/checkpoint_step500000_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andi611/TTS-Tacotron-Pytorch/d6004eb491abd231249a62038eb66ebb5713eabd/result/checkpoint_step500000_all.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ test.py ] 4 | # Synopsis [ Testing algorithms for a trained Tacotron model ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import os 14 | import sys 15 | import nltk 16 | import argparse 17 | import librosa 18 | import librosa.display 19 | import numpy as np 20 | from tqdm import tqdm 21 | #--------------------------------# 22 | import torch 23 | from torch.autograd import Variable 24 | #--------------------------------# 25 | from utils import audio 26 | from utils.text import text_to_sequence, symbols 27 | from utils.plot import test_visualize, plot_alignment 28 | #--------------------------------# 29 | from model.tacotron import Tacotron 30 | from config import config, get_test_args 31 | 32 | 33 | ############ 34 | # CONSTANT # 35 | ############ 36 | USE_CUDA = torch.cuda.is_available() 37 | 38 | 39 | ################## 40 | # TEXT TO SPEECH # 41 | ################## 42 | def tts(model, text): 43 | """Convert text to speech waveform given a Tacotron model. 44 | """ 45 | if USE_CUDA: 46 | model = model.cuda() 47 | 48 | # NOTE: dropout in the decoder should be activated for generalization! 49 | # model.decoder.eval() 50 | model.encoder.eval() 51 | model.postnet.eval() 52 | 53 | sequence = np.array(text_to_sequence(text)) 54 | sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0) 55 | if USE_CUDA: 56 | sequence = sequence.cuda() 57 | 58 | # Greedy decoding 59 | mel_outputs, linear_outputs, gate_outputs, alignments = model(sequence) 60 | 61 | linear_output = linear_outputs[0].cpu().data.numpy() 62 | spectrogram = audio._denormalize(linear_output) 63 | alignment = alignments[0].cpu().data.numpy() 64 | 65 | # Predicted audio signal 66 | waveform = audio.inv_spectrogram(linear_output.T) 67 | 68 | return waveform, alignment, spectrogram 69 | 70 | 71 | #################### 72 | # SYNTHESIS SPEECH # 73 | #################### 74 | def synthesis_speech(model, text, figures=True, path=None): 75 | waveform, alignment, spectrogram = tts(model, text) 76 | if figures: 77 | test_visualize(alignment, spectrogram, path) 78 | librosa.output.write_wav(path + '.wav', waveform, config.sample_rate) 79 | 80 | 81 | ######## 82 | # MAIN # 83 | ######## 84 | def main(): 85 | 86 | #---initialize---# 87 | args = get_test_args() 88 | 89 | model = Tacotron(n_vocab=len(symbols), 90 | embedding_dim=config.embedding_dim, 91 | mel_dim=config.num_mels, 92 | linear_dim=config.num_freq, 93 | r=config.outputs_per_step, 94 | padding_idx=config.padding_idx, 95 | attention=config.attention, 96 | use_mask=config.use_mask) 97 | 98 | #---handle path---# 99 | checkpoint_path = os.path.join(args.ckpt_dir, args.checkpoint_name + args.model_name + '.pth') 100 | os.makedirs(args.result_dir, exist_ok=True) 101 | 102 | #---load and set model---# 103 | print('Loading model: ', checkpoint_path) 104 | checkpoint = torch.load(checkpoint_path) 105 | model.load_state_dict(checkpoint["state_dict"]) 106 | model.decoder.max_decoder_steps = config.max_decoder_steps # Set large max_decoder steps to handle long sentence outputs 107 | 108 | if args.interactive == True: 109 | output_name = args.result_dir + args.model_name 110 | 111 | #---testing loop---# 112 | while True: 113 | try: 114 | text = str(input('< Tacotron > Text to speech: ')) 115 | print('Model input: ', text) 116 | synthesis_speech(model, text=text, figures=args.plot, path=output_name) 117 | except KeyboardInterrupt: 118 | print() 119 | print('Terminating!') 120 | break 121 | 122 | elif args.interactive == False: 123 | output_name = args.result_dir + args.model_name + '/' 124 | os.makedirs(output_name, exist_ok=True) 125 | 126 | #---testing flow---# 127 | with open(args.test_file_path, 'r', encoding='utf-8') as f: 128 | 129 | lines = f.readlines() 130 | for idx, line in enumerate(lines): 131 | print("{}: {} - ({} chars)".format(idx+1, line, len(line))) 132 | synthesis_speech(model, text=line, figures=args.plot, path=output_name+str(idx+1)) 133 | 134 | print("Finished! Check out {} for generated audio samples.".format(output_name)) 135 | 136 | else: 137 | raise RuntimeError('Invalid mode!!!') 138 | 139 | sys.exit(0) 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ train.py ] 4 | # Synopsis [ Trainining script for Tacotron speech synthesis model ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | """ 11 | Usage: train.py [options] 12 | 13 | Options: 14 | --ckpt_dir Directory where to save model checkpoints [default: checkpoints]. 15 | --model_name Restore model from checkpoint path if name is given. 16 | --data_root Directory contains preprocessed features. 17 | --meta_text Name of the model-ready training transcript. 18 | --log_dir Directory for log summary writer to write in. 19 | --log_comment Comment for log summary writer. 20 | -h, --help Show this help message and exit 21 | """ 22 | 23 | 24 | ############### 25 | # IMPORTATION # 26 | ############### 27 | import os 28 | import sys 29 | import time 30 | #----------------# 31 | import numpy as np 32 | #---------------------# 33 | from utils import audio 34 | from utils.plot import plot_alignment, plot_spectrogram 35 | from utils.text import symbols 36 | #----------------------------------------------# 37 | import torch 38 | from torch import nn 39 | from torch import optim 40 | import torch.backends.cudnn as cudnn 41 | #----------------------------------------# 42 | from model.tacotron import Tacotron 43 | from model.loss import TacotronLoss 44 | from config import config, get_training_args 45 | from dataloader import Dataloader 46 | #------------------------------------------# 47 | from tensorboardX import SummaryWriter 48 | 49 | 50 | #################### 51 | # GLOBAL VARIABLES # 52 | #################### 53 | global_step = 0 54 | global_epoch = 0 55 | USE_CUDA = torch.cuda.is_available() 56 | if USE_CUDA: 57 | cudnn.benchmark = False 58 | 59 | 60 | ####################### 61 | # LEARNING RATE DECAY # 62 | ####################### 63 | def _learning_rate_decay(init_lr, global_step): 64 | warmup_steps = 6000.0 65 | step = global_step + 1. 66 | lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5) 67 | return lr 68 | 69 | 70 | ############### 71 | # SAVE STATES # 72 | ############### 73 | def save_states(global_step, mel_outputs, linear_outputs, attn, y, checkpoint_dir=None): 74 | 75 | 76 | idx = 1 # idx = np.random.randint(0, len(mel_outputs)) 77 | 78 | # Alignment 79 | path = os.path.join(checkpoint_dir, "step{}_alignment.png".format(global_step)) 80 | alignment = attn[idx].cpu().data.numpy() # alignment = attn[idx].cpu().data.numpy()[:, :input_length] 81 | plot_alignment(alignment.T, path, info="tacotron, step={}".format(global_step)) 82 | 83 | # Predicted spectrogram 84 | path = os.path.join(checkpoint_dir, "step{}_predicted_spectrogram.png".format(global_step)) 85 | linear_output = linear_outputs[idx].cpu().data.numpy() 86 | plot_spectrogram(linear_output, path) 87 | 88 | # Predicted audio signal 89 | signal = audio.inv_spectrogram(linear_output.T) 90 | path = os.path.join(checkpoint_dir, "step{}_predicted.wav".format(global_step)) 91 | audio.save_wav(signal, path) 92 | 93 | # Target spectrogram 94 | path = os.path.join(checkpoint_dir, "step{}_target_spectrogram.png".format(global_step)) 95 | linear_output = y[idx].cpu().data.numpy() 96 | plot_spectrogram(linear_output, path) 97 | 98 | 99 | ################### 100 | # SAVE CHECKPOINT # 101 | ################### 102 | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch): 103 | checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_step{}.pth".format(global_step)) 104 | torch.save({"state_dict": model.state_dict(), 105 | "optimizer": optimizer.state_dict(), 106 | "global_step": step, 107 | "global_epoch": epoch,}, 108 | checkpoint_path) 109 | 110 | 111 | ################# 112 | # TACOTRON STEP # 113 | ################# 114 | """ 115 | One step of training: Train a single batch of data on Tacotron 116 | """ 117 | def tacotron_step(model, optimizer, criterion, 118 | x, mel, y, gate, sorted_lengths, 119 | init_lr, clip_thresh, global_step): 120 | 121 | #---decay learning rate---# 122 | current_lr = _learning_rate_decay(init_lr, global_step) 123 | for param_group in optimizer.param_groups: 124 | param_group['lr'] = current_lr 125 | 126 | #---feed data---# 127 | if USE_CUDA: 128 | x, mel, y, gate, = x.cuda(), mel.cuda(), y.cuda(), gate.cuda() 129 | mel_outputs, linear_outputs, gate_outputs, attn = model(x, mel, input_lengths=sorted_lengths) 130 | 131 | losses = criterion([mel_outputs, linear_outputs, gate_outputs], [mel, y, gate]) 132 | 133 | #---log loss---# 134 | loss, total_L = losses[0], losses[0].item() 135 | mel_loss, mel_L = losses[1], losses[1].item(), 136 | linear_loss, linear_L = losses[2], losses[2].item() 137 | gate_loss, gate_L = losses[3], losses[3].item() 138 | 139 | #---update model---# 140 | optimizer.zero_grad() 141 | loss.backward() 142 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_thresh) 143 | optimizer.step() 144 | 145 | #---wrap up returns---# 146 | Ms = { 'mel_outputs' : mel_outputs, 147 | 'linear_outputs' : linear_outputs, 148 | 'attn' : attn, 149 | 'sorted_lengths' : sorted_lengths, 150 | 'grad_norm' : grad_norm, 151 | 'current_lr' : current_lr } 152 | Ls = { 'total_L': total_L, 153 | 'mel_L' : mel_L, 154 | 'linear_L' : linear_L, 155 | 'gate_L' : gate_L } 156 | 157 | return model, optimizer, Ms, Ls 158 | 159 | 160 | ######### 161 | # TRAIN # 162 | ######### 163 | """ 164 | Main training loop 165 | """ 166 | def train(model, 167 | optimizer, 168 | dataloader, 169 | init_lr=0.002, 170 | log_dir=None, 171 | log_comment=None, 172 | checkpoint_dir=None, 173 | checkpoint_interval=None, 174 | max_epochs=None, 175 | max_steps=None, 176 | clip_thresh=1.0): 177 | 178 | if USE_CUDA: 179 | model = model.cuda() 180 | 181 | model.train() 182 | criterion = TacotronLoss() 183 | 184 | if log_dir != None: 185 | writer = SummaryWriter(log_dir) 186 | elif log_comment != None: 187 | writer = SummaryWriter(comment=log_comment) 188 | else: 189 | writer = SummaryWriter() 190 | 191 | global global_step, global_epoch 192 | 193 | while global_epoch < max_epochs and global_step < max_steps: 194 | 195 | start = time.time() 196 | 197 | for x, mel, y, gate, sorted_lengths in dataloader: 198 | 199 | model, optimizer, Ms, Rs = tacotron_step(model, optimizer, criterion, 200 | x, mel, y, gate, sorted_lengths, 201 | init_lr, clip_thresh, global_step) 202 | 203 | mel_outputs = Ms['mel_outputs'] 204 | linear_outputs = Ms['linear_outputs'] 205 | attn = Ms['attn'] 206 | sorted_lengths = Ms['sorted_lengths'] 207 | grad_norm = Ms['grad_norm'] 208 | current_lr = Ms['current_lr'] 209 | 210 | total_L = Rs['total_L'] 211 | mel_L = Rs['mel_L'] 212 | linear_L = Rs['linear_L'] 213 | gate_L = Rs['gate_L'] 214 | 215 | duration = time.time() - start 216 | if global_step > 0 and global_step % checkpoint_interval == 0: 217 | try: 218 | save_states(global_step, mel_outputs, linear_outputs, attn, y, checkpoint_dir) 219 | save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch) 220 | except: 221 | print() 222 | print('An error has occured during saving! Please attend and handle manually!') 223 | pass 224 | log = '[{}] total_L: {:.3f}, mel_L: {:.3f}, lin_L: {:.3f}, gate_L: {:.3f}, grad: {:.3f}, lr: {:.5f}, t: {:.2f}s, saved: T'.format(global_step, total_L, mel_L, linear_L, gate_L, grad_norm, current_lr, duration) 225 | print(log) 226 | elif global_step % 5 == 0: 227 | log = '[{}] total_L: {:.3f}, mel_L: {:.3f}, lin_L: {:.3f}, gate_L: {:.3f}, grad: {:.3f}, lr: {:.5f}, t: {:.2f}s, saved: F'.format(global_step, total_L, mel_L, linear_L, gate_L, grad_norm, current_lr, duration) 228 | print(log, end='\r') 229 | 230 | # Logs 231 | writer.add_scalar('total_loss', total_L, global_step) 232 | writer.add_scalar('mel_loss', mel_L, global_step) 233 | writer.add_scalar('linear_loss', linear_L, global_step) 234 | writer.add_scalar('gate_loss', gate_L, global_step) 235 | writer.add_scalar('grad_norm', grad_norm, global_step) 236 | writer.add_scalar('learning_rate', current_lr, global_step) 237 | 238 | global_step += 1 239 | start = time.time() 240 | 241 | global_epoch += 1 242 | 243 | ######################## 244 | # WARM FROM CHECKPOINT # 245 | ######################## 246 | """ 247 | Initialize training with a pre-trained model pth 248 | 249 | Args: 250 | checkpoint_path: ckpt/checkpoint_path200000.pth 251 | model: Pytorch model 252 | optimizer: Pytorch optimizer 253 | """ 254 | def warm_from_ckpt(checkpoint_dir, model_name, model, optimizer): 255 | checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_step{}.pth".format(model_name)) 256 | print('[Trainer] - Warming up! Load checkpoint from: {}'.format(checkpoint_path)) 257 | 258 | checkpoint = torch.load(checkpoint_path) 259 | model.load_state_dict(checkpoint['state_dict']) 260 | 261 | optimizer.load_state_dict(checkpoint['optimizer']) 262 | for state in optimizer.state.values(): 263 | for k, v in state.items(): 264 | if torch.is_tensor(v): 265 | state[k] = v.cuda() 266 | try: 267 | global global_step, global_epoch 268 | global_step = checkpoint['global_step'] 269 | global_epoch = checkpoint['global_epoch'] 270 | except: 271 | print('[Trainer] - Warning: global step and global epoch unable to restore!') 272 | sys.exit(0) 273 | 274 | return model, optimizer 275 | 276 | 277 | ####################### 278 | # INITIALIZE TRAINING # 279 | ####################### 280 | """ 281 | Setup and prepare for Tacotron training. 282 | """ 283 | def initialize_training(data_root, meta_text, checkpoint_dir=None, model_name=None): 284 | 285 | dataloader = Dataloader(data_root, meta_text) 286 | 287 | model = Tacotron(n_vocab=len(symbols), 288 | embedding_dim=config.embedding_dim, 289 | mel_dim=config.num_mels, 290 | linear_dim=config.num_freq, 291 | r=config.outputs_per_step, 292 | padding_idx=config.padding_idx, 293 | attention=config.attention, 294 | use_mask=config.use_mask) 295 | 296 | optimizer = optim.Adam(model.parameters(), 297 | lr=config.initial_learning_rate, 298 | betas=(config.adam_beta1, config.adam_beta2), 299 | weight_decay=config.weight_decay) 300 | 301 | # Load checkpoint 302 | if model_name != None: 303 | model, optimizer = warm_from_ckpt(checkpoint_dir, model_name, model, optimizer) 304 | 305 | return model, optimizer, dataloader 306 | 307 | 308 | ######## 309 | # MAIN # 310 | ######## 311 | def main(): 312 | 313 | args = get_training_args() 314 | 315 | os.makedirs(args.ckpt_dir, exist_ok=True) 316 | 317 | model, optimizer, dataloader = initialize_training(args.data_root, args.meta_text, args.ckpt_dir, args.model_name) 318 | 319 | # Train! 320 | try: 321 | train(model, optimizer, dataloader, 322 | init_lr=config.initial_learning_rate, 323 | log_dir=args.log_dir, 324 | log_comment=args.log_comment, 325 | checkpoint_dir=args.ckpt_dir, 326 | checkpoint_interval=config.checkpoint_interval, 327 | max_epochs=config.max_epochs, 328 | max_steps=config.max_steps, 329 | clip_thresh=config.clip_thresh) 330 | except KeyboardInterrupt: 331 | pass 332 | 333 | print() 334 | print('[Trainer] - Finished!') 335 | sys.exit(0) 336 | 337 | 338 | if __name__ == '__main__': 339 | main() -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ audio.py ] 4 | # Synopsis [ audio utility functions ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import math 14 | import numpy as np 15 | import librosa 16 | import librosa.filters 17 | from scipy import signal 18 | from config import config 19 | from scipy.io import wavfile 20 | 21 | 22 | ############# 23 | # FUNCTIONS # 24 | ############# 25 | 26 | 27 | def load_wav(path): 28 | return librosa.core.load(path, sr=config.sample_rate)[0] 29 | 30 | 31 | def save_wav(wav, path): 32 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 33 | wavfile.write(path, config.sample_rate, wav.astype(np.int16)) 34 | 35 | 36 | def preemphasis(x): 37 | return signal.lfilter([1, -config.preemphasis], [1], x) 38 | 39 | 40 | def inv_preemphasis(x): 41 | return signal.lfilter([1], [1, -config.preemphasis], x) 42 | 43 | 44 | def spectrogram(y): 45 | D = _stft(preemphasis(y)) 46 | S = _amp_to_db(np.abs(D)) - config.ref_level_db 47 | return _normalize(S) 48 | 49 | 50 | def inv_spectrogram(spectrogram): 51 | """ 52 | Converts spectrogram to waveform using librosa 53 | """ 54 | S = _db_to_amp(_denormalize(spectrogram) + config.ref_level_db) # Convert back to linear 55 | return inv_preemphasis(_griffin_lim(S ** config.power)) # Reconstruct phase 56 | 57 | 58 | def melspectrogram(y): 59 | D = _stft(preemphasis(y)) 60 | S = _amp_to_db(_linear_to_mel(np.abs(D))) 61 | return _normalize(S) 62 | 63 | 64 | def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8): 65 | window_length = int(config.sample_rate * min_silence_sec) 66 | hop_length = int(window_length / 4) 67 | threshold = _db_to_amp(threshold_db) 68 | for x in range(hop_length, len(wav) - window_length, hop_length): 69 | if np.max(wav[x:x+window_length]) < threshold: 70 | return x + hop_length 71 | return len(wav) 72 | 73 | 74 | def _griffin_lim(S): 75 | """ 76 | librosa implementation of Griffin-Lim 77 | Based on https://github.com/librosa/librosa/issues/434 78 | """ 79 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 80 | S_complex = np.abs(S).astype(np.complex) 81 | y = _istft(S_complex * angles) 82 | for i in range(config.griffin_lim_iters): 83 | angles = np.exp(1j * np.angle(_stft(y))) 84 | y = _istft(S_complex * angles) 85 | return y 86 | 87 | 88 | def _stft(y): 89 | n_fft, hop_length, win_length = _stft_parameters() 90 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 91 | 92 | 93 | def _istft(y): 94 | _, hop_length, win_length = _stft_parameters() 95 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 96 | 97 | 98 | def _stft_parameters(): 99 | n_fft = (config.num_freq - 1) * 2 100 | hop_length = int(config.frame_shift_ms / 1000 * config.sample_rate) 101 | win_length = int(config.frame_length_ms / 1000 * config.sample_rate) 102 | return n_fft, hop_length, win_length 103 | 104 | 105 | ######################## 106 | # CONVERSION FUNCTIONS # 107 | ######################## 108 | 109 | 110 | _mel_basis = None 111 | 112 | def _linear_to_mel(spectrogram): 113 | global _mel_basis 114 | if _mel_basis is None: 115 | _mel_basis = _build_mel_basis() 116 | return np.dot(_mel_basis, spectrogram) 117 | 118 | def _build_mel_basis(): 119 | n_fft = (config.num_freq - 1) * 2 120 | return librosa.filters.mel(config.sample_rate, n_fft, n_mels=config.num_mels) 121 | 122 | def _amp_to_db(x): 123 | return 20 * np.log10(np.maximum(1e-5, x)) 124 | 125 | def _db_to_amp(x): 126 | return np.power(10.0, x * 0.05) 127 | 128 | def _normalize(S): 129 | return np.clip((S - config.min_level_db) / -config.min_level_db, 0, 1) 130 | 131 | def _denormalize(S): 132 | return (np.clip(S, 0, 1) * -config.min_level_db) + config.min_level_db 133 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ data.py ] 4 | # Synopsis [ utility functions for preprocess.py ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import os 14 | import numpy as np 15 | from . import audio 16 | from functools import partial 17 | from concurrent.futures import ProcessPoolExecutor 18 | 19 | 20 | ################### 21 | # WRITE META DATA # 22 | ################### 23 | def write_meta_data(metadata, out_dir, meta_text, frame_shift_ms): 24 | with open(os.path.join(out_dir, meta_text), 'w', encoding='utf-8') as f: 25 | for m in metadata: 26 | f.write('|'.join([str(x) for x in m]) + '\n') 27 | frames = sum([m[2] for m in metadata]) 28 | hours = frames * frame_shift_ms / (3600 * 1000) 29 | print('Wrote %d utterances, %d frames (%.2f hours)' % (len(metadata), frames, hours)) 30 | print('Max input length: %d' % max(len(m[3]) for m in metadata)) 31 | print('Max output length: %d' % max(m[2] for m in metadata)) 32 | 33 | 34 | ################### 35 | # BUILD FROM PATH # 36 | ################### 37 | """ 38 | Preprocesses the dataset from given input paths into a given output directory. 39 | Use ProcessPoolExecutor to parallize across processes. This is just an optimization and you 40 | can omit it and just call _process_utterance on each input if you want. 41 | 42 | Args: 43 | transcript_path: The path to the transcript file 44 | wav_dir: The directory where the audio is contained 45 | out_dir: The directory to write the output into 46 | num_workers: Optional number of worker processes to parallelize across 47 | tqdm: You can optionally pass tqdm to get a nice progress bar 48 | 49 | Returns: 50 | A list of tuples describing the training examples. This should be written to meta_text.txt by 'write_meta_data' 51 | """ 52 | def build_from_path(transcript_path, wav_dir, out_dir, file_suffix, num_workers=1, tqdm=lambda x: x): 53 | executor = ProcessPoolExecutor(max_workers=num_workers) 54 | futures = [] 55 | index = 1 56 | with open(transcript_path, encoding='utf-8') as f: 57 | for line in f: 58 | tokens = line.strip().split('|') 59 | wav_path = os.path.join(wav_dir, '%s.%s' % (tokens[0], file_suffix)) 60 | text = tokens[1] 61 | futures.append(executor.submit(partial(_process_utterance, out_dir, index, wav_path, text))) 62 | index += 1 63 | return [future.result() for future in tqdm(futures)] 64 | 65 | 66 | ##################### 67 | # PROCESS UTTERANCE # 68 | ##################### 69 | """ 70 | Preprocesses a single utterance audio/text pair. 71 | 72 | This writes the mel and linear scale spectrograms to disk and returns a tuple to write 73 | to the meta_text.txt file. 74 | 75 | Args: 76 | out_dir: The directory to write the spectrograms into 77 | index: The numeric index to use in the spectrogram filenames. 78 | wav_path: Path to the audio file containing the speech input 79 | text: The text spoken in the input audio file 80 | 81 | Returns: 82 | A (spectrogram_filename, mel_filename, n_frames, text) tuple to write to meta_text.txt 83 | """ 84 | def _process_utterance(out_dir, index, wav_path, text): 85 | 86 | # Load the audio to a numpy array: 87 | wav = audio.load_wav(wav_path) 88 | 89 | # Compute the linear-scale spectrogram from the wav: 90 | spectrogram = audio.spectrogram(wav).astype(np.float32) 91 | n_frames = spectrogram.shape[1] 92 | 93 | # Compute a mel-scale spectrogram from the wav: 94 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) 95 | 96 | # Write the spectrograms to disk: 97 | spectrogram_filename = 'meta_spec_%05d.npy' % index 98 | mel_filename = 'meta_mel_%05d.npy' % index 99 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 100 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) 101 | 102 | # Return a tuple describing this training example: 103 | return (spectrogram_filename, mel_filename, n_frames, text) 104 | 105 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # FileName [ plot.py ] 4 | # Synopsis [ plot utility functions ] 5 | # Author [ Ting-Wei Liu (Andi611) ] 6 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 7 | """*********************************************************************************************""" 8 | 9 | 10 | ############### 11 | # IMPORTATION # 12 | ############### 13 | import numpy as np 14 | import librosa.display 15 | from . import audio 16 | from config import config 17 | import matplotlib.pyplot as plt 18 | plt.switch_backend('agg') 19 | 20 | 21 | ############# 22 | # CONSTANTS # 23 | ############# 24 | fs = config.sample_rate 25 | win = config.frame_length_ms 26 | hop = config.frame_shift_ms 27 | nfft = (config.num_freq - 1) * 2 28 | hop_length = config.hop_length 29 | 30 | 31 | ################## 32 | # PLOT ALIGNMENT # 33 | ################## 34 | def plot_alignment(alignment, path, info=None): 35 | plt.gcf().clear() 36 | fig, ax = plt.subplots() 37 | im = ax.imshow( 38 | alignment, 39 | aspect='auto', 40 | origin='lower', 41 | interpolation='none') 42 | fig.colorbar(im, ax=ax) 43 | xlabel = 'Decoder timestep' 44 | if info is not None: 45 | xlabel += '\n\n' + info 46 | plt.xlabel(xlabel) 47 | plt.ylabel('Encoder timestep') 48 | plt.tight_layout() 49 | plt.savefig(path, dpi=300, format='png') 50 | plt.close() 51 | 52 | 53 | #################### 54 | # PLOT SPECTROGRAM # 55 | #################### 56 | def plot_spectrogram(linear_output, path): 57 | spectrogram = audio._denormalize(linear_output) 58 | plt.gcf().clear() 59 | plt.figure(figsize=(16, 10)) 60 | plt.imshow(spectrogram.T, aspect="auto", origin="lower") 61 | plt.colorbar() 62 | plt.tight_layout() 63 | plt.savefig(path, dpi=300, format="png") 64 | plt.close() 65 | 66 | 67 | ################## 68 | # TEST VISUALIZE # 69 | ################## 70 | def test_visualize(alignment, spectrogram, path): 71 | 72 | _save_alignment(alignment, path) 73 | _save_spectrogram(spectrogram, path) 74 | label_fontsize = 16 75 | plt.gcf().clear() 76 | plt.figure(figsize=(16,16)) 77 | 78 | plt.subplot(2,1,1) 79 | plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) 80 | plt.xlabel("Decoder timestamp", fontsize=label_fontsize) 81 | plt.ylabel("Encoder timestamp", fontsize=label_fontsize) 82 | plt.colorbar() 83 | 84 | plt.subplot(2,1,2) 85 | librosa.display.specshow(spectrogram.T, sr=fs, 86 | hop_length=hop_length, x_axis="time", y_axis="linear") 87 | plt.xlabel("Time", fontsize=label_fontsize) 88 | plt.ylabel("Hz", fontsize=label_fontsize) 89 | plt.tight_layout() 90 | plt.colorbar() 91 | 92 | plt.savefig(path + '_all.png', dpi=300, format='png') 93 | plt.close() 94 | 95 | 96 | ################## 97 | # SAVE ALIGNMENT # 98 | ################## 99 | def _save_alignment(alignment, path): 100 | plt.gcf().clear() 101 | plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) 102 | plt.xlabel("Decoder timestamp") 103 | plt.ylabel("Encoder timestamp") 104 | plt.colorbar() 105 | plt.savefig(path + '_alignment.png', dpi=300, format='png') 106 | 107 | 108 | #################### 109 | # SAVE SPECTROGRAM # 110 | #################### 111 | def _save_spectrogram(spectrogram, path): 112 | plt.gcf().clear() # Clear current previous figure 113 | cmap = plt.get_cmap('jet') 114 | t = win + np.arange(spectrogram.shape[0]) * hop 115 | f = np.arange(spectrogram.shape[1]) * fs / nfft 116 | plt.pcolormesh(t, f, spectrogram.T, cmap=cmap) 117 | plt.xlabel('Time (sec)') 118 | plt.ylabel('Frequency (Hz)') 119 | plt.colorbar() 120 | plt.savefig(path + '_spectrogram.png', dpi=300, format='png') 121 | -------------------------------------------------------------------------------- /utils/text/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | from . import cleaners 3 | from .symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | # Regular expression matching text enclosed in curly braces: 11 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 12 | 13 | 14 | def text_to_sequence(text): 15 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 16 | 17 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 18 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 19 | 20 | Args: 21 | text: string to convert to a sequence 22 | cleaner_names: names of the cleaner functions to run the text through 23 | 24 | Returns: 25 | List of integers corresponding to the symbols in the text 26 | ''' 27 | sequence = [] 28 | 29 | # Check for curly braces and treat their contents as ARPAbet: 30 | while len(text): 31 | m = _curly_re.match(text) 32 | if not m: 33 | sequence += _symbols_to_sequence(text) 34 | break 35 | sequence += _symbols_to_sequence(m.group(1)) 36 | sequence += _arpabet_to_sequence(m.group(2)) 37 | text = m.group(3) 38 | 39 | # Append EOS token 40 | sequence.append(_symbol_to_id['~']) 41 | return sequence 42 | 43 | 44 | def sequence_to_text(sequence): 45 | '''Converts a sequence of IDs back to a string''' 46 | result = '' 47 | for symbol_id in sequence: 48 | if symbol_id in _id_to_symbol: 49 | s = _id_to_symbol[symbol_id] 50 | # Enclose ARPAbet back in curly braces: 51 | if len(s) > 1 and s[0] == '@': 52 | s = '{%s}' % s[1:] 53 | result += s 54 | return result.replace('}{', ' ') 55 | 56 | 57 | def _clean_text(text, cleaner_names): 58 | for name in cleaner_names: 59 | cleaner = getattr(cleaners, name) 60 | if not cleaner: 61 | raise Exception('Unknown cleaner: %s' % name) 62 | text = cleaner(text) 63 | return text 64 | 65 | 66 | def _symbols_to_sequence(symbols): 67 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 68 | 69 | 70 | def _arpabet_to_sequence(text): 71 | return _symbols_to_sequence(['@' + s for s in text.split()]) 72 | 73 | 74 | def _should_keep_symbol(s): 75 | return s in _symbol_to_id and s is not '_' and s is not '~' 76 | -------------------------------------------------------------------------------- /utils/text/cleaners.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Cleaners are transformations that run over the input text at both training and eval time. 3 | 4 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 5 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 6 | 1. "english_cleaners" for English text 7 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 8 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 9 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 10 | the symbols in symbols.py to match your data). 11 | ''' 12 | 13 | import re 14 | from unidecode import unidecode 15 | from .numbers import normalize_numbers 16 | 17 | 18 | # Regular expression matching whitespace: 19 | _whitespace_re = re.compile(r'\s+') 20 | 21 | # List of (regular expression, replacement) pairs for abbreviations: 22 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 23 | ('mrs', 'misess'), 24 | ('mr', 'mister'), 25 | ('dr', 'doctor'), 26 | ('st', 'saint'), 27 | ('co', 'company'), 28 | ('jr', 'junior'), 29 | ('maj', 'major'), 30 | ('gen', 'general'), 31 | ('drs', 'doctors'), 32 | ('rev', 'reverend'), 33 | ('lt', 'lieutenant'), 34 | ('hon', 'honorable'), 35 | ('sgt', 'sergeant'), 36 | ('capt', 'captain'), 37 | ('esq', 'esquire'), 38 | ('ltd', 'limited'), 39 | ('col', 'colonel'), 40 | ('ft', 'fort'), 41 | ]] 42 | 43 | 44 | def expand_abbreviations(text): 45 | for regex, replacement in _abbreviations: 46 | text = re.sub(regex, replacement, text) 47 | return text 48 | 49 | 50 | def expand_numbers(text): 51 | return normalize_numbers(text) 52 | 53 | 54 | def lowercase(text): 55 | return text.lower() 56 | 57 | 58 | def collapse_whitespace(text): 59 | return re.sub(_whitespace_re, ' ', text) 60 | 61 | 62 | def convert_to_ascii(text): 63 | return unidecode(text) 64 | 65 | 66 | def basic_cleaners(text): 67 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 68 | text = lowercase(text) 69 | text = collapse_whitespace(text) 70 | return text 71 | 72 | 73 | def transliteration_cleaners(text): 74 | '''Pipeline for non-English text that transliterates to ASCII.''' 75 | text = convert_to_ascii(text) 76 | text = lowercase(text) 77 | text = collapse_whitespace(text) 78 | return text 79 | 80 | 81 | def english_cleaners(text): 82 | '''Pipeline for English text, including number and abbreviation expansion.''' 83 | text = convert_to_ascii(text) 84 | text = lowercase(text) 85 | text = expand_numbers(text) 86 | text = expand_abbreviations(text) 87 | text = collapse_whitespace(text) 88 | return text 89 | -------------------------------------------------------------------------------- /utils/text/cmudict.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | valid_symbols = [ 5 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 6 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 7 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 8 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 9 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 10 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 11 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 12 | ] 13 | 14 | _valid_symbol_set = set(valid_symbols) 15 | 16 | 17 | class CMUDict: 18 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 19 | def __init__(self, file_or_path, keep_ambiguous=True): 20 | if isinstance(file_or_path, str): 21 | with open(file_or_path, encoding='latin-1') as f: 22 | entries = _parse_cmudict(f) 23 | else: 24 | entries = _parse_cmudict(file_or_path) 25 | if not keep_ambiguous: 26 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 27 | self._entries = entries 28 | 29 | 30 | def __len__(self): 31 | return len(self._entries) 32 | 33 | 34 | def lookup(self, word): 35 | '''Returns list of ARPAbet pronunciations of the given word.''' 36 | return self._entries.get(word.upper()) 37 | 38 | 39 | 40 | _alt_re = re.compile(r'\([0-9]+\)') 41 | 42 | 43 | def _parse_cmudict(file): 44 | cmudict = {} 45 | for line in file: 46 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 47 | parts = line.split(' ') 48 | word = re.sub(_alt_re, '', parts[0]) 49 | pronunciation = _get_pronunciation(parts[1]) 50 | if pronunciation: 51 | if word in cmudict: 52 | cmudict[word].append(pronunciation) 53 | else: 54 | cmudict[word] = [pronunciation] 55 | return cmudict 56 | 57 | 58 | def _get_pronunciation(s): 59 | parts = s.strip().split(' ') 60 | for part in parts: 61 | if part not in _valid_symbol_set: 62 | return None 63 | return ' '.join(parts) 64 | -------------------------------------------------------------------------------- /utils/text/numbers.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | import re 3 | 4 | 5 | _inflect = inflect.engine() 6 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 7 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 8 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 9 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 10 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 11 | _number_re = re.compile(r'[0-9]+') 12 | 13 | 14 | def _remove_commas(m): 15 | return m.group(1).replace(',', '') 16 | 17 | 18 | def _expand_decimal_point(m): 19 | return m.group(1).replace('.', ' point ') 20 | 21 | 22 | def _expand_dollars(m): 23 | match = m.group(1) 24 | parts = match.split('.') 25 | if len(parts) > 2: 26 | return match + ' dollars' # Unexpected format 27 | dollars = int(parts[0]) if parts[0] else 0 28 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 29 | if dollars and cents: 30 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 31 | cent_unit = 'cent' if cents == 1 else 'cents' 32 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 33 | elif dollars: 34 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 35 | return '%s %s' % (dollars, dollar_unit) 36 | elif cents: 37 | cent_unit = 'cent' if cents == 1 else 'cents' 38 | return '%s %s' % (cents, cent_unit) 39 | else: 40 | return 'zero dollars' 41 | 42 | 43 | def _expand_ordinal(m): 44 | return _inflect.number_to_words(m.group(0)) 45 | 46 | 47 | def _expand_number(m): 48 | num = int(m.group(0)) 49 | if num > 1000 and num < 3000: 50 | if num == 2000: 51 | return 'two thousand' 52 | elif num > 2000 and num < 2010: 53 | return 'two thousand ' + _inflect.number_to_words(num % 100) 54 | elif num % 100 == 0: 55 | return _inflect.number_to_words(num // 100) + ' hundred' 56 | else: 57 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 58 | else: 59 | return _inflect.number_to_words(num, andword='') 60 | 61 | 62 | def normalize_numbers(text): 63 | text = re.sub(_comma_number_re, _remove_commas, text) 64 | text = re.sub(_pounds_re, r'\1 pounds', text) 65 | text = re.sub(_dollars_re, _expand_dollars, text) 66 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 67 | text = re.sub(_ordinal_re, _expand_ordinal, text) 68 | text = re.sub(_number_re, _expand_number, text) 69 | return text 70 | -------------------------------------------------------------------------------- /utils/text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | 4 | The default is a set of ASCII characters that works well for English or text that has been run 5 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. 6 | ''' 7 | from . import cmudict 8 | 9 | _pad = '_' 10 | _eos = '~' 11 | _characters = '1234ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' 12 | 13 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 14 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 15 | 16 | # Export all symbols: 17 | symbols = [_pad, _eos] + list(_characters) + _arpabet 18 | --------------------------------------------------------------------------------