├── .gitignore ├── LICENSE ├── README.md ├── audio_processing.py ├── commands.sh ├── cpu_gpu_viterbi_benchmark.ipynb ├── extract_alignments.py ├── figures ├── align.png ├── alignments.JPG ├── mdn sample.JPG ├── melspecs.JPG ├── stage0_train_loss.JPG ├── stage0_val_loss.JPG ├── stage1_train_loss.JPG ├── stage1_val_loss.JPG ├── stage2_train_loss.JPG ├── stage2_val_fft_loss.JPG ├── stage2_val_mdn_loss.JPG ├── stage3_train_loss.JPG └── stage3_val_loss.JPG ├── filelists ├── ljs_audio_text_test_filelist.txt ├── ljs_audio_text_train_filelist.txt └── ljs_audio_text_val_filelist.txt ├── generate_samples.ipynb ├── hparams.py ├── index.html ├── inference.ipynb ├── layers.py ├── modules ├── __pycache__ │ ├── init_layer.cpython-36.pyc │ ├── init_layer.cpython-37.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── transformer.cpython-36.pyc │ └── transformer.cpython-37.pyc ├── init_layer.py ├── loss.py ├── model.py └── transformer.py ├── prepare_data.ipynb ├── prepare_stages_benchmark.ipynb ├── requirement.txt ├── stft.py ├── text ├── LICENSE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── cleaners.cpython-36.pyc │ ├── cleaners.cpython-37.pyc │ ├── cmudict.cpython-36.pyc │ ├── cmudict.cpython-37.pyc │ ├── numbers.cpython-36.pyc │ ├── numbers.cpython-37.pyc │ ├── symbols.cpython-36.pyc │ └── symbols.cpython-37.pyc ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py ├── train.py ├── training_log └── readme.txt ├── utils ├── __pycache__ │ ├── data_utils.cpython-36.pyc │ ├── data_utils.cpython-37.pyc │ ├── plot_image.cpython-36.pyc │ ├── plot_image.cpython-37.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── writer.cpython-36.pyc │ └── writer.cpython-37.pyc ├── data_utils.py ├── plot_image.py ├── utils.py └── writer.py ├── waveglow ├── .gitmodules ├── LICENSE ├── README.md ├── config.json ├── convert_model.py ├── denoiser.py ├── distributed.py ├── glow.py ├── glow_old.py ├── inference.py ├── mel2samp.py ├── requirements.txt ├── train.py └── waveglow_logo.png └── wavs ├── LJ001-0029_phone10000_10.wav ├── LJ001-0029_phone10000_11.wav ├── LJ001-0029_phone10000_12.wav ├── LJ001-0029_phone10000_8.wav ├── LJ001-0029_phone10000_9.wav ├── LJ001-0085_phone10000_10.wav ├── LJ001-0085_phone10000_11.wav ├── LJ001-0085_phone10000_12.wav ├── LJ001-0085_phone10000_8.wav ├── LJ001-0085_phone10000_9.wav ├── LJ002-0106_phone10000_10.wav ├── LJ002-0106_phone10000_11.wav ├── LJ002-0106_phone10000_12.wav ├── LJ002-0106_phone10000_8.wav └── LJ002-0106_phone10000_9.wav /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .ipynb_checkpoints/ 3 | runs/ 4 | training_log/ 5 | nohup.out 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Deepest-Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlignTTS 2 | Implementation of the [AlignTTS](https://arxiv.org/abs/2003.01950) 3 | 4 | # Figures 5 | ## Losses 6 | ### stage0 7 | 8 | ### stage1 9 | 10 | ### stage2 11 | 12 | ### stage3 13 | 14 | 15 | ## Alignments 16 | 17 | 18 | ## Melspectrograms 19 | 20 | 21 | ## Losses 22 | ## MDN Sample 23 | 24 | 25 | ## Alignment Smaple 26 | 27 | 28 | ## Audio Samples 29 | You can hear the audio samples [here](https://deepest-project.github.io/AlignTTS) 30 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, 8 | n_frames, 9 | hop_length=200, 10 | win_length=800, 11 | n_fft=800, 12 | dtype=np.float32, 13 | norm=None): 14 | """ 15 | # from librosa 0.6 16 | Compute the sum-square envelope of a window function at a given hop length. 17 | 18 | This is used to estimate modulation effects induced by windowing 19 | observations in short-time fourier transforms. 20 | 21 | Parameters 22 | ---------- 23 | window : string, tuple, number, callable, or list-like 24 | Window specification, as in `get_window` 25 | 26 | n_frames : int > 0 27 | The number of analysis frames 28 | 29 | hop_length : int > 0 30 | The number of samples to advance between frames 31 | 32 | win_length : [optional] 33 | The length of the window function. By default, this matches `n_fft`. 34 | 35 | n_fft : int > 0 36 | The length of each analysis frame. 37 | 38 | dtype : np.dtype 39 | The data type of the output 40 | 41 | Returns 42 | ------- 43 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 44 | The sum-squared envelope of the window function 45 | """ 46 | if win_length is None: 47 | win_length = n_fft 48 | 49 | n = n_fft + hop_length * (n_frames - 1) 50 | x = np.zeros(n, dtype=dtype) 51 | 52 | # Compute the squared window at the desired length 53 | win_sq = get_window(window, win_length, fftbins=True) 54 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 55 | win_sq = librosa_util.pad_center(win_sq, n_fft) 56 | 57 | # Fill the envelope 58 | for i in range(n_frames): 59 | sample = i * hop_length 60 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 61 | return x 62 | 63 | 64 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 65 | """ 66 | PARAMS 67 | ------ 68 | magnitudes: spectrogram magnitudes 69 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 70 | """ 71 | 72 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 73 | angles = angles.astype(np.float32) 74 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 75 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 76 | 77 | for i in range(n_iters): 78 | _, angles = stft_fn.transform(signal) 79 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 80 | return signal 81 | 82 | 83 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 84 | """ 85 | PARAMS 86 | ------ 87 | C: compression factor 88 | """ 89 | return torch.log(torch.clamp(x, min=clip_val) * C) 90 | 91 | 92 | def dynamic_range_decompression(x, C=1): 93 | """ 94 | PARAMS 95 | ------ 96 | C: compression factor used to compress 97 | """ 98 | return torch.exp(x) / C -------------------------------------------------------------------------------- /commands.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --stage=0 && 4 | python extract_alignments.py && 5 | python train.py --stage=1 && 6 | python train.py --stage=2 && 7 | python extract_alignments.py && 8 | python train.py --stage=3 9 | -------------------------------------------------------------------------------- /extract_alignments.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | os.environ["CUDA_VISIBLE_DEVICES"]='0' 3 | 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | import sys 8 | sys.path.append('waveglow/') 9 | 10 | import IPython.display as ipd 11 | import pickle as pkl 12 | import torch 13 | import torch.nn.functional as F 14 | import hparams 15 | from torch.utils.data import DataLoader 16 | from modules.model import Model 17 | from text import text_to_sequence, sequence_to_text 18 | from denoiser import Denoiser 19 | from tqdm import tqdm 20 | import librosa 21 | from modules.loss import MDNLoss 22 | import math 23 | import numpy as np 24 | from datetime import datetime 25 | 26 | def main(): 27 | data_type = 'phone' 28 | checkpoint_path = f"training_log/aligntts/stage0/checkpoint_{hparams.train_steps[0]}" 29 | state_dict = {} 30 | 31 | for k, v in torch.load(checkpoint_path)['state_dict'].items(): 32 | state_dict[k[7:]]=v 33 | 34 | model = Model(hparams).cuda() 35 | model.load_state_dict(state_dict) 36 | _ = model.cuda().eval() 37 | criterion = MDNLoss() 38 | 39 | datasets = ['train', 'val', 'test'] 40 | batch_size=64 41 | 42 | for dataset in datasets: 43 | with open(f'filelists/ljs_audio_text_{dataset}_filelist.txt', 'r') as f: 44 | lines_raw = [line.split('|') for line in f.read().splitlines()] 45 | lines_list = [ lines_raw[batch_size*i:batch_size*(i+1)] 46 | for i in range(len(lines_raw)//batch_size+1)] 47 | 48 | for batch in tqdm(lines_list): 49 | file_list, text_list, mel_list = [], [], [] 50 | text_lengths, mel_lengths=[], [] 51 | 52 | for i in range(len(batch)): 53 | file_name, _, text = batch[i] 54 | file_list.append(file_name) 55 | seq = os.path.join('../Dataset/LJSpeech-1.1/preprocessed', 56 | f'{data_type}_seq') 57 | mel = os.path.join('../Dataset/LJSpeech-1.1/preprocessed', 58 | 'melspectrogram') 59 | 60 | seq = torch.from_numpy(np.load(f'{seq}/{file_name}_sequence.npy')) 61 | mel = torch.from_numpy(np.load(f'{mel}/{file_name}_melspectrogram.npy')) 62 | 63 | text_list.append(seq) 64 | mel_list.append(mel) 65 | text_lengths.append(seq.size(0)) 66 | mel_lengths.append(mel.size(1)) 67 | 68 | text_lengths = torch.LongTensor(text_lengths) 69 | mel_lengths = torch.LongTensor(mel_lengths) 70 | text_padded = torch.zeros(len(batch), text_lengths.max().item(), dtype=torch.long) 71 | mel_padded = torch.zeros(len(batch), hparams.n_mel_channels, mel_lengths.max().item()) 72 | 73 | for j in range(len(batch)): 74 | text_padded[j, :text_list[j].size(0)] = text_list[j] 75 | mel_padded[j, :, :mel_list[j].size(1)] = mel_list[j] 76 | 77 | text_padded = text_padded.cuda() 78 | mel_padded = mel_padded.cuda() 79 | mel_padded = (torch.clamp(mel_padded, hparams.min_db, hparams.max_db)-hparams.min_db) / (hparams.max_db-hparams.min_db) 80 | text_lengths = text_lengths.cuda() 81 | mel_lengths = mel_lengths.cuda() 82 | 83 | with torch.no_grad(): 84 | encoder_input = model.Prenet(text_padded) 85 | hidden_states, _ = model.FFT_lower(encoder_input, text_lengths) 86 | mu_sigma = model.get_mu_sigma(hidden_states) 87 | _, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths) 88 | 89 | align = model.viterbi(log_prob_matrix, text_lengths, mel_lengths).to(torch.long) 90 | alignments = list(torch.split(align,1)) 91 | 92 | for j, (l, t) in enumerate(zip(text_lengths, mel_lengths)): 93 | alignments[j] = alignments[j][0, :l.item(), :t.item()].sum(dim=-1) 94 | np.save(f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_list[j]}_alignment.npy', 95 | alignments[j].detach().cpu().numpy()) 96 | 97 | print("Alignments Extraction End!!! ({datetime.now()})") 98 | 99 | 100 | if __name__ == '__main__': 101 | p = argparse.ArgumentParser() 102 | p.add_argument('--gpu', type=str, default='0') 103 | p.add_argument('-v', '--verbose', type=str, default='0') 104 | args = p.parse_args() 105 | 106 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 107 | torch.manual_seed(hparams.seed) 108 | torch.cuda.manual_seed(hparams.seed) 109 | 110 | if args.verbose=='0': 111 | import warnings 112 | warnings.filterwarnings("ignore") 113 | 114 | main() -------------------------------------------------------------------------------- /figures/align.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/align.png -------------------------------------------------------------------------------- /figures/alignments.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/alignments.JPG -------------------------------------------------------------------------------- /figures/mdn sample.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/mdn sample.JPG -------------------------------------------------------------------------------- /figures/melspecs.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/melspecs.JPG -------------------------------------------------------------------------------- /figures/stage0_train_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage0_train_loss.JPG -------------------------------------------------------------------------------- /figures/stage0_val_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage0_val_loss.JPG -------------------------------------------------------------------------------- /figures/stage1_train_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage1_train_loss.JPG -------------------------------------------------------------------------------- /figures/stage1_val_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage1_val_loss.JPG -------------------------------------------------------------------------------- /figures/stage2_train_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage2_train_loss.JPG -------------------------------------------------------------------------------- /figures/stage2_val_fft_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage2_val_fft_loss.JPG -------------------------------------------------------------------------------- /figures/stage2_val_mdn_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage2_val_mdn_loss.JPG -------------------------------------------------------------------------------- /figures/stage3_train_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage3_train_loss.JPG -------------------------------------------------------------------------------- /figures/stage3_val_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage3_val_loss.JPG -------------------------------------------------------------------------------- /generate_samples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Import libraries and setup matplotlib" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n", 18 | "\n", 19 | "import warnings\n", 20 | "warnings.filterwarnings(\"ignore\")\n", 21 | "\n", 22 | "import sys\n", 23 | "sys.path.append('waveglow/')\n", 24 | "\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "%matplotlib inline\n", 27 | "\n", 28 | "import IPython.display as ipd\n", 29 | "import pickle as pkl\n", 30 | "from text import *\n", 31 | "import numpy as np\n", 32 | "import torch\n", 33 | "import hparams\n", 34 | "from modules.model import Model\n", 35 | "from denoiser import Denoiser\n", 36 | "import soundfile" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### Text preprocessing" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from g2p_en import G2p\n", 53 | "from text.symbols import symbols\n", 54 | "from text.cleaners import custom_english_cleaners\n", 55 | "\n", 56 | "# Mappings from symbol to numeric ID and vice versa:\n", 57 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n", 58 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n", 59 | "\n", 60 | "g2p = G2p()\n", 61 | "def text2seq(text, data_type='char'):\n", 62 | " text = custom_english_cleaners(text.rstrip())\n", 63 | " if data_type=='phone':\n", 64 | " clean_phone = []\n", 65 | " for s in g2p(text.lower()):\n", 66 | " if '@'+s in symbol_to_id:\n", 67 | " clean_phone.append('@'+s)\n", 68 | " else:\n", 69 | " clean_phone.append(s)\n", 70 | " text = clean_phone\n", 71 | " \n", 72 | " # Append SOS, EOS token\n", 73 | " sequence = [symbol_to_id[c] for c in text]\n", 74 | " sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n", 75 | " return sequence" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Waveglow" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "code_folding": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "waveglow_path = 'training_log/waveglow_256channels.pt'\n", 94 | "waveglow = torch.load(waveglow_path)['model']\n", 95 | "\n", 96 | "for m in waveglow.modules():\n", 97 | " if 'Conv' in str(type(m)):\n", 98 | " setattr(m, 'padding_mode', 'zeros')\n", 99 | "\n", 100 | "waveglow.cuda().eval()\n", 101 | "for k in waveglow.convinv:\n", 102 | " k.float()\n", 103 | "\n", 104 | "denoiser = Denoiser(waveglow)\n", 105 | "\n", 106 | "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n", 107 | " lines = [line.split('|') for line in f.read().splitlines()]" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "scrolled": false 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "for data_type in ['phone']:\n", 119 | " for step in ['10000']:\n", 120 | " checkpoint_path = f\"training_log/aligntts/stage3/checkpoint_{step}\"\n", 121 | " state_dict = {}\n", 122 | " for k, v in torch.load(checkpoint_path)['state_dict'].items():\n", 123 | " state_dict[k[7:]]=v\n", 124 | "\n", 125 | " model = Model(hparams).cuda()\n", 126 | " model.load_state_dict(state_dict)\n", 127 | " _ = model.cuda().eval()\n", 128 | "\n", 129 | " for i in [1, 6, 22]:\n", 130 | " file_name, _, text = lines[i]\n", 131 | " sequence = np.array(text2seq(text,data_type))[None, :]\n", 132 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n", 133 | " \n", 134 | " for alpha in [0.8, 0.9, 1.0, 1.1, 1.2]:\n", 135 | " with torch.no_grad():\n", 136 | " melspec, durations = model.inference(sequence, alpha)\n", 137 | " melspec = melspec*(hparams.max_db-hparams.min_db)+hparams.min_db\n", 138 | " audio = waveglow.infer(melspec, sigma=0.666)\n", 139 | "\n", 140 | " soundfile.write(f'wavs/{file_name}_{data_type}{step}_{str(int(10*alpha))}.wav', audio.cpu().numpy()[0].astype(float), 22050)" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Environment (conda_pytorch_p36)", 147 | "language": "python", 148 | "name": "conda_pytorch_p36" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.6.5" 161 | }, 162 | "varInspector": { 163 | "cols": { 164 | "lenName": 16, 165 | "lenType": 16, 166 | "lenVar": 40 167 | }, 168 | "kernels_config": { 169 | "python": { 170 | "delete_cmd_postfix": "", 171 | "delete_cmd_prefix": "del ", 172 | "library": "var_list.py", 173 | "varRefreshCmd": "print(var_dic_list())" 174 | }, 175 | "r": { 176 | "delete_cmd_postfix": ") ", 177 | "delete_cmd_prefix": "rm(", 178 | "library": "var_list.r", 179 | "varRefreshCmd": "cat(var_dic_list()) " 180 | } 181 | }, 182 | "types_to_exclude": [ 183 | "module", 184 | "function", 185 | "builtin_function_or_method", 186 | "instance", 187 | "_Feature" 188 | ], 189 | "window_display": false 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 2 194 | } 195 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | ################################ 4 | # Experiment Parameters # 5 | ################################ 6 | seed=1234 7 | n_gpus=2 8 | output_directory = 'training_log' 9 | log_directory = 'aligntts' 10 | data_path = '../Dataset/LJSpeech-1.1/preprocessed' 11 | 12 | training_files='filelists/ljs_audio_text_train_filelist.txt' 13 | validation_files='filelists/ljs_audio_text_val_filelist.txt' 14 | text_cleaners=['english_cleaners'] 15 | 16 | 17 | ################################ 18 | # Audio Parameters # 19 | ################################ 20 | sampling_rate=22050 21 | filter_length=1024 22 | hop_length=256 23 | win_length=1024 24 | n_mel_channels=80 25 | mel_fmin=0 26 | mel_fmax=8000.0 27 | 28 | ################################ 29 | # Model Parameters # 30 | ################################ 31 | n_symbols=len(symbols) 32 | data_type='phone_seq' # 'phone_seq' 33 | symbols_embedding_dim=256 34 | hidden_dim=256 35 | dprenet_dim=256 36 | postnet_dim=256 37 | ff_dim=1024 38 | n_heads=2 39 | n_layers=6 40 | max_db=2 41 | min_db=-12 42 | 43 | ################################ 44 | # Optimization Hyperparameters # 45 | ################################ 46 | lr=384**-0.5 47 | warmup_steps=4000 48 | grad_clip_thresh=1.0 49 | batch_size=32 50 | accumulation=1 51 | iters_per_validation=1000 52 | iters_per_checkpoint=10000 53 | train_steps = [40000, 40000, 80000, 10000] 54 | 55 | 56 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | AlignTTS Audio Samples 6 | 7 | 8 | 9 | 10 |
11 |

AlignTTS (phoneme)

12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 30 | 32 | 34 | 36 | 37 | 38 | 39 | 41 | 43 | 45 | 47 | 49 | 50 | 51 | 52 | 54 | 56 | 58 | 60 | 62 | 63 | 64 |
Stepsalpha=0.8alpha=0.9alpha=1.0alpha=1.1alpha=1.2
LJ001-0029
LJ001-0085
LJ002-0106
65 |
66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Import libraries and setup matplotlib" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n", 18 | "\n", 19 | "import warnings\n", 20 | "warnings.filterwarnings(\"ignore\")\n", 21 | "\n", 22 | "import sys\n", 23 | "sys.path.append('waveglow/')\n", 24 | "\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "%matplotlib inline\n", 27 | "\n", 28 | "import IPython.display as ipd\n", 29 | "import pickle as pkl\n", 30 | "import librosa\n", 31 | "from text import *\n", 32 | "import numpy as np\n", 33 | "import torch\n", 34 | "import hparams\n", 35 | "from modules.model import Model\n", 36 | "from denoiser import Denoiser" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### Text preprocessing" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from g2p_en import G2p\n", 53 | "from text.symbols import symbols\n", 54 | "from text.cleaners import custom_english_cleaners\n", 55 | "\n", 56 | "# Mappings from symbol to numeric ID and vice versa:\n", 57 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n", 58 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n", 59 | "\n", 60 | "g2p = G2p()\n", 61 | "def text2seq(text, data_type='char'):\n", 62 | " text = custom_english_cleaners(text.rstrip())\n", 63 | " if data_type=='phone':\n", 64 | " clean_phone = []\n", 65 | " for s in g2p(text.lower()):\n", 66 | " if '@'+s in symbol_to_id:\n", 67 | " clean_phone.append('@'+s)\n", 68 | " else:\n", 69 | " clean_phone.append(s)\n", 70 | " text = clean_phone\n", 71 | " \n", 72 | " # Append SOS, EOS token\n", 73 | " sequence = [symbol_to_id[c] for c in text]\n", 74 | " sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n", 75 | " return sequence" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Waveglow" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "code_folding": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "waveglow_path = 'training_log/waveglow_256channels.pt'\n", 94 | "waveglow = torch.load(waveglow_path)['model']\n", 95 | "\n", 96 | "for m in waveglow.modules():\n", 97 | " if 'Conv' in str(type(m)):\n", 98 | " setattr(m, 'padding_mode', 'zeros')\n", 99 | "\n", 100 | "waveglow.cuda().eval()\n", 101 | "for k in waveglow.convinv:\n", 102 | " k.float()\n", 103 | "denoiser = Denoiser(waveglow)\n", 104 | "\n", 105 | "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n", 106 | " lines = [line.split('|') for line in f.read().splitlines()]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "scrolled": false 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "for data_type in ['phone']:\n", 118 | " for step in ['10000']:\n", 119 | " print(f'{data_type}_{step}steps')\n", 120 | " checkpoint_path = f\"training_log/aligntts/stage3/checkpoint_{step}\"\n", 121 | " state_dict = {}\n", 122 | " for k, v in torch.load(checkpoint_path)['state_dict'].items():\n", 123 | " state_dict[k[7:]]=v\n", 124 | "\n", 125 | " model = Model(hparams).cuda()\n", 126 | " model.load_state_dict(state_dict)\n", 127 | " _ = model.cuda().eval()\n", 128 | "\n", 129 | " for i in [1, 6, 22]:\n", 130 | " file_name, _, text = lines[i]\n", 131 | " sequence = np.array(text2seq(text, data_type))[None, :]\n", 132 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n", 133 | "\n", 134 | " print(f'Text: {text}')\n", 135 | " for alpha in [0.8, 0.9, 1.0, 1.1, 1.2]:\n", 136 | " with torch.no_grad():\n", 137 | " melspec, durations = model.inference(sequence, alpha)\n", 138 | " melspec = melspec*(hparams.max_db-hparams.min_db)+hparams.min_db\n", 139 | " audio = waveglow.infer(melspec, sigma=0.666)\n", 140 | "\n", 141 | " print(f\"alpha: {alpha}\")\n", 142 | " ipd.display(ipd.Audio(audio.cpu().numpy(), rate=hparams.sampling_rate))\n", 143 | "\n", 144 | " if alpha==1.0:\n", 145 | " ticks=[]\n", 146 | " phoneme = sequence_to_text(sequence[0].tolist())\n", 147 | " for i, d in enumerate(durations[0]):\n", 148 | " ticks.extend([phoneme[i]]*int(d))\n", 149 | "\n", 150 | " plt.figure(figsize=(20,5))\n", 151 | " plt.imshow(melspec.detach().cpu()[0], aspect='auto', origin='lower')\n", 152 | " plt.xticks(range(melspec.size(2)), ticks)\n", 153 | " plt.show()" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Environment (conda_pytorch_p36)", 160 | "language": "python", 161 | "name": "conda_pytorch_p36" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.5" 174 | }, 175 | "varInspector": { 176 | "cols": { 177 | "lenName": 16, 178 | "lenType": 16, 179 | "lenVar": 40 180 | }, 181 | "kernels_config": { 182 | "python": { 183 | "delete_cmd_postfix": "", 184 | "delete_cmd_prefix": "del ", 185 | "library": "var_list.py", 186 | "varRefreshCmd": "print(var_dic_list())" 187 | }, 188 | "r": { 189 | "delete_cmd_postfix": ") ", 190 | "delete_cmd_prefix": "rm(", 191 | "library": "var_list.r", 192 | "varRefreshCmd": "cat(var_dic_list()) " 193 | } 194 | }, 195 | "types_to_exclude": [ 196 | "module", 197 | "function", 198 | "builtin_function_or_method", 199 | "instance", 200 | "_Feature" 201 | ], 202 | "window_display": false 203 | } 204 | }, 205 | "nbformat": 4, 206 | "nbformat_minor": 2 207 | } 208 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from audio_processing import dynamic_range_compression 4 | from audio_processing import dynamic_range_decompression 5 | from stft import STFT 6 | 7 | 8 | class LinearNorm(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 10 | super(LinearNorm, self).__init__() 11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 12 | 13 | torch.nn.init.xavier_uniform_( 14 | self.linear_layer.weight, 15 | gain=torch.nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class ConvNorm(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 23 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 24 | super(ConvNorm, self).__init__() 25 | if padding is None: 26 | assert(kernel_size % 2 == 1) 27 | padding = int(dilation * (kernel_size - 1) / 2) 28 | 29 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, 32 | bias=bias) 33 | 34 | torch.nn.init.xavier_uniform_( 35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 36 | 37 | def forward(self, signal): 38 | conv_signal = self.conv(signal) 39 | return conv_signal 40 | 41 | 42 | class TacotronSTFT(torch.nn.Module): 43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 45 | mel_fmax=8000.0): 46 | super(TacotronSTFT, self).__init__() 47 | self.n_mel_channels = n_mel_channels 48 | self.sampling_rate = sampling_rate 49 | self.stft_fn = STFT(filter_length, hop_length, win_length) 50 | mel_basis = librosa_mel_fn( 51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 52 | mel_basis = torch.from_numpy(mel_basis).float() 53 | self.register_buffer('mel_basis', mel_basis) 54 | 55 | def spectral_normalize(self, magnitudes): 56 | output = dynamic_range_compression(magnitudes) 57 | return output 58 | 59 | def spectral_de_normalize(self, magnitudes): 60 | output = dynamic_range_decompression(magnitudes) 61 | return output 62 | 63 | def mel_spectrogram(self, y): 64 | """Computes mel-spectrograms from a batch of waves 65 | PARAMS 66 | ------ 67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 68 | RETURNS 69 | ------- 70 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 71 | """ 72 | assert(torch.min(y.data) >= -1) 73 | assert(torch.max(y.data) <= 1) 74 | 75 | magnitudes, phases = self.stft_fn.transform(y) 76 | magnitudes = magnitudes.data 77 | mel_output = torch.matmul(self.mel_basis, magnitudes) 78 | mel_output = self.spectral_normalize(mel_output) 79 | return mel_output -------------------------------------------------------------------------------- /modules/__pycache__/init_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/init_layer.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/init_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/init_layer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/init_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Linear(nn.Linear): 7 | def __init__(self, 8 | in_dim, 9 | out_dim, 10 | bias=True, 11 | w_init_gain='linear'): 12 | super(Linear, self).__init__(in_dim, 13 | out_dim, 14 | bias) 15 | nn.init.xavier_uniform_(self.weight, 16 | gain=nn.init.calculate_gain(w_init_gain)) 17 | 18 | 19 | class Conv1d(nn.Conv1d): 20 | def __init__(self, 21 | in_channels, 22 | out_channels, 23 | kernel_size, 24 | stride=1, 25 | padding=0, 26 | dilation=1, 27 | groups=1, 28 | bias=True, 29 | padding_mode='zeros', 30 | w_init_gain='linear'): 31 | super(Conv1d, self).__init__(in_channels, 32 | out_channels, 33 | kernel_size, 34 | stride, 35 | padding, 36 | dilation, 37 | groups, 38 | bias, 39 | padding_mode) 40 | nn.init.xavier_uniform_(self.weight, 41 | gain=nn.init.calculate_gain(w_init_gain)) -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import hparams as hp 5 | from utils.utils import get_mask_from_lengths 6 | import math 7 | 8 | 9 | class MDNLoss(nn.Module): 10 | def __init__(self): 11 | super(MDNLoss, self).__init__() 12 | 13 | def forward(self, mu_sigma, melspec, text_lengths, mel_lengths): 14 | # mu, sigma: B, L, F / melspec: B, F, T 15 | B, L, _ = mu_sigma.size() 16 | T = melspec.size(2) 17 | 18 | x = melspec.transpose(1,2).unsqueeze(1) # B, 1, T, F 19 | mu = torch.sigmoid(mu_sigma[:, :, :hp.n_mel_channels].unsqueeze(2)) # B, L, 1, F 20 | log_sigma = mu_sigma[:, :, hp.n_mel_channels:].unsqueeze(2) # B, L, 1, F 21 | 22 | exponential = -0.5*torch.sum((x-mu)*(x-mu)/log_sigma.exp()**2, dim=-1) # B, L, T 23 | log_prob_matrix = exponential - (hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) - 0.5 * log_sigma.sum(dim=-1) 24 | log_alpha = mu_sigma.new_ones(B, L, T)*(-1e30) 25 | log_alpha[:,0, 0] = log_prob_matrix[:,0, 0] 26 | 27 | for t in range(1, T): 28 | prev_step = torch.cat([log_alpha[:, :, t-1:t], F.pad(log_alpha[:, :, t-1:t], (0,0,1,-1), value=-1e30)], dim=-1) 29 | log_alpha[:, :, t] = torch.logsumexp(prev_step+1e-30, dim=-1)+log_prob_matrix[:, :, t] 30 | 31 | alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1] 32 | mdn_loss = -alpha_last.mean() 33 | 34 | return mdn_loss, log_prob_matrix -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .init_layer import * 5 | from .transformer import * 6 | from utils.utils import get_mask_from_lengths 7 | 8 | from datetime import datetime 9 | from time import sleep 10 | 11 | class Prenet(nn.Module): 12 | def __init__(self, hp): 13 | super(Prenet, self).__init__() 14 | # B, L -> B, L, D 15 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim) 16 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe) 17 | self.dropout = nn.Dropout(0.1) 18 | 19 | def forward(self, text): 20 | B, L = text.size(0), text.size(1) 21 | x = self.Embedding(text).transpose(0,1) 22 | x += self.pe[:L].unsqueeze(1) 23 | x = self.dropout(x).transpose(0,1) 24 | return x 25 | 26 | 27 | class FFT(nn.Module): 28 | def __init__(self, hidden_dim, n_heads, ff_dim, n_layers): 29 | super(FFT, self).__init__() 30 | self.FFT_layers = nn.ModuleList([TransformerEncoderLayer(d_model=hidden_dim, 31 | nhead=n_heads, 32 | dim_feedforward=ff_dim) 33 | for _ in range(n_layers)]) 34 | def forward(self, x, lengths): 35 | # B, L, D -> B, L, D 36 | alignments = [] 37 | x = x.transpose(0,1) 38 | mask = get_mask_from_lengths(lengths) 39 | for layer in self.FFT_layers: 40 | x, align = layer(x, src_key_padding_mask=mask) 41 | alignments.append(align.unsqueeze(1)) 42 | alignments = torch.cat(alignments, 1) 43 | 44 | return x.transpose(0,1), alignments 45 | 46 | 47 | class DurationPredictor(nn.Module): 48 | def __init__(self, hp): 49 | super(DurationPredictor, self).__init__() 50 | self.Prenet = Prenet(hp) 51 | self.FFT = FFT(hp.hidden_dim, hp.n_heads, hp.ff_dim, 2) 52 | self.linear = Linear(hp.hidden_dim, 1) 53 | 54 | def forward(self, text, text_lengths): 55 | # B, L -> B, L 56 | encoder_input = self.Prenet(text) 57 | x = self.FFT(encoder_input, text_lengths)[0] 58 | x = self.linear(x).squeeze(-1) 59 | return x 60 | 61 | 62 | class Model(nn.Module): 63 | def __init__(self, hp): 64 | super(Model, self).__init__() 65 | self.Prenet = Prenet(hp) 66 | self.FFT_lower = FFT(hp.hidden_dim, hp.n_heads, hp.ff_dim, hp.n_layers) 67 | self.FFT_upper = FFT(hp.hidden_dim, hp.n_heads, hp.ff_dim, hp.n_layers) 68 | self.MDN = nn.Sequential(Linear(hp.hidden_dim, hp.hidden_dim), 69 | nn.LayerNorm(hp.hidden_dim), 70 | nn.ReLU(), 71 | nn.Dropout(0.1), 72 | Linear(hp.hidden_dim, 2*hp.n_mel_channels)) 73 | self.DurationPredictor = DurationPredictor(hp) 74 | self.Projection = Linear(hp.hidden_dim, hp.n_mel_channels) 75 | 76 | def get_mu_sigma(self, hidden_states): 77 | mu_sigma = self.MDN(hidden_states) 78 | return mu_sigma 79 | 80 | def get_duration(self, text, text_lengths): 81 | durations = self.DurationPredictor(text, text_lengths).exp() 82 | return durations 83 | 84 | def get_melspec(self, hidden_states, align, mel_lengths): 85 | hidden_states_expanded = torch.matmul(align.transpose(1,2), hidden_states) 86 | hidden_states_expanded += self.Prenet.pe[:hidden_states_expanded.size(1)].unsqueeze(1).transpose(0,1) 87 | mel_out = torch.sigmoid(self.Projection(self.FFT_upper(hidden_states_expanded, mel_lengths)[0]).transpose(1,2)) 88 | return mel_out 89 | 90 | def forward(self, text, melspec, align, text_lengths, mel_lengths, criterion, stage, log_viterbi=False, cpu_viterbi=False): 91 | text = text[:,:text_lengths.max().item()] 92 | melspec = melspec[:,:,:mel_lengths.max().item()] 93 | 94 | if stage==0: 95 | encoder_input = self.Prenet(text) 96 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths) 97 | mu_sigma = self.get_mu_sigma(hidden_states) 98 | mdn_loss, _ = criterion(mu_sigma, melspec, text_lengths, mel_lengths) 99 | return mdn_loss 100 | 101 | elif stage==1: 102 | align = align[:, :text_lengths.max().item(), :mel_lengths.max().item()] 103 | encoder_input = self.Prenet(text) 104 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths) 105 | mel_out = self.get_melspec(hidden_states, align, mel_lengths) 106 | 107 | mel_mask = ~get_mask_from_lengths(mel_lengths) 108 | melspec = melspec.masked_select(mel_mask.unsqueeze(1)) 109 | mel_out = mel_out.masked_select(mel_mask.unsqueeze(1)) 110 | fft_loss = nn.L1Loss()(mel_out, melspec) 111 | 112 | return fft_loss 113 | 114 | elif stage==2: 115 | encoder_input = self.Prenet(text) 116 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths) 117 | mu_sigma = self.get_mu_sigma(hidden_states) 118 | mdn_loss, log_prob_matrix = criterion(mu_sigma, melspec, text_lengths, mel_lengths) 119 | 120 | before = datetime.now() 121 | if cpu_viterbi: 122 | align = self.viterbi_cpu(log_prob_matrix, text_lengths.cpu(), mel_lengths.cpu()) # B, T 123 | else: 124 | align = self.viterbi(log_prob_matrix, text_lengths, mel_lengths) # B, T 125 | after = datetime.now() 126 | 127 | if log_viterbi: 128 | time_delta = after - before 129 | print(f'Viterbi took {time_delta.total_seconds()} secs') 130 | 131 | mel_out = self.get_melspec(hidden_states, align, mel_lengths) 132 | 133 | mel_mask = ~get_mask_from_lengths(mel_lengths) 134 | melspec = melspec.masked_select(mel_mask.unsqueeze(1)) 135 | mel_out = mel_out.masked_select(mel_mask.unsqueeze(1)) 136 | fft_loss = nn.L1Loss()(mel_out, melspec) 137 | 138 | return mdn_loss + fft_loss 139 | 140 | elif stage==3: 141 | align = align[:, :text_lengths.max().item(), :mel_lengths.max().item()] 142 | duration_out = self.get_duration(text, text_lengths) # gradient cut 143 | duration_target = align.sum(-1) 144 | 145 | duration_mask = ~get_mask_from_lengths(text_lengths) 146 | duration_target = duration_target.masked_select(duration_mask) 147 | duration_out = duration_out.masked_select(duration_mask) 148 | duration_loss = nn.MSELoss()(torch.log(duration_out), torch.log(duration_target)) 149 | 150 | return duration_loss 151 | 152 | 153 | def inference(self, text, alpha=1.0): 154 | text_lengths = text.new_tensor([text.size(1)]) 155 | encoder_input = self.Prenet(text) 156 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths) 157 | durations = self.get_duration(text, text_lengths) 158 | durations = torch.round(durations*alpha).to(torch.long) 159 | durations[durations<=0]=1 160 | T=int(durations.sum().item()) 161 | mel_lengths = text.new_tensor([T]) 162 | hidden_states_expanded = torch.repeat_interleave(hidden_states, durations[0], dim=1) 163 | hidden_states_expanded += self.Prenet.pe[:hidden_states_expanded.size(1)].unsqueeze(1).transpose(0,1) 164 | mel_out = torch.sigmoid(self.Projection(self.FFT_upper(hidden_states_expanded, mel_lengths)[0]).transpose(1,2)) 165 | 166 | return mel_out, durations 167 | 168 | 169 | def viterbi(self, log_prob_matrix, text_lengths, mel_lengths): 170 | B, L, T = log_prob_matrix.size() 171 | log_beta = log_prob_matrix.new_ones(B, L, T)*(-1e15) 172 | log_beta[:, 0, 0] = log_prob_matrix[:, 0, 0] 173 | 174 | for t in range(1, T): 175 | prev_step = torch.cat([log_beta[:, :, t-1:t], F.pad(log_beta[:, :, t-1:t], (0,0,1,-1), value=-1e15)], dim=-1).max(dim=-1)[0] 176 | log_beta[:, :, t] = prev_step+log_prob_matrix[:, :, t] 177 | 178 | curr_rows = text_lengths-1 179 | curr_cols = mel_lengths-1 180 | path = [curr_rows*1.0] 181 | for _ in range(T-1): 182 | is_go = log_beta[torch.arange(B), (curr_rows-1).to(torch.long), (curr_cols-1).to(torch.long)]\ 183 | > log_beta[torch.arange(B), (curr_rows).to(torch.long), (curr_cols-1).to(torch.long)] 184 | curr_rows = F.relu(curr_rows-1.0*is_go+1.0)-1.0 185 | curr_cols = F.relu(curr_cols-1+1.0)-1.0 186 | path.append(curr_rows*1.0) 187 | 188 | path.reverse() 189 | path = torch.stack(path, -1) 190 | 191 | indices = path.new_tensor(torch.arange(path.max()+1).view(1,1,-1)) # 1, 1, L 192 | align = 1.0*(path.new_tensor(indices==path.unsqueeze(-1))) # B, T, L 193 | 194 | for i in range(align.size(0)): 195 | pad= T-mel_lengths[i] 196 | align[i] = F.pad(align[i], (0,0,-pad,pad)) 197 | 198 | return align.transpose(1,2) 199 | 200 | def fast_viterbi(self, log_prob_matrix, text_lengths, mel_lengths): 201 | B, L, T = log_prob_matrix.size() 202 | 203 | _log_prob_matrix = log_prob_matrix.cpu() 204 | 205 | curr_rows = text_lengths.cpu().to(torch.long)-1 206 | curr_cols = mel_lengths.cpu().to(torch.long)-1 207 | 208 | path = [curr_rows*1] 209 | 210 | for _ in range(T-1): 211 | # print(curr_rows-1) 212 | # print(curr_cols-1) 213 | is_go = _log_prob_matrix[torch.arange(B), curr_rows-1, curr_cols-1]\ 214 | > _log_prob_matrix[torch.arange(B), curr_rows, curr_cols-1] 215 | # curr_rows = F.relu(curr_rows-1*is_go+1)-1 216 | # curr_cols = F.relu(curr_cols)-1 217 | curr_rows = F.relu(curr_rows-1*is_go+1)-1 218 | curr_cols = F.relu(curr_cols-1+1)-1 219 | path.append(curr_rows*1) 220 | 221 | path.reverse() 222 | path = torch.stack(path, -1) 223 | 224 | indices = path.new_tensor(torch.arange(path.max()+1).view(1,1,-1)) # 1, 1, L 225 | align = 1.0*(path.new_tensor(indices==path.unsqueeze(-1))) # B, T, L 226 | 227 | for i in range(align.size(0)): 228 | pad= T-mel_lengths[i] 229 | align[i] = F.pad(align[i], (0,0,-pad,pad)) 230 | 231 | return align.transpose(1,2) 232 | 233 | def viterbi_cpu(self, log_prob_matrix, text_lengths, mel_lengths): 234 | 235 | original_device = log_prob_matrix.device 236 | 237 | B, L, T = log_prob_matrix.size() 238 | 239 | _log_prob_matrix = log_prob_matrix.cpu() 240 | 241 | log_beta = _log_prob_matrix.new_ones(B, L, T)*(-1e15) 242 | log_beta[:, 0, 0] = _log_prob_matrix[:, 0, 0] 243 | 244 | for t in range(1, T): 245 | prev_step = torch.cat([log_beta[:, :, t-1:t], F.pad(log_beta[:, :, t-1:t], (0,0,1,-1), value=-1e15)], dim=-1).max(dim=-1)[0] 246 | log_beta[:, :, t] = prev_step+_log_prob_matrix[:, :, t] 247 | 248 | curr_rows = text_lengths-1 249 | curr_cols = mel_lengths-1 250 | path = [curr_rows*1] 251 | for _ in range(T-1): 252 | is_go = log_beta[torch.arange(B), curr_rows-1, curr_cols-1]\ 253 | > log_beta[torch.arange(B), curr_rows, curr_cols-1] 254 | curr_rows = F.relu(curr_rows - 1 * is_go + 1) - 1 255 | curr_cols = F.relu(curr_cols) - 1 256 | path.append(curr_rows*1) 257 | 258 | path.reverse() 259 | path = torch.stack(path, -1) 260 | 261 | indices = path.new_tensor(torch.arange(path.max()+1).view(1,1,-1)) # 1, 1, L 262 | align = 1.0*(path.new_tensor(indices==path.unsqueeze(-1))) # B, T, L 263 | 264 | for i in range(align.size(0)): 265 | pad= T-mel_lengths[i] 266 | align[i] = F.pad(align[i], (0,0,-pad,pad)) 267 | 268 | return align.transpose(1,2).to(original_device) 269 | 270 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .init_layer import * 5 | 6 | 7 | class TransformerEncoderLayer(nn.Module): 8 | def __init__(self, 9 | d_model, 10 | nhead, 11 | dim_feedforward=2048, 12 | dropout=0.1, 13 | activation="relu"): 14 | super(TransformerEncoderLayer, self).__init__() 15 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 16 | 17 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation) 18 | self.linear2 = Linear(dim_feedforward, d_model) 19 | 20 | self.norm1 = nn.LayerNorm(d_model) 21 | self.norm2 = nn.LayerNorm(d_model) 22 | 23 | self.dropout = nn.Dropout(dropout) 24 | 25 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 26 | src2, enc_align = self.self_attn(src, 27 | src, 28 | src, 29 | attn_mask=src_mask, 30 | key_padding_mask=src_key_padding_mask) 31 | src = src + self.dropout(src2) 32 | src = self.norm1(src) 33 | 34 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 35 | src = src + self.dropout(src2) 36 | src = self.norm2(src) 37 | 38 | return src, enc_align 39 | 40 | 41 | class TransformerDecoderLayer(nn.Module): 42 | def __init__(self, 43 | d_model, 44 | nhead, 45 | dim_feedforward=2048, 46 | dropout=0.1, 47 | activation="relu"): 48 | super(TransformerDecoderLayer, self).__init__() 49 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 50 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 51 | 52 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation) 53 | self.linear2 = Linear(dim_feedforward, d_model) 54 | 55 | self.norm1 = nn.LayerNorm(d_model) 56 | self.norm2 = nn.LayerNorm(d_model) 57 | self.norm3 = nn.LayerNorm(d_model) 58 | 59 | self.dropout = nn.Dropout(dropout) 60 | 61 | def forward(self, 62 | tgt, 63 | memory, 64 | tgt_mask=None, 65 | memory_mask=None, 66 | tgt_key_padding_mask=None, 67 | memory_key_padding_mask=None): 68 | tgt2, dec_align = self.self_attn(tgt, 69 | tgt, 70 | tgt, 71 | attn_mask=tgt_mask, 72 | key_padding_mask=tgt_key_padding_mask) 73 | tgt = tgt + self.dropout(tgt2) 74 | tgt = self.norm1(tgt) 75 | 76 | tgt2, enc_dec_align = self.multihead_attn(tgt, 77 | memory, 78 | memory, 79 | attn_mask=memory_mask, 80 | key_padding_mask=memory_key_padding_mask) 81 | tgt = tgt + self.dropout(tgt2) 82 | tgt = self.norm2(tgt) 83 | 84 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 85 | tgt = tgt + self.dropout(tgt2) 86 | tgt = self.norm3(tgt) 87 | 88 | return tgt, dec_align, enc_dec_align 89 | 90 | 91 | class PositionalEncoding(nn.Module): 92 | def __init__(self, d_model, max_len=5000): 93 | super(PositionalEncoding, self).__init__() 94 | self.register_buffer('pe', self._get_pe_matrix(d_model, max_len)) 95 | 96 | def forward(self, x): 97 | return x + self.pe[:x.size(0)].unsqueeze(1) 98 | 99 | def _get_pe_matrix(self, d_model, max_len): 100 | pe = torch.zeros(max_len, d_model) 101 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 102 | div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model) 103 | 104 | pe[:, 0::2] = torch.sin(position / div_term) 105 | pe[:, 1::2] = torch.cos(position / div_term) 106 | 107 | return pe 108 | -------------------------------------------------------------------------------- /prepare_stages_benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Run this code twice after the stage0 and stage0" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "training_log/aligntts/stage0/checkpoint_1000\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "import os\n", 25 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]='0'\n", 26 | "\n", 27 | "import warnings\n", 28 | "warnings.filterwarnings(\"ignore\")\n", 29 | "\n", 30 | "import sys\n", 31 | "sys.path.append('waveglow/')\n", 32 | "\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "%matplotlib inline\n", 35 | "\n", 36 | "import IPython.display as ipd\n", 37 | "import pickle as pkl\n", 38 | "import torch\n", 39 | "import torch.nn.functional as F\n", 40 | "import hparams\n", 41 | "from torch.utils.data import DataLoader\n", 42 | "from modules.model import Model\n", 43 | "from text import text_to_sequence, sequence_to_text\n", 44 | "from denoiser import Denoiser\n", 45 | "from tqdm import tqdm_notebook as tqdm\n", 46 | "import librosa\n", 47 | "from modules.loss import MDNLoss\n", 48 | "import math\n", 49 | "from multiprocessing import Pool\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "data_type = 'char'\n", 53 | "checkpoint_path = f\"training_log/aligntts/stage0/checkpoint_40000\"\n", 54 | "\n", 55 | "from glob import glob\n", 56 | "\n", 57 | "checkpoint_path = sorted(glob(\"training_log/aligntts/stage0/checkpoint_*\"))[0]\n", 58 | "\n", 59 | "print(checkpoint_path)\n", 60 | "\n", 61 | "\n", 62 | "state_dict = {}\n", 63 | "for k, v in torch.load(checkpoint_path)['state_dict'].items():\n", 64 | " state_dict[k[7:]]=v\n", 65 | "\n", 66 | "\n", 67 | "model = Model(hparams).cuda()\n", 68 | "model.load_state_dict(state_dict)\n", 69 | "_ = model.cuda().eval()\n", 70 | "criterion = MDNLoss()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "import time" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 3, 85 | "metadata": { 86 | "scrolled": false 87 | }, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "application/vnd.jupyter.widget-view+json": { 92 | "model_id": "17e3fa022cc04cc0916ed226d16297dd", 93 | "version_major": 2, 94 | "version_minor": 0 95 | }, 96 | "text/plain": [ 97 | "HBox(children=(FloatProgress(value=0.0, max=656.0), HTML(value='')))" 98 | ] 99 | }, 100 | "metadata": {}, 101 | "output_type": "display_data" 102 | }, 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "VT Time: 0.758989 / 1.367927 = 55.48%\n", 108 | "IO Time: 0.005540 / 1.367927 = 0.40%\n", 109 | "DL Time: 0.591276 / 1.367927 = 43.22%\n", 110 | "torch.Size([1, 170, 857])\n", 111 | "\n" 112 | ] 113 | }, 114 | { 115 | "data": { 116 | "application/vnd.jupyter.widget-view+json": { 117 | "model_id": "239ca9b5506d4667a6e390b5b6507ae0", 118 | "version_major": 2, 119 | "version_minor": 0 120 | }, 121 | "text/plain": [ 122 | "HBox(children=(FloatProgress(value=0.0, max=86.0), HTML(value='')))" 123 | ] 124 | }, 125 | "metadata": {}, 126 | "output_type": "display_data" 127 | }, 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "VT Time: 0.590959 / 0.813022 = 72.69%\n", 133 | "IO Time: 0.005205 / 0.813022 = 0.64%\n", 134 | "DL Time: 0.213748 / 0.813022 = 26.29%\n", 135 | "torch.Size([1, 140, 725])\n", 136 | "\n" 137 | ] 138 | }, 139 | { 140 | "data": { 141 | "application/vnd.jupyter.widget-view+json": { 142 | "model_id": "fc2d69f9b7aa48a5971db851393858cc", 143 | "version_major": 2, 144 | "version_minor": 0 145 | }, 146 | "text/plain": [ 147 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))" 148 | ] 149 | }, 150 | "metadata": {}, 151 | "output_type": "display_data" 152 | }, 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "VT Time: 0.759542 / 1.017219 = 74.67%\n", 158 | "IO Time: 0.024682 / 1.017219 = 2.43%\n", 159 | "DL Time: 0.226169 / 1.017219 = 22.23%\n", 160 | "torch.Size([1, 137, 807])\n", 161 | "\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "datasets = ['train', 'val', 'test']\n", 167 | "batch_size=64\n", 168 | "batch_size = 16\n", 169 | "\n", 170 | "start = time.perf_counter()\n", 171 | "\n", 172 | "for dataset in datasets:\n", 173 | " \n", 174 | " with open(f'filelists/ljs_audio_text_{dataset}_filelist.txt', 'r') as f:\n", 175 | " lines_raw = [line.split('|') for line in f.read().splitlines()]\n", 176 | " lines_list = [ lines_raw[batch_size*i:batch_size*(i+1)] \n", 177 | " for i in range(len(lines_raw)//batch_size+1)]\n", 178 | " \n", 179 | " for batch in tqdm(lines_list):\n", 180 | " \n", 181 | " single_loop_start = time.perf_counter()\n", 182 | " \n", 183 | " file_list, text_list, mel_list = [], [], []\n", 184 | " text_lengths, mel_lengths=[], []\n", 185 | " \n", 186 | " for i in range(len(batch)):\n", 187 | " file_name, _, text = batch[i]\n", 188 | " file_list.append(file_name)\n", 189 | " seq_path = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',\n", 190 | " f'{data_type}_seq')\n", 191 | " mel_path = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',\n", 192 | " 'melspectrogram')\n", 193 | " try:\n", 194 | " seq = torch.from_numpy(np.load(f'{seq_path}/{file_name}_sequence.npy'))\n", 195 | " except FileNotFoundError:\n", 196 | " with open(f'{seq_path}/{file_name}_sequence.pkl', 'rb') as f:\n", 197 | " seq = pkl.load(f)\n", 198 | " \n", 199 | " try:\n", 200 | " mel = torch.from_numpy(np.load(f'{mel_path}/{file_name}_melspectrogram.npy'))\n", 201 | " except FileNotFoundError:\n", 202 | " with open(f'{mel_path}/{file_name}_melspectrogram.pkl', 'rb') as f:\n", 203 | " mel = pkl.load(f)\n", 204 | " \n", 205 | " text_list.append(seq)\n", 206 | " mel_list.append(mel)\n", 207 | " text_lengths.append(seq.size(0))\n", 208 | " mel_lengths.append(mel.size(1))\n", 209 | " \n", 210 | " io_time = time.perf_counter()\n", 211 | " \n", 212 | " text_lengths = torch.LongTensor(text_lengths)\n", 213 | " mel_lengths = torch.LongTensor(mel_lengths)\n", 214 | " text_padded = torch.zeros(len(batch), text_lengths.max().item(), dtype=torch.long)\n", 215 | " mel_padded = torch.zeros(len(batch), hparams.n_mel_channels, mel_lengths.max().item())\n", 216 | " \n", 217 | " for j in range(len(batch)):\n", 218 | " text_padded[j, :text_list[j].size(0)] = text_list[j]\n", 219 | " mel_padded[j, :, :mel_list[j].size(1)] = mel_list[j]\n", 220 | " \n", 221 | " text_padded = text_padded.cuda()\n", 222 | " mel_padded = mel_padded.cuda()\n", 223 | " text_lengths = text_lengths.cuda()\n", 224 | " mel_lengths = mel_lengths.cuda()\n", 225 | " \n", 226 | " with torch.no_grad():\n", 227 | " \n", 228 | " model_start = time.perf_counter()\n", 229 | " \n", 230 | " encoder_input = model.Prenet(text_padded)\n", 231 | " hidden_states, _ = model.FFT_lower(encoder_input, text_lengths)\n", 232 | " mu_sigma = model.get_mu_sigma(hidden_states)\n", 233 | " _, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)\n", 234 | " \n", 235 | " viterbi_start = time.perf_counter()\n", 236 | "\n", 237 | " align = model.viterbi(log_prob_matrix, text_lengths, mel_lengths).to(torch.long)\n", 238 | " alignments = list(torch.split(align,1))\n", 239 | " \n", 240 | " viterbi_end = time.perf_counter()\n", 241 | " \n", 242 | " print('VT Time: ', end=' ')\n", 243 | " print(f'{viterbi_end - viterbi_start:.6f} / {viterbi_end - single_loop_start:.6f} = ' +\n", 244 | " f'{(viterbi_end - viterbi_start) / (viterbi_end - single_loop_start) * 100:5.2f}%')\n", 245 | " \n", 246 | " print('IO Time: ', end=' ')\n", 247 | " print(f'{io_time - single_loop_start:.6f} / {viterbi_end - single_loop_start:.6f} = ' +\n", 248 | " f'{(io_time - single_loop_start) / (viterbi_end - single_loop_start) * 100:5.2f}%')\n", 249 | " \n", 250 | " print('DL Time: ', end=' ')\n", 251 | " print(f'{viterbi_start - model_start:.6f} / {viterbi_end - single_loop_start:.6f} = ' +\n", 252 | " f'{(viterbi_start - model_start) / (viterbi_end - single_loop_start) * 100:5.2f}%')\n", 253 | " \n", 254 | " print(alignments[0].shape)\n", 255 | " \n", 256 | " break\n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | "# for j, (l, t) in enumerate(zip(text_lengths, mel_lengths)):\n", 262 | "# alignments[j] = alignments[j][0, :l.item(), :t.item()].sum(dim=-1)\n", 263 | "# np.save(f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_list[j]}_alignment.npy',\n", 264 | "# alignments[j].detach().cpu().numpy())" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 4, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "torch.Size([16, 137, 807])\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "print(log_prob_matrix.shape)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 7, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "text/plain": [ 292 | "" 293 | ] 294 | }, 295 | "execution_count": 7, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | }, 299 | { 300 | "data": { 301 | "image/png": "\n", 302 | "text/plain": [ 303 | "
" 304 | ] 305 | }, 306 | "metadata": { 307 | "needs_background": "light" 308 | }, 309 | "output_type": "display_data" 310 | } 311 | ], 312 | "source": [ 313 | "plt.imshow(log_prob_matrix[0, :, :].cpu())" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 11, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "text/plain": [ 324 | "" 325 | ] 326 | }, 327 | "execution_count": 11, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | }, 331 | { 332 | "data": { 333 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABBYAAADNCAYAAAAMlPNVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVa0lEQVR4nO3dbYylZ3kf8P/V9Rsmccyal9petzbKikBRWOgK3FAh15vUJkE4HyA1SZMtcbWqRBNIUgWTfiCthBTUKISoLdIKCKaiBseB2KpoDBgQ7QccbHDB4BA2htjLbrxEmJdChXFy9cN5BsbLzO7sc86ZM3Pm95NGc577POecW7pmzpn973XfT3V3AAAAAMb4e4ueAAAAALB9CRYAAACA0QQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGG1uwUJVXVtVn6+qI1V147xeBwAAAFic6u7ZP2nVriR/keSnkhxN8okkr+juz838xQAAAICFmVfHwvOTHOnuB7r70STvTnLdnF4LAAAAWJCz5vS8lyZ5aNXx0SQvWO/kJ+/e1Y8+8iNzmgoAAAAw1jfzyN9091PWu39ewUKtMfa4NRdVdSjJoSQ5L+fnW8cvOuUTXnPJvplNDgAAANiYD/Wtf3Wq++cVLBxNctmq4z1Jjq0+obsPJzmcJBfU7j5dcHDHsXuFCwAAALDFzGuPhU8k2VtVV1TVOUmuT3L7nF4LAAAAWJC5dCx092NV9W+T3JFkV5K3d/dnp3nOay7ZlzuO3fu92wAAAMDizWspRLr7/UneP8vnXAkUBAwAAACwNcwtWJinkwOG1WMAAADA5tmWwcKK1WGCLgYAAADYfPPavBEAAADYAZYmWLjmkn2P2+ARAAAAmL+lCRZWCBcAAABg82zrPRbWM4twwV4NAAAAcHpLGSwk0wcD8+p6EFgAAACwTJZuKQQAAACweZa2Y2Fa8+osWK8TQicDAAAA25FgYZOtFyDccexe4QIAAADbjqUQW4SrWQAAALAd6VjYQjYaLuhsAAAAYKvQsQAAAACMpmNhi9lIN4L9GAAAANgqdCxsQ/ZjAAAAYKsQLGxTK+GCgAEAAIBFGr0UoqouS/LOJH8/yd8lOdzdb66q3Unek+TyJF9K8nPd/cj0U+VkK8shVsIFyyMAAADYbNN0LDyW5De6+5lJrkzyqqp6VpIbk9zZ3XuT3DkcAwAAAEtodMdCdx9Pcny4/c2quj/JpUmuS3LVcNpNST6a5LVTzZJTOrlzYVbPBwAAAKczk6tCVNXlSZ6b5K4kTxtCh3T38ap66ixeg9ObVSAwTUAhlAAAANhZpg4WquqHkvxxktd09zeqaqOPO5TkUJKcl/OnnQYzNE04oGsCAABgZ5kqWKiqszMJFd7V3e8dhh+uqouHboWLk5xY67HdfTjJ4SS5oHb3NPNg65hn14SwAQAAYOsZvXljTVoT3pbk/u7+vVV33Z7k4HD7YJLbxk8PAAAA2Mqm6Vh4YZJfTPKZqlr57+XfSvI7SW6pqhuSPJjk5dNNkZ1ore4EXQwAAABbzzRXhfjfSdbbUOHA2OeF9Ww0bJjl8wMAAHBqM7kqBCzKLMOAWYYUZ0qoAQAAbFeCBRgs8h/3pwo1hA4AAMBWNnrzRgAAAAAdC7AFnKorYaWbQecCAACwFQkWYItbCRQ2ugeEAAIAANhMggXYJjYaGNiEEgAA2EyCBVgyi96EUrgAAAA7i80bAQAAgNEEC8DMXHPJvoUuxQAAADafpRDATG0kXLBcAgAAlodgAZi50wUH9mIAAIDlYSkEsOksmQAAgOUhWAAAAABGsxQCWIiNdi1YMgEAAFubYAFYmI2EBvZjAACArU2wAGxpJ3c2CBkAAGBrmTpYqKpdSe5O8uXufklVXZHk3Ul2J/lkkl/s7kenfR1g51odJsx600dBBQAATGcWmze+Osn9q47fmORN3b03ySNJbpjBawAAAABb0FQdC1W1J8nPJHlDkl+vqkpydZKfH065KclvJ3nLNK8DsGLWHQb2cAAAgOlMuxTi95P8ZpIfHo4vSvK17n5sOD6a5NIpXwNgbta7OoWwAQAANmZ0sFBVL0lyorvvqaqrVobXOLXXefyhJIeS5LycP3YaAFNbK0TQyQAAABszzR4LL0zy0qr6UiabNV6dSQfDhVW1EljsSXJsrQd39+Hu3t/d+8/OuVNMA2D21utkAAAAHm90sNDdr+vuPd19eZLrk3y4u38hyUeSvGw47WCS26aeJQAAALAlzeKqECd7bSYbOR7JZM+Ft83hNQDmbqVrQecCAACsb9rNG5Mk3f3RJB8dbj+Q5PmzeF6ARVvZZ8GeCwAAsLaZBAsAy26aPRcEEgAALDPBAsAGjQ0IdDsAALDM5rHHAgAAALBDCBYA5sylKwEAWGaWQgBsgnmFC5ZYAACwaIIFgE0yjxDgTMIKIQQAAPMgWADYxs4kLFgJIQQMAADMkj0WAAAAgNF0LADsECudCqdaPqGbAQCAMyVYANhhThUe3HHsXuECAABnxFIIAL5n5eoVLo8JAMBGCRYAeJxrLtk3t8tjAgCwfAQLAAAAwGj2WABgTWfStWBfBgCAnUuwAMC6NhoYTLNsQigBALC9CRYAmNo04cB6oYTAAQBge5gqWKiqC5O8Ncmzk3SSX07y+STvSXJ5ki8l+bnufmSqWQKwtNYLEOa1eaTAAgBgtqbdvPHNSf60u38syXOS3J/kxiR3dvfeJHcOxwAAAMASGt2xUFUXJHlRkn+VJN39aJJHq+q6JFcNp92U5KNJXjvNJAHYeebVWbBWJ4QuBgCA8aZZCvH0JF9J8odV9Zwk9yR5dZKndffxJOnu41X11OmnCQCzsVaI4OoXAADjTRMsnJXkeUl+pbvvqqo35wyWPVTVoSSHkuS8nD/FNABgOq5+AQAw3jTBwtEkR7v7ruH41kyChYer6uKhW+HiJCfWenB3H05yOEkuqN09xTwAYFPM4+oX83xNAIDNMHrzxu7+6yQPVdUzhqEDST6X5PYkB4exg0lum2qGAAAAwJY11eUmk/xKkndV1TlJHkjyykzCiluq6oYkDyZ5+ZSvAQDb3tjOg9WdDroXAICtaKpgobvvTbJ/jbsOTPO8AMDE6jDhjmP3ChcAgC1n2o4FAGCTXHPJvlF7NQgjAIB5EiwAwDYyJiRYK4wQNgAAszJ680YAAAAAHQsAsOTW6k5Y6WLQuQAATEuwAAA70EqgIGAAAKZlKQQA7GDXXLJv9KaQAACJYAEAyPevOCFgAADOlGABAAAAGM0eCwBAkh/cd2Gt+wAATiZYAAAeZ72rSAgXAIC1WAoBAJyWPRgAgPXoWAAANuRUSyXWOxcAWH46FgAAAIDRdCwAAGdkI90Iq7sadC8AwHITLAAAM7c6TLDxIwAsN0shAIC5Wtn4EQBYTlMFC1X1a1X12aq6r6purqrzquqKqrqrqr5QVe+pqnNmNVkAYHsSLgDA8hodLFTVpUl+Ncn+7n52kl1Jrk/yxiRv6u69SR5JcsMsJgoAAABsPdPusXBWkidU1XeTnJ/keJKrk/z8cP9NSX47yVumfB0AYJubZdeCPRsAYOsYHSx095er6neTPJjk/yX5QJJ7knytux8bTjua5NK1Hl9Vh5IcSpLzcv7YaQAA28isAoFTBRRCBwDYXKODhap6UpLrklyR5GtJ/ijJi9c4tdd6fHcfTnI4SS6o3WueAwCwllOFB65CAQCba5rNG38yyRe7+yvd/d0k703yE0kurKqVwGJPkmNTzhEAYMNWllzYLBIANsc0wcKDSa6sqvOrqpIcSPK5JB9J8rLhnINJbptuigAAAMBWNTpY6O67ktya5JNJPjM81+Ekr03y61V1JMlFSd42g3kCAGzYNZfsc4lLANgk1b347Q0uqN39gjqw6GkAAEvIngsAMJ0P9a33dPf+9e6fZikEAMCWp3MBAOZLsAAALD0bOgLA/AgWAAAAgNEECwDAjmBDRwCYj7MWPQEAgM10qnDBJo8AcOYECwDAjrNegOAKEgBw5iyFAAAY2OQRAM6cYAEAAAAYzVIIAIBVLIUAgDOjYwEAAAAYTbAAAAAAjCZYAAAAAEYTLAAAAACjCRYAAACA0QQLAAAAwGiCBQAAAGC00wYLVfX2qjpRVfetGttdVR+sqi8M3580jFdV/UFVHamqT1fV8+Y5eQAAAGCxNtKx8I4k1540dmOSO7t7b5I7h+MkeXGSvcPXoSRvmc00AQAAgK3otMFCd38syVdPGr4uyU3D7ZuS/Oyq8Xf2xMeTXFhVF89qsgAAAMDWMnaPhad19/EkGb4/dRi/NMlDq847Ooz9gKo6VFV3V9Xd3813Rk4DAAAAWKRZb95Ya4z1Wid29+Hu3t/d+8/OuTOeBgAAALAZxgYLD68scRi+nxjGjya5bNV5e5IcGz89AAAAYCsbGyzcnuTgcPtgkttWjf/ScHWIK5N8fWXJBAAAALB8zjrdCVV1c5Krkjy5qo4meX2S30lyS1XdkOTBJC8fTn9/kp9OciTJt5O8cg5zBgAAALaI0wYL3f2Kde46sMa5neRV004KAAAA2B5mvXkjAAAAsIMIFgAAAIDRBAsAAADAaIIFAAAAYDTBAgAAADCaYAEAAAAYTbAAAAAAjCZYAAAAAEYTLAAAAACjCRYAAACA0QQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGE2wAAAAAIwmWAAAAABGO22wUFVvr6oTVXXfqrH/VFV/XlWfrqr3VdWFq+57XVUdqarPV9U185o4AAAAsHgb6Vh4R5JrTxr7YJJnd/ePJ/mLJK9Lkqp6VpLrk/yj4TH/tap2zWy2AAAAwJZy2mChuz+W5KsnjX2gux8bDj+eZM9w+7ok7+7u73T3F5McSfL8Gc4XAAAA2EJmscfCLyf5n8PtS5M8tOq+o8PYD6iqQ1V1d1Xd/d18ZwbTAAAAADbbVMFCVf37JI8ledfK0Bqn9VqP7e7D3b2/u/efnXOnmQYAAACwIGeNfWBVHUzykiQHunslPDia5LJVp+1Jcmz89AAAAICtbFTHQlVdm+S1SV7a3d9eddftSa6vqnOr6ooke5P82fTTBAAAALai03YsVNXNSa5K8uSqOprk9ZlcBeLcJB+sqiT5eHf/m+7+bFXdkuRzmSyReFV3/+28Jg8AAAAsVn1/FcPiXFC7+wV1YNHTAAAAAE7yob71nu7ev979s7gqBAAAALBDCRYAAACA0QQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGE2wAAAAAIwmWAAAAABGEywAAAAAowkWAAAAgNEECwAAAMBoggUAAABgNMECAAAAMJpgAQAAABhNsAAAAACMVt296Dmkqr6S5FtJ/mbRc2HTPDnqvdOo+c6i3juPmu8s6r3zqPnOot47z+lq/g+7+ynr3bklgoUkqaq7u3v/oufB5lDvnUfNdxb13nnUfGdR751HzXcW9d55pq25pRAAAADAaIIFAAAAYLStFCwcXvQE2FTqvfOo+c6i3juPmu8s6r3zqPnOot47z1Q13zJ7LAAAAADbz1bqWAAAAAC2mYUHC1V1bVV9vqqOVNWNi54Ps1FVb6+qE1V136qx3VX1war6wvD9ScN4VdUfDD8Dn66q5y1u5oxRVZdV1Ueq6v6q+mxVvXoYV/MlVVXnVdWfVdX/GWr+H4bxK6rqrqHm76mqc4bxc4fjI8P9ly9y/oxTVbuq6lNV9T+GY/VeYlX1par6TFXdW1V3D2Pe15dUVV1YVbdW1Z8Pn+f/RL2XU1U9Y/i9Xvn6RlW9Rr2XW1X92vA3231VdfPwt9zMPscXGixU1a4k/yXJi5M8K8krqupZi5wTM/OOJNeeNHZjkju7e2+SO4fjZFL/vcPXoSRv2aQ5MjuPJfmN7n5mkiuTvGr4XVbz5fWdJFd393OS7EtybVVdmeSNSd401PyRJDcM59+Q5JHu/tEkbxrOY/t5dZL7Vx2r9/L7Z929b9UlyLyvL683J/nT7v6xJM/J5HddvZdQd39++L3el+QfJ/l2kvdFvZdWVV2a5FeT7O/uZyfZleT6zPBzfNEdC89PcqS7H+juR5O8O8l1C54TM9DdH0vy1ZOGr0ty03D7piQ/u2r8nT3x8SQXVtXFmzNTZqG7j3f3J4fb38zkj5FLo+ZLa6jd/x0Ozx6+OsnVSW4dxk+u+crPwq1JDlRVbdJ0mYGq2pPkZ5K8dTiuqPdO5H19CVXVBUlelORtSdLdj3b316LeO8GBJH/Z3X8V9V52ZyV5QlWdleT8JMczw8/xRQcLlyZ5aNXx0WGM5fS07j6eTP4hmuSpw7ifgyUytEo9N8ldUfOlNrTF35vkRJIPJvnLJF/r7seGU1bX9Xs1H+7/epKLNnfGTOn3k/xmkr8bji+Kei+7TvKBqrqnqg4NY97Xl9PTk3wlyR8Oy53eWlVPjHrvBNcnuXm4rd5Lqru/nOR3kzyYSaDw9ST3ZIaf44sOFtZKPVymYufxc7AkquqHkvxxktd09zdOdeoaY2q+zXT33w5tlHsy6UB75lqnDd/VfBurqpckOdHd96weXuNU9V4uL+zu52XSBv2qqnrRKc5V8+3trCTPS/KW7n5ukm/l+23wa1HvJTCsp39pkj863alrjKn3NjLsl3FdkiuSXJLkiZm8t59s9Of4ooOFo0kuW3W8J8mxBc2F+Xt4pW1q+H5iGPdzsASq6uxMQoV3dfd7h2E13wGGdtmPZrK/xoVDi13y+Lp+r+bD/T+SH1wuxdb1wiQvraovZbJs8epMOhjUe4l197Hh+4lM1l8/P97Xl9XRJEe7+67h+NZMggb1Xm4vTvLJ7n54OFbv5fWTSb7Y3V/p7u8meW+Sn8gMP8cXHSx8IsneYTfKczJpxbl9wXNifm5PcnC4fTDJbavGf2nYcfbKJF9facNiexjWXL0tyf3d/Xur7lLzJVVVT6mqC4fbT8jkA+v+JB9J8rLhtJNrvvKz8LIkH+5u/9uxTXT367p7T3dfnsln9Ye7+xei3kurqp5YVT+8cjvJP09yX7yvL6Xu/uskD1XVM4ahA0k+F/Vedq/I95dBJOq9zB5McmVVnT/83b7yOz6zz/Fa9Od8Vf10Jv/rsSvJ27v7DQudEDNRVTcnuSrJk5M8nOT1Sf4kyS1J/kEmP9wv7+6vDj/c/zmTq0h8O8kru/vuRcybcarqnyb5X0k+k++vv/6tTPZZUPMlVFU/nsmmPrsyCalv6e7/WFVPz+R/tHcn+VSSf9nd36mq85L8t0z23/hqkuu7+4HFzJ5pVNVVSf5dd79EvZfXUNv3DYdnJfnv3f2Gqroo3teXUlXty2Rz1nOSPJDklRne36PeS6eqzs9kDf3Tu/vrw5jf7yVWk0uD/4tMrub2qST/OpO9FGbyOb7wYAEAAADYvha9FAIAAADYxgQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGE2wAAAAAIwmWAAAAABGEywAAAAAo/1/hvu7dg2PYBQAAAAASUVORK5CYII=\n", 334 | "text/plain": [ 335 | "
" 336 | ] 337 | }, 338 | "metadata": { 339 | "needs_background": "light" 340 | }, 341 | "output_type": "display_data" 342 | } 343 | ], 344 | "source": [ 345 | "plt.figure(figsize=(18, 18))\n", 346 | "plt.imshow(alignments[0][0, :, :].cpu())" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [] 355 | } 356 | ], 357 | "metadata": { 358 | "kernelspec": { 359 | "display_name": "Python 3", 360 | "language": "python", 361 | "name": "python3" 362 | }, 363 | "language_info": { 364 | "codemirror_mode": { 365 | "name": "ipython", 366 | "version": 3 367 | }, 368 | "file_extension": ".py", 369 | "mimetype": "text/x-python", 370 | "name": "python", 371 | "nbconvert_exporter": "python", 372 | "pygments_lexer": "ipython3", 373 | "version": "3.8.3" 374 | } 375 | }, 376 | "nbformat": 4, 377 | "nbformat_minor": 2 378 | } 379 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | tensorboard 3 | numpy==1.17.4 4 | inflect==4.0.0 5 | librosa==0.7.1 6 | scipy==1.3.2 7 | Unidecode==1.1.1 8 | pillow==6.2.1 9 | g2p_en==2.0.0 -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann'): 46 | super(STFT, self).__init__() 47 | self.filter_length = filter_length 48 | self.hop_length = hop_length 49 | self.win_length = win_length 50 | self.window = window 51 | self.forward_transform = None 52 | scale = self.filter_length / self.hop_length 53 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 54 | 55 | cutoff = int((self.filter_length / 2 + 1)) 56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 57 | np.imag(fourier_basis[:cutoff, :])]) 58 | 59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 60 | inverse_basis = torch.FloatTensor( 61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 62 | 63 | if window is not None: 64 | assert(filter_length >= win_length) 65 | # get window and zero center pad it to filter_length 66 | fft_window = get_window(window, win_length, fftbins=True) 67 | fft_window = pad_center(fft_window, filter_length) 68 | fft_window = torch.from_numpy(fft_window).float() 69 | 70 | # window the bases 71 | forward_basis *= fft_window 72 | inverse_basis *= fft_window 73 | 74 | self.register_buffer('forward_basis', forward_basis.float()) 75 | self.register_buffer('inverse_basis', inverse_basis.float()) 76 | 77 | def transform(self, input_data): 78 | num_batches = input_data.size(0) 79 | num_samples = input_data.size(1) 80 | 81 | self.num_samples = num_samples 82 | 83 | # similar to librosa, reflect-pad the input 84 | input_data = input_data.view(num_batches, 1, num_samples) 85 | input_data = F.pad( 86 | input_data.unsqueeze(1), 87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 88 | mode='reflect') 89 | input_data = input_data.squeeze(1) 90 | 91 | forward_transform = F.conv1d( 92 | input_data, 93 | Variable(self.forward_basis, requires_grad=False), 94 | stride=self.hop_length, 95 | padding=0) 96 | 97 | cutoff = int((self.filter_length / 2) + 1) 98 | real_part = forward_transform[:, :cutoff, :] 99 | imag_part = forward_transform[:, cutoff:, :] 100 | 101 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 102 | phase = torch.autograd.Variable( 103 | torch.atan2(imag_part.data, real_part.data)) 104 | 105 | return magnitude, phase 106 | 107 | def inverse(self, magnitude, phase): 108 | recombine_magnitude_phase = torch.cat( 109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 110 | 111 | inverse_transform = F.conv_transpose1d( 112 | recombine_magnitude_phase, 113 | Variable(self.inverse_basis, requires_grad=False), 114 | stride=self.hop_length, 115 | padding=0) 116 | 117 | if self.window is not None: 118 | window_sum = window_sumsquare( 119 | self.window, magnitude.size(-1), hop_length=self.hop_length, 120 | win_length=self.win_length, n_fft=self.filter_length, 121 | dtype=np.float32) 122 | # remove modulation effects 123 | approx_nonzero_indices = torch.from_numpy( 124 | np.where(window_sum > tiny(window_sum))[0]) 125 | window_sum = torch.autograd.Variable( 126 | torch.from_numpy(window_sum), requires_grad=False) 127 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 128 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 129 | 130 | # scale by hop ratio 131 | inverse_transform *= float(self.filter_length) / self.hop_length 132 | 133 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 134 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 135 | 136 | return inverse_transform 137 | 138 | def forward(self, input_data): 139 | self.magnitude, self.phase = self.transform(input_data) 140 | reconstruction = self.inverse(self.magnitude, self.phase) 141 | return reconstruction 142 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | # -*- coding: utf-8 -*- 3 | 4 | import re 5 | from text import cleaners 6 | from text.symbols import symbols 7 | 8 | # Mappings from symbol to numeric ID and vice versa: 9 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 10 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 11 | 12 | # Regular expression matching text enclosed in curly braces: 13 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 14 | 15 | 16 | def text_to_sequence(text, cleaner_names): 17 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 18 | 19 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 20 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 21 | 22 | Args: 23 | text: string to convert to a sequence 24 | cleaner_names: names of the cleaner functions to run the text through 25 | 26 | Returns: 27 | List of integers corresponding to the symbols in the text 28 | ''' 29 | sequence = [_symbol_to_id['^']] 30 | 31 | # Check for curly braces and treat their contents as ARPAbet: 32 | while len(text): 33 | m = _curly_re.match(text) 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 38 | sequence += _arpabet_to_sequence(m.group(2)) 39 | text = m.group(3) 40 | 41 | # Append EOS token 42 | sequence.append(_symbol_to_id['~']) 43 | return sequence 44 | 45 | 46 | def sequence_to_text(sequence): 47 | '''Converts a sequence of IDs back to a string''' 48 | result = '' 49 | for symbol_id in sequence: 50 | if symbol_id in _id_to_symbol: 51 | s = _id_to_symbol[symbol_id] 52 | # Enclose ARPAbet back in curly braces: 53 | if len(s) > 1 and s[0] == '@': 54 | s = '{%s}' % s[1:] 55 | result += s 56 | return result.replace('}{', ' ') 57 | 58 | 59 | def _clean_text(text, cleaner_names): 60 | for name in cleaner_names: 61 | cleaner = getattr(cleaners, name) 62 | if not cleaner: 63 | raise Exception('Unknown cleaner: %s' % name) 64 | text = cleaner(text) 65 | return text 66 | 67 | 68 | def _symbols_to_sequence(symbols): 69 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 70 | 71 | 72 | def _arpabet_to_sequence(text): 73 | return _symbols_to_sequence(['@' + s for s in text.split()]) 74 | 75 | 76 | def _should_keep_symbol(s): 77 | return s in _symbol_to_id and s is not '_' and s is not '~' 78 | -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cleaners.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cleaners.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cmudict.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cmudict.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/numbers.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/numbers.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/symbols.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/symbols.cpython-37.pyc -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """This file is derived from https://github.com/keithito/tacotron. 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import re 15 | 16 | from unidecode import unidecode 17 | 18 | from text.numbers import normalize_numbers 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | 92 | 93 | # NOTE (kan-bayashi): Following functions additionally defined, not inclueded in original codes. 94 | def remove_unnecessary_symbols(text): 95 | # added 96 | text = re.sub(r'[\(\)\[\]\<\>\"]+', '', text) 97 | return text 98 | 99 | 100 | def expand_symbols(text): 101 | # added 102 | text = re.sub("\;", ",", text) 103 | text = re.sub("\:", ",", text) 104 | text = re.sub("\-", " ", text) 105 | text = re.sub("\&", "and", text) 106 | return text 107 | 108 | 109 | def uppercase(text): 110 | # added 111 | return text.upper() 112 | 113 | 114 | def custom_english_cleaners(text): 115 | '''Custom pipeline for English text, including number and abbreviation expansion.''' 116 | text = convert_to_ascii(text) 117 | text = lowercase(text) 118 | text = expand_numbers(text) 119 | text = expand_abbreviations(text) 120 | text = expand_symbols(text) 121 | text = remove_unnecessary_symbols(text) 122 | text = uppercase(text) 123 | text = collapse_whitespace(text) 124 | 125 | # There is an exception (I found it!) 126 | # "'NOW FOR YOU, MY POOR FELLOW MORTALS, WHO ARE ABOUT TO SUFFER THE LAST PENALTY OF THE LAW.'" 127 | if text[0]=="'" and text[-1]=="'": 128 | text = text[1:-1] 129 | 130 | return text 131 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ from https://github.com/keithito/tacotron """ 3 | 4 | import inflect 5 | import re 6 | 7 | 8 | _inflect = inflect.engine() 9 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 10 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 11 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 12 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 13 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 14 | _number_re = re.compile(r'[0-9]+') 15 | 16 | 17 | def _remove_commas(m): 18 | return m.group(1).replace(',', '') 19 | 20 | 21 | def _expand_decimal_point(m): 22 | return m.group(1).replace('.', ' point ') 23 | 24 | 25 | def _expand_dollars(m): 26 | match = m.group(1) 27 | parts = match.split('.') 28 | if len(parts) > 2: 29 | return match + ' dollars' # Unexpected format 30 | dollars = int(parts[0]) if parts[0] else 0 31 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 32 | if dollars and cents: 33 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 34 | cent_unit = 'cent' if cents == 1 else 'cents' 35 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 36 | elif dollars: 37 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 38 | return '%s %s' % (dollars, dollar_unit) 39 | elif cents: 40 | cent_unit = 'cent' if cents == 1 else 'cents' 41 | return '%s %s' % (cents, cent_unit) 42 | else: 43 | return 'zero dollars' 44 | 45 | 46 | def _expand_ordinal(m): 47 | return _inflect.number_to_words(m.group(0)) 48 | 49 | 50 | def _expand_number(m): 51 | num = int(m.group(0)) 52 | if num > 1000 and num < 3000: 53 | if num == 2000: 54 | return 'two thousand' 55 | elif num > 2000 and num < 2010: 56 | return 'two thousand ' + _inflect.number_to_words(num % 100) 57 | elif num % 100 == 0: 58 | return _inflect.number_to_words(num // 100) + ' hundred' 59 | else: 60 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 61 | else: 62 | return _inflect.number_to_words(num, andword='') 63 | 64 | 65 | def normalize_numbers(text): 66 | text = re.sub(_comma_number_re, _remove_commas, text) 67 | text = re.sub(_pounds_re, r'\1 pounds', text) 68 | text = re.sub(_dollars_re, _expand_dollars, text) 69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 70 | text = re.sub(_ordinal_re, _expand_ordinal, text) 71 | text = re.sub(_number_re, _expand_number, text) 72 | return text 73 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _sos = '^' 11 | _eos = '~' 12 | _punctuations = " ,.'?!" 13 | _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 14 | 15 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 16 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 17 | 18 | # Export all symbols: 19 | symbols = [_pad, _sos, _eos] + list(_punctuations) + list(_characters) + _arpabet -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modules.model import Model 6 | from modules.loss import MDNLoss 7 | import hparams 8 | from text import * 9 | from utils.utils import * 10 | from utils.writer import get_writer 11 | from torch.utils.tensorboard import SummaryWriter 12 | import math 13 | import matplotlib.pyplot as plt 14 | from datetime import datetime 15 | from glob import glob 16 | 17 | 18 | def validate(model, criterion, val_loader, iteration, writer, stage): 19 | model.eval() 20 | with torch.no_grad(): 21 | n_data, val_loss = 0, 0 22 | for i, batch in enumerate(val_loader): 23 | n_data += len(batch[0]) 24 | 25 | if stage==0: 26 | text_padded, mel_padded, text_lengths, mel_lengths = [ 27 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 28 | ] 29 | else: 30 | text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [ 31 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 32 | ] 33 | 34 | if stage !=3: 35 | encoder_input = model.module.Prenet(text_padded) 36 | hidden_states, _ = model.module.FFT_lower(encoder_input, text_lengths) 37 | 38 | if stage==0: 39 | mu_sigma = model.module.get_mu_sigma(hidden_states) 40 | loss, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths) 41 | 42 | elif stage==1: 43 | mel_out = model.module.get_melspec(hidden_states, align_padded, mel_lengths) 44 | mel_mask = ~get_mask_from_lengths(mel_lengths) 45 | mel_padded_selected = mel_padded.masked_select(mel_mask.unsqueeze(1)) 46 | mel_out_selected = mel_out.masked_select(mel_mask.unsqueeze(1)) 47 | loss = nn.L1Loss()(mel_out_selected, mel_padded_selected) 48 | 49 | elif stage==2: 50 | mu_sigma = model.module.get_mu_sigma(hidden_states) 51 | mdn_loss, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths) 52 | 53 | align = model.module.viterbi(log_prob_matrix, text_lengths, mel_lengths) 54 | mel_out = model.module.get_melspec(hidden_states, align, mel_lengths) 55 | mel_mask = ~get_mask_from_lengths(mel_lengths) 56 | mel_padded_selected = mel_padded.masked_select(mel_mask.unsqueeze(1)) 57 | mel_out_selected = mel_out.masked_select(mel_mask.unsqueeze(1)) 58 | fft_loss = nn.L1Loss()(mel_out_selected, mel_padded_selected) 59 | loss = mdn_loss + fft_loss 60 | 61 | elif stage==3: 62 | duration_out = model.module.get_duration(text_padded, text_lengths) # gradient cut 63 | duration_target = align_padded.sum(-1) 64 | duration_mask = ~get_mask_from_lengths(text_lengths) 65 | duration_out = duration_out.masked_select(duration_mask) 66 | duration_target = duration_target.masked_select(duration_mask) 67 | loss = nn.MSELoss()(torch.log(duration_out), torch.log(duration_target)) 68 | 69 | val_loss += loss.item() * len(batch[0]) 70 | 71 | val_loss /= n_data 72 | 73 | if stage==0: 74 | writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation) 75 | 76 | align = model.module.viterbi(log_prob_matrix[0:1], text_lengths[0:1], mel_lengths[0:1]) # 1, L, T 77 | mel_out = torch.matmul(align[0].t(), mu_sigma[0, :, :hparams.n_mel_channels]).t() # F, T 78 | 79 | writer.add_image('Validation_alignments', align.detach().cpu(), iteration//hparams.accumulation) 80 | writer.add_specs(mel_padded[0].detach().cpu(), 81 | mel_out.detach().cpu(), 82 | iteration//hparams.accumulation, 'Validation') 83 | elif stage==1: 84 | writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation) 85 | writer.add_specs(mel_padded[0].detach().cpu(), 86 | mel_out[0].detach().cpu(), 87 | iteration//hparams.accumulation, 'Validation') 88 | elif stage==2: 89 | writer.add_scalar('Validation mdn_loss', mdn_loss.item(), iteration//hparams.accumulation) 90 | writer.add_scalar('Validation fft_loss', fft_loss.item(), iteration//hparams.accumulation) 91 | writer.add_image('Validation_alignments', 92 | align[0:1, :text_lengths[0], :mel_lengths[0]].detach().cpu(), 93 | iteration//hparams.accumulation) 94 | writer.add_specs(mel_padded[0].detach().cpu(), 95 | mel_out[0].detach().cpu(), 96 | iteration//hparams.accumulation, 'Validation') 97 | elif stage==3: 98 | writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation) 99 | 100 | model.train() 101 | 102 | 103 | def main(args): 104 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams, stage=args.stage) 105 | 106 | if args.stage!=0: 107 | checkpoint_path = f"training_log/aligntts/stage{args.stage-1}/checkpoint_{hparams.train_steps[args.stage-1]}" 108 | 109 | if not os.path.isfile(checkpoint_path): 110 | print(f'{checkpoint_path} does not exist') 111 | checkpoint_path = sorted(glob(f"training_log/aligntts/stage{args.stage-1}/checkpoint_*"))[-1] 112 | print(f'Loading {checkpoint_path} instead') 113 | 114 | state_dict = {} 115 | for k, v in torch.load(checkpoint_path)['state_dict'].items(): 116 | state_dict[k[7:]]=v 117 | 118 | model = Model(hparams).cuda() 119 | model.load_state_dict(state_dict) 120 | model = nn.DataParallel(model).cuda() 121 | else: 122 | model = nn.DataParallel(Model(hparams)).cuda() 123 | 124 | criterion = MDNLoss() 125 | writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage{args.stage}') 126 | optimizer = torch.optim.Adam(model.parameters(), 127 | lr=hparams.lr, 128 | betas=(0.9, 0.98), 129 | eps=1e-09) 130 | iteration, loss = 0, 0 131 | model.train() 132 | 133 | print(f'Stage{args.stage} Start!!! ({str(datetime.now())})') 134 | while True: 135 | for i, batch in enumerate(train_loader): 136 | if args.stage==0: 137 | text_padded, mel_padded, text_lengths, mel_lengths = [ 138 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 139 | ] 140 | align_padded=None 141 | else: 142 | text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [ 143 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 144 | ] 145 | 146 | sub_loss = model(text_padded, 147 | mel_padded, 148 | align_padded, 149 | text_lengths, 150 | mel_lengths, 151 | criterion, 152 | stage=args.stage, 153 | log_viterbi=args.log_viterbi, 154 | cpu_viterbi=args.cpu_viterbi) 155 | sub_loss = sub_loss.mean()/hparams.accumulation 156 | sub_loss.backward() 157 | loss = loss+sub_loss.item() 158 | iteration += 1 159 | 160 | print(f'[{str(datetime.now())}] Stage {args.stage} Iter {iteration:<6d} Loss {loss:<8.6f}') 161 | 162 | if iteration%hparams.accumulation == 0: 163 | lr_scheduling(optimizer, iteration//hparams.accumulation) 164 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) 165 | optimizer.step() 166 | model.zero_grad() 167 | writer.add_scalar('Train loss', loss, iteration//hparams.accumulation) 168 | loss=0 169 | 170 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0: 171 | validate(model, criterion, val_loader, iteration, writer, args.stage) 172 | 173 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0: 174 | save_checkpoint(model, 175 | optimizer, 176 | hparams.lr, 177 | iteration//hparams.accumulation, 178 | filepath=f'{hparams.output_directory}/{hparams.log_directory}/stage{args.stage}') 179 | 180 | if iteration==(hparams.train_steps[args.stage]*hparams.accumulation): 181 | break 182 | 183 | if iteration==(hparams.train_steps[args.stage]*hparams.accumulation): 184 | break 185 | 186 | print(f'Stage{args.stage} End!!! ({str(datetime.now())})') 187 | 188 | 189 | if __name__ == '__main__': 190 | p = argparse.ArgumentParser() 191 | p.add_argument('--gpu', type=str, default='0,1') 192 | p.add_argument('-v', '--verbose', type=str, default='0') 193 | p.add_argument('--stage', type=int, required=True) 194 | p.add_argument('--log_viterbi', type=bool, default=False) 195 | p.add_argument('--cpu_viterbi', type=bool, default=False) 196 | args = p.parse_args() 197 | 198 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 199 | torch.manual_seed(hparams.seed) 200 | torch.cuda.manual_seed(hparams.seed) 201 | 202 | if args.verbose=='0': 203 | import warnings 204 | warnings.filterwarnings("ignore") 205 | 206 | main(args) -------------------------------------------------------------------------------- /training_log/readme.txt: -------------------------------------------------------------------------------- 1 | download 'waveglow_256channels.pt' here from https://github.com/NVIDIA/waveglow 2 | mkdir aligntts here -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/plot_image.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/plot_image.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/writer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/writer.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/writer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/writer.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import hparams 4 | import torch 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | import os 8 | import pickle as pkl 9 | 10 | from text import text_to_sequence 11 | 12 | 13 | def load_filepaths_and_text(metadata, split="|"): 14 | with open(metadata, encoding='utf-8') as f: 15 | filepaths_and_text = [line.strip().split(split) for line in f] 16 | return filepaths_and_text 17 | 18 | 19 | class TextMelSet(torch.utils.data.Dataset): 20 | def __init__(self, audiopaths_and_text, hparams, stage): 21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 22 | self.data_type=hparams.data_type 23 | self.stage=stage 24 | 25 | self.text_dataset = [] 26 | self.align_dataset = [] 27 | seq_path = os.path.join(hparams.data_path, self.data_type) 28 | align_path = os.path.join(hparams.data_path, 'alignments') 29 | for data in self.audiopaths_and_text: 30 | file_name = data[0][:10] 31 | text = torch.from_numpy(np.load(f'{seq_path}/{file_name}_sequence.npy')) 32 | self.text_dataset.append(text) 33 | 34 | if stage !=0: 35 | align = torch.from_numpy(np.load(f'{align_path}/{file_name}_alignment.npy')) 36 | self.align_dataset.append(align) 37 | 38 | 39 | def get_mel_text_pair(self, index): 40 | file_name = self.audiopaths_and_text[index][0][:10] 41 | 42 | text = self.text_dataset[index] 43 | mel_path = os.path.join(hparams.data_path, 'melspectrogram') 44 | mel = torch.from_numpy(np.load(f'{mel_path}/{file_name}_melspectrogram.npy')) 45 | 46 | if self.stage == 0: 47 | return (text, mel) 48 | 49 | else: 50 | align = self.align_dataset[index] 51 | align = torch.repeat_interleave(torch.eye(len(align)).to(torch.long), 52 | align, 53 | dim=1) 54 | return (text, mel, align) 55 | 56 | def __getitem__(self, index): 57 | return self.get_mel_text_pair(index) 58 | 59 | def __len__(self): 60 | return len(self.audiopaths_and_text) 61 | 62 | 63 | class TextMelCollate(): 64 | def __init__(self, stage): 65 | self.stage=stage 66 | return 67 | 68 | def __call__(self, batch): 69 | # Right zero-pad all one-hot text sequences to max input length 70 | input_lengths, ids_sorted_decreasing = torch.sort( 71 | torch.LongTensor([len(x[0]) for x in batch]), 72 | dim=0, descending=True) 73 | max_input_len = input_lengths[0] 74 | max_target_len = max([x[1].size(1) for x in batch]) 75 | num_mels = batch[0][1].size(0) 76 | 77 | if self.stage==0: 78 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long) 79 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len) 80 | output_lengths = torch.LongTensor(len(batch)) 81 | 82 | for i in range(len(ids_sorted_decreasing)): 83 | text = batch[ids_sorted_decreasing[i]][0] 84 | text_padded[i, :text.size(0)] = text 85 | mel = batch[ids_sorted_decreasing[i]][1] 86 | mel_padded[i, :, :mel.size(1)] = mel 87 | output_lengths[i] = mel.size(1) 88 | 89 | mel_padded = (torch.clamp(mel_padded, hparams.min_db, hparams.max_db)-hparams.min_db) / (hparams.max_db-hparams.min_db) 90 | 91 | return text_padded, mel_padded, input_lengths, output_lengths 92 | 93 | 94 | else: 95 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long) 96 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len) 97 | align_padded = torch.zeros(len(batch), max_input_len, max_target_len) 98 | output_lengths = torch.LongTensor(len(batch)) 99 | 100 | for i in range(len(ids_sorted_decreasing)): 101 | text = batch[ids_sorted_decreasing[i]][0] 102 | text_padded[i, :text.size(0)] = text 103 | mel = batch[ids_sorted_decreasing[i]][1] 104 | mel_padded[i, :, :mel.size(1)] = mel 105 | output_lengths[i] = mel.size(1) 106 | align = batch[ids_sorted_decreasing[i]][2] 107 | align_padded[i, :align.size(0), :align.size(1)] = align 108 | 109 | mel_padded = (torch.clamp(mel_padded, hparams.min_db, hparams.max_db)-hparams.min_db) / (hparams.max_db-hparams.min_db) 110 | 111 | return text_padded, mel_padded, align_padded, input_lengths, output_lengths 112 | 113 | 114 | -------------------------------------------------------------------------------- /utils/plot_image.py: -------------------------------------------------------------------------------- 1 | from text import * 2 | import torch 3 | import hparams 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def plot_melspec(mel_target, mel_out): 8 | fig, axes = plt.subplots(2, 1, figsize=(20,20)) 9 | 10 | axes[0].imshow(mel_target, 11 | origin='lower', 12 | aspect='auto') 13 | 14 | axes[1].imshow(mel_out, 15 | origin='lower', 16 | aspect='auto') 17 | 18 | return fig 19 | 20 | 21 | def plot_alignments(alignments, text, mel_lengths, text_lengths, att_type): 22 | fig, axes = plt.subplots(hparams.n_layers, hparams.n_heads, figsize=(5*hparams.n_heads,5*hparams.n_layers)) 23 | L, T = text_lengths[-1], mel_lengths[-1] 24 | n_layers, n_heads = alignments.size(1), alignments.size(2) 25 | 26 | for layer in range(n_layers): 27 | for head in range(n_heads): 28 | if att_type=='enc': 29 | align = alignments[-1, layer, head].contiguous() 30 | axes[layer,head].imshow(align[:L, :L], aspect='auto') 31 | axes[layer,head].xaxis.tick_top() 32 | 33 | elif att_type=='dec': 34 | align = alignments[-1, layer, head].contiguous() 35 | axes[layer,head].imshow(align[:T, :T], aspect='auto') 36 | axes[layer,head].xaxis.tick_top() 37 | 38 | elif att_type=='enc_dec': 39 | align = alignments[-1, layer, head].transpose(0,1).contiguous() 40 | axes[layer,head].imshow(align[:L, :T], origin='lower', aspect='auto') 41 | 42 | return fig 43 | 44 | 45 | def plot_gate(gate_out): 46 | fig = plt.figure(figsize=(10,5)) 47 | plt.plot(torch.sigmoid(gate_out[-1])) 48 | return fig -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import hparams 2 | from torch.utils.data import DataLoader 3 | from .data_utils import TextMelSet, TextMelCollate 4 | import torch 5 | from text import * 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def prepare_dataloaders(hparams, stage): 10 | # Get data, data loaders and collate function ready 11 | trainset = TextMelSet(hparams.training_files, hparams, stage) 12 | valset = TextMelSet(hparams.validation_files, hparams, stage) 13 | collate_fn = TextMelCollate(stage) 14 | 15 | train_loader = DataLoader(trainset, 16 | shuffle=True, 17 | batch_size=hparams.batch_size, 18 | drop_last=True, 19 | collate_fn=collate_fn) 20 | 21 | val_loader = DataLoader(valset, 22 | batch_size=hparams.batch_size//hparams.n_gpus, 23 | collate_fn=collate_fn) 24 | 25 | return train_loader, val_loader, collate_fn 26 | 27 | 28 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 29 | print(f"Saving model and optimizer state at iteration {iteration} to {filepath}") 30 | torch.save({'iteration': iteration, 31 | 'state_dict': model.state_dict(), 32 | 'optimizer': optimizer.state_dict(), 33 | 'learning_rate': learning_rate}, f'{filepath}/checkpoint_{iteration}') 34 | 35 | 36 | def lr_scheduling(opt, step, init_lr=hparams.lr, warmup_steps=hparams.warmup_steps): 37 | opt.param_groups[0]['lr'] = init_lr * min(step ** -0.5, step * warmup_steps ** -1.5) 38 | return 39 | 40 | 41 | def get_mask_from_lengths(lengths): 42 | max_len = torch.max(lengths).item() 43 | ids = lengths.new_tensor(torch.arange(0, max_len)) 44 | mask = (lengths.unsqueeze(1) <= ids).to(torch.bool) 45 | return mask 46 | 47 | 48 | def reorder_batch(x, n_gpus, base=0): 49 | assert (x.size(0)%n_gpus)==0, 'Batch size must be a multiple of the number of GPUs.' 50 | base = base%n_gpus 51 | new_x = list(torch.zeros_like(x).chunk(n_gpus)) 52 | for i in range(base, base+n_gpus): 53 | new_x[i%n_gpus] = x[i-base::n_gpus] 54 | 55 | new_x = torch.cat(new_x, dim=0) 56 | 57 | return new_x 58 | 59 | 60 | 61 | def decode_text(padded_text, text_lengths, batch_idx=0): 62 | text = padded_text[batch_idx] 63 | text_len = text_lengths[batch_idx] 64 | text = ''.join([symbols[ci] for i, ci in enumerate(text) if i < text_len]) 65 | 66 | return text 67 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.tensorboard import SummaryWriter 3 | from .plot_image import * 4 | 5 | def get_writer(output_directory, log_directory): 6 | logging_path=f'{output_directory}/{log_directory}' 7 | 8 | if os.path.exists(logging_path): 9 | writer = TTSWriter(logging_path) 10 | # raise Exception('The experiment already exists') 11 | print(f'The experiment {logging_path} already exists!') 12 | else: 13 | os.makedirs(logging_path) 14 | writer = TTSWriter(logging_path) 15 | 16 | return writer 17 | 18 | 19 | class TTSWriter(SummaryWriter): 20 | def __init__(self, log_dir): 21 | super(TTSWriter, self).__init__(log_dir) 22 | 23 | def add_specs(self, mel_target, mel_out, global_step, phase): 24 | fig, axes = plt.subplots(2, 1, figsize=(20,20)) 25 | 26 | axes[0].imshow(mel_target, 27 | origin='lower', 28 | aspect='auto') 29 | 30 | axes[1].imshow(mel_out, 31 | origin='lower', 32 | aspect='auto') 33 | 34 | self.add_figure(f'{phase}_melspec', fig, global_step) 35 | 36 | def add_alignments(self, alignment, global_step, phase): 37 | fig = plt.plot(figsize=(20,10)) 38 | plt.imshow(alignment, origin='lower', aspect='auto') 39 | self.add_figure(f'{phase}_alignments', fig, global_step) 40 | 41 | 42 | -------------------------------------------------------------------------------- /waveglow/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tacotron2"] 2 | path = tacotron2 3 | url = http://github.com/NVIDIA/tacotron2 4 | -------------------------------------------------------------------------------- /waveglow/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /waveglow/README.md: -------------------------------------------------------------------------------- 1 | ![WaveGlow](waveglow_logo.png "WaveGLow") 2 | 3 | ## WaveGlow: a Flow-based Generative Network for Speech Synthesis 4 | 5 | ### Ryan Prenger, Rafael Valle, and Bryan Catanzaro 6 | 7 | In our recent [paper], we propose WaveGlow: a flow-based network capable of 8 | generating high quality speech from mel-spectrograms. WaveGlow combines insights 9 | from [Glow] and [WaveNet] in order to provide fast, efficient and high-quality 10 | audio synthesis, without the need for auto-regression. WaveGlow is implemented 11 | using only a single network, trained using only a single cost function: 12 | maximizing the likelihood of the training data, which makes the training 13 | procedure simple and stable. 14 | 15 | Our [PyTorch] implementation produces audio samples at a rate of 2750 16 | kHz on an NVIDIA V100 GPU. Mean Opinion Scores show that it delivers audio 17 | quality as good as the best publicly available WaveNet implementation. 18 | 19 | Visit our [website] for audio samples. 20 | 21 | ## Setup 22 | 23 | 1. Clone our repo and initialize submodule 24 | 25 | ```command 26 | git clone https://github.com/NVIDIA/waveglow.git 27 | cd waveglow 28 | git submodule init 29 | git submodule update 30 | ``` 31 | 32 | 2. Install requirements `pip3 install -r requirements.txt` 33 | 34 | 3. Install [Apex] 35 | 36 | 37 | ## Generate audio with our pre-existing model 38 | 39 | 1. Download our [published model] 40 | 2. Download [mel-spectrograms] 41 | 3. Generate audio `python3 inference.py -f <(ls mel_spectrograms/*.pt) -w waveglow_256channels.pt -o . --is_fp16 -s 0.6` 42 | 43 | N.b. use `convert_model.py` to convert your older models to the current model 44 | with fused residual and skip connections. 45 | 46 | ## Train your own model 47 | 48 | 1. Download [LJ Speech Data]. In this example it's in `data/` 49 | 50 | 2. Make a list of the file names to use for training/testing 51 | 52 | ```command 53 | ls data/*.wav | tail -n+10 > train_files.txt 54 | ls data/*.wav | head -n10 > test_files.txt 55 | ``` 56 | 57 | 3. Train your WaveGlow networks 58 | 59 | ```command 60 | mkdir checkpoints 61 | python train.py -c config.json 62 | ``` 63 | 64 | For multi-GPU training replace `train.py` with `distributed.py`. Only tested with single node and NCCL. 65 | 66 | For mixed precision training set `"fp16_run": true` on `config.json`. 67 | 68 | 4. Make test set mel-spectrograms 69 | 70 | `python mel2samp.py -f test_files.txt -o . -c config.json` 71 | 72 | 5. Do inference with your network 73 | 74 | ```command 75 | ls *.pt > mel_files.txt 76 | python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6 77 | ``` 78 | 79 | [//]: # (TODO) 80 | [//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS) 81 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation 82 | [website]: https://nv-adlr.github.io/WaveGlow 83 | [paper]: https://arxiv.org/abs/1811.00002 84 | [WaveNet implementation]: https://github.com/r9y9/wavenet_vocoder 85 | [Glow]: https://blog.openai.com/glow/ 86 | [WaveNet]: https://deepmind.com/blog/wavenet-generative-model-raw-audio/ 87 | [PyTorch]: http://pytorch.org 88 | [published model]: https://drive.google.com/file/d/1WsibBTsuRg_SF2Z6L6NFRTT-NjEy1oTx/view?usp=sharing 89 | [mel-spectrograms]: https://drive.google.com/file/d/1g_VXK2lpP9J25dQFhQwx7doWl_p20fXA/view?usp=sharing 90 | [LJ Speech Data]: https://keithito.com/LJ-Speech-Dataset 91 | [Apex]: https://github.com/nvidia/apex 92 | -------------------------------------------------------------------------------- /waveglow/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": { 3 | "fp16_run": true, 4 | "output_directory": "checkpoints", 5 | "epochs": 100000, 6 | "learning_rate": 1e-4, 7 | "sigma": 1.0, 8 | "iters_per_checkpoint": 2000, 9 | "batch_size": 12, 10 | "seed": 1234, 11 | "checkpoint_path": "", 12 | "with_tensorboard": false 13 | }, 14 | "data_config": { 15 | "training_files": "train_files.txt", 16 | "segment_length": 16000, 17 | "sampling_rate": 22050, 18 | "filter_length": 1024, 19 | "hop_length": 256, 20 | "win_length": 1024, 21 | "mel_fmin": 0.0, 22 | "mel_fmax": 8000.0 23 | }, 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321" 27 | }, 28 | 29 | "waveglow_config": { 30 | "n_mel_channels": 80, 31 | "n_flows": 12, 32 | "n_group": 8, 33 | "n_early_every": 4, 34 | "n_early_size": 2, 35 | "WN_config": { 36 | "n_layers": 8, 37 | "n_channels": 256, 38 | "kernel_size": 3 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /waveglow/convert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | 5 | def _check_model_old_version(model): 6 | if hasattr(model.WN[0], 'res_layers'): 7 | return True 8 | else: 9 | return False 10 | 11 | def update_model(old_model): 12 | if not _check_model_old_version(old_model): 13 | return old_model 14 | new_model = copy.deepcopy(old_model) 15 | for idx in range(0, len(new_model.WN)): 16 | wavenet = new_model.WN[idx] 17 | wavenet.res_skip_layers = torch.nn.ModuleList() 18 | n_channels = wavenet.n_channels 19 | n_layers = wavenet.n_layers 20 | for i in range(0, n_layers): 21 | if i < n_layers - 1: 22 | res_skip_channels = 2*n_channels 23 | else: 24 | res_skip_channels = n_channels 25 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 26 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) 27 | if i < n_layers - 1: 28 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) 29 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) 30 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) 31 | else: 32 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) 33 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) 34 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 35 | wavenet.res_skip_layers.append(res_skip_layer) 36 | del wavenet.res_layers 37 | del wavenet.skip_layers 38 | return new_model 39 | 40 | if __name__ == '__main__': 41 | old_model_path = sys.argv[1] 42 | new_model_path = sys.argv[2] 43 | model = torch.load(old_model_path) 44 | model['model'] = update_model(model['model']) 45 | torch.save(model, new_model_path) 46 | 47 | -------------------------------------------------------------------------------- /waveglow/denoiser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('tacotron2') 3 | import torch 4 | from layers import STFT 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """ Removes model bias from audio produced with waveglow """ 9 | 10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4, 11 | win_length=1024, mode='zeros'): 12 | super(Denoiser, self).__init__() 13 | self.stft = STFT(filter_length=filter_length, 14 | hop_length=int(filter_length/n_overlap), 15 | win_length=win_length).cuda() 16 | if mode == 'zeros': 17 | mel_input = torch.zeros( 18 | (1, 80, 88), 19 | dtype=waveglow.upsample.weight.dtype, 20 | device=waveglow.upsample.weight.device) 21 | elif mode == 'normal': 22 | mel_input = torch.randn( 23 | (1, 80, 88), 24 | dtype=waveglow.upsample.weight.dtype, 25 | device=waveglow.upsample.weight.device) 26 | else: 27 | raise Exception("Mode {} if not supported".format(mode)) 28 | 29 | with torch.no_grad(): 30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float() 31 | bias_spec, _ = self.stft.transform(bias_audio) 32 | 33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) 34 | 35 | def forward(self, audio, strength=0.1): 36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) 37 | audio_spec_denoised = audio_spec - self.bias_spec * strength 38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) 40 | return audio_denoised 41 | -------------------------------------------------------------------------------- /waveglow/distributed.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | import sys 29 | import time 30 | import subprocess 31 | import argparse 32 | 33 | import torch 34 | import torch.distributed as dist 35 | from torch.autograd import Variable 36 | 37 | def reduce_tensor(tensor, num_gpus): 38 | rt = tensor.clone() 39 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 40 | rt /= num_gpus 41 | return rt 42 | 43 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 44 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 45 | print("Initializing Distributed") 46 | 47 | # Set cuda device so everything is done on the right GPU. 48 | torch.cuda.set_device(rank % torch.cuda.device_count()) 49 | 50 | # Initialize distributed communication 51 | dist.init_process_group(dist_backend, init_method=dist_url, 52 | world_size=num_gpus, rank=rank, 53 | group_name=group_name) 54 | 55 | def _flatten_dense_tensors(tensors): 56 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 57 | same dense type. 58 | Since inputs are dense, the resulting tensor will be a concatenated 1D 59 | buffer. Element-wise operation on this buffer will be equivalent to 60 | operating individually. 61 | Arguments: 62 | tensors (Iterable[Tensor]): dense tensors to flatten. 63 | Returns: 64 | A contiguous 1D buffer containing input tensors. 65 | """ 66 | if len(tensors) == 1: 67 | return tensors[0].contiguous().view(-1) 68 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 69 | return flat 70 | 71 | def _unflatten_dense_tensors(flat, tensors): 72 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 73 | same dense type, and that flat is given by _flatten_dense_tensors. 74 | Arguments: 75 | flat (Tensor): flattened dense tensors to unflatten. 76 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 77 | unflatten flat. 78 | Returns: 79 | Unflattened dense tensors with sizes same as tensors and values from 80 | flat. 81 | """ 82 | outputs = [] 83 | offset = 0 84 | for tensor in tensors: 85 | numel = tensor.numel() 86 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 87 | offset += numel 88 | return tuple(outputs) 89 | 90 | def apply_gradient_allreduce(module): 91 | """ 92 | Modifies existing model to do gradient allreduce, but doesn't change class 93 | so you don't need "module" 94 | """ 95 | if not hasattr(dist, '_backend'): 96 | module.warn_on_half = True 97 | else: 98 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 99 | 100 | for p in module.state_dict().values(): 101 | if not torch.is_tensor(p): 102 | continue 103 | dist.broadcast(p, 0) 104 | 105 | def allreduce_params(): 106 | if(module.needs_reduction): 107 | module.needs_reduction = False 108 | buckets = {} 109 | for param in module.parameters(): 110 | if param.requires_grad and param.grad is not None: 111 | tp = type(param.data) 112 | if tp not in buckets: 113 | buckets[tp] = [] 114 | buckets[tp].append(param) 115 | if module.warn_on_half: 116 | if torch.cuda.HalfTensor in buckets: 117 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 118 | " It is recommended to use the NCCL backend in this case. This currently requires" + 119 | "PyTorch built from top of tree master.") 120 | module.warn_on_half = False 121 | 122 | for tp in buckets: 123 | bucket = buckets[tp] 124 | grads = [param.grad.data for param in bucket] 125 | coalesced = _flatten_dense_tensors(grads) 126 | dist.all_reduce(coalesced) 127 | coalesced /= dist.get_world_size() 128 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 129 | buf.copy_(synced) 130 | 131 | for param in list(module.parameters()): 132 | def allreduce_hook(*unused): 133 | Variable._execution_engine.queue_callback(allreduce_params) 134 | if param.requires_grad: 135 | param.register_hook(allreduce_hook) 136 | dir(param) 137 | 138 | def set_needs_reduction(self, input, output): 139 | self.needs_reduction = True 140 | 141 | module.register_forward_hook(set_needs_reduction) 142 | return module 143 | 144 | 145 | def main(config, stdout_dir, args_str): 146 | args_list = ['train.py'] 147 | args_list += args_str.split(' ') if len(args_str) > 0 else [] 148 | 149 | args_list.append('--config={}'.format(config)) 150 | 151 | num_gpus = torch.cuda.device_count() 152 | args_list.append('--num_gpus={}'.format(num_gpus)) 153 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S"))) 154 | 155 | if not os.path.isdir(stdout_dir): 156 | os.makedirs(stdout_dir) 157 | os.chmod(stdout_dir, 0o775) 158 | 159 | workers = [] 160 | 161 | for i in range(num_gpus): 162 | args_list[-2] = '--rank={}'.format(i) 163 | stdout = None if i == 0 else open( 164 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w") 165 | print(args_list) 166 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout) 167 | workers.append(p) 168 | 169 | for p in workers: 170 | p.wait() 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('-c', '--config', type=str, required=True, 176 | help='JSON file for configuration') 177 | parser.add_argument('-s', '--stdout_dir', type=str, default=".", 178 | help='directory to save stoud logs') 179 | parser.add_argument( 180 | '-a', '--args_str', type=str, default='', 181 | help='double quoted string with space separated key value pairs') 182 | 183 | args = parser.parse_args() 184 | main(args.config, args.stdout_dir, args.args_str) 185 | -------------------------------------------------------------------------------- /waveglow/glow.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import copy 28 | import torch 29 | from torch.autograd import Variable 30 | import torch.nn.functional as F 31 | 32 | 33 | @torch.jit.script 34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 35 | n_channels_int = n_channels[0] 36 | in_act = input_a+input_b 37 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 39 | acts = t_act * s_act 40 | return acts 41 | 42 | 43 | class WaveGlowLoss(torch.nn.Module): 44 | def __init__(self, sigma=1.0): 45 | super(WaveGlowLoss, self).__init__() 46 | self.sigma = sigma 47 | 48 | def forward(self, model_output): 49 | z, log_s_list, log_det_W_list = model_output 50 | for i, log_s in enumerate(log_s_list): 51 | if i == 0: 52 | log_s_total = torch.sum(log_s) 53 | log_det_W_total = log_det_W_list[i] 54 | else: 55 | log_s_total = log_s_total + torch.sum(log_s) 56 | log_det_W_total += log_det_W_list[i] 57 | 58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total 59 | return loss/(z.size(0)*z.size(1)*z.size(2)) 60 | 61 | 62 | class Invertible1x1Conv(torch.nn.Module): 63 | """ 64 | The layer outputs both the convolution, and the log determinant 65 | of its weight matrix. If reverse=True it does convolution with 66 | inverse 67 | """ 68 | def __init__(self, c): 69 | super(Invertible1x1Conv, self).__init__() 70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, 71 | bias=False) 72 | 73 | # Sample a random orthonormal matrix to initialize weights 74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0] 75 | 76 | # Ensure determinant is 1.0 not -1.0 77 | if torch.det(W) < 0: 78 | W[:,0] = -1*W[:,0] 79 | W = W.view(c, c, 1) 80 | self.conv.weight.data = W 81 | 82 | def forward(self, z, reverse=False): 83 | # shape 84 | batch_size, group_size, n_of_groups = z.size() 85 | 86 | W = self.conv.weight.squeeze() 87 | 88 | if reverse: 89 | if not hasattr(self, 'W_inverse'): 90 | # Reverse computation 91 | W_inverse = W.float().inverse() 92 | W_inverse = Variable(W_inverse[..., None]) 93 | if z.type() == 'torch.cuda.HalfTensor': 94 | W_inverse = W_inverse.half() 95 | self.W_inverse = W_inverse 96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) 97 | return z 98 | else: 99 | # Forward computation 100 | log_det_W = batch_size * n_of_groups * torch.logdet(W) 101 | z = self.conv(z) 102 | return z, log_det_W 103 | 104 | 105 | class WN(torch.nn.Module): 106 | """ 107 | This is the WaveNet like layer for the affine coupling. The primary difference 108 | from WaveNet is the convolutions need not be causal. There is also no dilation 109 | size reset. The dilation only doubles on each layer 110 | """ 111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 112 | kernel_size): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | assert(n_channels % 2 == 0) 116 | self.n_layers = n_layers 117 | self.n_channels = n_channels 118 | self.in_layers = torch.nn.ModuleList() 119 | self.res_skip_layers = torch.nn.ModuleList() 120 | self.cond_layers = torch.nn.ModuleList() 121 | 122 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 123 | start = torch.nn.utils.weight_norm(start, name='weight') 124 | self.start = start 125 | 126 | # Initializing last layer to 0 makes the affine coupling layers 127 | # do nothing at first. This helps with training stability 128 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 129 | end.weight.data.zero_() 130 | end.bias.data.zero_() 131 | self.end = end 132 | 133 | for i in range(n_layers): 134 | dilation = 2 ** i 135 | padding = int((kernel_size*dilation - dilation)/2) 136 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 137 | dilation=dilation, padding=padding) 138 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 139 | self.in_layers.append(in_layer) 140 | 141 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 142 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 143 | self.cond_layers.append(cond_layer) 144 | 145 | # last one is not necessary 146 | if i < n_layers - 1: 147 | res_skip_channels = 2*n_channels 148 | else: 149 | res_skip_channels = n_channels 150 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 151 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 152 | self.res_skip_layers.append(res_skip_layer) 153 | 154 | def forward(self, forward_input): 155 | audio, spect = forward_input 156 | audio = self.start(audio) 157 | 158 | for i in range(self.n_layers): 159 | acts = fused_add_tanh_sigmoid_multiply( 160 | self.in_layers[i](audio), 161 | self.cond_layers[i](spect), 162 | torch.IntTensor([self.n_channels])) 163 | 164 | res_skip_acts = self.res_skip_layers[i](acts) 165 | if i < self.n_layers - 1: 166 | audio = res_skip_acts[:,:self.n_channels,:] + audio 167 | skip_acts = res_skip_acts[:,self.n_channels:,:] 168 | else: 169 | skip_acts = res_skip_acts 170 | 171 | if i == 0: 172 | output = skip_acts 173 | else: 174 | output = skip_acts + output 175 | return self.end(output) 176 | 177 | 178 | class WaveGlow(torch.nn.Module): 179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 180 | n_early_size, WN_config): 181 | super(WaveGlow, self).__init__() 182 | 183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 184 | n_mel_channels, 185 | 1024, stride=256) 186 | assert(n_group % 2 == 0) 187 | self.n_flows = n_flows 188 | self.n_group = n_group 189 | self.n_early_every = n_early_every 190 | self.n_early_size = n_early_size 191 | self.WN = torch.nn.ModuleList() 192 | self.convinv = torch.nn.ModuleList() 193 | 194 | n_half = int(n_group/2) 195 | 196 | # Set up layers with the right sizes based on how many dimensions 197 | # have been output already 198 | n_remaining_channels = n_group 199 | for k in range(n_flows): 200 | if k % self.n_early_every == 0 and k > 0: 201 | n_half = n_half - int(self.n_early_size/2) 202 | n_remaining_channels = n_remaining_channels - self.n_early_size 203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 205 | self.n_remaining_channels = n_remaining_channels # Useful during inference 206 | 207 | def forward(self, forward_input): 208 | """ 209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames 210 | forward_input[1] = audio: batch x time 211 | """ 212 | spect, audio = forward_input 213 | 214 | # Upsample spectrogram to size of audio 215 | spect = self.upsample(spect) 216 | assert(spect.size(2) >= audio.size(1)) 217 | if spect.size(2) > audio.size(1): 218 | spect = spect[:, :, :audio.size(1)] 219 | 220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 222 | 223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 224 | output_audio = [] 225 | log_s_list = [] 226 | log_det_W_list = [] 227 | 228 | for k in range(self.n_flows): 229 | if k % self.n_early_every == 0 and k > 0: 230 | output_audio.append(audio[:,:self.n_early_size,:]) 231 | audio = audio[:,self.n_early_size:,:] 232 | 233 | audio, log_det_W = self.convinv[k](audio) 234 | log_det_W_list.append(log_det_W) 235 | 236 | n_half = int(audio.size(1)/2) 237 | audio_0 = audio[:,:n_half,:] 238 | audio_1 = audio[:,n_half:,:] 239 | 240 | output = self.WN[k]((audio_0, spect)) 241 | log_s = output[:, n_half:, :] 242 | b = output[:, :n_half, :] 243 | audio_1 = torch.exp(log_s)*audio_1 + b 244 | log_s_list.append(log_s) 245 | 246 | audio = torch.cat([audio_0, audio_1],1) 247 | 248 | output_audio.append(audio) 249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list 250 | 251 | def infer(self, spect, sigma=1.0): 252 | spect = self.upsample(spect) 253 | # trim conv artifacts. maybe pad spec to kernel multiple 254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 255 | spect = spect[:, :, :-time_cutoff] 256 | 257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 259 | 260 | if spect.type() == 'torch.cuda.HalfTensor': 261 | audio = torch.cuda.HalfTensor(spect.size(0), 262 | self.n_remaining_channels, 263 | spect.size(2)).normal_() 264 | else: 265 | audio = torch.cuda.FloatTensor(spect.size(0), 266 | self.n_remaining_channels, 267 | spect.size(2)).normal_() 268 | 269 | audio = torch.autograd.Variable(sigma*audio) 270 | 271 | for k in reversed(range(self.n_flows)): 272 | n_half = int(audio.size(1)/2) 273 | audio_0 = audio[:,:n_half,:] 274 | audio_1 = audio[:,n_half:,:] 275 | 276 | output = self.WN[k]((audio_0, spect)) 277 | s = output[:, n_half:, :] 278 | b = output[:, :n_half, :] 279 | audio_1 = (audio_1 - b)/torch.exp(s) 280 | audio = torch.cat([audio_0, audio_1],1) 281 | 282 | audio = self.convinv[k](audio, reverse=True) 283 | 284 | if k % self.n_early_every == 0 and k > 0: 285 | if spect.type() == 'torch.cuda.HalfTensor': 286 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 287 | else: 288 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 289 | audio = torch.cat((sigma*z, audio),1) 290 | 291 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 292 | return audio 293 | 294 | @staticmethod 295 | def remove_weightnorm(model): 296 | waveglow = model 297 | for WN in waveglow.WN: 298 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 299 | WN.in_layers = remove(WN.in_layers) 300 | WN.cond_layers = remove(WN.cond_layers) 301 | WN.res_skip_layers = remove(WN.res_skip_layers) 302 | return waveglow 303 | 304 | 305 | def remove(conv_list): 306 | new_conv_list = torch.nn.ModuleList() 307 | for old_conv in conv_list: 308 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 309 | new_conv_list.append(old_conv) 310 | return new_conv_list 311 | -------------------------------------------------------------------------------- /waveglow/glow_old.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from glow import Invertible1x1Conv, remove 4 | 5 | 6 | @torch.jit.script 7 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 8 | n_channels_int = n_channels[0] 9 | in_act = input_a+input_b 10 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 11 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 12 | acts = t_act * s_act 13 | return acts 14 | 15 | 16 | class WN(torch.nn.Module): 17 | """ 18 | This is the WaveNet like layer for the affine coupling. The primary difference 19 | from WaveNet is the convolutions need not be causal. There is also no dilation 20 | size reset. The dilation only doubles on each layer 21 | """ 22 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 23 | kernel_size): 24 | super(WN, self).__init__() 25 | assert(kernel_size % 2 == 1) 26 | assert(n_channels % 2 == 0) 27 | self.n_layers = n_layers 28 | self.n_channels = n_channels 29 | self.in_layers = torch.nn.ModuleList() 30 | self.res_skip_layers = torch.nn.ModuleList() 31 | self.cond_layers = torch.nn.ModuleList() 32 | 33 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 34 | start = torch.nn.utils.weight_norm(start, name='weight') 35 | self.start = start 36 | 37 | # Initializing last layer to 0 makes the affine coupling layers 38 | # do nothing at first. This helps with training stability 39 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 40 | end.weight.data.zero_() 41 | end.bias.data.zero_() 42 | self.end = end 43 | 44 | for i in range(n_layers): 45 | dilation = 2 ** i 46 | padding = int((kernel_size*dilation - dilation)/2) 47 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 48 | dilation=dilation, padding=padding) 49 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 50 | self.in_layers.append(in_layer) 51 | 52 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 53 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 54 | self.cond_layers.append(cond_layer) 55 | 56 | # last one is not necessary 57 | if i < n_layers - 1: 58 | res_skip_channels = 2*n_channels 59 | else: 60 | res_skip_channels = n_channels 61 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 62 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 63 | self.res_skip_layers.append(res_skip_layer) 64 | 65 | def forward(self, forward_input): 66 | audio, spect = forward_input 67 | audio = self.start(audio) 68 | 69 | for i in range(self.n_layers): 70 | acts = fused_add_tanh_sigmoid_multiply( 71 | self.in_layers[i](audio), 72 | self.cond_layers[i](spect), 73 | torch.IntTensor([self.n_channels])) 74 | 75 | res_skip_acts = self.res_skip_layers[i](acts) 76 | if i < self.n_layers - 1: 77 | audio = res_skip_acts[:,:self.n_channels,:] + audio 78 | skip_acts = res_skip_acts[:,self.n_channels:,:] 79 | else: 80 | skip_acts = res_skip_acts 81 | 82 | if i == 0: 83 | output = skip_acts 84 | else: 85 | output = skip_acts + output 86 | return self.end(output) 87 | 88 | 89 | class WaveGlow(torch.nn.Module): 90 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 91 | n_early_size, WN_config): 92 | super(WaveGlow, self).__init__() 93 | 94 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 95 | n_mel_channels, 96 | 1024, stride=256) 97 | assert(n_group % 2 == 0) 98 | self.n_flows = n_flows 99 | self.n_group = n_group 100 | self.n_early_every = n_early_every 101 | self.n_early_size = n_early_size 102 | self.WN = torch.nn.ModuleList() 103 | self.convinv = torch.nn.ModuleList() 104 | 105 | n_half = int(n_group/2) 106 | 107 | # Set up layers with the right sizes based on how many dimensions 108 | # have been output already 109 | n_remaining_channels = n_group 110 | for k in range(n_flows): 111 | if k % self.n_early_every == 0 and k > 0: 112 | n_half = n_half - int(self.n_early_size/2) 113 | n_remaining_channels = n_remaining_channels - self.n_early_size 114 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 115 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 116 | self.n_remaining_channels = n_remaining_channels # Useful during inference 117 | 118 | def forward(self, forward_input): 119 | return None 120 | """ 121 | forward_input[0] = audio: batch x time 122 | forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time 123 | """ 124 | """ 125 | spect, audio = forward_input 126 | 127 | # Upsample spectrogram to size of audio 128 | spect = self.upsample(spect) 129 | assert(spect.size(2) >= audio.size(1)) 130 | if spect.size(2) > audio.size(1): 131 | spect = spect[:, :, :audio.size(1)] 132 | 133 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 134 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 135 | 136 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 137 | output_audio = [] 138 | s_list = [] 139 | s_conv_list = [] 140 | 141 | for k in range(self.n_flows): 142 | if k%4 == 0 and k > 0: 143 | output_audio.append(audio[:,:self.n_multi,:]) 144 | audio = audio[:,self.n_multi:,:] 145 | 146 | # project to new basis 147 | audio, s = self.convinv[k](audio) 148 | s_conv_list.append(s) 149 | 150 | n_half = int(audio.size(1)/2) 151 | if k%2 == 0: 152 | audio_0 = audio[:,:n_half,:] 153 | audio_1 = audio[:,n_half:,:] 154 | else: 155 | audio_1 = audio[:,:n_half,:] 156 | audio_0 = audio[:,n_half:,:] 157 | 158 | output = self.nn[k]((audio_0, spect)) 159 | s = output[:, n_half:, :] 160 | b = output[:, :n_half, :] 161 | audio_1 = torch.exp(s)*audio_1 + b 162 | s_list.append(s) 163 | 164 | if k%2 == 0: 165 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 166 | else: 167 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 168 | output_audio.append(audio) 169 | return torch.cat(output_audio,1), s_list, s_conv_list 170 | """ 171 | 172 | def infer(self, spect, sigma=1.0): 173 | spect = self.upsample(spect) 174 | # trim conv artifacts. maybe pad spec to kernel multiple 175 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 176 | spect = spect[:, :, :-time_cutoff] 177 | 178 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 179 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 180 | 181 | if spect.type() == 'torch.cuda.HalfTensor': 182 | audio = torch.cuda.HalfTensor(spect.size(0), 183 | self.n_remaining_channels, 184 | spect.size(2)).normal_() 185 | else: 186 | audio = torch.cuda.FloatTensor(spect.size(0), 187 | self.n_remaining_channels, 188 | spect.size(2)).normal_() 189 | 190 | audio = torch.autograd.Variable(sigma*audio) 191 | 192 | for k in reversed(range(self.n_flows)): 193 | n_half = int(audio.size(1)/2) 194 | if k%2 == 0: 195 | audio_0 = audio[:,:n_half,:] 196 | audio_1 = audio[:,n_half:,:] 197 | else: 198 | audio_1 = audio[:,:n_half,:] 199 | audio_0 = audio[:,n_half:,:] 200 | 201 | output = self.WN[k]((audio_0, spect)) 202 | s = output[:, n_half:, :] 203 | b = output[:, :n_half, :] 204 | audio_1 = (audio_1 - b)/torch.exp(s) 205 | if k%2 == 0: 206 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 207 | else: 208 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 209 | 210 | audio = self.convinv[k](audio, reverse=True) 211 | 212 | if k%4 == 0 and k > 0: 213 | if spect.type() == 'torch.cuda.HalfTensor': 214 | z = torch.cuda.HalfTensor(spect.size(0), 215 | self.n_early_size, 216 | spect.size(2)).normal_() 217 | else: 218 | z = torch.cuda.FloatTensor(spect.size(0), 219 | self.n_early_size, 220 | spect.size(2)).normal_() 221 | audio = torch.cat((sigma*z, audio),1) 222 | 223 | return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 224 | 225 | @staticmethod 226 | def remove_weightnorm(model): 227 | waveglow = model 228 | for WN in waveglow.WN: 229 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 230 | WN.in_layers = remove(WN.in_layers) 231 | WN.cond_layers = remove(WN.cond_layers) 232 | WN.res_skip_layers = remove(WN.res_skip_layers) 233 | return waveglow 234 | -------------------------------------------------------------------------------- /waveglow/inference.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | from scipy.io.wavfile import write 29 | import torch 30 | from mel2samp import files_to_list, MAX_WAV_VALUE 31 | from denoiser import Denoiser 32 | 33 | 34 | def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, 35 | denoiser_strength): 36 | mel_files = files_to_list(mel_files) 37 | waveglow = torch.load(waveglow_path)['model'] 38 | waveglow = waveglow.remove_weightnorm(waveglow) 39 | waveglow.cuda().eval() 40 | if is_fp16: 41 | from apex import amp 42 | waveglow, _ = amp.initialize(waveglow, [], opt_level="O3") 43 | 44 | if denoiser_strength > 0: 45 | denoiser = Denoiser(waveglow).cuda() 46 | 47 | for i, file_path in enumerate(mel_files): 48 | file_name = os.path.splitext(os.path.basename(file_path))[0] 49 | mel = torch.load(file_path) 50 | mel = torch.autograd.Variable(mel.cuda()) 51 | mel = torch.unsqueeze(mel, 0) 52 | mel = mel.half() if is_fp16 else mel 53 | with torch.no_grad(): 54 | audio = waveglow.infer(mel, sigma=sigma) 55 | if denoiser_strength > 0: 56 | audio = denoiser(audio, denoiser_strength) 57 | audio = audio * MAX_WAV_VALUE 58 | audio = audio.squeeze() 59 | audio = audio.cpu().numpy() 60 | audio = audio.astype('int16') 61 | audio_path = os.path.join( 62 | output_dir, "{}_synthesis.wav".format(file_name)) 63 | write(audio_path, sampling_rate, audio) 64 | print(audio_path) 65 | 66 | 67 | if __name__ == "__main__": 68 | import argparse 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-f', "--filelist_path", required=True) 72 | parser.add_argument('-w', '--waveglow_path', 73 | help='Path to waveglow decoder checkpoint with model') 74 | parser.add_argument('-o', "--output_dir", required=True) 75 | parser.add_argument("-s", "--sigma", default=1.0, type=float) 76 | parser.add_argument("--sampling_rate", default=22050, type=int) 77 | parser.add_argument("--is_fp16", action="store_true") 78 | parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float, 79 | help='Removes model bias. Start with 0.1 and adjust') 80 | 81 | args = parser.parse_args() 82 | 83 | main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir, 84 | args.sampling_rate, args.is_fp16, args.denoiser_strength) 85 | -------------------------------------------------------------------------------- /waveglow/mel2samp.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # *****************************************************************************\ 27 | import os 28 | import random 29 | import argparse 30 | import json 31 | import torch 32 | import torch.utils.data 33 | import sys 34 | from scipy.io.wavfile import read 35 | 36 | # We're using the audio processing from TacoTron2 to make sure it matches 37 | sys.path.insert(0, 'tacotron2') 38 | from tacotron2.layers import TacotronSTFT 39 | 40 | MAX_WAV_VALUE = 32768.0 41 | 42 | def files_to_list(filename): 43 | """ 44 | Takes a text file of filenames and makes a list of filenames 45 | """ 46 | with open(filename, encoding='utf-8') as f: 47 | files = f.readlines() 48 | 49 | files = [f.rstrip() for f in files] 50 | return files 51 | 52 | def load_wav_to_torch(full_path): 53 | """ 54 | Loads wavdata into torch array 55 | """ 56 | sampling_rate, data = read(full_path) 57 | return torch.from_numpy(data).float(), sampling_rate 58 | 59 | 60 | class Mel2Samp(torch.utils.data.Dataset): 61 | """ 62 | This is the main class that calculates the spectrogram and returns the 63 | spectrogram, audio pair. 64 | """ 65 | def __init__(self, training_files, segment_length, filter_length, 66 | hop_length, win_length, sampling_rate, mel_fmin, mel_fmax): 67 | self.audio_files = files_to_list(training_files) 68 | random.seed(1234) 69 | random.shuffle(self.audio_files) 70 | self.stft = TacotronSTFT(filter_length=filter_length, 71 | hop_length=hop_length, 72 | win_length=win_length, 73 | sampling_rate=sampling_rate, 74 | mel_fmin=mel_fmin, mel_fmax=mel_fmax) 75 | self.segment_length = segment_length 76 | self.sampling_rate = sampling_rate 77 | 78 | def get_mel(self, audio): 79 | audio_norm = audio / MAX_WAV_VALUE 80 | audio_norm = audio_norm.unsqueeze(0) 81 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 82 | melspec = self.stft.mel_spectrogram(audio_norm) 83 | melspec = torch.squeeze(melspec, 0) 84 | return melspec 85 | 86 | def __getitem__(self, index): 87 | # Read audio 88 | filename = self.audio_files[index] 89 | audio, sampling_rate = load_wav_to_torch(filename) 90 | if sampling_rate != self.sampling_rate: 91 | raise ValueError("{} SR doesn't match target {} SR".format( 92 | sampling_rate, self.sampling_rate)) 93 | 94 | # Take segment 95 | if audio.size(0) >= self.segment_length: 96 | max_audio_start = audio.size(0) - self.segment_length 97 | audio_start = random.randint(0, max_audio_start) 98 | audio = audio[audio_start:audio_start+self.segment_length] 99 | else: 100 | audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data 101 | 102 | mel = self.get_mel(audio) 103 | audio = audio / MAX_WAV_VALUE 104 | 105 | return (mel, audio) 106 | 107 | def __len__(self): 108 | return len(self.audio_files) 109 | 110 | # =================================================================== 111 | # Takes directory of clean audio and makes directory of spectrograms 112 | # Useful for making test sets 113 | # =================================================================== 114 | if __name__ == "__main__": 115 | # Get defaults so it can work with no Sacred 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('-f', "--filelist_path", required=True) 118 | parser.add_argument('-c', '--config', type=str, 119 | help='JSON file for configuration') 120 | parser.add_argument('-o', '--output_dir', type=str, 121 | help='Output directory') 122 | args = parser.parse_args() 123 | 124 | with open(args.config) as f: 125 | data = f.read() 126 | data_config = json.loads(data)["data_config"] 127 | mel2samp = Mel2Samp(**data_config) 128 | 129 | filepaths = files_to_list(args.filelist_path) 130 | 131 | # Make directory if it doesn't exist 132 | if not os.path.isdir(args.output_dir): 133 | os.makedirs(args.output_dir) 134 | os.chmod(args.output_dir, 0o775) 135 | 136 | for filepath in filepaths: 137 | audio, sr = load_wav_to_torch(filepath) 138 | melspectrogram = mel2samp.get_mel(audio) 139 | filename = os.path.basename(filepath) 140 | new_filepath = args.output_dir + '/' + filename + '.pt' 141 | print(new_filepath) 142 | torch.save(melspectrogram, new_filepath) 143 | -------------------------------------------------------------------------------- /waveglow/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0 2 | matplotlib==2.1.0 3 | tensorflow 4 | numpy==1.13.3 5 | inflect==0.2.5 6 | librosa==0.6.0 7 | scipy==1.0.0 8 | tensorboardX==1.1 9 | Unidecode==1.0.22 10 | pillow 11 | -------------------------------------------------------------------------------- /waveglow/train.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import argparse 28 | import json 29 | import os 30 | import torch 31 | 32 | #=====START: ADDED FOR DISTRIBUTED====== 33 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor 34 | from torch.utils.data.distributed import DistributedSampler 35 | #=====END: ADDED FOR DISTRIBUTED====== 36 | 37 | from torch.utils.data import DataLoader 38 | from glow import WaveGlow, WaveGlowLoss 39 | from mel2samp import Mel2Samp 40 | 41 | def load_checkpoint(checkpoint_path, model, optimizer): 42 | assert os.path.isfile(checkpoint_path) 43 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 44 | iteration = checkpoint_dict['iteration'] 45 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 46 | model_for_loading = checkpoint_dict['model'] 47 | model.load_state_dict(model_for_loading.state_dict()) 48 | print("Loaded checkpoint '{}' (iteration {})" .format( 49 | checkpoint_path, iteration)) 50 | return model, optimizer, iteration 51 | 52 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 53 | print("Saving model and optimizer state at iteration {} to {}".format( 54 | iteration, filepath)) 55 | model_for_saving = WaveGlow(**waveglow_config).cuda() 56 | model_for_saving.load_state_dict(model.state_dict()) 57 | torch.save({'model': model_for_saving, 58 | 'iteration': iteration, 59 | 'optimizer': optimizer.state_dict(), 60 | 'learning_rate': learning_rate}, filepath) 61 | 62 | def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, 63 | sigma, iters_per_checkpoint, batch_size, seed, fp16_run, 64 | checkpoint_path, with_tensorboard): 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed(seed) 67 | #=====START: ADDED FOR DISTRIBUTED====== 68 | if num_gpus > 1: 69 | init_distributed(rank, num_gpus, group_name, **dist_config) 70 | #=====END: ADDED FOR DISTRIBUTED====== 71 | 72 | criterion = WaveGlowLoss(sigma) 73 | model = WaveGlow(**waveglow_config).cuda() 74 | 75 | #=====START: ADDED FOR DISTRIBUTED====== 76 | if num_gpus > 1: 77 | model = apply_gradient_allreduce(model) 78 | #=====END: ADDED FOR DISTRIBUTED====== 79 | 80 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 81 | 82 | if fp16_run: 83 | from apex import amp 84 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 85 | 86 | # Load checkpoint if one exists 87 | iteration = 0 88 | if checkpoint_path != "": 89 | model, optimizer, iteration = load_checkpoint(checkpoint_path, model, 90 | optimizer) 91 | iteration += 1 # next iteration is iteration + 1 92 | 93 | trainset = Mel2Samp(**data_config) 94 | # =====START: ADDED FOR DISTRIBUTED====== 95 | train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None 96 | # =====END: ADDED FOR DISTRIBUTED====== 97 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False, 98 | sampler=train_sampler, 99 | batch_size=batch_size, 100 | pin_memory=False, 101 | drop_last=True) 102 | 103 | # Get shared output_directory ready 104 | if rank == 0: 105 | if not os.path.isdir(output_directory): 106 | os.makedirs(output_directory) 107 | os.chmod(output_directory, 0o775) 108 | print("output directory", output_directory) 109 | 110 | if with_tensorboard and rank == 0: 111 | from tensorboardX import SummaryWriter 112 | logger = SummaryWriter(os.path.join(output_directory, 'logs')) 113 | 114 | model.train() 115 | epoch_offset = max(0, int(iteration / len(train_loader))) 116 | # ================ MAIN TRAINNIG LOOP! =================== 117 | for epoch in range(epoch_offset, epochs): 118 | print("Epoch: {}".format(epoch)) 119 | for i, batch in enumerate(train_loader): 120 | model.zero_grad() 121 | 122 | mel, audio = batch 123 | mel = torch.autograd.Variable(mel.cuda()) 124 | audio = torch.autograd.Variable(audio.cuda()) 125 | outputs = model((mel, audio)) 126 | 127 | loss = criterion(outputs) 128 | if num_gpus > 1: 129 | reduced_loss = reduce_tensor(loss.data, num_gpus).item() 130 | else: 131 | reduced_loss = loss.item() 132 | 133 | if fp16_run: 134 | with amp.scale_loss(loss, optimizer) as scaled_loss: 135 | scaled_loss.backward() 136 | else: 137 | loss.backward() 138 | 139 | optimizer.step() 140 | 141 | print("{}:\t{:.9f}".format(iteration, reduced_loss)) 142 | if with_tensorboard and rank == 0: 143 | logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) 144 | 145 | if (iteration % iters_per_checkpoint == 0): 146 | if rank == 0: 147 | checkpoint_path = "{}/waveglow_{}".format( 148 | output_directory, iteration) 149 | save_checkpoint(model, optimizer, learning_rate, iteration, 150 | checkpoint_path) 151 | 152 | iteration += 1 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('-c', '--config', type=str, 157 | help='JSON file for configuration') 158 | parser.add_argument('-r', '--rank', type=int, default=0, 159 | help='rank of process for distributed') 160 | parser.add_argument('-g', '--group_name', type=str, default='', 161 | help='name of group for distributed') 162 | args = parser.parse_args() 163 | 164 | # Parse configs. Globals nicer in this case 165 | with open(args.config) as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | train_config = config["train_config"] 169 | global data_config 170 | data_config = config["data_config"] 171 | global dist_config 172 | dist_config = config["dist_config"] 173 | global waveglow_config 174 | waveglow_config = config["waveglow_config"] 175 | 176 | num_gpus = torch.cuda.device_count() 177 | if num_gpus > 1: 178 | if args.group_name == '': 179 | print("WARNING: Multiple GPUs detected but no distributed group set") 180 | print("Only running 1 GPU. Use distributed.py for multiple GPUs") 181 | num_gpus = 1 182 | 183 | if num_gpus == 1 and args.rank != 0: 184 | raise Exception("Doing single GPU training on rank > 0") 185 | 186 | torch.backends.cudnn.enabled = True 187 | torch.backends.cudnn.benchmark = False 188 | train(num_gpus, args.rank, args.group_name, **train_config) 189 | -------------------------------------------------------------------------------- /waveglow/waveglow_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/waveglow/waveglow_logo.png -------------------------------------------------------------------------------- /wavs/LJ001-0029_phone10000_10.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_10.wav -------------------------------------------------------------------------------- /wavs/LJ001-0029_phone10000_11.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_11.wav -------------------------------------------------------------------------------- /wavs/LJ001-0029_phone10000_12.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_12.wav -------------------------------------------------------------------------------- /wavs/LJ001-0029_phone10000_8.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_8.wav -------------------------------------------------------------------------------- /wavs/LJ001-0029_phone10000_9.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_9.wav -------------------------------------------------------------------------------- /wavs/LJ001-0085_phone10000_10.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_10.wav -------------------------------------------------------------------------------- /wavs/LJ001-0085_phone10000_11.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_11.wav -------------------------------------------------------------------------------- /wavs/LJ001-0085_phone10000_12.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_12.wav -------------------------------------------------------------------------------- /wavs/LJ001-0085_phone10000_8.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_8.wav -------------------------------------------------------------------------------- /wavs/LJ001-0085_phone10000_9.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_9.wav -------------------------------------------------------------------------------- /wavs/LJ002-0106_phone10000_10.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_10.wav -------------------------------------------------------------------------------- /wavs/LJ002-0106_phone10000_11.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_11.wav -------------------------------------------------------------------------------- /wavs/LJ002-0106_phone10000_12.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_12.wav -------------------------------------------------------------------------------- /wavs/LJ002-0106_phone10000_8.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_8.wav -------------------------------------------------------------------------------- /wavs/LJ002-0106_phone10000_9.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_9.wav --------------------------------------------------------------------------------