├── README.md ├── audio_processing.py ├── figures ├── enc_dec_align.JPG ├── gate_out.JPG ├── melspec.JPG ├── train_bce_loss.JPG ├── train_guide_loss.JPG ├── train_mel_loss.JPG ├── val_bce_loss.JPG ├── val_guide_loss.JPG └── val_mel_loss.JPG ├── filelists ├── ljs_audio_text_test_filelist.txt ├── ljs_audio_text_train_filelist.txt └── ljs_audio_text_val_filelist.txt ├── generate_samples.ipynb ├── hparams.py ├── index.html ├── inference.ipynb ├── layers.py ├── modules ├── __pycache__ │ ├── init_layer.cpython-37.pyc │ ├── loss.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── transformer.cpython-37.pyc ├── init_layer.py ├── loss.py ├── model.py └── transformer.py ├── prepare_data.ipynb ├── stft.py ├── text ├── LICENSE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── cleaners.cpython-36.pyc │ ├── cleaners.cpython-37.pyc │ ├── cmudict.cpython-37.pyc │ ├── numbers.cpython-37.pyc │ └── symbols.cpython-37.pyc ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py ├── train-gae.py ├── train-graphtts.py ├── train.py ├── utils ├── __pycache__ │ ├── data_utils.cpython-37.pyc │ ├── plot_image.cpython-37.pyc │ ├── utils.cpython-37.pyc │ └── writer.cpython-37.pyc ├── data_utils.py ├── plot_image.py ├── utils.py └── writer.py ├── waveglow ├── .gitmodules ├── LICENSE ├── README.md ├── config.json ├── convert_model.py ├── denoiser.py ├── distributed.py ├── glow.py ├── glow_old.py ├── inference.py ├── mel2samp.py ├── requirements.txt ├── train.py └── waveglow_logo.png └── wavs ├── gae-char ├── LJ001-0029_char100000.wav ├── LJ001-0029_char20000.wav ├── LJ001-0029_char200000.wav ├── LJ001-0029_char50000.wav ├── LJ001-0085_char100000.wav ├── LJ001-0085_char20000.wav ├── LJ001-0085_char200000.wav ├── LJ001-0085_char50000.wav ├── LJ002-0106_char100000.wav ├── LJ002-0106_char20000.wav ├── LJ002-0106_char200000.wav └── LJ002-0106_char50000.wav ├── graph-tts-char ├── LJ001-0029_char100000.wav ├── LJ001-0029_char20000.wav ├── LJ001-0029_char200000.wav ├── LJ001-0029_char50000.wav ├── LJ001-0085_char100000.wav ├── LJ001-0085_char20000.wav ├── LJ001-0085_char200000.wav ├── LJ001-0085_char50000.wav ├── LJ002-0106_char100000.wav ├── LJ002-0106_char20000.wav ├── LJ002-0106_char200000.wav └── LJ002-0106_char50000.wav └── graph-tts-char_iter5 ├── LJ001-0029_char100000.wav ├── LJ001-0029_char20000.wav ├── LJ001-0029_char200000.wav ├── LJ001-0029_char50000.wav ├── LJ001-0085_char100000.wav ├── LJ001-0085_char20000.wav ├── LJ001-0085_char200000.wav ├── LJ001-0085_char50000.wav ├── LJ002-0106_char100000.wav ├── LJ002-0106_char20000.wav ├── LJ002-0106_char200000.wav └── LJ002-0106_char50000.wav /README.md: -------------------------------------------------------------------------------- 1 | # Graph-TTS 2 | - Implementation of ["GraphTTS: graph-to-sequence modelling in neural text-to-speech"](https://arxiv.org/abs/2003.01924) 3 | - I failed to generate the plausible speech :( 4 | 5 | ## Training 6 | 1. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/) 7 | 2. Make `preprocessed` folder in LJSpeech directory and make `char_seq` & `phone_seq` & `melspectrogram` folder in it 8 | 3. Set `data_path` in `hparams.py` as the LJSpeech folder 9 | 4. Using `prepare_data.ipynb`, prepare melspectrogram and text (converted into indices) tensors. 10 | 5. `python train.py` 11 | 12 | ## Training curve (Orange: transformer-tts / Navy: graph-tts / Red: grap-tts-iter5 / Blue: gae) 13 | - Stop prediction loss (train / val) 14 | 15 | - Guided attention loss (train / val) 16 | 17 | - L1 loss (train / val) 18 | 19 | 20 | ## Alignments 21 | - Encoder-Decoder Alignments 22 | 23 | 24 | - Melspectrogram 25 | 26 | 27 | - Stop prediction 28 | 29 | 30 | ## Audio Samples 31 | You can hear the audio samples [here](https://leeyoonhyung.github.io/GraphTTS/) 32 | You can also hear the audio samples obtained from the Transformer-TTS [here](https://leeyoonhyung.github.io/Transformer-TTS/) 33 | 34 | ## Notice 35 | 1. Unlike the original paper, I didn't use the encoder-prenet following [espnet](https://github.com/espnet/espnet) 36 | 2. I apply additional ["guided attention loss"](https://arxiv.org/pdf/1710.08969.pdf) to the two heads of the last two layers 37 | 3. Batch size is important, so I use gradient accumulation 38 | 4. You can also use DataParallel. Change the `n_gpus`, `batch_size`, `accumulation` appropriately. 39 | 5. To draw attention plots for every each head, I change return values of the "torch.nn.functional.multi_head_attention_forward()" 40 | ```python 41 | #before 42 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 43 | 44 | #after 45 | return attn_output, attn_output_weights 46 | ``` 47 | 3. Among `num_layers*num_heads` attention matrices, the one with the highest focus rate is saved. 48 | 49 | ## Reference 50 | 1.NVIDIA/tacotron2: https://github.com/NVIDIA/tacotron2 51 | 2.espnet/espnet: https://github.com/espnet/espnet 52 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, 8 | n_frames, 9 | hop_length=200, 10 | win_length=800, 11 | n_fft=800, 12 | dtype=np.float32, 13 | norm=None): 14 | """ 15 | # from librosa 0.6 16 | Compute the sum-square envelope of a window function at a given hop length. 17 | 18 | This is used to estimate modulation effects induced by windowing 19 | observations in short-time fourier transforms. 20 | 21 | Parameters 22 | ---------- 23 | window : string, tuple, number, callable, or list-like 24 | Window specification, as in `get_window` 25 | 26 | n_frames : int > 0 27 | The number of analysis frames 28 | 29 | hop_length : int > 0 30 | The number of samples to advance between frames 31 | 32 | win_length : [optional] 33 | The length of the window function. By default, this matches `n_fft`. 34 | 35 | n_fft : int > 0 36 | The length of each analysis frame. 37 | 38 | dtype : np.dtype 39 | The data type of the output 40 | 41 | Returns 42 | ------- 43 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 44 | The sum-squared envelope of the window function 45 | """ 46 | if win_length is None: 47 | win_length = n_fft 48 | 49 | n = n_fft + hop_length * (n_frames - 1) 50 | x = np.zeros(n, dtype=dtype) 51 | 52 | # Compute the squared window at the desired length 53 | win_sq = get_window(window, win_length, fftbins=True) 54 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 55 | win_sq = librosa_util.pad_center(win_sq, n_fft) 56 | 57 | # Fill the envelope 58 | for i in range(n_frames): 59 | sample = i * hop_length 60 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 61 | return x 62 | 63 | 64 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 65 | """ 66 | PARAMS 67 | ------ 68 | magnitudes: spectrogram magnitudes 69 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 70 | """ 71 | 72 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 73 | angles = angles.astype(np.float32) 74 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 75 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 76 | 77 | for i in range(n_iters): 78 | _, angles = stft_fn.transform(signal) 79 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 80 | return signal 81 | 82 | 83 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 84 | """ 85 | PARAMS 86 | ------ 87 | C: compression factor 88 | """ 89 | return torch.log(torch.clamp(x, min=clip_val) * C) 90 | 91 | 92 | def dynamic_range_decompression(x, C=1): 93 | """ 94 | PARAMS 95 | ------ 96 | C: compression factor used to compress 97 | """ 98 | return torch.exp(x) / C -------------------------------------------------------------------------------- /figures/enc_dec_align.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/enc_dec_align.JPG -------------------------------------------------------------------------------- /figures/gate_out.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/gate_out.JPG -------------------------------------------------------------------------------- /figures/melspec.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/melspec.JPG -------------------------------------------------------------------------------- /figures/train_bce_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/train_bce_loss.JPG -------------------------------------------------------------------------------- /figures/train_guide_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/train_guide_loss.JPG -------------------------------------------------------------------------------- /figures/train_mel_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/train_mel_loss.JPG -------------------------------------------------------------------------------- /figures/val_bce_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/val_bce_loss.JPG -------------------------------------------------------------------------------- /figures/val_guide_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/val_guide_loss.JPG -------------------------------------------------------------------------------- /figures/val_mel_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/val_mel_loss.JPG -------------------------------------------------------------------------------- /generate_samples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Import libraries and setup matplotlib" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 18 | "\n", 19 | "import warnings\n", 20 | "warnings.filterwarnings(\"ignore\")\n", 21 | "\n", 22 | "import sys\n", 23 | "sys.path.append('waveglow/')\n", 24 | "\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "%matplotlib inline\n", 27 | "\n", 28 | "import IPython.display as ipd\n", 29 | "import pickle as pkl\n", 30 | "from text import *\n", 31 | "import numpy as np\n", 32 | "import torch\n", 33 | "import hparams\n", 34 | "from modules.model import GAE, GraphTTS\n", 35 | "from denoiser import Denoiser\n", 36 | "import soundfile" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### Text preprocessing" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from g2p_en import G2p\n", 53 | "from text.symbols import symbols\n", 54 | "from text.cleaners import custom_english_cleaners\n", 55 | "\n", 56 | "# Mappings from symbol to numeric ID and vice versa:\n", 57 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n", 58 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n", 59 | "\n", 60 | "g2p = G2p()\n", 61 | "\n", 62 | "def text2seq(text, data_type='char'):\n", 63 | " text = custom_english_cleaners(text.rstrip())\n", 64 | " if data_type=='phone':\n", 65 | " clean_phone = []\n", 66 | " for s in g2p(text.lower()):\n", 67 | " if '@'+s in symbol_to_id:\n", 68 | " clean_phone.append('@'+s)\n", 69 | " else:\n", 70 | " clean_phone.append(s)\n", 71 | " text = clean_phone\n", 72 | " \n", 73 | " # Append SOS, EOS token\n", 74 | " sequence = [symbol_to_id[c] for c in text]\n", 75 | " sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n", 76 | " return sequence\n", 77 | "\n", 78 | "\n", 79 | "def create_adjacency_matrix(char_seq):\n", 80 | " n_nodes=char_seq.size(1)\n", 81 | " n_edge_types=3\n", 82 | " \n", 83 | " a = np.zeros([n_edge_types, n_nodes, n_nodes], dtype=np.int)\n", 84 | " \n", 85 | " a[0] = np.eye(n_nodes, k=1, dtype=int) + np.eye(n_nodes, k=-1, dtype=int)\n", 86 | " a[2] = 1\n", 87 | " \n", 88 | " white_spaces = (char_seq==symbol_to_id[' ']).nonzero()\n", 89 | " start = torch.cat([white_spaces.new_tensor(torch.tensor([0])), white_spaces[1:].view(-1)])\n", 90 | " end = torch.cat([white_spaces[1:].view(-1), white_spaces.new_tensor(torch.tensor([n_nodes]))])\n", 91 | " for i in range(len(start)):\n", 92 | " a[1, start[i]:end[i], start[i]:end[i]]=1\n", 93 | " \n", 94 | " return torch.from_numpy(a).unsqueeze(0)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "### Waveglow" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 3, 107 | "metadata": { 108 | "code_folding": [] 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "waveglow_path = 'training_log/waveglow_256channels.pt'\n", 113 | "waveglow = torch.load(waveglow_path)['model']\n", 114 | "\n", 115 | "for m in waveglow.modules():\n", 116 | " if 'Conv' in str(type(m)):\n", 117 | " setattr(m, 'padding_mode', 'zeros')\n", 118 | "\n", 119 | "waveglow.cuda().eval()\n", 120 | "for k in waveglow.convinv:\n", 121 | " k.float()\n", 122 | "\n", 123 | "denoiser = Denoiser(waveglow)\n", 124 | "\n", 125 | "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n", 126 | " lines = [line.split('|') for line in f.read().splitlines()]" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 4, 132 | "metadata": { 133 | "scrolled": false 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "data_type='char'\n", 138 | "for test_case in ['graph-tts-char', 'graph-tts-char_iter5', 'gae-char']:\n", 139 | " for step in ['20000', '50000', '100000', '200000']:\n", 140 | " checkpoint_path = f\"training_log/{test_case}/checkpoint_{step}\"\n", 141 | " state_dict = {}\n", 142 | " for k, v in torch.load(checkpoint_path)['state_dict'].items():\n", 143 | " state_dict[k[7:]]=v\n", 144 | " \n", 145 | " if 'gae' in test_case:\n", 146 | " model = GAE(hparams).cuda()\n", 147 | " model.load_state_dict(state_dict)\n", 148 | " _ = model.cuda().eval()\n", 149 | " else:\n", 150 | " model = GraphTTS(hparams).cuda()\n", 151 | " model.load_state_dict(state_dict)\n", 152 | " _ = model.cuda().eval()\n", 153 | "\n", 154 | " for i in [1, 6, 22]:\n", 155 | " file_name, _, text = lines[i]\n", 156 | " sequence = np.array(text2seq(text,data_type))[None, :]\n", 157 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n", 158 | " adj_matrix = create_adjacency_matrix(sequence).cuda().long()\n", 159 | " adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1)\n", 160 | " adj_matrix = adj_matrix.transpose(1, 2).reshape(-1, sequence.size(1), sequence.size(1)*3*2)\n", 161 | "\n", 162 | " with torch.no_grad():\n", 163 | " melspec, dec_alignments, enc_dec_alignments, stop = model.inference(sequence,\n", 164 | " adj_matrix,\n", 165 | " max_len=1024)\n", 166 | " melspec = melspec[:,:,:len(stop)]\n", 167 | " audio = waveglow.infer(melspec, sigma=0.666)\n", 168 | "\n", 169 | " soundfile.write(f'wavs/{test_case}/{file_name}_{data_type}{step}.wav', audio.cpu().numpy()[0].astype(float), 22050)" 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python [conda env:LYH] *", 176 | "language": "python", 177 | "name": "conda-env-LYH-py" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.7.5" 190 | }, 191 | "varInspector": { 192 | "cols": { 193 | "lenName": 16, 194 | "lenType": 16, 195 | "lenVar": 40 196 | }, 197 | "kernels_config": { 198 | "python": { 199 | "delete_cmd_postfix": "", 200 | "delete_cmd_prefix": "del ", 201 | "library": "var_list.py", 202 | "varRefreshCmd": "print(var_dic_list())" 203 | }, 204 | "r": { 205 | "delete_cmd_postfix": ") ", 206 | "delete_cmd_prefix": "rm(", 207 | "library": "var_list.r", 208 | "varRefreshCmd": "cat(var_dic_list()) " 209 | } 210 | }, 211 | "types_to_exclude": [ 212 | "module", 213 | "function", 214 | "builtin_function_or_method", 215 | "instance", 216 | "_Feature" 217 | ], 218 | "window_display": false 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 2 223 | } 224 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | ################################ 4 | # Experiment Parameters # 5 | ################################ 6 | seed=1234 7 | n_gpus=2 8 | output_directory = 'training_log' 9 | log_directory = 'graph-tts-char' 10 | data_path = '/media/disk1/lyh/LJSpeech-1.1/preprocessed' 11 | 12 | training_files='filelists/ljs_audio_text_train_filelist.txt' 13 | validation_files='filelists/ljs_audio_text_val_filelist.txt' 14 | text_cleaners=['english_cleaners'] 15 | 16 | 17 | ################################ 18 | # Audio Parameters # 19 | ################################ 20 | sampling_rate=22050 21 | filter_length=1024 22 | hop_length=256 23 | win_length=1024 24 | n_mel_channels=80 25 | mel_fmin=0 26 | mel_fmax=8000.0 27 | 28 | ################################ 29 | # Model Parameters # 30 | ################################ 31 | n_symbols=len(symbols) 32 | data_type='char_seq' # 'phone_seq' 33 | symbols_embedding_dim=256 34 | hidden_dim=256 35 | dprenet_dim=256 36 | postnet_dim=256 37 | ff_dim=1024 38 | n_heads=4 39 | n_layers=6 40 | n_postnet_layers=5 41 | iterations=1 42 | 43 | ################################ 44 | # Optimization Hyperparameters # 45 | ################################ 46 | lr=384**-0.5 47 | warmup_steps=4000 48 | grad_clip_thresh=1.0 49 | batch_size=32 50 | accumulation=2 51 | iters_per_validation=2000 52 | iters_per_checkpoint=10000 53 | train_steps = 200000 -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Grpah-TTS Audio Samples 7 | 8 | 9 | 10 | 11 |
12 |

