├── README.md
├── audio_processing.py
├── figures
├── enc_dec_align.JPG
├── gate_out.JPG
├── melspec.JPG
├── train_bce_loss.JPG
├── train_guide_loss.JPG
├── train_mel_loss.JPG
├── val_bce_loss.JPG
├── val_guide_loss.JPG
└── val_mel_loss.JPG
├── filelists
├── ljs_audio_text_test_filelist.txt
├── ljs_audio_text_train_filelist.txt
└── ljs_audio_text_val_filelist.txt
├── generate_samples.ipynb
├── hparams.py
├── index.html
├── inference.ipynb
├── layers.py
├── modules
├── __pycache__
│ ├── init_layer.cpython-37.pyc
│ ├── loss.cpython-37.pyc
│ ├── model.cpython-37.pyc
│ └── transformer.cpython-37.pyc
├── init_layer.py
├── loss.py
├── model.py
└── transformer.py
├── prepare_data.ipynb
├── stft.py
├── text
├── LICENSE
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── cleaners.cpython-36.pyc
│ ├── cleaners.cpython-37.pyc
│ ├── cmudict.cpython-37.pyc
│ ├── numbers.cpython-37.pyc
│ └── symbols.cpython-37.pyc
├── cleaners.py
├── cmudict.py
├── numbers.py
└── symbols.py
├── train-gae.py
├── train-graphtts.py
├── train.py
├── utils
├── __pycache__
│ ├── data_utils.cpython-37.pyc
│ ├── plot_image.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ └── writer.cpython-37.pyc
├── data_utils.py
├── plot_image.py
├── utils.py
└── writer.py
├── waveglow
├── .gitmodules
├── LICENSE
├── README.md
├── config.json
├── convert_model.py
├── denoiser.py
├── distributed.py
├── glow.py
├── glow_old.py
├── inference.py
├── mel2samp.py
├── requirements.txt
├── train.py
└── waveglow_logo.png
└── wavs
├── gae-char
├── LJ001-0029_char100000.wav
├── LJ001-0029_char20000.wav
├── LJ001-0029_char200000.wav
├── LJ001-0029_char50000.wav
├── LJ001-0085_char100000.wav
├── LJ001-0085_char20000.wav
├── LJ001-0085_char200000.wav
├── LJ001-0085_char50000.wav
├── LJ002-0106_char100000.wav
├── LJ002-0106_char20000.wav
├── LJ002-0106_char200000.wav
└── LJ002-0106_char50000.wav
├── graph-tts-char
├── LJ001-0029_char100000.wav
├── LJ001-0029_char20000.wav
├── LJ001-0029_char200000.wav
├── LJ001-0029_char50000.wav
├── LJ001-0085_char100000.wav
├── LJ001-0085_char20000.wav
├── LJ001-0085_char200000.wav
├── LJ001-0085_char50000.wav
├── LJ002-0106_char100000.wav
├── LJ002-0106_char20000.wav
├── LJ002-0106_char200000.wav
└── LJ002-0106_char50000.wav
└── graph-tts-char_iter5
├── LJ001-0029_char100000.wav
├── LJ001-0029_char20000.wav
├── LJ001-0029_char200000.wav
├── LJ001-0029_char50000.wav
├── LJ001-0085_char100000.wav
├── LJ001-0085_char20000.wav
├── LJ001-0085_char200000.wav
├── LJ001-0085_char50000.wav
├── LJ002-0106_char100000.wav
├── LJ002-0106_char20000.wav
├── LJ002-0106_char200000.wav
└── LJ002-0106_char50000.wav
/README.md:
--------------------------------------------------------------------------------
1 | # Graph-TTS
2 | - Implementation of ["GraphTTS: graph-to-sequence modelling in neural text-to-speech"](https://arxiv.org/abs/2003.01924)
3 | - I failed to generate the plausible speech :(
4 |
5 | ## Training
6 | 1. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/)
7 | 2. Make `preprocessed` folder in LJSpeech directory and make `char_seq` & `phone_seq` & `melspectrogram` folder in it
8 | 3. Set `data_path` in `hparams.py` as the LJSpeech folder
9 | 4. Using `prepare_data.ipynb`, prepare melspectrogram and text (converted into indices) tensors.
10 | 5. `python train.py`
11 |
12 | ## Training curve (Orange: transformer-tts / Navy: graph-tts / Red: grap-tts-iter5 / Blue: gae)
13 | - Stop prediction loss (train / val)
14 |
15 | - Guided attention loss (train / val)
16 |
17 | - L1 loss (train / val)
18 |
19 |
20 | ## Alignments
21 | - Encoder-Decoder Alignments
22 |
23 |
24 | - Melspectrogram
25 |
26 |
27 | - Stop prediction
28 |
29 |
30 | ## Audio Samples
31 | You can hear the audio samples [here](https://leeyoonhyung.github.io/GraphTTS/)
32 | You can also hear the audio samples obtained from the Transformer-TTS [here](https://leeyoonhyung.github.io/Transformer-TTS/)
33 |
34 | ## Notice
35 | 1. Unlike the original paper, I didn't use the encoder-prenet following [espnet](https://github.com/espnet/espnet)
36 | 2. I apply additional ["guided attention loss"](https://arxiv.org/pdf/1710.08969.pdf) to the two heads of the last two layers
37 | 3. Batch size is important, so I use gradient accumulation
38 | 4. You can also use DataParallel. Change the `n_gpus`, `batch_size`, `accumulation` appropriately.
39 | 5. To draw attention plots for every each head, I change return values of the "torch.nn.functional.multi_head_attention_forward()"
40 | ```python
41 | #before
42 | return attn_output, attn_output_weights.sum(dim=1) / num_heads
43 |
44 | #after
45 | return attn_output, attn_output_weights
46 | ```
47 | 3. Among `num_layers*num_heads` attention matrices, the one with the highest focus rate is saved.
48 |
49 | ## Reference
50 | 1.NVIDIA/tacotron2: https://github.com/NVIDIA/tacotron2
51 | 2.espnet/espnet: https://github.com/espnet/espnet
52 |
--------------------------------------------------------------------------------
/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.signal import get_window
4 | import librosa.util as librosa_util
5 |
6 |
7 | def window_sumsquare(window,
8 | n_frames,
9 | hop_length=200,
10 | win_length=800,
11 | n_fft=800,
12 | dtype=np.float32,
13 | norm=None):
14 | """
15 | # from librosa 0.6
16 | Compute the sum-square envelope of a window function at a given hop length.
17 |
18 | This is used to estimate modulation effects induced by windowing
19 | observations in short-time fourier transforms.
20 |
21 | Parameters
22 | ----------
23 | window : string, tuple, number, callable, or list-like
24 | Window specification, as in `get_window`
25 |
26 | n_frames : int > 0
27 | The number of analysis frames
28 |
29 | hop_length : int > 0
30 | The number of samples to advance between frames
31 |
32 | win_length : [optional]
33 | The length of the window function. By default, this matches `n_fft`.
34 |
35 | n_fft : int > 0
36 | The length of each analysis frame.
37 |
38 | dtype : np.dtype
39 | The data type of the output
40 |
41 | Returns
42 | -------
43 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
44 | The sum-squared envelope of the window function
45 | """
46 | if win_length is None:
47 | win_length = n_fft
48 |
49 | n = n_fft + hop_length * (n_frames - 1)
50 | x = np.zeros(n, dtype=dtype)
51 |
52 | # Compute the squared window at the desired length
53 | win_sq = get_window(window, win_length, fftbins=True)
54 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2
55 | win_sq = librosa_util.pad_center(win_sq, n_fft)
56 |
57 | # Fill the envelope
58 | for i in range(n_frames):
59 | sample = i * hop_length
60 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
61 | return x
62 |
63 |
64 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
65 | """
66 | PARAMS
67 | ------
68 | magnitudes: spectrogram magnitudes
69 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
70 | """
71 |
72 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
73 | angles = angles.astype(np.float32)
74 | angles = torch.autograd.Variable(torch.from_numpy(angles))
75 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
76 |
77 | for i in range(n_iters):
78 | _, angles = stft_fn.transform(signal)
79 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
80 | return signal
81 |
82 |
83 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
84 | """
85 | PARAMS
86 | ------
87 | C: compression factor
88 | """
89 | return torch.log(torch.clamp(x, min=clip_val) * C)
90 |
91 |
92 | def dynamic_range_decompression(x, C=1):
93 | """
94 | PARAMS
95 | ------
96 | C: compression factor used to compress
97 | """
98 | return torch.exp(x) / C
--------------------------------------------------------------------------------
/figures/enc_dec_align.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/enc_dec_align.JPG
--------------------------------------------------------------------------------
/figures/gate_out.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/gate_out.JPG
--------------------------------------------------------------------------------
/figures/melspec.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/melspec.JPG
--------------------------------------------------------------------------------
/figures/train_bce_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/train_bce_loss.JPG
--------------------------------------------------------------------------------
/figures/train_guide_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/train_guide_loss.JPG
--------------------------------------------------------------------------------
/figures/train_mel_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/train_mel_loss.JPG
--------------------------------------------------------------------------------
/figures/val_bce_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/val_bce_loss.JPG
--------------------------------------------------------------------------------
/figures/val_guide_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/val_guide_loss.JPG
--------------------------------------------------------------------------------
/figures/val_mel_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/figures/val_mel_loss.JPG
--------------------------------------------------------------------------------
/generate_samples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Import libraries and setup matplotlib"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os\n",
17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
18 | "\n",
19 | "import warnings\n",
20 | "warnings.filterwarnings(\"ignore\")\n",
21 | "\n",
22 | "import sys\n",
23 | "sys.path.append('waveglow/')\n",
24 | "\n",
25 | "import matplotlib.pyplot as plt\n",
26 | "%matplotlib inline\n",
27 | "\n",
28 | "import IPython.display as ipd\n",
29 | "import pickle as pkl\n",
30 | "from text import *\n",
31 | "import numpy as np\n",
32 | "import torch\n",
33 | "import hparams\n",
34 | "from modules.model import GAE, GraphTTS\n",
35 | "from denoiser import Denoiser\n",
36 | "import soundfile"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### Text preprocessing"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "from g2p_en import G2p\n",
53 | "from text.symbols import symbols\n",
54 | "from text.cleaners import custom_english_cleaners\n",
55 | "\n",
56 | "# Mappings from symbol to numeric ID and vice versa:\n",
57 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n",
58 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n",
59 | "\n",
60 | "g2p = G2p()\n",
61 | "\n",
62 | "def text2seq(text, data_type='char'):\n",
63 | " text = custom_english_cleaners(text.rstrip())\n",
64 | " if data_type=='phone':\n",
65 | " clean_phone = []\n",
66 | " for s in g2p(text.lower()):\n",
67 | " if '@'+s in symbol_to_id:\n",
68 | " clean_phone.append('@'+s)\n",
69 | " else:\n",
70 | " clean_phone.append(s)\n",
71 | " text = clean_phone\n",
72 | " \n",
73 | " # Append SOS, EOS token\n",
74 | " sequence = [symbol_to_id[c] for c in text]\n",
75 | " sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n",
76 | " return sequence\n",
77 | "\n",
78 | "\n",
79 | "def create_adjacency_matrix(char_seq):\n",
80 | " n_nodes=char_seq.size(1)\n",
81 | " n_edge_types=3\n",
82 | " \n",
83 | " a = np.zeros([n_edge_types, n_nodes, n_nodes], dtype=np.int)\n",
84 | " \n",
85 | " a[0] = np.eye(n_nodes, k=1, dtype=int) + np.eye(n_nodes, k=-1, dtype=int)\n",
86 | " a[2] = 1\n",
87 | " \n",
88 | " white_spaces = (char_seq==symbol_to_id[' ']).nonzero()\n",
89 | " start = torch.cat([white_spaces.new_tensor(torch.tensor([0])), white_spaces[1:].view(-1)])\n",
90 | " end = torch.cat([white_spaces[1:].view(-1), white_spaces.new_tensor(torch.tensor([n_nodes]))])\n",
91 | " for i in range(len(start)):\n",
92 | " a[1, start[i]:end[i], start[i]:end[i]]=1\n",
93 | " \n",
94 | " return torch.from_numpy(a).unsqueeze(0)"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {},
100 | "source": [
101 | "### Waveglow"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 3,
107 | "metadata": {
108 | "code_folding": []
109 | },
110 | "outputs": [],
111 | "source": [
112 | "waveglow_path = 'training_log/waveglow_256channels.pt'\n",
113 | "waveglow = torch.load(waveglow_path)['model']\n",
114 | "\n",
115 | "for m in waveglow.modules():\n",
116 | " if 'Conv' in str(type(m)):\n",
117 | " setattr(m, 'padding_mode', 'zeros')\n",
118 | "\n",
119 | "waveglow.cuda().eval()\n",
120 | "for k in waveglow.convinv:\n",
121 | " k.float()\n",
122 | "\n",
123 | "denoiser = Denoiser(waveglow)\n",
124 | "\n",
125 | "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n",
126 | " lines = [line.split('|') for line in f.read().splitlines()]"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 4,
132 | "metadata": {
133 | "scrolled": false
134 | },
135 | "outputs": [],
136 | "source": [
137 | "data_type='char'\n",
138 | "for test_case in ['graph-tts-char', 'graph-tts-char_iter5', 'gae-char']:\n",
139 | " for step in ['20000', '50000', '100000', '200000']:\n",
140 | " checkpoint_path = f\"training_log/{test_case}/checkpoint_{step}\"\n",
141 | " state_dict = {}\n",
142 | " for k, v in torch.load(checkpoint_path)['state_dict'].items():\n",
143 | " state_dict[k[7:]]=v\n",
144 | " \n",
145 | " if 'gae' in test_case:\n",
146 | " model = GAE(hparams).cuda()\n",
147 | " model.load_state_dict(state_dict)\n",
148 | " _ = model.cuda().eval()\n",
149 | " else:\n",
150 | " model = GraphTTS(hparams).cuda()\n",
151 | " model.load_state_dict(state_dict)\n",
152 | " _ = model.cuda().eval()\n",
153 | "\n",
154 | " for i in [1, 6, 22]:\n",
155 | " file_name, _, text = lines[i]\n",
156 | " sequence = np.array(text2seq(text,data_type))[None, :]\n",
157 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n",
158 | " adj_matrix = create_adjacency_matrix(sequence).cuda().long()\n",
159 | " adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1)\n",
160 | " adj_matrix = adj_matrix.transpose(1, 2).reshape(-1, sequence.size(1), sequence.size(1)*3*2)\n",
161 | "\n",
162 | " with torch.no_grad():\n",
163 | " melspec, dec_alignments, enc_dec_alignments, stop = model.inference(sequence,\n",
164 | " adj_matrix,\n",
165 | " max_len=1024)\n",
166 | " melspec = melspec[:,:,:len(stop)]\n",
167 | " audio = waveglow.infer(melspec, sigma=0.666)\n",
168 | "\n",
169 | " soundfile.write(f'wavs/{test_case}/{file_name}_{data_type}{step}.wav', audio.cpu().numpy()[0].astype(float), 22050)"
170 | ]
171 | }
172 | ],
173 | "metadata": {
174 | "kernelspec": {
175 | "display_name": "Python [conda env:LYH] *",
176 | "language": "python",
177 | "name": "conda-env-LYH-py"
178 | },
179 | "language_info": {
180 | "codemirror_mode": {
181 | "name": "ipython",
182 | "version": 3
183 | },
184 | "file_extension": ".py",
185 | "mimetype": "text/x-python",
186 | "name": "python",
187 | "nbconvert_exporter": "python",
188 | "pygments_lexer": "ipython3",
189 | "version": "3.7.5"
190 | },
191 | "varInspector": {
192 | "cols": {
193 | "lenName": 16,
194 | "lenType": 16,
195 | "lenVar": 40
196 | },
197 | "kernels_config": {
198 | "python": {
199 | "delete_cmd_postfix": "",
200 | "delete_cmd_prefix": "del ",
201 | "library": "var_list.py",
202 | "varRefreshCmd": "print(var_dic_list())"
203 | },
204 | "r": {
205 | "delete_cmd_postfix": ") ",
206 | "delete_cmd_prefix": "rm(",
207 | "library": "var_list.r",
208 | "varRefreshCmd": "cat(var_dic_list()) "
209 | }
210 | },
211 | "types_to_exclude": [
212 | "module",
213 | "function",
214 | "builtin_function_or_method",
215 | "instance",
216 | "_Feature"
217 | ],
218 | "window_display": false
219 | }
220 | },
221 | "nbformat": 4,
222 | "nbformat_minor": 2
223 | }
224 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | from text import symbols
2 |
3 | ################################
4 | # Experiment Parameters #
5 | ################################
6 | seed=1234
7 | n_gpus=2
8 | output_directory = 'training_log'
9 | log_directory = 'graph-tts-char'
10 | data_path = '/media/disk1/lyh/LJSpeech-1.1/preprocessed'
11 |
12 | training_files='filelists/ljs_audio_text_train_filelist.txt'
13 | validation_files='filelists/ljs_audio_text_val_filelist.txt'
14 | text_cleaners=['english_cleaners']
15 |
16 |
17 | ################################
18 | # Audio Parameters #
19 | ################################
20 | sampling_rate=22050
21 | filter_length=1024
22 | hop_length=256
23 | win_length=1024
24 | n_mel_channels=80
25 | mel_fmin=0
26 | mel_fmax=8000.0
27 |
28 | ################################
29 | # Model Parameters #
30 | ################################
31 | n_symbols=len(symbols)
32 | data_type='char_seq' # 'phone_seq'
33 | symbols_embedding_dim=256
34 | hidden_dim=256
35 | dprenet_dim=256
36 | postnet_dim=256
37 | ff_dim=1024
38 | n_heads=4
39 | n_layers=6
40 | n_postnet_layers=5
41 | iterations=1
42 |
43 | ################################
44 | # Optimization Hyperparameters #
45 | ################################
46 | lr=384**-0.5
47 | warmup_steps=4000
48 | grad_clip_thresh=1.0
49 | batch_size=32
50 | accumulation=2
51 | iters_per_validation=2000
52 | iters_per_checkpoint=10000
53 | train_steps = 200000
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Grpah-TTS Audio Samples
7 |
8 |
9 |
10 |
11 |
12 |
graph-tts
13 |
14 |
15 |
16 | |
17 | Steps 20000 |
18 | Steps 50000 |
19 | Steps 100000 |
20 | Steps 200000 |
21 |
22 |
23 |
24 |
25 | Steps LJ001-0029 |
26 | |
28 | |
30 | |
32 | |
34 |
35 |
36 | Steps LJ001-0085 |
37 | |
39 | |
41 | |
43 | |
45 |
46 |
47 | Steps LJ002-0106 |
48 | |
50 | |
52 | |
54 | |
56 |
57 |
58 |
59 |
60 |
graph-tts-iter5
61 |
62 |
63 |
64 | |
65 | Steps 20000 |
66 | Steps 50000 |
67 | Steps 100000 |
68 | Steps 200000 |
69 |
70 |
71 |
72 |
73 | Steps LJ001-0029 |
74 | |
76 | |
78 | |
80 | |
82 |
83 |
84 | Steps LJ001-0085 |
85 | |
87 | |
89 | |
91 | |
93 |
94 |
95 | Steps LJ002-0106 |
96 | |
98 | |
100 | |
102 | |
104 |
105 |
106 |
107 |
108 |
gae
109 |
110 |
111 |
112 | |
113 | Steps 20000 |
114 | Steps 50000 |
115 | Steps 100000 |
116 | Steps 200000 |
117 |
118 |
119 |
120 |
121 | Steps LJ001-0029 |
122 | |
124 | |
126 | |
128 | |
130 |
131 |
132 | Steps LJ001-0085 |
133 | |
135 | |
137 | |
139 | |
141 |
142 |
143 | Steps LJ002-0106 |
144 | |
146 | |
148 | |
150 | |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from librosa.filters import mel as librosa_mel_fn
3 | from audio_processing import dynamic_range_compression
4 | from audio_processing import dynamic_range_decompression
5 | from stft import STFT
6 |
7 |
8 | class LinearNorm(torch.nn.Module):
9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
10 | super(LinearNorm, self).__init__()
11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
12 |
13 | torch.nn.init.xavier_uniform_(
14 | self.linear_layer.weight,
15 | gain=torch.nn.init.calculate_gain(w_init_gain))
16 |
17 | def forward(self, x):
18 | return self.linear_layer(x)
19 |
20 |
21 | class ConvNorm(torch.nn.Module):
22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
23 | padding=None, dilation=1, bias=True, w_init_gain='linear'):
24 | super(ConvNorm, self).__init__()
25 | if padding is None:
26 | assert(kernel_size % 2 == 1)
27 | padding = int(dilation * (kernel_size - 1) / 2)
28 |
29 | self.conv = torch.nn.Conv1d(in_channels, out_channels,
30 | kernel_size=kernel_size, stride=stride,
31 | padding=padding, dilation=dilation,
32 | bias=bias)
33 |
34 | torch.nn.init.xavier_uniform_(
35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
36 |
37 | def forward(self, signal):
38 | conv_signal = self.conv(signal)
39 | return conv_signal
40 |
41 |
42 | class TacotronSTFT(torch.nn.Module):
43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
45 | mel_fmax=8000.0):
46 | super(TacotronSTFT, self).__init__()
47 | self.n_mel_channels = n_mel_channels
48 | self.sampling_rate = sampling_rate
49 | self.stft_fn = STFT(filter_length, hop_length, win_length)
50 | mel_basis = librosa_mel_fn(
51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
52 | mel_basis = torch.from_numpy(mel_basis).float()
53 | self.register_buffer('mel_basis', mel_basis)
54 |
55 | def spectral_normalize(self, magnitudes):
56 | output = dynamic_range_compression(magnitudes)
57 | return output
58 |
59 | def spectral_de_normalize(self, magnitudes):
60 | output = dynamic_range_decompression(magnitudes)
61 | return output
62 |
63 | def mel_spectrogram(self, y):
64 | """Computes mel-spectrograms from a batch of waves
65 | PARAMS
66 | ------
67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
68 | RETURNS
69 | -------
70 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
71 | """
72 | assert(torch.min(y.data) >= -1)
73 | assert(torch.max(y.data) <= 1)
74 |
75 | magnitudes, phases = self.stft_fn.transform(y)
76 | magnitudes = magnitudes.data
77 | mel_output = torch.matmul(self.mel_basis, magnitudes)
78 | mel_output = self.spectral_normalize(mel_output)
79 | return mel_output
--------------------------------------------------------------------------------
/modules/__pycache__/init_layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/init_layer.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/modules/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/init_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Linear(nn.Linear):
7 | def __init__(self,
8 | in_dim,
9 | out_dim,
10 | bias=True,
11 | w_init_gain='linear'):
12 | super(Linear, self).__init__(in_dim,
13 | out_dim,
14 | bias)
15 | nn.init.xavier_uniform_(self.weight,
16 | gain=nn.init.calculate_gain(w_init_gain))
17 |
18 |
19 | class Conv1d(nn.Conv1d):
20 | def __init__(self,
21 | in_channels,
22 | out_channels,
23 | kernel_size,
24 | stride=1,
25 | padding=0,
26 | dilation=1,
27 | groups=1,
28 | bias=True,
29 | padding_mode='zeros',
30 | w_init_gain='linear'):
31 | super(Conv1d, self).__init__(in_channels,
32 | out_channels,
33 | kernel_size,
34 | stride,
35 | padding,
36 | dilation,
37 | groups,
38 | bias,
39 | padding_mode)
40 | nn.init.xavier_uniform_(self.weight,
41 | gain=nn.init.calculate_gain(w_init_gain))
--------------------------------------------------------------------------------
/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils.utils import get_mask_from_lengths
4 |
5 |
6 | class TransformerLoss(nn.Module):
7 | def __init__(self):
8 | super(TransformerLoss, self).__init__()
9 |
10 | def forward(self, pred, target, guide):
11 | mel_out, mel_out_post, gate_out = pred
12 | mel_target, gate_target = target
13 | alignments, text_lengths, mel_lengths = guide
14 |
15 | mask = ~get_mask_from_lengths(mel_lengths)
16 |
17 | mel_target = mel_target.masked_select(mask.unsqueeze(1))
18 | mel_out_post = mel_out_post.masked_select(mask.unsqueeze(1))
19 | mel_out = mel_out.masked_select(mask.unsqueeze(1))
20 |
21 | gate_target = gate_target.masked_select(mask)
22 | gate_out = gate_out.masked_select(mask)
23 |
24 | mel_loss = nn.L1Loss()(mel_out, mel_target) + nn.L1Loss()(mel_out_post, mel_target)
25 | bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))(gate_out, gate_target)
26 | guide_loss = self.guide_loss(alignments, text_lengths, mel_lengths)
27 |
28 | return mel_loss, bce_loss, guide_loss
29 |
30 |
31 | def guide_loss(self, alignments, text_lengths, mel_lengths):
32 | B, n_layers, n_heads, T, L = alignments.size()
33 |
34 | # B, T, L
35 | W = alignments.new_zeros(B, T, L)
36 | mask = alignments.new_zeros(B, T, L)
37 |
38 | for i, (t, l) in enumerate(zip(mel_lengths, text_lengths)):
39 | mel_seq = alignments.new_tensor( torch.arange(t).to(torch.float32).unsqueeze(-1)/t )
40 | text_seq = alignments.new_tensor( torch.arange(l).to(torch.float32).unsqueeze(0)/l )
41 | x = torch.pow(mel_seq-text_seq, 2)
42 | W[i, :t, :l] += alignments.new_tensor(1-torch.exp(-3.125*x))
43 | mask[i, :t, :l] = 1
44 |
45 | # Apply guided_loss to 2 heads of the last 2 layers
46 | applied_align = alignments[:, -2:, :2]
47 | losses = applied_align*(W.unsqueeze(1).unsqueeze(1))
48 |
49 | return torch.mean(losses.masked_select(mask.unsqueeze(1).unsqueeze(1).to(torch.bool)))
50 |
--------------------------------------------------------------------------------
/modules/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .init_layer import *
5 | from .transformer import *
6 | from utils.utils import get_mask_from_lengths
7 |
8 |
9 | class CBAD(nn.Module):
10 | def __init__(self,
11 | in_dim,
12 | out_dim,
13 | kernel_size,
14 | stride,
15 | padding,
16 | bias,
17 | activation,
18 | dropout):
19 | super(CBAD, self).__init__()
20 | self.conv = Conv1d(in_dim,
21 | out_dim,
22 | kernel_size=kernel_size,
23 | stride=stride,
24 | padding=padding,
25 | bias=bias,
26 | w_init_gain=activation)
27 |
28 | self.bn = nn.BatchNorm1d(out_dim)
29 |
30 | if activation == 'relu':
31 | self.activation = nn.ReLU()
32 | elif activation == 'tanh':
33 | self.activation = nn.Tanh()
34 |
35 | self.dropout = nn.Dropout(dropout)
36 |
37 | def forward(self, x):
38 | x = self.conv(x)
39 | x = self.bn(x)
40 | x = self.activation(x)
41 | out = self.dropout(x)
42 |
43 | return out
44 |
45 | class Prenet_D(nn.Module):
46 | def __init__(self, hp):
47 | super(Prenet_D, self).__init__()
48 | self.linear1 = Linear(hp.n_mel_channels,
49 | hp.dprenet_dim,
50 | w_init_gain='relu')
51 | self.linear2 = Linear(hp.dprenet_dim, hp.dprenet_dim, w_init_gain='relu')
52 | self.linear3 = Linear(hp.dprenet_dim, hp.hidden_dim)
53 |
54 | def forward(self, x):
55 | # Set training==True following tacotron2
56 | x = F.dropout(F.relu(self.linear1(x)), p=0.5, training=True)
57 | x = F.dropout(F.relu(self.linear2(x)), p=0.5, training=True)
58 | x = self.linear3(x)
59 | return x
60 |
61 | class PostNet(nn.Module):
62 | def __init__(self, hp):
63 | super(PostNet, self).__init__()
64 | conv_list = [CBAD(in_dim=hp.n_mel_channels,
65 | out_dim=hp.postnet_dim,
66 | kernel_size=5,
67 | stride=1,
68 | padding=2,
69 | bias=False,
70 | activation='tanh',
71 | dropout=0.5)]
72 |
73 | for _ in range(hp.n_postnet_layers-2):
74 | conv_list.append(CBAD(in_dim=hp.postnet_dim,
75 | out_dim=hp.postnet_dim,
76 | kernel_size=5,
77 | stride=1,
78 | padding=2,
79 | bias=False,
80 | activation='tanh',
81 | dropout=0.5))
82 |
83 | conv_list.append(nn.Sequential(nn.Conv1d(hp.postnet_dim,
84 | hp.n_mel_channels,
85 | kernel_size=5,
86 | padding=2,
87 | bias=False),
88 | nn.BatchNorm1d(hp.n_mel_channels),
89 | nn.Dropout(0.5)))
90 |
91 | self.conv=nn.ModuleList(conv_list)
92 |
93 | def forward(self, x):
94 | for conv in self.conv:
95 | x = conv(x)
96 | return x
97 |
98 | class GraphEncoder(nn.Module):
99 | def __init__(self, hp):
100 | super(GraphEncoder, self).__init__()
101 | self.iterations = hp.iterations
102 | self.forward_edges = Linear(hp.hidden_dim, hp.hidden_dim*3, bias=False)
103 | self.backward_edges = Linear(hp.hidden_dim, hp.hidden_dim*3, bias=False)
104 |
105 | self.reset_gate = nn.Sequential(
106 | Linear(hp.hidden_dim*3, hp.hidden_dim, w_init_gain='sigmoid'),
107 | nn.Sigmoid()
108 | )
109 | self.update_gate = nn.Sequential(
110 | Linear(hp.hidden_dim*3, hp.hidden_dim, w_init_gain='sigmoid'),
111 | nn.Sigmoid()
112 | )
113 | self.tansform = nn.Sequential(
114 | Linear(hp.hidden_dim*3, hp.hidden_dim, w_init_gain='tanh'),
115 | nn.Tanh()
116 | )
117 |
118 | def forward(self, x, adj_matrix):
119 | x = x.transpose(0,1)
120 | adj_matrix = adj_matrix / adj_matrix.sum(dim=-1).unsqueeze(-1) # B, N, 6N
121 | A_in = adj_matrix[:, :, :adj_matrix.size(2)//2].to(torch.float) # B, N, 3N
122 | A_out = adj_matrix[:, :, adj_matrix.size(2)//2:].to(torch.float) # B, N, 3N
123 |
124 | for k in range(self.iterations):
125 | H_in = torch.cat(torch.chunk(self.forward_edges(x), chunks=3, dim=-1), dim=1)
126 | H_out = torch.cat(torch.chunk(self.backward_edges(x), chunks=3, dim=-1), dim=1)
127 |
128 | a_in = torch.bmm(A_in, H_in)
129 | a_out = torch.bmm(A_out, H_out)
130 | a = torch.cat((a_in, a_out, x), dim=-1)
131 |
132 | r = self.reset_gate(a)
133 | z = self.update_gate(a)
134 | joined_input = torch.cat((a_in, a_out, r * x), dim=-1)
135 | h_hat = self.tansform(joined_input)
136 |
137 | x = (1 - z) * x + z * h_hat
138 |
139 | return x.transpose(0,1)
140 |
141 |
142 | class GraphTTS(nn.Module):
143 | def __init__(self, hp):
144 | super(GraphTTS, self).__init__()
145 | self.hp = hp
146 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim)
147 | self.Prenet_D = Prenet_D(hp)
148 |
149 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe)
150 | self.dropout = nn.Dropout(0.1)
151 |
152 | self.Encoder = GraphEncoder(hp)
153 | self.Decoder = nn.ModuleList([TransformerDecoderLayer(d_model=hp.hidden_dim,
154 | nhead=hp.n_heads,
155 | dim_feedforward=hp.ff_dim)
156 | for _ in range(hp.n_layers)])
157 |
158 | self.Projection = Linear(hp.hidden_dim, hp.n_mel_channels)
159 | self.Postnet = PostNet(hp)
160 | self.Stop = nn.Linear(hp.n_mel_channels, 1)
161 |
162 |
163 | def outputs(self, text, adj_matrix, melspec, text_lengths, mel_lengths):
164 | ### Size ###
165 | B, L, T = text.size(0), text.size(1), melspec.size(2)
166 | adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1)
167 | adj_matrix = adj_matrix.transpose(1, 2).reshape(-1, text_lengths.max().item(), text_lengths.max().item()*3*2)
168 |
169 | ### Prepare Encoder Input ###
170 | encoder_input = self.Embedding(text).transpose(0,1)
171 | encoder_input += self.pe[:L].unsqueeze(1)
172 | encoder_input = self.dropout(encoder_input)
173 | memory = self.Encoder(encoder_input, adj_matrix)
174 |
175 | ### Prepare Decoder Input ###
176 | mel_input = F.pad(melspec, (1,-1)).transpose(1,2)
177 | decoder_input = self.Prenet_D(mel_input).transpose(0,1)
178 | decoder_input += self.pe[:T].unsqueeze(1)
179 | decoder_input = self.dropout(decoder_input)
180 |
181 | ### Prepare Masks ###
182 | text_mask = get_mask_from_lengths(text_lengths)
183 | mel_mask = get_mask_from_lengths(mel_lengths)
184 | diag_mask = torch.triu(melspec.new_ones(T,T)).transpose(0, 1)
185 | diag_mask[diag_mask == 0] = -float('inf')
186 | diag_mask[diag_mask == 1] = 0
187 |
188 | ### Decoding ###
189 | tgt = decoder_input
190 | dec_alignments, enc_dec_alignments = [], []
191 | for layer in self.Decoder:
192 | tgt, dec_align, enc_dec_align = layer(tgt,
193 | memory,
194 | tgt_mask=diag_mask,
195 | tgt_key_padding_mask=mel_mask,
196 | memory_key_padding_mask=text_mask)
197 | dec_alignments.append(dec_align.unsqueeze(1))
198 | enc_dec_alignments.append(enc_dec_align.unsqueeze(1))
199 | dec_alignments = torch.cat(dec_alignments, 1)
200 | enc_dec_alignments = torch.cat(enc_dec_alignments, 1)
201 |
202 | ### Projection + PostNet ###
203 | mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2)
204 | mel_out_post = self.Postnet(mel_out) + mel_out
205 |
206 | gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1)
207 |
208 | return mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out
209 |
210 |
211 | def forward(self, text, adj_matrix, melspec, gate, text_lengths, mel_lengths, criterion):
212 | ### Size ###
213 | text = text[:,:text_lengths.max().item()]
214 | adj_matrix = adj_matrix[:, :, :text_lengths.max().item(), :text_lengths.max().item()]
215 |
216 | melspec = melspec[:,:,:mel_lengths.max().item()]
217 | gate = gate[:, :mel_lengths.max().item()]
218 | outputs = self.outputs(text, adj_matrix, melspec, text_lengths, mel_lengths)
219 |
220 | mel_out, mel_out_post = outputs[0], outputs[1]
221 | enc_dec_alignments = outputs[3]
222 | gate_out=outputs[4]
223 |
224 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
225 | (melspec, gate),
226 | (enc_dec_alignments, text_lengths, mel_lengths))
227 |
228 | return mel_loss, bce_loss, guide_loss
229 |
230 |
231 | def inference(self, text, adj_matrix, max_len=1024):
232 | ### Size & Length ###
233 | (B, L), T = text.size(), max_len
234 |
235 | ### Prepare Inputs ###
236 | encoder_input = self.Embedding(text).transpose(0,1).contiguous()
237 | encoder_input += self.pe[:L].unsqueeze(1)
238 | memory = self.Encoder(encoder_input, adj_matrix)
239 |
240 | ### Prepare Masks ###
241 | text_mask = text.new_zeros(1, L).to(torch.bool)
242 | mel_mask = text.new_zeros(1, T).to(torch.bool)
243 | diag_mask = torch.triu(text.new_ones(T, T)).transpose(0, 1).contiguous()
244 | diag_mask[diag_mask == 0] = -1e9
245 | diag_mask[diag_mask == 1] = 0
246 |
247 | ### Transformer Decoder ###
248 | mel_input = text.new_zeros(1,
249 | self.hp.n_mel_channels,
250 | max_len).to(torch.float32)
251 | dec_alignments = text.new_zeros(self.hp.n_layers,
252 | self.hp.n_heads,
253 | max_len,
254 | max_len).to(torch.float32)
255 | enc_dec_alignments = text.new_zeros(self.hp.n_layers,
256 | self.hp.n_heads,
257 | max_len,
258 | text.size(1)).to(torch.float32)
259 |
260 | ### Generation ###
261 | stop=[]
262 | for i in range(max_len):
263 | tgt = self.Prenet_D(mel_input.transpose(1,2).contiguous()).transpose(0,1).contiguous()
264 | tgt += self.pe[:T].unsqueeze(1)
265 |
266 | for j, layer in enumerate(self.Decoder):
267 | tgt, dec_align, enc_dec_align = layer(tgt,
268 | memory,
269 | tgt_mask=diag_mask,
270 | tgt_key_padding_mask=mel_mask,
271 | memory_key_padding_mask=text_mask)
272 | dec_alignments[j, :, i] = dec_align[0, :, i]
273 | enc_dec_alignments[j, :, i] = enc_dec_align[0, :, i]
274 |
275 | mel_out = self.Projection(tgt.transpose(0,1).contiguous())
276 | stop.append(torch.sigmoid(self.Stop(mel_out[:,i]))[0,0].item())
277 |
278 | if i < max_len - 1:
279 | mel_input[0, :, i+1] = mel_out[0, i]
280 |
281 | if stop[-1]>0.5:
282 | break
283 |
284 | mel_out_post = self.Postnet(mel_out.transpose(1, 2).contiguous())
285 | mel_out_post = mel_out_post.transpose(1, 2).contiguous() + mel_out
286 | mel_out_post = mel_out_post.transpose(1, 2).contiguous()
287 |
288 | return mel_out_post, dec_alignments, enc_dec_alignments, stop
289 |
290 |
291 | class GAE(nn.Module):
292 | def __init__(self, hp):
293 | super(GAE, self).__init__()
294 | self.hp = hp
295 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim)
296 | self.Prenet_D = Prenet_D(hp)
297 |
298 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe)
299 | self.dropout = nn.Dropout(0.1)
300 |
301 | self.Encoder1 = nn.ModuleList([TransformerEncoderLayer(d_model=hp.hidden_dim,
302 | nhead=hp.n_heads,
303 | dim_feedforward=hp.ff_dim)
304 | for _ in range(hp.n_layers)])
305 |
306 | self.Encoder2 = GraphEncoder(hp)
307 | self.linear = nn.Linear(hp.hidden_dim*2, hp.hidden_dim)
308 | self.Decoder = nn.ModuleList([TransformerDecoderLayer(d_model=hp.hidden_dim,
309 | nhead=hp.n_heads,
310 | dim_feedforward=hp.ff_dim)
311 | for _ in range(hp.n_layers)])
312 |
313 | self.Projection = Linear(hp.hidden_dim, hp.n_mel_channels)
314 | self.Postnet = PostNet(hp)
315 | self.Stop = nn.Linear(hp.n_mel_channels, 1)
316 |
317 |
318 | def outputs(self, text, adj_matrix, melspec, text_lengths, mel_lengths):
319 | ### Size ###
320 | B, L, T = text.size(0), text.size(1), melspec.size(2)
321 | adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1)
322 | adj_matrix = adj_matrix.transpose(1, 2).reshape(-1, text_lengths.max().item(), text_lengths.max().item()*3*2)
323 |
324 | ### Prepare Encoder Input ###
325 | encoder_input = self.Embedding(text).transpose(0,1)
326 | encoder_input += self.pe[:L].unsqueeze(1)
327 | encoder_input = self.dropout(encoder_input)
328 |
329 | ### Transformer Encoder ###
330 | memory1 = encoder_input
331 | enc_alignments = []
332 | text_mask = get_mask_from_lengths(text_lengths)
333 | for layer in self.Encoder1:
334 | memory1, enc_align = layer(memory1, src_key_padding_mask=text_mask)
335 | enc_alignments.append(enc_align.unsqueeze(1))
336 | enc_alignments = torch.cat(enc_alignments, 1)
337 |
338 | memory2 = self.Encoder2(encoder_input, adj_matrix)
339 | memory = self.linear(torch.cat([memory1, memory2], dim=-1))
340 |
341 | ### Prepare Decoder Input ###
342 | mel_input = F.pad(melspec, (1,-1)).transpose(1,2)
343 | decoder_input = self.Prenet_D(mel_input).transpose(0,1)
344 | decoder_input += self.pe[:T].unsqueeze(1)
345 | decoder_input = self.dropout(decoder_input)
346 |
347 | ### Prepare Masks ###
348 | mel_mask = get_mask_from_lengths(mel_lengths)
349 | diag_mask = torch.triu(melspec.new_ones(T,T)).transpose(0, 1)
350 | diag_mask[diag_mask == 0] = -float('inf')
351 | diag_mask[diag_mask == 1] = 0
352 |
353 | ### Decoding ###
354 | tgt = decoder_input
355 | dec_alignments, enc_dec_alignments = [], []
356 | for layer in self.Decoder:
357 | tgt, dec_align, enc_dec_align = layer(tgt,
358 | memory,
359 | tgt_mask=diag_mask,
360 | tgt_key_padding_mask=mel_mask,
361 | memory_key_padding_mask=text_mask)
362 | dec_alignments.append(dec_align.unsqueeze(1))
363 | enc_dec_alignments.append(enc_dec_align.unsqueeze(1))
364 | dec_alignments = torch.cat(dec_alignments, 1)
365 | enc_dec_alignments = torch.cat(enc_dec_alignments, 1)
366 |
367 | ### Projection + PostNet ###
368 | mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2)
369 | mel_out_post = self.Postnet(mel_out) + mel_out
370 |
371 | gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1)
372 |
373 | return mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out
374 |
375 |
376 | def forward(self, text, adj_matrix, melspec, gate, text_lengths, mel_lengths, criterion):
377 | ### Size ###
378 | text = text[:,:text_lengths.max().item()]
379 | adj_matrix = adj_matrix[:, :, :text_lengths.max().item(), :text_lengths.max().item()]
380 |
381 | melspec = melspec[:,:,:mel_lengths.max().item()]
382 | gate = gate[:, :mel_lengths.max().item()]
383 | outputs = self.outputs(text, adj_matrix, melspec, text_lengths, mel_lengths)
384 |
385 | mel_out, mel_out_post = outputs[0], outputs[1]
386 | enc_dec_alignments = outputs[3]
387 | gate_out=outputs[4]
388 |
389 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
390 | (melspec, gate),
391 | (enc_dec_alignments, text_lengths, mel_lengths))
392 |
393 | return mel_loss, bce_loss, guide_loss
394 |
395 |
396 | def inference(self, text,adj_matrix, max_len=1024):
397 | ### Size & Length ###
398 | (B, L), T = text.size(), max_len
399 |
400 | ### Prepare Inputs ###
401 | encoder_input = self.Embedding(text).transpose(0,1).contiguous()
402 | encoder_input += self.pe[:L].unsqueeze(1)
403 |
404 | memory1 = encoder_input
405 | enc_alignments = []
406 | text_mask = text.new_zeros(1, L).to(torch.bool)
407 | for layer in self.Encoder1:
408 | memory1, enc_align = layer(memory1, src_key_padding_mask=text_mask)
409 | enc_alignments.append(enc_align.unsqueeze(1))
410 | enc_alignments = torch.cat(enc_alignments, 1)
411 |
412 | memory2 = self.Encoder2(encoder_input, adj_matrix)
413 | memory = self.linear(torch.cat([memory1, memory2], dim=-1))
414 |
415 | ### Prepare Masks ###
416 | mel_mask = text.new_zeros(1, T).to(torch.bool)
417 | diag_mask = torch.triu(text.new_ones(T, T)).transpose(0, 1).contiguous()
418 | diag_mask[diag_mask == 0] = -1e9
419 | diag_mask[diag_mask == 1] = 0
420 |
421 | ### Transformer Decoder ###
422 | mel_input = text.new_zeros(1,
423 | self.hp.n_mel_channels,
424 | max_len).to(torch.float32)
425 | dec_alignments = text.new_zeros(self.hp.n_layers,
426 | self.hp.n_heads,
427 | max_len,
428 | max_len).to(torch.float32)
429 | enc_dec_alignments = text.new_zeros(self.hp.n_layers,
430 | self.hp.n_heads,
431 | max_len,
432 | text.size(1)).to(torch.float32)
433 |
434 | ### Generation ###
435 | stop=[]
436 | for i in range(max_len):
437 | tgt = self.Prenet_D(mel_input.transpose(1,2).contiguous()).transpose(0,1).contiguous()
438 | tgt += self.pe[:T].unsqueeze(1)
439 |
440 | for j, layer in enumerate(self.Decoder):
441 | tgt, dec_align, enc_dec_align = layer(tgt,
442 | memory,
443 | tgt_mask=diag_mask,
444 | tgt_key_padding_mask=mel_mask,
445 | memory_key_padding_mask=text_mask)
446 | dec_alignments[j, :, i] = dec_align[0, :, i]
447 | enc_dec_alignments[j, :, i] = enc_dec_align[0, :, i]
448 |
449 | mel_out = self.Projection(tgt.transpose(0,1).contiguous())
450 | stop.append(torch.sigmoid(self.Stop(mel_out[:,i]))[0,0].item())
451 |
452 | if i < max_len - 1:
453 | mel_input[0, :, i+1] = mel_out[0, i]
454 |
455 | if stop[-1]>0.5:
456 | break
457 |
458 | mel_out_post = self.Postnet(mel_out.transpose(1, 2).contiguous())
459 | mel_out_post = mel_out_post.transpose(1, 2).contiguous() + mel_out
460 | mel_out_post = mel_out_post.transpose(1, 2).contiguous()
461 |
462 | return mel_out_post, dec_alignments, enc_dec_alignments, stop
--------------------------------------------------------------------------------
/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .init_layer import *
5 |
6 |
7 | class TransformerEncoderLayer(nn.Module):
8 | def __init__(self,
9 | d_model,
10 | nhead,
11 | dim_feedforward=2048,
12 | dropout=0.1,
13 | activation="relu"):
14 | super(TransformerEncoderLayer, self).__init__()
15 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
16 |
17 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation)
18 | self.linear2 = Linear(dim_feedforward, d_model)
19 |
20 | self.norm1 = nn.LayerNorm(d_model)
21 | self.norm2 = nn.LayerNorm(d_model)
22 |
23 | self.dropout = nn.Dropout(dropout)
24 |
25 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
26 | src2, enc_align = self.self_attn(src,
27 | src,
28 | src,
29 | attn_mask=src_mask,
30 | key_padding_mask=src_key_padding_mask)
31 | src = src + self.dropout(src2)
32 | src = self.norm1(src)
33 |
34 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
35 | src = src + self.dropout(src2)
36 | src = self.norm2(src)
37 |
38 | return src, enc_align
39 |
40 |
41 | class TransformerDecoderLayer(nn.Module):
42 | def __init__(self,
43 | d_model,
44 | nhead,
45 | dim_feedforward=2048,
46 | dropout=0.1,
47 | activation="relu"):
48 | super(TransformerDecoderLayer, self).__init__()
49 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
50 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
51 |
52 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation)
53 | self.linear2 = Linear(dim_feedforward, d_model)
54 |
55 | self.norm1 = nn.LayerNorm(d_model)
56 | self.norm2 = nn.LayerNorm(d_model)
57 | self.norm3 = nn.LayerNorm(d_model)
58 |
59 | self.dropout = nn.Dropout(dropout)
60 |
61 | def forward(self,
62 | tgt,
63 | memory,
64 | tgt_mask=None,
65 | memory_mask=None,
66 | tgt_key_padding_mask=None,
67 | memory_key_padding_mask=None):
68 | tgt2, dec_align = self.self_attn(tgt,
69 | tgt,
70 | tgt,
71 | attn_mask=tgt_mask,
72 | key_padding_mask=tgt_key_padding_mask)
73 | tgt = tgt + self.dropout(tgt2)
74 | tgt = self.norm1(tgt)
75 |
76 | tgt2, enc_dec_align = self.multihead_attn(tgt,
77 | memory,
78 | memory,
79 | attn_mask=memory_mask,
80 | key_padding_mask=memory_key_padding_mask)
81 | tgt = tgt + self.dropout(tgt2)
82 | tgt = self.norm2(tgt)
83 |
84 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
85 | tgt = tgt + self.dropout(tgt2)
86 | tgt = self.norm3(tgt)
87 |
88 | return tgt, dec_align, enc_dec_align
89 |
90 |
91 | class PositionalEncoding(nn.Module):
92 | def __init__(self, d_model, max_len=5000):
93 | super(PositionalEncoding, self).__init__()
94 | self.register_buffer('pe', self._get_pe_matrix(d_model, max_len))
95 |
96 | def forward(self, x):
97 | return x + self.pe[:x.size(0)].unsqueeze(1)
98 |
99 | def _get_pe_matrix(self, d_model, max_len):
100 | pe = torch.zeros(max_len, d_model)
101 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
102 | div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model)
103 |
104 | pe[:, 0::2] = torch.sin(position / div_term)
105 | pe[:, 1::2] = torch.cos(position / div_term)
106 |
107 | return pe
108 |
--------------------------------------------------------------------------------
/prepare_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### Import libraries, metadata"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {
14 | "ExecuteTime": {
15 | "end_time": "2019-12-18T06:01:56.263558Z",
16 | "start_time": "2019-12-18T06:01:51.717351Z"
17 | }
18 | },
19 | "outputs": [
20 | {
21 | "name": "stderr",
22 | "output_type": "stream",
23 | "text": [
24 | "/home/lyh/anaconda3/envs/LYH/lib/python3.7/site-packages/librosa/util/decorators.py:9: NumbaDeprecationWarning: \u001b[1mAn import was requested from a module that has moved location.\n",
25 | "Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.\u001b[0m\n",
26 | " from numba.decorators import jit as optional_jit\n",
27 | "/home/lyh/anaconda3/envs/LYH/lib/python3.7/site-packages/librosa/util/decorators.py:9: NumbaDeprecationWarning: \u001b[1mAn import was requested from a module that has moved location.\n",
28 | "Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.\u001b[0m\n",
29 | " from numba.decorators import jit as optional_jit\n"
30 | ]
31 | }
32 | ],
33 | "source": [
34 | "import os\n",
35 | "import librosa\n",
36 | "from librosa.filters import mel as librosa_mel_fn\n",
37 | "import pickle as pkl\n",
38 | "import IPython.display as ipd\n",
39 | "from tqdm.notebook import tqdm\n",
40 | "import torch\n",
41 | "import numpy as np\n",
42 | "import codecs\n",
43 | "import matplotlib.pyplot as plt\n",
44 | "%matplotlib inline\n",
45 | "\n",
46 | "from g2p_en import G2p\n",
47 | "from text import *\n",
48 | "from text import cmudict\n",
49 | "from text.cleaners import custom_english_cleaners\n",
50 | "from text.symbols import symbols\n",
51 | "\n",
52 | "# Mappings from symbol to numeric ID and vice versa:\n",
53 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n",
54 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n",
55 | "\n",
56 | "csv_file = '/media/disk1/lyh/LJSpeech-1.1/metadata.csv'\n",
57 | "root_dir = '/media/disk1/lyh/LJSpeech-1.1/wavs'\n",
58 | "data_dir = '/media/disk1/lyh/LJSpeech-1.1/preprocessed'\n",
59 | "\n",
60 | "g2p = G2p()\n",
61 | "metadata={}\n",
62 | "with codecs.open(csv_file, 'r', 'utf-8') as fid:\n",
63 | " for line in fid.readlines():\n",
64 | " id, _, text = line.split(\"|\")\n",
65 | " \n",
66 | " clean_char = custom_english_cleaners(text.rstrip())\n",
67 | " clean_phone = []\n",
68 | " for s in g2p(clean_char.lower()):\n",
69 | " if '@'+s in symbol_to_id:\n",
70 | " clean_phone.append('@'+s)\n",
71 | " else:\n",
72 | " clean_phone.append(s)\n",
73 | " \n",
74 | " metadata[id]={'char':clean_char,\n",
75 | " 'phone':clean_phone}"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "### Others"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "scrolled": true
90 | },
91 | "outputs": [],
92 | "source": [
93 | "from layers import TacotronSTFT\n",
94 | "stft = TacotronSTFT()\n",
95 | "\n",
96 | "def text2seq(text):\n",
97 | " sequence=[symbol_to_id['^']]\n",
98 | " sequence.extend([symbol_to_id[c] for c in text])\n",
99 | " sequence.append(symbol_to_id['~'])\n",
100 | " return sequence\n",
101 | "\n",
102 | "def create_adjacency_matrix(char_seq):\n",
103 | " n_nodes=len(char_seq)\n",
104 | " n_edge_types=3\n",
105 | " \n",
106 | " a = np.zeros([n_edge_types, n_nodes, n_nodes], dtype=np.int)\n",
107 | " \n",
108 | " a[0] = np.eye(n_nodes, k=1, dtype=int) + np.eye(n_nodes, k=-1, dtype=int)\n",
109 | " a[2] = 1\n",
110 | " \n",
111 | " white_spaces = (char_seq==symbol_to_id[' ']).nonzero()\n",
112 | " start = torch.cat([torch.tensor([0]), white_spaces[1:].view(-1)])\n",
113 | " end = torch.cat([white_spaces[1:].view(-1), torch.tensor([n_nodes])])\n",
114 | " for i in range(len(start)):\n",
115 | " a[1, start[i]:end[i], start[i]:end[i]]=1\n",
116 | " \n",
117 | " # a = np.concatenate((a, a), axis=0)\n",
118 | " # a = np.transpose(a, (1, 0, 2)).reshape(n_nodes, n_nodes*n_edge_types*2)\n",
119 | " return a\n",
120 | "\n",
121 | "def get_mel(filename):\n",
122 | " wav, sr = librosa.load(filename, sr=22050)\n",
123 | " wav = torch.FloatTensor(wav.astype(np.float32))\n",
124 | " melspec = stft.mel_spectrogram(wav.unsqueeze(0))\n",
125 | " return melspec.squeeze(0), wav\n",
126 | "\n",
127 | "def save_file(fname):\n",
128 | " wav_name = os.path.join(root_dir, fname) + '.wav'\n",
129 | " text = metadata[fname]['char']\n",
130 | " char_seq = torch.LongTensor( text2seq(metadata[fname]['char']) )\n",
131 | " phone_seq = torch.LongTensor( text2seq(metadata[fname]['phone']) )\n",
132 | " adj_matrix = create_adjacency_matrix(char_seq)\n",
133 | " melspec, wav = get_mel(wav_name)\n",
134 | " \n",
135 | " with open(f'{data_dir}/char_seq/{fname}_sequence.pkl', 'wb') as f:\n",
136 | " pkl.dump(char_seq, f)\n",
137 | " with open(f'{data_dir}/phone_seq/{fname}_sequence.pkl', 'wb') as f:\n",
138 | " pkl.dump(phone_seq, f)\n",
139 | " with open(f'{data_dir}/adj_matrix/{fname}_adj_matrix.pkl', 'wb') as f:\n",
140 | " pkl.dump(adj_matrix, f)\n",
141 | " with open(f'{data_dir}/melspectrogram/{fname}_melspectrogram.pkl', 'wb') as f:\n",
142 | " pkl.dump(melspec, f)\n",
143 | " \n",
144 | " return text, char_seq, phone_seq, adj_matrix, melspec, wav"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "### Save and Inspect Data"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {
158 | "ExecuteTime": {
159 | "start_time": "2019-12-18T06:01:23.402Z"
160 | },
161 | "code_folding": [],
162 | "scrolled": true
163 | },
164 | "outputs": [],
165 | "source": [
166 | "for k in tqdm(metadata.keys()):\n",
167 | " text, char_seq, phone_seq, adj_matrix, melspec, wav = save_file(k)\n",
168 | " if k == 'LJ001-0019':\n",
169 | " print(\"Text:\")\n",
170 | " print(text)\n",
171 | " print()\n",
172 | " print(\"Melspectrogram:\")\n",
173 | " plt.figure(figsize=(16,4))\n",
174 | " plt.imshow(melspec, aspect='auto', origin='lower')\n",
175 | " plt.show()"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 2,
181 | "metadata": {},
182 | "outputs": [
183 | {
184 | "data": {
185 | "application/vnd.jupyter.widget-view+json": {
186 | "model_id": "34e6927c94f245669e35ed76ef0fbff0",
187 | "version_major": 2,
188 | "version_minor": 0
189 | },
190 | "text/plain": [
191 | "HBox(children=(FloatProgress(value=0.0, max=13100.0), HTML(value='')))"
192 | ]
193 | },
194 | "metadata": {},
195 | "output_type": "display_data"
196 | },
197 | {
198 | "name": "stderr",
199 | "output_type": "stream",
200 | "text": [
201 | "/home/lyh/anaconda3/envs/LYH/lib/python3.7/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n",
202 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n"
203 | ]
204 | },
205 | {
206 | "name": "stdout",
207 | "output_type": "stream",
208 | "text": [
209 | "\n"
210 | ]
211 | }
212 | ],
213 | "source": [
214 | "def text2seq(text):\n",
215 | " sequence=[symbol_to_id['^']]\n",
216 | " sequence.extend([symbol_to_id[c] for c in text])\n",
217 | " sequence.append(symbol_to_id['~'])\n",
218 | " return sequence\n",
219 | "\n",
220 | "def create_adjacency_matrix(char_seq):\n",
221 | " n_nodes=len(char_seq)\n",
222 | " n_edge_types=3\n",
223 | " \n",
224 | " a = np.zeros([n_edge_types, n_nodes, n_nodes], dtype=np.int)\n",
225 | " \n",
226 | " a[0] = np.eye(n_nodes, k=1, dtype=int) + np.eye(n_nodes, k=-1, dtype=int)\n",
227 | " a[2] = 1\n",
228 | " \n",
229 | " white_spaces = (char_seq==symbol_to_id[' ']).nonzero()\n",
230 | " start = torch.cat([torch.tensor([0]), white_spaces[1:].view(-1)])\n",
231 | " end = torch.cat([white_spaces[1:].view(-1), torch.tensor([n_nodes])])\n",
232 | " for i in range(len(start)):\n",
233 | " a[1, start[i]:end[i], start[i]:end[i]]=1\n",
234 | " \n",
235 | " # a = np.concatenate((a, a), axis=0)\n",
236 | " # a = np.transpose(a, (1, 0, 2)).reshape(n_nodes, n_nodes*n_edge_types*2)\n",
237 | " return a\n",
238 | "\n",
239 | "def save_file(fname):\n",
240 | " wav_name = os.path.join(root_dir, fname) + '.wav'\n",
241 | " text = metadata[fname]['char']\n",
242 | " char_seq = torch.LongTensor( text2seq(metadata[fname]['char']) )\n",
243 | " adj_matrix = torch.LongTensor(create_adjacency_matrix(char_seq))\n",
244 | " \n",
245 | " with open(f'{data_dir}/adj_matrix/{fname}_adj_matrix.pkl', 'wb') as f:\n",
246 | " pkl.dump(adj_matrix, f)\n",
247 | "\n",
248 | "\n",
249 | "for k in tqdm(metadata.keys()):\n",
250 | " save_file(k)"
251 | ]
252 | }
253 | ],
254 | "metadata": {
255 | "kernelspec": {
256 | "display_name": "Python [conda env:LYH] *",
257 | "language": "python",
258 | "name": "conda-env-LYH-py"
259 | },
260 | "language_info": {
261 | "codemirror_mode": {
262 | "name": "ipython",
263 | "version": 3
264 | },
265 | "file_extension": ".py",
266 | "mimetype": "text/x-python",
267 | "name": "python",
268 | "nbconvert_exporter": "python",
269 | "pygments_lexer": "ipython3",
270 | "version": "3.7.5"
271 | },
272 | "varInspector": {
273 | "cols": {
274 | "lenName": 16,
275 | "lenType": 16,
276 | "lenVar": 40
277 | },
278 | "kernels_config": {
279 | "python": {
280 | "delete_cmd_postfix": "",
281 | "delete_cmd_prefix": "del ",
282 | "library": "var_list.py",
283 | "varRefreshCmd": "print(var_dic_list())"
284 | },
285 | "r": {
286 | "delete_cmd_postfix": ") ",
287 | "delete_cmd_prefix": "rm(",
288 | "library": "var_list.r",
289 | "varRefreshCmd": "cat(var_dic_list()) "
290 | }
291 | },
292 | "types_to_exclude": [
293 | "module",
294 | "function",
295 | "builtin_function_or_method",
296 | "instance",
297 | "_Feature"
298 | ],
299 | "window_display": false
300 | }
301 | },
302 | "nbformat": 4,
303 | "nbformat_minor": 2
304 | }
305 |
--------------------------------------------------------------------------------
/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2017, Prem Seetharaman
5 | All rights reserved.
6 |
7 | * Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice,
11 | this list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice, this
14 | list of conditions and the following disclaimer in the
15 | documentation and/or other materials provided with the distribution.
16 |
17 | * Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from this
19 | software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | """
32 |
33 | import torch
34 | import numpy as np
35 | import torch.nn.functional as F
36 | from torch.autograd import Variable
37 | from scipy.signal import get_window
38 | from librosa.util import pad_center, tiny
39 | from audio_processing import window_sumsquare
40 |
41 |
42 | class STFT(torch.nn.Module):
43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
44 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
45 | window='hann'):
46 | super(STFT, self).__init__()
47 | self.filter_length = filter_length
48 | self.hop_length = hop_length
49 | self.win_length = win_length
50 | self.window = window
51 | self.forward_transform = None
52 | scale = self.filter_length / self.hop_length
53 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
54 |
55 | cutoff = int((self.filter_length / 2 + 1))
56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
57 | np.imag(fourier_basis[:cutoff, :])])
58 |
59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
60 | inverse_basis = torch.FloatTensor(
61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
62 |
63 | if window is not None:
64 | assert(filter_length >= win_length)
65 | # get window and zero center pad it to filter_length
66 | fft_window = get_window(window, win_length, fftbins=True)
67 | fft_window = pad_center(fft_window, filter_length)
68 | fft_window = torch.from_numpy(fft_window).float()
69 |
70 | # window the bases
71 | forward_basis *= fft_window
72 | inverse_basis *= fft_window
73 |
74 | self.register_buffer('forward_basis', forward_basis.float())
75 | self.register_buffer('inverse_basis', inverse_basis.float())
76 |
77 | def transform(self, input_data):
78 | num_batches = input_data.size(0)
79 | num_samples = input_data.size(1)
80 |
81 | self.num_samples = num_samples
82 |
83 | # similar to librosa, reflect-pad the input
84 | input_data = input_data.view(num_batches, 1, num_samples)
85 | input_data = F.pad(
86 | input_data.unsqueeze(1),
87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
88 | mode='reflect')
89 | input_data = input_data.squeeze(1)
90 |
91 | forward_transform = F.conv1d(
92 | input_data,
93 | Variable(self.forward_basis, requires_grad=False),
94 | stride=self.hop_length,
95 | padding=0)
96 |
97 | cutoff = int((self.filter_length / 2) + 1)
98 | real_part = forward_transform[:, :cutoff, :]
99 | imag_part = forward_transform[:, cutoff:, :]
100 |
101 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
102 | phase = torch.autograd.Variable(
103 | torch.atan2(imag_part.data, real_part.data))
104 |
105 | return magnitude, phase
106 |
107 | def inverse(self, magnitude, phase):
108 | recombine_magnitude_phase = torch.cat(
109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
110 |
111 | inverse_transform = F.conv_transpose1d(
112 | recombine_magnitude_phase,
113 | Variable(self.inverse_basis, requires_grad=False),
114 | stride=self.hop_length,
115 | padding=0)
116 |
117 | if self.window is not None:
118 | window_sum = window_sumsquare(
119 | self.window, magnitude.size(-1), hop_length=self.hop_length,
120 | win_length=self.win_length, n_fft=self.filter_length,
121 | dtype=np.float32)
122 | # remove modulation effects
123 | approx_nonzero_indices = torch.from_numpy(
124 | np.where(window_sum > tiny(window_sum))[0])
125 | window_sum = torch.autograd.Variable(
126 | torch.from_numpy(window_sum), requires_grad=False)
127 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
128 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
129 |
130 | # scale by hop ratio
131 | inverse_transform *= float(self.filter_length) / self.hop_length
132 |
133 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
134 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
135 |
136 | return inverse_transform
137 |
138 | def forward(self, input_data):
139 | self.magnitude, self.phase = self.transform(input_data)
140 | reconstruction = self.inverse(self.magnitude, self.phase)
141 | return reconstruction
142 |
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | # -*- coding: utf-8 -*-
3 |
4 | import re
5 | from text import cleaners
6 | from text.symbols import symbols
7 |
8 | # Mappings from symbol to numeric ID and vice versa:
9 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
10 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
11 |
12 | # Regular expression matching text enclosed in curly braces:
13 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
14 |
15 |
16 | def text_to_sequence(text, cleaner_names):
17 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
18 |
19 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
20 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
21 |
22 | Args:
23 | text: string to convert to a sequence
24 | cleaner_names: names of the cleaner functions to run the text through
25 |
26 | Returns:
27 | List of integers corresponding to the symbols in the text
28 | '''
29 | sequence = [_symbol_to_id['^']]
30 |
31 | # Check for curly braces and treat their contents as ARPAbet:
32 | while len(text):
33 | m = _curly_re.match(text)
34 | if not m:
35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
36 | break
37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
38 | sequence += _arpabet_to_sequence(m.group(2))
39 | text = m.group(3)
40 |
41 | # Append EOS token
42 | sequence.append(_symbol_to_id['~'])
43 | return sequence
44 |
45 |
46 | def sequence_to_text(sequence):
47 | '''Converts a sequence of IDs back to a string'''
48 | result = ''
49 | for symbol_id in sequence:
50 | if symbol_id in _id_to_symbol:
51 | s = _id_to_symbol[symbol_id]
52 | # Enclose ARPAbet back in curly braces:
53 | if len(s) > 1 and s[0] == '@':
54 | s = '{%s}' % s[1:]
55 | result += s
56 | return result.replace('}{', ' ')
57 |
58 |
59 | def _clean_text(text, cleaner_names):
60 | for name in cleaner_names:
61 | cleaner = getattr(cleaners, name)
62 | if not cleaner:
63 | raise Exception('Unknown cleaner: %s' % name)
64 | text = cleaner(text)
65 | return text
66 |
67 |
68 | def _symbols_to_sequence(symbols):
69 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
70 |
71 |
72 | def _arpabet_to_sequence(text):
73 | return _symbols_to_sequence(['@' + s for s in text.split()])
74 |
75 |
76 | def _should_keep_symbol(s):
77 | return s in _symbol_to_id and s is not '_' and s is not '~'
78 |
--------------------------------------------------------------------------------
/text/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cleaners.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/cleaners.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cleaners.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/cleaners.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cmudict.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/cmudict.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/numbers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/numbers.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/symbols.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/text/__pycache__/symbols.cpython-37.pyc
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """This file is derived from https://github.com/keithito/tacotron.
2 |
3 | Cleaners are transformations that run over the input text at both training and eval time.
4 |
5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
7 | 1. "english_cleaners" for English text
8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
11 | the symbols in symbols.py to match your data).
12 | """
13 |
14 | import re
15 |
16 | from unidecode import unidecode
17 |
18 | from text.numbers import normalize_numbers
19 |
20 | # Regular expression matching whitespace:
21 | _whitespace_re = re.compile(r'\s+')
22 |
23 | # List of (regular expression, replacement) pairs for abbreviations:
24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
25 | ('mrs', 'misess'),
26 | ('mr', 'mister'),
27 | ('dr', 'doctor'),
28 | ('st', 'saint'),
29 | ('co', 'company'),
30 | ('jr', 'junior'),
31 | ('maj', 'major'),
32 | ('gen', 'general'),
33 | ('drs', 'doctors'),
34 | ('rev', 'reverend'),
35 | ('lt', 'lieutenant'),
36 | ('hon', 'honorable'),
37 | ('sgt', 'sergeant'),
38 | ('capt', 'captain'),
39 | ('esq', 'esquire'),
40 | ('ltd', 'limited'),
41 | ('col', 'colonel'),
42 | ('ft', 'fort'),
43 | ]]
44 |
45 |
46 | def expand_abbreviations(text):
47 | for regex, replacement in _abbreviations:
48 | text = re.sub(regex, replacement, text)
49 | return text
50 |
51 |
52 | def expand_numbers(text):
53 | return normalize_numbers(text)
54 |
55 |
56 | def lowercase(text):
57 | return text.lower()
58 |
59 |
60 | def collapse_whitespace(text):
61 | return re.sub(_whitespace_re, ' ', text)
62 |
63 |
64 | def convert_to_ascii(text):
65 | return unidecode(text)
66 |
67 |
68 | def basic_cleaners(text):
69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
70 | text = lowercase(text)
71 | text = collapse_whitespace(text)
72 | return text
73 |
74 |
75 | def transliteration_cleaners(text):
76 | '''Pipeline for non-English text that transliterates to ASCII.'''
77 | text = convert_to_ascii(text)
78 | text = lowercase(text)
79 | text = collapse_whitespace(text)
80 | return text
81 |
82 |
83 | def english_cleaners(text):
84 | '''Pipeline for English text, including number and abbreviation expansion.'''
85 | text = convert_to_ascii(text)
86 | text = lowercase(text)
87 | text = expand_numbers(text)
88 | text = expand_abbreviations(text)
89 | text = collapse_whitespace(text)
90 | return text
91 |
92 |
93 | # NOTE (kan-bayashi): Following functions additionally defined, not inclueded in original codes.
94 | def remove_unnecessary_symbols(text):
95 | # added
96 | text = re.sub(r'[\(\)\[\]\<\>\"]+', '', text)
97 | return text
98 |
99 |
100 | def expand_symbols(text):
101 | # added
102 | text = re.sub("\;", ",", text)
103 | text = re.sub("\:", ",", text)
104 | text = re.sub("\-", " ", text)
105 | text = re.sub("\&", "and", text)
106 | return text
107 |
108 |
109 | def uppercase(text):
110 | # added
111 | return text.upper()
112 |
113 |
114 | def custom_english_cleaners(text):
115 | '''Custom pipeline for English text, including number and abbreviation expansion.'''
116 | text = convert_to_ascii(text)
117 | text = lowercase(text)
118 | text = expand_numbers(text)
119 | text = expand_abbreviations(text)
120 | text = expand_symbols(text)
121 | text = remove_unnecessary_symbols(text)
122 | text = uppercase(text)
123 | text = collapse_whitespace(text)
124 |
125 | # There is an exception (I found it!)
126 | # "'NOW FOR YOU, MY POOR FELLOW MORTALS, WHO ARE ABOUT TO SUFFER THE LAST PENALTY OF THE LAW.'"
127 | if text[0]=="'" and text[-1]=="'":
128 | text = text[1:-1]
129 |
130 | return text
131 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
14 | ]
15 |
16 | _valid_symbol_set = set(valid_symbols)
17 |
18 |
19 | class CMUDict:
20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
21 | def __init__(self, file_or_path, keep_ambiguous=True):
22 | if isinstance(file_or_path, str):
23 | with open(file_or_path, encoding='latin-1') as f:
24 | entries = _parse_cmudict(f)
25 | else:
26 | entries = _parse_cmudict(file_or_path)
27 | if not keep_ambiguous:
28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
29 | self._entries = entries
30 |
31 |
32 | def __len__(self):
33 | return len(self._entries)
34 |
35 |
36 | def lookup(self, word):
37 | '''Returns list of ARPAbet pronunciations of the given word.'''
38 | return self._entries.get(word.upper())
39 |
40 |
41 |
42 | _alt_re = re.compile(r'\([0-9]+\)')
43 |
44 |
45 | def _parse_cmudict(file):
46 | cmudict = {}
47 | for line in file:
48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
49 | parts = line.split(' ')
50 | word = re.sub(_alt_re, '', parts[0])
51 | pronunciation = _get_pronunciation(parts[1])
52 | if pronunciation:
53 | if word in cmudict:
54 | cmudict[word].append(pronunciation)
55 | else:
56 | cmudict[word] = [pronunciation]
57 | return cmudict
58 |
59 |
60 | def _get_pronunciation(s):
61 | parts = s.strip().split(' ')
62 | for part in parts:
63 | if part not in _valid_symbol_set:
64 | return None
65 | return ' '.join(parts)
66 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """ from https://github.com/keithito/tacotron """
3 |
4 | import inflect
5 | import re
6 |
7 |
8 | _inflect = inflect.engine()
9 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
10 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
11 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
12 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
13 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
14 | _number_re = re.compile(r'[0-9]+')
15 |
16 |
17 | def _remove_commas(m):
18 | return m.group(1).replace(',', '')
19 |
20 |
21 | def _expand_decimal_point(m):
22 | return m.group(1).replace('.', ' point ')
23 |
24 |
25 | def _expand_dollars(m):
26 | match = m.group(1)
27 | parts = match.split('.')
28 | if len(parts) > 2:
29 | return match + ' dollars' # Unexpected format
30 | dollars = int(parts[0]) if parts[0] else 0
31 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
32 | if dollars and cents:
33 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
34 | cent_unit = 'cent' if cents == 1 else 'cents'
35 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
36 | elif dollars:
37 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
38 | return '%s %s' % (dollars, dollar_unit)
39 | elif cents:
40 | cent_unit = 'cent' if cents == 1 else 'cents'
41 | return '%s %s' % (cents, cent_unit)
42 | else:
43 | return 'zero dollars'
44 |
45 |
46 | def _expand_ordinal(m):
47 | return _inflect.number_to_words(m.group(0))
48 |
49 |
50 | def _expand_number(m):
51 | num = int(m.group(0))
52 | if num > 1000 and num < 3000:
53 | if num == 2000:
54 | return 'two thousand'
55 | elif num > 2000 and num < 2010:
56 | return 'two thousand ' + _inflect.number_to_words(num % 100)
57 | elif num % 100 == 0:
58 | return _inflect.number_to_words(num // 100) + ' hundred'
59 | else:
60 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
61 | else:
62 | return _inflect.number_to_words(num, andword='')
63 |
64 |
65 | def normalize_numbers(text):
66 | text = re.sub(_comma_number_re, _remove_commas, text)
67 | text = re.sub(_pounds_re, r'\1 pounds', text)
68 | text = re.sub(_dollars_re, _expand_dollars, text)
69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
70 | text = re.sub(_ordinal_re, _expand_ordinal, text)
71 | text = re.sub(_number_re, _expand_number, text)
72 | return text
73 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7 | from text import cmudict
8 |
9 | _pad = '_'
10 | _sos = '^'
11 | _eos = '~'
12 | _punctuations = " ,.'?!"
13 | _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
14 |
15 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
16 | _arpabet = ['@' + s for s in cmudict.valid_symbols]
17 |
18 | # Export all symbols:
19 | symbols = [_pad, _sos, _eos] + list(_punctuations) + list(_characters) + _arpabet
--------------------------------------------------------------------------------
/train-gae.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modules.model import GAE
6 | from modules.loss import TransformerLoss
7 | import hparams
8 | from text import *
9 | from utils.utils import *
10 | from utils.writer import get_writer
11 |
12 |
13 | def validate(model, criterion, val_loader, iteration, writer):
14 | model.eval()
15 | with torch.no_grad():
16 | n_data, val_loss = 0, 0
17 | for i, batch in enumerate(val_loader):
18 | n_data += len(batch[0])
19 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
20 | x.cuda() for x in batch
21 | ]
22 |
23 | mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded,
24 | adj_padded,
25 | mel_padded,
26 | text_lengths,
27 | mel_lengths)
28 |
29 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
30 | (mel_padded, gate_padded),
31 | (enc_dec_alignments, text_lengths, mel_lengths))
32 |
33 | loss = torch.mean(mel_loss+bce_loss+guide_loss)
34 | val_loss += loss.item() * len(batch[0])
35 |
36 | val_loss /= n_data
37 |
38 | writer.add_losses(mel_loss.item(),
39 | bce_loss.item(),
40 | guide_loss.item(),
41 | iteration//hparams.accumulation, 'Validation')
42 |
43 | writer.add_specs(mel_padded.detach().cpu(),
44 | mel_out.detach().cpu(),
45 | mel_out_post.detach().cpu(),
46 | mel_lengths.detach().cpu(),
47 | iteration//hparams.accumulation, 'Validation')
48 |
49 | enc_alignments=dec_alignments.new_zeros(enc_dec_alignments.size(0),
50 | enc_dec_alignments.size(1),
51 | enc_dec_alignments.size(2),
52 | enc_dec_alignments.size(4),
53 | enc_dec_alignments.size(4))
54 | writer.add_alignments(enc_alignments.detach().cpu(),
55 | dec_alignments.detach().cpu(),
56 | enc_dec_alignments.detach().cpu(),
57 | text_padded.detach().cpu(),
58 | mel_lengths.detach().cpu(),
59 | text_lengths.detach().cpu(),
60 | iteration//hparams.accumulation, 'Validation')
61 |
62 | writer.add_gates(gate_out.detach().cpu(),
63 | iteration//hparams.accumulation, 'Validation')
64 | model.train()
65 |
66 |
67 |
68 | def main():
69 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
70 | model = nn.DataParallel(GAE(hparams)).cuda()
71 | optimizer = torch.optim.Adam(model.parameters(),
72 | lr=hparams.lr,
73 | betas=(0.9, 0.98),
74 | eps=1e-09)
75 | criterion = TransformerLoss()
76 | writer = get_writer(hparams.output_directory, hparams.log_directory)
77 |
78 | iteration, loss = 0, 0
79 | model.train()
80 | print("Training Start!!!")
81 | while iteration < (hparams.train_steps*hparams.accumulation):
82 | for i, batch in enumerate(train_loader):
83 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
84 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
85 | ]
86 |
87 | mel_loss, bce_loss, guide_loss = model(text_padded,
88 | adj_padded,
89 | mel_padded,
90 | gate_padded,
91 | text_lengths,
92 | mel_lengths,
93 | criterion)
94 |
95 | mel_loss, bce_loss, guide_loss=[
96 | torch.mean(x) for x in [mel_loss, bce_loss, guide_loss]
97 | ]
98 | sub_loss = (mel_loss+bce_loss+guide_loss)/hparams.accumulation
99 | sub_loss.backward()
100 | loss = loss+sub_loss.item()
101 |
102 | iteration += 1
103 | if iteration%hparams.accumulation == 0:
104 | lr_scheduling(optimizer, iteration//hparams.accumulation)
105 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
106 | optimizer.step()
107 | model.zero_grad()
108 | writer.add_losses(mel_loss.item(),
109 | bce_loss.item(),
110 | guide_loss.item(),
111 | iteration//hparams.accumulation, 'Train')
112 | loss=0
113 |
114 |
115 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
116 | validate(model, criterion, val_loader, iteration, writer)
117 |
118 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
119 | save_checkpoint(model,
120 | optimizer,
121 | hparams.lr,
122 | iteration//hparams.accumulation,
123 | filepath=f'{hparams.output_directory}/{hparams.log_directory}')
124 |
125 | if iteration==(hparams.train_steps*hparams.accumulation):
126 | break
127 |
128 |
129 | if __name__ == '__main__':
130 | p = argparse.ArgumentParser()
131 | p.add_argument('--gpu', type=str, default='0,1')
132 | p.add_argument('-v', '--verbose', type=str, default='0')
133 | args = p.parse_args()
134 | hparams.log_directory = 'gae-char'
135 | hparams.iterations = 1
136 |
137 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
138 | torch.manual_seed(hparams.seed)
139 | torch.cuda.manual_seed(hparams.seed)
140 |
141 | if args.verbose=='0':
142 | import warnings
143 | warnings.filterwarnings("ignore")
144 |
145 | main()
--------------------------------------------------------------------------------
/train-graphtts.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modules.model import GraphTTS
6 | from modules.loss import TransformerLoss
7 | import hparams
8 | from text import *
9 | from utils.utils import *
10 | from utils.writer import get_writer
11 |
12 |
13 | def validate(model, criterion, val_loader, iteration, writer):
14 | model.eval()
15 | with torch.no_grad():
16 | n_data, val_loss = 0, 0
17 | for i, batch in enumerate(val_loader):
18 | n_data += len(batch[0])
19 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
20 | x.cuda() for x in batch
21 | ]
22 |
23 | mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded,
24 | adj_padded,
25 | mel_padded,
26 | text_lengths,
27 | mel_lengths)
28 |
29 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
30 | (mel_padded, gate_padded),
31 | (enc_dec_alignments, text_lengths, mel_lengths))
32 |
33 | loss = torch.mean(mel_loss+bce_loss+guide_loss)
34 | val_loss += loss.item() * len(batch[0])
35 |
36 | val_loss /= n_data
37 |
38 | writer.add_losses(mel_loss.item(),
39 | bce_loss.item(),
40 | guide_loss.item(),
41 | iteration//hparams.accumulation, 'Validation')
42 |
43 | writer.add_specs(mel_padded.detach().cpu(),
44 | mel_out.detach().cpu(),
45 | mel_out_post.detach().cpu(),
46 | mel_lengths.detach().cpu(),
47 | iteration//hparams.accumulation, 'Validation')
48 |
49 | enc_alignments=dec_alignments.new_zeros(enc_dec_alignments.size(0),
50 | enc_dec_alignments.size(1),
51 | enc_dec_alignments.size(2),
52 | enc_dec_alignments.size(4),
53 | enc_dec_alignments.size(4))
54 | writer.add_alignments(enc_alignments.detach().cpu(),
55 | dec_alignments.detach().cpu(),
56 | enc_dec_alignments.detach().cpu(),
57 | text_padded.detach().cpu(),
58 | mel_lengths.detach().cpu(),
59 | text_lengths.detach().cpu(),
60 | iteration//hparams.accumulation, 'Validation')
61 |
62 | writer.add_gates(gate_out.detach().cpu(),
63 | iteration//hparams.accumulation, 'Validation')
64 | model.train()
65 |
66 |
67 |
68 | def main():
69 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
70 | model = nn.DataParallel(GraphTTS(hparams)).cuda()
71 | optimizer = torch.optim.Adam(model.parameters(),
72 | lr=hparams.lr,
73 | betas=(0.9, 0.98),
74 | eps=1e-09)
75 | criterion = TransformerLoss()
76 | writer = get_writer(hparams.output_directory, hparams.log_directory)
77 |
78 | iteration, loss = 0, 0
79 | model.train()
80 | print("Training Start!!!")
81 | while iteration < (hparams.train_steps*hparams.accumulation):
82 | for i, batch in enumerate(train_loader):
83 | text_padded, adj_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
84 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
85 | ]
86 |
87 | mel_loss, bce_loss, guide_loss = model(text_padded,
88 | adj_padded,
89 | mel_padded,
90 | gate_padded,
91 | text_lengths,
92 | mel_lengths,
93 | criterion)
94 |
95 | mel_loss, bce_loss, guide_loss=[
96 | torch.mean(x) for x in [mel_loss, bce_loss, guide_loss]
97 | ]
98 | sub_loss = (mel_loss+bce_loss+guide_loss)/hparams.accumulation
99 | sub_loss.backward()
100 | loss = loss+sub_loss.item()
101 |
102 | iteration += 1
103 | if iteration%hparams.accumulation == 0:
104 | lr_scheduling(optimizer, iteration//hparams.accumulation)
105 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
106 | optimizer.step()
107 | model.zero_grad()
108 | writer.add_losses(mel_loss.item(),
109 | bce_loss.item(),
110 | guide_loss.item(),
111 | iteration//hparams.accumulation, 'Train')
112 | loss=0
113 |
114 |
115 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
116 | validate(model, criterion, val_loader, iteration, writer)
117 |
118 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
119 | save_checkpoint(model,
120 | optimizer,
121 | hparams.lr,
122 | iteration//hparams.accumulation,
123 | filepath=f'{hparams.output_directory}/{hparams.log_directory}')
124 |
125 | if iteration==(hparams.train_steps*hparams.accumulation):
126 | break
127 |
128 |
129 | if __name__ == '__main__':
130 | p = argparse.ArgumentParser()
131 | p.add_argument('--gpu', type=str, default='0,1')
132 | p.add_argument('-v', '--verbose', type=str, default='0')
133 | args = p.parse_args()
134 | hparams.log_directory = 'graph-tts-char_iter5'
135 | hparams.iterations = 5
136 |
137 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
138 | torch.manual_seed(hparams.seed)
139 | torch.cuda.manual_seed(hparams.seed)
140 |
141 | if args.verbose=='0':
142 | import warnings
143 | warnings.filterwarnings("ignore")
144 |
145 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modules.model import Model
6 | from modules.loss import TransformerLoss
7 | import hparams
8 | from text import *
9 | from utils.utils import *
10 | from utils.writer import get_writer
11 |
12 |
13 | def validate(model, criterion, val_loader, iteration, writer):
14 | model.eval()
15 | with torch.no_grad():
16 | n_data, val_loss = 0, 0
17 | for i, batch in enumerate(val_loader):
18 | n_data += len(batch[0])
19 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
20 | x.cuda() for x in batch
21 | ]
22 |
23 | mel_out, mel_out_post,\
24 | enc_alignments, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded,
25 | mel_padded,
26 | text_lengths,
27 | mel_lengths)
28 |
29 | mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
30 | (mel_padded, gate_padded),
31 | (enc_dec_alignments, text_lengths, mel_lengths))
32 |
33 | loss = torch.mean(mel_loss+bce_loss+guide_loss)
34 | val_loss += loss.item() * len(batch[0])
35 |
36 | val_loss /= n_data
37 |
38 | writer.add_losses(mel_loss.item(),
39 | bce_loss.item(),
40 | guide_loss.item(),
41 | iteration//hparams.accumulation, 'Validation')
42 |
43 | writer.add_specs(mel_padded.detach().cpu(),
44 | mel_out.detach().cpu(),
45 | mel_out_post.detach().cpu(),
46 | mel_lengths.detach().cpu(),
47 | iteration//hparams.accumulation, 'Validation')
48 |
49 | writer.add_alignments(enc_alignments.detach().cpu(),
50 | dec_alignments.detach().cpu(),
51 | enc_dec_alignments.detach().cpu(),
52 | text_padded.detach().cpu(),
53 | mel_lengths.detach().cpu(),
54 | text_lengths.detach().cpu(),
55 | iteration//hparams.accumulation, 'Validation')
56 |
57 | writer.add_gates(gate_out.detach().cpu(),
58 | iteration//hparams.accumulation, 'Validation')
59 | model.train()
60 |
61 |
62 |
63 | def main():
64 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
65 | model = nn.DataParallel(Model(hparams)).cuda()
66 | optimizer = torch.optim.Adam(model.parameters(),
67 | lr=hparams.lr,
68 | betas=(0.9, 0.98),
69 | eps=1e-09)
70 | criterion = TransformerLoss()
71 | writer = get_writer(hparams.output_directory, hparams.log_directory)
72 |
73 | iteration, loss = 0, 0
74 | model.train()
75 | print("Training Start!!!")
76 | while iteration < (hparams.train_steps*hparams.accumulation):
77 | for i, batch in enumerate(train_loader):
78 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
79 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
80 | ]
81 |
82 | mel_loss, bce_loss, guide_loss = model(text_padded,
83 | mel_padded,
84 | gate_padded,
85 | text_lengths,
86 | mel_lengths,
87 | criterion)
88 |
89 | mel_loss, bce_loss, guide_loss=[
90 | torch.mean(x) for x in [mel_loss, bce_loss, guide_loss]
91 | ]
92 | sub_loss = (mel_loss+bce_loss+guide_loss)/hparams.accumulation
93 | sub_loss.backward()
94 | loss = loss+sub_loss.item()
95 |
96 | iteration += 1
97 | if iteration%hparams.accumulation == 0:
98 | lr_scheduling(optimizer, iteration//hparams.accumulation)
99 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
100 | optimizer.step()
101 | model.zero_grad()
102 | writer.add_losses(mel_loss.item(),
103 | bce_loss.item(),
104 | guide_loss.item(),
105 | iteration//hparams.accumulation, 'Train')
106 | loss=0
107 |
108 |
109 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
110 | validate(model, criterion, val_loader, iteration, writer)
111 |
112 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
113 | save_checkpoint(model,
114 | optimizer,
115 | hparams.lr,
116 | iteration//hparams.accumulation,
117 | filepath=f'{hparams.output_directory}/{hparams.log_directory}')
118 |
119 | if iteration==(hparams.train_steps*hparams.accumulation):
120 | break
121 |
122 |
123 | if __name__ == '__main__':
124 | p = argparse.ArgumentParser()
125 | p.add_argument('--gpu', type=str, default='0,1')
126 | p.add_argument('-v', '--verbose', type=str, default='0')
127 | args = p.parse_args()
128 |
129 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
130 | torch.manual_seed(hparams.seed)
131 | torch.cuda.manual_seed(hparams.seed)
132 |
133 | if args.verbose=='0':
134 | import warnings
135 | warnings.filterwarnings("ignore")
136 |
137 | main()
--------------------------------------------------------------------------------
/utils/__pycache__/data_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/data_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/plot_image.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/plot_image.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/writer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/utils/__pycache__/writer.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import hparams
4 | import torch
5 | import torch.utils.data
6 | import torch.nn.functional as F
7 | import os
8 | import pickle as pkl
9 |
10 | from text import text_to_sequence
11 |
12 |
13 | def load_filepaths_and_text(metadata, split="|"):
14 | with open(metadata, encoding='utf-8') as f:
15 | filepaths_and_text = [line.strip().split(split) for line in f]
16 | return filepaths_and_text
17 |
18 |
19 | class TextMelSet(torch.utils.data.Dataset):
20 | def __init__(self, audiopaths_and_text, hparams):
21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
22 | random.seed(1234)
23 | random.shuffle(self.audiopaths_and_text)
24 | self.data_type=hparams.data_type
25 |
26 | def get_mel_text_pair(self, audiopath_and_text):
27 | # separate filename and text
28 | file_name = audiopath_and_text[0][:10]
29 | seq_path = os.path.join(hparams.data_path, self.data_type)
30 | adj_path = os.path.join(hparams.data_path, 'adj_matrix')
31 | mel_path = os.path.join(hparams.data_path, 'melspectrogram')
32 |
33 | with open(f'{seq_path}/{file_name}_sequence.pkl', 'rb') as f:
34 | text = pkl.load(f)
35 | with open(f'{adj_path}/{file_name}_adj_matrix.pkl', 'rb') as f:
36 | adj = pkl.load(f)
37 | with open(f'{mel_path}/{file_name}_melspectrogram.pkl', 'rb') as f:
38 | mel = pkl.load(f)
39 |
40 | return (text, adj, mel)
41 |
42 | def __getitem__(self, index):
43 | return self.get_mel_text_pair(self.audiopaths_and_text[index])
44 |
45 | def __len__(self):
46 | return len(self.audiopaths_and_text)
47 |
48 |
49 | class TextMelCollate():
50 | def __init__(self):
51 | return
52 |
53 | def __call__(self, batch):
54 | # Right zero-pad all one-hot text sequences to max input length
55 | input_lengths, ids_sorted_decreasing = torch.sort(
56 | torch.LongTensor([len(x[0]) for x in batch]),
57 | dim=0, descending=True)
58 | max_input_len = input_lengths[0]
59 |
60 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long)
61 | adj_padded = torch.zeros(len(batch), 3, max_input_len, max_input_len, dtype=torch.long)
62 | for i in range(len(ids_sorted_decreasing)):
63 | text = batch[ids_sorted_decreasing[i]][0]
64 | adj = batch[ids_sorted_decreasing[i]][1]
65 | text_padded[i, :text.size(0)] = text
66 | adj_padded[i, :, :adj.size(1), :adj.size(1)] = adj
67 |
68 | # Right zero-pad
69 | num_mels = batch[0][2].size(0)
70 | max_target_len = max([x[2].size(1) for x in batch])
71 |
72 | # include Spec padded and gate padded
73 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len)
74 | gate_padded = torch.zeros(len(batch), max_target_len)
75 | output_lengths = torch.LongTensor(len(batch))
76 | for i in range(len(ids_sorted_decreasing)):
77 | mel = batch[ids_sorted_decreasing[i]][2]
78 | mel_padded[i, :, :mel.size(1)] = mel
79 | gate_padded[i, mel.size(1)-1:] = 1
80 | output_lengths[i] = mel.size(1)
81 |
82 | return text_padded, adj_padded, input_lengths, mel_padded, output_lengths, gate_padded
--------------------------------------------------------------------------------
/utils/plot_image.py:
--------------------------------------------------------------------------------
1 | from text import *
2 | import torch
3 | import hparams
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def plot_melspec(target, melspec, melspec_post, mel_lengths):
8 | fig, axes = plt.subplots(3, 1, figsize=(20,30))
9 | T = mel_lengths[-1]
10 |
11 | axes[0].imshow(target[-1][:,:T],
12 | origin='lower',
13 | aspect='auto')
14 |
15 | axes[1].imshow(melspec[-1][:,:T],
16 | origin='lower',
17 | aspect='auto')
18 |
19 | axes[2].imshow(melspec_post[-1][:,:T],
20 | origin='lower',
21 | aspect='auto')
22 |
23 | return fig
24 |
25 |
26 | def plot_alignments(alignments, text, mel_lengths, text_lengths, att_type):
27 | fig, axes = plt.subplots(hparams.n_layers, hparams.n_heads, figsize=(5*hparams.n_heads,5*hparams.n_layers))
28 | L, T = text_lengths[-1], mel_lengths[-1]
29 | n_layers, n_heads = alignments.size(1), alignments.size(2)
30 |
31 | for layer in range(n_layers):
32 | for head in range(n_heads):
33 | if att_type=='enc':
34 | align = alignments[-1, layer, head].contiguous()
35 | axes[layer,head].imshow(align[:L, :L], aspect='auto')
36 | axes[layer,head].xaxis.tick_top()
37 |
38 | elif att_type=='dec':
39 | align = alignments[-1, layer, head].contiguous()
40 | axes[layer,head].imshow(align[:T, :T], aspect='auto')
41 | axes[layer,head].xaxis.tick_top()
42 |
43 | elif att_type=='enc_dec':
44 | align = alignments[-1, layer, head].transpose(0,1).contiguous()
45 | axes[layer,head].imshow(align[:L, :T], origin='lower', aspect='auto')
46 |
47 | return fig
48 |
49 |
50 | def plot_gate(gate_out):
51 | fig = plt.figure(figsize=(10,5))
52 | plt.plot(torch.sigmoid(gate_out[-1]))
53 | return fig
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import hparams
2 | from torch.utils.data import DataLoader
3 | from .data_utils import TextMelSet, TextMelCollate
4 | import torch
5 | from text import *
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def prepare_dataloaders(hparams):
10 | # Get data, data loaders and collate function ready
11 | trainset = TextMelSet(hparams.training_files, hparams)
12 | valset = TextMelSet(hparams.validation_files, hparams)
13 | collate_fn = TextMelCollate()
14 |
15 | train_loader = DataLoader(trainset,
16 | num_workers=hparams.n_gpus-1,
17 | shuffle=True,
18 | batch_size=hparams.batch_size,
19 | drop_last=True,
20 | collate_fn=collate_fn)
21 |
22 | val_loader = DataLoader(valset,
23 | batch_size=hparams.batch_size//hparams.n_gpus,
24 | collate_fn=collate_fn)
25 |
26 | return train_loader, val_loader, collate_fn
27 |
28 |
29 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
30 | print(f"Saving model and optimizer state at iteration {iteration} to {filepath}")
31 | torch.save({'iteration': iteration,
32 | 'state_dict': model.state_dict(),
33 | 'optimizer': optimizer.state_dict(),
34 | 'learning_rate': learning_rate}, f'{filepath}/checkpoint_{iteration}')
35 |
36 |
37 | def lr_scheduling(opt, step, init_lr=hparams.lr, warmup_steps=hparams.warmup_steps):
38 | opt.param_groups[0]['lr'] = init_lr * min(step ** -0.5, step * warmup_steps ** -1.5)
39 | return
40 |
41 |
42 | def get_mask_from_lengths(lengths):
43 | max_len = torch.max(lengths).item()
44 | ids = lengths.new_tensor(torch.arange(0, max_len))
45 | mask = (lengths.unsqueeze(1) <= ids).to(torch.bool)
46 | return mask
47 |
48 |
49 | def reorder_batch(x, n_gpus):
50 | assert (x.size(0)%n_gpus)==0, 'Batch size must be a multiple of the number of GPUs.'
51 | new_x = x.new_zeros(x.size())
52 | chunk_size = x.size(0)//n_gpus
53 |
54 | for i in range(n_gpus):
55 | new_x[i::n_gpus] = x[i*chunk_size:(i+1)*chunk_size]
56 |
57 | return new_x
58 |
--------------------------------------------------------------------------------
/utils/writer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.tensorboard import SummaryWriter
3 | from .plot_image import *
4 |
5 | def get_writer(output_directory, log_directory):
6 | logging_path=f'{output_directory}/{log_directory}'
7 |
8 | if os.path.exists(logging_path):
9 | raise Exception('The experiment already exists')
10 | else:
11 | os.mkdir(logging_path)
12 | writer = TTSWriter(logging_path)
13 |
14 | return writer
15 |
16 |
17 | class TTSWriter(SummaryWriter):
18 | def __init__(self, log_dir):
19 | super(TTSWriter, self).__init__(log_dir)
20 |
21 | def add_losses(self, mel_loss, bce_loss, guide_loss, global_step, phase):
22 | self.add_scalar(f'{phase}_mel_loss', mel_loss, global_step)
23 | self.add_scalar(f'{phase}_bce_loss', bce_loss, global_step)
24 | self.add_scalar(f'{phase}_guide_loss', guide_loss, global_step)
25 |
26 | def add_specs(self, mel_padded, mel_out, mel_out_post, mel_lengths, global_step, phase):
27 | mel_fig = plot_melspec(mel_padded, mel_out, mel_out_post, mel_lengths)
28 | self.add_figure(f'{phase}_melspec', mel_fig, global_step)
29 |
30 | def add_alignments(self, enc_alignments, dec_alignments, enc_dec_alignments,
31 | text_padded, mel_lengths, text_lengths, global_step, phase):
32 | enc_align_fig = plot_alignments(enc_alignments, text_padded, mel_lengths, text_lengths, 'enc')
33 | self.add_figure(f'{phase}_enc_alignments', enc_align_fig, global_step)
34 |
35 | dec_align_fig = plot_alignments(dec_alignments, text_padded, mel_lengths, text_lengths, 'dec')
36 | self.add_figure(f'{phase}_dec_alignments', dec_align_fig, global_step)
37 |
38 | enc_dec_align_fig = plot_alignments(enc_dec_alignments, text_padded, mel_lengths, text_lengths, 'enc_dec')
39 | self.add_figure(f'{phase}_enc_dec_alignments', enc_dec_align_fig, global_step)
40 |
41 | def add_gates(self, gate_out, global_step, phase):
42 | gate_fig = plot_gate(gate_out)
43 | self.add_figure(f'{phase}_gate_out', gate_fig, global_step)
--------------------------------------------------------------------------------
/waveglow/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "tacotron2"]
2 | path = tacotron2
3 | url = http://github.com/NVIDIA/tacotron2
4 |
--------------------------------------------------------------------------------
/waveglow/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, NVIDIA Corporation
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/waveglow/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | ## WaveGlow: a Flow-based Generative Network for Speech Synthesis
4 |
5 | ### Ryan Prenger, Rafael Valle, and Bryan Catanzaro
6 |
7 | In our recent [paper], we propose WaveGlow: a flow-based network capable of
8 | generating high quality speech from mel-spectrograms. WaveGlow combines insights
9 | from [Glow] and [WaveNet] in order to provide fast, efficient and high-quality
10 | audio synthesis, without the need for auto-regression. WaveGlow is implemented
11 | using only a single network, trained using only a single cost function:
12 | maximizing the likelihood of the training data, which makes the training
13 | procedure simple and stable.
14 |
15 | Our [PyTorch] implementation produces audio samples at a rate of 2750
16 | kHz on an NVIDIA V100 GPU. Mean Opinion Scores show that it delivers audio
17 | quality as good as the best publicly available WaveNet implementation.
18 |
19 | Visit our [website] for audio samples.
20 |
21 | ## Setup
22 |
23 | 1. Clone our repo and initialize submodule
24 |
25 | ```command
26 | git clone https://github.com/NVIDIA/waveglow.git
27 | cd waveglow
28 | git submodule init
29 | git submodule update
30 | ```
31 |
32 | 2. Install requirements `pip3 install -r requirements.txt`
33 |
34 | 3. Install [Apex]
35 |
36 |
37 | ## Generate audio with our pre-existing model
38 |
39 | 1. Download our [published model]
40 | 2. Download [mel-spectrograms]
41 | 3. Generate audio `python3 inference.py -f <(ls mel_spectrograms/*.pt) -w waveglow_256channels.pt -o . --is_fp16 -s 0.6`
42 |
43 | N.b. use `convert_model.py` to convert your older models to the current model
44 | with fused residual and skip connections.
45 |
46 | ## Train your own model
47 |
48 | 1. Download [LJ Speech Data]. In this example it's in `data/`
49 |
50 | 2. Make a list of the file names to use for training/testing
51 |
52 | ```command
53 | ls data/*.wav | tail -n+10 > train_files.txt
54 | ls data/*.wav | head -n10 > test_files.txt
55 | ```
56 |
57 | 3. Train your WaveGlow networks
58 |
59 | ```command
60 | mkdir checkpoints
61 | python train.py -c config.json
62 | ```
63 |
64 | For multi-GPU training replace `train.py` with `distributed.py`. Only tested with single node and NCCL.
65 |
66 | For mixed precision training set `"fp16_run": true` on `config.json`.
67 |
68 | 4. Make test set mel-spectrograms
69 |
70 | `python mel2samp.py -f test_files.txt -o . -c config.json`
71 |
72 | 5. Do inference with your network
73 |
74 | ```command
75 | ls *.pt > mel_files.txt
76 | python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6
77 | ```
78 |
79 | [//]: # (TODO)
80 | [//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS)
81 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation
82 | [website]: https://nv-adlr.github.io/WaveGlow
83 | [paper]: https://arxiv.org/abs/1811.00002
84 | [WaveNet implementation]: https://github.com/r9y9/wavenet_vocoder
85 | [Glow]: https://blog.openai.com/glow/
86 | [WaveNet]: https://deepmind.com/blog/wavenet-generative-model-raw-audio/
87 | [PyTorch]: http://pytorch.org
88 | [published model]: https://drive.google.com/file/d/1WsibBTsuRg_SF2Z6L6NFRTT-NjEy1oTx/view?usp=sharing
89 | [mel-spectrograms]: https://drive.google.com/file/d/1g_VXK2lpP9J25dQFhQwx7doWl_p20fXA/view?usp=sharing
90 | [LJ Speech Data]: https://keithito.com/LJ-Speech-Dataset
91 | [Apex]: https://github.com/nvidia/apex
92 |
--------------------------------------------------------------------------------
/waveglow/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_config": {
3 | "fp16_run": true,
4 | "output_directory": "checkpoints",
5 | "epochs": 100000,
6 | "learning_rate": 1e-4,
7 | "sigma": 1.0,
8 | "iters_per_checkpoint": 2000,
9 | "batch_size": 12,
10 | "seed": 1234,
11 | "checkpoint_path": "",
12 | "with_tensorboard": false
13 | },
14 | "data_config": {
15 | "training_files": "train_files.txt",
16 | "segment_length": 16000,
17 | "sampling_rate": 22050,
18 | "filter_length": 1024,
19 | "hop_length": 256,
20 | "win_length": 1024,
21 | "mel_fmin": 0.0,
22 | "mel_fmax": 8000.0
23 | },
24 | "dist_config": {
25 | "dist_backend": "nccl",
26 | "dist_url": "tcp://localhost:54321"
27 | },
28 |
29 | "waveglow_config": {
30 | "n_mel_channels": 80,
31 | "n_flows": 12,
32 | "n_group": 8,
33 | "n_early_every": 4,
34 | "n_early_size": 2,
35 | "WN_config": {
36 | "n_layers": 8,
37 | "n_channels": 256,
38 | "kernel_size": 3
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/waveglow/convert_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import copy
3 | import torch
4 |
5 | def _check_model_old_version(model):
6 | if hasattr(model.WN[0], 'res_layers'):
7 | return True
8 | else:
9 | return False
10 |
11 | def update_model(old_model):
12 | if not _check_model_old_version(old_model):
13 | return old_model
14 | new_model = copy.deepcopy(old_model)
15 | for idx in range(0, len(new_model.WN)):
16 | wavenet = new_model.WN[idx]
17 | wavenet.res_skip_layers = torch.nn.ModuleList()
18 | n_channels = wavenet.n_channels
19 | n_layers = wavenet.n_layers
20 | for i in range(0, n_layers):
21 | if i < n_layers - 1:
22 | res_skip_channels = 2*n_channels
23 | else:
24 | res_skip_channels = n_channels
25 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
26 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
27 | if i < n_layers - 1:
28 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
29 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
30 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
31 | else:
32 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
33 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
34 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
35 | wavenet.res_skip_layers.append(res_skip_layer)
36 | del wavenet.res_layers
37 | del wavenet.skip_layers
38 | return new_model
39 |
40 | if __name__ == '__main__':
41 | old_model_path = sys.argv[1]
42 | new_model_path = sys.argv[2]
43 | model = torch.load(old_model_path)
44 | model['model'] = update_model(model['model'])
45 | torch.save(model, new_model_path)
46 |
47 |
--------------------------------------------------------------------------------
/waveglow/denoiser.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('tacotron2')
3 | import torch
4 | from layers import STFT
5 |
6 |
7 | class Denoiser(torch.nn.Module):
8 | """ Removes model bias from audio produced with waveglow """
9 |
10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4,
11 | win_length=1024, mode='zeros'):
12 | super(Denoiser, self).__init__()
13 | self.stft = STFT(filter_length=filter_length,
14 | hop_length=int(filter_length/n_overlap),
15 | win_length=win_length).cuda()
16 | if mode == 'zeros':
17 | mel_input = torch.zeros(
18 | (1, 80, 88),
19 | dtype=waveglow.upsample.weight.dtype,
20 | device=waveglow.upsample.weight.device)
21 | elif mode == 'normal':
22 | mel_input = torch.randn(
23 | (1, 80, 88),
24 | dtype=waveglow.upsample.weight.dtype,
25 | device=waveglow.upsample.weight.device)
26 | else:
27 | raise Exception("Mode {} if not supported".format(mode))
28 |
29 | with torch.no_grad():
30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
31 | bias_spec, _ = self.stft.transform(bias_audio)
32 |
33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
34 |
35 | def forward(self, audio, strength=0.1):
36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
37 | audio_spec_denoised = audio_spec - self.bias_spec * strength
38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
40 | return audio_denoised
41 |
--------------------------------------------------------------------------------
/waveglow/distributed.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import os
28 | import sys
29 | import time
30 | import subprocess
31 | import argparse
32 |
33 | import torch
34 | import torch.distributed as dist
35 | from torch.autograd import Variable
36 |
37 | def reduce_tensor(tensor, num_gpus):
38 | rt = tensor.clone()
39 | dist.all_reduce(rt, op=dist.reduce_op.SUM)
40 | rt /= num_gpus
41 | return rt
42 |
43 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
44 | assert torch.cuda.is_available(), "Distributed mode requires CUDA."
45 | print("Initializing Distributed")
46 |
47 | # Set cuda device so everything is done on the right GPU.
48 | torch.cuda.set_device(rank % torch.cuda.device_count())
49 |
50 | # Initialize distributed communication
51 | dist.init_process_group(dist_backend, init_method=dist_url,
52 | world_size=num_gpus, rank=rank,
53 | group_name=group_name)
54 |
55 | def _flatten_dense_tensors(tensors):
56 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
57 | same dense type.
58 | Since inputs are dense, the resulting tensor will be a concatenated 1D
59 | buffer. Element-wise operation on this buffer will be equivalent to
60 | operating individually.
61 | Arguments:
62 | tensors (Iterable[Tensor]): dense tensors to flatten.
63 | Returns:
64 | A contiguous 1D buffer containing input tensors.
65 | """
66 | if len(tensors) == 1:
67 | return tensors[0].contiguous().view(-1)
68 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
69 | return flat
70 |
71 | def _unflatten_dense_tensors(flat, tensors):
72 | """View a flat buffer using the sizes of tensors. Assume that tensors are of
73 | same dense type, and that flat is given by _flatten_dense_tensors.
74 | Arguments:
75 | flat (Tensor): flattened dense tensors to unflatten.
76 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
77 | unflatten flat.
78 | Returns:
79 | Unflattened dense tensors with sizes same as tensors and values from
80 | flat.
81 | """
82 | outputs = []
83 | offset = 0
84 | for tensor in tensors:
85 | numel = tensor.numel()
86 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
87 | offset += numel
88 | return tuple(outputs)
89 |
90 | def apply_gradient_allreduce(module):
91 | """
92 | Modifies existing model to do gradient allreduce, but doesn't change class
93 | so you don't need "module"
94 | """
95 | if not hasattr(dist, '_backend'):
96 | module.warn_on_half = True
97 | else:
98 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
99 |
100 | for p in module.state_dict().values():
101 | if not torch.is_tensor(p):
102 | continue
103 | dist.broadcast(p, 0)
104 |
105 | def allreduce_params():
106 | if(module.needs_reduction):
107 | module.needs_reduction = False
108 | buckets = {}
109 | for param in module.parameters():
110 | if param.requires_grad and param.grad is not None:
111 | tp = type(param.data)
112 | if tp not in buckets:
113 | buckets[tp] = []
114 | buckets[tp].append(param)
115 | if module.warn_on_half:
116 | if torch.cuda.HalfTensor in buckets:
117 | print("WARNING: gloo dist backend for half parameters may be extremely slow." +
118 | " It is recommended to use the NCCL backend in this case. This currently requires" +
119 | "PyTorch built from top of tree master.")
120 | module.warn_on_half = False
121 |
122 | for tp in buckets:
123 | bucket = buckets[tp]
124 | grads = [param.grad.data for param in bucket]
125 | coalesced = _flatten_dense_tensors(grads)
126 | dist.all_reduce(coalesced)
127 | coalesced /= dist.get_world_size()
128 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
129 | buf.copy_(synced)
130 |
131 | for param in list(module.parameters()):
132 | def allreduce_hook(*unused):
133 | Variable._execution_engine.queue_callback(allreduce_params)
134 | if param.requires_grad:
135 | param.register_hook(allreduce_hook)
136 | dir(param)
137 |
138 | def set_needs_reduction(self, input, output):
139 | self.needs_reduction = True
140 |
141 | module.register_forward_hook(set_needs_reduction)
142 | return module
143 |
144 |
145 | def main(config, stdout_dir, args_str):
146 | args_list = ['train.py']
147 | args_list += args_str.split(' ') if len(args_str) > 0 else []
148 |
149 | args_list.append('--config={}'.format(config))
150 |
151 | num_gpus = torch.cuda.device_count()
152 | args_list.append('--num_gpus={}'.format(num_gpus))
153 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S")))
154 |
155 | if not os.path.isdir(stdout_dir):
156 | os.makedirs(stdout_dir)
157 | os.chmod(stdout_dir, 0o775)
158 |
159 | workers = []
160 |
161 | for i in range(num_gpus):
162 | args_list[-2] = '--rank={}'.format(i)
163 | stdout = None if i == 0 else open(
164 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w")
165 | print(args_list)
166 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout)
167 | workers.append(p)
168 |
169 | for p in workers:
170 | p.wait()
171 |
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('-c', '--config', type=str, required=True,
176 | help='JSON file for configuration')
177 | parser.add_argument('-s', '--stdout_dir', type=str, default=".",
178 | help='directory to save stoud logs')
179 | parser.add_argument(
180 | '-a', '--args_str', type=str, default='',
181 | help='double quoted string with space separated key value pairs')
182 |
183 | args = parser.parse_args()
184 | main(args.config, args.stdout_dir, args.args_str)
185 |
--------------------------------------------------------------------------------
/waveglow/glow.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import copy
28 | import torch
29 | from torch.autograd import Variable
30 | import torch.nn.functional as F
31 |
32 |
33 | @torch.jit.script
34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
35 | n_channels_int = n_channels[0]
36 | in_act = input_a+input_b
37 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
39 | acts = t_act * s_act
40 | return acts
41 |
42 |
43 | class WaveGlowLoss(torch.nn.Module):
44 | def __init__(self, sigma=1.0):
45 | super(WaveGlowLoss, self).__init__()
46 | self.sigma = sigma
47 |
48 | def forward(self, model_output):
49 | z, log_s_list, log_det_W_list = model_output
50 | for i, log_s in enumerate(log_s_list):
51 | if i == 0:
52 | log_s_total = torch.sum(log_s)
53 | log_det_W_total = log_det_W_list[i]
54 | else:
55 | log_s_total = log_s_total + torch.sum(log_s)
56 | log_det_W_total += log_det_W_list[i]
57 |
58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
59 | return loss/(z.size(0)*z.size(1)*z.size(2))
60 |
61 |
62 | class Invertible1x1Conv(torch.nn.Module):
63 | """
64 | The layer outputs both the convolution, and the log determinant
65 | of its weight matrix. If reverse=True it does convolution with
66 | inverse
67 | """
68 | def __init__(self, c):
69 | super(Invertible1x1Conv, self).__init__()
70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
71 | bias=False)
72 |
73 | # Sample a random orthonormal matrix to initialize weights
74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
75 |
76 | # Ensure determinant is 1.0 not -1.0
77 | if torch.det(W) < 0:
78 | W[:,0] = -1*W[:,0]
79 | W = W.view(c, c, 1)
80 | self.conv.weight.data = W
81 |
82 | def forward(self, z, reverse=False):
83 | # shape
84 | batch_size, group_size, n_of_groups = z.size()
85 |
86 | W = self.conv.weight.squeeze()
87 |
88 | if reverse:
89 | if not hasattr(self, 'W_inverse'):
90 | # Reverse computation
91 | W_inverse = W.float().inverse()
92 | W_inverse = Variable(W_inverse[..., None])
93 | if z.type() == 'torch.cuda.HalfTensor':
94 | W_inverse = W_inverse.half()
95 | self.W_inverse = W_inverse
96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
97 | return z
98 | else:
99 | # Forward computation
100 | log_det_W = batch_size * n_of_groups * torch.logdet(W)
101 | z = self.conv(z)
102 | return z, log_det_W
103 |
104 |
105 | class WN(torch.nn.Module):
106 | """
107 | This is the WaveNet like layer for the affine coupling. The primary difference
108 | from WaveNet is the convolutions need not be causal. There is also no dilation
109 | size reset. The dilation only doubles on each layer
110 | """
111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
112 | kernel_size):
113 | super(WN, self).__init__()
114 | assert(kernel_size % 2 == 1)
115 | assert(n_channels % 2 == 0)
116 | self.n_layers = n_layers
117 | self.n_channels = n_channels
118 | self.in_layers = torch.nn.ModuleList()
119 | self.res_skip_layers = torch.nn.ModuleList()
120 | self.cond_layers = torch.nn.ModuleList()
121 |
122 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
123 | start = torch.nn.utils.weight_norm(start, name='weight')
124 | self.start = start
125 |
126 | # Initializing last layer to 0 makes the affine coupling layers
127 | # do nothing at first. This helps with training stability
128 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
129 | end.weight.data.zero_()
130 | end.bias.data.zero_()
131 | self.end = end
132 |
133 | for i in range(n_layers):
134 | dilation = 2 ** i
135 | padding = int((kernel_size*dilation - dilation)/2)
136 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
137 | dilation=dilation, padding=padding)
138 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
139 | self.in_layers.append(in_layer)
140 |
141 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
142 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
143 | self.cond_layers.append(cond_layer)
144 |
145 | # last one is not necessary
146 | if i < n_layers - 1:
147 | res_skip_channels = 2*n_channels
148 | else:
149 | res_skip_channels = n_channels
150 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
151 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
152 | self.res_skip_layers.append(res_skip_layer)
153 |
154 | def forward(self, forward_input):
155 | audio, spect = forward_input
156 | audio = self.start(audio)
157 |
158 | for i in range(self.n_layers):
159 | acts = fused_add_tanh_sigmoid_multiply(
160 | self.in_layers[i](audio),
161 | self.cond_layers[i](spect),
162 | torch.IntTensor([self.n_channels]))
163 |
164 | res_skip_acts = self.res_skip_layers[i](acts)
165 | if i < self.n_layers - 1:
166 | audio = res_skip_acts[:,:self.n_channels,:] + audio
167 | skip_acts = res_skip_acts[:,self.n_channels:,:]
168 | else:
169 | skip_acts = res_skip_acts
170 |
171 | if i == 0:
172 | output = skip_acts
173 | else:
174 | output = skip_acts + output
175 | return self.end(output)
176 |
177 |
178 | class WaveGlow(torch.nn.Module):
179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
180 | n_early_size, WN_config):
181 | super(WaveGlow, self).__init__()
182 |
183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
184 | n_mel_channels,
185 | 1024, stride=256)
186 | assert(n_group % 2 == 0)
187 | self.n_flows = n_flows
188 | self.n_group = n_group
189 | self.n_early_every = n_early_every
190 | self.n_early_size = n_early_size
191 | self.WN = torch.nn.ModuleList()
192 | self.convinv = torch.nn.ModuleList()
193 |
194 | n_half = int(n_group/2)
195 |
196 | # Set up layers with the right sizes based on how many dimensions
197 | # have been output already
198 | n_remaining_channels = n_group
199 | for k in range(n_flows):
200 | if k % self.n_early_every == 0 and k > 0:
201 | n_half = n_half - int(self.n_early_size/2)
202 | n_remaining_channels = n_remaining_channels - self.n_early_size
203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels))
204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
205 | self.n_remaining_channels = n_remaining_channels # Useful during inference
206 |
207 | def forward(self, forward_input):
208 | """
209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
210 | forward_input[1] = audio: batch x time
211 | """
212 | spect, audio = forward_input
213 |
214 | # Upsample spectrogram to size of audio
215 | spect = self.upsample(spect)
216 | assert(spect.size(2) >= audio.size(1))
217 | if spect.size(2) > audio.size(1):
218 | spect = spect[:, :, :audio.size(1)]
219 |
220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
222 |
223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
224 | output_audio = []
225 | log_s_list = []
226 | log_det_W_list = []
227 |
228 | for k in range(self.n_flows):
229 | if k % self.n_early_every == 0 and k > 0:
230 | output_audio.append(audio[:,:self.n_early_size,:])
231 | audio = audio[:,self.n_early_size:,:]
232 |
233 | audio, log_det_W = self.convinv[k](audio)
234 | log_det_W_list.append(log_det_W)
235 |
236 | n_half = int(audio.size(1)/2)
237 | audio_0 = audio[:,:n_half,:]
238 | audio_1 = audio[:,n_half:,:]
239 |
240 | output = self.WN[k]((audio_0, spect))
241 | log_s = output[:, n_half:, :]
242 | b = output[:, :n_half, :]
243 | audio_1 = torch.exp(log_s)*audio_1 + b
244 | log_s_list.append(log_s)
245 |
246 | audio = torch.cat([audio_0, audio_1],1)
247 |
248 | output_audio.append(audio)
249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list
250 |
251 | def infer(self, spect, sigma=1.0):
252 | spect = self.upsample(spect)
253 | # trim conv artifacts. maybe pad spec to kernel multiple
254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
255 | spect = spect[:, :, :-time_cutoff]
256 |
257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
259 |
260 | if spect.type() == 'torch.cuda.HalfTensor':
261 | audio = torch.cuda.HalfTensor(spect.size(0),
262 | self.n_remaining_channels,
263 | spect.size(2)).normal_()
264 | else:
265 | audio = torch.cuda.FloatTensor(spect.size(0),
266 | self.n_remaining_channels,
267 | spect.size(2)).normal_()
268 |
269 | audio = torch.autograd.Variable(sigma*audio)
270 |
271 | for k in reversed(range(self.n_flows)):
272 | n_half = int(audio.size(1)/2)
273 | audio_0 = audio[:,:n_half,:]
274 | audio_1 = audio[:,n_half:,:]
275 |
276 | output = self.WN[k]((audio_0, spect))
277 | s = output[:, n_half:, :]
278 | b = output[:, :n_half, :]
279 | audio_1 = (audio_1 - b)/torch.exp(s)
280 | audio = torch.cat([audio_0, audio_1],1)
281 |
282 | audio = self.convinv[k](audio, reverse=True)
283 |
284 | if k % self.n_early_every == 0 and k > 0:
285 | if spect.type() == 'torch.cuda.HalfTensor':
286 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
287 | else:
288 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
289 | audio = torch.cat((sigma*z, audio),1)
290 |
291 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
292 | return audio
293 |
294 | @staticmethod
295 | def remove_weightnorm(model):
296 | waveglow = model
297 | for WN in waveglow.WN:
298 | WN.start = torch.nn.utils.remove_weight_norm(WN.start)
299 | WN.in_layers = remove(WN.in_layers)
300 | WN.cond_layers = remove(WN.cond_layers)
301 | WN.res_skip_layers = remove(WN.res_skip_layers)
302 | return waveglow
303 |
304 |
305 | def remove(conv_list):
306 | new_conv_list = torch.nn.ModuleList()
307 | for old_conv in conv_list:
308 | old_conv = torch.nn.utils.remove_weight_norm(old_conv)
309 | new_conv_list.append(old_conv)
310 | return new_conv_list
311 |
--------------------------------------------------------------------------------
/waveglow/glow_old.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from glow import Invertible1x1Conv, remove
4 |
5 |
6 | @torch.jit.script
7 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
8 | n_channels_int = n_channels[0]
9 | in_act = input_a+input_b
10 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
11 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
12 | acts = t_act * s_act
13 | return acts
14 |
15 |
16 | class WN(torch.nn.Module):
17 | """
18 | This is the WaveNet like layer for the affine coupling. The primary difference
19 | from WaveNet is the convolutions need not be causal. There is also no dilation
20 | size reset. The dilation only doubles on each layer
21 | """
22 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
23 | kernel_size):
24 | super(WN, self).__init__()
25 | assert(kernel_size % 2 == 1)
26 | assert(n_channels % 2 == 0)
27 | self.n_layers = n_layers
28 | self.n_channels = n_channels
29 | self.in_layers = torch.nn.ModuleList()
30 | self.res_skip_layers = torch.nn.ModuleList()
31 | self.cond_layers = torch.nn.ModuleList()
32 |
33 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
34 | start = torch.nn.utils.weight_norm(start, name='weight')
35 | self.start = start
36 |
37 | # Initializing last layer to 0 makes the affine coupling layers
38 | # do nothing at first. This helps with training stability
39 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
40 | end.weight.data.zero_()
41 | end.bias.data.zero_()
42 | self.end = end
43 |
44 | for i in range(n_layers):
45 | dilation = 2 ** i
46 | padding = int((kernel_size*dilation - dilation)/2)
47 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
48 | dilation=dilation, padding=padding)
49 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
50 | self.in_layers.append(in_layer)
51 |
52 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
53 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
54 | self.cond_layers.append(cond_layer)
55 |
56 | # last one is not necessary
57 | if i < n_layers - 1:
58 | res_skip_channels = 2*n_channels
59 | else:
60 | res_skip_channels = n_channels
61 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
62 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
63 | self.res_skip_layers.append(res_skip_layer)
64 |
65 | def forward(self, forward_input):
66 | audio, spect = forward_input
67 | audio = self.start(audio)
68 |
69 | for i in range(self.n_layers):
70 | acts = fused_add_tanh_sigmoid_multiply(
71 | self.in_layers[i](audio),
72 | self.cond_layers[i](spect),
73 | torch.IntTensor([self.n_channels]))
74 |
75 | res_skip_acts = self.res_skip_layers[i](acts)
76 | if i < self.n_layers - 1:
77 | audio = res_skip_acts[:,:self.n_channels,:] + audio
78 | skip_acts = res_skip_acts[:,self.n_channels:,:]
79 | else:
80 | skip_acts = res_skip_acts
81 |
82 | if i == 0:
83 | output = skip_acts
84 | else:
85 | output = skip_acts + output
86 | return self.end(output)
87 |
88 |
89 | class WaveGlow(torch.nn.Module):
90 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
91 | n_early_size, WN_config):
92 | super(WaveGlow, self).__init__()
93 |
94 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
95 | n_mel_channels,
96 | 1024, stride=256)
97 | assert(n_group % 2 == 0)
98 | self.n_flows = n_flows
99 | self.n_group = n_group
100 | self.n_early_every = n_early_every
101 | self.n_early_size = n_early_size
102 | self.WN = torch.nn.ModuleList()
103 | self.convinv = torch.nn.ModuleList()
104 |
105 | n_half = int(n_group/2)
106 |
107 | # Set up layers with the right sizes based on how many dimensions
108 | # have been output already
109 | n_remaining_channels = n_group
110 | for k in range(n_flows):
111 | if k % self.n_early_every == 0 and k > 0:
112 | n_half = n_half - int(self.n_early_size/2)
113 | n_remaining_channels = n_remaining_channels - self.n_early_size
114 | self.convinv.append(Invertible1x1Conv(n_remaining_channels))
115 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
116 | self.n_remaining_channels = n_remaining_channels # Useful during inference
117 |
118 | def forward(self, forward_input):
119 | return None
120 | """
121 | forward_input[0] = audio: batch x time
122 | forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time
123 | """
124 | """
125 | spect, audio = forward_input
126 |
127 | # Upsample spectrogram to size of audio
128 | spect = self.upsample(spect)
129 | assert(spect.size(2) >= audio.size(1))
130 | if spect.size(2) > audio.size(1):
131 | spect = spect[:, :, :audio.size(1)]
132 |
133 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
134 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
135 |
136 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
137 | output_audio = []
138 | s_list = []
139 | s_conv_list = []
140 |
141 | for k in range(self.n_flows):
142 | if k%4 == 0 and k > 0:
143 | output_audio.append(audio[:,:self.n_multi,:])
144 | audio = audio[:,self.n_multi:,:]
145 |
146 | # project to new basis
147 | audio, s = self.convinv[k](audio)
148 | s_conv_list.append(s)
149 |
150 | n_half = int(audio.size(1)/2)
151 | if k%2 == 0:
152 | audio_0 = audio[:,:n_half,:]
153 | audio_1 = audio[:,n_half:,:]
154 | else:
155 | audio_1 = audio[:,:n_half,:]
156 | audio_0 = audio[:,n_half:,:]
157 |
158 | output = self.nn[k]((audio_0, spect))
159 | s = output[:, n_half:, :]
160 | b = output[:, :n_half, :]
161 | audio_1 = torch.exp(s)*audio_1 + b
162 | s_list.append(s)
163 |
164 | if k%2 == 0:
165 | audio = torch.cat([audio[:,:n_half,:], audio_1],1)
166 | else:
167 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
168 | output_audio.append(audio)
169 | return torch.cat(output_audio,1), s_list, s_conv_list
170 | """
171 |
172 | def infer(self, spect, sigma=1.0):
173 | spect = self.upsample(spect)
174 | # trim conv artifacts. maybe pad spec to kernel multiple
175 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
176 | spect = spect[:, :, :-time_cutoff]
177 |
178 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
179 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
180 |
181 | if spect.type() == 'torch.cuda.HalfTensor':
182 | audio = torch.cuda.HalfTensor(spect.size(0),
183 | self.n_remaining_channels,
184 | spect.size(2)).normal_()
185 | else:
186 | audio = torch.cuda.FloatTensor(spect.size(0),
187 | self.n_remaining_channels,
188 | spect.size(2)).normal_()
189 |
190 | audio = torch.autograd.Variable(sigma*audio)
191 |
192 | for k in reversed(range(self.n_flows)):
193 | n_half = int(audio.size(1)/2)
194 | if k%2 == 0:
195 | audio_0 = audio[:,:n_half,:]
196 | audio_1 = audio[:,n_half:,:]
197 | else:
198 | audio_1 = audio[:,:n_half,:]
199 | audio_0 = audio[:,n_half:,:]
200 |
201 | output = self.WN[k]((audio_0, spect))
202 | s = output[:, n_half:, :]
203 | b = output[:, :n_half, :]
204 | audio_1 = (audio_1 - b)/torch.exp(s)
205 | if k%2 == 0:
206 | audio = torch.cat([audio[:,:n_half,:], audio_1],1)
207 | else:
208 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
209 |
210 | audio = self.convinv[k](audio, reverse=True)
211 |
212 | if k%4 == 0 and k > 0:
213 | if spect.type() == 'torch.cuda.HalfTensor':
214 | z = torch.cuda.HalfTensor(spect.size(0),
215 | self.n_early_size,
216 | spect.size(2)).normal_()
217 | else:
218 | z = torch.cuda.FloatTensor(spect.size(0),
219 | self.n_early_size,
220 | spect.size(2)).normal_()
221 | audio = torch.cat((sigma*z, audio),1)
222 |
223 | return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
224 |
225 | @staticmethod
226 | def remove_weightnorm(model):
227 | waveglow = model
228 | for WN in waveglow.WN:
229 | WN.start = torch.nn.utils.remove_weight_norm(WN.start)
230 | WN.in_layers = remove(WN.in_layers)
231 | WN.cond_layers = remove(WN.cond_layers)
232 | WN.res_skip_layers = remove(WN.res_skip_layers)
233 | return waveglow
234 |
--------------------------------------------------------------------------------
/waveglow/inference.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import os
28 | from scipy.io.wavfile import write
29 | import torch
30 | from mel2samp import files_to_list, MAX_WAV_VALUE
31 | from denoiser import Denoiser
32 |
33 |
34 | def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
35 | denoiser_strength):
36 | mel_files = files_to_list(mel_files)
37 | waveglow = torch.load(waveglow_path)['model']
38 | waveglow = waveglow.remove_weightnorm(waveglow)
39 | waveglow.cuda().eval()
40 | if is_fp16:
41 | from apex import amp
42 | waveglow, _ = amp.initialize(waveglow, [], opt_level="O3")
43 |
44 | if denoiser_strength > 0:
45 | denoiser = Denoiser(waveglow).cuda()
46 |
47 | for i, file_path in enumerate(mel_files):
48 | file_name = os.path.splitext(os.path.basename(file_path))[0]
49 | mel = torch.load(file_path)
50 | mel = torch.autograd.Variable(mel.cuda())
51 | mel = torch.unsqueeze(mel, 0)
52 | mel = mel.half() if is_fp16 else mel
53 | with torch.no_grad():
54 | audio = waveglow.infer(mel, sigma=sigma)
55 | if denoiser_strength > 0:
56 | audio = denoiser(audio, denoiser_strength)
57 | audio = audio * MAX_WAV_VALUE
58 | audio = audio.squeeze()
59 | audio = audio.cpu().numpy()
60 | audio = audio.astype('int16')
61 | audio_path = os.path.join(
62 | output_dir, "{}_synthesis.wav".format(file_name))
63 | write(audio_path, sampling_rate, audio)
64 | print(audio_path)
65 |
66 |
67 | if __name__ == "__main__":
68 | import argparse
69 |
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('-f', "--filelist_path", required=True)
72 | parser.add_argument('-w', '--waveglow_path',
73 | help='Path to waveglow decoder checkpoint with model')
74 | parser.add_argument('-o', "--output_dir", required=True)
75 | parser.add_argument("-s", "--sigma", default=1.0, type=float)
76 | parser.add_argument("--sampling_rate", default=22050, type=int)
77 | parser.add_argument("--is_fp16", action="store_true")
78 | parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float,
79 | help='Removes model bias. Start with 0.1 and adjust')
80 |
81 | args = parser.parse_args()
82 |
83 | main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir,
84 | args.sampling_rate, args.is_fp16, args.denoiser_strength)
85 |
--------------------------------------------------------------------------------
/waveglow/mel2samp.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************\
27 | import os
28 | import random
29 | import argparse
30 | import json
31 | import torch
32 | import torch.utils.data
33 | import sys
34 | from scipy.io.wavfile import read
35 |
36 | # We're using the audio processing from TacoTron2 to make sure it matches
37 | sys.path.insert(0, 'tacotron2')
38 | from tacotron2.layers import TacotronSTFT
39 |
40 | MAX_WAV_VALUE = 32768.0
41 |
42 | def files_to_list(filename):
43 | """
44 | Takes a text file of filenames and makes a list of filenames
45 | """
46 | with open(filename, encoding='utf-8') as f:
47 | files = f.readlines()
48 |
49 | files = [f.rstrip() for f in files]
50 | return files
51 |
52 | def load_wav_to_torch(full_path):
53 | """
54 | Loads wavdata into torch array
55 | """
56 | sampling_rate, data = read(full_path)
57 | return torch.from_numpy(data).float(), sampling_rate
58 |
59 |
60 | class Mel2Samp(torch.utils.data.Dataset):
61 | """
62 | This is the main class that calculates the spectrogram and returns the
63 | spectrogram, audio pair.
64 | """
65 | def __init__(self, training_files, segment_length, filter_length,
66 | hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
67 | self.audio_files = files_to_list(training_files)
68 | random.seed(1234)
69 | random.shuffle(self.audio_files)
70 | self.stft = TacotronSTFT(filter_length=filter_length,
71 | hop_length=hop_length,
72 | win_length=win_length,
73 | sampling_rate=sampling_rate,
74 | mel_fmin=mel_fmin, mel_fmax=mel_fmax)
75 | self.segment_length = segment_length
76 | self.sampling_rate = sampling_rate
77 |
78 | def get_mel(self, audio):
79 | audio_norm = audio / MAX_WAV_VALUE
80 | audio_norm = audio_norm.unsqueeze(0)
81 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
82 | melspec = self.stft.mel_spectrogram(audio_norm)
83 | melspec = torch.squeeze(melspec, 0)
84 | return melspec
85 |
86 | def __getitem__(self, index):
87 | # Read audio
88 | filename = self.audio_files[index]
89 | audio, sampling_rate = load_wav_to_torch(filename)
90 | if sampling_rate != self.sampling_rate:
91 | raise ValueError("{} SR doesn't match target {} SR".format(
92 | sampling_rate, self.sampling_rate))
93 |
94 | # Take segment
95 | if audio.size(0) >= self.segment_length:
96 | max_audio_start = audio.size(0) - self.segment_length
97 | audio_start = random.randint(0, max_audio_start)
98 | audio = audio[audio_start:audio_start+self.segment_length]
99 | else:
100 | audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data
101 |
102 | mel = self.get_mel(audio)
103 | audio = audio / MAX_WAV_VALUE
104 |
105 | return (mel, audio)
106 |
107 | def __len__(self):
108 | return len(self.audio_files)
109 |
110 | # ===================================================================
111 | # Takes directory of clean audio and makes directory of spectrograms
112 | # Useful for making test sets
113 | # ===================================================================
114 | if __name__ == "__main__":
115 | # Get defaults so it can work with no Sacred
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('-f', "--filelist_path", required=True)
118 | parser.add_argument('-c', '--config', type=str,
119 | help='JSON file for configuration')
120 | parser.add_argument('-o', '--output_dir', type=str,
121 | help='Output directory')
122 | args = parser.parse_args()
123 |
124 | with open(args.config) as f:
125 | data = f.read()
126 | data_config = json.loads(data)["data_config"]
127 | mel2samp = Mel2Samp(**data_config)
128 |
129 | filepaths = files_to_list(args.filelist_path)
130 |
131 | # Make directory if it doesn't exist
132 | if not os.path.isdir(args.output_dir):
133 | os.makedirs(args.output_dir)
134 | os.chmod(args.output_dir, 0o775)
135 |
136 | for filepath in filepaths:
137 | audio, sr = load_wav_to_torch(filepath)
138 | melspectrogram = mel2samp.get_mel(audio)
139 | filename = os.path.basename(filepath)
140 | new_filepath = args.output_dir + '/' + filename + '.pt'
141 | print(new_filepath)
142 | torch.save(melspectrogram, new_filepath)
143 |
--------------------------------------------------------------------------------
/waveglow/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.0
2 | matplotlib==2.1.0
3 | tensorflow
4 | numpy==1.13.3
5 | inflect==0.2.5
6 | librosa==0.6.0
7 | scipy==1.0.0
8 | tensorboardX==1.1
9 | Unidecode==1.0.22
10 | pillow
11 |
--------------------------------------------------------------------------------
/waveglow/train.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import argparse
28 | import json
29 | import os
30 | import torch
31 |
32 | #=====START: ADDED FOR DISTRIBUTED======
33 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor
34 | from torch.utils.data.distributed import DistributedSampler
35 | #=====END: ADDED FOR DISTRIBUTED======
36 |
37 | from torch.utils.data import DataLoader
38 | from glow import WaveGlow, WaveGlowLoss
39 | from mel2samp import Mel2Samp
40 |
41 | def load_checkpoint(checkpoint_path, model, optimizer):
42 | assert os.path.isfile(checkpoint_path)
43 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
44 | iteration = checkpoint_dict['iteration']
45 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
46 | model_for_loading = checkpoint_dict['model']
47 | model.load_state_dict(model_for_loading.state_dict())
48 | print("Loaded checkpoint '{}' (iteration {})" .format(
49 | checkpoint_path, iteration))
50 | return model, optimizer, iteration
51 |
52 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
53 | print("Saving model and optimizer state at iteration {} to {}".format(
54 | iteration, filepath))
55 | model_for_saving = WaveGlow(**waveglow_config).cuda()
56 | model_for_saving.load_state_dict(model.state_dict())
57 | torch.save({'model': model_for_saving,
58 | 'iteration': iteration,
59 | 'optimizer': optimizer.state_dict(),
60 | 'learning_rate': learning_rate}, filepath)
61 |
62 | def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
63 | sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
64 | checkpoint_path, with_tensorboard):
65 | torch.manual_seed(seed)
66 | torch.cuda.manual_seed(seed)
67 | #=====START: ADDED FOR DISTRIBUTED======
68 | if num_gpus > 1:
69 | init_distributed(rank, num_gpus, group_name, **dist_config)
70 | #=====END: ADDED FOR DISTRIBUTED======
71 |
72 | criterion = WaveGlowLoss(sigma)
73 | model = WaveGlow(**waveglow_config).cuda()
74 |
75 | #=====START: ADDED FOR DISTRIBUTED======
76 | if num_gpus > 1:
77 | model = apply_gradient_allreduce(model)
78 | #=====END: ADDED FOR DISTRIBUTED======
79 |
80 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
81 |
82 | if fp16_run:
83 | from apex import amp
84 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
85 |
86 | # Load checkpoint if one exists
87 | iteration = 0
88 | if checkpoint_path != "":
89 | model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
90 | optimizer)
91 | iteration += 1 # next iteration is iteration + 1
92 |
93 | trainset = Mel2Samp(**data_config)
94 | # =====START: ADDED FOR DISTRIBUTED======
95 | train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
96 | # =====END: ADDED FOR DISTRIBUTED======
97 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
98 | sampler=train_sampler,
99 | batch_size=batch_size,
100 | pin_memory=False,
101 | drop_last=True)
102 |
103 | # Get shared output_directory ready
104 | if rank == 0:
105 | if not os.path.isdir(output_directory):
106 | os.makedirs(output_directory)
107 | os.chmod(output_directory, 0o775)
108 | print("output directory", output_directory)
109 |
110 | if with_tensorboard and rank == 0:
111 | from tensorboardX import SummaryWriter
112 | logger = SummaryWriter(os.path.join(output_directory, 'logs'))
113 |
114 | model.train()
115 | epoch_offset = max(0, int(iteration / len(train_loader)))
116 | # ================ MAIN TRAINNIG LOOP! ===================
117 | for epoch in range(epoch_offset, epochs):
118 | print("Epoch: {}".format(epoch))
119 | for i, batch in enumerate(train_loader):
120 | model.zero_grad()
121 |
122 | mel, audio = batch
123 | mel = torch.autograd.Variable(mel.cuda())
124 | audio = torch.autograd.Variable(audio.cuda())
125 | outputs = model((mel, audio))
126 |
127 | loss = criterion(outputs)
128 | if num_gpus > 1:
129 | reduced_loss = reduce_tensor(loss.data, num_gpus).item()
130 | else:
131 | reduced_loss = loss.item()
132 |
133 | if fp16_run:
134 | with amp.scale_loss(loss, optimizer) as scaled_loss:
135 | scaled_loss.backward()
136 | else:
137 | loss.backward()
138 |
139 | optimizer.step()
140 |
141 | print("{}:\t{:.9f}".format(iteration, reduced_loss))
142 | if with_tensorboard and rank == 0:
143 | logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)
144 |
145 | if (iteration % iters_per_checkpoint == 0):
146 | if rank == 0:
147 | checkpoint_path = "{}/waveglow_{}".format(
148 | output_directory, iteration)
149 | save_checkpoint(model, optimizer, learning_rate, iteration,
150 | checkpoint_path)
151 |
152 | iteration += 1
153 |
154 | if __name__ == "__main__":
155 | parser = argparse.ArgumentParser()
156 | parser.add_argument('-c', '--config', type=str,
157 | help='JSON file for configuration')
158 | parser.add_argument('-r', '--rank', type=int, default=0,
159 | help='rank of process for distributed')
160 | parser.add_argument('-g', '--group_name', type=str, default='',
161 | help='name of group for distributed')
162 | args = parser.parse_args()
163 |
164 | # Parse configs. Globals nicer in this case
165 | with open(args.config) as f:
166 | data = f.read()
167 | config = json.loads(data)
168 | train_config = config["train_config"]
169 | global data_config
170 | data_config = config["data_config"]
171 | global dist_config
172 | dist_config = config["dist_config"]
173 | global waveglow_config
174 | waveglow_config = config["waveglow_config"]
175 |
176 | num_gpus = torch.cuda.device_count()
177 | if num_gpus > 1:
178 | if args.group_name == '':
179 | print("WARNING: Multiple GPUs detected but no distributed group set")
180 | print("Only running 1 GPU. Use distributed.py for multiple GPUs")
181 | num_gpus = 1
182 |
183 | if num_gpus == 1 and args.rank != 0:
184 | raise Exception("Doing single GPU training on rank > 0")
185 |
186 | torch.backends.cudnn.enabled = True
187 | torch.backends.cudnn.benchmark = False
188 | train(num_gpus, args.rank, args.group_name, **train_config)
189 |
--------------------------------------------------------------------------------
/waveglow/waveglow_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/waveglow/waveglow_logo.png
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0029_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char100000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0029_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char20000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0029_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char200000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0029_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0029_char50000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0085_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char100000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0085_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char20000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0085_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char200000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ001-0085_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ001-0085_char50000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ002-0106_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char100000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ002-0106_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char20000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ002-0106_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char200000.wav
--------------------------------------------------------------------------------
/wavs/gae-char/LJ002-0106_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/gae-char/LJ002-0106_char50000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0029_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char100000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0029_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char20000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0029_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char200000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0029_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0029_char50000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0085_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char100000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0085_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char20000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0085_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char200000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ001-0085_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ001-0085_char50000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ002-0106_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char100000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ002-0106_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char20000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ002-0106_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char200000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char/LJ002-0106_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char/LJ002-0106_char50000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0029_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char100000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0029_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char20000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0029_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char200000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0029_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0029_char50000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0085_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char100000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0085_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char20000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0085_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char200000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ001-0085_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ001-0085_char50000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ002-0106_char100000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char100000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ002-0106_char20000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char20000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ002-0106_char200000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char200000.wav
--------------------------------------------------------------------------------
/wavs/graph-tts-char_iter5/LJ002-0106_char50000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LEEYOONHYUNG/GraphTTS/1b354eadd37a9eea804b5dffd6b74cfb21ef0628/wavs/graph-tts-char_iter5/LJ002-0106_char50000.wav
--------------------------------------------------------------------------------