├── .gitignore
├── LICENSE
├── README.md
├── audio_processing.py
├── commands.sh
├── cpu_gpu_viterbi_benchmark.ipynb
├── extract_alignments.py
├── figures
├── align.png
├── alignments.JPG
├── mdn sample.JPG
├── melspecs.JPG
├── stage0_train_loss.JPG
├── stage0_val_loss.JPG
├── stage1_train_loss.JPG
├── stage1_val_loss.JPG
├── stage2_train_loss.JPG
├── stage2_val_fft_loss.JPG
├── stage2_val_mdn_loss.JPG
├── stage3_train_loss.JPG
└── stage3_val_loss.JPG
├── filelists
├── ljs_audio_text_test_filelist.txt
├── ljs_audio_text_train_filelist.txt
└── ljs_audio_text_val_filelist.txt
├── generate_samples.ipynb
├── hparams.py
├── index.html
├── inference.ipynb
├── layers.py
├── modules
├── __pycache__
│ ├── init_layer.cpython-36.pyc
│ ├── init_layer.cpython-37.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-37.pyc
│ ├── model.cpython-36.pyc
│ ├── model.cpython-37.pyc
│ ├── transformer.cpython-36.pyc
│ └── transformer.cpython-37.pyc
├── init_layer.py
├── loss.py
├── model.py
└── transformer.py
├── prepare_data.ipynb
├── prepare_stages_benchmark.ipynb
├── requirement.txt
├── stft.py
├── text
├── LICENSE
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── cleaners.cpython-36.pyc
│ ├── cleaners.cpython-37.pyc
│ ├── cmudict.cpython-36.pyc
│ ├── cmudict.cpython-37.pyc
│ ├── numbers.cpython-36.pyc
│ ├── numbers.cpython-37.pyc
│ ├── symbols.cpython-36.pyc
│ └── symbols.cpython-37.pyc
├── cleaners.py
├── cmudict.py
├── numbers.py
└── symbols.py
├── train.py
├── training_log
└── readme.txt
├── utils
├── __pycache__
│ ├── data_utils.cpython-36.pyc
│ ├── data_utils.cpython-37.pyc
│ ├── plot_image.cpython-36.pyc
│ ├── plot_image.cpython-37.pyc
│ ├── utils.cpython-36.pyc
│ ├── utils.cpython-37.pyc
│ ├── writer.cpython-36.pyc
│ └── writer.cpython-37.pyc
├── data_utils.py
├── plot_image.py
├── utils.py
└── writer.py
├── waveglow
├── .gitmodules
├── LICENSE
├── README.md
├── config.json
├── convert_model.py
├── denoiser.py
├── distributed.py
├── glow.py
├── glow_old.py
├── inference.py
├── mel2samp.py
├── requirements.txt
├── train.py
└── waveglow_logo.png
└── wavs
├── LJ001-0029_phone10000_10.wav
├── LJ001-0029_phone10000_11.wav
├── LJ001-0029_phone10000_12.wav
├── LJ001-0029_phone10000_8.wav
├── LJ001-0029_phone10000_9.wav
├── LJ001-0085_phone10000_10.wav
├── LJ001-0085_phone10000_11.wav
├── LJ001-0085_phone10000_12.wav
├── LJ001-0085_phone10000_8.wav
├── LJ001-0085_phone10000_9.wav
├── LJ002-0106_phone10000_10.wav
├── LJ002-0106_phone10000_11.wav
├── LJ002-0106_phone10000_12.wav
├── LJ002-0106_phone10000_8.wav
└── LJ002-0106_phone10000_9.wav
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .ipynb_checkpoints/
3 | runs/
4 | training_log/
5 | nohup.out
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Deepest-Project
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AlignTTS
2 | Implementation of the [AlignTTS](https://arxiv.org/abs/2003.01950)
3 |
4 | # Figures
5 | ## Losses
6 | ### stage0
7 |
8 | ### stage1
9 |
10 | ### stage2
11 |
12 | ### stage3
13 |
14 |
15 | ## Alignments
16 |
17 |
18 | ## Melspectrograms
19 |
20 |
21 | ## Losses
22 | ## MDN Sample
23 |
24 |
25 | ## Alignment Smaple
26 |
27 |
28 | ## Audio Samples
29 | You can hear the audio samples [here](https://deepest-project.github.io/AlignTTS)
30 |
--------------------------------------------------------------------------------
/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.signal import get_window
4 | import librosa.util as librosa_util
5 |
6 |
7 | def window_sumsquare(window,
8 | n_frames,
9 | hop_length=200,
10 | win_length=800,
11 | n_fft=800,
12 | dtype=np.float32,
13 | norm=None):
14 | """
15 | # from librosa 0.6
16 | Compute the sum-square envelope of a window function at a given hop length.
17 |
18 | This is used to estimate modulation effects induced by windowing
19 | observations in short-time fourier transforms.
20 |
21 | Parameters
22 | ----------
23 | window : string, tuple, number, callable, or list-like
24 | Window specification, as in `get_window`
25 |
26 | n_frames : int > 0
27 | The number of analysis frames
28 |
29 | hop_length : int > 0
30 | The number of samples to advance between frames
31 |
32 | win_length : [optional]
33 | The length of the window function. By default, this matches `n_fft`.
34 |
35 | n_fft : int > 0
36 | The length of each analysis frame.
37 |
38 | dtype : np.dtype
39 | The data type of the output
40 |
41 | Returns
42 | -------
43 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
44 | The sum-squared envelope of the window function
45 | """
46 | if win_length is None:
47 | win_length = n_fft
48 |
49 | n = n_fft + hop_length * (n_frames - 1)
50 | x = np.zeros(n, dtype=dtype)
51 |
52 | # Compute the squared window at the desired length
53 | win_sq = get_window(window, win_length, fftbins=True)
54 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2
55 | win_sq = librosa_util.pad_center(win_sq, n_fft)
56 |
57 | # Fill the envelope
58 | for i in range(n_frames):
59 | sample = i * hop_length
60 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
61 | return x
62 |
63 |
64 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
65 | """
66 | PARAMS
67 | ------
68 | magnitudes: spectrogram magnitudes
69 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
70 | """
71 |
72 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
73 | angles = angles.astype(np.float32)
74 | angles = torch.autograd.Variable(torch.from_numpy(angles))
75 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
76 |
77 | for i in range(n_iters):
78 | _, angles = stft_fn.transform(signal)
79 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
80 | return signal
81 |
82 |
83 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
84 | """
85 | PARAMS
86 | ------
87 | C: compression factor
88 | """
89 | return torch.log(torch.clamp(x, min=clip_val) * C)
90 |
91 |
92 | def dynamic_range_decompression(x, C=1):
93 | """
94 | PARAMS
95 | ------
96 | C: compression factor used to compress
97 | """
98 | return torch.exp(x) / C
--------------------------------------------------------------------------------
/commands.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --stage=0 &&
4 | python extract_alignments.py &&
5 | python train.py --stage=1 &&
6 | python train.py --stage=2 &&
7 | python extract_alignments.py &&
8 | python train.py --stage=3
9 |
--------------------------------------------------------------------------------
/extract_alignments.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | os.environ["CUDA_VISIBLE_DEVICES"]='0'
3 |
4 | import warnings
5 | warnings.filterwarnings("ignore")
6 |
7 | import sys
8 | sys.path.append('waveglow/')
9 |
10 | import IPython.display as ipd
11 | import pickle as pkl
12 | import torch
13 | import torch.nn.functional as F
14 | import hparams
15 | from torch.utils.data import DataLoader
16 | from modules.model import Model
17 | from text import text_to_sequence, sequence_to_text
18 | from denoiser import Denoiser
19 | from tqdm import tqdm
20 | import librosa
21 | from modules.loss import MDNLoss
22 | import math
23 | import numpy as np
24 | from datetime import datetime
25 |
26 | def main():
27 | data_type = 'phone'
28 | checkpoint_path = f"training_log/aligntts/stage0/checkpoint_{hparams.train_steps[0]}"
29 | state_dict = {}
30 |
31 | for k, v in torch.load(checkpoint_path)['state_dict'].items():
32 | state_dict[k[7:]]=v
33 |
34 | model = Model(hparams).cuda()
35 | model.load_state_dict(state_dict)
36 | _ = model.cuda().eval()
37 | criterion = MDNLoss()
38 |
39 | datasets = ['train', 'val', 'test']
40 | batch_size=64
41 |
42 | for dataset in datasets:
43 | with open(f'filelists/ljs_audio_text_{dataset}_filelist.txt', 'r') as f:
44 | lines_raw = [line.split('|') for line in f.read().splitlines()]
45 | lines_list = [ lines_raw[batch_size*i:batch_size*(i+1)]
46 | for i in range(len(lines_raw)//batch_size+1)]
47 |
48 | for batch in tqdm(lines_list):
49 | file_list, text_list, mel_list = [], [], []
50 | text_lengths, mel_lengths=[], []
51 |
52 | for i in range(len(batch)):
53 | file_name, _, text = batch[i]
54 | file_list.append(file_name)
55 | seq = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',
56 | f'{data_type}_seq')
57 | mel = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',
58 | 'melspectrogram')
59 |
60 | seq = torch.from_numpy(np.load(f'{seq}/{file_name}_sequence.npy'))
61 | mel = torch.from_numpy(np.load(f'{mel}/{file_name}_melspectrogram.npy'))
62 |
63 | text_list.append(seq)
64 | mel_list.append(mel)
65 | text_lengths.append(seq.size(0))
66 | mel_lengths.append(mel.size(1))
67 |
68 | text_lengths = torch.LongTensor(text_lengths)
69 | mel_lengths = torch.LongTensor(mel_lengths)
70 | text_padded = torch.zeros(len(batch), text_lengths.max().item(), dtype=torch.long)
71 | mel_padded = torch.zeros(len(batch), hparams.n_mel_channels, mel_lengths.max().item())
72 |
73 | for j in range(len(batch)):
74 | text_padded[j, :text_list[j].size(0)] = text_list[j]
75 | mel_padded[j, :, :mel_list[j].size(1)] = mel_list[j]
76 |
77 | text_padded = text_padded.cuda()
78 | mel_padded = mel_padded.cuda()
79 | mel_padded = (torch.clamp(mel_padded, hparams.min_db, hparams.max_db)-hparams.min_db) / (hparams.max_db-hparams.min_db)
80 | text_lengths = text_lengths.cuda()
81 | mel_lengths = mel_lengths.cuda()
82 |
83 | with torch.no_grad():
84 | encoder_input = model.Prenet(text_padded)
85 | hidden_states, _ = model.FFT_lower(encoder_input, text_lengths)
86 | mu_sigma = model.get_mu_sigma(hidden_states)
87 | _, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)
88 |
89 | align = model.viterbi(log_prob_matrix, text_lengths, mel_lengths).to(torch.long)
90 | alignments = list(torch.split(align,1))
91 |
92 | for j, (l, t) in enumerate(zip(text_lengths, mel_lengths)):
93 | alignments[j] = alignments[j][0, :l.item(), :t.item()].sum(dim=-1)
94 | np.save(f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_list[j]}_alignment.npy',
95 | alignments[j].detach().cpu().numpy())
96 |
97 | print("Alignments Extraction End!!! ({datetime.now()})")
98 |
99 |
100 | if __name__ == '__main__':
101 | p = argparse.ArgumentParser()
102 | p.add_argument('--gpu', type=str, default='0')
103 | p.add_argument('-v', '--verbose', type=str, default='0')
104 | args = p.parse_args()
105 |
106 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
107 | torch.manual_seed(hparams.seed)
108 | torch.cuda.manual_seed(hparams.seed)
109 |
110 | if args.verbose=='0':
111 | import warnings
112 | warnings.filterwarnings("ignore")
113 |
114 | main()
--------------------------------------------------------------------------------
/figures/align.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/align.png
--------------------------------------------------------------------------------
/figures/alignments.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/alignments.JPG
--------------------------------------------------------------------------------
/figures/mdn sample.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/mdn sample.JPG
--------------------------------------------------------------------------------
/figures/melspecs.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/melspecs.JPG
--------------------------------------------------------------------------------
/figures/stage0_train_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage0_train_loss.JPG
--------------------------------------------------------------------------------
/figures/stage0_val_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage0_val_loss.JPG
--------------------------------------------------------------------------------
/figures/stage1_train_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage1_train_loss.JPG
--------------------------------------------------------------------------------
/figures/stage1_val_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage1_val_loss.JPG
--------------------------------------------------------------------------------
/figures/stage2_train_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage2_train_loss.JPG
--------------------------------------------------------------------------------
/figures/stage2_val_fft_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage2_val_fft_loss.JPG
--------------------------------------------------------------------------------
/figures/stage2_val_mdn_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage2_val_mdn_loss.JPG
--------------------------------------------------------------------------------
/figures/stage3_train_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage3_train_loss.JPG
--------------------------------------------------------------------------------
/figures/stage3_val_loss.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/figures/stage3_val_loss.JPG
--------------------------------------------------------------------------------
/generate_samples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Import libraries and setup matplotlib"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os\n",
17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
18 | "\n",
19 | "import warnings\n",
20 | "warnings.filterwarnings(\"ignore\")\n",
21 | "\n",
22 | "import sys\n",
23 | "sys.path.append('waveglow/')\n",
24 | "\n",
25 | "import matplotlib.pyplot as plt\n",
26 | "%matplotlib inline\n",
27 | "\n",
28 | "import IPython.display as ipd\n",
29 | "import pickle as pkl\n",
30 | "from text import *\n",
31 | "import numpy as np\n",
32 | "import torch\n",
33 | "import hparams\n",
34 | "from modules.model import Model\n",
35 | "from denoiser import Denoiser\n",
36 | "import soundfile"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### Text preprocessing"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "from g2p_en import G2p\n",
53 | "from text.symbols import symbols\n",
54 | "from text.cleaners import custom_english_cleaners\n",
55 | "\n",
56 | "# Mappings from symbol to numeric ID and vice versa:\n",
57 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n",
58 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n",
59 | "\n",
60 | "g2p = G2p()\n",
61 | "def text2seq(text, data_type='char'):\n",
62 | " text = custom_english_cleaners(text.rstrip())\n",
63 | " if data_type=='phone':\n",
64 | " clean_phone = []\n",
65 | " for s in g2p(text.lower()):\n",
66 | " if '@'+s in symbol_to_id:\n",
67 | " clean_phone.append('@'+s)\n",
68 | " else:\n",
69 | " clean_phone.append(s)\n",
70 | " text = clean_phone\n",
71 | " \n",
72 | " # Append SOS, EOS token\n",
73 | " sequence = [symbol_to_id[c] for c in text]\n",
74 | " sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n",
75 | " return sequence"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "### Waveglow"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "code_folding": []
90 | },
91 | "outputs": [],
92 | "source": [
93 | "waveglow_path = 'training_log/waveglow_256channels.pt'\n",
94 | "waveglow = torch.load(waveglow_path)['model']\n",
95 | "\n",
96 | "for m in waveglow.modules():\n",
97 | " if 'Conv' in str(type(m)):\n",
98 | " setattr(m, 'padding_mode', 'zeros')\n",
99 | "\n",
100 | "waveglow.cuda().eval()\n",
101 | "for k in waveglow.convinv:\n",
102 | " k.float()\n",
103 | "\n",
104 | "denoiser = Denoiser(waveglow)\n",
105 | "\n",
106 | "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n",
107 | " lines = [line.split('|') for line in f.read().splitlines()]"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {
114 | "scrolled": false
115 | },
116 | "outputs": [],
117 | "source": [
118 | "for data_type in ['phone']:\n",
119 | " for step in ['10000']:\n",
120 | " checkpoint_path = f\"training_log/aligntts/stage3/checkpoint_{step}\"\n",
121 | " state_dict = {}\n",
122 | " for k, v in torch.load(checkpoint_path)['state_dict'].items():\n",
123 | " state_dict[k[7:]]=v\n",
124 | "\n",
125 | " model = Model(hparams).cuda()\n",
126 | " model.load_state_dict(state_dict)\n",
127 | " _ = model.cuda().eval()\n",
128 | "\n",
129 | " for i in [1, 6, 22]:\n",
130 | " file_name, _, text = lines[i]\n",
131 | " sequence = np.array(text2seq(text,data_type))[None, :]\n",
132 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n",
133 | " \n",
134 | " for alpha in [0.8, 0.9, 1.0, 1.1, 1.2]:\n",
135 | " with torch.no_grad():\n",
136 | " melspec, durations = model.inference(sequence, alpha)\n",
137 | " melspec = melspec*(hparams.max_db-hparams.min_db)+hparams.min_db\n",
138 | " audio = waveglow.infer(melspec, sigma=0.666)\n",
139 | "\n",
140 | " soundfile.write(f'wavs/{file_name}_{data_type}{step}_{str(int(10*alpha))}.wav', audio.cpu().numpy()[0].astype(float), 22050)"
141 | ]
142 | }
143 | ],
144 | "metadata": {
145 | "kernelspec": {
146 | "display_name": "Environment (conda_pytorch_p36)",
147 | "language": "python",
148 | "name": "conda_pytorch_p36"
149 | },
150 | "language_info": {
151 | "codemirror_mode": {
152 | "name": "ipython",
153 | "version": 3
154 | },
155 | "file_extension": ".py",
156 | "mimetype": "text/x-python",
157 | "name": "python",
158 | "nbconvert_exporter": "python",
159 | "pygments_lexer": "ipython3",
160 | "version": "3.6.5"
161 | },
162 | "varInspector": {
163 | "cols": {
164 | "lenName": 16,
165 | "lenType": 16,
166 | "lenVar": 40
167 | },
168 | "kernels_config": {
169 | "python": {
170 | "delete_cmd_postfix": "",
171 | "delete_cmd_prefix": "del ",
172 | "library": "var_list.py",
173 | "varRefreshCmd": "print(var_dic_list())"
174 | },
175 | "r": {
176 | "delete_cmd_postfix": ") ",
177 | "delete_cmd_prefix": "rm(",
178 | "library": "var_list.r",
179 | "varRefreshCmd": "cat(var_dic_list()) "
180 | }
181 | },
182 | "types_to_exclude": [
183 | "module",
184 | "function",
185 | "builtin_function_or_method",
186 | "instance",
187 | "_Feature"
188 | ],
189 | "window_display": false
190 | }
191 | },
192 | "nbformat": 4,
193 | "nbformat_minor": 2
194 | }
195 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | from text import symbols
2 |
3 | ################################
4 | # Experiment Parameters #
5 | ################################
6 | seed=1234
7 | n_gpus=2
8 | output_directory = 'training_log'
9 | log_directory = 'aligntts'
10 | data_path = '../Dataset/LJSpeech-1.1/preprocessed'
11 |
12 | training_files='filelists/ljs_audio_text_train_filelist.txt'
13 | validation_files='filelists/ljs_audio_text_val_filelist.txt'
14 | text_cleaners=['english_cleaners']
15 |
16 |
17 | ################################
18 | # Audio Parameters #
19 | ################################
20 | sampling_rate=22050
21 | filter_length=1024
22 | hop_length=256
23 | win_length=1024
24 | n_mel_channels=80
25 | mel_fmin=0
26 | mel_fmax=8000.0
27 |
28 | ################################
29 | # Model Parameters #
30 | ################################
31 | n_symbols=len(symbols)
32 | data_type='phone_seq' # 'phone_seq'
33 | symbols_embedding_dim=256
34 | hidden_dim=256
35 | dprenet_dim=256
36 | postnet_dim=256
37 | ff_dim=1024
38 | n_heads=2
39 | n_layers=6
40 | max_db=2
41 | min_db=-12
42 |
43 | ################################
44 | # Optimization Hyperparameters #
45 | ################################
46 | lr=384**-0.5
47 | warmup_steps=4000
48 | grad_clip_thresh=1.0
49 | batch_size=32
50 | accumulation=1
51 | iters_per_validation=1000
52 | iters_per_checkpoint=10000
53 | train_steps = [40000, 40000, 80000, 10000]
54 |
55 |
56 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | AlignTTS Audio Samples
6 |
7 |
8 |
9 |
10 |
11 |
AlignTTS (phoneme)
12 |
13 |
14 |
15 | Steps |
16 | alpha=0.8 |
17 | alpha=0.9 |
18 | alpha=1.0 |
19 | alpha=1.1 |
20 | alpha=1.2 |
21 |
22 |
23 |
24 |
25 | LJ001-0029 |
26 | |
28 | |
30 | |
32 | |
34 | |
36 |
37 |
38 | LJ001-0085 |
39 | |
41 | |
43 | |
45 | |
47 | |
49 |
50 |
51 | LJ002-0106 |
52 | |
54 | |
56 | |
58 | |
60 | |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Import libraries and setup matplotlib"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os\n",
17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
18 | "\n",
19 | "import warnings\n",
20 | "warnings.filterwarnings(\"ignore\")\n",
21 | "\n",
22 | "import sys\n",
23 | "sys.path.append('waveglow/')\n",
24 | "\n",
25 | "import matplotlib.pyplot as plt\n",
26 | "%matplotlib inline\n",
27 | "\n",
28 | "import IPython.display as ipd\n",
29 | "import pickle as pkl\n",
30 | "import librosa\n",
31 | "from text import *\n",
32 | "import numpy as np\n",
33 | "import torch\n",
34 | "import hparams\n",
35 | "from modules.model import Model\n",
36 | "from denoiser import Denoiser"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### Text preprocessing"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "from g2p_en import G2p\n",
53 | "from text.symbols import symbols\n",
54 | "from text.cleaners import custom_english_cleaners\n",
55 | "\n",
56 | "# Mappings from symbol to numeric ID and vice versa:\n",
57 | "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n",
58 | "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n",
59 | "\n",
60 | "g2p = G2p()\n",
61 | "def text2seq(text, data_type='char'):\n",
62 | " text = custom_english_cleaners(text.rstrip())\n",
63 | " if data_type=='phone':\n",
64 | " clean_phone = []\n",
65 | " for s in g2p(text.lower()):\n",
66 | " if '@'+s in symbol_to_id:\n",
67 | " clean_phone.append('@'+s)\n",
68 | " else:\n",
69 | " clean_phone.append(s)\n",
70 | " text = clean_phone\n",
71 | " \n",
72 | " # Append SOS, EOS token\n",
73 | " sequence = [symbol_to_id[c] for c in text]\n",
74 | " sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n",
75 | " return sequence"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "### Waveglow"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "code_folding": []
90 | },
91 | "outputs": [],
92 | "source": [
93 | "waveglow_path = 'training_log/waveglow_256channels.pt'\n",
94 | "waveglow = torch.load(waveglow_path)['model']\n",
95 | "\n",
96 | "for m in waveglow.modules():\n",
97 | " if 'Conv' in str(type(m)):\n",
98 | " setattr(m, 'padding_mode', 'zeros')\n",
99 | "\n",
100 | "waveglow.cuda().eval()\n",
101 | "for k in waveglow.convinv:\n",
102 | " k.float()\n",
103 | "denoiser = Denoiser(waveglow)\n",
104 | "\n",
105 | "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n",
106 | " lines = [line.split('|') for line in f.read().splitlines()]"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "scrolled": false
114 | },
115 | "outputs": [],
116 | "source": [
117 | "for data_type in ['phone']:\n",
118 | " for step in ['10000']:\n",
119 | " print(f'{data_type}_{step}steps')\n",
120 | " checkpoint_path = f\"training_log/aligntts/stage3/checkpoint_{step}\"\n",
121 | " state_dict = {}\n",
122 | " for k, v in torch.load(checkpoint_path)['state_dict'].items():\n",
123 | " state_dict[k[7:]]=v\n",
124 | "\n",
125 | " model = Model(hparams).cuda()\n",
126 | " model.load_state_dict(state_dict)\n",
127 | " _ = model.cuda().eval()\n",
128 | "\n",
129 | " for i in [1, 6, 22]:\n",
130 | " file_name, _, text = lines[i]\n",
131 | " sequence = np.array(text2seq(text, data_type))[None, :]\n",
132 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n",
133 | "\n",
134 | " print(f'Text: {text}')\n",
135 | " for alpha in [0.8, 0.9, 1.0, 1.1, 1.2]:\n",
136 | " with torch.no_grad():\n",
137 | " melspec, durations = model.inference(sequence, alpha)\n",
138 | " melspec = melspec*(hparams.max_db-hparams.min_db)+hparams.min_db\n",
139 | " audio = waveglow.infer(melspec, sigma=0.666)\n",
140 | "\n",
141 | " print(f\"alpha: {alpha}\")\n",
142 | " ipd.display(ipd.Audio(audio.cpu().numpy(), rate=hparams.sampling_rate))\n",
143 | "\n",
144 | " if alpha==1.0:\n",
145 | " ticks=[]\n",
146 | " phoneme = sequence_to_text(sequence[0].tolist())\n",
147 | " for i, d in enumerate(durations[0]):\n",
148 | " ticks.extend([phoneme[i]]*int(d))\n",
149 | "\n",
150 | " plt.figure(figsize=(20,5))\n",
151 | " plt.imshow(melspec.detach().cpu()[0], aspect='auto', origin='lower')\n",
152 | " plt.xticks(range(melspec.size(2)), ticks)\n",
153 | " plt.show()"
154 | ]
155 | }
156 | ],
157 | "metadata": {
158 | "kernelspec": {
159 | "display_name": "Environment (conda_pytorch_p36)",
160 | "language": "python",
161 | "name": "conda_pytorch_p36"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.6.5"
174 | },
175 | "varInspector": {
176 | "cols": {
177 | "lenName": 16,
178 | "lenType": 16,
179 | "lenVar": 40
180 | },
181 | "kernels_config": {
182 | "python": {
183 | "delete_cmd_postfix": "",
184 | "delete_cmd_prefix": "del ",
185 | "library": "var_list.py",
186 | "varRefreshCmd": "print(var_dic_list())"
187 | },
188 | "r": {
189 | "delete_cmd_postfix": ") ",
190 | "delete_cmd_prefix": "rm(",
191 | "library": "var_list.r",
192 | "varRefreshCmd": "cat(var_dic_list()) "
193 | }
194 | },
195 | "types_to_exclude": [
196 | "module",
197 | "function",
198 | "builtin_function_or_method",
199 | "instance",
200 | "_Feature"
201 | ],
202 | "window_display": false
203 | }
204 | },
205 | "nbformat": 4,
206 | "nbformat_minor": 2
207 | }
208 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from librosa.filters import mel as librosa_mel_fn
3 | from audio_processing import dynamic_range_compression
4 | from audio_processing import dynamic_range_decompression
5 | from stft import STFT
6 |
7 |
8 | class LinearNorm(torch.nn.Module):
9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
10 | super(LinearNorm, self).__init__()
11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
12 |
13 | torch.nn.init.xavier_uniform_(
14 | self.linear_layer.weight,
15 | gain=torch.nn.init.calculate_gain(w_init_gain))
16 |
17 | def forward(self, x):
18 | return self.linear_layer(x)
19 |
20 |
21 | class ConvNorm(torch.nn.Module):
22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
23 | padding=None, dilation=1, bias=True, w_init_gain='linear'):
24 | super(ConvNorm, self).__init__()
25 | if padding is None:
26 | assert(kernel_size % 2 == 1)
27 | padding = int(dilation * (kernel_size - 1) / 2)
28 |
29 | self.conv = torch.nn.Conv1d(in_channels, out_channels,
30 | kernel_size=kernel_size, stride=stride,
31 | padding=padding, dilation=dilation,
32 | bias=bias)
33 |
34 | torch.nn.init.xavier_uniform_(
35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
36 |
37 | def forward(self, signal):
38 | conv_signal = self.conv(signal)
39 | return conv_signal
40 |
41 |
42 | class TacotronSTFT(torch.nn.Module):
43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
45 | mel_fmax=8000.0):
46 | super(TacotronSTFT, self).__init__()
47 | self.n_mel_channels = n_mel_channels
48 | self.sampling_rate = sampling_rate
49 | self.stft_fn = STFT(filter_length, hop_length, win_length)
50 | mel_basis = librosa_mel_fn(
51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
52 | mel_basis = torch.from_numpy(mel_basis).float()
53 | self.register_buffer('mel_basis', mel_basis)
54 |
55 | def spectral_normalize(self, magnitudes):
56 | output = dynamic_range_compression(magnitudes)
57 | return output
58 |
59 | def spectral_de_normalize(self, magnitudes):
60 | output = dynamic_range_decompression(magnitudes)
61 | return output
62 |
63 | def mel_spectrogram(self, y):
64 | """Computes mel-spectrograms from a batch of waves
65 | PARAMS
66 | ------
67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
68 | RETURNS
69 | -------
70 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
71 | """
72 | assert(torch.min(y.data) >= -1)
73 | assert(torch.max(y.data) <= 1)
74 |
75 | magnitudes, phases = self.stft_fn.transform(y)
76 | magnitudes = magnitudes.data
77 | mel_output = torch.matmul(self.mel_basis, magnitudes)
78 | mel_output = self.spectral_normalize(mel_output)
79 | return mel_output
--------------------------------------------------------------------------------
/modules/__pycache__/init_layer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/init_layer.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/init_layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/init_layer.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/transformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/transformer.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/modules/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/modules/init_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Linear(nn.Linear):
7 | def __init__(self,
8 | in_dim,
9 | out_dim,
10 | bias=True,
11 | w_init_gain='linear'):
12 | super(Linear, self).__init__(in_dim,
13 | out_dim,
14 | bias)
15 | nn.init.xavier_uniform_(self.weight,
16 | gain=nn.init.calculate_gain(w_init_gain))
17 |
18 |
19 | class Conv1d(nn.Conv1d):
20 | def __init__(self,
21 | in_channels,
22 | out_channels,
23 | kernel_size,
24 | stride=1,
25 | padding=0,
26 | dilation=1,
27 | groups=1,
28 | bias=True,
29 | padding_mode='zeros',
30 | w_init_gain='linear'):
31 | super(Conv1d, self).__init__(in_channels,
32 | out_channels,
33 | kernel_size,
34 | stride,
35 | padding,
36 | dilation,
37 | groups,
38 | bias,
39 | padding_mode)
40 | nn.init.xavier_uniform_(self.weight,
41 | gain=nn.init.calculate_gain(w_init_gain))
--------------------------------------------------------------------------------
/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import hparams as hp
5 | from utils.utils import get_mask_from_lengths
6 | import math
7 |
8 |
9 | class MDNLoss(nn.Module):
10 | def __init__(self):
11 | super(MDNLoss, self).__init__()
12 |
13 | def forward(self, mu_sigma, melspec, text_lengths, mel_lengths):
14 | # mu, sigma: B, L, F / melspec: B, F, T
15 | B, L, _ = mu_sigma.size()
16 | T = melspec.size(2)
17 |
18 | x = melspec.transpose(1,2).unsqueeze(1) # B, 1, T, F
19 | mu = torch.sigmoid(mu_sigma[:, :, :hp.n_mel_channels].unsqueeze(2)) # B, L, 1, F
20 | log_sigma = mu_sigma[:, :, hp.n_mel_channels:].unsqueeze(2) # B, L, 1, F
21 |
22 | exponential = -0.5*torch.sum((x-mu)*(x-mu)/log_sigma.exp()**2, dim=-1) # B, L, T
23 | log_prob_matrix = exponential - (hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) - 0.5 * log_sigma.sum(dim=-1)
24 | log_alpha = mu_sigma.new_ones(B, L, T)*(-1e30)
25 | log_alpha[:,0, 0] = log_prob_matrix[:,0, 0]
26 |
27 | for t in range(1, T):
28 | prev_step = torch.cat([log_alpha[:, :, t-1:t], F.pad(log_alpha[:, :, t-1:t], (0,0,1,-1), value=-1e30)], dim=-1)
29 | log_alpha[:, :, t] = torch.logsumexp(prev_step+1e-30, dim=-1)+log_prob_matrix[:, :, t]
30 |
31 | alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
32 | mdn_loss = -alpha_last.mean()
33 |
34 | return mdn_loss, log_prob_matrix
--------------------------------------------------------------------------------
/modules/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .init_layer import *
5 | from .transformer import *
6 | from utils.utils import get_mask_from_lengths
7 |
8 | from datetime import datetime
9 | from time import sleep
10 |
11 | class Prenet(nn.Module):
12 | def __init__(self, hp):
13 | super(Prenet, self).__init__()
14 | # B, L -> B, L, D
15 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim)
16 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe)
17 | self.dropout = nn.Dropout(0.1)
18 |
19 | def forward(self, text):
20 | B, L = text.size(0), text.size(1)
21 | x = self.Embedding(text).transpose(0,1)
22 | x += self.pe[:L].unsqueeze(1)
23 | x = self.dropout(x).transpose(0,1)
24 | return x
25 |
26 |
27 | class FFT(nn.Module):
28 | def __init__(self, hidden_dim, n_heads, ff_dim, n_layers):
29 | super(FFT, self).__init__()
30 | self.FFT_layers = nn.ModuleList([TransformerEncoderLayer(d_model=hidden_dim,
31 | nhead=n_heads,
32 | dim_feedforward=ff_dim)
33 | for _ in range(n_layers)])
34 | def forward(self, x, lengths):
35 | # B, L, D -> B, L, D
36 | alignments = []
37 | x = x.transpose(0,1)
38 | mask = get_mask_from_lengths(lengths)
39 | for layer in self.FFT_layers:
40 | x, align = layer(x, src_key_padding_mask=mask)
41 | alignments.append(align.unsqueeze(1))
42 | alignments = torch.cat(alignments, 1)
43 |
44 | return x.transpose(0,1), alignments
45 |
46 |
47 | class DurationPredictor(nn.Module):
48 | def __init__(self, hp):
49 | super(DurationPredictor, self).__init__()
50 | self.Prenet = Prenet(hp)
51 | self.FFT = FFT(hp.hidden_dim, hp.n_heads, hp.ff_dim, 2)
52 | self.linear = Linear(hp.hidden_dim, 1)
53 |
54 | def forward(self, text, text_lengths):
55 | # B, L -> B, L
56 | encoder_input = self.Prenet(text)
57 | x = self.FFT(encoder_input, text_lengths)[0]
58 | x = self.linear(x).squeeze(-1)
59 | return x
60 |
61 |
62 | class Model(nn.Module):
63 | def __init__(self, hp):
64 | super(Model, self).__init__()
65 | self.Prenet = Prenet(hp)
66 | self.FFT_lower = FFT(hp.hidden_dim, hp.n_heads, hp.ff_dim, hp.n_layers)
67 | self.FFT_upper = FFT(hp.hidden_dim, hp.n_heads, hp.ff_dim, hp.n_layers)
68 | self.MDN = nn.Sequential(Linear(hp.hidden_dim, hp.hidden_dim),
69 | nn.LayerNorm(hp.hidden_dim),
70 | nn.ReLU(),
71 | nn.Dropout(0.1),
72 | Linear(hp.hidden_dim, 2*hp.n_mel_channels))
73 | self.DurationPredictor = DurationPredictor(hp)
74 | self.Projection = Linear(hp.hidden_dim, hp.n_mel_channels)
75 |
76 | def get_mu_sigma(self, hidden_states):
77 | mu_sigma = self.MDN(hidden_states)
78 | return mu_sigma
79 |
80 | def get_duration(self, text, text_lengths):
81 | durations = self.DurationPredictor(text, text_lengths).exp()
82 | return durations
83 |
84 | def get_melspec(self, hidden_states, align, mel_lengths):
85 | hidden_states_expanded = torch.matmul(align.transpose(1,2), hidden_states)
86 | hidden_states_expanded += self.Prenet.pe[:hidden_states_expanded.size(1)].unsqueeze(1).transpose(0,1)
87 | mel_out = torch.sigmoid(self.Projection(self.FFT_upper(hidden_states_expanded, mel_lengths)[0]).transpose(1,2))
88 | return mel_out
89 |
90 | def forward(self, text, melspec, align, text_lengths, mel_lengths, criterion, stage, log_viterbi=False, cpu_viterbi=False):
91 | text = text[:,:text_lengths.max().item()]
92 | melspec = melspec[:,:,:mel_lengths.max().item()]
93 |
94 | if stage==0:
95 | encoder_input = self.Prenet(text)
96 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
97 | mu_sigma = self.get_mu_sigma(hidden_states)
98 | mdn_loss, _ = criterion(mu_sigma, melspec, text_lengths, mel_lengths)
99 | return mdn_loss
100 |
101 | elif stage==1:
102 | align = align[:, :text_lengths.max().item(), :mel_lengths.max().item()]
103 | encoder_input = self.Prenet(text)
104 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
105 | mel_out = self.get_melspec(hidden_states, align, mel_lengths)
106 |
107 | mel_mask = ~get_mask_from_lengths(mel_lengths)
108 | melspec = melspec.masked_select(mel_mask.unsqueeze(1))
109 | mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
110 | fft_loss = nn.L1Loss()(mel_out, melspec)
111 |
112 | return fft_loss
113 |
114 | elif stage==2:
115 | encoder_input = self.Prenet(text)
116 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
117 | mu_sigma = self.get_mu_sigma(hidden_states)
118 | mdn_loss, log_prob_matrix = criterion(mu_sigma, melspec, text_lengths, mel_lengths)
119 |
120 | before = datetime.now()
121 | if cpu_viterbi:
122 | align = self.viterbi_cpu(log_prob_matrix, text_lengths.cpu(), mel_lengths.cpu()) # B, T
123 | else:
124 | align = self.viterbi(log_prob_matrix, text_lengths, mel_lengths) # B, T
125 | after = datetime.now()
126 |
127 | if log_viterbi:
128 | time_delta = after - before
129 | print(f'Viterbi took {time_delta.total_seconds()} secs')
130 |
131 | mel_out = self.get_melspec(hidden_states, align, mel_lengths)
132 |
133 | mel_mask = ~get_mask_from_lengths(mel_lengths)
134 | melspec = melspec.masked_select(mel_mask.unsqueeze(1))
135 | mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
136 | fft_loss = nn.L1Loss()(mel_out, melspec)
137 |
138 | return mdn_loss + fft_loss
139 |
140 | elif stage==3:
141 | align = align[:, :text_lengths.max().item(), :mel_lengths.max().item()]
142 | duration_out = self.get_duration(text, text_lengths) # gradient cut
143 | duration_target = align.sum(-1)
144 |
145 | duration_mask = ~get_mask_from_lengths(text_lengths)
146 | duration_target = duration_target.masked_select(duration_mask)
147 | duration_out = duration_out.masked_select(duration_mask)
148 | duration_loss = nn.MSELoss()(torch.log(duration_out), torch.log(duration_target))
149 |
150 | return duration_loss
151 |
152 |
153 | def inference(self, text, alpha=1.0):
154 | text_lengths = text.new_tensor([text.size(1)])
155 | encoder_input = self.Prenet(text)
156 | hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
157 | durations = self.get_duration(text, text_lengths)
158 | durations = torch.round(durations*alpha).to(torch.long)
159 | durations[durations<=0]=1
160 | T=int(durations.sum().item())
161 | mel_lengths = text.new_tensor([T])
162 | hidden_states_expanded = torch.repeat_interleave(hidden_states, durations[0], dim=1)
163 | hidden_states_expanded += self.Prenet.pe[:hidden_states_expanded.size(1)].unsqueeze(1).transpose(0,1)
164 | mel_out = torch.sigmoid(self.Projection(self.FFT_upper(hidden_states_expanded, mel_lengths)[0]).transpose(1,2))
165 |
166 | return mel_out, durations
167 |
168 |
169 | def viterbi(self, log_prob_matrix, text_lengths, mel_lengths):
170 | B, L, T = log_prob_matrix.size()
171 | log_beta = log_prob_matrix.new_ones(B, L, T)*(-1e15)
172 | log_beta[:, 0, 0] = log_prob_matrix[:, 0, 0]
173 |
174 | for t in range(1, T):
175 | prev_step = torch.cat([log_beta[:, :, t-1:t], F.pad(log_beta[:, :, t-1:t], (0,0,1,-1), value=-1e15)], dim=-1).max(dim=-1)[0]
176 | log_beta[:, :, t] = prev_step+log_prob_matrix[:, :, t]
177 |
178 | curr_rows = text_lengths-1
179 | curr_cols = mel_lengths-1
180 | path = [curr_rows*1.0]
181 | for _ in range(T-1):
182 | is_go = log_beta[torch.arange(B), (curr_rows-1).to(torch.long), (curr_cols-1).to(torch.long)]\
183 | > log_beta[torch.arange(B), (curr_rows).to(torch.long), (curr_cols-1).to(torch.long)]
184 | curr_rows = F.relu(curr_rows-1.0*is_go+1.0)-1.0
185 | curr_cols = F.relu(curr_cols-1+1.0)-1.0
186 | path.append(curr_rows*1.0)
187 |
188 | path.reverse()
189 | path = torch.stack(path, -1)
190 |
191 | indices = path.new_tensor(torch.arange(path.max()+1).view(1,1,-1)) # 1, 1, L
192 | align = 1.0*(path.new_tensor(indices==path.unsqueeze(-1))) # B, T, L
193 |
194 | for i in range(align.size(0)):
195 | pad= T-mel_lengths[i]
196 | align[i] = F.pad(align[i], (0,0,-pad,pad))
197 |
198 | return align.transpose(1,2)
199 |
200 | def fast_viterbi(self, log_prob_matrix, text_lengths, mel_lengths):
201 | B, L, T = log_prob_matrix.size()
202 |
203 | _log_prob_matrix = log_prob_matrix.cpu()
204 |
205 | curr_rows = text_lengths.cpu().to(torch.long)-1
206 | curr_cols = mel_lengths.cpu().to(torch.long)-1
207 |
208 | path = [curr_rows*1]
209 |
210 | for _ in range(T-1):
211 | # print(curr_rows-1)
212 | # print(curr_cols-1)
213 | is_go = _log_prob_matrix[torch.arange(B), curr_rows-1, curr_cols-1]\
214 | > _log_prob_matrix[torch.arange(B), curr_rows, curr_cols-1]
215 | # curr_rows = F.relu(curr_rows-1*is_go+1)-1
216 | # curr_cols = F.relu(curr_cols)-1
217 | curr_rows = F.relu(curr_rows-1*is_go+1)-1
218 | curr_cols = F.relu(curr_cols-1+1)-1
219 | path.append(curr_rows*1)
220 |
221 | path.reverse()
222 | path = torch.stack(path, -1)
223 |
224 | indices = path.new_tensor(torch.arange(path.max()+1).view(1,1,-1)) # 1, 1, L
225 | align = 1.0*(path.new_tensor(indices==path.unsqueeze(-1))) # B, T, L
226 |
227 | for i in range(align.size(0)):
228 | pad= T-mel_lengths[i]
229 | align[i] = F.pad(align[i], (0,0,-pad,pad))
230 |
231 | return align.transpose(1,2)
232 |
233 | def viterbi_cpu(self, log_prob_matrix, text_lengths, mel_lengths):
234 |
235 | original_device = log_prob_matrix.device
236 |
237 | B, L, T = log_prob_matrix.size()
238 |
239 | _log_prob_matrix = log_prob_matrix.cpu()
240 |
241 | log_beta = _log_prob_matrix.new_ones(B, L, T)*(-1e15)
242 | log_beta[:, 0, 0] = _log_prob_matrix[:, 0, 0]
243 |
244 | for t in range(1, T):
245 | prev_step = torch.cat([log_beta[:, :, t-1:t], F.pad(log_beta[:, :, t-1:t], (0,0,1,-1), value=-1e15)], dim=-1).max(dim=-1)[0]
246 | log_beta[:, :, t] = prev_step+_log_prob_matrix[:, :, t]
247 |
248 | curr_rows = text_lengths-1
249 | curr_cols = mel_lengths-1
250 | path = [curr_rows*1]
251 | for _ in range(T-1):
252 | is_go = log_beta[torch.arange(B), curr_rows-1, curr_cols-1]\
253 | > log_beta[torch.arange(B), curr_rows, curr_cols-1]
254 | curr_rows = F.relu(curr_rows - 1 * is_go + 1) - 1
255 | curr_cols = F.relu(curr_cols) - 1
256 | path.append(curr_rows*1)
257 |
258 | path.reverse()
259 | path = torch.stack(path, -1)
260 |
261 | indices = path.new_tensor(torch.arange(path.max()+1).view(1,1,-1)) # 1, 1, L
262 | align = 1.0*(path.new_tensor(indices==path.unsqueeze(-1))) # B, T, L
263 |
264 | for i in range(align.size(0)):
265 | pad= T-mel_lengths[i]
266 | align[i] = F.pad(align[i], (0,0,-pad,pad))
267 |
268 | return align.transpose(1,2).to(original_device)
269 |
270 |
--------------------------------------------------------------------------------
/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .init_layer import *
5 |
6 |
7 | class TransformerEncoderLayer(nn.Module):
8 | def __init__(self,
9 | d_model,
10 | nhead,
11 | dim_feedforward=2048,
12 | dropout=0.1,
13 | activation="relu"):
14 | super(TransformerEncoderLayer, self).__init__()
15 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
16 |
17 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation)
18 | self.linear2 = Linear(dim_feedforward, d_model)
19 |
20 | self.norm1 = nn.LayerNorm(d_model)
21 | self.norm2 = nn.LayerNorm(d_model)
22 |
23 | self.dropout = nn.Dropout(dropout)
24 |
25 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
26 | src2, enc_align = self.self_attn(src,
27 | src,
28 | src,
29 | attn_mask=src_mask,
30 | key_padding_mask=src_key_padding_mask)
31 | src = src + self.dropout(src2)
32 | src = self.norm1(src)
33 |
34 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
35 | src = src + self.dropout(src2)
36 | src = self.norm2(src)
37 |
38 | return src, enc_align
39 |
40 |
41 | class TransformerDecoderLayer(nn.Module):
42 | def __init__(self,
43 | d_model,
44 | nhead,
45 | dim_feedforward=2048,
46 | dropout=0.1,
47 | activation="relu"):
48 | super(TransformerDecoderLayer, self).__init__()
49 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
50 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
51 |
52 | self.linear1 = Linear(d_model, dim_feedforward, w_init_gain=activation)
53 | self.linear2 = Linear(dim_feedforward, d_model)
54 |
55 | self.norm1 = nn.LayerNorm(d_model)
56 | self.norm2 = nn.LayerNorm(d_model)
57 | self.norm3 = nn.LayerNorm(d_model)
58 |
59 | self.dropout = nn.Dropout(dropout)
60 |
61 | def forward(self,
62 | tgt,
63 | memory,
64 | tgt_mask=None,
65 | memory_mask=None,
66 | tgt_key_padding_mask=None,
67 | memory_key_padding_mask=None):
68 | tgt2, dec_align = self.self_attn(tgt,
69 | tgt,
70 | tgt,
71 | attn_mask=tgt_mask,
72 | key_padding_mask=tgt_key_padding_mask)
73 | tgt = tgt + self.dropout(tgt2)
74 | tgt = self.norm1(tgt)
75 |
76 | tgt2, enc_dec_align = self.multihead_attn(tgt,
77 | memory,
78 | memory,
79 | attn_mask=memory_mask,
80 | key_padding_mask=memory_key_padding_mask)
81 | tgt = tgt + self.dropout(tgt2)
82 | tgt = self.norm2(tgt)
83 |
84 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
85 | tgt = tgt + self.dropout(tgt2)
86 | tgt = self.norm3(tgt)
87 |
88 | return tgt, dec_align, enc_dec_align
89 |
90 |
91 | class PositionalEncoding(nn.Module):
92 | def __init__(self, d_model, max_len=5000):
93 | super(PositionalEncoding, self).__init__()
94 | self.register_buffer('pe', self._get_pe_matrix(d_model, max_len))
95 |
96 | def forward(self, x):
97 | return x + self.pe[:x.size(0)].unsqueeze(1)
98 |
99 | def _get_pe_matrix(self, d_model, max_len):
100 | pe = torch.zeros(max_len, d_model)
101 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
102 | div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model)
103 |
104 | pe[:, 0::2] = torch.sin(position / div_term)
105 | pe[:, 1::2] = torch.cos(position / div_term)
106 |
107 | return pe
108 |
--------------------------------------------------------------------------------
/prepare_stages_benchmark.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Run this code twice after the stage0 and stage0"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [
15 | {
16 | "name": "stdout",
17 | "output_type": "stream",
18 | "text": [
19 | "training_log/aligntts/stage0/checkpoint_1000\n"
20 | ]
21 | }
22 | ],
23 | "source": [
24 | "import os\n",
25 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]='0'\n",
26 | "\n",
27 | "import warnings\n",
28 | "warnings.filterwarnings(\"ignore\")\n",
29 | "\n",
30 | "import sys\n",
31 | "sys.path.append('waveglow/')\n",
32 | "\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "%matplotlib inline\n",
35 | "\n",
36 | "import IPython.display as ipd\n",
37 | "import pickle as pkl\n",
38 | "import torch\n",
39 | "import torch.nn.functional as F\n",
40 | "import hparams\n",
41 | "from torch.utils.data import DataLoader\n",
42 | "from modules.model import Model\n",
43 | "from text import text_to_sequence, sequence_to_text\n",
44 | "from denoiser import Denoiser\n",
45 | "from tqdm import tqdm_notebook as tqdm\n",
46 | "import librosa\n",
47 | "from modules.loss import MDNLoss\n",
48 | "import math\n",
49 | "from multiprocessing import Pool\n",
50 | "import numpy as np\n",
51 | "\n",
52 | "data_type = 'char'\n",
53 | "checkpoint_path = f\"training_log/aligntts/stage0/checkpoint_40000\"\n",
54 | "\n",
55 | "from glob import glob\n",
56 | "\n",
57 | "checkpoint_path = sorted(glob(\"training_log/aligntts/stage0/checkpoint_*\"))[0]\n",
58 | "\n",
59 | "print(checkpoint_path)\n",
60 | "\n",
61 | "\n",
62 | "state_dict = {}\n",
63 | "for k, v in torch.load(checkpoint_path)['state_dict'].items():\n",
64 | " state_dict[k[7:]]=v\n",
65 | "\n",
66 | "\n",
67 | "model = Model(hparams).cuda()\n",
68 | "model.load_state_dict(state_dict)\n",
69 | "_ = model.cuda().eval()\n",
70 | "criterion = MDNLoss()"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 2,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "import time"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 3,
85 | "metadata": {
86 | "scrolled": false
87 | },
88 | "outputs": [
89 | {
90 | "data": {
91 | "application/vnd.jupyter.widget-view+json": {
92 | "model_id": "17e3fa022cc04cc0916ed226d16297dd",
93 | "version_major": 2,
94 | "version_minor": 0
95 | },
96 | "text/plain": [
97 | "HBox(children=(FloatProgress(value=0.0, max=656.0), HTML(value='')))"
98 | ]
99 | },
100 | "metadata": {},
101 | "output_type": "display_data"
102 | },
103 | {
104 | "name": "stdout",
105 | "output_type": "stream",
106 | "text": [
107 | "VT Time: 0.758989 / 1.367927 = 55.48%\n",
108 | "IO Time: 0.005540 / 1.367927 = 0.40%\n",
109 | "DL Time: 0.591276 / 1.367927 = 43.22%\n",
110 | "torch.Size([1, 170, 857])\n",
111 | "\n"
112 | ]
113 | },
114 | {
115 | "data": {
116 | "application/vnd.jupyter.widget-view+json": {
117 | "model_id": "239ca9b5506d4667a6e390b5b6507ae0",
118 | "version_major": 2,
119 | "version_minor": 0
120 | },
121 | "text/plain": [
122 | "HBox(children=(FloatProgress(value=0.0, max=86.0), HTML(value='')))"
123 | ]
124 | },
125 | "metadata": {},
126 | "output_type": "display_data"
127 | },
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "VT Time: 0.590959 / 0.813022 = 72.69%\n",
133 | "IO Time: 0.005205 / 0.813022 = 0.64%\n",
134 | "DL Time: 0.213748 / 0.813022 = 26.29%\n",
135 | "torch.Size([1, 140, 725])\n",
136 | "\n"
137 | ]
138 | },
139 | {
140 | "data": {
141 | "application/vnd.jupyter.widget-view+json": {
142 | "model_id": "fc2d69f9b7aa48a5971db851393858cc",
143 | "version_major": 2,
144 | "version_minor": 0
145 | },
146 | "text/plain": [
147 | "HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))"
148 | ]
149 | },
150 | "metadata": {},
151 | "output_type": "display_data"
152 | },
153 | {
154 | "name": "stdout",
155 | "output_type": "stream",
156 | "text": [
157 | "VT Time: 0.759542 / 1.017219 = 74.67%\n",
158 | "IO Time: 0.024682 / 1.017219 = 2.43%\n",
159 | "DL Time: 0.226169 / 1.017219 = 22.23%\n",
160 | "torch.Size([1, 137, 807])\n",
161 | "\n"
162 | ]
163 | }
164 | ],
165 | "source": [
166 | "datasets = ['train', 'val', 'test']\n",
167 | "batch_size=64\n",
168 | "batch_size = 16\n",
169 | "\n",
170 | "start = time.perf_counter()\n",
171 | "\n",
172 | "for dataset in datasets:\n",
173 | " \n",
174 | " with open(f'filelists/ljs_audio_text_{dataset}_filelist.txt', 'r') as f:\n",
175 | " lines_raw = [line.split('|') for line in f.read().splitlines()]\n",
176 | " lines_list = [ lines_raw[batch_size*i:batch_size*(i+1)] \n",
177 | " for i in range(len(lines_raw)//batch_size+1)]\n",
178 | " \n",
179 | " for batch in tqdm(lines_list):\n",
180 | " \n",
181 | " single_loop_start = time.perf_counter()\n",
182 | " \n",
183 | " file_list, text_list, mel_list = [], [], []\n",
184 | " text_lengths, mel_lengths=[], []\n",
185 | " \n",
186 | " for i in range(len(batch)):\n",
187 | " file_name, _, text = batch[i]\n",
188 | " file_list.append(file_name)\n",
189 | " seq_path = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',\n",
190 | " f'{data_type}_seq')\n",
191 | " mel_path = os.path.join('../Dataset/LJSpeech-1.1/preprocessed',\n",
192 | " 'melspectrogram')\n",
193 | " try:\n",
194 | " seq = torch.from_numpy(np.load(f'{seq_path}/{file_name}_sequence.npy'))\n",
195 | " except FileNotFoundError:\n",
196 | " with open(f'{seq_path}/{file_name}_sequence.pkl', 'rb') as f:\n",
197 | " seq = pkl.load(f)\n",
198 | " \n",
199 | " try:\n",
200 | " mel = torch.from_numpy(np.load(f'{mel_path}/{file_name}_melspectrogram.npy'))\n",
201 | " except FileNotFoundError:\n",
202 | " with open(f'{mel_path}/{file_name}_melspectrogram.pkl', 'rb') as f:\n",
203 | " mel = pkl.load(f)\n",
204 | " \n",
205 | " text_list.append(seq)\n",
206 | " mel_list.append(mel)\n",
207 | " text_lengths.append(seq.size(0))\n",
208 | " mel_lengths.append(mel.size(1))\n",
209 | " \n",
210 | " io_time = time.perf_counter()\n",
211 | " \n",
212 | " text_lengths = torch.LongTensor(text_lengths)\n",
213 | " mel_lengths = torch.LongTensor(mel_lengths)\n",
214 | " text_padded = torch.zeros(len(batch), text_lengths.max().item(), dtype=torch.long)\n",
215 | " mel_padded = torch.zeros(len(batch), hparams.n_mel_channels, mel_lengths.max().item())\n",
216 | " \n",
217 | " for j in range(len(batch)):\n",
218 | " text_padded[j, :text_list[j].size(0)] = text_list[j]\n",
219 | " mel_padded[j, :, :mel_list[j].size(1)] = mel_list[j]\n",
220 | " \n",
221 | " text_padded = text_padded.cuda()\n",
222 | " mel_padded = mel_padded.cuda()\n",
223 | " text_lengths = text_lengths.cuda()\n",
224 | " mel_lengths = mel_lengths.cuda()\n",
225 | " \n",
226 | " with torch.no_grad():\n",
227 | " \n",
228 | " model_start = time.perf_counter()\n",
229 | " \n",
230 | " encoder_input = model.Prenet(text_padded)\n",
231 | " hidden_states, _ = model.FFT_lower(encoder_input, text_lengths)\n",
232 | " mu_sigma = model.get_mu_sigma(hidden_states)\n",
233 | " _, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)\n",
234 | " \n",
235 | " viterbi_start = time.perf_counter()\n",
236 | "\n",
237 | " align = model.viterbi(log_prob_matrix, text_lengths, mel_lengths).to(torch.long)\n",
238 | " alignments = list(torch.split(align,1))\n",
239 | " \n",
240 | " viterbi_end = time.perf_counter()\n",
241 | " \n",
242 | " print('VT Time: ', end=' ')\n",
243 | " print(f'{viterbi_end - viterbi_start:.6f} / {viterbi_end - single_loop_start:.6f} = ' +\n",
244 | " f'{(viterbi_end - viterbi_start) / (viterbi_end - single_loop_start) * 100:5.2f}%')\n",
245 | " \n",
246 | " print('IO Time: ', end=' ')\n",
247 | " print(f'{io_time - single_loop_start:.6f} / {viterbi_end - single_loop_start:.6f} = ' +\n",
248 | " f'{(io_time - single_loop_start) / (viterbi_end - single_loop_start) * 100:5.2f}%')\n",
249 | " \n",
250 | " print('DL Time: ', end=' ')\n",
251 | " print(f'{viterbi_start - model_start:.6f} / {viterbi_end - single_loop_start:.6f} = ' +\n",
252 | " f'{(viterbi_start - model_start) / (viterbi_end - single_loop_start) * 100:5.2f}%')\n",
253 | " \n",
254 | " print(alignments[0].shape)\n",
255 | " \n",
256 | " break\n",
257 | " \n",
258 | " \n",
259 | " \n",
260 | " \n",
261 | "# for j, (l, t) in enumerate(zip(text_lengths, mel_lengths)):\n",
262 | "# alignments[j] = alignments[j][0, :l.item(), :t.item()].sum(dim=-1)\n",
263 | "# np.save(f'../Dataset/LJSpeech-1.1/preprocessed/alignments/{file_list[j]}_alignment.npy',\n",
264 | "# alignments[j].detach().cpu().numpy())"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 4,
270 | "metadata": {},
271 | "outputs": [
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "torch.Size([16, 137, 807])\n"
277 | ]
278 | }
279 | ],
280 | "source": [
281 | "print(log_prob_matrix.shape)"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 7,
287 | "metadata": {},
288 | "outputs": [
289 | {
290 | "data": {
291 | "text/plain": [
292 | ""
293 | ]
294 | },
295 | "execution_count": 7,
296 | "metadata": {},
297 | "output_type": "execute_result"
298 | },
299 | {
300 | "data": {
301 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAABcCAYAAACC/mV2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy9bZAl13nf93vO6e57987cncHsDmaxwGIXS0B4ISiKFBySsiglkh3JKrpcqVQSyZWKnXKiqiROxZJTsfwplQ+pspNPTjmJwg+ylKpErxWXFEYq6tWUKVmU+CJaJEgQIAhwF4NdzO7szM7Mnb7dfc6TD+ecvn3vzOwbFsCsef9VU3PnTr+cPt39nOf8n+f8H1FV5phjjjnm+M6Aea8bMMccc8wxx7uHudGfY4455vgOwtzozzHHHHN8B2Fu9OeYY445voMwN/pzzDHHHN9BmBv9OeaYY47vILwjRl9EflREXhKRV0TkZ96Jc8wxxxxzzHH3kPudpy8iFvgG8FeBy8CfAT+hqi/e1xPNMcccc8xx13gnPP1/C3hFVV9V1Qr4JeBvvAPnmWOOOeaY4y7xThj9R4FLnb8vx+/mmGOOOeZ4j5G9A8eUQ747wCGJyE8CPwkwGMj3Pv1kjnY26+4gU58Fh6IatvYImdD+rQilZhTS0BOoFSq1DIzDxCMpadtw7PR9rR6P4BCcGobGt9t226CAU2WsGUY8J0SpNLSl2+5KMwyKIbXV0KjBYUBDWxVYy0bttSmKJ4zG3d/p/NLZLrVFOr/lkGu0cXud2u7gTTEItXqaTjtPmhpVRUQOnD+dy6N41fb6ezL9CDTqaeK2fZEDbXMotUJPJvfYE/q0kNAuASr1GBESJSkimHgvjEh7TU6VGoOqoAgD49r+08799qkP2/0nz0GNofQ5inDKVlPP5ixqhXSGAk8V+8+KR4A3q6Uj951jjreD7Zc2rqnq6t3s804Y/cvAuc7fjwHrsxup6ieBTwJ86IOF/uFvnaHGYTsvXoKJRsQiGAy7WlOqUqpQq2HVekrV9uX7erXKo9kWT+eeS43n9eYhXuhtMjQFLhqMkdZ4IEcYmJxaHRuuYUczdnzBlh/wV07sUGpDrR6LkIvBoYzVs+Xh1XqFodnn+WLMeqOMNKNW2xqA1+pVFsyYvtRUaik152q9zMgXODV4hNLn/P1TnyMXQy6WWh0jdQzEMlJHjlCj5NE49sWSi6XUBgCvSi6GWj25GEycvNU4vCo1ypIp2utIA0naJ/UrwMDkrDdjNn3BFXeS680iP7bwOrUq/XjtuRhyLB6PwWBFGPmaPfXs+HDui3ne9jPApq+45nKMKE/nk2sEWDIF277iirOcz8I+OZaxNrzeWM5mDUumwGD4VlOybKBUxQCFCAOxbPmGBTEYCYPWpocrboFaLbVmfKR3AyOCjYNaT7L2GQjPmjIwlhwbHmBXc8Ut8PL4DCPf4yeXX6GO7XUd45/6bd05Nl0fI56zdsy663GlWeKU3cWi/KNLP3bLF2aOOe4V/+8P/K+v3+0+74TR/zPgKRF5AngD+HHgb95uJ4/npTpjQRoG4lg2hlwMI3VUPhinvggex6LkPGTCCzrWmp708fiJoSk2WDXC642Qi3I222YYDd+2d7xan+S6O02pOUOzzzPFBhal0mC0FqSmb29SqmO9CR49wBV3EoDniutYYNXuMDQ1YFpD50mGxHPGvk4uwciOFSo1PJ5tUmpOpRaPoVLLpvfk4lkx0JOcWoNBfcjkXHX7jBUc0BfF4zDqKFXpy2QgAhip47rzrFqlVsXFvt32FQOxIAYf+8giIBN2z6uy6cYMjQAVfbnBqt0BYGAs/6pcZmDGWDwLUrNmGza9ZSCOx7ITLAKrNrT7G3VFXxwOYdnAqu2xbBwG0xr8NEj3pOEh02dRHJu+wgN98SyZgtN2zKIUbftXDAxNwaJq6yBcdQ2lWq6qJRdPjrCjOafMPgAmDlQAnxsvkItj1Yzoi8cKnDZFO43a9hV9MawY6MsufbkUBw7Hi7WN/abUajHi6YujLw6LMjQVA3F44IwdsyDXGUi4Ay++uXa7x3+OOd413Hejr6qNiPxd4NOABX5OVb96q30kevB9ceTxZQSi1+qi/xUMle3s5/FTnlfwY8GpRK9UsTOemQH6UrNgxuDD576EbVyHpmjPJ548HndgxngNrZoNhgSDP2mPi4OIwYOEGYhHqLDxs2l/H3a8SZvT/7TdLvSDtp5mMOBgVCnEYzF4AaKBv5PATfKED54/0TI1famxaLxHMnWvEjxhAAttVfJDzm46P91zFiJUGmY1s9sfhVwI+4jHEq6/6MwSu0jtL2K7Z49rRTAiOFWcgo8UV3o2vQomPlNGlBxPTurjSV/Pnj3PHXPMcVzwTnj6qOpvAr95x9sTDGapNhgV3BSJ7gjGz83wqgYzZTSCifVY0TiVFxDfOU7gyysse75HqTm5NtTxsGGwkLhtGDhcpJAASp/jgxlvX2yn0fCKgE5eeY9gpTvgKA7F4qmDWQYMLg5Vh5upcO2eNGho3C548Q6d8NMar01NbPfE07eHHHcWXrXlxqfPH85Rao5RjyUYV6eOWm24Vx0YDE7Dtac+PHCuzk/3/y5ew+w+R/UNBD7dE++ReCo1VJiOEdbOthkubpMG46lrVW15fgAjHtS2zyZAoYGzT8GR5FQkR+GwtjbNndyBOeZ4d/COGP27RXrNnspqjAheTeSxPQOxDKY8/8DFlr5irLCnGWdtSanKKBror1ZneTzb5Ol8zOtNxjfqhzlj11kyBdYIK0VJnY+iERZ6UlATqJxSLVtuwJYf8MHiGqu2IoQ24Yn8JgA7Ppz3tfo0Q7PPGbvNq00dOf08GjvDy+MzgdM3NaXPI6e/xI7r4zVsM/YZ3/fwRsvVjyPP7PFccxV9CeHJvgAxrpBjycW1QVCDwUXaaymXyD+HGVKlE07fq7bXDLSxioQV2+PVumbTn+CN5iE2m0UuDL/Bjnd8uLcJhBhITzJqLEvGkkvBrh8fyunb6DVvuDFXXUEunqdzSy8OGkagLxk3fMkVZ7mYZe193vEVl5sT2GyfFdsD4PVGgYpStR3ITlvLpnOs2XAtpXpq71hvlig1x6lh1b6JRfhQb49aPX0J86ct37DtK6wIPTEsmQKPZ71R3mhO8tL4LCNf8JHeizyV1R0qbRILcRheb4RNN8CI50K2y6v1Sd5yw5bT/9Bjl+/buzLHHF28cg/73PfFWfeC7/lgob/x/50Cgle6ZApysfzBfp++qTkpY3LxDI2jVvh6fbr1sJ7Kr7Pjg6FdMRW5BP58aIRFycnFsuvH9CRjpOHF3XSOq+4Ee1rQl5qn8n0GYtuZQgrW5tHIJmz6MCt4LOuFgLIfk4th3Tn+fHy23c6inMm2yMWxIA19cS2VMFJLraadUdRqeb6o2XANfYFchD8uV7mQb3LGOnqRh0+zlJxAQVxzjkoNI81YNhVDI+x5pVTDk3mvpboAdnwwlJs+a2dTy6ZiTzNyPIV4ejIJjNbq2VPPSAMnbwjBThvpF4A1e4JvNSVOA4deak4dveFTZkRPHIV4coErrseWGwDwTHEjxCLiddTqeaXuc8UtsWJ3OWX2WwMeBhfHtnecNgVWhK9VnrfcIgtmzKrdZ83GwLWmuM/Eq050Xfpcx8+5CH2x9CTH4/nM/oD35Tcw0MaSSnVseqjV8FiW+tExVlqvf2gcKyYLsYo44/FxtjTyrg1+f+RX//49vRdzzHE7vPZT/+0XVPWFu9nnWHj6woTPDXSGBywnTYmRxBFrzDqBZTNqjX5PwJkGp0IeDVfd4YWT8bMiWD0sm/RoztiIzNAPk9S8NDjkhADrKbvbbmdRFqQOvG80fC0UjCjEYxkJx7HR4OdiWDBjFqQhj/TVNOUgbZsLCa1IwWIbcxb9DMmQYgD9GFgMvDz0ce3nFDswHVbaRmrHIoxQUG2PPOlXZUhNX1xLg+USBhID5MBQarAhLTWf6W+LMDQVpY5YkCpy85PB1iLkne174jhpSvrStLGYRG2l4x6gAUUiRZWuf7KNwXDSlPTkIDUzoeQmmU7pHIaDzlKKrcx+5/u3IqjmmOPdxbEw+kqgIWZfjS0/CME3E4JnffXUCNfdIjZy9au633r6fR2TyKLuS52CczVKrhOOvNYsZmOEoG/X6BkmXlsXTsOxcrEtNbLnDRvNyXabXBy5NPRpyJkY9dAejcZepxLla4UyUgdbboEdM2Jomta4eKY55zry2HuakWsIKNYK9SFDmEOpgZHPKOPaglw9I5+Ri8fhQSaUCUziACmGcNjA6DQEp3d80WYkAZyye+21gmfb99jyAwyeVRPvURygapQt3+e6W4x954PXrJ5e9JRLhSEei2WkGVt+wMCM6XnHsBM3Sc/PVEBaaGdKh8Hj2fKLrGrI9nGiU4NMuDcpTjPpdyNKrXHgF8HrhDLLxbRBYa9KtjPn9Oc4PjgWRl+gzbMGyCUEz85nN7Ci9KNB7sf/P57dACZe97Kp2nRLD2z6AqhYyiwjrbjmHGs20D1WBIIJZNmM6YsLXL9Im8teqrLju2mYIcdm2UyMx0grNr2nJ55lA2fzGyFHPy7KsShb/gTOlCFn3ocg8ZY/0XLNKWXz2XyDJeNbauJCfo3TtsYg5FhqcfSix9+TDCvCUMcgjoE6FoyQY8hFGeIw5J1MJmUgebhmW1NqQ47GferWy0+Ux1gbdtSz5TNGPmfJlBDvD0A/3jOD4ZQNS6wGss9Iq5b2OGsDT27ivawJMzavhhU7MYDJi1+1+zgMj9pdegKLErj1Xa3pSdinJzm7fsyy8cAOfXEMjWdgepTa0IszEoNhrA17cU1CrYH3N0g74wozCdum+YaU3pTZZSnVse2VK25AqTkX8zIGuR199QykbvukL1m7TsECfTHseMd2TGcFMBf23s7rMccc9xXHwuiP1PA7+4/wh9vPcDLb5/2DN/ie3mX+rxsfZdf1yMXRMw2P9m4w8gVjn/N4cY2TtmStd4U/Lh/lpCm5mG/SR/nuwlJrwSe3L/BDg5c4l2VcdQ1/tH+BUgt+f/MZvnnjNHtlwUK/4hPnvsIjxRbv711mQWoGpmFoPL++d5ottwDA2Of8yxtPMmoKfurcb3PG7rJigrH89d33kUvDsh1RiMPiGZqKZ7OqNTBZ9J2tOFwM1ib8/M3HebNe5vHiGu8r3uKXNz/Cc4N1Pj54ha9Xp3EIm80iQ1uSS4NXw3r9UDvoLdl9LuQbfLl8nBd3z/Ifr/5xTCv0LEjFF8oLPJzd5Eq9xI4PZns122kHqFwcfalZrx/iPzr5Iism4xEb/N3f3j/Fqt3hn175YVaLXcY+JzOOH1v6Mr96/QfIjeOjw1dYMBV9U9GXmv95/Uc5XexhxPPUibdYy7falM9/vvsIENJfa508fn2p+ZXrH+Fkts8zJ9Z5f2+dc5nj9Ub49O77+UD/Mt/fr/nz8Rm23CAMEvkmr47XWMl2+UZ5hp40XOy9Rd/UvNC7wlAMA5vzpbFhTwt+4a2/zHK+z6O9LdbybZ4qrnAuG/G7u9/N2fwGr1enebJ3lR86cYXHsz7vy0PK8G/sLfGl0QWu1wt4NWxWA4wop3u7PHPiTX5w8DILJgyenykf5vv6VzmfCY4QH3nff/7aO/j2zPGdjAc2kPvBDxb6a586RakWg3LGhtWhv753mr7UnDTB2K2YkhrDi+NHWgrlA8VbXHUnAFiz+zHLJSwo8qqcNH3G2uDxbPqGHLjqcl5rTnHT9VkwFd/TW28Dvwk1jpF3rcfqgfUm0CPPF3VcROXIxfLVquGr1Vls9K5zcTya3WgX7PQlpnQyobGcpvOEYOmGC7ngA4FPjy7yTPEm57P9dr8yBSEJx1pvMioMI99j2exz2tZs+rCa+INFBUyyna75ilphIwavLcqKHTHyeTT4rg02n8161OqiFIFSqVKIcN2FNMzE25/PhBdr2656TYvNwvU3LEgVB52G6/4EO76PRTmf3aAvvu2TUrXNdulLzSm7y7lsxECEoSkY+Zp1J5y1yqLp8YUxXHFLDGTMo9lNloyjkBDEtgLDOGPsy2RA2fEVtSpb3mBF25nOkikA+NTeKZ4prsZkgIaVmMUzUkcVM4UcsOND4HxPw37LpuSMde1xICzwWpScGteuXP7I//bTb+f1mGOOI/H1//GnH8xAriFMwb9S9ViQmk1fAzU/MrjCjg8ZO1YCveO04RMLb9KLBrrWgrNZojLCy3fVVQyA1xvLwOxTq+HJPGMonk3vueJOctP1KbWgUEeplsopG4RVpIGfNTyWedbdJNtmvXkIgLP+CrU27Piwovd8Zni2uNKuCE5B3m3vWn44rUDd8gW1WiosXgO9cy7bZiCOFWsZSMEnFr4Vs0tO8Kbbj3ITgS5IFMWKrWMWTJo1GKwJ0gwjdVN5+jmwZC0rZowj0DW5GLyOp+7CjvqWqhipUKptM4heqlcZmn0sysCM2fJjCoRlW/FE1gfqlgb7atXQE4dHWDGOJ/OaWssoHyGMPJQaqJdVY/hYf4zTkquuihlKMJCcDTdmxRSsmAoTs7A+1MtxutmuyF13wqa37QC2I449zULwmBBoXrOGXODL1TK5NKzaPWrv2PNj1mzBjwzeAoQt39AXaemhlO30dN7wlaqHw1B01iWUatlwyo4fU6lhwXicQiUVI5U2cL7/TPk23o455ri/OBZGPwlvPV8EI5Ty9HsIQzMJmBoMSMjT3/JNJ0/ftXn6ToUXq7Ocy6/zVFZzyVlerU+zaq/GVFDHmt2m1httzvrA5JTacNV5Sp+zEwOLzxbXOUdNHSTSeDrfAGDkYUctrzWnGJp9Tvd2ebWuW+0dCEb91ephBlF7p9Sc0udcbZYYuV5YUOSDp/wPVj9LX7I2Tz9d7zW3z0CEHCU3QXatH1NLe508/RzLSGuWTMGKlRjQ9TEIqqzYXqvVAymbR6iZztNfsz1ebyo23QmuuCWuN4v89cVvsuMd39+/2qaMpjz9FSPYQ/L0ny3yuEBLcGo7efqOp/MsLlt19FEGHe2dbp7+SGvWXQ/LuM3Tf7WpWe3k6VsR1qxh0zWczV2bp4+HSzFP32NY6V/FInysv9XJ07fs+IbdSLUFmYfg4b/eKFfcEl+PefoffOjrPF+M2/7u9qMDLjeW6/4Em95zLhvxarPIRnOSU3YXI54XLt61PMocc9wR7uXJOhZG38QA3JavsMDQZCxKzjW3Txm9fMMkmFhGyqEQWCEZQxia8P9C1lmJn89Zz7K5wkPmRCuateEaNn1BqRl9abgYF96sGmHF1JRacT67CfRi+wKuuTATOGszBkYZyBUWTJgVDI2nr4FW8YTU0lNmP0pLBDrHEzz9UjMcpvX0LcJIHYsYepKz4fZZscKKDVSLTRlAHQNdq6f10yUEIEdaM/LKmu211E5ffGvwE10zWUUKxOOaGMh+zOasmJI1u88otwzEsiiGN11YA1GilFqxYntcdePglRvLAobChuNedeMosRD/J4Y1W8V2B2PZj/fSxfTaZdOw5SfrNJZMn6eyMYvmBCOtGEjB+SxjrE0nXTdqJZmQpuslnM+KZ8XcaH3yoQn3cb0ZxywqR19sux7kmtvHiVJqTV8s5zNh1d4MMwIN677H6ts1CuNo9wfxfI9lsKp78ZoNz+V77Njd+GzAF7/d1R+cY473FsfC6PsowzCMImoO5YYvGRjLAKY9fcARVDaDp29bT3/T+ylP/+m84fVGeLUOwbUlU+AVVm3Gsmlw1FOe/oZXdjqe/tnsemzfhIIC2PFBbOzl+gxDs88Lvd2W700LlA7z9GvNWK+XD3j6H179LAOxWBHGWjM0wRPfdGPytO5AklcfkIs54OkPJGdopOXkk6e/Gj39sCAqZa8YnE5UTb0qA5NHT7/fevqri9/EacOSkSlPv9SmXTB1UGWzmFqRu6lVuyJ31QaF0FKTnpFEvj3jYudp3PYlrzYF52xYkRs88CaIyREooHFM69zz2vbZyDs2veWN5iQ1Id7w/dHTX7G24+nTrsa1UXcoj5ThtxrHFXey9fSfL75OTwyDeA+GsY2p/y83cN0vYAme/qVmwJVmufX0P/x4t7zEHHPcP3zrHvY5FkZfCIY+cfpDU7ee+kFOXxmajCWTOH2HFdu+iAALZp0Vk/FKbRiamqfzt1iKQcFN73m5PsWVZonS5wxtyff0Lsec8iAitmz2WTb77PjmAKfv1PCX+uv0xXEhu87Q1NQqraEjquk4Ki5kr7aB1ypy+tvFevT0wxL+Wi0bTuhLw4pVBlK0ssOnbYfT94aBOGpp2gBovw0yB29901ds+owzdprT33DjdkCF4PGbmZkDBA991RgGss+q3WcvD4/HwFj+YH+a0z9rHdecCRkzec4iplXZ/NeVa1U2V03TqmzmYhn5mlJ94PQR+j7QUouiLac/kODtn7NjlkzRtj8E+AsGHZXNS42nJuNqPQlK72nGuWwbIK4KDtf+2fIh+lKzavdi8BqWZJLeuumD7MVZK6yY7dbTL9VNcfql5lFls25XXK+aUcvpn89GrNr9dvHYn7184T68JXPMcX9wLIw+BA31xJtCFrTfVUP+vuno6UvQ00+SCJWGfPEUKIUg0/Co3ebJXFl3wqVmiTN2m0XToxc5/VI3Ww866exf8xV7HnbI2XIDni1GnJcGrw1GpOX0axW2VbnuB5RacdpWBzh9h3CpPsXQ7k95+lfqJUa+FwPGwej/1ytfwpK1evoJ275kyVj6mtQcTau73ztET3/ZZHGwNNRyUE8/STnMrlxNxn/N9lhvxmz7nLfcIptukbODy4zV84P9t4CwujXozhcsG08u+VTw16nwbJ4xWd4VjHbQ03c8nefk6sjjdQ5M3tHTD/uk3PcNn5FLzWqkq77t9qkjp5/09M9nYf++9a2sw7Yfs+HiegiEFbODEeEHI6cf9PRD3GAcZx1Dk7ESs3DedBUb7gTfqNYY+R7fXXy7zYgKfVZNnkcsbzrHjuaMnWfN1qy7gg03ZCWu0n7/EwfKScwxx33Bt+9hn2Nh9JUw1U5YMgWLJucPy6CNMzQVQbO8wSm8VAedHoPnffkN1ptp7Z2n8ussG8il4IksZ9Vs05OMbV+Si2HbO664HiPfoy81F/OSgVhOm4IV43E0jLMtSg0USkqb7GrvDA0smZvkYrjmHC9WQXvHRcmDM9kWF4u3Wk8w4WJ+7YD2TvDSJxo9f1Ku8r78OqvRbqYCKsFgK06bVntnrJahCZTQlmuoEZ7I+lOSAMkob3vLXlyFfJj2jvMVqzZjaBpW7A1K3WoHiVKDnHKobeCi9k4NOHZ8TqlZULFEqNlrdeZ7AutReydw7dsUHe2dHV/xWlOw3jzEjr/Jshlz1gbhvYsZOEybxfOILfhGrVPaOz3rGJhA26AhttG3cNpWQBVX08pUDMRF7n7JhDoMXyz7XIwzg2VjWLMZQ1OyZL5NjWGUCu94R6lBP8mgLETtnUdswWmbBmHDxaxhzV5nIa4o/tqfn79v78occ7xdHAujLwQK4cvVCRakYkdrVk3Dh4sxO+qplVbHxQh8rL/FQIJX5mOwNSB83nBjBpLxlUpZNvs4hPOZgMJV53m5Ps2VepmR7zG0+wzMt6POTEMqlwghVvCtmA8P8FodimEY3sBDTA0MfPcnFq5PVYuqcez4kEMeyhAGLjolf/ho+B2GS42nEBhKiC98tL8ROeQ+l5t9yrh6dyAOl1Z5QvS0gzYMQC7Kpsu54cspegdgaCx98fjopfbEsqQpAyiMDjtescaHAixKzMkPsgKf2V8L9E5c8AX7bPkeC9LwfCEYFM8Yg+FLlY3XKPRszfO54vOdSO8Y9tSzp0qNsGLggwU8n2/ypqtwBOpqKRZPCTnwjnGcbT1bZHyX7sW/DetNKI2440+0And7UYQOIEdZsQaD8Jn9U/Sl5ky2E+mdEJd4obcLWLa9i6qlHWkHDXoRX65O4NWQS0OpeZAF8TU7pqIvod0DCbPNJPrXymwvTjKn5pjjvcaxMPoQpsoFri2G0coIa9A9D0v6QxGVnAk/nQKG6TMQK01p1JWZ6OEHWWI3KaJCWBmaEzT4k56+j78tQewtiaP129xv8BoMQn1I8RFPoKYC3RS1ciKHXcayicno15phbcj8CMJghm5xEZs022NcI11rjrYByKTxEmQADhbssOl4MyUSfdw32bhctNXpKdW0Rj/8r6EQFwXwXHuuIIY3rczTFxcKz0QBPCuC10ktW6vgoyaOjVfsknBc93522p+uMXV1KvpixUdD6yjwrcFO9xyZpKX2pY4rpg8W14FJURevIctprGExnqWmwOHitSfk4trzOQ1lLW0Ubms1+wHTnxdRmeP44NgYfYPh2SJRPDaWDAy1S2ezdw7q6R+epx+yd6bz9FethABdsTOVp1+rY905xj7o42z5Ac8XNzhLTR319J/MQzZPytP/dvMQJ03JWm/vyDz9oSnpm4oyipK1evqE4uu1Wl44/UdTefoJN3zZydOHlKefauROZe/QzdN3bZ5+3cnTnyqVSOSnOwuyVw/J039s8ZuMvOP7+hvtvilP/0ljsHHRVDd75+k8aAslTDh93+bpj3ww8gOTH6qnn/L0C6kO5OmPVNsA+Zo1bPtQYjOsPVC8a7jiFluNo5X+BhbhI/2bbQwEskPz9AFebxquupN8ffwII9/jQ72XeLaoOnn6Tacfg57+dX+CAsfZbL/N3lnNQv2F771wL8zrHHPcHq/dwz7HwuhL5M1TIDbHxr995LIlesEBQXg5TKOHNFhMyCoRxUeZ42VTYbEsm4ZVe7PVyzeEzI9SfUz9UwYEQzOUBmuSJHJDUujMteNpEkS1lozjYbvDwIS0z0IcPkoVA/RpWM1usiBBj6YyJaUPC/P3bC/m6QdPv6vZn8eC6KEfJKZm+vb8pp0LdPTiJa6wxeOVNlvFdDzaJNGccvNnlSST2N1QFGfGOG6SSxPPP61S6fH0JQtBZw19VCD0o/JpCkandhYiYTGZhGV4ObbNbElXE9RIk+Ce0MdyyozptdcS1kJYulWxiMfy7YwnV2LdhRFVXAuRUjS712AwDNo6y9P6rkPjGes+Z/JtSp+TKjZ5rPcAACAASURBVLSl/dvBM87MhibQXYaQUbVsxrjsJsumxKLcGA+YY47jgmNh9A9DV9cdmJrad9Gt0ZpIn1wmC3gCLXRweu0IC6ZMd3+RQC+gcZ+DevaBNjLYSCnkMdSZasKajtRvjouG0wWdezGtZpDFMCmtckQfzMg6z/6d0NXPT7n8bf8dvsuRCHnrsQ+4NS3RrSIV9p18H/rNT+iqmXrFB8+reO3WC5imuWC67KM9hFaDSJEpE1quM7jUMYOpO+B1jxHq+9roUHhyaXAyfY+6Ms2pxkO696GGQEj7LUjPBtRuLq08x/HBsTX6s4VAjjJ4XRxloryaI//bfR27gVivE2/6qHOlYue3b9d07d3p/91J2fJbI3HiwAFDeNRgObX/kYPJvbXtMGN8L7iX0iOpoErq89n+vh1SAHz2fh31HEzteyC2EwLcIrffd4453i0cC6OvKCNft8YneHGWS42Pi7I8Rmnpg/UmNNshLJuKTQ3FrvviYwWqYEp7kjM0yqrdJ+T+B9XMsde27J3HsRSDhwNjQ0ESdRjGePKglBiDlePI5eZYhgZO6ZiehJq5l9zilJ7+glStrn4/BhmdNCE7xUzr6acVtQmXmpw1GxYK5ZK3s4dkwD2ekQ9GvlRYMDU5EiXP4CHJYwAzkCFOlVIdI1VKDTECbxx7PhSPCbMWweLII00zMKHmQDfwmxBy3fOWFtnzykhtq6ef09XTVzZ9kJ/wamIq5SQQO9aaTQ9X3AKWoKefGxtXaKfi9A0DKeiLYdN7tnwR6yA4Bkbo0amyJRZoMCak14YBwMZA8YQaCs9YhlPlioMzNjwz/bgIcEEals2IvoRezbHUuFbKIpxPsUaCPIhJ/WQYiKeOCqsA65tLb/MNmWOO+4djYfQhcNJfrgqGpmJBxqwYw9lM2PHBkKUfp8qTuU6pbM5m7zgds2KD5PHAeGq17WrQTef5Zv0QG+4ke77HSbPPB3rrMeUwaKCn7J0dX7UrcgEuNQ8D8OHeFWolVOwyDatG+EivxrXaOyEWsen22yIlNbQxi1LrKe2d15tQ+i+PK3LPZyV9yehJ3lHZDMakL8Gg1gTjtNyp0FSrsuEzYDplM+kZWTwD6WRGmWkPdNM3DCRluji8BOmFnhg+s3+Kk7GgSliRu881lzMwDU9kJ9rrDiqb+ZTK5vmsx1kNgdtSJ/nuNcKqEc5nBY/Zims+ZF6lVM1KGxbFhjUcMQXzsazHIzH91ZLxehOKt3Rlovc0Z0FC1THb8bI/Wy61KpsLEvSB1mwR5R8sW76BuNK3J6HkZS6OWh1frgocBQWOEC3y9KVhIHXM/Aoqm1s+lMc0TKqlNdWc3pnj+OBYGH0l6Nk8lQcztShBCOtfjW271N2gGBMGgM+NF9p9L2Tb7GmGV2HJ1FgJ6YYj73i2yICQ750GhDWb0ZMbXHO77GlOXxrWrI/VpYLhqtWxp4HffTIzLdV01l6hJqxchWCccoRN7/lMGWSXk6e/akPmxtBUB9IobeTCU8nHc5lhy3vG6vE65vPjFc5nNzhlS1ZM1gZsa4Re1NzZ8w1bGq51KE0QHQMWpOEhc2KKHhv5mh3fsOUNI83JxTOUhpHawF2jQbBOTAwIu1gSMhA8tXo+2t8gFUiHIBFR6z4e+GrVUGpGqb3Yx+O2KDrAVyply4d79ly+Hat0TQarr1TKG80yZ7Jtlk2Y4dQ4lk3WFhzPsSDwUu3YcAssSMWq3eesTf0TUnB7kuFxYa1BDFynwOsLvc3Y/0IvPmMez5+WAy7mN+PMJM4aYlwgpPEqHywqRhqyxMpYGnEg2ipzhhmRaak0b3xb6c3vzIbN55jjvcOxMPpC8ES/UuUMTM2CVKwYw4cLw5Zv2lWhRdTe+Uhv7whPv4i/KwbG8rXKMzANtRqW8rQ4q2k9/aC9s89A3sRKHXOtDV4tjowFaXjF+dbTf6N5GI+Z8vQHpmHVGP7a4MYBPf1N56Y8/WQsUn5++r3ugu7+0BgGUvBCbzPq6U88fadCXxQnLgZbYUFC1osR2wYYtzQ7sDgrefp98biov5+LZVkVI6Y1invq6UeRtlQL1uOnPP1Ujcuxz3a8/vcXWbzusDjra7UBDde7YhzPF0KtezHV1DLyjhFK6YOn/92F5dl8qy32UgOLZGz6KqZRulByUpWn85wns/2ovRP6LhyzoB+D5nuasRBLbAY9/dDvnysfIpeGM3aXvoTV22Fx1gjI2PKhXGaSTK5jSq3F85Uqx1FEzz/JRTTs+IZCfKunn7DnJ5SPGU5XSptjjvcSx8ToBy425ennTLynfidLoycZCOxq8FxrQgbOMCo1OgIFdMktUrPHd+WWa17ZdAW11q32zml7k5G/0ao1rtgeTpVtX1GqZ4Rlz/dYNJaLUlPH1aAXs2sAGLHseMd1DLkajMDlZkypiQoSQLjilhmYcUsJ1GrZaE62+eNJcO1j/a2Wjw7XHTzvkVasmKBDFOIdgf8PaZ2RSlJtJReGkjGM8hOptm+tvlUR7WrvJD39FJ80IqyaHptuzJ43bPo+O/4ET+RbeFW+L+a6Q1zkJj2GpiEnZ6QVY/Xs+XCwkG8fBkoTZ1qh8pbjfFaQm1CLdyiwaHps+5INJ5yxQWcppYOONaxsXonaO9f8Ps5P9PRzER6zObtak9tJWuuOjtmKBVCcCmdtGP4+1t9q8/QNRRhI8Oz4hqHJWDYZBtMK1327eYiR7/G9vWs8X9SdlM267bM8Dk6lShzkYMMJm75PEn648Mj1+/7OzDEHPMB5+h7lqptUcVoxQk9yfn8/lNBbtqNWL6ZSw4tVqLOaS8MzxQavNiFIuGoDh37O7jKMuZiP2BMsmQoQrrn9Nhj4hltkx/fpS83TbDM0tjX+Hs+a7rPtg4FLAeZNHzy685llzYb4Q47lW43jC+WFdlVvIY6z+Q1W7C5DqdsAM8C57OaU9o6P3v9V19CThkKEz+yf46niKmdt1erOj9W3lIvVhiuOoAvjC5ZMKBi+o8LIZzwbq/flcSHXtq/Y88qmL9jTwEsPTUUZqZ4c35ZMXDIFA+NY9iNKHTHyk7KGhrTK2HM2s1xuAt2zpz32fK/N9rnqRnH9gtKXqu1rAMs1BpICx+GevNwssl4/xCW7yym7x7ksCF88YgvG2nDZjTltw/350tjwlhuyYMactTus2hTEVlANi6zEcNqYdsZVxiLprVAbSk9C/MTj+bPxCk/l17CiLJtAKy2IZyhvUSNs+zBI7qhnzwe9I4ChqTltaVVEIRRVP5cZTuk+vTiLuvSnj973d2aOOe4Vx8LoG4S1TnUng2GsNR/vX2tzvoNXFazZqn0zLnUPRuipLMj1JjrjNbfIo+zSt543XcV6c4Lnizry0I7T1tGTXWrdpS8TT3/TjVtBrR3f53t7wdv2UT5gxZgoU+DYdBXrLmjPnM+EwYnXgnRBJ81xvVmijHr6dZRfuO4Wg8pmx9N/Iv8GazZrF1X90OAyA7E4wmrkOurGI7Se/lmpIHLZYYWsY4gEMjoi6eovmYKBOJZNQx1b2BdLGat0pVlEXzK2fcWWhw03YMefaL3jtJ5hQaTlvR/LiAZ3zFj32zq+KyY9VkIuPZbMHpt+D4vyWBb6ehzv9Wl7gr6MOGP3WLOp8EtBrY43XcVpa3ksi56+2+fp3LS1g4O2f/T043XU6tlRz5bP2sH1udxhMOSE2U6qPjbSsMr2o/2NdvEe0Hr6r9WnKTXn31vYZKQVS2LD4jWa2G9h5rXhxmz7EB9ZMXCp8Wz6EyybECx+9C/NVTbneGfwzXvY51gYfY+y3gRP38rE0//d/VAAOxRGdwc8fSue54qrvNzkt/T0V0yYjidPf8Mr681Jbvo+C2Z8qKdfannXnn5C8vRX7Q4DaQ719IFW5weCp59LqNH6+6PHpjz9JKdwmKe/43ssmzErxrEVvdDbefoWz7IZz3j6vvX0e9YxlMM8/RTInfb0b+oJSh9KEzqEU2av4+mPpzx9uEYvavLk0dN/qT7JlWaJZTti1e7csad/xu6yFj39MnLxR3v6rr0Gj5vy9P9kP3n6jmXDbT39URRcu5Wnv6r7ISiusP6nZ9+Bt2aOOe4Nx8Lop3WdpVqMaiwAbqnUAj2s+iCUpq4NroVc/JB5EoJtSlLi7JvgjZbqOBGzNGoNaYLg2fEF16MhKjVnlN2kaDXrw0yjL8Qg8iQ4l4qa+5gwmLzuWjWuOE3iZI4CFykTT19o89YNiheHS2qbAmOVEHykIUej9n7w+uvYpuRFEzXyx5pTYag1o9Y65pAbbmqPWmMhbp2UlwyBWRvbb8I+8RxWAtc/UliRMBD1xWAllVeEnejJpgGrVseOL0h1ARIsys1o4E3MTkr/t2grRGYJaw7asu7iQx1hzSi1YhCF2ojrKWr1LIqNcRvTavuE0pkhFRQNtYSNSKSPbKywFbK3dryNUgkejO/IRmTUmFifoWFRQiH1BRMG2lpDmcgdHwLRO74frs3D0DT0ZSJ9seMbBmLIRds1Htnu/VmsNscc9wOinVWo7xU+9MFC/+VvPdJWQ4KOrkzEpHhHCOSGBUdhsFgyE+lip/Bqs8RZu8O5zLDhGi65RT5Y7DOItEEK3jmCx7kUi6hs+opR5MVvao+P9dwUvQOT9L8d9VxqBixIzRO555pzbepo4uvfaJZZmCmX+FYzjEVUTFtE5e8sf7UN5DpVdrVmILYN4KaMEgjBbBvr2ab2JHpnEmgNg1yYIXiGsYhKKqoyoXcm+xgJAdREVWy4BW76fkuxJaRSi8F0h7z8GtcGcj3BQ0//S/drw4WFURfznG45xyXTbwO5Z7O0RiBrBfBWTRDEMxjejIXiu4HcgcQiKnFG5lXZ8p7NOEA7ZKqoeeovoJVrHqufoneu+YoNV/BafZo9X/ATw6utEF73uUz9cM1XbPlQp2DVeq44y6YbsGz2MaL8d6/++2/7HZljjsPwO//OP/mCqr5wN/scG0//WqeIyorJGEjBb5cLsYhKSYFnaIJu+YtV0rX3vL94i1ebwGcEgS64mG0zNEEM67HsBCt2vw0azgZyF6Ti6fwmA2NZTfSO9ZS6d4DeSYVLzmcZi8ayZMbkWF5vlC+Oz7ftz6XhTLbNmWz7QCD3fHaDpNnfHSC6gdx/sX+W78rfYs16+lH/p0ZjSmgDypGB3FItT+eR2iGsMN2OGS+bLmcvUhNLZtzm7KcZyQjHiilYkIYVc5NSb1LH9iWPf4SCetZsj2/VisOz43vtoAZw3e/GlMlAba27AVt+EOme69EzDoHcG77k5foEbzQPse52OWVGnM3C+ofHbM5YG9abMas2Y832+HIFG27IglSs2V1WbU1fzFQgd9Ua1ix46tZIj5kUNq+p6Ylh0YTau58thzyVX8cSaiOE569m2azjEHZ96IMd9ZTRKYAUyLWs2R4rxgFCqXA+U87YnTaQ++qfPn6/X5k55rhnHAujL8BSp4ZrSl28mG22i3wMtJWILubXpgzmqhlTk+rFwobvAWOWspDaeNV5zlrLkkk5/TWGXUozoi++/b7UhrF6SlX2vOFinrccv8GwYiZiW+m4fXGcssKFfKPNvQcocOz4Pl4M4yi/4BC23CDo6UevslLLs/kGS2ZSsDsY/JpcpPXie1H4LXn6yzoGPKU0sSh4oBQ8DsOk7qtTZVFyDDXYmqGGQO6CERa0iumbUdlSghzESEOh8h1f8JwpcWhbYzfNegypJi6MzKjNanEI57JJODvcy1Ererdibet1W4KEwhk7AuCsHdGXsDjP49nVYJxXbFizsOvHrBpPgacnjqHxDE1IAOgxSe0NwdxQZa0myC/3kLaweTc91qlyIbvBigkSCj3JGGnNtlfecIuUPuf8iVGsBeAYqGcY6yr0xbSlHfeiJMWCGHa8Y9OHAjsA2ZM79+M1mWOO+4JjYfRH3vDp0Rn+cPsZFrIx7z9xme/pX+aXbnyUm80JMuMYmIpHezfYdX1qtTxeXOOkLVnrXeGz++dYtiMu5pvkKB8qDLUW/OzWRf7Kwtc4l1muuoY/2r9AqQW/e/1ZvnnjFKOyx4lexV8//xUeybd4rv8GJ2XcGpR/vvswWy7I4paa89nNJxk1BX/v3O/waHaTJROkdP+f3afoS8VJW9KXmlwahqbi2axqDUxG4s9rnI6nrv/nbz7Om/UyjxfXeF/xFr94/WM8v3CZjw9e4avVKgCbzSIDM6Zvarwa1uuHWk5+yY64UFzjL8rH+IudR/lbD/9Rq2g5MGO+uH+BM/k2b9QPMYoaOGv5NpUGaYa+qelLzaV6hb958quxBGAOKL+9v8Ips8c/vfLDnO7tUvmMTByfWP5zfvX6D2DE85dPvkzfhMI0OY5//MYPcLq3RyaOiyc2OJvfoBDHQMb8WnkKG4uKOwxOg6pnLg2/eP1jLGRjnj2xzgd6lzmXOb5VG35n7zk+0L/M9/dr/mz8CDvuBA7DmWyLV6uHWbG7fKN8hL6pebJ3lb6peaF3JSx2MzmfH4fMqX929eMs5/uc7W2xlm/zTG+dc9mIT+9+N2fzG7w6fpin+2/yQyeu8HjW5305QM1v7C3xhdETXK8W8Qib1QCDcrq3y3ODdT4+eCWsjwB+f/8MHz/xZiz1GCi68//Jq+/GazTHdyBeuod9bmv0ReTngE8Ab6nq8/G7FeCXgQuE9QH/oareEBEB/gnwY8AI+Nuq+sXbnUMJwb6xz+j5Go+JnrzBIzTeUottBc0SQkBwGpaQAgpQahapFI2KmIFDL11O1WQ0taXOLKXP2+BpJRM/uYpplgBjn1O6nLELBUTC8aImkArWJA1+H4PKk6pQbdvSEv+OXK9TPxVgTaqb6bdvUzszPHWr19MVd3Mxa6b2WegrzXBR5iHtn/rNa/jsNPQxEgyvE2kHkS6H72LaY+Utlc8YuwxnhBrL2FvyJJIXtfuteBofjgm2DeKa2B6HiZLHYRuPweJiPQKh8YYqzpjSCufaZ1RqD1W6DH2XMfYZVjyVWkzMy0/bp+uvvGXsM8ax2E2tGU4n6piTUpmz58gofc7Yh5hN5SxGlLHP2uLrSaa7bbdMFD/9/v6BY84xx3uFO9HO/XngR2e++xng91T1KeD34t8Afw14Kv78JPC/359m3jlupQDfHTBmMVtC757OPXP8O5HjPWz/O5FrPuqct4K/i23vBodJMIfBZUZG+hbnd4dsDwfliifHv78ZMXcqcW3mMslzPOC4raevqn8oIhdmvv4bwL8dP/8C8C+AfxC//z81pAT9iYgsi8gjqvrmrc4hhJQ9I57MeAw+Vq8KNUgz42LxERflikMufBIsM+JbDzt5/kE6t5t1EjzxWi2ZOIwoYsJ5UqrlLFI9VYhyzaJkbQ3fCdIxUpuM+JjCmbzcmcFAPVYMTn27v4nts3gK07TXbmIxkHSNRmL1qM6157GduWnIjGu9ahv3t/HvXBpMzFxJfRc+65GD3uQ4SioSk+oY5zK5Tza2LxdHZjy5cfGaghc/+UmlBkOdsbBPE3/C/S86RXDS7MHG6ljp727bbOyDWVXNyb338b5o21/pc5LizpMq55F94Ns+z5KM8kxhGAsHnw1AiuLAMeeY476gvPtd7tX1W0uGPP5+OH7/KHCps93l+N1t4dUceGG73lxKcYQwpa86VI+P2vSTabrGwNukUPYsjCipdkgyAEYm1Ixpt5u87KkgRlvwmkkHOkzbpuTltlotM4NKl96Zut5I7zQ+yDvficfvE03Tfr6FNz07+CQqSY8uNuJaeiauYZCDA6RTaesDJIpocg6hSvcLOdTbT/cvfD58m+7x0mrmo2A79697HYc9C06n++2ofjAyPaBMBsE08By2zzw/f47jh/s93z/sKT/U6orIT4rI50Xk81ubgZRxKtHYmSmDP2X848trRac8/S4Spx+MSec40agmo6xeYjxB2hc/VcMKnPAklz5977oGig6nn7zrrgedcuCP6OZk/Gepiu71JE4+GcbU9sTNH4bUvrRwapbTT/14N/CpX2Kftv2kqfDNZLbloypo6vuud57O3cYsOn3TvVeHGV/XiRtMjjMZBML9SW3r7hfO0bQDsmnv+ew1Hnn9KtTetoNsGpR924ZJNbXJPsrhc8g55njvcK/ZO1cTbSMijwBvxe8vA+c62z0GHCo8oqqfBD4J8F0f6Ie1svGFcq3XFY12+r4bcOtsd5R362aMSPLCPYLq5CcFQg+UyOues/Oyd+UTJsefGFM7o5/vY7jyVpganDpG8U645rb9Ha8/BXkP3V47g42YtmZsNwg5aZeZut5k0MNxpOW4nRq8hLKUE4PqO0YxaAc5TEun+UOuMQXwJ9cmbeC5+/eB/uvMNo56HrrGerr/JucM7Z32U9pBpdMH6frbvrmFU69VdfQ/55jjXca9Gv3fAP4W8I/i71/vfP93ReSXgI8A27fj82HyvpjIuyb+ldnvD5ueH/J3KmSeeF46XHAuDVnknY2dcPqJF08/SYs95Zd7MWTiW07fyKQodnvuSFG5lCM/QzMkLr/7d7g+PUA9TCQdGryaltMntsuIb42fZcJtm0g/dTn9ozAbzwjHmJl1xGNnkSrrcvqhHTrlrTukw5dPn9/pxOAHTn96Bnd0myac/iy6nH73x0xtM+HiU3zIzLQt5E4lumb6POn4mXHtcbr9EJ7Zg5x+KNKuSDYvojLHO4R78CfuJGXzFwlB29Michn47wnG/ldE5O8A3wb+g7j5bxLSNV8hpGz+p3ffpAmOCrBOb3P037P0Dkwomy69k7YN/0+ziYPfJ4qjPX5nu4RWV+awth7F5d8iE2WW3gHX0jvT13V45sv9ytiZvu5p6s3c4j5N6fJ0B4AOvWNxR85KujgsG2oyO+vQO3I4vdOdBc22bfa6Dl7HpC9bj39q0IJcDtI7AOic5Jnj+OBOsnd+4oh//fAh2yrwX91tI2YNb/g8/QLO/n1YlkaCudVcGw6ld+Bow3lUG2Yxa7hux+nPtul+IMVE0IPtuR2OCmJOjistvXPUtonTP/D9Ldpyu8H9dvGHe0lb7bbxTlJ1Z2m2w/rhsHRhD6ifp3nOcXxwLFbkdukd2s869aLN5kcfltMN0/TOUTgqZbObtpjOnP72keKou9N37ozeuR2nfxi900VKFzyK3um2827onYPnucVA2knZnE1V7CLRO5P9ukHXo2IvRwel4XYDRieg30nZPIze6VI63c+3Guy6x+juG9p1MGVzFgYQOy+MPsc7hFstTDoC78xqnfuEO6F33k0c5o0fRu98J+FeFysdt3sL3HLgnWOOf1NwrI3+neBWA93deLlHbTs1+zjEKMx24GGrU98O7uYa3kmYqRnO/TOOdzNIHhbInfxvOvh7y+Pcpz49rB9mA7lzzHHccCzonS6nP5Ua2Emj66ZsTtMaB5E4/e5CoW7KZqNh8ZP6icZPysXvBnKdTvRoarU0amjUTG3zdjz9tCo38eVdGYZZ7Z2jArmJNummbCZOn6i9k3AgcHlIyuZhcEjIcY+cfipDGCg2c2Cgm03ZnD6WiZk7k/0s0+s0jkrZ7AZy26Bs/H+tljz2za0CuWndxWyefpsefNj162TRHEDjDUZ0qtZxOMbhefra1Mwxx3HBsTD6kLJqDhr1bi44BIN4a4539rjC7Pown4K4gOp0zr1Xg5vx0LrBv9lc7alzyyQ2EH7fmac3y3V3PcikRHmnOJSCehsZPNMrZQ9fJXynA127qOqQ2dC9et9ThvsO8/Rvd7zDsoRulVhwuzx9jkGhojnmSDgWRj/qMQbdmJk8/RS0THn6Bm2977BwaBrdQO6snkzSxcnEY43HxJ9ujn76nPL0Q6DQ4ZCgC2PcoXn6abHYrfL0j0IbSD7E8Fk8iJmiebqZS9MaOj7m00+vDO7q9HQxy8cfRou0efrGk4vHt7pCMU8fbRdwOYScSZB0NsOqu64gHbtdoBX77rA8/aP6Jl1Tel5SCc0Q9D6IlKefAtKHUWe3y9MPayamA9rts3pY+wDJjsVrNse/ibiHSeSxeBqVSf58u7Kyu+qTCb1zVwqUszn6nRW5id7RQ+gEmNA7symbt6N33g5upbLpZ2Y/s/ukz7dK2ZyVbvBxFtXKO9yC3jksZdMd4eF26Z3Dj3dQgiFd12H0zux1Hjhee93TayiOale6jsNmBLdekWumZgt3nLLp7iHFYo453iEcC6N/v3G7PP2ERO8AdPn8w9CNN4S/uy/7/QveHsaBH/W3Rw4EE48aMKYGL+TINh8mwxCOcTS9c6truJN0yLvFtGzD/aN3bpfvf89yznN6Z45jhGNh9BO9MyvDEETMDtI73Ib/vVMZBmsDzdNdvt9V2ZyVYSiMw3lzgN7J5c48udvJMCT6KUj8Hi7D0KV7TNTX6cowdOmdlKefR6niRIWQ5BDERPmBicLoUTIMhXEtvdOVYWjP16HSAk0Xtpvtm6NkGIAj6Z1ZGYbU9nSP75cMQ/j+7mUYksrmkTIMvd5Rj8Qcc7w93IO08rEw+kfhfnjQh8kwQArm0n4Ov4Mo2Ow+twuE3jaQ9zZwP2QU7nc/HuifI45/q1nBrAzDXbfnXr3uWx3zNiuSZ/9O2yeVzfyo3ecrcuc4RhA9BlNPEdnh3so9vts4DVx7rxtxGzwIbYQHo50PQhvhwWjng9BGeDDa2W3jeVVdvZudj4un/5KqvvBeN+J2EJHPH/d2PghthAejnQ9CG+HBaOeD0EZ4MNr5dtv4wK/InWOOOeaY484xN/pzzDHHHN9BOC5G/5PvdQPuEA9COx+ENsKD0c4HoY3wYLTzQWgjPBjtfFttPBaB3DnmmGOOOd4dHBdPf4455phjjncB77nRF5EfFZGXROQVEfmZ97AdPycib4nIVzrfrYjI74jIy/H3Q/F7EZH/Jbb5X4vIh9/Fdp4TkT8Qka+JyFdF5L85bm0Vkb6I/KmIfDm28X+I3z8hIp+LbfxlESni97349yvxGsWc9wAABIVJREFU/xfe6TZ22mpF5Esi8qlj3MbXROQvROTPReTz8btjc7877VwWkV8Tka/H5/Njx6mdIvJ07MP0c1NE/t5xamM870/F9+YrIvKL8X26f8+lqr5nP4RFjN8ELgIF8GXgufeoLT8AfBj4Sue7/wn4mfj5Z4B/HD//GPBbhCVZHwU+9y628xHgw/HzEPgG8Nxxams812L8nAOfi+f+FeDH4/c/C/wX8fN/Cfxs/PzjwC+/i/3508D/DXwq/n0c2/gacHrmu2Nzvztt+gXgP4ufC2D5OLYznt8CV4Dzx6mNwKPAt4ATnefxb9/P5/Jd6+QjLvBjwKc7f/9D4B++h+25wLTRfwl4JH5+hLCeAOD/AH7isO3egzb/OvBXj2tbgQHwReAjhAUl2ey9Bz4NfCx+zuJ28i607THg94AfAj4VX+5j1cZ4vtc4aPSP1f0GTkZjJce5nZ3z/bvAHx23NhKM/iVgJT5nnwJ+5H4+l+81vZMuMOFy/O64YE1V3wSIvx+O3x+Ldsep3If4/9s5mxCbwjCO/57yPeRjsqBRTEk2GpKEpNiYNDYWpFgoGxsrJaXsJRvZkIVE+cySfGx9DONzlBExBiMyygr9Ld7nzNzFnSs1955X9/nV7ZzznHfx67zvec55n3PuSXfSWbl62aQHGASuk2Z03yT9quIx7Oj7h4DWejsCR4F9jHyUszVDR0jfBrxmZt1mtttjWfU3abb+GTjl5bITZtaSoWfBVuCsr2fjKOk9cBh4C3wgjbNuxnBclp30q32t5H94nah0bzObClwE9kr6XqtplVjdXSX9ltRBupteASyu4dFwRzPbBAxK6q4M1/Aos89XS1oGbAT2mNnaGm3L8hxHKo8el7QU+EEqlYxGacfT6+FdwPm/Na0Sq/e4nAlsBhYAc4EWUr+P5vHPjmUn/X5gXsV2GzBQkks1PpnZHABfDnq8VG8zG09K+GckXcrZVdI34DapJjrDzIpPf1R6DDv6/unA1zqrrQa6zOwNcI5U4jmamSMAkgZ8OQhcJl1Ec+vvfqBf0h3fvkC6COTmCSmJPpD0ybdzctwAvJb0WdJP4BKwijEcl2Un/XvAQn8yPYE05bpaslMlV4Gdvr6TVD8v4jv86f5KYKiYHtYbMzPgJNAr6UiOrmY228xm+Ppk0kDuBW4BW0ZxLNy3ADflRcp6IWm/pDZJ80nj7qak7Tk5AphZi5lNK9ZJteinZNTfAJI+Au/MbJGH1gPPc/N0tjFS2ilccnF8C6w0syl+rhfHcezGZaMenNR4cNFJegPlFXCgRI+zpBraT9LVcxepNnYDeOnLWd7WgGPu/ARY3kDPNaTp22Ogx3+dObkCS4CH7vgUOOjxduAu0EeaWk/0+CTf7vP97Q3u+3WMvL2TlaP7PPLfs+Icyam/K1w7gPve71eAmbl5kl4s+AJMr4jl5ngIeOHnzmlg4liOy/hHbhAEQRNRdnknCIIgaCCR9IMgCJqISPpBEARNRCT9IAiCJiKSfhAEQRMRST8IgqCJiKQfBEHQRETSD4IgaCL+ANeGfzr6+s7pAAAAAElFTkSuQmCC\n",
302 | "text/plain": [
303 | ""
304 | ]
305 | },
306 | "metadata": {
307 | "needs_background": "light"
308 | },
309 | "output_type": "display_data"
310 | }
311 | ],
312 | "source": [
313 | "plt.imshow(log_prob_matrix[0, :, :].cpu())"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 11,
319 | "metadata": {},
320 | "outputs": [
321 | {
322 | "data": {
323 | "text/plain": [
324 | ""
325 | ]
326 | },
327 | "execution_count": 11,
328 | "metadata": {},
329 | "output_type": "execute_result"
330 | },
331 | {
332 | "data": {
333 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABBYAAADNCAYAAAAMlPNVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVa0lEQVR4nO3dbYylZ3kf8P/V9Rsmccyal9petzbKikBRWOgK3FAh15vUJkE4HyA1SZMtcbWqRBNIUgWTfiCthBTUKISoLdIKCKaiBseB2KpoDBgQ7QccbHDB4BA2htjLbrxEmJdChXFy9cN5BsbLzO7sc86ZM3Pm95NGc577POecW7pmzpn973XfT3V3AAAAAMb4e4ueAAAAALB9CRYAAACA0QQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGG1uwUJVXVtVn6+qI1V147xeBwAAAFic6u7ZP2nVriR/keSnkhxN8okkr+juz838xQAAAICFmVfHwvOTHOnuB7r70STvTnLdnF4LAAAAWJCz5vS8lyZ5aNXx0SQvWO/kJ+/e1Y8+8iNzmgoAAAAw1jfzyN9091PWu39ewUKtMfa4NRdVdSjJoSQ5L+fnW8cvOuUTXnPJvplNDgAAANiYD/Wtf3Wq++cVLBxNctmq4z1Jjq0+obsPJzmcJBfU7j5dcHDHsXuFCwAAALDFzGuPhU8k2VtVV1TVOUmuT3L7nF4LAAAAWJC5dCx092NV9W+T3JFkV5K3d/dnp3nOay7ZlzuO3fu92wAAAMDizWspRLr7/UneP8vnXAkUBAwAAACwNcwtWJinkwOG1WMAAADA5tmWwcKK1WGCLgYAAADYfPPavBEAAADYAZYmWLjmkn2P2+ARAAAAmL+lCRZWCBcAAABg82zrPRbWM4twwV4NAAAAcHpLGSwk0wcD8+p6EFgAAACwTJZuKQQAAACweZa2Y2Fa8+osWK8TQicDAAAA25FgYZOtFyDccexe4QIAAADbjqUQW4SrWQAAALAd6VjYQjYaLuhsAAAAYKvQsQAAAACMpmNhi9lIN4L9GAAAANgqdCxsQ/ZjAAAAYKsQLGxTK+GCgAEAAIBFGr0UoqouS/LOJH8/yd8lOdzdb66q3Unek+TyJF9K8nPd/cj0U+VkK8shVsIFyyMAAADYbNN0LDyW5De6+5lJrkzyqqp6VpIbk9zZ3XuT3DkcAwAAAEtodMdCdx9Pcny4/c2quj/JpUmuS3LVcNpNST6a5LVTzZJTOrlzYVbPBwAAAKczk6tCVNXlSZ6b5K4kTxtCh3T38ap66ixeg9ObVSAwTUAhlAAAANhZpg4WquqHkvxxktd09zeqaqOPO5TkUJKcl/OnnQYzNE04oGsCAABgZ5kqWKiqszMJFd7V3e8dhh+uqouHboWLk5xY67HdfTjJ4SS5oHb3NPNg65hn14SwAQAAYOsZvXljTVoT3pbk/u7+vVV33Z7k4HD7YJLbxk8PAAAA2Mqm6Vh4YZJfTPKZqlr57+XfSvI7SW6pqhuSPJjk5dNNkZ1ore4EXQwAAABbzzRXhfjfSdbbUOHA2OeF9Ww0bJjl8wMAAHBqM7kqBCzKLMOAWYYUZ0qoAQAAbFeCBRgs8h/3pwo1hA4AAMBWNnrzRgAAAAAdC7AFnKorYaWbQecCAACwFQkWYItbCRQ2ugeEAAIAANhMggXYJjYaGNiEEgAA2EyCBVgyi96EUrgAAAA7i80bAQAAgNEEC8DMXHPJvoUuxQAAADafpRDATG0kXLBcAgAAlodgAZi50wUH9mIAAIDlYSkEsOksmQAAgOUhWAAAAABGsxQCWIiNdi1YMgEAAFubYAFYmI2EBvZjAACArU2wAGxpJ3c2CBkAAGBrmTpYqKpdSe5O8uXufklVXZHk3Ul2J/lkkl/s7kenfR1g51odJsx600dBBQAATGcWmze+Osn9q47fmORN3b03ySNJbpjBawAAAABb0FQdC1W1J8nPJHlDkl+vqkpydZKfH065KclvJ3nLNK8DsGLWHQb2cAAAgOlMuxTi95P8ZpIfHo4vSvK17n5sOD6a5NIpXwNgbta7OoWwAQAANmZ0sFBVL0lyorvvqaqrVobXOLXXefyhJIeS5LycP3YaAFNbK0TQyQAAABszzR4LL0zy0qr6UiabNV6dSQfDhVW1EljsSXJsrQd39+Hu3t/d+8/OuVNMA2D21utkAAAAHm90sNDdr+vuPd19eZLrk3y4u38hyUeSvGw47WCS26aeJQAAALAlzeKqECd7bSYbOR7JZM+Ft83hNQDmbqVrQecCAACsb9rNG5Mk3f3RJB8dbj+Q5PmzeF6ARVvZZ8GeCwAAsLaZBAsAy26aPRcEEgAALDPBAsAGjQ0IdDsAALDM5rHHAgAAALBDCBYA5sylKwEAWGaWQgBsgnmFC5ZYAACwaIIFgE0yjxDgTMIKIQQAAPMgWADYxs4kLFgJIQQMAADMkj0WAAAAgNF0LADsECudCqdaPqGbAQCAMyVYANhhThUe3HHsXuECAABnxFIIAL5n5eoVLo8JAMBGCRYAeJxrLtk3t8tjAgCwfAQLAAAAwGj2WABgTWfStWBfBgCAnUuwAMC6NhoYTLNsQigBALC9CRYAmNo04cB6oYTAAQBge5gqWKiqC5O8Ncmzk3SSX07y+STvSXJ5ki8l+bnufmSqWQKwtNYLEOa1eaTAAgBgtqbdvPHNSf60u38syXOS3J/kxiR3dvfeJHcOxwAAAMASGt2xUFUXJHlRkn+VJN39aJJHq+q6JFcNp92U5KNJXjvNJAHYeebVWbBWJ4QuBgCA8aZZCvH0JF9J8odV9Zwk9yR5dZKndffxJOnu41X11OmnCQCzsVaI4OoXAADjTRMsnJXkeUl+pbvvqqo35wyWPVTVoSSHkuS8nD/FNABgOq5+AQAw3jTBwtEkR7v7ruH41kyChYer6uKhW+HiJCfWenB3H05yOEkuqN09xTwAYFPM4+oX83xNAIDNMHrzxu7+6yQPVdUzhqEDST6X5PYkB4exg0lum2qGAAAAwJY11eUmk/xKkndV1TlJHkjyykzCiluq6oYkDyZ5+ZSvAQDb3tjOg9WdDroXAICtaKpgobvvTbJ/jbsOTPO8AMDE6jDhjmP3ChcAgC1n2o4FAGCTXHPJvlF7NQgjAIB5EiwAwDYyJiRYK4wQNgAAszJ680YAAAAAHQsAsOTW6k5Y6WLQuQAATEuwAAA70EqgIGAAAKZlKQQA7GDXXLJv9KaQAACJYAEAyPevOCFgAADOlGABAAAAGM0eCwBAkh/cd2Gt+wAATiZYAAAeZ72rSAgXAIC1WAoBAJyWPRgAgPXoWAAANuRUSyXWOxcAWH46FgAAAIDRdCwAAGdkI90Iq7sadC8AwHITLAAAM7c6TLDxIwAsN0shAIC5Wtn4EQBYTlMFC1X1a1X12aq6r6purqrzquqKqrqrqr5QVe+pqnNmNVkAYHsSLgDA8hodLFTVpUl+Ncn+7n52kl1Jrk/yxiRv6u69SR5JcsMsJgoAAABsPdPusXBWkidU1XeTnJ/keJKrk/z8cP9NSX47yVumfB0AYJubZdeCPRsAYOsYHSx095er6neTPJjk/yX5QJJ7knytux8bTjua5NK1Hl9Vh5IcSpLzcv7YaQAA28isAoFTBRRCBwDYXKODhap6UpLrklyR5GtJ/ijJi9c4tdd6fHcfTnI4SS6o3WueAwCwllOFB65CAQCba5rNG38yyRe7+yvd/d0k703yE0kurKqVwGJPkmNTzhEAYMNWllzYLBIANsc0wcKDSa6sqvOrqpIcSPK5JB9J8rLhnINJbptuigAAAMBWNTpY6O67ktya5JNJPjM81+Ekr03y61V1JMlFSd42g3kCAGzYNZfsc4lLANgk1b347Q0uqN39gjqw6GkAAEvIngsAMJ0P9a33dPf+9e6fZikEAMCWp3MBAOZLsAAALD0bOgLA/AgWAAAAgNEECwDAjmBDRwCYj7MWPQEAgM10qnDBJo8AcOYECwDAjrNegOAKEgBw5iyFAAAY2OQRAM6cYAEAAAAYzVIIAIBVLIUAgDOjYwEAAAAYTbAAAAAAjCZYAAAAAEYTLAAAAACjCRYAAACA0QQLAAAAwGiCBQAAAGC00wYLVfX2qjpRVfetGttdVR+sqi8M3580jFdV/UFVHamqT1fV8+Y5eQAAAGCxNtKx8I4k1540dmOSO7t7b5I7h+MkeXGSvcPXoSRvmc00AQAAgK3otMFCd38syVdPGr4uyU3D7ZuS/Oyq8Xf2xMeTXFhVF89qsgAAAMDWMnaPhad19/EkGb4/dRi/NMlDq847Ooz9gKo6VFV3V9Xd3813Rk4DAAAAWKRZb95Ya4z1Wid29+Hu3t/d+8/OuTOeBgAAALAZxgYLD68scRi+nxjGjya5bNV5e5IcGz89AAAAYCsbGyzcnuTgcPtgkttWjf/ScHWIK5N8fWXJBAAAALB8zjrdCVV1c5Krkjy5qo4meX2S30lyS1XdkOTBJC8fTn9/kp9OciTJt5O8cg5zBgAAALaI0wYL3f2Kde46sMa5neRV004KAAAA2B5mvXkjAAAAsIMIFgAAAIDRBAsAAADAaIIFAAAAYDTBAgAAADCaYAEAAAAYTbAAAAAAjCZYAAAAAEYTLAAAAACjCRYAAACA0QQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGE2wAAAAAIwmWAAAAABGO22wUFVvr6oTVXXfqrH/VFV/XlWfrqr3VdWFq+57XVUdqarPV9U185o4AAAAsHgb6Vh4R5JrTxr7YJJnd/ePJ/mLJK9Lkqp6VpLrk/yj4TH/tap2zWy2AAAAwJZy2mChuz+W5KsnjX2gux8bDj+eZM9w+7ok7+7u73T3F5McSfL8Gc4XAAAA2EJmscfCLyf5n8PtS5M8tOq+o8PYD6iqQ1V1d1Xd/d18ZwbTAAAAADbbVMFCVf37JI8ledfK0Bqn9VqP7e7D3b2/u/efnXOnmQYAAACwIGeNfWBVHUzykiQHunslPDia5LJVp+1Jcmz89AAAAICtbFTHQlVdm+S1SV7a3d9eddftSa6vqnOr6ooke5P82fTTBAAAALai03YsVNXNSa5K8uSqOprk9ZlcBeLcJB+sqiT5eHf/m+7+bFXdkuRzmSyReFV3/+28Jg8AAAAsVn1/FcPiXFC7+wV1YNHTAAAAAE7yob71nu7ev979s7gqBAAAALBDCRYAAACA0QQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGE2wAAAAAIwmWAAAAABGEywAAAAAowkWAAAAgNEECwAAAMBoggUAAABgNMECAAAAMJpgAQAAABhNsAAAAACMVt296Dmkqr6S5FtJ/mbRc2HTPDnqvdOo+c6i3juPmu8s6r3zqPnOot47z+lq/g+7+ynr3bklgoUkqaq7u3v/oufB5lDvnUfNdxb13nnUfGdR751HzXcW9d55pq25pRAAAADAaIIFAAAAYLStFCwcXvQE2FTqvfOo+c6i3juPmu8s6r3zqPnOot47z1Q13zJ7LAAAAADbz1bqWAAAAAC2mYUHC1V1bVV9vqqOVNWNi54Ps1FVb6+qE1V136qx3VX1war6wvD9ScN4VdUfDD8Dn66q5y1u5oxRVZdV1Ueq6v6q+mxVvXoYV/MlVVXnVdWfVdX/GWr+H4bxK6rqrqHm76mqc4bxc4fjI8P9ly9y/oxTVbuq6lNV9T+GY/VeYlX1par6TFXdW1V3D2Pe15dUVV1YVbdW1Z8Pn+f/RL2XU1U9Y/i9Xvn6RlW9Rr2XW1X92vA3231VdfPwt9zMPscXGixU1a4k/yXJi5M8K8krqupZi5wTM/OOJNeeNHZjkju7e2+SO4fjZFL/vcPXoSRv2aQ5MjuPJfmN7n5mkiuTvGr4XVbz5fWdJFd393OS7EtybVVdmeSNSd401PyRJDcM59+Q5JHu/tEkbxrOY/t5dZL7Vx2r9/L7Z929b9UlyLyvL683J/nT7v6xJM/J5HddvZdQd39++L3el+QfJ/l2kvdFvZdWVV2a5FeT7O/uZyfZleT6zPBzfNEdC89PcqS7H+juR5O8O8l1C54TM9DdH0vy1ZOGr0ty03D7piQ/u2r8nT3x8SQXVtXFmzNTZqG7j3f3J4fb38zkj5FLo+ZLa6jd/x0Ozx6+OsnVSW4dxk+u+crPwq1JDlRVbdJ0mYGq2pPkZ5K8dTiuqPdO5H19CVXVBUlelORtSdLdj3b316LeO8GBJH/Z3X8V9V52ZyV5QlWdleT8JMczw8/xRQcLlyZ5aNXx0WGM5fS07j6eTP4hmuSpw7ifgyUytEo9N8ldUfOlNrTF35vkRJIPJvnLJF/r7seGU1bX9Xs1H+7/epKLNnfGTOn3k/xmkr8bji+Kei+7TvKBqrqnqg4NY97Xl9PTk3wlyR8Oy53eWlVPjHrvBNcnuXm4rd5Lqru/nOR3kzyYSaDw9ST3ZIaf44sOFtZKPVymYufxc7AkquqHkvxxktd09zdOdeoaY2q+zXT33w5tlHsy6UB75lqnDd/VfBurqpckOdHd96weXuNU9V4uL+zu52XSBv2qqnrRKc5V8+3trCTPS/KW7n5ukm/l+23wa1HvJTCsp39pkj863alrjKn3NjLsl3FdkiuSXJLkiZm8t59s9Of4ooOFo0kuW3W8J8mxBc2F+Xt4pW1q+H5iGPdzsASq6uxMQoV3dfd7h2E13wGGdtmPZrK/xoVDi13y+Lp+r+bD/T+SH1wuxdb1wiQvraovZbJs8epMOhjUe4l197Hh+4lM1l8/P97Xl9XRJEe7+67h+NZMggb1Xm4vTvLJ7n54OFbv5fWTSb7Y3V/p7u8meW+Sn8gMP8cXHSx8IsneYTfKczJpxbl9wXNifm5PcnC4fTDJbavGf2nYcfbKJF9facNiexjWXL0tyf3d/Xur7lLzJVVVT6mqC4fbT8jkA+v+JB9J8rLhtJNrvvKz8LIkH+5u/9uxTXT367p7T3dfnsln9Ye7+xei3kurqp5YVT+8cjvJP09yX7yvL6Xu/uskD1XVM4ahA0k+F/Vedq/I95dBJOq9zB5McmVVnT/83b7yOz6zz/Fa9Od8Vf10Jv/rsSvJ27v7DQudEDNRVTcnuSrJk5M8nOT1Sf4kyS1J/kEmP9wv7+6vDj/c/zmTq0h8O8kru/vuRcybcarqnyb5X0k+k++vv/6tTPZZUPMlVFU/nsmmPrsyCalv6e7/WFVPz+R/tHcn+VSSf9nd36mq85L8t0z23/hqkuu7+4HFzJ5pVNVVSf5dd79EvZfXUNv3DYdnJfnv3f2Gqroo3teXUlXty2Rz1nOSPJDklRne36PeS6eqzs9kDf3Tu/vrw5jf7yVWk0uD/4tMrub2qST/OpO9FGbyOb7wYAEAAADYvha9FAIAAADYxgQLAAAAwGiCBQAAAGA0wQIAAAAwmmABAAAAGE2wAAAAAIwmWAAAAABGEywAAAAAo/1/hvu7dg2PYBQAAAAASUVORK5CYII=\n",
334 | "text/plain": [
335 | ""
336 | ]
337 | },
338 | "metadata": {
339 | "needs_background": "light"
340 | },
341 | "output_type": "display_data"
342 | }
343 | ],
344 | "source": [
345 | "plt.figure(figsize=(18, 18))\n",
346 | "plt.imshow(alignments[0][0, :, :].cpu())"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": []
355 | }
356 | ],
357 | "metadata": {
358 | "kernelspec": {
359 | "display_name": "Python 3",
360 | "language": "python",
361 | "name": "python3"
362 | },
363 | "language_info": {
364 | "codemirror_mode": {
365 | "name": "ipython",
366 | "version": 3
367 | },
368 | "file_extension": ".py",
369 | "mimetype": "text/x-python",
370 | "name": "python",
371 | "nbconvert_exporter": "python",
372 | "pygments_lexer": "ipython3",
373 | "version": "3.8.3"
374 | }
375 | },
376 | "nbformat": 4,
377 | "nbformat_minor": 2
378 | }
379 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.1.1
2 | tensorboard
3 | numpy==1.17.4
4 | inflect==4.0.0
5 | librosa==0.7.1
6 | scipy==1.3.2
7 | Unidecode==1.1.1
8 | pillow==6.2.1
9 | g2p_en==2.0.0
--------------------------------------------------------------------------------
/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2017, Prem Seetharaman
5 | All rights reserved.
6 |
7 | * Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice,
11 | this list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice, this
14 | list of conditions and the following disclaimer in the
15 | documentation and/or other materials provided with the distribution.
16 |
17 | * Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from this
19 | software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | """
32 |
33 | import torch
34 | import numpy as np
35 | import torch.nn.functional as F
36 | from torch.autograd import Variable
37 | from scipy.signal import get_window
38 | from librosa.util import pad_center, tiny
39 | from audio_processing import window_sumsquare
40 |
41 |
42 | class STFT(torch.nn.Module):
43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
44 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
45 | window='hann'):
46 | super(STFT, self).__init__()
47 | self.filter_length = filter_length
48 | self.hop_length = hop_length
49 | self.win_length = win_length
50 | self.window = window
51 | self.forward_transform = None
52 | scale = self.filter_length / self.hop_length
53 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
54 |
55 | cutoff = int((self.filter_length / 2 + 1))
56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
57 | np.imag(fourier_basis[:cutoff, :])])
58 |
59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
60 | inverse_basis = torch.FloatTensor(
61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
62 |
63 | if window is not None:
64 | assert(filter_length >= win_length)
65 | # get window and zero center pad it to filter_length
66 | fft_window = get_window(window, win_length, fftbins=True)
67 | fft_window = pad_center(fft_window, filter_length)
68 | fft_window = torch.from_numpy(fft_window).float()
69 |
70 | # window the bases
71 | forward_basis *= fft_window
72 | inverse_basis *= fft_window
73 |
74 | self.register_buffer('forward_basis', forward_basis.float())
75 | self.register_buffer('inverse_basis', inverse_basis.float())
76 |
77 | def transform(self, input_data):
78 | num_batches = input_data.size(0)
79 | num_samples = input_data.size(1)
80 |
81 | self.num_samples = num_samples
82 |
83 | # similar to librosa, reflect-pad the input
84 | input_data = input_data.view(num_batches, 1, num_samples)
85 | input_data = F.pad(
86 | input_data.unsqueeze(1),
87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
88 | mode='reflect')
89 | input_data = input_data.squeeze(1)
90 |
91 | forward_transform = F.conv1d(
92 | input_data,
93 | Variable(self.forward_basis, requires_grad=False),
94 | stride=self.hop_length,
95 | padding=0)
96 |
97 | cutoff = int((self.filter_length / 2) + 1)
98 | real_part = forward_transform[:, :cutoff, :]
99 | imag_part = forward_transform[:, cutoff:, :]
100 |
101 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
102 | phase = torch.autograd.Variable(
103 | torch.atan2(imag_part.data, real_part.data))
104 |
105 | return magnitude, phase
106 |
107 | def inverse(self, magnitude, phase):
108 | recombine_magnitude_phase = torch.cat(
109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
110 |
111 | inverse_transform = F.conv_transpose1d(
112 | recombine_magnitude_phase,
113 | Variable(self.inverse_basis, requires_grad=False),
114 | stride=self.hop_length,
115 | padding=0)
116 |
117 | if self.window is not None:
118 | window_sum = window_sumsquare(
119 | self.window, magnitude.size(-1), hop_length=self.hop_length,
120 | win_length=self.win_length, n_fft=self.filter_length,
121 | dtype=np.float32)
122 | # remove modulation effects
123 | approx_nonzero_indices = torch.from_numpy(
124 | np.where(window_sum > tiny(window_sum))[0])
125 | window_sum = torch.autograd.Variable(
126 | torch.from_numpy(window_sum), requires_grad=False)
127 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
128 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
129 |
130 | # scale by hop ratio
131 | inverse_transform *= float(self.filter_length) / self.hop_length
132 |
133 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
134 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
135 |
136 | return inverse_transform
137 |
138 | def forward(self, input_data):
139 | self.magnitude, self.phase = self.transform(input_data)
140 | reconstruction = self.inverse(self.magnitude, self.phase)
141 | return reconstruction
142 |
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | # -*- coding: utf-8 -*-
3 |
4 | import re
5 | from text import cleaners
6 | from text.symbols import symbols
7 |
8 | # Mappings from symbol to numeric ID and vice versa:
9 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
10 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
11 |
12 | # Regular expression matching text enclosed in curly braces:
13 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
14 |
15 |
16 | def text_to_sequence(text, cleaner_names):
17 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
18 |
19 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
20 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
21 |
22 | Args:
23 | text: string to convert to a sequence
24 | cleaner_names: names of the cleaner functions to run the text through
25 |
26 | Returns:
27 | List of integers corresponding to the symbols in the text
28 | '''
29 | sequence = [_symbol_to_id['^']]
30 |
31 | # Check for curly braces and treat their contents as ARPAbet:
32 | while len(text):
33 | m = _curly_re.match(text)
34 | if not m:
35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
36 | break
37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
38 | sequence += _arpabet_to_sequence(m.group(2))
39 | text = m.group(3)
40 |
41 | # Append EOS token
42 | sequence.append(_symbol_to_id['~'])
43 | return sequence
44 |
45 |
46 | def sequence_to_text(sequence):
47 | '''Converts a sequence of IDs back to a string'''
48 | result = ''
49 | for symbol_id in sequence:
50 | if symbol_id in _id_to_symbol:
51 | s = _id_to_symbol[symbol_id]
52 | # Enclose ARPAbet back in curly braces:
53 | if len(s) > 1 and s[0] == '@':
54 | s = '{%s}' % s[1:]
55 | result += s
56 | return result.replace('}{', ' ')
57 |
58 |
59 | def _clean_text(text, cleaner_names):
60 | for name in cleaner_names:
61 | cleaner = getattr(cleaners, name)
62 | if not cleaner:
63 | raise Exception('Unknown cleaner: %s' % name)
64 | text = cleaner(text)
65 | return text
66 |
67 |
68 | def _symbols_to_sequence(symbols):
69 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
70 |
71 |
72 | def _arpabet_to_sequence(text):
73 | return _symbols_to_sequence(['@' + s for s in text.split()])
74 |
75 |
76 | def _should_keep_symbol(s):
77 | return s in _symbol_to_id and s is not '_' and s is not '~'
78 |
--------------------------------------------------------------------------------
/text/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cleaners.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cleaners.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cleaners.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cleaners.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cmudict.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cmudict.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/cmudict.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/cmudict.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/numbers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/numbers.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/numbers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/numbers.cpython-37.pyc
--------------------------------------------------------------------------------
/text/__pycache__/symbols.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/symbols.cpython-36.pyc
--------------------------------------------------------------------------------
/text/__pycache__/symbols.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/text/__pycache__/symbols.cpython-37.pyc
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """This file is derived from https://github.com/keithito/tacotron.
2 |
3 | Cleaners are transformations that run over the input text at both training and eval time.
4 |
5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
7 | 1. "english_cleaners" for English text
8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
11 | the symbols in symbols.py to match your data).
12 | """
13 |
14 | import re
15 |
16 | from unidecode import unidecode
17 |
18 | from text.numbers import normalize_numbers
19 |
20 | # Regular expression matching whitespace:
21 | _whitespace_re = re.compile(r'\s+')
22 |
23 | # List of (regular expression, replacement) pairs for abbreviations:
24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
25 | ('mrs', 'misess'),
26 | ('mr', 'mister'),
27 | ('dr', 'doctor'),
28 | ('st', 'saint'),
29 | ('co', 'company'),
30 | ('jr', 'junior'),
31 | ('maj', 'major'),
32 | ('gen', 'general'),
33 | ('drs', 'doctors'),
34 | ('rev', 'reverend'),
35 | ('lt', 'lieutenant'),
36 | ('hon', 'honorable'),
37 | ('sgt', 'sergeant'),
38 | ('capt', 'captain'),
39 | ('esq', 'esquire'),
40 | ('ltd', 'limited'),
41 | ('col', 'colonel'),
42 | ('ft', 'fort'),
43 | ]]
44 |
45 |
46 | def expand_abbreviations(text):
47 | for regex, replacement in _abbreviations:
48 | text = re.sub(regex, replacement, text)
49 | return text
50 |
51 |
52 | def expand_numbers(text):
53 | return normalize_numbers(text)
54 |
55 |
56 | def lowercase(text):
57 | return text.lower()
58 |
59 |
60 | def collapse_whitespace(text):
61 | return re.sub(_whitespace_re, ' ', text)
62 |
63 |
64 | def convert_to_ascii(text):
65 | return unidecode(text)
66 |
67 |
68 | def basic_cleaners(text):
69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
70 | text = lowercase(text)
71 | text = collapse_whitespace(text)
72 | return text
73 |
74 |
75 | def transliteration_cleaners(text):
76 | '''Pipeline for non-English text that transliterates to ASCII.'''
77 | text = convert_to_ascii(text)
78 | text = lowercase(text)
79 | text = collapse_whitespace(text)
80 | return text
81 |
82 |
83 | def english_cleaners(text):
84 | '''Pipeline for English text, including number and abbreviation expansion.'''
85 | text = convert_to_ascii(text)
86 | text = lowercase(text)
87 | text = expand_numbers(text)
88 | text = expand_abbreviations(text)
89 | text = collapse_whitespace(text)
90 | return text
91 |
92 |
93 | # NOTE (kan-bayashi): Following functions additionally defined, not inclueded in original codes.
94 | def remove_unnecessary_symbols(text):
95 | # added
96 | text = re.sub(r'[\(\)\[\]\<\>\"]+', '', text)
97 | return text
98 |
99 |
100 | def expand_symbols(text):
101 | # added
102 | text = re.sub("\;", ",", text)
103 | text = re.sub("\:", ",", text)
104 | text = re.sub("\-", " ", text)
105 | text = re.sub("\&", "and", text)
106 | return text
107 |
108 |
109 | def uppercase(text):
110 | # added
111 | return text.upper()
112 |
113 |
114 | def custom_english_cleaners(text):
115 | '''Custom pipeline for English text, including number and abbreviation expansion.'''
116 | text = convert_to_ascii(text)
117 | text = lowercase(text)
118 | text = expand_numbers(text)
119 | text = expand_abbreviations(text)
120 | text = expand_symbols(text)
121 | text = remove_unnecessary_symbols(text)
122 | text = uppercase(text)
123 | text = collapse_whitespace(text)
124 |
125 | # There is an exception (I found it!)
126 | # "'NOW FOR YOU, MY POOR FELLOW MORTALS, WHO ARE ABOUT TO SUFFER THE LAST PENALTY OF THE LAW.'"
127 | if text[0]=="'" and text[-1]=="'":
128 | text = text[1:-1]
129 |
130 | return text
131 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
14 | ]
15 |
16 | _valid_symbol_set = set(valid_symbols)
17 |
18 |
19 | class CMUDict:
20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
21 | def __init__(self, file_or_path, keep_ambiguous=True):
22 | if isinstance(file_or_path, str):
23 | with open(file_or_path, encoding='latin-1') as f:
24 | entries = _parse_cmudict(f)
25 | else:
26 | entries = _parse_cmudict(file_or_path)
27 | if not keep_ambiguous:
28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
29 | self._entries = entries
30 |
31 |
32 | def __len__(self):
33 | return len(self._entries)
34 |
35 |
36 | def lookup(self, word):
37 | '''Returns list of ARPAbet pronunciations of the given word.'''
38 | return self._entries.get(word.upper())
39 |
40 |
41 |
42 | _alt_re = re.compile(r'\([0-9]+\)')
43 |
44 |
45 | def _parse_cmudict(file):
46 | cmudict = {}
47 | for line in file:
48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
49 | parts = line.split(' ')
50 | word = re.sub(_alt_re, '', parts[0])
51 | pronunciation = _get_pronunciation(parts[1])
52 | if pronunciation:
53 | if word in cmudict:
54 | cmudict[word].append(pronunciation)
55 | else:
56 | cmudict[word] = [pronunciation]
57 | return cmudict
58 |
59 |
60 | def _get_pronunciation(s):
61 | parts = s.strip().split(' ')
62 | for part in parts:
63 | if part not in _valid_symbol_set:
64 | return None
65 | return ' '.join(parts)
66 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """ from https://github.com/keithito/tacotron """
3 |
4 | import inflect
5 | import re
6 |
7 |
8 | _inflect = inflect.engine()
9 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
10 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
11 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
12 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
13 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
14 | _number_re = re.compile(r'[0-9]+')
15 |
16 |
17 | def _remove_commas(m):
18 | return m.group(1).replace(',', '')
19 |
20 |
21 | def _expand_decimal_point(m):
22 | return m.group(1).replace('.', ' point ')
23 |
24 |
25 | def _expand_dollars(m):
26 | match = m.group(1)
27 | parts = match.split('.')
28 | if len(parts) > 2:
29 | return match + ' dollars' # Unexpected format
30 | dollars = int(parts[0]) if parts[0] else 0
31 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
32 | if dollars and cents:
33 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
34 | cent_unit = 'cent' if cents == 1 else 'cents'
35 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
36 | elif dollars:
37 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
38 | return '%s %s' % (dollars, dollar_unit)
39 | elif cents:
40 | cent_unit = 'cent' if cents == 1 else 'cents'
41 | return '%s %s' % (cents, cent_unit)
42 | else:
43 | return 'zero dollars'
44 |
45 |
46 | def _expand_ordinal(m):
47 | return _inflect.number_to_words(m.group(0))
48 |
49 |
50 | def _expand_number(m):
51 | num = int(m.group(0))
52 | if num > 1000 and num < 3000:
53 | if num == 2000:
54 | return 'two thousand'
55 | elif num > 2000 and num < 2010:
56 | return 'two thousand ' + _inflect.number_to_words(num % 100)
57 | elif num % 100 == 0:
58 | return _inflect.number_to_words(num // 100) + ' hundred'
59 | else:
60 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
61 | else:
62 | return _inflect.number_to_words(num, andword='')
63 |
64 |
65 | def normalize_numbers(text):
66 | text = re.sub(_comma_number_re, _remove_commas, text)
67 | text = re.sub(_pounds_re, r'\1 pounds', text)
68 | text = re.sub(_dollars_re, _expand_dollars, text)
69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
70 | text = re.sub(_ordinal_re, _expand_ordinal, text)
71 | text = re.sub(_number_re, _expand_number, text)
72 | return text
73 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7 | from text import cmudict
8 |
9 | _pad = '_'
10 | _sos = '^'
11 | _eos = '~'
12 | _punctuations = " ,.'?!"
13 | _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
14 |
15 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
16 | _arpabet = ['@' + s for s in cmudict.valid_symbols]
17 |
18 | # Export all symbols:
19 | symbols = [_pad, _sos, _eos] + list(_punctuations) + list(_characters) + _arpabet
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modules.model import Model
6 | from modules.loss import MDNLoss
7 | import hparams
8 | from text import *
9 | from utils.utils import *
10 | from utils.writer import get_writer
11 | from torch.utils.tensorboard import SummaryWriter
12 | import math
13 | import matplotlib.pyplot as plt
14 | from datetime import datetime
15 | from glob import glob
16 |
17 |
18 | def validate(model, criterion, val_loader, iteration, writer, stage):
19 | model.eval()
20 | with torch.no_grad():
21 | n_data, val_loss = 0, 0
22 | for i, batch in enumerate(val_loader):
23 | n_data += len(batch[0])
24 |
25 | if stage==0:
26 | text_padded, mel_padded, text_lengths, mel_lengths = [
27 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
28 | ]
29 | else:
30 | text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
31 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
32 | ]
33 |
34 | if stage !=3:
35 | encoder_input = model.module.Prenet(text_padded)
36 | hidden_states, _ = model.module.FFT_lower(encoder_input, text_lengths)
37 |
38 | if stage==0:
39 | mu_sigma = model.module.get_mu_sigma(hidden_states)
40 | loss, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)
41 |
42 | elif stage==1:
43 | mel_out = model.module.get_melspec(hidden_states, align_padded, mel_lengths)
44 | mel_mask = ~get_mask_from_lengths(mel_lengths)
45 | mel_padded_selected = mel_padded.masked_select(mel_mask.unsqueeze(1))
46 | mel_out_selected = mel_out.masked_select(mel_mask.unsqueeze(1))
47 | loss = nn.L1Loss()(mel_out_selected, mel_padded_selected)
48 |
49 | elif stage==2:
50 | mu_sigma = model.module.get_mu_sigma(hidden_states)
51 | mdn_loss, log_prob_matrix = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)
52 |
53 | align = model.module.viterbi(log_prob_matrix, text_lengths, mel_lengths)
54 | mel_out = model.module.get_melspec(hidden_states, align, mel_lengths)
55 | mel_mask = ~get_mask_from_lengths(mel_lengths)
56 | mel_padded_selected = mel_padded.masked_select(mel_mask.unsqueeze(1))
57 | mel_out_selected = mel_out.masked_select(mel_mask.unsqueeze(1))
58 | fft_loss = nn.L1Loss()(mel_out_selected, mel_padded_selected)
59 | loss = mdn_loss + fft_loss
60 |
61 | elif stage==3:
62 | duration_out = model.module.get_duration(text_padded, text_lengths) # gradient cut
63 | duration_target = align_padded.sum(-1)
64 | duration_mask = ~get_mask_from_lengths(text_lengths)
65 | duration_out = duration_out.masked_select(duration_mask)
66 | duration_target = duration_target.masked_select(duration_mask)
67 | loss = nn.MSELoss()(torch.log(duration_out), torch.log(duration_target))
68 |
69 | val_loss += loss.item() * len(batch[0])
70 |
71 | val_loss /= n_data
72 |
73 | if stage==0:
74 | writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation)
75 |
76 | align = model.module.viterbi(log_prob_matrix[0:1], text_lengths[0:1], mel_lengths[0:1]) # 1, L, T
77 | mel_out = torch.matmul(align[0].t(), mu_sigma[0, :, :hparams.n_mel_channels]).t() # F, T
78 |
79 | writer.add_image('Validation_alignments', align.detach().cpu(), iteration//hparams.accumulation)
80 | writer.add_specs(mel_padded[0].detach().cpu(),
81 | mel_out.detach().cpu(),
82 | iteration//hparams.accumulation, 'Validation')
83 | elif stage==1:
84 | writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation)
85 | writer.add_specs(mel_padded[0].detach().cpu(),
86 | mel_out[0].detach().cpu(),
87 | iteration//hparams.accumulation, 'Validation')
88 | elif stage==2:
89 | writer.add_scalar('Validation mdn_loss', mdn_loss.item(), iteration//hparams.accumulation)
90 | writer.add_scalar('Validation fft_loss', fft_loss.item(), iteration//hparams.accumulation)
91 | writer.add_image('Validation_alignments',
92 | align[0:1, :text_lengths[0], :mel_lengths[0]].detach().cpu(),
93 | iteration//hparams.accumulation)
94 | writer.add_specs(mel_padded[0].detach().cpu(),
95 | mel_out[0].detach().cpu(),
96 | iteration//hparams.accumulation, 'Validation')
97 | elif stage==3:
98 | writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation)
99 |
100 | model.train()
101 |
102 |
103 | def main(args):
104 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams, stage=args.stage)
105 |
106 | if args.stage!=0:
107 | checkpoint_path = f"training_log/aligntts/stage{args.stage-1}/checkpoint_{hparams.train_steps[args.stage-1]}"
108 |
109 | if not os.path.isfile(checkpoint_path):
110 | print(f'{checkpoint_path} does not exist')
111 | checkpoint_path = sorted(glob(f"training_log/aligntts/stage{args.stage-1}/checkpoint_*"))[-1]
112 | print(f'Loading {checkpoint_path} instead')
113 |
114 | state_dict = {}
115 | for k, v in torch.load(checkpoint_path)['state_dict'].items():
116 | state_dict[k[7:]]=v
117 |
118 | model = Model(hparams).cuda()
119 | model.load_state_dict(state_dict)
120 | model = nn.DataParallel(model).cuda()
121 | else:
122 | model = nn.DataParallel(Model(hparams)).cuda()
123 |
124 | criterion = MDNLoss()
125 | writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage{args.stage}')
126 | optimizer = torch.optim.Adam(model.parameters(),
127 | lr=hparams.lr,
128 | betas=(0.9, 0.98),
129 | eps=1e-09)
130 | iteration, loss = 0, 0
131 | model.train()
132 |
133 | print(f'Stage{args.stage} Start!!! ({str(datetime.now())})')
134 | while True:
135 | for i, batch in enumerate(train_loader):
136 | if args.stage==0:
137 | text_padded, mel_padded, text_lengths, mel_lengths = [
138 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
139 | ]
140 | align_padded=None
141 | else:
142 | text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
143 | reorder_batch(x, hparams.n_gpus).cuda() for x in batch
144 | ]
145 |
146 | sub_loss = model(text_padded,
147 | mel_padded,
148 | align_padded,
149 | text_lengths,
150 | mel_lengths,
151 | criterion,
152 | stage=args.stage,
153 | log_viterbi=args.log_viterbi,
154 | cpu_viterbi=args.cpu_viterbi)
155 | sub_loss = sub_loss.mean()/hparams.accumulation
156 | sub_loss.backward()
157 | loss = loss+sub_loss.item()
158 | iteration += 1
159 |
160 | print(f'[{str(datetime.now())}] Stage {args.stage} Iter {iteration:<6d} Loss {loss:<8.6f}')
161 |
162 | if iteration%hparams.accumulation == 0:
163 | lr_scheduling(optimizer, iteration//hparams.accumulation)
164 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
165 | optimizer.step()
166 | model.zero_grad()
167 | writer.add_scalar('Train loss', loss, iteration//hparams.accumulation)
168 | loss=0
169 |
170 | if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
171 | validate(model, criterion, val_loader, iteration, writer, args.stage)
172 |
173 | if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
174 | save_checkpoint(model,
175 | optimizer,
176 | hparams.lr,
177 | iteration//hparams.accumulation,
178 | filepath=f'{hparams.output_directory}/{hparams.log_directory}/stage{args.stage}')
179 |
180 | if iteration==(hparams.train_steps[args.stage]*hparams.accumulation):
181 | break
182 |
183 | if iteration==(hparams.train_steps[args.stage]*hparams.accumulation):
184 | break
185 |
186 | print(f'Stage{args.stage} End!!! ({str(datetime.now())})')
187 |
188 |
189 | if __name__ == '__main__':
190 | p = argparse.ArgumentParser()
191 | p.add_argument('--gpu', type=str, default='0,1')
192 | p.add_argument('-v', '--verbose', type=str, default='0')
193 | p.add_argument('--stage', type=int, required=True)
194 | p.add_argument('--log_viterbi', type=bool, default=False)
195 | p.add_argument('--cpu_viterbi', type=bool, default=False)
196 | args = p.parse_args()
197 |
198 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
199 | torch.manual_seed(hparams.seed)
200 | torch.cuda.manual_seed(hparams.seed)
201 |
202 | if args.verbose=='0':
203 | import warnings
204 | warnings.filterwarnings("ignore")
205 |
206 | main(args)
--------------------------------------------------------------------------------
/training_log/readme.txt:
--------------------------------------------------------------------------------
1 | download 'waveglow_256channels.pt' here from https://github.com/NVIDIA/waveglow
2 | mkdir aligntts here
--------------------------------------------------------------------------------
/utils/__pycache__/data_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/data_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/data_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/data_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/plot_image.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/plot_image.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/plot_image.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/plot_image.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/writer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/writer.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/writer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/utils/__pycache__/writer.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import hparams
4 | import torch
5 | import torch.utils.data
6 | import torch.nn.functional as F
7 | import os
8 | import pickle as pkl
9 |
10 | from text import text_to_sequence
11 |
12 |
13 | def load_filepaths_and_text(metadata, split="|"):
14 | with open(metadata, encoding='utf-8') as f:
15 | filepaths_and_text = [line.strip().split(split) for line in f]
16 | return filepaths_and_text
17 |
18 |
19 | class TextMelSet(torch.utils.data.Dataset):
20 | def __init__(self, audiopaths_and_text, hparams, stage):
21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
22 | self.data_type=hparams.data_type
23 | self.stage=stage
24 |
25 | self.text_dataset = []
26 | self.align_dataset = []
27 | seq_path = os.path.join(hparams.data_path, self.data_type)
28 | align_path = os.path.join(hparams.data_path, 'alignments')
29 | for data in self.audiopaths_and_text:
30 | file_name = data[0][:10]
31 | text = torch.from_numpy(np.load(f'{seq_path}/{file_name}_sequence.npy'))
32 | self.text_dataset.append(text)
33 |
34 | if stage !=0:
35 | align = torch.from_numpy(np.load(f'{align_path}/{file_name}_alignment.npy'))
36 | self.align_dataset.append(align)
37 |
38 |
39 | def get_mel_text_pair(self, index):
40 | file_name = self.audiopaths_and_text[index][0][:10]
41 |
42 | text = self.text_dataset[index]
43 | mel_path = os.path.join(hparams.data_path, 'melspectrogram')
44 | mel = torch.from_numpy(np.load(f'{mel_path}/{file_name}_melspectrogram.npy'))
45 |
46 | if self.stage == 0:
47 | return (text, mel)
48 |
49 | else:
50 | align = self.align_dataset[index]
51 | align = torch.repeat_interleave(torch.eye(len(align)).to(torch.long),
52 | align,
53 | dim=1)
54 | return (text, mel, align)
55 |
56 | def __getitem__(self, index):
57 | return self.get_mel_text_pair(index)
58 |
59 | def __len__(self):
60 | return len(self.audiopaths_and_text)
61 |
62 |
63 | class TextMelCollate():
64 | def __init__(self, stage):
65 | self.stage=stage
66 | return
67 |
68 | def __call__(self, batch):
69 | # Right zero-pad all one-hot text sequences to max input length
70 | input_lengths, ids_sorted_decreasing = torch.sort(
71 | torch.LongTensor([len(x[0]) for x in batch]),
72 | dim=0, descending=True)
73 | max_input_len = input_lengths[0]
74 | max_target_len = max([x[1].size(1) for x in batch])
75 | num_mels = batch[0][1].size(0)
76 |
77 | if self.stage==0:
78 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long)
79 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len)
80 | output_lengths = torch.LongTensor(len(batch))
81 |
82 | for i in range(len(ids_sorted_decreasing)):
83 | text = batch[ids_sorted_decreasing[i]][0]
84 | text_padded[i, :text.size(0)] = text
85 | mel = batch[ids_sorted_decreasing[i]][1]
86 | mel_padded[i, :, :mel.size(1)] = mel
87 | output_lengths[i] = mel.size(1)
88 |
89 | mel_padded = (torch.clamp(mel_padded, hparams.min_db, hparams.max_db)-hparams.min_db) / (hparams.max_db-hparams.min_db)
90 |
91 | return text_padded, mel_padded, input_lengths, output_lengths
92 |
93 |
94 | else:
95 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long)
96 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len)
97 | align_padded = torch.zeros(len(batch), max_input_len, max_target_len)
98 | output_lengths = torch.LongTensor(len(batch))
99 |
100 | for i in range(len(ids_sorted_decreasing)):
101 | text = batch[ids_sorted_decreasing[i]][0]
102 | text_padded[i, :text.size(0)] = text
103 | mel = batch[ids_sorted_decreasing[i]][1]
104 | mel_padded[i, :, :mel.size(1)] = mel
105 | output_lengths[i] = mel.size(1)
106 | align = batch[ids_sorted_decreasing[i]][2]
107 | align_padded[i, :align.size(0), :align.size(1)] = align
108 |
109 | mel_padded = (torch.clamp(mel_padded, hparams.min_db, hparams.max_db)-hparams.min_db) / (hparams.max_db-hparams.min_db)
110 |
111 | return text_padded, mel_padded, align_padded, input_lengths, output_lengths
112 |
113 |
114 |
--------------------------------------------------------------------------------
/utils/plot_image.py:
--------------------------------------------------------------------------------
1 | from text import *
2 | import torch
3 | import hparams
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def plot_melspec(mel_target, mel_out):
8 | fig, axes = plt.subplots(2, 1, figsize=(20,20))
9 |
10 | axes[0].imshow(mel_target,
11 | origin='lower',
12 | aspect='auto')
13 |
14 | axes[1].imshow(mel_out,
15 | origin='lower',
16 | aspect='auto')
17 |
18 | return fig
19 |
20 |
21 | def plot_alignments(alignments, text, mel_lengths, text_lengths, att_type):
22 | fig, axes = plt.subplots(hparams.n_layers, hparams.n_heads, figsize=(5*hparams.n_heads,5*hparams.n_layers))
23 | L, T = text_lengths[-1], mel_lengths[-1]
24 | n_layers, n_heads = alignments.size(1), alignments.size(2)
25 |
26 | for layer in range(n_layers):
27 | for head in range(n_heads):
28 | if att_type=='enc':
29 | align = alignments[-1, layer, head].contiguous()
30 | axes[layer,head].imshow(align[:L, :L], aspect='auto')
31 | axes[layer,head].xaxis.tick_top()
32 |
33 | elif att_type=='dec':
34 | align = alignments[-1, layer, head].contiguous()
35 | axes[layer,head].imshow(align[:T, :T], aspect='auto')
36 | axes[layer,head].xaxis.tick_top()
37 |
38 | elif att_type=='enc_dec':
39 | align = alignments[-1, layer, head].transpose(0,1).contiguous()
40 | axes[layer,head].imshow(align[:L, :T], origin='lower', aspect='auto')
41 |
42 | return fig
43 |
44 |
45 | def plot_gate(gate_out):
46 | fig = plt.figure(figsize=(10,5))
47 | plt.plot(torch.sigmoid(gate_out[-1]))
48 | return fig
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import hparams
2 | from torch.utils.data import DataLoader
3 | from .data_utils import TextMelSet, TextMelCollate
4 | import torch
5 | from text import *
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def prepare_dataloaders(hparams, stage):
10 | # Get data, data loaders and collate function ready
11 | trainset = TextMelSet(hparams.training_files, hparams, stage)
12 | valset = TextMelSet(hparams.validation_files, hparams, stage)
13 | collate_fn = TextMelCollate(stage)
14 |
15 | train_loader = DataLoader(trainset,
16 | shuffle=True,
17 | batch_size=hparams.batch_size,
18 | drop_last=True,
19 | collate_fn=collate_fn)
20 |
21 | val_loader = DataLoader(valset,
22 | batch_size=hparams.batch_size//hparams.n_gpus,
23 | collate_fn=collate_fn)
24 |
25 | return train_loader, val_loader, collate_fn
26 |
27 |
28 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
29 | print(f"Saving model and optimizer state at iteration {iteration} to {filepath}")
30 | torch.save({'iteration': iteration,
31 | 'state_dict': model.state_dict(),
32 | 'optimizer': optimizer.state_dict(),
33 | 'learning_rate': learning_rate}, f'{filepath}/checkpoint_{iteration}')
34 |
35 |
36 | def lr_scheduling(opt, step, init_lr=hparams.lr, warmup_steps=hparams.warmup_steps):
37 | opt.param_groups[0]['lr'] = init_lr * min(step ** -0.5, step * warmup_steps ** -1.5)
38 | return
39 |
40 |
41 | def get_mask_from_lengths(lengths):
42 | max_len = torch.max(lengths).item()
43 | ids = lengths.new_tensor(torch.arange(0, max_len))
44 | mask = (lengths.unsqueeze(1) <= ids).to(torch.bool)
45 | return mask
46 |
47 |
48 | def reorder_batch(x, n_gpus, base=0):
49 | assert (x.size(0)%n_gpus)==0, 'Batch size must be a multiple of the number of GPUs.'
50 | base = base%n_gpus
51 | new_x = list(torch.zeros_like(x).chunk(n_gpus))
52 | for i in range(base, base+n_gpus):
53 | new_x[i%n_gpus] = x[i-base::n_gpus]
54 |
55 | new_x = torch.cat(new_x, dim=0)
56 |
57 | return new_x
58 |
59 |
60 |
61 | def decode_text(padded_text, text_lengths, batch_idx=0):
62 | text = padded_text[batch_idx]
63 | text_len = text_lengths[batch_idx]
64 | text = ''.join([symbols[ci] for i, ci in enumerate(text) if i < text_len])
65 |
66 | return text
67 |
--------------------------------------------------------------------------------
/utils/writer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.tensorboard import SummaryWriter
3 | from .plot_image import *
4 |
5 | def get_writer(output_directory, log_directory):
6 | logging_path=f'{output_directory}/{log_directory}'
7 |
8 | if os.path.exists(logging_path):
9 | writer = TTSWriter(logging_path)
10 | # raise Exception('The experiment already exists')
11 | print(f'The experiment {logging_path} already exists!')
12 | else:
13 | os.makedirs(logging_path)
14 | writer = TTSWriter(logging_path)
15 |
16 | return writer
17 |
18 |
19 | class TTSWriter(SummaryWriter):
20 | def __init__(self, log_dir):
21 | super(TTSWriter, self).__init__(log_dir)
22 |
23 | def add_specs(self, mel_target, mel_out, global_step, phase):
24 | fig, axes = plt.subplots(2, 1, figsize=(20,20))
25 |
26 | axes[0].imshow(mel_target,
27 | origin='lower',
28 | aspect='auto')
29 |
30 | axes[1].imshow(mel_out,
31 | origin='lower',
32 | aspect='auto')
33 |
34 | self.add_figure(f'{phase}_melspec', fig, global_step)
35 |
36 | def add_alignments(self, alignment, global_step, phase):
37 | fig = plt.plot(figsize=(20,10))
38 | plt.imshow(alignment, origin='lower', aspect='auto')
39 | self.add_figure(f'{phase}_alignments', fig, global_step)
40 |
41 |
42 |
--------------------------------------------------------------------------------
/waveglow/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "tacotron2"]
2 | path = tacotron2
3 | url = http://github.com/NVIDIA/tacotron2
4 |
--------------------------------------------------------------------------------
/waveglow/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, NVIDIA Corporation
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/waveglow/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | ## WaveGlow: a Flow-based Generative Network for Speech Synthesis
4 |
5 | ### Ryan Prenger, Rafael Valle, and Bryan Catanzaro
6 |
7 | In our recent [paper], we propose WaveGlow: a flow-based network capable of
8 | generating high quality speech from mel-spectrograms. WaveGlow combines insights
9 | from [Glow] and [WaveNet] in order to provide fast, efficient and high-quality
10 | audio synthesis, without the need for auto-regression. WaveGlow is implemented
11 | using only a single network, trained using only a single cost function:
12 | maximizing the likelihood of the training data, which makes the training
13 | procedure simple and stable.
14 |
15 | Our [PyTorch] implementation produces audio samples at a rate of 2750
16 | kHz on an NVIDIA V100 GPU. Mean Opinion Scores show that it delivers audio
17 | quality as good as the best publicly available WaveNet implementation.
18 |
19 | Visit our [website] for audio samples.
20 |
21 | ## Setup
22 |
23 | 1. Clone our repo and initialize submodule
24 |
25 | ```command
26 | git clone https://github.com/NVIDIA/waveglow.git
27 | cd waveglow
28 | git submodule init
29 | git submodule update
30 | ```
31 |
32 | 2. Install requirements `pip3 install -r requirements.txt`
33 |
34 | 3. Install [Apex]
35 |
36 |
37 | ## Generate audio with our pre-existing model
38 |
39 | 1. Download our [published model]
40 | 2. Download [mel-spectrograms]
41 | 3. Generate audio `python3 inference.py -f <(ls mel_spectrograms/*.pt) -w waveglow_256channels.pt -o . --is_fp16 -s 0.6`
42 |
43 | N.b. use `convert_model.py` to convert your older models to the current model
44 | with fused residual and skip connections.
45 |
46 | ## Train your own model
47 |
48 | 1. Download [LJ Speech Data]. In this example it's in `data/`
49 |
50 | 2. Make a list of the file names to use for training/testing
51 |
52 | ```command
53 | ls data/*.wav | tail -n+10 > train_files.txt
54 | ls data/*.wav | head -n10 > test_files.txt
55 | ```
56 |
57 | 3. Train your WaveGlow networks
58 |
59 | ```command
60 | mkdir checkpoints
61 | python train.py -c config.json
62 | ```
63 |
64 | For multi-GPU training replace `train.py` with `distributed.py`. Only tested with single node and NCCL.
65 |
66 | For mixed precision training set `"fp16_run": true` on `config.json`.
67 |
68 | 4. Make test set mel-spectrograms
69 |
70 | `python mel2samp.py -f test_files.txt -o . -c config.json`
71 |
72 | 5. Do inference with your network
73 |
74 | ```command
75 | ls *.pt > mel_files.txt
76 | python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6
77 | ```
78 |
79 | [//]: # (TODO)
80 | [//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS)
81 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation
82 | [website]: https://nv-adlr.github.io/WaveGlow
83 | [paper]: https://arxiv.org/abs/1811.00002
84 | [WaveNet implementation]: https://github.com/r9y9/wavenet_vocoder
85 | [Glow]: https://blog.openai.com/glow/
86 | [WaveNet]: https://deepmind.com/blog/wavenet-generative-model-raw-audio/
87 | [PyTorch]: http://pytorch.org
88 | [published model]: https://drive.google.com/file/d/1WsibBTsuRg_SF2Z6L6NFRTT-NjEy1oTx/view?usp=sharing
89 | [mel-spectrograms]: https://drive.google.com/file/d/1g_VXK2lpP9J25dQFhQwx7doWl_p20fXA/view?usp=sharing
90 | [LJ Speech Data]: https://keithito.com/LJ-Speech-Dataset
91 | [Apex]: https://github.com/nvidia/apex
92 |
--------------------------------------------------------------------------------
/waveglow/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_config": {
3 | "fp16_run": true,
4 | "output_directory": "checkpoints",
5 | "epochs": 100000,
6 | "learning_rate": 1e-4,
7 | "sigma": 1.0,
8 | "iters_per_checkpoint": 2000,
9 | "batch_size": 12,
10 | "seed": 1234,
11 | "checkpoint_path": "",
12 | "with_tensorboard": false
13 | },
14 | "data_config": {
15 | "training_files": "train_files.txt",
16 | "segment_length": 16000,
17 | "sampling_rate": 22050,
18 | "filter_length": 1024,
19 | "hop_length": 256,
20 | "win_length": 1024,
21 | "mel_fmin": 0.0,
22 | "mel_fmax": 8000.0
23 | },
24 | "dist_config": {
25 | "dist_backend": "nccl",
26 | "dist_url": "tcp://localhost:54321"
27 | },
28 |
29 | "waveglow_config": {
30 | "n_mel_channels": 80,
31 | "n_flows": 12,
32 | "n_group": 8,
33 | "n_early_every": 4,
34 | "n_early_size": 2,
35 | "WN_config": {
36 | "n_layers": 8,
37 | "n_channels": 256,
38 | "kernel_size": 3
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/waveglow/convert_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import copy
3 | import torch
4 |
5 | def _check_model_old_version(model):
6 | if hasattr(model.WN[0], 'res_layers'):
7 | return True
8 | else:
9 | return False
10 |
11 | def update_model(old_model):
12 | if not _check_model_old_version(old_model):
13 | return old_model
14 | new_model = copy.deepcopy(old_model)
15 | for idx in range(0, len(new_model.WN)):
16 | wavenet = new_model.WN[idx]
17 | wavenet.res_skip_layers = torch.nn.ModuleList()
18 | n_channels = wavenet.n_channels
19 | n_layers = wavenet.n_layers
20 | for i in range(0, n_layers):
21 | if i < n_layers - 1:
22 | res_skip_channels = 2*n_channels
23 | else:
24 | res_skip_channels = n_channels
25 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
26 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
27 | if i < n_layers - 1:
28 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
29 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
30 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
31 | else:
32 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
33 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
34 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
35 | wavenet.res_skip_layers.append(res_skip_layer)
36 | del wavenet.res_layers
37 | del wavenet.skip_layers
38 | return new_model
39 |
40 | if __name__ == '__main__':
41 | old_model_path = sys.argv[1]
42 | new_model_path = sys.argv[2]
43 | model = torch.load(old_model_path)
44 | model['model'] = update_model(model['model'])
45 | torch.save(model, new_model_path)
46 |
47 |
--------------------------------------------------------------------------------
/waveglow/denoiser.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('tacotron2')
3 | import torch
4 | from layers import STFT
5 |
6 |
7 | class Denoiser(torch.nn.Module):
8 | """ Removes model bias from audio produced with waveglow """
9 |
10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4,
11 | win_length=1024, mode='zeros'):
12 | super(Denoiser, self).__init__()
13 | self.stft = STFT(filter_length=filter_length,
14 | hop_length=int(filter_length/n_overlap),
15 | win_length=win_length).cuda()
16 | if mode == 'zeros':
17 | mel_input = torch.zeros(
18 | (1, 80, 88),
19 | dtype=waveglow.upsample.weight.dtype,
20 | device=waveglow.upsample.weight.device)
21 | elif mode == 'normal':
22 | mel_input = torch.randn(
23 | (1, 80, 88),
24 | dtype=waveglow.upsample.weight.dtype,
25 | device=waveglow.upsample.weight.device)
26 | else:
27 | raise Exception("Mode {} if not supported".format(mode))
28 |
29 | with torch.no_grad():
30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
31 | bias_spec, _ = self.stft.transform(bias_audio)
32 |
33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
34 |
35 | def forward(self, audio, strength=0.1):
36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
37 | audio_spec_denoised = audio_spec - self.bias_spec * strength
38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
40 | return audio_denoised
41 |
--------------------------------------------------------------------------------
/waveglow/distributed.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import os
28 | import sys
29 | import time
30 | import subprocess
31 | import argparse
32 |
33 | import torch
34 | import torch.distributed as dist
35 | from torch.autograd import Variable
36 |
37 | def reduce_tensor(tensor, num_gpus):
38 | rt = tensor.clone()
39 | dist.all_reduce(rt, op=dist.reduce_op.SUM)
40 | rt /= num_gpus
41 | return rt
42 |
43 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
44 | assert torch.cuda.is_available(), "Distributed mode requires CUDA."
45 | print("Initializing Distributed")
46 |
47 | # Set cuda device so everything is done on the right GPU.
48 | torch.cuda.set_device(rank % torch.cuda.device_count())
49 |
50 | # Initialize distributed communication
51 | dist.init_process_group(dist_backend, init_method=dist_url,
52 | world_size=num_gpus, rank=rank,
53 | group_name=group_name)
54 |
55 | def _flatten_dense_tensors(tensors):
56 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
57 | same dense type.
58 | Since inputs are dense, the resulting tensor will be a concatenated 1D
59 | buffer. Element-wise operation on this buffer will be equivalent to
60 | operating individually.
61 | Arguments:
62 | tensors (Iterable[Tensor]): dense tensors to flatten.
63 | Returns:
64 | A contiguous 1D buffer containing input tensors.
65 | """
66 | if len(tensors) == 1:
67 | return tensors[0].contiguous().view(-1)
68 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
69 | return flat
70 |
71 | def _unflatten_dense_tensors(flat, tensors):
72 | """View a flat buffer using the sizes of tensors. Assume that tensors are of
73 | same dense type, and that flat is given by _flatten_dense_tensors.
74 | Arguments:
75 | flat (Tensor): flattened dense tensors to unflatten.
76 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
77 | unflatten flat.
78 | Returns:
79 | Unflattened dense tensors with sizes same as tensors and values from
80 | flat.
81 | """
82 | outputs = []
83 | offset = 0
84 | for tensor in tensors:
85 | numel = tensor.numel()
86 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
87 | offset += numel
88 | return tuple(outputs)
89 |
90 | def apply_gradient_allreduce(module):
91 | """
92 | Modifies existing model to do gradient allreduce, but doesn't change class
93 | so you don't need "module"
94 | """
95 | if not hasattr(dist, '_backend'):
96 | module.warn_on_half = True
97 | else:
98 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
99 |
100 | for p in module.state_dict().values():
101 | if not torch.is_tensor(p):
102 | continue
103 | dist.broadcast(p, 0)
104 |
105 | def allreduce_params():
106 | if(module.needs_reduction):
107 | module.needs_reduction = False
108 | buckets = {}
109 | for param in module.parameters():
110 | if param.requires_grad and param.grad is not None:
111 | tp = type(param.data)
112 | if tp not in buckets:
113 | buckets[tp] = []
114 | buckets[tp].append(param)
115 | if module.warn_on_half:
116 | if torch.cuda.HalfTensor in buckets:
117 | print("WARNING: gloo dist backend for half parameters may be extremely slow." +
118 | " It is recommended to use the NCCL backend in this case. This currently requires" +
119 | "PyTorch built from top of tree master.")
120 | module.warn_on_half = False
121 |
122 | for tp in buckets:
123 | bucket = buckets[tp]
124 | grads = [param.grad.data for param in bucket]
125 | coalesced = _flatten_dense_tensors(grads)
126 | dist.all_reduce(coalesced)
127 | coalesced /= dist.get_world_size()
128 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
129 | buf.copy_(synced)
130 |
131 | for param in list(module.parameters()):
132 | def allreduce_hook(*unused):
133 | Variable._execution_engine.queue_callback(allreduce_params)
134 | if param.requires_grad:
135 | param.register_hook(allreduce_hook)
136 | dir(param)
137 |
138 | def set_needs_reduction(self, input, output):
139 | self.needs_reduction = True
140 |
141 | module.register_forward_hook(set_needs_reduction)
142 | return module
143 |
144 |
145 | def main(config, stdout_dir, args_str):
146 | args_list = ['train.py']
147 | args_list += args_str.split(' ') if len(args_str) > 0 else []
148 |
149 | args_list.append('--config={}'.format(config))
150 |
151 | num_gpus = torch.cuda.device_count()
152 | args_list.append('--num_gpus={}'.format(num_gpus))
153 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S")))
154 |
155 | if not os.path.isdir(stdout_dir):
156 | os.makedirs(stdout_dir)
157 | os.chmod(stdout_dir, 0o775)
158 |
159 | workers = []
160 |
161 | for i in range(num_gpus):
162 | args_list[-2] = '--rank={}'.format(i)
163 | stdout = None if i == 0 else open(
164 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w")
165 | print(args_list)
166 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout)
167 | workers.append(p)
168 |
169 | for p in workers:
170 | p.wait()
171 |
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('-c', '--config', type=str, required=True,
176 | help='JSON file for configuration')
177 | parser.add_argument('-s', '--stdout_dir', type=str, default=".",
178 | help='directory to save stoud logs')
179 | parser.add_argument(
180 | '-a', '--args_str', type=str, default='',
181 | help='double quoted string with space separated key value pairs')
182 |
183 | args = parser.parse_args()
184 | main(args.config, args.stdout_dir, args.args_str)
185 |
--------------------------------------------------------------------------------
/waveglow/glow.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import copy
28 | import torch
29 | from torch.autograd import Variable
30 | import torch.nn.functional as F
31 |
32 |
33 | @torch.jit.script
34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
35 | n_channels_int = n_channels[0]
36 | in_act = input_a+input_b
37 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
39 | acts = t_act * s_act
40 | return acts
41 |
42 |
43 | class WaveGlowLoss(torch.nn.Module):
44 | def __init__(self, sigma=1.0):
45 | super(WaveGlowLoss, self).__init__()
46 | self.sigma = sigma
47 |
48 | def forward(self, model_output):
49 | z, log_s_list, log_det_W_list = model_output
50 | for i, log_s in enumerate(log_s_list):
51 | if i == 0:
52 | log_s_total = torch.sum(log_s)
53 | log_det_W_total = log_det_W_list[i]
54 | else:
55 | log_s_total = log_s_total + torch.sum(log_s)
56 | log_det_W_total += log_det_W_list[i]
57 |
58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
59 | return loss/(z.size(0)*z.size(1)*z.size(2))
60 |
61 |
62 | class Invertible1x1Conv(torch.nn.Module):
63 | """
64 | The layer outputs both the convolution, and the log determinant
65 | of its weight matrix. If reverse=True it does convolution with
66 | inverse
67 | """
68 | def __init__(self, c):
69 | super(Invertible1x1Conv, self).__init__()
70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
71 | bias=False)
72 |
73 | # Sample a random orthonormal matrix to initialize weights
74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
75 |
76 | # Ensure determinant is 1.0 not -1.0
77 | if torch.det(W) < 0:
78 | W[:,0] = -1*W[:,0]
79 | W = W.view(c, c, 1)
80 | self.conv.weight.data = W
81 |
82 | def forward(self, z, reverse=False):
83 | # shape
84 | batch_size, group_size, n_of_groups = z.size()
85 |
86 | W = self.conv.weight.squeeze()
87 |
88 | if reverse:
89 | if not hasattr(self, 'W_inverse'):
90 | # Reverse computation
91 | W_inverse = W.float().inverse()
92 | W_inverse = Variable(W_inverse[..., None])
93 | if z.type() == 'torch.cuda.HalfTensor':
94 | W_inverse = W_inverse.half()
95 | self.W_inverse = W_inverse
96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
97 | return z
98 | else:
99 | # Forward computation
100 | log_det_W = batch_size * n_of_groups * torch.logdet(W)
101 | z = self.conv(z)
102 | return z, log_det_W
103 |
104 |
105 | class WN(torch.nn.Module):
106 | """
107 | This is the WaveNet like layer for the affine coupling. The primary difference
108 | from WaveNet is the convolutions need not be causal. There is also no dilation
109 | size reset. The dilation only doubles on each layer
110 | """
111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
112 | kernel_size):
113 | super(WN, self).__init__()
114 | assert(kernel_size % 2 == 1)
115 | assert(n_channels % 2 == 0)
116 | self.n_layers = n_layers
117 | self.n_channels = n_channels
118 | self.in_layers = torch.nn.ModuleList()
119 | self.res_skip_layers = torch.nn.ModuleList()
120 | self.cond_layers = torch.nn.ModuleList()
121 |
122 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
123 | start = torch.nn.utils.weight_norm(start, name='weight')
124 | self.start = start
125 |
126 | # Initializing last layer to 0 makes the affine coupling layers
127 | # do nothing at first. This helps with training stability
128 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
129 | end.weight.data.zero_()
130 | end.bias.data.zero_()
131 | self.end = end
132 |
133 | for i in range(n_layers):
134 | dilation = 2 ** i
135 | padding = int((kernel_size*dilation - dilation)/2)
136 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
137 | dilation=dilation, padding=padding)
138 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
139 | self.in_layers.append(in_layer)
140 |
141 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
142 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
143 | self.cond_layers.append(cond_layer)
144 |
145 | # last one is not necessary
146 | if i < n_layers - 1:
147 | res_skip_channels = 2*n_channels
148 | else:
149 | res_skip_channels = n_channels
150 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
151 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
152 | self.res_skip_layers.append(res_skip_layer)
153 |
154 | def forward(self, forward_input):
155 | audio, spect = forward_input
156 | audio = self.start(audio)
157 |
158 | for i in range(self.n_layers):
159 | acts = fused_add_tanh_sigmoid_multiply(
160 | self.in_layers[i](audio),
161 | self.cond_layers[i](spect),
162 | torch.IntTensor([self.n_channels]))
163 |
164 | res_skip_acts = self.res_skip_layers[i](acts)
165 | if i < self.n_layers - 1:
166 | audio = res_skip_acts[:,:self.n_channels,:] + audio
167 | skip_acts = res_skip_acts[:,self.n_channels:,:]
168 | else:
169 | skip_acts = res_skip_acts
170 |
171 | if i == 0:
172 | output = skip_acts
173 | else:
174 | output = skip_acts + output
175 | return self.end(output)
176 |
177 |
178 | class WaveGlow(torch.nn.Module):
179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
180 | n_early_size, WN_config):
181 | super(WaveGlow, self).__init__()
182 |
183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
184 | n_mel_channels,
185 | 1024, stride=256)
186 | assert(n_group % 2 == 0)
187 | self.n_flows = n_flows
188 | self.n_group = n_group
189 | self.n_early_every = n_early_every
190 | self.n_early_size = n_early_size
191 | self.WN = torch.nn.ModuleList()
192 | self.convinv = torch.nn.ModuleList()
193 |
194 | n_half = int(n_group/2)
195 |
196 | # Set up layers with the right sizes based on how many dimensions
197 | # have been output already
198 | n_remaining_channels = n_group
199 | for k in range(n_flows):
200 | if k % self.n_early_every == 0 and k > 0:
201 | n_half = n_half - int(self.n_early_size/2)
202 | n_remaining_channels = n_remaining_channels - self.n_early_size
203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels))
204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
205 | self.n_remaining_channels = n_remaining_channels # Useful during inference
206 |
207 | def forward(self, forward_input):
208 | """
209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
210 | forward_input[1] = audio: batch x time
211 | """
212 | spect, audio = forward_input
213 |
214 | # Upsample spectrogram to size of audio
215 | spect = self.upsample(spect)
216 | assert(spect.size(2) >= audio.size(1))
217 | if spect.size(2) > audio.size(1):
218 | spect = spect[:, :, :audio.size(1)]
219 |
220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
222 |
223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
224 | output_audio = []
225 | log_s_list = []
226 | log_det_W_list = []
227 |
228 | for k in range(self.n_flows):
229 | if k % self.n_early_every == 0 and k > 0:
230 | output_audio.append(audio[:,:self.n_early_size,:])
231 | audio = audio[:,self.n_early_size:,:]
232 |
233 | audio, log_det_W = self.convinv[k](audio)
234 | log_det_W_list.append(log_det_W)
235 |
236 | n_half = int(audio.size(1)/2)
237 | audio_0 = audio[:,:n_half,:]
238 | audio_1 = audio[:,n_half:,:]
239 |
240 | output = self.WN[k]((audio_0, spect))
241 | log_s = output[:, n_half:, :]
242 | b = output[:, :n_half, :]
243 | audio_1 = torch.exp(log_s)*audio_1 + b
244 | log_s_list.append(log_s)
245 |
246 | audio = torch.cat([audio_0, audio_1],1)
247 |
248 | output_audio.append(audio)
249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list
250 |
251 | def infer(self, spect, sigma=1.0):
252 | spect = self.upsample(spect)
253 | # trim conv artifacts. maybe pad spec to kernel multiple
254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
255 | spect = spect[:, :, :-time_cutoff]
256 |
257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
259 |
260 | if spect.type() == 'torch.cuda.HalfTensor':
261 | audio = torch.cuda.HalfTensor(spect.size(0),
262 | self.n_remaining_channels,
263 | spect.size(2)).normal_()
264 | else:
265 | audio = torch.cuda.FloatTensor(spect.size(0),
266 | self.n_remaining_channels,
267 | spect.size(2)).normal_()
268 |
269 | audio = torch.autograd.Variable(sigma*audio)
270 |
271 | for k in reversed(range(self.n_flows)):
272 | n_half = int(audio.size(1)/2)
273 | audio_0 = audio[:,:n_half,:]
274 | audio_1 = audio[:,n_half:,:]
275 |
276 | output = self.WN[k]((audio_0, spect))
277 | s = output[:, n_half:, :]
278 | b = output[:, :n_half, :]
279 | audio_1 = (audio_1 - b)/torch.exp(s)
280 | audio = torch.cat([audio_0, audio_1],1)
281 |
282 | audio = self.convinv[k](audio, reverse=True)
283 |
284 | if k % self.n_early_every == 0 and k > 0:
285 | if spect.type() == 'torch.cuda.HalfTensor':
286 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
287 | else:
288 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
289 | audio = torch.cat((sigma*z, audio),1)
290 |
291 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
292 | return audio
293 |
294 | @staticmethod
295 | def remove_weightnorm(model):
296 | waveglow = model
297 | for WN in waveglow.WN:
298 | WN.start = torch.nn.utils.remove_weight_norm(WN.start)
299 | WN.in_layers = remove(WN.in_layers)
300 | WN.cond_layers = remove(WN.cond_layers)
301 | WN.res_skip_layers = remove(WN.res_skip_layers)
302 | return waveglow
303 |
304 |
305 | def remove(conv_list):
306 | new_conv_list = torch.nn.ModuleList()
307 | for old_conv in conv_list:
308 | old_conv = torch.nn.utils.remove_weight_norm(old_conv)
309 | new_conv_list.append(old_conv)
310 | return new_conv_list
311 |
--------------------------------------------------------------------------------
/waveglow/glow_old.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from glow import Invertible1x1Conv, remove
4 |
5 |
6 | @torch.jit.script
7 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
8 | n_channels_int = n_channels[0]
9 | in_act = input_a+input_b
10 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
11 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
12 | acts = t_act * s_act
13 | return acts
14 |
15 |
16 | class WN(torch.nn.Module):
17 | """
18 | This is the WaveNet like layer for the affine coupling. The primary difference
19 | from WaveNet is the convolutions need not be causal. There is also no dilation
20 | size reset. The dilation only doubles on each layer
21 | """
22 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
23 | kernel_size):
24 | super(WN, self).__init__()
25 | assert(kernel_size % 2 == 1)
26 | assert(n_channels % 2 == 0)
27 | self.n_layers = n_layers
28 | self.n_channels = n_channels
29 | self.in_layers = torch.nn.ModuleList()
30 | self.res_skip_layers = torch.nn.ModuleList()
31 | self.cond_layers = torch.nn.ModuleList()
32 |
33 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
34 | start = torch.nn.utils.weight_norm(start, name='weight')
35 | self.start = start
36 |
37 | # Initializing last layer to 0 makes the affine coupling layers
38 | # do nothing at first. This helps with training stability
39 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
40 | end.weight.data.zero_()
41 | end.bias.data.zero_()
42 | self.end = end
43 |
44 | for i in range(n_layers):
45 | dilation = 2 ** i
46 | padding = int((kernel_size*dilation - dilation)/2)
47 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
48 | dilation=dilation, padding=padding)
49 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
50 | self.in_layers.append(in_layer)
51 |
52 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
53 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
54 | self.cond_layers.append(cond_layer)
55 |
56 | # last one is not necessary
57 | if i < n_layers - 1:
58 | res_skip_channels = 2*n_channels
59 | else:
60 | res_skip_channels = n_channels
61 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
62 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
63 | self.res_skip_layers.append(res_skip_layer)
64 |
65 | def forward(self, forward_input):
66 | audio, spect = forward_input
67 | audio = self.start(audio)
68 |
69 | for i in range(self.n_layers):
70 | acts = fused_add_tanh_sigmoid_multiply(
71 | self.in_layers[i](audio),
72 | self.cond_layers[i](spect),
73 | torch.IntTensor([self.n_channels]))
74 |
75 | res_skip_acts = self.res_skip_layers[i](acts)
76 | if i < self.n_layers - 1:
77 | audio = res_skip_acts[:,:self.n_channels,:] + audio
78 | skip_acts = res_skip_acts[:,self.n_channels:,:]
79 | else:
80 | skip_acts = res_skip_acts
81 |
82 | if i == 0:
83 | output = skip_acts
84 | else:
85 | output = skip_acts + output
86 | return self.end(output)
87 |
88 |
89 | class WaveGlow(torch.nn.Module):
90 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
91 | n_early_size, WN_config):
92 | super(WaveGlow, self).__init__()
93 |
94 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
95 | n_mel_channels,
96 | 1024, stride=256)
97 | assert(n_group % 2 == 0)
98 | self.n_flows = n_flows
99 | self.n_group = n_group
100 | self.n_early_every = n_early_every
101 | self.n_early_size = n_early_size
102 | self.WN = torch.nn.ModuleList()
103 | self.convinv = torch.nn.ModuleList()
104 |
105 | n_half = int(n_group/2)
106 |
107 | # Set up layers with the right sizes based on how many dimensions
108 | # have been output already
109 | n_remaining_channels = n_group
110 | for k in range(n_flows):
111 | if k % self.n_early_every == 0 and k > 0:
112 | n_half = n_half - int(self.n_early_size/2)
113 | n_remaining_channels = n_remaining_channels - self.n_early_size
114 | self.convinv.append(Invertible1x1Conv(n_remaining_channels))
115 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
116 | self.n_remaining_channels = n_remaining_channels # Useful during inference
117 |
118 | def forward(self, forward_input):
119 | return None
120 | """
121 | forward_input[0] = audio: batch x time
122 | forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time
123 | """
124 | """
125 | spect, audio = forward_input
126 |
127 | # Upsample spectrogram to size of audio
128 | spect = self.upsample(spect)
129 | assert(spect.size(2) >= audio.size(1))
130 | if spect.size(2) > audio.size(1):
131 | spect = spect[:, :, :audio.size(1)]
132 |
133 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
134 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
135 |
136 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
137 | output_audio = []
138 | s_list = []
139 | s_conv_list = []
140 |
141 | for k in range(self.n_flows):
142 | if k%4 == 0 and k > 0:
143 | output_audio.append(audio[:,:self.n_multi,:])
144 | audio = audio[:,self.n_multi:,:]
145 |
146 | # project to new basis
147 | audio, s = self.convinv[k](audio)
148 | s_conv_list.append(s)
149 |
150 | n_half = int(audio.size(1)/2)
151 | if k%2 == 0:
152 | audio_0 = audio[:,:n_half,:]
153 | audio_1 = audio[:,n_half:,:]
154 | else:
155 | audio_1 = audio[:,:n_half,:]
156 | audio_0 = audio[:,n_half:,:]
157 |
158 | output = self.nn[k]((audio_0, spect))
159 | s = output[:, n_half:, :]
160 | b = output[:, :n_half, :]
161 | audio_1 = torch.exp(s)*audio_1 + b
162 | s_list.append(s)
163 |
164 | if k%2 == 0:
165 | audio = torch.cat([audio[:,:n_half,:], audio_1],1)
166 | else:
167 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
168 | output_audio.append(audio)
169 | return torch.cat(output_audio,1), s_list, s_conv_list
170 | """
171 |
172 | def infer(self, spect, sigma=1.0):
173 | spect = self.upsample(spect)
174 | # trim conv artifacts. maybe pad spec to kernel multiple
175 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
176 | spect = spect[:, :, :-time_cutoff]
177 |
178 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
179 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
180 |
181 | if spect.type() == 'torch.cuda.HalfTensor':
182 | audio = torch.cuda.HalfTensor(spect.size(0),
183 | self.n_remaining_channels,
184 | spect.size(2)).normal_()
185 | else:
186 | audio = torch.cuda.FloatTensor(spect.size(0),
187 | self.n_remaining_channels,
188 | spect.size(2)).normal_()
189 |
190 | audio = torch.autograd.Variable(sigma*audio)
191 |
192 | for k in reversed(range(self.n_flows)):
193 | n_half = int(audio.size(1)/2)
194 | if k%2 == 0:
195 | audio_0 = audio[:,:n_half,:]
196 | audio_1 = audio[:,n_half:,:]
197 | else:
198 | audio_1 = audio[:,:n_half,:]
199 | audio_0 = audio[:,n_half:,:]
200 |
201 | output = self.WN[k]((audio_0, spect))
202 | s = output[:, n_half:, :]
203 | b = output[:, :n_half, :]
204 | audio_1 = (audio_1 - b)/torch.exp(s)
205 | if k%2 == 0:
206 | audio = torch.cat([audio[:,:n_half,:], audio_1],1)
207 | else:
208 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
209 |
210 | audio = self.convinv[k](audio, reverse=True)
211 |
212 | if k%4 == 0 and k > 0:
213 | if spect.type() == 'torch.cuda.HalfTensor':
214 | z = torch.cuda.HalfTensor(spect.size(0),
215 | self.n_early_size,
216 | spect.size(2)).normal_()
217 | else:
218 | z = torch.cuda.FloatTensor(spect.size(0),
219 | self.n_early_size,
220 | spect.size(2)).normal_()
221 | audio = torch.cat((sigma*z, audio),1)
222 |
223 | return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
224 |
225 | @staticmethod
226 | def remove_weightnorm(model):
227 | waveglow = model
228 | for WN in waveglow.WN:
229 | WN.start = torch.nn.utils.remove_weight_norm(WN.start)
230 | WN.in_layers = remove(WN.in_layers)
231 | WN.cond_layers = remove(WN.cond_layers)
232 | WN.res_skip_layers = remove(WN.res_skip_layers)
233 | return waveglow
234 |
--------------------------------------------------------------------------------
/waveglow/inference.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import os
28 | from scipy.io.wavfile import write
29 | import torch
30 | from mel2samp import files_to_list, MAX_WAV_VALUE
31 | from denoiser import Denoiser
32 |
33 |
34 | def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
35 | denoiser_strength):
36 | mel_files = files_to_list(mel_files)
37 | waveglow = torch.load(waveglow_path)['model']
38 | waveglow = waveglow.remove_weightnorm(waveglow)
39 | waveglow.cuda().eval()
40 | if is_fp16:
41 | from apex import amp
42 | waveglow, _ = amp.initialize(waveglow, [], opt_level="O3")
43 |
44 | if denoiser_strength > 0:
45 | denoiser = Denoiser(waveglow).cuda()
46 |
47 | for i, file_path in enumerate(mel_files):
48 | file_name = os.path.splitext(os.path.basename(file_path))[0]
49 | mel = torch.load(file_path)
50 | mel = torch.autograd.Variable(mel.cuda())
51 | mel = torch.unsqueeze(mel, 0)
52 | mel = mel.half() if is_fp16 else mel
53 | with torch.no_grad():
54 | audio = waveglow.infer(mel, sigma=sigma)
55 | if denoiser_strength > 0:
56 | audio = denoiser(audio, denoiser_strength)
57 | audio = audio * MAX_WAV_VALUE
58 | audio = audio.squeeze()
59 | audio = audio.cpu().numpy()
60 | audio = audio.astype('int16')
61 | audio_path = os.path.join(
62 | output_dir, "{}_synthesis.wav".format(file_name))
63 | write(audio_path, sampling_rate, audio)
64 | print(audio_path)
65 |
66 |
67 | if __name__ == "__main__":
68 | import argparse
69 |
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('-f', "--filelist_path", required=True)
72 | parser.add_argument('-w', '--waveglow_path',
73 | help='Path to waveglow decoder checkpoint with model')
74 | parser.add_argument('-o', "--output_dir", required=True)
75 | parser.add_argument("-s", "--sigma", default=1.0, type=float)
76 | parser.add_argument("--sampling_rate", default=22050, type=int)
77 | parser.add_argument("--is_fp16", action="store_true")
78 | parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float,
79 | help='Removes model bias. Start with 0.1 and adjust')
80 |
81 | args = parser.parse_args()
82 |
83 | main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir,
84 | args.sampling_rate, args.is_fp16, args.denoiser_strength)
85 |
--------------------------------------------------------------------------------
/waveglow/mel2samp.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************\
27 | import os
28 | import random
29 | import argparse
30 | import json
31 | import torch
32 | import torch.utils.data
33 | import sys
34 | from scipy.io.wavfile import read
35 |
36 | # We're using the audio processing from TacoTron2 to make sure it matches
37 | sys.path.insert(0, 'tacotron2')
38 | from tacotron2.layers import TacotronSTFT
39 |
40 | MAX_WAV_VALUE = 32768.0
41 |
42 | def files_to_list(filename):
43 | """
44 | Takes a text file of filenames and makes a list of filenames
45 | """
46 | with open(filename, encoding='utf-8') as f:
47 | files = f.readlines()
48 |
49 | files = [f.rstrip() for f in files]
50 | return files
51 |
52 | def load_wav_to_torch(full_path):
53 | """
54 | Loads wavdata into torch array
55 | """
56 | sampling_rate, data = read(full_path)
57 | return torch.from_numpy(data).float(), sampling_rate
58 |
59 |
60 | class Mel2Samp(torch.utils.data.Dataset):
61 | """
62 | This is the main class that calculates the spectrogram and returns the
63 | spectrogram, audio pair.
64 | """
65 | def __init__(self, training_files, segment_length, filter_length,
66 | hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
67 | self.audio_files = files_to_list(training_files)
68 | random.seed(1234)
69 | random.shuffle(self.audio_files)
70 | self.stft = TacotronSTFT(filter_length=filter_length,
71 | hop_length=hop_length,
72 | win_length=win_length,
73 | sampling_rate=sampling_rate,
74 | mel_fmin=mel_fmin, mel_fmax=mel_fmax)
75 | self.segment_length = segment_length
76 | self.sampling_rate = sampling_rate
77 |
78 | def get_mel(self, audio):
79 | audio_norm = audio / MAX_WAV_VALUE
80 | audio_norm = audio_norm.unsqueeze(0)
81 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
82 | melspec = self.stft.mel_spectrogram(audio_norm)
83 | melspec = torch.squeeze(melspec, 0)
84 | return melspec
85 |
86 | def __getitem__(self, index):
87 | # Read audio
88 | filename = self.audio_files[index]
89 | audio, sampling_rate = load_wav_to_torch(filename)
90 | if sampling_rate != self.sampling_rate:
91 | raise ValueError("{} SR doesn't match target {} SR".format(
92 | sampling_rate, self.sampling_rate))
93 |
94 | # Take segment
95 | if audio.size(0) >= self.segment_length:
96 | max_audio_start = audio.size(0) - self.segment_length
97 | audio_start = random.randint(0, max_audio_start)
98 | audio = audio[audio_start:audio_start+self.segment_length]
99 | else:
100 | audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data
101 |
102 | mel = self.get_mel(audio)
103 | audio = audio / MAX_WAV_VALUE
104 |
105 | return (mel, audio)
106 |
107 | def __len__(self):
108 | return len(self.audio_files)
109 |
110 | # ===================================================================
111 | # Takes directory of clean audio and makes directory of spectrograms
112 | # Useful for making test sets
113 | # ===================================================================
114 | if __name__ == "__main__":
115 | # Get defaults so it can work with no Sacred
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('-f', "--filelist_path", required=True)
118 | parser.add_argument('-c', '--config', type=str,
119 | help='JSON file for configuration')
120 | parser.add_argument('-o', '--output_dir', type=str,
121 | help='Output directory')
122 | args = parser.parse_args()
123 |
124 | with open(args.config) as f:
125 | data = f.read()
126 | data_config = json.loads(data)["data_config"]
127 | mel2samp = Mel2Samp(**data_config)
128 |
129 | filepaths = files_to_list(args.filelist_path)
130 |
131 | # Make directory if it doesn't exist
132 | if not os.path.isdir(args.output_dir):
133 | os.makedirs(args.output_dir)
134 | os.chmod(args.output_dir, 0o775)
135 |
136 | for filepath in filepaths:
137 | audio, sr = load_wav_to_torch(filepath)
138 | melspectrogram = mel2samp.get_mel(audio)
139 | filename = os.path.basename(filepath)
140 | new_filepath = args.output_dir + '/' + filename + '.pt'
141 | print(new_filepath)
142 | torch.save(melspectrogram, new_filepath)
143 |
--------------------------------------------------------------------------------
/waveglow/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.0
2 | matplotlib==2.1.0
3 | tensorflow
4 | numpy==1.13.3
5 | inflect==0.2.5
6 | librosa==0.6.0
7 | scipy==1.0.0
8 | tensorboardX==1.1
9 | Unidecode==1.0.22
10 | pillow
11 |
--------------------------------------------------------------------------------
/waveglow/train.py:
--------------------------------------------------------------------------------
1 | # *****************************************************************************
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of the NVIDIA CORPORATION nor the
12 | # names of its contributors may be used to endorse or promote products
13 | # derived from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # *****************************************************************************
27 | import argparse
28 | import json
29 | import os
30 | import torch
31 |
32 | #=====START: ADDED FOR DISTRIBUTED======
33 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor
34 | from torch.utils.data.distributed import DistributedSampler
35 | #=====END: ADDED FOR DISTRIBUTED======
36 |
37 | from torch.utils.data import DataLoader
38 | from glow import WaveGlow, WaveGlowLoss
39 | from mel2samp import Mel2Samp
40 |
41 | def load_checkpoint(checkpoint_path, model, optimizer):
42 | assert os.path.isfile(checkpoint_path)
43 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
44 | iteration = checkpoint_dict['iteration']
45 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
46 | model_for_loading = checkpoint_dict['model']
47 | model.load_state_dict(model_for_loading.state_dict())
48 | print("Loaded checkpoint '{}' (iteration {})" .format(
49 | checkpoint_path, iteration))
50 | return model, optimizer, iteration
51 |
52 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
53 | print("Saving model and optimizer state at iteration {} to {}".format(
54 | iteration, filepath))
55 | model_for_saving = WaveGlow(**waveglow_config).cuda()
56 | model_for_saving.load_state_dict(model.state_dict())
57 | torch.save({'model': model_for_saving,
58 | 'iteration': iteration,
59 | 'optimizer': optimizer.state_dict(),
60 | 'learning_rate': learning_rate}, filepath)
61 |
62 | def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
63 | sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
64 | checkpoint_path, with_tensorboard):
65 | torch.manual_seed(seed)
66 | torch.cuda.manual_seed(seed)
67 | #=====START: ADDED FOR DISTRIBUTED======
68 | if num_gpus > 1:
69 | init_distributed(rank, num_gpus, group_name, **dist_config)
70 | #=====END: ADDED FOR DISTRIBUTED======
71 |
72 | criterion = WaveGlowLoss(sigma)
73 | model = WaveGlow(**waveglow_config).cuda()
74 |
75 | #=====START: ADDED FOR DISTRIBUTED======
76 | if num_gpus > 1:
77 | model = apply_gradient_allreduce(model)
78 | #=====END: ADDED FOR DISTRIBUTED======
79 |
80 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
81 |
82 | if fp16_run:
83 | from apex import amp
84 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
85 |
86 | # Load checkpoint if one exists
87 | iteration = 0
88 | if checkpoint_path != "":
89 | model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
90 | optimizer)
91 | iteration += 1 # next iteration is iteration + 1
92 |
93 | trainset = Mel2Samp(**data_config)
94 | # =====START: ADDED FOR DISTRIBUTED======
95 | train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
96 | # =====END: ADDED FOR DISTRIBUTED======
97 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
98 | sampler=train_sampler,
99 | batch_size=batch_size,
100 | pin_memory=False,
101 | drop_last=True)
102 |
103 | # Get shared output_directory ready
104 | if rank == 0:
105 | if not os.path.isdir(output_directory):
106 | os.makedirs(output_directory)
107 | os.chmod(output_directory, 0o775)
108 | print("output directory", output_directory)
109 |
110 | if with_tensorboard and rank == 0:
111 | from tensorboardX import SummaryWriter
112 | logger = SummaryWriter(os.path.join(output_directory, 'logs'))
113 |
114 | model.train()
115 | epoch_offset = max(0, int(iteration / len(train_loader)))
116 | # ================ MAIN TRAINNIG LOOP! ===================
117 | for epoch in range(epoch_offset, epochs):
118 | print("Epoch: {}".format(epoch))
119 | for i, batch in enumerate(train_loader):
120 | model.zero_grad()
121 |
122 | mel, audio = batch
123 | mel = torch.autograd.Variable(mel.cuda())
124 | audio = torch.autograd.Variable(audio.cuda())
125 | outputs = model((mel, audio))
126 |
127 | loss = criterion(outputs)
128 | if num_gpus > 1:
129 | reduced_loss = reduce_tensor(loss.data, num_gpus).item()
130 | else:
131 | reduced_loss = loss.item()
132 |
133 | if fp16_run:
134 | with amp.scale_loss(loss, optimizer) as scaled_loss:
135 | scaled_loss.backward()
136 | else:
137 | loss.backward()
138 |
139 | optimizer.step()
140 |
141 | print("{}:\t{:.9f}".format(iteration, reduced_loss))
142 | if with_tensorboard and rank == 0:
143 | logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)
144 |
145 | if (iteration % iters_per_checkpoint == 0):
146 | if rank == 0:
147 | checkpoint_path = "{}/waveglow_{}".format(
148 | output_directory, iteration)
149 | save_checkpoint(model, optimizer, learning_rate, iteration,
150 | checkpoint_path)
151 |
152 | iteration += 1
153 |
154 | if __name__ == "__main__":
155 | parser = argparse.ArgumentParser()
156 | parser.add_argument('-c', '--config', type=str,
157 | help='JSON file for configuration')
158 | parser.add_argument('-r', '--rank', type=int, default=0,
159 | help='rank of process for distributed')
160 | parser.add_argument('-g', '--group_name', type=str, default='',
161 | help='name of group for distributed')
162 | args = parser.parse_args()
163 |
164 | # Parse configs. Globals nicer in this case
165 | with open(args.config) as f:
166 | data = f.read()
167 | config = json.loads(data)
168 | train_config = config["train_config"]
169 | global data_config
170 | data_config = config["data_config"]
171 | global dist_config
172 | dist_config = config["dist_config"]
173 | global waveglow_config
174 | waveglow_config = config["waveglow_config"]
175 |
176 | num_gpus = torch.cuda.device_count()
177 | if num_gpus > 1:
178 | if args.group_name == '':
179 | print("WARNING: Multiple GPUs detected but no distributed group set")
180 | print("Only running 1 GPU. Use distributed.py for multiple GPUs")
181 | num_gpus = 1
182 |
183 | if num_gpus == 1 and args.rank != 0:
184 | raise Exception("Doing single GPU training on rank > 0")
185 |
186 | torch.backends.cudnn.enabled = True
187 | torch.backends.cudnn.benchmark = False
188 | train(num_gpus, args.rank, args.group_name, **train_config)
189 |
--------------------------------------------------------------------------------
/waveglow/waveglow_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/waveglow/waveglow_logo.png
--------------------------------------------------------------------------------
/wavs/LJ001-0029_phone10000_10.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_10.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0029_phone10000_11.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_11.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0029_phone10000_12.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_12.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0029_phone10000_8.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_8.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0029_phone10000_9.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0029_phone10000_9.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0085_phone10000_10.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_10.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0085_phone10000_11.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_11.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0085_phone10000_12.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_12.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0085_phone10000_8.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_8.wav
--------------------------------------------------------------------------------
/wavs/LJ001-0085_phone10000_9.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ001-0085_phone10000_9.wav
--------------------------------------------------------------------------------
/wavs/LJ002-0106_phone10000_10.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_10.wav
--------------------------------------------------------------------------------
/wavs/LJ002-0106_phone10000_11.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_11.wav
--------------------------------------------------------------------------------
/wavs/LJ002-0106_phone10000_12.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_12.wav
--------------------------------------------------------------------------------
/wavs/LJ002-0106_phone10000_8.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_8.wav
--------------------------------------------------------------------------------
/wavs/LJ002-0106_phone10000_9.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Deepest-Project/AlignTTS/ed9c29d845f65ceb44c87f293b2919b9bbc6a6de/wavs/LJ002-0106_phone10000_9.wav
--------------------------------------------------------------------------------