graph-tts

13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 30 | 32 | 34 | 35 | 36 | 37 | 39 | 41 | 43 | 45 | 46 | 47 | 48 | 50 | 52 | 54 | 56 | 57 | 58 |
Steps 20000Steps 50000Steps 100000Steps 200000
Steps LJ001-0029
Steps LJ001-0085
Steps LJ002-0106
59 | 60 |

graph-tts-iter5

61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 76 | 78 | 80 | 82 | 83 | 84 | 85 | 87 | 89 | 91 | 93 | 94 | 95 | 96 | 98 | 100 | 102 | 104 | 105 | 106 |
Steps 20000Steps 50000Steps 100000Steps 200000
Steps LJ001-0029
Steps LJ001-0085
Steps LJ002-0106
107 | 108 |

gae

109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 124 | 126 | 128 | 130 | 131 | 132 | 133 | 135 | 137 | 139 | 141 | 142 | 143 | 144 | 146 | 148 | 150 | 152 | 153 | 154 |
Steps 20000Steps 50000Steps 100000Steps 200000
Steps LJ001-0029
Steps LJ001-0085
Steps LJ002-0106
155 | 156 |
157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from audio_processing import dynamic_range_compression 4 | from audio_processing import dynamic_range_decompression 5 | from stft import STFT 6 | 7 | 8 | class LinearNorm(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 10 | super(LinearNorm, self).__init__() 11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 12 | 13 | torch.nn.init.xavier_uniform_( 14 | self.linear_layer.weight, 15 | gain=torch.nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class ConvNorm(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 23 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 24 | super(ConvNorm, self).__init__() 25 | if padding is None: 26 | assert(kernel_size % 2 == 1) 27 | padding = int(dilation * (kernel_size - 1) / 2) 28 | 29 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, 32 | bias=bias) 33 | 34 | torch.nn.init.xavier_uniform_( 35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 36 | 37 | def forward(self, signal): 38 | conv_signal = self.conv(signal) 39 | return conv_signal 40 | 41 | 42 | class TacotronSTFT(torch.nn.Module): 43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 45 | mel_fmax=8000.0): 46 | super(TacotronSTFT, self).__init__() 47 | self.n_mel_channels = n_mel_channels 48 | self.sampling_rate = sampling_rate 49 | self.stft_fn = STFT(filter_length, hop_length, win_length) 50 | mel_basis = librosa_mel_fn( 51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 52 | mel_basis = torch.from_numpy(mel_basis).float() 53 | self.register_buffer('mel_basis', mel_basis) 54 | 55 | def spectral_normalize(self, magnitudes): 56 | output = dynamic_range_compression(magnitudes) 57 | return output 58 | 59 | def spectral_de_normalize(self, magnitudes): 60 | output = dynamic_range_decompression(magnitudes) 61 | return output 62 | 63 | def mel_spectrogram(self, y): 64 | """Computes mel-spectrograms from a batch of waves 65 | PARAMS 66 | ------ 67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 68 | RETURNS 69 | ------- 70 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 71 | """ 72 | assert(torch.min(y.data) >= -1) 73 | assert(torch.max(y.data) <= 1) 74 | 75 | magnitudes, phases = self.stft_fn.transform(y) 76 | magnitudes = magnitudes.data 77 | mel_output = torch.matmul(self.mel_basis, magnitudes) 78 | mel_output = self.spectral_normalize(mel_output) 79 | return mel_output -------------------------------------------------------------------------------- /modules/__pycache__/init_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/init_layer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/init_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Linear(nn.Linear): 7 | def __init__(self, 8 | in_dim, 9 | out_dim, 10 | bias=True, 11 | w_init_gain='linear'): 12 | super(Linear, self).__init__(in_dim, 13 | out_dim, 14 | bias) 15 | nn.init.xavier_uniform_(self.weight, 16 | gain=nn.init.calculate_gain(w_init_gain)) 17 | 18 | 19 | class Conv1d(nn.Conv1d): 20 | def __init__(self, 21 | in_channels, 22 | out_channels, 23 | kernel_size, 24 | stride=1, 25 | padding=0, 26 | dilation=1, 27 | groups=1, 28 | bias=True, 29 | padding_mode='zeros', 30 | w_init_gain='linear'): 31 | super(Conv1d, self).__init__(in_channels, 32 | out_channels, 33 | kernel_size, 34 | stride, 35 | padding, 36 | dilation, 37 | groups, 38 | bias, 39 | padding_mode) 40 | nn.init.xavier_uniform_(self.weight, 41 | gain=nn.init.calculate_gain(w_init_gain)) -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.utils import get_mask_from_lengths 4 | 5 | 6 | class TransformerLoss(nn.Module): 7 | def __init__(self): 8 | super(TransformerLoss, self).__init__() 9 | 10 | def forward(self, pred, target, guide): 11 | mel_out, mel_out_post, gate_out = pred 12 | mel_target, gate_target = target 13 | alignments, text_lengths, mel_lengths = guide 14 | 15 | mask = ~get_mask_from_lengths(mel_lengths) 16 | 17 | mel_target = mel_target.masked_select(mask.unsqueeze(1)) 18 | mel_out_post = mel_out_post.masked_select(mask.unsqueeze(1)) 19 | mel_out = mel_out.masked_select(mask.unsqueeze(1)) 20 | 21 | gate_target = gate_target.masked_select(mask) 22 | gate_out = gate_out.masked_select(mask) 23 | 24 | mel_loss = nn.L1Loss()(mel_out, mel_target) + nn.L1Loss()(mel_out_post, mel_target) 25 | bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))(gate_out, gate_target) 26 | guide_loss = self.guide_loss(alignments, text_lengths, mel_lengths) 27 | 28 | return mel_loss, bce_loss, guide_loss 29 | 30 | 31 | def guide_loss(self, alignments, text_lengths, mel_lengths): 32 | B, n_layers, n_heads, T, L = alignments.size() 33 | 34 | # B, T, L 35 | W = alignments.new_zeros(B, T, L) 36 | mask = alignments.new_zeros(B, T, L) 37 | 38 | for i, (t, l) in enumerate(zip(mel_lengths, text_lengths)): 39 | mel_seq = alignments.new_tensor( torch.arange(t).to(torch.float32).unsqueeze(-1)/t ) 40 | text_seq = alignments.new_tensor( torch.arange(l).to(torch.float32).unsqueeze(0)/l ) 41 | x = torch.pow(mel_seq-text_seq, 2) 42 | W[i, :t, :l] += alignments.new_tensor(1-torch.exp(-3.125*x)) 43 | mask[i, :t, :l] = 1 44 | 45 | # Apply guided_loss to 2 heads of the last 2 layers 46 | applied_align = alignments[:, -2:, :2] 47 | losses = applied_align*(W.unsqueeze(1).unsqueeze(1)) 48 | 49 | return torch.mean(losses.masked_select(mask.unsqueeze(1).unsqueeze(1).to(torch.bool))) 50 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .init_layer import * 5 | from .transformer import * 6 | from utils.utils import get_mask_from_lengths 7 | 8 | 9 | class CBAD(nn.Module): 10 | def __init__(self, 11 | in_dim, 12 | out_dim, 13 | kernel_size, 14 | stride, 15 | padding, 16 | bias, 17 | activation, 18 | dropout): 19 | super(CBAD, self).__init__() 20 | self.conv = Conv1d(in_dim, 21 | out_dim, 22 | kernel_size=kernel_size, 23 | stride=stride, 24 | padding=padding, 25 | bias=bias, 26 | w_init_gain=activation) 27 | 28 | self.bn = nn.BatchNorm1d(out_dim) 29 | 30 | if activation == 'relu': 31 | self.activation = nn.ReLU() 32 | elif activation == 'tanh': 33 | self.activation = nn.Tanh() 34 | 35 | self.dropout = nn.Dropout(dropout) 36 | 37 | def forward(self, x): 38 | x = self.conv(x) 39 | x = self.bn(x) 40 | x = self.activation(x) 41 | out = self.dropout(x) 42 | 43 | return out 44 | 45 | class Prenet_D(nn.Module): 46 | def __init__(self, hp): 47 | super(Prenet_D, self).__init__() 48 | self.linear1 = Linear(hp.n_mel_channels, 49 | hp.dprenet_dim, 50 | w_init_gain='relu') 51 | self.linear2 = Linear(hp.dprenet_dim, hp.dprenet_dim, w_init_gain='relu') 52 | self.linear3 = Linear(hp.dprenet_dim, hp.hidden_dim) 53 | 54 | def forward(self, x): 55 | # Set training==True following tacotron2 56 | x = F.dropout(F.relu(self.linear1(x)), p=0.5, training=True) 57 | x = F.dropout(F.relu(self.linear2(x)), p=0.5, training=True) 58 | x = self.linear3(x) 59 | return x 60 | 61 | class PostNet(nn.Module): 62 | def __init__(self, hp): 63 | super(PostNet, self).__init__() 64 | conv_list = [CBAD(in_dim=hp.n_mel_channels, 65 | out_dim=hp.postnet_dim, 66 | kernel_size=5, 67 | stride=1, 68 | padding=2, 69 | bias=False, 70 | activation='tanh', 71 | dropout=0.5)] 72 | 73 | for _ in range(hp.n_postnet_layers-2): 74 | conv_list.append(CBAD(in_dim=hp.postnet_dim, 75 | out_dim=hp.postnet_dim, 76 | kernel_size=5, 77 | stride=1, 78 | padding=2, 79 | bias=False, 80 | activation='tanh', 81 | dropout=0.5)) 82 | 83 | conv_list.append(nn.Sequential(nn.Conv1d(hp.postnet_dim, 84 | hp.n_mel_channels, 85 | kernel_size=5, 86 | padding=2, 87 | bias=False), 88 | nn.BatchNorm1d(hp.n_mel_channels), 89 | nn.Dropout(0.5))) 90 | 91 | self.conv=nn.ModuleList(conv_list) 92 | 93 | def forward(self, x): 94 | for conv in self.conv: 95 | x = conv(x) 96 | return x 97 | 98 | class GraphEncoder(nn.Module): 99 | def __init__(self, hp): 100 | super(GraphEncoder, self).__init__() 101 | self.iterations = hp.iterations 102 | self.forward_edges = Linear(hp.hidden_dim, hp.hidden_dim*3, bias=False) 103 | self.backward_edges = Linear(hp.hidden_dim, hp.hidden_dim*3, bias=False) 104 | 105 | self.reset_gate = nn.Sequential( 106 | Linear(hp.hidden_dim*3, hp.hidden_dim, w_init_gain='sigmoid'), 107 | nn.Sigmoid() 108 | ) 109 | self.update_gate = nn.Sequential( 110 | Linear(hp.hidden_dim*3, hp.hidden_dim, w_init_gain='sigmoid'), 111 | nn.Sigmoid() 112 | ) 113 | self.tansform = nn.Sequential( 114 | Linear(hp.hidden_dim*3, hp.hidden_dim, w_init_gain='tanh'), 115 | nn.Tanh() 116 | ) 117 | 118 | def forward(self, x, adj_matrix): 119 | x = x.transpose(0,1) 120 | adj_matrix = adj_matrix / adj_matrix.sum(dim=-1).unsqueeze(-1) # B, N, 6N 121 | A_in = adj_matrix[:, :, :adj_matrix.size(2)//2].to(torch.float) # B, N, 3N 122 | A_out = adj_matrix[:, :, adj_matrix.size(2)//2:].to(torch.float) # B, N, 3N 123 | 124 | for k in range(self.iterations): 125 | H_in = torch.cat(torch.chunk(self.forward_edges(x), chunks=3, dim=-1), dim=1) 126 | H_out = torch.cat(torch.chunk(self.backward_edges(x), chunks=3, dim=-1), dim=1) 127 | 128 | a_in = torch.bmm(A_in, H_in) 129 | a_out = torch.bmm(A_out, H_out) 130 | a = torch.cat((a_in, a_out, x), dim=-1) 131 | 132 | r = self.reset_gate(a) 133 | z = self.update_gate(a) 134 | joined_input = torch.cat((a_in, a_out, r * x), dim=-1) 135 | h_hat = self.tansform(joined_input) 136 | 137 | x = (1 - z) * x + z * h_hat 138 | 139 | return x.transpose(0,1) 140 | 141 | 142 | class GraphTTS(nn.Module): 143 | def __init__(self, hp): 144 | super(GraphTTS, self).__init__() 145 | self.hp = hp 146 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim) 147 | self.Prenet_D = Prenet_D(hp) 148 | 149 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe) 150 | self.dropout = nn.Dropout(0.1) 151 | 152 | self.Encoder = GraphEncoder(hp) 153 | self.Decoder = nn.ModuleList([TransformerDecoderLayer(d_model=hp.hidden_dim, 154 | nhead=hp.n_heads, 155 | dim_feedforward=hp.ff_dim) 156 | for _ in range(hp.n_layers)]) 157 | 158 | self.Projection = Linear(hp.hidden_dim, hp.n_mel_channels) 159 | self.Postnet = PostNet(hp) 160 | self.Stop = nn.Linear(hp.n_mel_channels, 1) 161 | 162 | 163 | def outputs(self, text, adj_matrix, melspec, text_lengths, mel_lengths): 164 | ### Size ### 165 | B, L, T = text.size(0), text.size(1), melspec.size(2) 166 | adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1) 167 | adj_matrix = adj_matrix.transpose(1, 2).reshape(-1, text_lengths.max().item(), text_lengths.max().item()*3*2) 168 | 169 | ### Prepare Encoder Input ### 170 | encoder_input = self.Embedding(text).transpose(0,1) 171 | encoder_input += self.pe[:L].unsqueeze(1) 172 | encoder_input = self.dropout(encoder_input) 173 | memory = self.Encoder(encoder_input, adj_matrix) 174 | 175 | ### Prepare Decoder Input ### 176 | mel_input = F.pad(melspec, (1,-1)).transpose(1,2) 177 | decoder_input = self.Prenet_D(mel_input).transpose(0,1) 178 | decoder_input += self.pe[:T].unsqueeze(1) 179 | decoder_input = self.dropout(decoder_input) 180 | 181 | ### Prepare Masks ### 182 | text_mask = get_mask_from_lengths(text_lengths) 183 | mel_mask = get_mask_from_lengths(mel_lengths) 184 | diag_mask = torch.triu(melspec.new_ones(T,T)).transpose(0, 1) 185 | diag_mask[diag_mask == 0] = -float('inf') 186 | diag_mask[diag_mask == 1] = 0 187 | 188 | ### Decoding ### 189 | tgt = decoder_input 190 | dec_alignments, enc_dec_alignments = [], [] 191 | for layer in self.Decoder: 192 | tgt, dec_align, enc_dec_align = layer(tgt, 193 | memory, 194 | tgt_mask=diag_mask, 195 | tgt_key_padding_mask=mel_mask, 196 | memory_key_padding_mask=text_mask) 197 | dec_alignments.append(dec_align.unsqueeze(1)) 198 | enc_dec_alignments.append(enc_dec_align.unsqueeze(1)) 199 | dec_alignments = torch.cat(dec_alignments, 1) 200 | enc_dec_alignments = torch.cat(enc_dec_alignments, 1) 201 | 202 | ### Projection + PostNet ### 203 | mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2) 204 | mel_out_post = self.Postnet(mel_out) + mel_out 205 | 206 | gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1) 207 | 208 | return mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out 209 | 210 | 211 | def forward(self, text, adj_matrix, melspec, gate, text_lengths, mel_lengths, criterion): 212 | ### Size ### 213 | text = text[:,:text_lengths.max().item()] 214 | adj_matrix = adj_matrix[:, :, :text_lengths.max().item(), :text_lengths.max().item()] 215 | 216 | melspec = melspec[:,:,:mel_lengths.max().item()] 217 | gate = gate[:, :mel_lengths.max().item()] 218 | outputs = self.outputs(text, adj_matrix, melspec, text_lengths, mel_lengths) 219 | 220 | mel_out, mel_out_post = outputs[0], outputs[1] 221 | enc_dec_alignments = outputs[3] 222 | gate_out=outputs[4] 223 | 224 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out), 225 | (melspec, gate), 226 | (enc_dec_alignments, text_lengths, mel_lengths)) 227 | 228 | return mel_loss, bce_loss, guide_loss 229 | 230 | 231 | def inference(self, text, adj_matrix, max_len=1024): 232 | ### Size & Length ### 233 | (B, L), T = text.size(), max_len 234 | 235 | ### Prepare Inputs ### 236 | encoder_input = self.Embedding(text).transpose(0,1).contiguous() 237 | encoder_input += self.pe[:L].unsqueeze(1) 238 | memory = self.Encoder(encoder_input, adj_matrix) 239 | 240 | ### Prepare Masks ### 241 | text_mask = text.new_zeros(1, L).to(torch.bool) 242 | mel_mask = text.new_zeros(1, T).to(torch.bool) 243 | diag_mask = torch.triu(text.new_ones(T, T)).transpose(0, 1).contiguous() 244 | diag_mask[diag_mask == 0] = -1e9 245 | diag_mask[diag_mask == 1] = 0 246 | 247 | ### Transformer Decoder ### 248 | mel_input = text.new_zeros(1, 249 | self.hp.n_mel_channels, 250 | max_len).to(torch.float32) 251 | dec_alignments = text.new_zeros(self.hp.n_layers, 252 | self.hp.n_heads, 253 | max_len, 254 | max_len).to(torch.float32) 255 | enc_dec_alignments = text.new_zeros(self.hp.n_layers, 256 | self.hp.n_heads, 257 | max_len, 258 | text.size(1)).to(torch.float32) 259 | 260 | ### Generation ### 261 | stop=[] 262 | for i in range(max_len): 263 | tgt = self.Prenet_D(mel_input.transpose(1,2).contiguous()).transpose(0,1).contiguous() 264 | tgt += self.pe[:T].unsqueeze(1) 265 | 266 | for j, layer in enumerate(self.Decoder): 267 | tgt, dec_align, enc_dec_align = layer(tgt, 268 | memory, 269 | tgt_mask=diag_mask, 270 | tgt_key_padding_mask=mel_mask, 271 | memory_key_padding_mask=text_mask) 272 | dec_alignments[j, :, i] = dec_align[0, :, i] 273 | enc_dec_alignments[j, :, i] = enc_dec_align[0, :, i] 274 | 275 | mel_out = self.Projection(tgt.transpose(0,1).contiguous()) 276 | stop.append(torch.sigmoid(self.Stop(mel_out[:,i]))[0,0].item()) 277 | 278 | if i < max_len - 1: 279 | mel_input[0, :, i+1] = mel_out[0, i] 280 | 281 | if stop[-1]>0.5: 282 | break 283 | 284 | mel_out_post = self.Postnet(mel_out.transpose(1, 2).contiguous()) 285 | mel_out_post = mel_out_post.transpose(1, 2).contiguous() + mel_out 286 | mel_out_post = mel_out_post.transpose(1, 2).contiguous() 287 | 288 | return mel_out_post, dec_alignments, enc_dec_alignments, stop 289 | 290 | 291 | class GAE(nn.Module): 292 | def __init__(self, hp): 293 | super(GAE, self).__init__() 294 | self.hp = hp 295 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim) 296 | self.Prenet_D = Prenet_D(hp) 297 | 298 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe) 299 | self.dropout = nn.Dropout(0.1) 300 | 301 | self.Encoder1 = nn.ModuleList([TransformerEncoderLayer(d_model=hp.hidden_dim, 302 | nhead=hp.n_heads, 303 | dim_feedforward=hp.ff_dim) 304 | for _ in range(hp.n_layers)]) 305 | 306 | self.Encoder2 = GraphEncoder(hp) 307 | self.linear = nn.Linear(hp.hidden_dim*2, hp.hidden_dim) 308 | self.Decoder = nn.ModuleList([TransformerDecoderLayer(d_model=hp.hidden_dim, 309 | nhead=hp.n_heads, 310 | dim_feedforward=hp.ff_dim) 311 | for _ in range(hp.n_layers)]) 312 | 313 | self.Projection = Linear(hp.hidden_dim, hp.n_mel_channels) 314 | self.Postnet = PostNet(hp) 315 | self.Stop = nn.Linear(hp.n_mel_channels, 1) 316 | 317 | 318 | def outputs(self, text, adj_matrix, melspec, text_lengths, mel_lengths): 319 | ### Size ### 320 | B, L, T = text.size(0), text.size(1), melspec.size(2) 321 | adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1) 322 | adj_matrix = adj_matrix.transpose(1, 2).reshape(-1, text_lengths.max().item(), text_lengths.max().item()*3*2) 323 | 324 | ### Prepare Encoder Input ### 325 | encoder_input = self.Embedding(text).transpose(0,1) 326 | encoder_input += self.pe[:L].unsqueeze(1) 327 | encoder_input = self.dropout(encoder_input) 328 | 329 | ### Transformer Encoder ### 330 | memory1 = encoder_input 331 | enc_alignments = [] 332 | text_mask = get_mask_from_lengths(text_lengths) 333 | for layer in self.Encoder1: 334 | memory1, enc_align = layer(memory1, src_key_padding_mask=text_mask) 335 | enc_alignments.append(enc_align.unsqueeze(1)) 336 | enc_alignments = torch.cat(enc_alignments, 1) 337 | 338 | memory2 = self.Encoder2(encoder_input, adj_matrix) 339 | memory = self.linear(torch.cat([memory1, memory2], dim=-1)) 340 | 341 | ### Prepare Decoder Input ### 342 | mel_input = F.pad(melspec, (1,-1)).transpose(1,2) 343 | decoder_input = self.Prenet_D(mel_input).transpose(0,1) 344 | decoder_input += self.pe[:T].unsqueeze(1) 345 | decoder_input = self.dropout(decoder_input) 346 | 347 | ### Prepare Masks ### 348 | mel_mask = get_mask_from_lengths(mel_lengths) 349 | diag_mask = torch.triu(melspec.new_ones(T,T)).transpose(0, 1) 350 | diag_mask[diag_mask == 0] = -float('inf') 351 | diag_mask[diag_mask == 1] = 0 352 | 353 | ### Decoding ### 354 | tgt = decoder_input 355 | dec_alignments, enc_dec_alignments = [], [] 356 | for layer in self.Decoder: 357 | tgt, dec_align, enc_dec_align = layer(tgt, 358 | memory, 359 | tgt_mask=diag_mask, 360 | tgt_key_padding_mask=mel_mask, 361 | memory_key_padding_mask=text_mask) 362 | dec_alignments.append(dec_align.unsqueeze(1)) 363 | enc_dec_alignments.append(enc_dec_align.unsqueeze(1)) 364 | dec_alignments = torch.cat(dec_alignments, 1) 365 | enc_dec_alignments = torch.cat(enc_dec_alignments, 1) 366 | 367 | ### Projection + PostNet ### 368 | mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2) 369 | mel_out_post = self.Postnet(mel_out) + mel_out 370 | 371 | gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1) 372 | 373 | return mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out 374 | 375 | 376 | def forward(self, text, adj_matrix, melspec, gate, text_lengths, mel_lengths, criterion): 377 | ### Size ### 378 | text = text[:,:text_lengths.max().item()] 379 | adj_matrix = adj_matrix[:, :, :text_lengths.max().item(), :text_lengths.max().item()] 380 | 381 | melspec = melspec[:,:,:mel_lengths.max().item()] 382 | gate = gate[:, :mel_lengths.max().item()] 383 | outputs = self.outputs(text, adj_matrix, melspec, text_lengths, mel_lengths) 384 | 385 | mel_out, mel_out_post = outputs[0], outputs[1] 386 | enc_dec_alignments = outputs[3] 387 | gate_out=outputs[4] 388 | 389 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out), 390 | (melspec, gate), 391 | (enc_dec_alignments, text_lengths, mel_lengths)) 392 | 393 | return mel_loss, bce_loss, guide_loss 394 | 395 | 396 | def inference(self, text,adj_matrix, max_len=1024): 397 | ### Size & Length ### 398 | (B, L), T = text.size(), max_len 399 | 400 | ### Prepare Inputs ### 401 | encoder_input = self.Embedding(text).transpose(0,1).contiguous() 402 | encoder_input += self.pe[:L].unsqueeze(1) 403 | 404 | memory1 = encoder_input 405 | enc_alignments = [] 406 | text_mask = text.new_zeros(1, L).to(torch.bool) 407 | for layer in self.Encoder1: 408 | memory1, enc_align = layer(memory1, src_key_padding_mask=text_mask) 409 | enc_alignments.append(enc_align.unsqueeze(1)) 410 | enc_alignments = torch.cat(enc_alignments, 1) 411 | 412 | memory2 = self.Encoder2(encoder_input, adj_matrix) 413 | memory = self.linear(torch.cat([memory1, memory2], dim=-1)) 414 | 415 | ### Prepare Masks ### 416 | mel_mask = text.new_zeros(1, T).to(torch.bool) 417 | diag_mask = torch.triu(text.new_ones(T, T)).transpose(0, 1).contiguous() 418 | diag_mask[diag_mask == 0] = -1e9 419 | diag_mask[diag_mask == 1] = 0 420 | 421 | ### Transformer Decoder ### 422 | mel_input = text.new_zeros(1, 423 | self.hp.n_mel_channels, 424 | max_len).to(torch.float32) 425 | dec_alignments = text.new_zeros(self.hp.n_layers, 426 | self.hp.n_heads, 427 | max_len, 428 | max_len).to(torch.float32) 429 | enc_dec_alignments = text.new_zeros(self.hp.n_layers, 430 | self.hp.n_heads, 431 | max_len, 432 | text.size(1)).to(torch.float32) 433 | 434 | ### Generation ### 435 | stop=[] 436 | for i in range(max_len): 437 | tgt = self.Prenet_D(mel_input.transpose(1,2).contiguous()).transpose(0,1).contiguous() 438 | tgt += self.pe[:T].unsqueeze(1) 439 | 440 | for j, layer in enumerate(self.Decoder): 441 | tgt, dec_align, enc_dec_align = layer(tgt, 442 | memory, 443 | tgt_mask=diag_mask, 444 | tgt_key_padding_mask=mel_mask, 445 | memory_key_padding_mask=text_mask) 446 | dec_alignments[j, :, i] = dec_align[0, :, i] 447 | enc_dec_alignments[j, :, i] = enc_dec_align[0, :, i] 448 | 449 | mel_out = self.Projection(tgt.transpose(0,1).contiguous()) 450 | stop.append(torch.sigmoid(self.Stop(mel_out[:,i]))[0,0].item()) 451 | 452 | if i < max_len - 1: 453 | mel_input[0, :, i+1] = mel_out[0, i] 454 | 455 | if stop[-1]>0.5: 456 | break 457 | 458 | mel_out_post = self.Postnet(mel_out.transpose(1, 2).contiguous()) 459 | mel_out_post = mel_out_post.transpose(1, 2).contiguous() + mel_out 460 | mel_out_post = mel_out_post.transpose(1, 2).contiguous() 461 | 462 | return mel_out_post, dec_alignments, enc_dec_alignments, stop -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .init_layer import * 5 | 6 | 7 | class TransformerEncoderLayer(nn.Module): 8 | def __init__(self, 9 | d_model, 10 | nhead, 11 | dim_feedforward=2048, 12 | dropout=0.1, 13 | activation="relu"): 14 | super(TransformerEncoderLayer, self).__init__() 15 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 16 | 17 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation) 18 | self.linear2 = Linear(dim_feedforward, d_model) 19 | 20 | self.norm1 = nn.LayerNorm(d_model) 21 | self.norm2 = nn.LayerNorm(d_model) 22 | 23 | self.dropout = nn.Dropout(dropout) 24 | 25 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 26 | src2, enc_align = self.self_attn(src, 27 | src, 28 | src, 29 | attn_mask=src_mask, 30 | key_padding_mask=src_key_padding_mask) 31 | src = src + self.dropout(src2) 32 | src = self.norm1(src) 33 | 34 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 35 | src = src + self.dropout(src2) 36 | src = self.norm2(src) 37 | 38 | return src, enc_align 39 | 40 | 41 | class TransformerDecoderLayer(nn.Module): 42 | def __init__(self, 43 | d_model, 44 | nhead, 45 | dim_feedforward=2048, 46 | dropout=0.1, 47 | activation="relu"): 48 | super(TransformerDecoderLayer, self).__init__() 49 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 50 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 51 | 52 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation) 53 | self.linear2 = Linear(dim_feedforward, d_model) 54 | 55 | self.norm1 = nn.LayerNorm(d_model) 56 | self.norm2 = nn.LayerNorm(d_model) 57 | self.norm3 = nn.LayerNorm(d_model) 58 | 59 | self.dropout = nn.Dropout(dropout) 60 | 61 | def forward(self, 62 | tgt, 63 | memory, 64 | tgt_mask=None, 65 | memory_mask=None, 66 | tgt_key_padding_mask=None, 67 | memory_key_padding_mask=None): 68 | tgt2, dec_align = self.self_attn(tgt, 69 | tgt, 70 | tgt, 71 | attn_mask=tgt_mask, 72 | key_padding_mask=tgt_key_padding_mask) 73 | tgt = tgt + self.dropout(tgt2) 74 | tgt = self.norm1(tgt) 75 | 76 | tgt2, enc_dec_align = self.multihead_attn(tgt, 77 | memory, 78 | memory, 79 | attn_mask=memory_mask, 80 | key_padding_mask=memory_key_padding_mask) 81 | tgt = tgt + self.dropout(tgt2) 82 | tgt = self.norm2(tgt) 83 | 84 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 85 | tgt = tgt + self.dropout(tgt2) 86 | tgt = self.norm3(tgt) 87 | 88 | return tgt, dec_align, enc_dec_align 89 | 90 | 91 | class PositionalEncoding(nn.Module): 92 | def __init__(self, d_model, max_len=5000): 93 | super(PositionalEncoding, self).__init__() 94 | self.register_buffer('pe', self._get_pe_matrix(d_model, max_len)) 95 | 96 | def forward(self, x): 97 | return x + self.pe[:x.size(0)].unsqueeze(1) 98 | 99 | def _get_pe_matrix(self, d_model, max_len): 100 | pe = torch.zeros(max_len, d_model) 101 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 102 | div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model) 103 | 104 | pe[:, 0::2] = torch.sin(position / div_term) 105 | pe[:, 1::2] = torch.cos(position / div_term) 106 | 107 | return pe 108 | -------------------------------------------------------------------------------- /prepare_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Import libraries, metadata" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2019-12-18T06:01:56.263558Z", 16 | "start_time": "2019-12-18T06:01:51.717351Z" 17 | } 18 | }, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "/home/lyh/anaconda3/envs/LYH/lib/python3.7/site-packages/librosa/util/decorators.py:9: NumbaDeprecationWarning: \u001b[1mAn import was requested from a module that has moved location.\n", 25 | "Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.\u001b[0m\n", 26 | " from numba.decorators import jit as optional_jit\n", 27 | "/home/lyh/anaconda3/envs/LYH/lib/python3.7/site-packages/librosa/util/decorators.py:9: NumbaDeprecationWarning: \u001b[1mAn import was requested from a module that has moved location.\n", 28 | "Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.\u001b[0m\n", 29 | " from numba.decorators import jit as optional_jit\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import os\n", 35 | "import librosa\n", 36 | "from librosa.filters import mel as librosa_mel_fn\n", 37 | "import pickle as pkl\n", 38 | "import IPython.display as ipd\n", 39 | "from tqdm.notebook import tqdm\n", 40 | "import torch\n", 41 | "import numpy as np\n", 42 | "import codecs\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "%matplotlib inline\n", 45 | "\n", 46 | "from g2p_en import G2p\n", 47 | "from text import *\n", 48 | "from text import cmudict\n", 49 | "from text.cleaners import custom_english_cleaners\n", 50 | "from text.symbols import symbols\n", 51 | "\n", 52 | "# Mappings from symbol to numeric ID and vice versa:\n", 53 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n", 54 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n", 55 | "\n", 56 | "csv_file = '/media/disk1/lyh/LJSpeech-1.1/metadata.csv'\n", 57 | "root_dir = '/media/disk1/lyh/LJSpeech-1.1/wavs'\n", 58 | "data_dir = '/media/disk1/lyh/LJSpeech-1.1/preprocessed'\n", 59 | "\n", 60 | "g2p = G2p()\n", 61 | "metadata={}\n", 62 | "with codecs.open(csv_file, 'r', 'utf-8') as fid:\n", 63 | " for line in fid.readlines():\n", 64 | " id, _, text = line.split(\"|\")\n", 65 | " \n", 66 | " clean_char = custom_english_cleaners(text.rstrip())\n", 67 | " clean_phone = []\n", 68 | " for s in g2p(clean_char.lower()):\n", 69 | " if '@'+s in symbol_to_id:\n", 70 | " clean_phone.append('@'+s)\n", 71 | " else:\n", 72 | " clean_phone.append(s)\n", 73 | " \n", 74 | " metadata[id]={'char':clean_char,\n", 75 | " 'phone':clean_phone}" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Others" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "scrolled": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "from layers import TacotronSTFT\n", 94 | "stft = TacotronSTFT()\n", 95 | "\n", 96 | "def text2seq(text):\n", 97 | " sequence=[symbol_to_id['^']]\n", 98 | " sequence.extend([symbol_to_id[c] for c in text])\n", 99 | " sequence.append(symbol_to_id['~'])\n", 100 | " return sequence\n", 101 | "\n", 102 | "def create_adjacency_matrix(char_seq):\n", 103 | " n_nodes=len(char_seq)\n", 104 | " n_edge_types=3\n", 105 | " \n", 106 | " a = np.zeros([n_edge_types, n_nodes, n_nodes], dtype=np.int)\n", 107 | " \n", 108 | " a[0] = np.eye(n_nodes, k=1, dtype=int) + np.eye(n_nodes, k=-1, dtype=int)\n", 109 | " a[2] = 1\n", 110 | " \n", 111 | " white_spaces = (char_seq==symbol_to_id[' ']).nonzero()\n", 112 | " start = torch.cat([torch.tensor([0]), white_spaces[1:].view(-1)])\n", 113 | " end = torch.cat([white_spaces[1:].view(-1), torch.tensor([n_nodes])])\n", 114 | " for i in range(len(start)):\n", 115 | " a[1, start[i]:end[i], start[i]:end[i]]=1\n", 116 | " \n", 117 | " # a = np.concatenate((a, a), axis=0)\n", 118 | " # a = np.transpose(a, (1, 0, 2)).reshape(n_nodes, n_nodes*n_edge_types*2)\n", 119 | " return a\n", 120 | "\n", 121 | "def get_mel(filename):\n", 122 | " wav, sr = librosa.load(filename, sr=22050)\n", 123 | " wav = torch.FloatTensor(wav.astype(np.float32))\n", 124 | " melspec = stft.mel_spectrogram(wav.unsqueeze(0))\n", 125 | " return melspec.squeeze(0), wav\n", 126 | "\n", 127 | "def save_file(fname):\n", 128 | " wav_name = os.path.join(root_dir, fname) + '.wav'\n", 129 | " text = metadata[fname]['char']\n", 130 | " char_seq = torch.LongTensor( text2seq(metadata[fname]['char']) )\n", 131 | " phone_seq = torch.LongTensor( text2seq(metadata[fname]['phone']) )\n", 132 | " adj_matrix = create_adjacency_matrix(char_seq)\n", 133 | " melspec, wav = get_mel(wav_name)\n", 134 | " \n", 135 | " with open(f'{data_dir}/char_seq/{fname}_sequence.pkl', 'wb') as f:\n", 136 | " pkl.dump(char_seq, f)\n", 137 | " with open(f'{data_dir}/phone_seq/{fname}_sequence.pkl', 'wb') as f:\n", 138 | " pkl.dump(phone_seq, f)\n", 139 | " with open(f'{data_dir}/adj_matrix/{fname}_adj_matrix.pkl', 'wb') as f:\n", 140 | " pkl.dump(adj_matrix, f)\n", 141 | " with open(f'{data_dir}/melspectrogram/{fname}_melspectrogram.pkl', 'wb') as f:\n", 142 | " pkl.dump(melspec, f)\n", 143 | " \n", 144 | " return text, char_seq, phone_seq, adj_matrix, melspec, wav" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "### Save and Inspect Data" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": { 158 | "ExecuteTime": { 159 | "start_time": "2019-12-18T06:01:23.402Z" 160 | }, 161 | "code_folding": [], 162 | "scrolled": true 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "for k in tqdm(metadata.keys()):\n", 167 | " text, char_seq, phone_seq, adj_matrix, melspec, wav = save_file(k)\n", 168 | " if k == 'LJ001-0019':\n", 169 | " print(\"Text:\")\n", 170 | " print(text)\n", 171 | " print()\n", 172 | " print(\"Melspectrogram:\")\n", 173 | " plt.figure(figsize=(16,4))\n", 174 | " plt.imshow(melspec, aspect='auto', origin='lower')\n", 175 | " plt.show()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 2, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "application/vnd.jupyter.widget-view+json": { 186 | "model_id": "34e6927c94f245669e35ed76ef0fbff0", 187 | "version_major": 2, 188 | "version_minor": 0 189 | }, 190 | "text/plain": [ 191 | "HBox(children=(FloatProgress(value=0.0, max=13100.0), HTML(value='')))" 192 | ] 193 | }, 194 | "metadata": {}, 195 | "output_type": "display_data" 196 | }, 197 | { 198 | "name": "stderr", 199 | "output_type": "stream", 200 | "text": [ 201 | "/home/lyh/anaconda3/envs/LYH/lib/python3.7/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", 202 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n" 203 | ] 204 | }, 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "def text2seq(text):\n", 215 | " sequence=[symbol_to_id['^']]\n", 216 | " sequence.extend([symbol_to_id[c] for c in text])\n", 217 | " sequence.append(symbol_to_id['~'])\n", 218 | " return sequence\n", 219 | "\n", 220 | "def create_adjacency_matrix(char_seq):\n", 221 | " n_nodes=len(char_seq)\n", 222 | " n_edge_types=3\n", 223 | " \n", 224 | " a = np.zeros([n_edge_types, n_nodes, n_nodes], dtype=np.int)\n", 225 | " \n", 226 | " a[0] = np.eye(n_nodes, k=1, dtype=int) + np.eye(n_nodes, k=-1, dtype=int)\n", 227 | " a[2] = 1\n", 228 | " \n", 229 | " white_spaces = (char_seq==symbol_to_id[' ']).nonzero()\n", 230 | " start = torch.cat([torch.tensor([0]), white_spaces[1:].view(-1)])\n", 231 | " end = torch.cat([white_spaces[1:].view(-1), torch.tensor([n_nodes])])\n", 232 | " for i in range(len(start)):\n", 233 | " a[1, start[i]:end[i], start[i]:end[i]]=1\n", 234 | " \n", 235 | " # a = np.concatenate((a, a), axis=0)\n", 236 | " # a = np.transpose(a, (1, 0, 2)).reshape(n_nodes, n_nodes*n_edge_types*2)\n", 237 | " return a\n", 238 | "\n", 239 | "def save_file(fname):\n", 240 | " wav_name = os.path.join(root_dir, fname) + '.wav'\n", 241 | " text = metadata[fname]['char']\n", 242 | " char_seq = torch.LongTensor( text2seq(metadata[fname]['char']) )\n", 243 | " adj_matrix = torch.LongTensor(create_adjacency_matrix(char_seq))\n", 244 | " \n", 245 | " with open(f'{data_dir}/adj_matrix/{fname}_adj_matrix.pkl', 'wb') as f:\n", 246 | " pkl.dump(adj_matrix, f)\n", 247 | "\n", 248 | "\n", 249 | "for k in tqdm(metadata.keys()):\n", 250 | " save_file(k)" 251 | ] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python [conda env:LYH] *", 257 | "language": "python", 258 | "name": "conda-env-LYH-py" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.7.5" 271 | }, 272 | "varInspector": { 273 | "cols": { 274 | "lenName": 16, 275 | "lenType": 16, 276 | "lenVar": 40 277 | }, 278 | "kernels_config": { 279 | "python": { 280 | "delete_cmd_postfix": "", 281 | "delete_cmd_prefix": "del ", 282 | "library": "var_list.py", 283 | "varRefreshCmd": "print(var_dic_list())" 284 | }, 285 | "r": { 286 | "delete_cmd_postfix": ") ", 287 | "delete_cmd_prefix": "rm(", 288 | "library": "var_list.r", 289 | "varRefreshCmd": "cat(var_dic_list()) " 290 | } 291 | }, 292 | "types_to_exclude": [ 293 | "module", 294 | "function", 295 | "builtin_function_or_method", 296 | "instance", 297 | "_Feature" 298 | ], 299 | "window_display": false 300 | } 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 2 304 | } 305 | -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann'): 46 | super(STFT, self).__init__() 47 | self.filter_length = filter_length 48 | self.hop_length = hop_length 49 | self.win_length = win_length 50 | self.window = window 51 | self.forward_transform = None 52 | scale = self.filter_length / self.hop_length 53 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 54 | 55 | cutoff = int((self.filter_length / 2 + 1)) 56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 57 | np.imag(fourier_basis[:cutoff, :])]) 58 | 59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 60 | inverse_basis = torch.FloatTensor( 61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 62 | 63 | if window is not None: 64 | assert(filter_length >= win_length) 65 | # get window and zero center pad it to filter_length 66 | fft_window = get_window(window, win_length, fftbins=True) 67 | fft_window = pad_center(fft_window, filter_length) 68 | fft_window = torch.from_numpy(fft_window).float() 69 | 70 | # window the bases 71 | forward_basis *= fft_window 72 | inverse_basis *= fft_window 73 | 74 | self.register_buffer('forward_basis', forward_basis.float()) 75 | self.register_buffer('inverse_basis', inverse_basis.float()) 76 | 77 | def transform(self, input_data): 78 | num_batches = input_data.size(0) 79 | num_samples = input_data.size(1) 80 | 81 | self.num_samples = num_samples 82 | 83 | # similar to librosa, reflect-pad the input 84 | input_data = input_data.view(num_batches, 1, num_samples) 85 | input_data = F.pad( 86 | input_data.unsqueeze(1), 87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 88 | mode='reflect') 89 | input_data = input_data.squeeze(1) 90 | 91 | forward_transform = F.conv1d( 92 | input_data, 93 | Variable(self.forward_basis, requires_grad=False), 94 | stride=self.hop_length, 95 | padding=0) 96 | 97 | cutoff = int((self.filter_length / 2) + 1) 98 | real_part = forward_transform[:, :cutoff, :] 99 | imag_part = forward_transform[:, cutoff:, :] 100 | 101 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 102 | phase = torch.autograd.Variable( 103 | torch.atan2(imag_part.data, real_part.data)) 104 | 105 | return magnitude, phase 106 | 107 | def inverse(self, magnitude, phase): 108 | recombine_magnitude_phase = torch.cat( 109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 110 | 111 | inverse_transform = F.conv_transpose1d( 112 | recombine_magnitude_phase, 113 | Variable(self.inverse_basis, requires_grad=False), 114 | stride=self.hop_length, 115 | padding=0) 116 | 117 | if self.window is not None: 118 | window_sum = window_sumsquare( 119 | self.window, magnitude.size(-1), hop_length=self.hop_length, 120 | win_length=self.win_length, n_fft=self.filter_length, 121 | dtype=np.float32) 122 | # remove modulation effects 123 | approx_nonzero_indices = torch.from_numpy( 124 | np.where(window_sum > tiny(window_sum))[0]) 125 | window_sum = torch.autograd.Variable( 126 | torch.from_numpy(window_sum), requires_grad=False) 127 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 128 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 129 | 130 | # scale by hop ratio 131 | inverse_transform *= float(self.filter_length) / self.hop_length 132 | 133 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 134 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 135 | 136 | return inverse_transform 137 | 138 | def forward(self, input_data): 139 | self.magnitude, self.phase = self.transform(input_data) 140 | reconstruction = self.inverse(self.magnitude, self.phase) 141 | return reconstruction 142 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | # -*- coding: utf-8 -*- 3 | 4 | import re 5 | from text import cleaners 6 | from text.symbols import symbols 7 | 8 | # Mappings from symbol to numeric ID and vice versa: 9 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 10 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 11 | 12 | # Regular expression matching text enclosed in curly braces: 13 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 14 | 15 | 16 | def text_to_sequence(text, cleaner_names): 17 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 18 | 19 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 20 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 21 | 22 | Args: 23 | text: string to convert to a sequence 24 | cleaner_names: names of the cleaner functions to run the text through 25 | 26 | Returns: 27 | List of integers corresponding to the symbols in the text 28 | ''' 29 | sequence = [_symbol_to_id['^']] 30 | 31 | # Check for curly braces and treat their contents as ARPAbet: 32 | while len(text): 33 | m = _curly_re.match(text) 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 38 | sequence += _arpabet_to_sequence(m.group(2)) 39 | text = m.group(3) 40 | 41 | # Append EOS token 42 | sequence.append(_symbol_to_id['~']) 43 | return sequence 44 | 45 | 46 | def sequence_to_text(sequence): 47 | '''Converts a sequence of IDs back to a string''' 48 | result = '' 49 | for symbol_id in sequence: 50 | if symbol_id in _id_to_symbol: 51 | s = _id_to_symbol[symbol_id] 52 | # Enclose ARPAbet back in curly braces: 53 | if len(s) > 1 and s[0] == '@': 54 | s = '{%s}' % s[1:] 55 | result += s 56 | return result.replace('}{', ' ') 57 | 58 | 59 | def _clean_text(text, cleaner_names): 60 | for name in cleaner_names: 61 | cleaner = getattr(cleaners, name) 62 | if not cleaner: 63 | raise Exception('Unknown cleaner: %s' % name) 64 | text = cleaner(text) 65 | return text 66 | 67 | 68 | def _symbols_to_sequence(symbols): 69 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 70 | 71 | 72 | def _arpabet_to_sequence(text): 73 | return _symbols_to_sequence(['@' + s for s in text.split()]) 74 | 75 | 76 | def _should_keep_symbol(s): 77 | return s in _symbol_to_id and s is not '_' and s is not '~' 78 | -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/cleaners.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/cleaners.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/cmudict.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/numbers.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/symbols.cpython-37.pyc -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """This file is derived from https://github.com/keithito/tacotron. 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import re 15 | 16 | from unidecode import unidecode 17 | 18 | from text.numbers import normalize_numbers 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | 92 | 93 | # NOTE (kan-bayashi): Following functions additionally defined, not inclueded in original codes. 94 | def remove_unnecessary_symbols(text): 95 | # added 96 | text = re.sub(r'[\(\)\[\]\<\>\"]+', '', text) 97 | return text 98 | 99 | 100 | def expand_symbols(text): 101 | # added 102 | text = re.sub("\;", ",", text) 103 | text = re.sub("\:", ",", text) 104 | text = re.sub("\-", " ", text) 105 | text = re.sub("\&", "and", text) 106 | return text 107 | 108 | 109 | def uppercase(text): 110 | # added 111 | return text.upper() 112 | 113 | 114 | def custom_english_cleaners(text): 115 | '''Custom pipeline for English text, including number and abbreviation expansion.''' 116 | text = convert_to_ascii(text) 117 | text = lowercase(text) 118 | text = expand_numbers(text) 119 | text = expand_abbreviations(text) 120 | text = expand_symbols(text) 121 | text = remove_unnecessary_symbols(text) 122 | text = uppercase(text) 123 | text = collapse_whitespace(text) 124 | 125 | # There is an exception (I found it!) 126 | # "'NOW FOR YOU, MY POOR FELLOW MORTALS, WHO ARE ABOUT TO SUFFER THE LAST PENALTY OF THE LAW.'" 127 | if text[0]=="'" and text[-1]=="'": 128 | text = text[1:-1] 129 | 130 | return text 131 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ from https://github.com/keithito/tacotron """ 3 | 4 | import inflect 5 | import re 6 | 7 | 8 | _inflect = inflect.engine() 9 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 10 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 11 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 12 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 13 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 14 | _number_re = re.compile(r'[0-9]+') 15 | 16 | 17 | def _remove_commas(m): 18 | return m.group(1).replace(',', '') 19 | 20 | 21 | def _expand_decimal_point(m): 22 | return m.group(1).replace('.', ' point ') 23 | 24 | 25 | def _expand_dollars(m): 26 | match = m.group(1) 27 | parts = match.split('.') 28 | if len(parts) > 2: 29 | return match + ' dollars' # Unexpected format 30 | dollars = int(parts[0]) if parts[0] else 0 31 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 32 | if dollars and cents: 33 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 34 | cent_unit = 'cent' if cents == 1 else 'cents' 35 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 36 | elif dollars: 37 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 38 | return '%s %s' % (dollars, dollar_unit) 39 | elif cents: 40 | cent_unit = 'cent' if cents == 1 else 'cents' 41 | return '%s %s' % (cents, cent_unit) 42 | else: 43 | return 'zero dollars' 44 | 45 | 46 | def _expand_ordinal(m): 47 | return _inflect.number_to_words(m.group(0)) 48 | 49 | 50 | def _expand_number(m): 51 | num = int(m.group(0)) 52 | if num > 1000 and num < 3000: 53 | if num == 2000: 54 | return 'two thousand' 55 | elif num > 2000 and num < 2010: 56 | return 'two thousand ' + _inflect.number_to_words(num % 100) 57 | elif num % 100 == 0: 58 | return _inflect.number_to_words(num // 100) + ' hundred' 59 | else: 60 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 61 | else: 62 | return _inflect.number_to_words(num, andword='') 63 | 64 | 65 | def normalize_numbers(text): 66 | text = re.sub(_comma_number_re, _remove_commas, text) 67 | text = re.sub(_pounds_re, r'\1 pounds', text) 68 | text = re.sub(_dollars_re, _expand_dollars, text) 69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 70 | text = re.sub(_ordinal_re, _expand_ordinal, text) 71 | text = re.sub(_number_re, _expand_number, text) 72 | return text 73 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _sos = '^' 11 | _eos = '~' 12 | _punctuations = " ,.'?!" 13 | _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 14 | 15 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 16 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 17 | 18 | # Export all symbols: 19 | symbols = [_pad, _sos, _eos] + list(_punctuations) + list(_characters) + _arpabet -------------------------------------------------------------------------------- /train-gae.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modules.model import GAE 6 | from modules.loss import TransformerLoss 7 | import hparams 8 | from text import * 9 | from utils.utils import * 10 | from utils.writer import get_writer 11 | 12 | 13 | def validate(model, criterion, val_loader, iteration, writer): 14 | model.eval() 15 | with torch.no_grad(): 16 | n_data, val_loss = 0, 0 17 | for i, batch in enumerate(val_loader): 18 | n_data += len(batch[0]) 19 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [ 20 | x.cuda() for x in batch 21 | ] 22 | 23 | mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded, 24 | adj_padded, 25 | mel_padded, 26 | text_lengths, 27 | mel_lengths) 28 | 29 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out), 30 | (mel_padded, gate_padded), 31 | (enc_dec_alignments, text_lengths, mel_lengths)) 32 | 33 | loss = torch.mean(mel_loss+bce_loss+guide_loss) 34 | val_loss += loss.item() * len(batch[0]) 35 | 36 | val_loss /= n_data 37 | 38 | writer.add_losses(mel_loss.item(), 39 | bce_loss.item(), 40 | guide_loss.item(), 41 | iteration//hparams.accumulation, 'Validation') 42 | 43 | writer.add_specs(mel_padded.detach().cpu(), 44 | mel_out.detach().cpu(), 45 | mel_out_post.detach().cpu(), 46 | mel_lengths.detach().cpu(), 47 | iteration//hparams.accumulation, 'Validation') 48 | 49 | enc_alignments=dec_alignments.new_zeros(enc_dec_alignments.size(0), 50 | enc_dec_alignments.size(1), 51 | enc_dec_alignments.size(2), 52 | enc_dec_alignments.size(4), 53 | enc_dec_alignments.size(4)) 54 | writer.add_alignments(enc_alignments.detach().cpu(), 55 | dec_alignments.detach().cpu(), 56 | enc_dec_alignments.detach().cpu(), 57 | text_padded.detach().cpu(), 58 | mel_lengths.detach().cpu(), 59 | text_lengths.detach().cpu(), 60 | iteration//hparams.accumulation, 'Validation') 61 | 62 | writer.add_gates(gate_out.detach().cpu(), 63 | iteration//hparams.accumulation, 'Validation') 64 | model.train() 65 | 66 | 67 | 68 | def main(): 69 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams) 70 | model = nn.DataParallel(GAE(hparams)).cuda() 71 | optimizer = torch.optim.Adam(model.parameters(), 72 | lr=hparams.lr, 73 | betas=(0.9, 0.98), 74 | eps=1e-09) 75 | criterion = TransformerLoss() 76 | writer = get_writer(hparams.output_directory, hparams.log_directory) 77 | 78 | iteration, loss = 0, 0 79 | model.train() 80 | print("Training Start!!!") 81 | while iteration < (hparams.train_steps*hparams.accumulation): 82 | for i, batch in enumerate(train_loader): 83 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [ 84 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 85 | ] 86 | 87 | mel_loss, bce_loss, guide_loss = model(text_padded, 88 | adj_padded, 89 | mel_padded, 90 | gate_padded, 91 | text_lengths, 92 | mel_lengths, 93 | criterion) 94 | 95 | mel_loss, bce_loss, guide_loss=[ 96 | torch.mean(x) for x in [mel_loss, bce_loss, guide_loss] 97 | ] 98 | sub_loss = (mel_loss+bce_loss+guide_loss)/hparams.accumulation 99 | sub_loss.backward() 100 | loss = loss+sub_loss.item() 101 | 102 | iteration += 1 103 | if iteration%hparams.accumulation == 0: 104 | lr_scheduling(optimizer, iteration//hparams.accumulation) 105 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) 106 | optimizer.step() 107 | model.zero_grad() 108 | writer.add_losses(mel_loss.item(), 109 | bce_loss.item(), 110 | guide_loss.item(), 111 | iteration//hparams.accumulation, 'Train') 112 | loss=0 113 | 114 | 115 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0: 116 | validate(model, criterion, val_loader, iteration, writer) 117 | 118 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0: 119 | save_checkpoint(model, 120 | optimizer, 121 | hparams.lr, 122 | iteration//hparams.accumulation, 123 | filepath=f'{hparams.output_directory}/{hparams.log_directory}') 124 | 125 | if iteration==(hparams.train_steps*hparams.accumulation): 126 | break 127 | 128 | 129 | if __name__ == '__main__': 130 | p = argparse.ArgumentParser() 131 | p.add_argument('--gpu', type=str, default='0,1') 132 | p.add_argument('-v', '--verbose', type=str, default='0') 133 | args = p.parse_args() 134 | hparams.log_directory = 'gae-char' 135 | hparams.iterations = 1 136 | 137 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 138 | torch.manual_seed(hparams.seed) 139 | torch.cuda.manual_seed(hparams.seed) 140 | 141 | if args.verbose=='0': 142 | import warnings 143 | warnings.filterwarnings("ignore") 144 | 145 | main() -------------------------------------------------------------------------------- /train-graphtts.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modules.model import GraphTTS 6 | from modules.loss import TransformerLoss 7 | import hparams 8 | from text import * 9 | from utils.utils import * 10 | from utils.writer import get_writer 11 | 12 | 13 | def validate(model, criterion, val_loader, iteration, writer): 14 | model.eval() 15 | with torch.no_grad(): 16 | n_data, val_loss = 0, 0 17 | for i, batch in enumerate(val_loader): 18 | n_data += len(batch[0]) 19 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [ 20 | x.cuda() for x in batch 21 | ] 22 | 23 | mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded, 24 | adj_padded, 25 | mel_padded, 26 | text_lengths, 27 | mel_lengths) 28 | 29 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out), 30 | (mel_padded, gate_padded), 31 | (enc_dec_alignments, text_lengths, mel_lengths)) 32 | 33 | loss = torch.mean(mel_loss+bce_loss+guide_loss) 34 | val_loss += loss.item() * len(batch[0]) 35 | 36 | val_loss /= n_data 37 | 38 | writer.add_losses(mel_loss.item(), 39 | bce_loss.item(), 40 | guide_loss.item(), 41 | iteration//hparams.accumulation, 'Validation') 42 | 43 | writer.add_specs(mel_padded.detach().cpu(), 44 | mel_out.detach().cpu(), 45 | mel_out_post.detach().cpu(), 46 | mel_lengths.detach().cpu(), 47 | iteration//hparams.accumulation, 'Validation') 48 | 49 | enc_alignments=dec_alignments.new_zeros(enc_dec_alignments.size(0), 50 | enc_dec_alignments.size(1), 51 | enc_dec_alignments.size(2), 52 | enc_dec_alignments.size(4), 53 | enc_dec_alignments.size(4)) 54 | writer.add_alignments(enc_alignments.detach().cpu(), 55 | dec_alignments.detach().cpu(), 56 | enc_dec_alignments.detach().cpu(), 57 | text_padded.detach().cpu(), 58 | mel_lengths.detach().cpu(), 59 | text_lengths.detach().cpu(), 60 | iteration//hparams.accumulation, 'Validation') 61 | 62 | writer.add_gates(gate_out.detach().cpu(), 63 | iteration//hparams.accumulation, 'Validation') 64 | model.train() 65 | 66 | 67 | 68 | def main(): 69 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams) 70 | model = nn.DataParallel(GraphTTS(hparams)).cuda() 71 | optimizer = torch.optim.Adam(model.parameters(), 72 | lr=hparams.lr, 73 | betas=(0.9, 0.98), 74 | eps=1e-09) 75 | criterion = TransformerLoss() 76 | writer = get_writer(hparams.output_directory, hparams.log_directory) 77 | 78 | iteration, loss = 0, 0 79 | model.train() 80 | print("Training Start!!!") 81 | while iteration < (hparams.train_steps*hparams.accumulation): 82 | for i, batch in enumerate(train_loader): 83 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [ 84 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 85 | ] 86 | 87 | mel_loss, bce_loss, guide_loss = model(text_padded, 88 | adj_padded, 89 | mel_padded, 90 | gate_padded, 91 | text_lengths, 92 | mel_lengths, 93 | criterion) 94 | 95 | mel_loss, bce_loss, guide_loss=[ 96 | torch.mean(x) for x in [mel_loss, bce_loss, guide_loss] 97 | ] 98 | sub_loss = (mel_loss+bce_loss+guide_loss)/hparams.accumulation 99 | sub_loss.backward() 100 | loss = loss+sub_loss.item() 101 | 102 | iteration += 1 103 | if iteration%hparams.accumulation == 0: 104 | lr_scheduling(optimizer, iteration//hparams.accumulation) 105 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) 106 | optimizer.step() 107 | model.zero_grad() 108 | writer.add_losses(mel_loss.item(), 109 | bce_loss.item(), 110 | guide_loss.item(), 111 | iteration//hparams.accumulation, 'Train') 112 | loss=0 113 | 114 | 115 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0: 116 | validate(model, criterion, val_loader, iteration, writer) 117 | 118 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0: 119 | save_checkpoint(model, 120 | optimizer, 121 | hparams.lr, 122 | iteration//hparams.accumulation, 123 | filepath=f'{hparams.output_directory}/{hparams.log_directory}') 124 | 125 | if iteration==(hparams.train_steps*hparams.accumulation): 126 | break 127 | 128 | 129 | if __name__ == '__main__': 130 | p = argparse.ArgumentParser() 131 | p.add_argument('--gpu', type=str, default='0,1') 132 | p.add_argument('-v', '--verbose', type=str, default='0') 133 | args = p.parse_args() 134 | hparams.log_directory = 'graph-tts-char_iter5' 135 | hparams.iterations = 5 136 | 137 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 138 | torch.manual_seed(hparams.seed) 139 | torch.cuda.manual_seed(hparams.seed) 140 | 141 | if args.verbose=='0': 142 | import warnings 143 | warnings.filterwarnings("ignore") 144 | 145 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modules.model import Model 6 | from modules.loss import TransformerLoss 7 | import hparams 8 | from text import * 9 | from utils.utils import * 10 | from utils.writer import get_writer 11 | 12 | 13 | def validate(model, criterion, val_loader, iteration, writer): 14 | model.eval() 15 | with torch.no_grad(): 16 | n_data, val_loss = 0, 0 17 | for i, batch in enumerate(val_loader): 18 | n_data += len(batch[0]) 19 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [ 20 | x.cuda() for x in batch 21 | ] 22 | 23 | mel_out, mel_out_post,\ 24 | enc_alignments, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded, 25 | mel_padded, 26 | text_lengths, 27 | mel_lengths) 28 | 29 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out), 30 | (mel_padded, gate_padded), 31 | (enc_dec_alignments, text_lengths, mel_lengths)) 32 | 33 | loss = torch.mean(mel_loss+bce_loss+guide_loss) 34 | val_loss += loss.item() * len(batch[0]) 35 | 36 | val_loss /= n_data 37 | 38 | writer.add_losses(mel_loss.item(), 39 | bce_loss.item(), 40 | guide_loss.item(), 41 | iteration//hparams.accumulation, 'Validation') 42 | 43 | writer.add_specs(mel_padded.detach().cpu(), 44 | mel_out.detach().cpu(), 45 | mel_out_post.detach().cpu(), 46 | mel_lengths.detach().cpu(), 47 | iteration//hparams.accumulation, 'Validation') 48 | 49 | writer.add_alignments(enc_alignments.detach().cpu(), 50 | dec_alignments.detach().cpu(), 51 | enc_dec_alignments.detach().cpu(), 52 | text_padded.detach().cpu(), 53 | mel_lengths.detach().cpu(), 54 | text_lengths.detach().cpu(), 55 | iteration//hparams.accumulation, 'Validation') 56 | 57 | writer.add_gates(gate_out.detach().cpu(), 58 | iteration//hparams.accumulation, 'Validation') 59 | model.train() 60 | 61 | 62 | 63 | def main(): 64 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams) 65 | model = nn.DataParallel(Model(hparams)).cuda() 66 | optimizer = torch.optim.Adam(model.parameters(), 67 | lr=hparams.lr, 68 | betas=(0.9, 0.98), 69 | eps=1e-09) 70 | criterion = TransformerLoss() 71 | writer = get_writer(hparams.output_directory, hparams.log_directory) 72 | 73 | iteration, loss = 0, 0 74 | model.train() 75 | print("Training Start!!!") 76 | while iteration < (hparams.train_steps*hparams.accumulation): 77 | for i, batch in enumerate(train_loader): 78 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [ 79 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch 80 | ] 81 | 82 | mel_loss, bce_loss, guide_loss = model(text_padded, 83 | mel_padded, 84 | gate_padded, 85 | text_lengths, 86 | mel_lengths, 87 | criterion) 88 | 89 | mel_loss, bce_loss, guide_loss=[ 90 | torch.mean(x) for x in [mel_loss, bce_loss, guide_loss] 91 | ] 92 | sub_loss = (mel_loss+bce_loss+guide_loss)/hparams.accumulation 93 | sub_loss.backward() 94 | loss = loss+sub_loss.item() 95 | 96 | iteration += 1 97 | if iteration%hparams.accumulation == 0: 98 | lr_scheduling(optimizer, iteration//hparams.accumulation) 99 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) 100 | optimizer.step() 101 | model.zero_grad() 102 | writer.add_losses(mel_loss.item(), 103 | bce_loss.item(), 104 | guide_loss.item(), 105 | iteration//hparams.accumulation, 'Train') 106 | loss=0 107 | 108 | 109 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0: 110 | validate(model, criterion, val_loader, iteration, writer) 111 | 112 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0: 113 | save_checkpoint(model, 114 | optimizer, 115 | hparams.lr, 116 | iteration//hparams.accumulation, 117 | filepath=f'{hparams.output_directory}/{hparams.log_directory}') 118 | 119 | if iteration==(hparams.train_steps*hparams.accumulation): 120 | break 121 | 122 | 123 | if __name__ == '__main__': 124 | p = argparse.ArgumentParser() 125 | p.add_argument('--gpu', type=str, default='0,1') 126 | p.add_argument('-v', '--verbose', type=str, default='0') 127 | args = p.parse_args() 128 | 129 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 130 | torch.manual_seed(hparams.seed) 131 | torch.cuda.manual_seed(hparams.seed) 132 | 133 | if args.verbose=='0': 134 | import warnings 135 | warnings.filterwarnings("ignore") 136 | 137 | main() -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/plot_image.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/writer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/writer.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import hparams 4 | import torch 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | import os 8 | import pickle as pkl 9 | 10 | from text import text_to_sequence 11 | 12 | 13 | def load_filepaths_and_text(metadata, split="|"): 14 | with open(metadata, encoding='utf-8') as f: 15 | filepaths_and_text = [line.strip().split(split) for line in f] 16 | return filepaths_and_text 17 | 18 | 19 | class TextMelSet(torch.utils.data.Dataset): 20 | def __init__(self, audiopaths_and_text, hparams): 21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 22 | random.seed(1234) 23 | random.shuffle(self.audiopaths_and_text) 24 | self.data_type=hparams.data_type 25 | 26 | def get_mel_text_pair(self, audiopath_and_text): 27 | # separate filename and text 28 | file_name = audiopath_and_text[0][:10] 29 | seq_path = os.path.join(hparams.data_path, self.data_type) 30 | adj_path = os.path.join(hparams.data_path, 'adj_matrix') 31 | mel_path = os.path.join(hparams.data_path, 'melspectrogram') 32 | 33 | with open(f'{seq_path}/{file_name}_sequence.pkl', 'rb') as f: 34 | text = pkl.load(f) 35 | with open(f'{adj_path}/{file_name}_adj_matrix.pkl', 'rb') as f: 36 | adj = pkl.load(f) 37 | with open(f'{mel_path}/{file_name}_melspectrogram.pkl', 'rb') as f: 38 | mel = pkl.load(f) 39 | 40 | return (text, adj, mel) 41 | 42 | def __getitem__(self, index): 43 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 44 | 45 | def __len__(self): 46 | return len(self.audiopaths_and_text) 47 | 48 | 49 | class TextMelCollate(): 50 | def __init__(self): 51 | return 52 | 53 | def __call__(self, batch): 54 | # Right zero-pad all one-hot text sequences to max input length 55 | input_lengths, ids_sorted_decreasing = torch.sort( 56 | torch.LongTensor([len(x[0]) for x in batch]), 57 | dim=0, descending=True) 58 | max_input_len = input_lengths[0] 59 | 60 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long) 61 | adj_padded = torch.zeros(len(batch), 3, max_input_len, max_input_len, dtype=torch.long) 62 | for i in range(len(ids_sorted_decreasing)): 63 | text = batch[ids_sorted_decreasing[i]][0] 64 | adj = batch[ids_sorted_decreasing[i]][1] 65 | text_padded[i, :text.size(0)] = text 66 | adj_padded[i, :, :adj.size(1), :adj.size(1)] = adj 67 | 68 | # Right zero-pad 69 | num_mels = batch[0][2].size(0) 70 | max_target_len = max([x[2].size(1) for x in batch]) 71 | 72 | # include Spec padded and gate padded 73 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len) 74 | gate_padded = torch.zeros(len(batch), max_target_len) 75 | output_lengths = torch.LongTensor(len(batch)) 76 | for i in range(len(ids_sorted_decreasing)): 77 | mel = batch[ids_sorted_decreasing[i]][2] 78 | mel_padded[i, :, :mel.size(1)] = mel 79 | gate_padded[i, mel.size(1)-1:] = 1 80 | output_lengths[i] = mel.size(1) 81 | 82 | return text_padded, adj_padded, input_lengths, mel_padded, output_lengths, gate_padded -------------------------------------------------------------------------------- /utils/plot_image.py: -------------------------------------------------------------------------------- 1 | from text import * 2 | import torch 3 | import hparams 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def plot_melspec(target, melspec, melspec_post, mel_lengths): 8 | fig, axes = plt.subplots(3, 1, figsize=(20,30)) 9 | T = mel_lengths[-1] 10 | 11 | axes[0].imshow(target[-1][:,:T], 12 | origin='lower', 13 | aspect='auto') 14 | 15 | axes[1].imshow(melspec[-1][:,:T], 16 | origin='lower', 17 | aspect='auto') 18 | 19 | axes[2].imshow(melspec_post[-1][:,:T], 20 | origin='lower', 21 | aspect='auto') 22 | 23 | return fig 24 | 25 | 26 | def plot_alignments(alignments, text, mel_lengths, text_lengths, att_type): 27 | fig, axes = plt.subplots(hparams.n_layers, hparams.n_heads, figsize=(5*hparams.n_heads,5*hparams.n_layers)) 28 | L, T = text_lengths[-1], mel_lengths[-1] 29 | n_layers, n_heads = alignments.size(1), alignments.size(2) 30 | 31 | for layer in range(n_layers): 32 | for head in range(n_heads): 33 | if att_type=='enc': 34 | align = alignments[-1, layer, head].contiguous() 35 | axes[layer,head].imshow(align[:L, :L], aspect='auto') 36 | axes[layer,head].xaxis.tick_top() 37 | 38 | elif att_type=='dec': 39 | align = alignments[-1, layer, head].contiguous() 40 | axes[layer,head].imshow(align[:T, :T], aspect='auto') 41 | axes[layer,head].xaxis.tick_top() 42 | 43 | elif att_type=='enc_dec': 44 | align = alignments[-1, layer, head].transpose(0,1).contiguous() 45 | axes[layer,head].imshow(align[:L, :T], origin='lower', aspect='auto') 46 | 47 | return fig 48 | 49 | 50 | def plot_gate(gate_out): 51 | fig = plt.figure(figsize=(10,5)) 52 | plt.plot(torch.sigmoid(gate_out[-1])) 53 | return fig -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import hparams 2 | from torch.utils.data import DataLoader 3 | from .data_utils import TextMelSet, TextMelCollate 4 | import torch 5 | from text import * 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def prepare_dataloaders(hparams): 10 | # Get data, data loaders and collate function ready 11 | trainset = TextMelSet(hparams.training_files, hparams) 12 | valset = TextMelSet(hparams.validation_files, hparams) 13 | collate_fn = TextMelCollate() 14 | 15 | train_loader = DataLoader(trainset, 16 | num_workers=hparams.n_gpus-1, 17 | shuffle=True, 18 | batch_size=hparams.batch_size, 19 | drop_last=True, 20 | collate_fn=collate_fn) 21 | 22 | val_loader = DataLoader(valset, 23 | batch_size=hparams.batch_size//hparams.n_gpus, 24 | collate_fn=collate_fn) 25 | 26 | return train_loader, val_loader, collate_fn 27 | 28 | 29 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 30 | print(f"Saving model and optimizer state at iteration {iteration} to {filepath}") 31 | torch.save({'iteration': iteration, 32 | 'state_dict': model.state_dict(), 33 | 'optimizer': optimizer.state_dict(), 34 | 'learning_rate': learning_rate}, f'{filepath}/checkpoint_{iteration}') 35 | 36 | 37 | def lr_scheduling(opt, step, init_lr=hparams.lr, warmup_steps=hparams.warmup_steps): 38 | opt.param_groups[0]['lr'] = init_lr * min(step ** -0.5, step * warmup_steps ** -1.5) 39 | return 40 | 41 | 42 | def get_mask_from_lengths(lengths): 43 | max_len = torch.max(lengths).item() 44 | ids = lengths.new_tensor(torch.arange(0, max_len)) 45 | mask = (lengths.unsqueeze(1) <= ids).to(torch.bool) 46 | return mask 47 | 48 | 49 | def reorder_batch(x, n_gpus): 50 | assert (x.size(0)%n_gpus)==0, 'Batch size must be a multiple of the number of GPUs.' 51 | new_x = x.new_zeros(x.size()) 52 | chunk_size = x.size(0)//n_gpus 53 | 54 | for i in range(n_gpus): 55 | new_x[i::n_gpus] = x[i*chunk_size:(i+1)*chunk_size] 56 | 57 | return new_x 58 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.tensorboard import SummaryWriter 3 | from .plot_image import * 4 | 5 | def get_writer(output_directory, log_directory): 6 | logging_path=f'{output_directory}/{log_directory}' 7 | 8 | if os.path.exists(logging_path): 9 | raise Exception('The experiment already exists') 10 | else: 11 | os.mkdir(logging_path) 12 | writer = TTSWriter(logging_path) 13 | 14 | return writer 15 | 16 | 17 | class TTSWriter(SummaryWriter): 18 | def __init__(self, log_dir): 19 | super(TTSWriter, self).__init__(log_dir) 20 | 21 | def add_losses(self, mel_loss, bce_loss, guide_loss, global_step, phase): 22 | self.add_scalar(f'{phase}_mel_loss', mel_loss, global_step) 23 | self.add_scalar(f'{phase}_bce_loss', bce_loss, global_step) 24 | self.add_scalar(f'{phase}_guide_loss', guide_loss, global_step) 25 | 26 | def add_specs(self, mel_padded, mel_out, mel_out_post, mel_lengths, global_step, phase): 27 | mel_fig = plot_melspec(mel_padded, mel_out, mel_out_post, mel_lengths) 28 | self.add_figure(f'{phase}_melspec', mel_fig, global_step) 29 | 30 | def add_alignments(self, enc_alignments, dec_alignments, enc_dec_alignments, 31 | text_padded, mel_lengths, text_lengths, global_step, phase): 32 | enc_align_fig = plot_alignments(enc_alignments, text_padded, mel_lengths, text_lengths, 'enc') 33 | self.add_figure(f'{phase}_enc_alignments', enc_align_fig, global_step) 34 | 35 | dec_align_fig = plot_alignments(dec_alignments, text_padded, mel_lengths, text_lengths, 'dec') 36 | self.add_figure(f'{phase}_dec_alignments', dec_align_fig, global_step) 37 | 38 | enc_dec_align_fig = plot_alignments(enc_dec_alignments, text_padded, mel_lengths, text_lengths, 'enc_dec') 39 | self.add_figure(f'{phase}_enc_dec_alignments', enc_dec_align_fig, global_step) 40 | 41 | def add_gates(self, gate_out, global_step, phase): 42 | gate_fig = plot_gate(gate_out) 43 | self.add_figure(f'{phase}_gate_out', gate_fig, global_step) -------------------------------------------------------------------------------- /waveglow/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tacotron2"] 2 | path = tacotron2 3 | url = http://github.com/NVIDIA/tacotron2 4 | -------------------------------------------------------------------------------- /waveglow/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /waveglow/README.md: -------------------------------------------------------------------------------- 1 | ![WaveGlow](waveglow_logo.png "WaveGLow") 2 | 3 | ## WaveGlow: a Flow-based Generative Network for Speech Synthesis 4 | 5 | ### Ryan Prenger, Rafael Valle, and Bryan Catanzaro 6 | 7 | In our recent [paper], we propose WaveGlow: a flow-based network capable of 8 | generating high quality speech from mel-spectrograms. WaveGlow combines insights 9 | from [Glow] and [WaveNet] in order to provide fast, efficient and high-quality 10 | audio synthesis, without the need for auto-regression. WaveGlow is implemented 11 | using only a single network, trained using only a single cost function: 12 | maximizing the likelihood of the training data, which makes the training 13 | procedure simple and stable. 14 | 15 | Our [PyTorch] implementation produces audio samples at a rate of 2750 16 | kHz on an NVIDIA V100 GPU. Mean Opinion Scores show that it delivers audio 17 | quality as good as the best publicly available WaveNet implementation. 18 | 19 | Visit our [website] for audio samples. 20 | 21 | ## Setup 22 | 23 | 1. Clone our repo and initialize submodule 24 | 25 | ```command 26 | git clone https://github.com/NVIDIA/waveglow.git 27 | cd waveglow 28 | git submodule init 29 | git submodule update 30 | ``` 31 | 32 | 2. Install requirements `pip3 install -r requirements.txt` 33 | 34 | 3. Install [Apex] 35 | 36 | 37 | ## Generate audio with our pre-existing model 38 | 39 | 1. Download our [published model] 40 | 2. Download [mel-spectrograms] 41 | 3. Generate audio `python3 inference.py -f <(ls mel_spectrograms/*.pt) -w waveglow_256channels.pt -o . --is_fp16 -s 0.6` 42 | 43 | N.b. use `convert_model.py` to convert your older models to the current model 44 | with fused residual and skip connections. 45 | 46 | ## Train your own model 47 | 48 | 1. Download [LJ Speech Data]. In this example it's in `data/` 49 | 50 | 2. Make a list of the file names to use for training/testing 51 | 52 | ```command 53 | ls data/*.wav | tail -n+10 > train_files.txt 54 | ls data/*.wav | head -n10 > test_files.txt 55 | ``` 56 | 57 | 3. Train your WaveGlow networks 58 | 59 | ```command 60 | mkdir checkpoints 61 | python train.py -c config.json 62 | ``` 63 | 64 | For multi-GPU training replace `train.py` with `distributed.py`. Only tested with single node and NCCL. 65 | 66 | For mixed precision training set `"fp16_run": true` on `config.json`. 67 | 68 | 4. Make test set mel-spectrograms 69 | 70 | `python mel2samp.py -f test_files.txt -o . -c config.json` 71 | 72 | 5. Do inference with your network 73 | 74 | ```command 75 | ls *.pt > mel_files.txt 76 | python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6 77 | ``` 78 | 79 | [//]: # (TODO) 80 | [//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS) 81 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation 82 | [website]: https://nv-adlr.github.io/WaveGlow 83 | [paper]: https://arxiv.org/abs/1811.00002 84 | [WaveNet implementation]: https://github.com/r9y9/wavenet_vocoder 85 | [Glow]: https://blog.openai.com/glow/ 86 | [WaveNet]: https://deepmind.com/blog/wavenet-generative-model-raw-audio/ 87 | [PyTorch]: http://pytorch.org 88 | [published model]: https://drive.google.com/file/d/1WsibBTsuRg_SF2Z6L6NFRTT-NjEy1oTx/view?usp=sharing 89 | [mel-spectrograms]: https://drive.google.com/file/d/1g_VXK2lpP9J25dQFhQwx7doWl_p20fXA/view?usp=sharing 90 | [LJ Speech Data]: https://keithito.com/LJ-Speech-Dataset 91 | [Apex]: https://github.com/nvidia/apex 92 | -------------------------------------------------------------------------------- /waveglow/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": { 3 | "fp16_run": true, 4 | "output_directory": "checkpoints", 5 | "epochs": 100000, 6 | "learning_rate": 1e-4, 7 | "sigma": 1.0, 8 | "iters_per_checkpoint": 2000, 9 | "batch_size": 12, 10 | "seed": 1234, 11 | "checkpoint_path": "", 12 | "with_tensorboard": false 13 | }, 14 | "data_config": { 15 | "training_files": "train_files.txt", 16 | "segment_length": 16000, 17 | "sampling_rate": 22050, 18 | "filter_length": 1024, 19 | "hop_length": 256, 20 | "win_length": 1024, 21 | "mel_fmin": 0.0, 22 | "mel_fmax": 8000.0 23 | }, 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321" 27 | }, 28 | 29 | "waveglow_config": { 30 | "n_mel_channels": 80, 31 | "n_flows": 12, 32 | "n_group": 8, 33 | "n_early_every": 4, 34 | "n_early_size": 2, 35 | "WN_config": { 36 | "n_layers": 8, 37 | "n_channels": 256, 38 | "kernel_size": 3 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /waveglow/convert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | 5 | def _check_model_old_version(model): 6 | if hasattr(model.WN[0], 'res_layers'): 7 | return True 8 | else: 9 | return False 10 | 11 | def update_model(old_model): 12 | if not _check_model_old_version(old_model): 13 | return old_model 14 | new_model = copy.deepcopy(old_model) 15 | for idx in range(0, len(new_model.WN)): 16 | wavenet = new_model.WN[idx] 17 | wavenet.res_skip_layers = torch.nn.ModuleList() 18 | n_channels = wavenet.n_channels 19 | n_layers = wavenet.n_layers 20 | for i in range(0, n_layers): 21 | if i < n_layers - 1: 22 | res_skip_channels = 2*n_channels 23 | else: 24 | res_skip_channels = n_channels 25 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 26 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) 27 | if i < n_layers - 1: 28 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) 29 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) 30 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) 31 | else: 32 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) 33 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) 34 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 35 | wavenet.res_skip_layers.append(res_skip_layer) 36 | del wavenet.res_layers 37 | del wavenet.skip_layers 38 | return new_model 39 | 40 | if __name__ == '__main__': 41 | old_model_path = sys.argv[1] 42 | new_model_path = sys.argv[2] 43 | model = torch.load(old_model_path) 44 | model['model'] = update_model(model['model']) 45 | torch.save(model, new_model_path) 46 | 47 | -------------------------------------------------------------------------------- /waveglow/denoiser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('tacotron2') 3 | import torch 4 | from layers import STFT 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """ Removes model bias from audio produced with waveglow """ 9 | 10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4, 11 | win_length=1024, mode='zeros'): 12 | super(Denoiser, self).__init__() 13 | self.stft = STFT(filter_length=filter_length, 14 | hop_length=int(filter_length/n_overlap), 15 | win_length=win_length).cuda() 16 | if mode == 'zeros': 17 | mel_input = torch.zeros( 18 | (1, 80, 88), 19 | dtype=waveglow.upsample.weight.dtype, 20 | device=waveglow.upsample.weight.device) 21 | elif mode == 'normal': 22 | mel_input = torch.randn( 23 | (1, 80, 88), 24 | dtype=waveglow.upsample.weight.dtype, 25 | device=waveglow.upsample.weight.device) 26 | else: 27 | raise Exception("Mode {} if not supported".format(mode)) 28 | 29 | with torch.no_grad(): 30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float() 31 | bias_spec, _ = self.stft.transform(bias_audio) 32 | 33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) 34 | 35 | def forward(self, audio, strength=0.1): 36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) 37 | audio_spec_denoised = audio_spec - self.bias_spec * strength 38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) 40 | return audio_denoised 41 | -------------------------------------------------------------------------------- /waveglow/distributed.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | import sys 29 | import time 30 | import subprocess 31 | import argparse 32 | 33 | import torch 34 | import torch.distributed as dist 35 | from torch.autograd import Variable 36 | 37 | def reduce_tensor(tensor, num_gpus): 38 | rt = tensor.clone() 39 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 40 | rt /= num_gpus 41 | return rt 42 | 43 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 44 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 45 | print("Initializing Distributed") 46 | 47 | # Set cuda device so everything is done on the right GPU. 48 | torch.cuda.set_device(rank % torch.cuda.device_count()) 49 | 50 | # Initialize distributed communication 51 | dist.init_process_group(dist_backend, init_method=dist_url, 52 | world_size=num_gpus, rank=rank, 53 | group_name=group_name) 54 | 55 | def _flatten_dense_tensors(tensors): 56 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 57 | same dense type. 58 | Since inputs are dense, the resulting tensor will be a concatenated 1D 59 | buffer. Element-wise operation on this buffer will be equivalent to 60 | operating individually. 61 | Arguments: 62 | tensors (Iterable[Tensor]): dense tensors to flatten. 63 | Returns: 64 | A contiguous 1D buffer containing input tensors. 65 | """ 66 | if len(tensors) == 1: 67 | return tensors[0].contiguous().view(-1) 68 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 69 | return flat 70 | 71 | def _unflatten_dense_tensors(flat, tensors): 72 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 73 | same dense type, and that flat is given by _flatten_dense_tensors. 74 | Arguments: 75 | flat (Tensor): flattened dense tensors to unflatten. 76 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 77 | unflatten flat. 78 | Returns: 79 | Unflattened dense tensors with sizes same as tensors and values from 80 | flat. 81 | """ 82 | outputs = [] 83 | offset = 0 84 | for tensor in tensors: 85 | numel = tensor.numel() 86 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 87 | offset += numel 88 | return tuple(outputs) 89 | 90 | def apply_gradient_allreduce(module): 91 | """ 92 | Modifies existing model to do gradient allreduce, but doesn't change class 93 | so you don't need "module" 94 | """ 95 | if not hasattr(dist, '_backend'): 96 | module.warn_on_half = True 97 | else: 98 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 99 | 100 | for p in module.state_dict().values(): 101 | if not torch.is_tensor(p): 102 | continue 103 | dist.broadcast(p, 0) 104 | 105 | def allreduce_params(): 106 | if(module.needs_reduction): 107 | module.needs_reduction = False 108 | buckets = {} 109 | for param in module.parameters(): 110 | if param.requires_grad and param.grad is not None: 111 | tp = type(param.data) 112 | if tp not in buckets: 113 | buckets[tp] = [] 114 | buckets[tp].append(param) 115 | if module.warn_on_half: 116 | if torch.cuda.HalfTensor in buckets: 117 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 118 | " It is recommended to use the NCCL backend in this case. This currently requires" + 119 | "PyTorch built from top of tree master.") 120 | module.warn_on_half = False 121 | 122 | for tp in buckets: 123 | bucket = buckets[tp] 124 | grads = [param.grad.data for param in bucket] 125 | coalesced = _flatten_dense_tensors(grads) 126 | dist.all_reduce(coalesced) 127 | coalesced /= dist.get_world_size() 128 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 129 | buf.copy_(synced) 130 | 131 | for param in list(module.parameters()): 132 | def allreduce_hook(*unused): 133 | Variable._execution_engine.queue_callback(allreduce_params) 134 | if param.requires_grad: 135 | param.register_hook(allreduce_hook) 136 | dir(param) 137 | 138 | def set_needs_reduction(self, input, output): 139 | self.needs_reduction = True 140 | 141 | module.register_forward_hook(set_needs_reduction) 142 | return module 143 | 144 | 145 | def main(config, stdout_dir, args_str): 146 | args_list = ['train.py'] 147 | args_list += args_str.split(' ') if len(args_str) > 0 else [] 148 | 149 | args_list.append('--config={}'.format(config)) 150 | 151 | num_gpus = torch.cuda.device_count() 152 | args_list.append('--num_gpus={}'.format(num_gpus)) 153 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S"))) 154 | 155 | if not os.path.isdir(stdout_dir): 156 | os.makedirs(stdout_dir) 157 | os.chmod(stdout_dir, 0o775) 158 | 159 | workers = [] 160 | 161 | for i in range(num_gpus): 162 | args_list[-2] = '--rank={}'.format(i) 163 | stdout = None if i == 0 else open( 164 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w") 165 | print(args_list) 166 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout) 167 | workers.append(p) 168 | 169 | for p in workers: 170 | p.wait() 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('-c', '--config', type=str, required=True, 176 | help='JSON file for configuration') 177 | parser.add_argument('-s', '--stdout_dir', type=str, default=".", 178 | help='directory to save stoud logs') 179 | parser.add_argument( 180 | '-a', '--args_str', type=str, default='', 181 | help='double quoted string with space separated key value pairs') 182 | 183 | args = parser.parse_args() 184 | main(args.config, args.stdout_dir, args.args_str) 185 | -------------------------------------------------------------------------------- /waveglow/glow.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import copy 28 | import torch 29 | from torch.autograd import Variable 30 | import torch.nn.functional as F 31 | 32 | 33 | @torch.jit.script 34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 35 | n_channels_int = n_channels[0] 36 | in_act = input_a+input_b 37 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 39 | acts = t_act * s_act 40 | return acts 41 | 42 | 43 | class WaveGlowLoss(torch.nn.Module): 44 | def __init__(self, sigma=1.0): 45 | super(WaveGlowLoss, self).__init__() 46 | self.sigma = sigma 47 | 48 | def forward(self, model_output): 49 | z, log_s_list, log_det_W_list = model_output 50 | for i, log_s in enumerate(log_s_list): 51 | if i == 0: 52 | log_s_total = torch.sum(log_s) 53 | log_det_W_total = log_det_W_list[i] 54 | else: 55 | log_s_total = log_s_total + torch.sum(log_s) 56 | log_det_W_total += log_det_W_list[i] 57 | 58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total 59 | return loss/(z.size(0)*z.size(1)*z.size(2)) 60 | 61 | 62 | class Invertible1x1Conv(torch.nn.Module): 63 | """ 64 | The layer outputs both the convolution, and the log determinant 65 | of its weight matrix. If reverse=True it does convolution with 66 | inverse 67 | """ 68 | def __init__(self, c): 69 | super(Invertible1x1Conv, self).__init__() 70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, 71 | bias=False) 72 | 73 | # Sample a random orthonormal matrix to initialize weights 74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0] 75 | 76 | # Ensure determinant is 1.0 not -1.0 77 | if torch.det(W) < 0: 78 | W[:,0] = -1*W[:,0] 79 | W = W.view(c, c, 1) 80 | self.conv.weight.data = W 81 | 82 | def forward(self, z, reverse=False): 83 | # shape 84 | batch_size, group_size, n_of_groups = z.size() 85 | 86 | W = self.conv.weight.squeeze() 87 | 88 | if reverse: 89 | if not hasattr(self, 'W_inverse'): 90 | # Reverse computation 91 | W_inverse = W.float().inverse() 92 | W_inverse = Variable(W_inverse[..., None]) 93 | if z.type() == 'torch.cuda.HalfTensor': 94 | W_inverse = W_inverse.half() 95 | self.W_inverse = W_inverse 96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) 97 | return z 98 | else: 99 | # Forward computation 100 | log_det_W = batch_size * n_of_groups * torch.logdet(W) 101 | z = self.conv(z) 102 | return z, log_det_W 103 | 104 | 105 | class WN(torch.nn.Module): 106 | """ 107 | This is the WaveNet like layer for the affine coupling. The primary difference 108 | from WaveNet is the convolutions need not be causal. There is also no dilation 109 | size reset. The dilation only doubles on each layer 110 | """ 111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 112 | kernel_size): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | assert(n_channels % 2 == 0) 116 | self.n_layers = n_layers 117 | self.n_channels = n_channels 118 | self.in_layers = torch.nn.ModuleList() 119 | self.res_skip_layers = torch.nn.ModuleList() 120 | self.cond_layers = torch.nn.ModuleList() 121 | 122 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 123 | start = torch.nn.utils.weight_norm(start, name='weight') 124 | self.start = start 125 | 126 | # Initializing last layer to 0 makes the affine coupling layers 127 | # do nothing at first. This helps with training stability 128 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 129 | end.weight.data.zero_() 130 | end.bias.data.zero_() 131 | self.end = end 132 | 133 | for i in range(n_layers): 134 | dilation = 2 ** i 135 | padding = int((kernel_size*dilation - dilation)/2) 136 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 137 | dilation=dilation, padding=padding) 138 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 139 | self.in_layers.append(in_layer) 140 | 141 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 142 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 143 | self.cond_layers.append(cond_layer) 144 | 145 | # last one is not necessary 146 | if i < n_layers - 1: 147 | res_skip_channels = 2*n_channels 148 | else: 149 | res_skip_channels = n_channels 150 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 151 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 152 | self.res_skip_layers.append(res_skip_layer) 153 | 154 | def forward(self, forward_input): 155 | audio, spect = forward_input 156 | audio = self.start(audio) 157 | 158 | for i in range(self.n_layers): 159 | acts = fused_add_tanh_sigmoid_multiply( 160 | self.in_layers[i](audio), 161 | self.cond_layers[i](spect), 162 | torch.IntTensor([self.n_channels])) 163 | 164 | res_skip_acts = self.res_skip_layers[i](acts) 165 | if i < self.n_layers - 1: 166 | audio = res_skip_acts[:,:self.n_channels,:] + audio 167 | skip_acts = res_skip_acts[:,self.n_channels:,:] 168 | else: 169 | skip_acts = res_skip_acts 170 | 171 | if i == 0: 172 | output = skip_acts 173 | else: 174 | output = skip_acts + output 175 | return self.end(output) 176 | 177 | 178 | class WaveGlow(torch.nn.Module): 179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 180 | n_early_size, WN_config): 181 | super(WaveGlow, self).__init__() 182 | 183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 184 | n_mel_channels, 185 | 1024, stride=256) 186 | assert(n_group % 2 == 0) 187 | self.n_flows = n_flows 188 | self.n_group = n_group 189 | self.n_early_every = n_early_every 190 | self.n_early_size = n_early_size 191 | self.WN = torch.nn.ModuleList() 192 | self.convinv = torch.nn.ModuleList() 193 | 194 | n_half = int(n_group/2) 195 | 196 | # Set up layers with the right sizes based on how many dimensions 197 | # have been output already 198 | n_remaining_channels = n_group 199 | for k in range(n_flows): 200 | if k % self.n_early_every == 0 and k > 0: 201 | n_half = n_half - int(self.n_early_size/2) 202 | n_remaining_channels = n_remaining_channels - self.n_early_size 203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 205 | self.n_remaining_channels = n_remaining_channels # Useful during inference 206 | 207 | def forward(self, forward_input): 208 | """ 209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames 210 | forward_input[1] = audio: batch x time 211 | """ 212 | spect, audio = forward_input 213 | 214 | # Upsample spectrogram to size of audio 215 | spect = self.upsample(spect) 216 | assert(spect.size(2) >= audio.size(1)) 217 | if spect.size(2) > audio.size(1): 218 | spect = spect[:, :, :audio.size(1)] 219 | 220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 222 | 223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 224 | output_audio = [] 225 | log_s_list = [] 226 | log_det_W_list = [] 227 | 228 | for k in range(self.n_flows): 229 | if k % self.n_early_every == 0 and k > 0: 230 | output_audio.append(audio[:,:self.n_early_size,:]) 231 | audio = audio[:,self.n_early_size:,:] 232 | 233 | audio, log_det_W = self.convinv[k](audio) 234 | log_det_W_list.append(log_det_W) 235 | 236 | n_half = int(audio.size(1)/2) 237 | audio_0 = audio[:,:n_half,:] 238 | audio_1 = audio[:,n_half:,:] 239 | 240 | output = self.WN[k]((audio_0, spect)) 241 | log_s = output[:, n_half:, :] 242 | b = output[:, :n_half, :] 243 | audio_1 = torch.exp(log_s)*audio_1 + b 244 | log_s_list.append(log_s) 245 | 246 | audio = torch.cat([audio_0, audio_1],1) 247 | 248 | output_audio.append(audio) 249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list 250 | 251 | def infer(self, spect, sigma=1.0): 252 | spect = self.upsample(spect) 253 | # trim conv artifacts. maybe pad spec to kernel multiple 254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 255 | spect = spect[:, :, :-time_cutoff] 256 | 257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 259 | 260 | if spect.type() == 'torch.cuda.HalfTensor': 261 | audio = torch.cuda.HalfTensor(spect.size(0), 262 | self.n_remaining_channels, 263 | spect.size(2)).normal_() 264 | else: 265 | audio = torch.cuda.FloatTensor(spect.size(0), 266 | self.n_remaining_channels, 267 | spect.size(2)).normal_() 268 | 269 | audio = torch.autograd.Variable(sigma*audio) 270 | 271 | for k in reversed(range(self.n_flows)): 272 | n_half = int(audio.size(1)/2) 273 | audio_0 = audio[:,:n_half,:] 274 | audio_1 = audio[:,n_half:,:] 275 | 276 | output = self.WN[k]((audio_0, spect)) 277 | s = output[:, n_half:, :] 278 | b = output[:, :n_half, :] 279 | audio_1 = (audio_1 - b)/torch.exp(s) 280 | audio = torch.cat([audio_0, audio_1],1) 281 | 282 | audio = self.convinv[k](audio, reverse=True) 283 | 284 | if k % self.n_early_every == 0 and k > 0: 285 | if spect.type() == 'torch.cuda.HalfTensor': 286 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 287 | else: 288 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 289 | audio = torch.cat((sigma*z, audio),1) 290 | 291 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 292 | return audio 293 | 294 | @staticmethod 295 | def remove_weightnorm(model): 296 | waveglow = model 297 | for WN in waveglow.WN: 298 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 299 | WN.in_layers = remove(WN.in_layers) 300 | WN.cond_layers = remove(WN.cond_layers) 301 | WN.res_skip_layers = remove(WN.res_skip_layers) 302 | return waveglow 303 | 304 | 305 | def remove(conv_list): 306 | new_conv_list = torch.nn.ModuleList() 307 | for old_conv in conv_list: 308 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 309 | new_conv_list.append(old_conv) 310 | return new_conv_list 311 | -------------------------------------------------------------------------------- /waveglow/glow_old.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from glow import Invertible1x1Conv, remove 4 | 5 | 6 | @torch.jit.script 7 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 8 | n_channels_int = n_channels[0] 9 | in_act = input_a+input_b 10 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 11 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 12 | acts = t_act * s_act 13 | return acts 14 | 15 | 16 | class WN(torch.nn.Module): 17 | """ 18 | This is the WaveNet like layer for the affine coupling. The primary difference 19 | from WaveNet is the convolutions need not be causal. There is also no dilation 20 | size reset. The dilation only doubles on each layer 21 | """ 22 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 23 | kernel_size): 24 | super(WN, self).__init__() 25 | assert(kernel_size % 2 == 1) 26 | assert(n_channels % 2 == 0) 27 | self.n_layers = n_layers 28 | self.n_channels = n_channels 29 | self.in_layers = torch.nn.ModuleList() 30 | self.res_skip_layers = torch.nn.ModuleList() 31 | self.cond_layers = torch.nn.ModuleList() 32 | 33 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 34 | start = torch.nn.utils.weight_norm(start, name='weight') 35 | self.start = start 36 | 37 | # Initializing last layer to 0 makes the affine coupling layers 38 | # do nothing at first. This helps with training stability 39 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 40 | end.weight.data.zero_() 41 | end.bias.data.zero_() 42 | self.end = end 43 | 44 | for i in range(n_layers): 45 | dilation = 2 ** i 46 | padding = int((kernel_size*dilation - dilation)/2) 47 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 48 | dilation=dilation, padding=padding) 49 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 50 | self.in_layers.append(in_layer) 51 | 52 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 53 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 54 | self.cond_layers.append(cond_layer) 55 | 56 | # last one is not necessary 57 | if i < n_layers - 1: 58 | res_skip_channels = 2*n_channels 59 | else: 60 | res_skip_channels = n_channels 61 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 62 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 63 | self.res_skip_layers.append(res_skip_layer) 64 | 65 | def forward(self, forward_input): 66 | audio, spect = forward_input 67 | audio = self.start(audio) 68 | 69 | for i in range(self.n_layers): 70 | acts = fused_add_tanh_sigmoid_multiply( 71 | self.in_layers[i](audio), 72 | self.cond_layers[i](spect), 73 | torch.IntTensor([self.n_channels])) 74 | 75 | res_skip_acts = self.res_skip_layers[i](acts) 76 | if i < self.n_layers - 1: 77 | audio = res_skip_acts[:,:self.n_channels,:] + audio 78 | skip_acts = res_skip_acts[:,self.n_channels:,:] 79 | else: 80 | skip_acts = res_skip_acts 81 | 82 | if i == 0: 83 | output = skip_acts 84 | else: 85 | output = skip_acts + output 86 | return self.end(output) 87 | 88 | 89 | class WaveGlow(torch.nn.Module): 90 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 91 | n_early_size, WN_config): 92 | super(WaveGlow, self).__init__() 93 | 94 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 95 | n_mel_channels, 96 | 1024, stride=256) 97 | assert(n_group % 2 == 0) 98 | self.n_flows = n_flows 99 | self.n_group = n_group 100 | self.n_early_every = n_early_every 101 | self.n_early_size = n_early_size 102 | self.WN = torch.nn.ModuleList() 103 | self.convinv = torch.nn.ModuleList() 104 | 105 | n_half = int(n_group/2) 106 | 107 | # Set up layers with the right sizes based on how many dimensions 108 | # have been output already 109 | n_remaining_channels = n_group 110 | for k in range(n_flows): 111 | if k % self.n_early_every == 0 and k > 0: 112 | n_half = n_half - int(self.n_early_size/2) 113 | n_remaining_channels = n_remaining_channels - self.n_early_size 114 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 115 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 116 | self.n_remaining_channels = n_remaining_channels # Useful during inference 117 | 118 | def forward(self, forward_input): 119 | return None 120 | """ 121 | forward_input[0] = audio: batch x time 122 | forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time 123 | """ 124 | """ 125 | spect, audio = forward_input 126 | 127 | # Upsample spectrogram to size of audio 128 | spect = self.upsample(spect) 129 | assert(spect.size(2) >= audio.size(1)) 130 | if spect.size(2) > audio.size(1): 131 | spect = spect[:, :, :audio.size(1)] 132 | 133 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 134 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 135 | 136 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 137 | output_audio = [] 138 | s_list = [] 139 | s_conv_list = [] 140 | 141 | for k in range(self.n_flows): 142 | if k%4 == 0 and k > 0: 143 | output_audio.append(audio[:,:self.n_multi,:]) 144 | audio = audio[:,self.n_multi:,:] 145 | 146 | # project to new basis 147 | audio, s = self.convinv[k](audio) 148 | s_conv_list.append(s) 149 | 150 | n_half = int(audio.size(1)/2) 151 | if k%2 == 0: 152 | audio_0 = audio[:,:n_half,:] 153 | audio_1 = audio[:,n_half:,:] 154 | else: 155 | audio_1 = audio[:,:n_half,:] 156 | audio_0 = audio[:,n_half:,:] 157 | 158 | output = self.nn[k]((audio_0, spect)) 159 | s = output[:, n_half:, :] 160 | b = output[:, :n_half, :] 161 | audio_1 = torch.exp(s)*audio_1 + b 162 | s_list.append(s) 163 | 164 | if k%2 == 0: 165 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 166 | else: 167 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 168 | output_audio.append(audio) 169 | return torch.cat(output_audio,1), s_list, s_conv_list 170 | """ 171 | 172 | def infer(self, spect, sigma=1.0): 173 | spect = self.upsample(spect) 174 | # trim conv artifacts. maybe pad spec to kernel multiple 175 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 176 | spect = spect[:, :, :-time_cutoff] 177 | 178 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 179 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 180 | 181 | if spect.type() == 'torch.cuda.HalfTensor': 182 | audio = torch.cuda.HalfTensor(spect.size(0), 183 | self.n_remaining_channels, 184 | spect.size(2)).normal_() 185 | else: 186 | audio = torch.cuda.FloatTensor(spect.size(0), 187 | self.n_remaining_channels, 188 | spect.size(2)).normal_() 189 | 190 | audio = torch.autograd.Variable(sigma*audio) 191 | 192 | for k in reversed(range(self.n_flows)): 193 | n_half = int(audio.size(1)/2) 194 | if k%2 == 0: 195 | audio_0 = audio[:,:n_half,:] 196 | audio_1 = audio[:,n_half:,:] 197 | else: 198 | audio_1 = audio[:,:n_half,:] 199 | audio_0 = audio[:,n_half:,:] 200 | 201 | output = self.WN[k]((audio_0, spect)) 202 | s = output[:, n_half:, :] 203 | b = output[:, :n_half, :] 204 | audio_1 = (audio_1 - b)/torch.exp(s) 205 | if k%2 == 0: 206 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 207 | else: 208 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 209 | 210 | audio = self.convinv[k](audio, reverse=True) 211 | 212 | if k%4 == 0 and k > 0: 213 | if spect.type() == 'torch.cuda.HalfTensor': 214 | z = torch.cuda.HalfTensor(spect.size(0), 215 | self.n_early_size, 216 | spect.size(2)).normal_() 217 | else: 218 | z = torch.cuda.FloatTensor(spect.size(0), 219 | self.n_early_size, 220 | spect.size(2)).normal_() 221 | audio = torch.cat((sigma*z, audio),1) 222 | 223 | return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 224 | 225 | @staticmethod 226 | def remove_weightnorm(model): 227 | waveglow = model 228 | for WN in waveglow.WN: 229 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 230 | WN.in_layers = remove(WN.in_layers) 231 | WN.cond_layers = remove(WN.cond_layers) 232 | WN.res_skip_layers = remove(WN.res_skip_layers) 233 | return waveglow 234 | -------------------------------------------------------------------------------- /waveglow/inference.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | from scipy.io.wavfile import write 29 | import torch 30 | from mel2samp import files_to_list, MAX_WAV_VALUE 31 | from denoiser import Denoiser 32 | 33 | 34 | def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, 35 | denoiser_strength): 36 | mel_files = files_to_list(mel_files) 37 | waveglow = torch.load(waveglow_path)['model'] 38 | waveglow = waveglow.remove_weightnorm(waveglow) 39 | waveglow.cuda().eval() 40 | if is_fp16: 41 | from apex import amp 42 | waveglow, _ = amp.initialize(waveglow, [], opt_level="O3") 43 | 44 | if denoiser_strength > 0: 45 | denoiser = Denoiser(waveglow).cuda() 46 | 47 | for i, file_path in enumerate(mel_files): 48 | file_name = os.path.splitext(os.path.basename(file_path))[0] 49 | mel = torch.load(file_path) 50 | mel = torch.autograd.Variable(mel.cuda()) 51 | mel = torch.unsqueeze(mel, 0) 52 | mel = mel.half() if is_fp16 else mel 53 | with torch.no_grad(): 54 | audio = waveglow.infer(mel, sigma=sigma) 55 | if denoiser_strength > 0: 56 | audio = denoiser(audio, denoiser_strength) 57 | audio = audio * MAX_WAV_VALUE 58 | audio = audio.squeeze() 59 | audio = audio.cpu().numpy() 60 | audio = audio.astype('int16') 61 | audio_path = os.path.join( 62 | output_dir, "{}_synthesis.wav".format(file_name)) 63 | write(audio_path, sampling_rate, audio) 64 | print(audio_path) 65 | 66 | 67 | if __name__ == "__main__": 68 | import argparse 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-f', "--filelist_path", required=True) 72 | parser.add_argument('-w', '--waveglow_path', 73 | help='Path to waveglow decoder checkpoint with model') 74 | parser.add_argument('-o', "--output_dir", required=True) 75 | parser.add_argument("-s", "--sigma", default=1.0, type=float) 76 | parser.add_argument("--sampling_rate", default=22050, type=int) 77 | parser.add_argument("--is_fp16", action="store_true") 78 | parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float, 79 | help='Removes model bias. Start with 0.1 and adjust') 80 | 81 | args = parser.parse_args() 82 | 83 | main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir, 84 | args.sampling_rate, args.is_fp16, args.denoiser_strength) 85 | -------------------------------------------------------------------------------- /waveglow/mel2samp.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # *****************************************************************************\ 27 | import os 28 | import random 29 | import argparse 30 | import json 31 | import torch 32 | import torch.utils.data 33 | import sys 34 | from scipy.io.wavfile import read 35 | 36 | # We're using the audio processing from TacoTron2 to make sure it matches 37 | sys.path.insert(0, 'tacotron2') 38 | from tacotron2.layers import TacotronSTFT 39 | 40 | MAX_WAV_VALUE = 32768.0 41 | 42 | def files_to_list(filename): 43 | """ 44 | Takes a text file of filenames and makes a list of filenames 45 | """ 46 | with open(filename, encoding='utf-8') as f: 47 | files = f.readlines() 48 | 49 | files = [f.rstrip() for f in files] 50 | return files 51 | 52 | def load_wav_to_torch(full_path): 53 | """ 54 | Loads wavdata into torch array 55 | """ 56 | sampling_rate, data = read(full_path) 57 | return torch.from_numpy(data).float(), sampling_rate 58 | 59 | 60 | class Mel2Samp(torch.utils.data.Dataset): 61 | """ 62 | This is the main class that calculates the spectrogram and returns the 63 | spectrogram, audio pair. 64 | """ 65 | def __init__(self, training_files, segment_length, filter_length, 66 | hop_length, win_length, sampling_rate, mel_fmin, mel_fmax): 67 | self.audio_files = files_to_list(training_files) 68 | random.seed(1234) 69 | random.shuffle(self.audio_files) 70 | self.stft = TacotronSTFT(filter_length=filter_length, 71 | hop_length=hop_length, 72 | win_length=win_length, 73 | sampling_rate=sampling_rate, 74 | mel_fmin=mel_fmin, mel_fmax=mel_fmax) 75 | self.segment_length = segment_length 76 | self.sampling_rate = sampling_rate 77 | 78 | def get_mel(self, audio): 79 | audio_norm = audio / MAX_WAV_VALUE 80 | audio_norm = audio_norm.unsqueeze(0) 81 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 82 | melspec = self.stft.mel_spectrogram(audio_norm) 83 | melspec = torch.squeeze(melspec, 0) 84 | return melspec 85 | 86 | def __getitem__(self, index): 87 | # Read audio 88 | filename = self.audio_files[index] 89 | audio, sampling_rate = load_wav_to_torch(filename) 90 | if sampling_rate != self.sampling_rate: 91 | raise ValueError("{} SR doesn't match target {} SR".format( 92 | sampling_rate, self.sampling_rate)) 93 | 94 | # Take segment 95 | if audio.size(0) >= self.segment_length: 96 | max_audio_start = audio.size(0) - self.segment_length 97 | audio_start = random.randint(0, max_audio_start) 98 | audio = audio[audio_start:audio_start+self.segment_length] 99 | else: 100 | audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data 101 | 102 | mel = self.get_mel(audio) 103 | audio = audio / MAX_WAV_VALUE 104 | 105 | return (mel, audio) 106 | 107 | def __len__(self): 108 | return len(self.audio_files) 109 | 110 | # =================================================================== 111 | # Takes directory of clean audio and makes directory of spectrograms 112 | # Useful for making test sets 113 | # =================================================================== 114 | if __name__ == "__main__": 115 | # Get defaults so it can work with no Sacred 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('-f', "--filelist_path", required=True) 118 | parser.add_argument('-c', '--config', type=str, 119 | help='JSON file for configuration') 120 | parser.add_argument('-o', '--output_dir', type=str, 121 | help='Output directory') 122 | args = parser.parse_args() 123 | 124 | with open(args.config) as f: 125 | data = f.read() 126 | data_config = json.loads(data)["data_config"] 127 | mel2samp = Mel2Samp(**data_config) 128 | 129 | filepaths = files_to_list(args.filelist_path) 130 | 131 | # Make directory if it doesn't exist 132 | if not os.path.isdir(args.output_dir): 133 | os.makedirs(args.output_dir) 134 | os.chmod(args.output_dir, 0o775) 135 | 136 | for filepath in filepaths: 137 | audio, sr = load_wav_to_torch(filepath) 138 | melspectrogram = mel2samp.get_mel(audio) 139 | filename = os.path.basename(filepath) 140 | new_filepath = args.output_dir + '/' + filename + '.pt' 141 | print(new_filepath) 142 | torch.save(melspectrogram, new_filepath) 143 | -------------------------------------------------------------------------------- /waveglow/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0 2 | matplotlib==2.1.0 3 | tensorflow 4 | numpy==1.13.3 5 | inflect==0.2.5 6 | librosa==0.6.0 7 | scipy==1.0.0 8 | tensorboardX==1.1 9 | Unidecode==1.0.22 10 | pillow 11 | -------------------------------------------------------------------------------- /waveglow/train.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import argparse 28 | import json 29 | import os 30 | import torch 31 | 32 | #=====START: ADDED FOR DISTRIBUTED====== 33 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor 34 | from torch.utils.data.distributed import DistributedSampler 35 | #=====END: ADDED FOR DISTRIBUTED====== 36 | 37 | from torch.utils.data import DataLoader 38 | from glow import WaveGlow, WaveGlowLoss 39 | from mel2samp import Mel2Samp 40 | 41 | def load_checkpoint(checkpoint_path, model, optimizer): 42 | assert os.path.isfile(checkpoint_path) 43 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 44 | iteration = checkpoint_dict['iteration'] 45 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 46 | model_for_loading = checkpoint_dict['model'] 47 | model.load_state_dict(model_for_loading.state_dict()) 48 | print("Loaded checkpoint '{}' (iteration {})" .format( 49 | checkpoint_path, iteration)) 50 | return model, optimizer, iteration 51 | 52 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 53 | print("Saving model and optimizer state at iteration {} to {}".format( 54 | iteration, filepath)) 55 | model_for_saving = WaveGlow(**waveglow_config).cuda() 56 | model_for_saving.load_state_dict(model.state_dict()) 57 | torch.save({'model': model_for_saving, 58 | 'iteration': iteration, 59 | 'optimizer': optimizer.state_dict(), 60 | 'learning_rate': learning_rate}, filepath) 61 | 62 | def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, 63 | sigma, iters_per_checkpoint, batch_size, seed, fp16_run, 64 | checkpoint_path, with_tensorboard): 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed(seed) 67 | #=====START: ADDED FOR DISTRIBUTED====== 68 | if num_gpus > 1: 69 | init_distributed(rank, num_gpus, group_name, **dist_config) 70 | #=====END: ADDED FOR DISTRIBUTED====== 71 | 72 | criterion = WaveGlowLoss(sigma) 73 | model = WaveGlow(**waveglow_config).cuda() 74 | 75 | #=====START: ADDED FOR DISTRIBUTED====== 76 | if num_gpus > 1: 77 | model = apply_gradient_allreduce(model) 78 | #=====END: ADDED FOR DISTRIBUTED====== 79 | 80 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 81 | 82 | if fp16_run: 83 | from apex import amp 84 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 85 | 86 | # Load checkpoint if one exists 87 | iteration = 0 88 | if checkpoint_path != "": 89 | model, optimizer, iteration = load_checkpoint(checkpoint_path, model, 90 | optimizer) 91 | iteration += 1 # next iteration is iteration + 1 92 | 93 | trainset = Mel2Samp(**data_config) 94 | # =====START: ADDED FOR DISTRIBUTED====== 95 | train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None 96 | # =====END: ADDED FOR DISTRIBUTED====== 97 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False, 98 | sampler=train_sampler, 99 | batch_size=batch_size, 100 | pin_memory=False, 101 | drop_last=True) 102 | 103 | # Get shared output_directory ready 104 | if rank == 0: 105 | if not os.path.isdir(output_directory): 106 | os.makedirs(output_directory) 107 | os.chmod(output_directory, 0o775) 108 | print("output directory", output_directory) 109 | 110 | if with_tensorboard and rank == 0: 111 | from tensorboardX import SummaryWriter 112 | logger = SummaryWriter(os.path.join(output_directory, 'logs')) 113 | 114 | model.train() 115 | epoch_offset = max(0, int(iteration / len(train_loader))) 116 | # ================ MAIN TRAINNIG LOOP! =================== 117 | for epoch in range(epoch_offset, epochs): 118 | print("Epoch: {}".format(epoch)) 119 | for i, batch in enumerate(train_loader): 120 | model.zero_grad() 121 | 122 | mel, audio = batch 123 | mel = torch.autograd.Variable(mel.cuda()) 124 | audio = torch.autograd.Variable(audio.cuda()) 125 | outputs = model((mel, audio)) 126 | 127 | loss = criterion(outputs) 128 | if num_gpus > 1: 129 | reduced_loss = reduce_tensor(loss.data, num_gpus).item() 130 | else: 131 | reduced_loss = loss.item() 132 | 133 | if fp16_run: 134 | with amp.scale_loss(loss, optimizer) as scaled_loss: 135 | scaled_loss.backward() 136 | else: 137 | loss.backward() 138 | 139 | optimizer.step() 140 | 141 | print("{}:\t{:.9f}".format(iteration, reduced_loss)) 142 | if with_tensorboard and rank == 0: 143 | logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) 144 | 145 | if (iteration % iters_per_checkpoint == 0): 146 | if rank == 0: 147 | checkpoint_path = "{}/waveglow_{}".format( 148 | output_directory, iteration) 149 | save_checkpoint(model, optimizer, learning_rate, iteration, 150 | checkpoint_path) 151 | 152 | iteration += 1 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('-c', '--config', type=str, 157 | help='JSON file for configuration') 158 | parser.add_argument('-r', '--rank', type=int, default=0, 159 | help='rank of process for distributed') 160 | parser.add_argument('-g', '--group_name', type=str, default='', 161 | help='name of group for distributed') 162 | args = parser.parse_args() 163 | 164 | # Parse configs. Globals nicer in this case 165 | with open(args.config) as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | train_config = config["train_config"] 169 | global data_config 170 | data_config = config["data_config"] 171 | global dist_config 172 | dist_config = config["dist_config"] 173 | global waveglow_config 174 | waveglow_config = config["waveglow_config"] 175 | 176 | num_gpus = torch.cuda.device_count() 177 | if num_gpus > 1: 178 | if args.group_name == '': 179 | print("WARNING: Multiple GPUs detected but no distributed group set") 180 | print("Only running 1 GPU. Use distributed.py for multiple GPUs") 181 | num_gpus = 1 182 | 183 | if num_gpus == 1 and args.rank != 0: 184 | raise Exception("Doing single GPU training on rank > 0") 185 | 186 | torch.backends.cudnn.enabled = True 187 | torch.backends.cudnn.benchmark = False 188 | train(num_gpus, args.rank, args.group_name, **train_config) 189 | -------------------------------------------------------------------------------- /waveglow/waveglow_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/waveglow/waveglow_logo.png -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0029_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char100000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0029_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char20000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0029_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char200000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0029_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char50000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0085_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char100000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0085_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char20000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0085_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char200000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ001-0085_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char50000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ002-0106_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char100000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ002-0106_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char20000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ002-0106_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char200000.wav -------------------------------------------------------------------------------- /wavs/gae-char/LJ002-0106_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char50000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0029_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char100000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0029_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char20000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0029_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char200000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0029_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char50000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0085_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char100000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0085_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char20000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0085_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char200000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ001-0085_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char50000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ002-0106_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char100000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ002-0106_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char20000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ002-0106_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char200000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char/LJ002-0106_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char50000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0029_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char100000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0029_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char20000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0029_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char200000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0029_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char50000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0085_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char100000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0085_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char20000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0085_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char200000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ001-0085_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char50000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ002-0106_char100000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char100000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ002-0106_char20000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char20000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ002-0106_char200000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char200000.wav -------------------------------------------------------------------------------- /wavs/graph-tts-char_iter5/LJ002-0106_char50000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char50000.wav --------------------------------------------------------------------------------