├── .gitignore
├── LICENSE
├── README.md
├── attentions.py
├── commons.py
├── configs
├── ljs_base.json
├── ljs_nosdp.json
└── vctk_base.json
├── data_utils.py
├── filelists
├── ljs_audio_text_test_filelist.txt
├── ljs_audio_text_test_filelist.txt.cleaned
├── ljs_audio_text_train_filelist.txt
├── ljs_audio_text_train_filelist.txt.cleaned
├── ljs_audio_text_val_filelist.txt
├── ljs_audio_text_val_filelist.txt.cleaned
├── vctk_audio_sid_text_test_filelist.txt
├── vctk_audio_sid_text_test_filelist.txt.cleaned
├── vctk_audio_sid_text_train_filelist.txt
├── vctk_audio_sid_text_train_filelist.txt.cleaned
├── vctk_audio_sid_text_val_filelist.txt
└── vctk_audio_sid_text_val_filelist.txt.cleaned
├── inference.ipynb
├── losses.py
├── mel_processing.py
├── models.py
├── modules.py
├── monotonic_align
├── __init__.py
├── core.pyx
└── setup.py
├── preprocess.py
├── requirements.txt
├── resources
├── fig_1a.png
├── fig_1b.png
└── training.png
├── text
├── LICENSE
├── __init__.py
├── cleaners.py
└── symbols.py
├── train.py
├── train_ms.py
├── transforms.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | DUMMY1
2 | DUMMY2
3 | DUMMY3
4 | logs
5 | __pycache__
6 | .ipynb_checkpoints
7 | .*.swp
8 |
9 | build
10 | *.c
11 | monotonic_align/monotonic_align
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jaehyeon Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
2 |
3 | ### Jaehyeon Kim, Jungil Kong, and Juhee Son
4 |
5 | In our recent [paper](https://arxiv.org/abs/2106.06103), we propose VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech.
6 |
7 | Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. In this work, we present a parallel end-to-end TTS method that generates more natural sounding audio than current two-stage models. Our method adopts variational inference augmented with normalizing flows and an adversarial training process, which improves the expressive power of generative modeling. We also propose a stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the natural one-to-many relationship in which a text input can be spoken in multiple ways with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly available TTS systems and achieves a MOS comparable to ground truth.
8 |
9 | Visit our [demo](https://jaywalnut310.github.io/vits-demo/index.html) for audio samples.
10 |
11 | We also provide the [pretrained models](https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2?usp=sharing).
12 |
13 | ** Update note: Thanks to [Rishikesh (ऋषिकेश)](https://github.com/jaywalnut310/vits/issues/1), our interactive TTS demo is now available on [Colab Notebook](https://colab.research.google.com/drive/1CO61pZizDj7en71NQG_aqqKdGaA_SaBf?usp=sharing).
14 |
15 |
16 |
17 | VITS at training |
18 | VITS at inference |
19 |
20 |
21 |  |
22 |  |
23 |
24 |
25 |
26 |
27 | ## Pre-requisites
28 | 0. Python >= 3.6
29 | 0. Clone this repository
30 | 0. Install python requirements. Please refer [requirements.txt](requirements.txt)
31 | 1. You may need to install espeak first: `apt-get install espeak`
32 | 0. Download datasets
33 | 1. Download and extract the LJ Speech dataset, then rename or create a link to the dataset folder: `ln -s /path/to/LJSpeech-1.1/wavs DUMMY1`
34 | 1. For mult-speaker setting, download and extract the VCTK dataset, and downsample wav files to 22050 Hz. Then rename or create a link to the dataset folder: `ln -s /path/to/VCTK-Corpus/downsampled_wavs DUMMY2`
35 | 0. Build Monotonic Alignment Search and run preprocessing if you use your own datasets.
36 | ```sh
37 | # Cython-version Monotonoic Alignment Search
38 | cd monotonic_align
39 | python setup.py build_ext --inplace
40 |
41 | # Preprocessing (g2p) for your own datasets. Preprocessed phonemes for LJ Speech and VCTK have been already provided.
42 | # python preprocess.py --text_index 1 --filelists filelists/ljs_audio_text_train_filelist.txt filelists/ljs_audio_text_val_filelist.txt filelists/ljs_audio_text_test_filelist.txt
43 | # python preprocess.py --text_index 2 --filelists filelists/vctk_audio_sid_text_train_filelist.txt filelists/vctk_audio_sid_text_val_filelist.txt filelists/vctk_audio_sid_text_test_filelist.txt
44 | ```
45 |
46 |
47 | ## Training Exmaple
48 | ```sh
49 | # LJ Speech
50 | python train.py -c configs/ljs_base.json -m ljs_base
51 |
52 | # VCTK
53 | python train_ms.py -c configs/vctk_base.json -m vctk_base
54 | ```
55 |
56 |
57 | ## Inference Example
58 | See [inference.ipynb](inference.ipynb)
59 |
--------------------------------------------------------------------------------
/attentions.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | import commons
9 | import modules
10 | from modules import LayerNorm
11 |
12 |
13 | class Encoder(nn.Module):
14 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
15 | super().__init__()
16 | self.hidden_channels = hidden_channels
17 | self.filter_channels = filter_channels
18 | self.n_heads = n_heads
19 | self.n_layers = n_layers
20 | self.kernel_size = kernel_size
21 | self.p_dropout = p_dropout
22 | self.window_size = window_size
23 |
24 | self.drop = nn.Dropout(p_dropout)
25 | self.attn_layers = nn.ModuleList()
26 | self.norm_layers_1 = nn.ModuleList()
27 | self.ffn_layers = nn.ModuleList()
28 | self.norm_layers_2 = nn.ModuleList()
29 | for i in range(self.n_layers):
30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
31 | self.norm_layers_1.append(LayerNorm(hidden_channels))
32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
33 | self.norm_layers_2.append(LayerNorm(hidden_channels))
34 |
35 | def forward(self, x, x_mask):
36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
37 | x = x * x_mask
38 | for i in range(self.n_layers):
39 | y = self.attn_layers[i](x, x, attn_mask)
40 | y = self.drop(y)
41 | x = self.norm_layers_1[i](x + y)
42 |
43 | y = self.ffn_layers[i](x, x_mask)
44 | y = self.drop(y)
45 | x = self.norm_layers_2[i](x + y)
46 | x = x * x_mask
47 | return x
48 |
49 |
50 | class Decoder(nn.Module):
51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
52 | super().__init__()
53 | self.hidden_channels = hidden_channels
54 | self.filter_channels = filter_channels
55 | self.n_heads = n_heads
56 | self.n_layers = n_layers
57 | self.kernel_size = kernel_size
58 | self.p_dropout = p_dropout
59 | self.proximal_bias = proximal_bias
60 | self.proximal_init = proximal_init
61 |
62 | self.drop = nn.Dropout(p_dropout)
63 | self.self_attn_layers = nn.ModuleList()
64 | self.norm_layers_0 = nn.ModuleList()
65 | self.encdec_attn_layers = nn.ModuleList()
66 | self.norm_layers_1 = nn.ModuleList()
67 | self.ffn_layers = nn.ModuleList()
68 | self.norm_layers_2 = nn.ModuleList()
69 | for i in range(self.n_layers):
70 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
71 | self.norm_layers_0.append(LayerNorm(hidden_channels))
72 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
73 | self.norm_layers_1.append(LayerNorm(hidden_channels))
74 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
75 | self.norm_layers_2.append(LayerNorm(hidden_channels))
76 |
77 | def forward(self, x, x_mask, h, h_mask):
78 | """
79 | x: decoder input
80 | h: encoder output
81 | """
82 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
83 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
84 | x = x * x_mask
85 | for i in range(self.n_layers):
86 | y = self.self_attn_layers[i](x, x, self_attn_mask)
87 | y = self.drop(y)
88 | x = self.norm_layers_0[i](x + y)
89 |
90 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
91 | y = self.drop(y)
92 | x = self.norm_layers_1[i](x + y)
93 |
94 | y = self.ffn_layers[i](x, x_mask)
95 | y = self.drop(y)
96 | x = self.norm_layers_2[i](x + y)
97 | x = x * x_mask
98 | return x
99 |
100 |
101 | class MultiHeadAttention(nn.Module):
102 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
103 | super().__init__()
104 | assert channels % n_heads == 0
105 |
106 | self.channels = channels
107 | self.out_channels = out_channels
108 | self.n_heads = n_heads
109 | self.p_dropout = p_dropout
110 | self.window_size = window_size
111 | self.heads_share = heads_share
112 | self.block_length = block_length
113 | self.proximal_bias = proximal_bias
114 | self.proximal_init = proximal_init
115 | self.attn = None
116 |
117 | self.k_channels = channels // n_heads
118 | self.conv_q = nn.Conv1d(channels, channels, 1)
119 | self.conv_k = nn.Conv1d(channels, channels, 1)
120 | self.conv_v = nn.Conv1d(channels, channels, 1)
121 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
122 | self.drop = nn.Dropout(p_dropout)
123 |
124 | if window_size is not None:
125 | n_heads_rel = 1 if heads_share else n_heads
126 | rel_stddev = self.k_channels**-0.5
127 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
128 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
129 |
130 | nn.init.xavier_uniform_(self.conv_q.weight)
131 | nn.init.xavier_uniform_(self.conv_k.weight)
132 | nn.init.xavier_uniform_(self.conv_v.weight)
133 | if proximal_init:
134 | with torch.no_grad():
135 | self.conv_k.weight.copy_(self.conv_q.weight)
136 | self.conv_k.bias.copy_(self.conv_q.bias)
137 |
138 | def forward(self, x, c, attn_mask=None):
139 | q = self.conv_q(x)
140 | k = self.conv_k(c)
141 | v = self.conv_v(c)
142 |
143 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
144 |
145 | x = self.conv_o(x)
146 | return x
147 |
148 | def attention(self, query, key, value, mask=None):
149 | # reshape [b, d, t] -> [b, n_h, t, d_k]
150 | b, d, t_s, t_t = (*key.size(), query.size(2))
151 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
152 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
153 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
154 |
155 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
156 | if self.window_size is not None:
157 | assert t_s == t_t, "Relative attention is only available for self-attention."
158 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
159 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
160 | scores_local = self._relative_position_to_absolute_position(rel_logits)
161 | scores = scores + scores_local
162 | if self.proximal_bias:
163 | assert t_s == t_t, "Proximal bias is only available for self-attention."
164 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
165 | if mask is not None:
166 | scores = scores.masked_fill(mask == 0, -1e4)
167 | if self.block_length is not None:
168 | assert t_s == t_t, "Local attention is only available for self-attention."
169 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
170 | scores = scores.masked_fill(block_mask == 0, -1e4)
171 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
172 | p_attn = self.drop(p_attn)
173 | output = torch.matmul(p_attn, value)
174 | if self.window_size is not None:
175 | relative_weights = self._absolute_position_to_relative_position(p_attn)
176 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
177 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
178 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
179 | return output, p_attn
180 |
181 | def _matmul_with_relative_values(self, x, y):
182 | """
183 | x: [b, h, l, m]
184 | y: [h or 1, m, d]
185 | ret: [b, h, l, d]
186 | """
187 | ret = torch.matmul(x, y.unsqueeze(0))
188 | return ret
189 |
190 | def _matmul_with_relative_keys(self, x, y):
191 | """
192 | x: [b, h, l, d]
193 | y: [h or 1, m, d]
194 | ret: [b, h, l, m]
195 | """
196 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
197 | return ret
198 |
199 | def _get_relative_embeddings(self, relative_embeddings, length):
200 | max_relative_position = 2 * self.window_size + 1
201 | # Pad first before slice to avoid using cond ops.
202 | pad_length = max(length - (self.window_size + 1), 0)
203 | slice_start_position = max((self.window_size + 1) - length, 0)
204 | slice_end_position = slice_start_position + 2 * length - 1
205 | if pad_length > 0:
206 | padded_relative_embeddings = F.pad(
207 | relative_embeddings,
208 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
209 | else:
210 | padded_relative_embeddings = relative_embeddings
211 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
212 | return used_relative_embeddings
213 |
214 | def _relative_position_to_absolute_position(self, x):
215 | """
216 | x: [b, h, l, 2*l-1]
217 | ret: [b, h, l, l]
218 | """
219 | batch, heads, length, _ = x.size()
220 | # Concat columns of pad to shift from relative to absolute indexing.
221 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
222 |
223 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
224 | x_flat = x.view([batch, heads, length * 2 * length])
225 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
226 |
227 | # Reshape and slice out the padded elements.
228 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
229 | return x_final
230 |
231 | def _absolute_position_to_relative_position(self, x):
232 | """
233 | x: [b, h, l, l]
234 | ret: [b, h, l, 2*l-1]
235 | """
236 | batch, heads, length, _ = x.size()
237 | # padd along column
238 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
239 | x_flat = x.view([batch, heads, length**2 + length*(length -1)])
240 | # add 0's in the beginning that will skew the elements after reshape
241 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
242 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
243 | return x_final
244 |
245 | def _attention_bias_proximal(self, length):
246 | """Bias for self-attention to encourage attention to close positions.
247 | Args:
248 | length: an integer scalar.
249 | Returns:
250 | a Tensor with shape [1, 1, length, length]
251 | """
252 | r = torch.arange(length, dtype=torch.float32)
253 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
254 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
255 |
256 |
257 | class FFN(nn.Module):
258 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
259 | super().__init__()
260 | self.in_channels = in_channels
261 | self.out_channels = out_channels
262 | self.filter_channels = filter_channels
263 | self.kernel_size = kernel_size
264 | self.p_dropout = p_dropout
265 | self.activation = activation
266 | self.causal = causal
267 |
268 | if causal:
269 | self.padding = self._causal_padding
270 | else:
271 | self.padding = self._same_padding
272 |
273 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
274 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
275 | self.drop = nn.Dropout(p_dropout)
276 |
277 | def forward(self, x, x_mask):
278 | x = self.conv_1(self.padding(x * x_mask))
279 | if self.activation == "gelu":
280 | x = x * torch.sigmoid(1.702 * x)
281 | else:
282 | x = torch.relu(x)
283 | x = self.drop(x)
284 | x = self.conv_2(self.padding(x * x_mask))
285 | return x * x_mask
286 |
287 | def _causal_padding(self, x):
288 | if self.kernel_size == 1:
289 | return x
290 | pad_l = self.kernel_size - 1
291 | pad_r = 0
292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
293 | x = F.pad(x, commons.convert_pad_shape(padding))
294 | return x
295 |
296 | def _same_padding(self, x):
297 | if self.kernel_size == 1:
298 | return x
299 | pad_l = (self.kernel_size - 1) // 2
300 | pad_r = self.kernel_size // 2
301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
302 | x = F.pad(x, commons.convert_pad_shape(padding))
303 | return x
304 |
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | def init_weights(m, mean=0.0, std=0.01):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | m.weight.data.normal_(mean, std)
12 |
13 |
14 | def get_padding(kernel_size, dilation=1):
15 | return int((kernel_size*dilation - dilation)/2)
16 |
17 |
18 | def convert_pad_shape(pad_shape):
19 | l = pad_shape[::-1]
20 | pad_shape = [item for sublist in l for item in sublist]
21 | return pad_shape
22 |
23 |
24 | def intersperse(lst, item):
25 | result = [item] * (len(lst) * 2 + 1)
26 | result[1::2] = lst
27 | return result
28 |
29 |
30 | def kl_divergence(m_p, logs_p, m_q, logs_q):
31 | """KL(P||Q)"""
32 | kl = (logs_q - logs_p) - 0.5
33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34 | return kl
35 |
36 |
37 | def rand_gumbel(shape):
38 | """Sample from the Gumbel distribution, protect from overflows."""
39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40 | return -torch.log(-torch.log(uniform_samples))
41 |
42 |
43 | def rand_gumbel_like(x):
44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45 | return g
46 |
47 |
48 | def slice_segments(x, ids_str, segment_size=4):
49 | ret = torch.zeros_like(x[:, :, :segment_size])
50 | for i in range(x.size(0)):
51 | idx_str = ids_str[i]
52 | idx_end = idx_str + segment_size
53 | ret[i] = x[i, :, idx_str:idx_end]
54 | return ret
55 |
56 |
57 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
58 | b, d, t = x.size()
59 | if x_lengths is None:
60 | x_lengths = t
61 | ids_str_max = x_lengths - segment_size + 1
62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63 | ret = slice_segments(x, ids_str, segment_size)
64 | return ret, ids_str
65 |
66 |
67 | def get_timing_signal_1d(
68 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
69 | position = torch.arange(length, dtype=torch.float)
70 | num_timescales = channels // 2
71 | log_timescale_increment = (
72 | math.log(float(max_timescale) / float(min_timescale)) /
73 | (num_timescales - 1))
74 | inv_timescales = min_timescale * torch.exp(
75 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78 | signal = F.pad(signal, [0, 0, 0, channels % 2])
79 | signal = signal.view(1, channels, length)
80 | return signal
81 |
82 |
83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84 | b, channels, length = x.size()
85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86 | return x + signal.to(dtype=x.dtype, device=x.device)
87 |
88 |
89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90 | b, channels, length = x.size()
91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93 |
94 |
95 | def subsequent_mask(length):
96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97 | return mask
98 |
99 |
100 | @torch.jit.script
101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102 | n_channels_int = n_channels[0]
103 | in_act = input_a + input_b
104 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106 | acts = t_act * s_act
107 | return acts
108 |
109 |
110 | def convert_pad_shape(pad_shape):
111 | l = pad_shape[::-1]
112 | pad_shape = [item for sublist in l for item in sublist]
113 | return pad_shape
114 |
115 |
116 | def shift_1d(x):
117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118 | return x
119 |
120 |
121 | def sequence_mask(length, max_length=None):
122 | if max_length is None:
123 | max_length = length.max()
124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125 | return x.unsqueeze(0) < length.unsqueeze(1)
126 |
127 |
128 | def generate_path(duration, mask):
129 | """
130 | duration: [b, 1, t_x]
131 | mask: [b, 1, t_y, t_x]
132 | """
133 | device = duration.device
134 |
135 | b, _, t_y, t_x = mask.shape
136 | cum_duration = torch.cumsum(duration, -1)
137 |
138 | cum_duration_flat = cum_duration.view(b * t_x)
139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140 | path = path.view(b, t_x, t_y)
141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142 | path = path.unsqueeze(1).transpose(2,3) * mask
143 | return path
144 |
145 |
146 | def clip_grad_value_(parameters, clip_value, norm_type=2):
147 | if isinstance(parameters, torch.Tensor):
148 | parameters = [parameters]
149 | parameters = list(filter(lambda p: p.grad is not None, parameters))
150 | norm_type = float(norm_type)
151 | if clip_value is not None:
152 | clip_value = float(clip_value)
153 |
154 | total_norm = 0
155 | for p in parameters:
156 | param_norm = p.grad.data.norm(norm_type)
157 | total_norm += param_norm.item() ** norm_type
158 | if clip_value is not None:
159 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
160 | total_norm = total_norm ** (1. / norm_type)
161 | return total_norm
162 |
--------------------------------------------------------------------------------
/configs/ljs_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 0,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/configs/ljs_nosdp.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 0,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false,
51 | "use_sdp": false
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/configs/vctk_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 109,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false,
51 | "gin_channels": 256
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 |
8 | import commons
9 | from mel_processing import spectrogram_torch
10 | from utils import load_wav_to_torch, load_filepaths_and_text
11 | from text import text_to_sequence, cleaned_text_to_sequence
12 |
13 |
14 | class TextAudioLoader(torch.utils.data.Dataset):
15 | """
16 | 1) loads audio, text pairs
17 | 2) normalizes text and converts them to sequences of integers
18 | 3) computes spectrograms from audio files.
19 | """
20 | def __init__(self, audiopaths_and_text, hparams):
21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
22 | self.text_cleaners = hparams.text_cleaners
23 | self.max_wav_value = hparams.max_wav_value
24 | self.sampling_rate = hparams.sampling_rate
25 | self.filter_length = hparams.filter_length
26 | self.hop_length = hparams.hop_length
27 | self.win_length = hparams.win_length
28 | self.sampling_rate = hparams.sampling_rate
29 |
30 | self.cleaned_text = getattr(hparams, "cleaned_text", False)
31 |
32 | self.add_blank = hparams.add_blank
33 | self.min_text_len = getattr(hparams, "min_text_len", 1)
34 | self.max_text_len = getattr(hparams, "max_text_len", 190)
35 |
36 | random.seed(1234)
37 | random.shuffle(self.audiopaths_and_text)
38 | self._filter()
39 |
40 |
41 | def _filter(self):
42 | """
43 | Filter text & store spec lengths
44 | """
45 | # Store spectrogram lengths for Bucketing
46 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
47 | # spec_length = wav_length // hop_length
48 |
49 | audiopaths_and_text_new = []
50 | lengths = []
51 | for audiopath, text in self.audiopaths_and_text:
52 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
53 | audiopaths_and_text_new.append([audiopath, text])
54 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
55 | self.audiopaths_and_text = audiopaths_and_text_new
56 | self.lengths = lengths
57 |
58 | def get_audio_text_pair(self, audiopath_and_text):
59 | # separate filename and text
60 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
61 | text = self.get_text(text)
62 | spec, wav = self.get_audio(audiopath)
63 | return (text, spec, wav)
64 |
65 | def get_audio(self, filename):
66 | audio, sampling_rate = load_wav_to_torch(filename)
67 | if sampling_rate != self.sampling_rate:
68 | raise ValueError("{} {} SR doesn't match target {} SR".format(
69 | sampling_rate, self.sampling_rate))
70 | audio_norm = audio / self.max_wav_value
71 | audio_norm = audio_norm.unsqueeze(0)
72 | spec_filename = filename.replace(".wav", ".spec.pt")
73 | if os.path.exists(spec_filename):
74 | spec = torch.load(spec_filename)
75 | else:
76 | spec = spectrogram_torch(audio_norm, self.filter_length,
77 | self.sampling_rate, self.hop_length, self.win_length,
78 | center=False)
79 | spec = torch.squeeze(spec, 0)
80 | torch.save(spec, spec_filename)
81 | return spec, audio_norm
82 |
83 | def get_text(self, text):
84 | if self.cleaned_text:
85 | text_norm = cleaned_text_to_sequence(text)
86 | else:
87 | text_norm = text_to_sequence(text, self.text_cleaners)
88 | if self.add_blank:
89 | text_norm = commons.intersperse(text_norm, 0)
90 | text_norm = torch.LongTensor(text_norm)
91 | return text_norm
92 |
93 | def __getitem__(self, index):
94 | return self.get_audio_text_pair(self.audiopaths_and_text[index])
95 |
96 | def __len__(self):
97 | return len(self.audiopaths_and_text)
98 |
99 |
100 | class TextAudioCollate():
101 | """ Zero-pads model inputs and targets
102 | """
103 | def __init__(self, return_ids=False):
104 | self.return_ids = return_ids
105 |
106 | def __call__(self, batch):
107 | """Collate's training batch from normalized text and aduio
108 | PARAMS
109 | ------
110 | batch: [text_normalized, spec_normalized, wav_normalized]
111 | """
112 | # Right zero-pad all one-hot text sequences to max input length
113 | _, ids_sorted_decreasing = torch.sort(
114 | torch.LongTensor([x[1].size(1) for x in batch]),
115 | dim=0, descending=True)
116 |
117 | max_text_len = max([len(x[0]) for x in batch])
118 | max_spec_len = max([x[1].size(1) for x in batch])
119 | max_wav_len = max([x[2].size(1) for x in batch])
120 |
121 | text_lengths = torch.LongTensor(len(batch))
122 | spec_lengths = torch.LongTensor(len(batch))
123 | wav_lengths = torch.LongTensor(len(batch))
124 |
125 | text_padded = torch.LongTensor(len(batch), max_text_len)
126 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
127 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
128 | text_padded.zero_()
129 | spec_padded.zero_()
130 | wav_padded.zero_()
131 | for i in range(len(ids_sorted_decreasing)):
132 | row = batch[ids_sorted_decreasing[i]]
133 |
134 | text = row[0]
135 | text_padded[i, :text.size(0)] = text
136 | text_lengths[i] = text.size(0)
137 |
138 | spec = row[1]
139 | spec_padded[i, :, :spec.size(1)] = spec
140 | spec_lengths[i] = spec.size(1)
141 |
142 | wav = row[2]
143 | wav_padded[i, :, :wav.size(1)] = wav
144 | wav_lengths[i] = wav.size(1)
145 |
146 | if self.return_ids:
147 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing
148 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths
149 |
150 |
151 | """Multi speaker version"""
152 | class TextAudioSpeakerLoader(torch.utils.data.Dataset):
153 | """
154 | 1) loads audio, speaker_id, text pairs
155 | 2) normalizes text and converts them to sequences of integers
156 | 3) computes spectrograms from audio files.
157 | """
158 | def __init__(self, audiopaths_sid_text, hparams):
159 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
160 | self.text_cleaners = hparams.text_cleaners
161 | self.max_wav_value = hparams.max_wav_value
162 | self.sampling_rate = hparams.sampling_rate
163 | self.filter_length = hparams.filter_length
164 | self.hop_length = hparams.hop_length
165 | self.win_length = hparams.win_length
166 | self.sampling_rate = hparams.sampling_rate
167 |
168 | self.cleaned_text = getattr(hparams, "cleaned_text", False)
169 |
170 | self.add_blank = hparams.add_blank
171 | self.min_text_len = getattr(hparams, "min_text_len", 1)
172 | self.max_text_len = getattr(hparams, "max_text_len", 190)
173 |
174 | random.seed(1234)
175 | random.shuffle(self.audiopaths_sid_text)
176 | self._filter()
177 |
178 | def _filter(self):
179 | """
180 | Filter text & store spec lengths
181 | """
182 | # Store spectrogram lengths for Bucketing
183 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
184 | # spec_length = wav_length // hop_length
185 |
186 | audiopaths_sid_text_new = []
187 | lengths = []
188 | for audiopath, sid, text in self.audiopaths_sid_text:
189 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
190 | audiopaths_sid_text_new.append([audiopath, sid, text])
191 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
192 | self.audiopaths_sid_text = audiopaths_sid_text_new
193 | self.lengths = lengths
194 |
195 | def get_audio_text_speaker_pair(self, audiopath_sid_text):
196 | # separate filename, speaker_id and text
197 | audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2]
198 | text = self.get_text(text)
199 | spec, wav = self.get_audio(audiopath)
200 | sid = self.get_sid(sid)
201 | return (text, spec, wav, sid)
202 |
203 | def get_audio(self, filename):
204 | audio, sampling_rate = load_wav_to_torch(filename)
205 | if sampling_rate != self.sampling_rate:
206 | raise ValueError("{} {} SR doesn't match target {} SR".format(
207 | sampling_rate, self.sampling_rate))
208 | audio_norm = audio / self.max_wav_value
209 | audio_norm = audio_norm.unsqueeze(0)
210 | spec_filename = filename.replace(".wav", ".spec.pt")
211 | if os.path.exists(spec_filename):
212 | spec = torch.load(spec_filename)
213 | else:
214 | spec = spectrogram_torch(audio_norm, self.filter_length,
215 | self.sampling_rate, self.hop_length, self.win_length,
216 | center=False)
217 | spec = torch.squeeze(spec, 0)
218 | torch.save(spec, spec_filename)
219 | return spec, audio_norm
220 |
221 | def get_text(self, text):
222 | if self.cleaned_text:
223 | text_norm = cleaned_text_to_sequence(text)
224 | else:
225 | text_norm = text_to_sequence(text, self.text_cleaners)
226 | if self.add_blank:
227 | text_norm = commons.intersperse(text_norm, 0)
228 | text_norm = torch.LongTensor(text_norm)
229 | return text_norm
230 |
231 | def get_sid(self, sid):
232 | sid = torch.LongTensor([int(sid)])
233 | return sid
234 |
235 | def __getitem__(self, index):
236 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
237 |
238 | def __len__(self):
239 | return len(self.audiopaths_sid_text)
240 |
241 |
242 | class TextAudioSpeakerCollate():
243 | """ Zero-pads model inputs and targets
244 | """
245 | def __init__(self, return_ids=False):
246 | self.return_ids = return_ids
247 |
248 | def __call__(self, batch):
249 | """Collate's training batch from normalized text, audio and speaker identities
250 | PARAMS
251 | ------
252 | batch: [text_normalized, spec_normalized, wav_normalized, sid]
253 | """
254 | # Right zero-pad all one-hot text sequences to max input length
255 | _, ids_sorted_decreasing = torch.sort(
256 | torch.LongTensor([x[1].size(1) for x in batch]),
257 | dim=0, descending=True)
258 |
259 | max_text_len = max([len(x[0]) for x in batch])
260 | max_spec_len = max([x[1].size(1) for x in batch])
261 | max_wav_len = max([x[2].size(1) for x in batch])
262 |
263 | text_lengths = torch.LongTensor(len(batch))
264 | spec_lengths = torch.LongTensor(len(batch))
265 | wav_lengths = torch.LongTensor(len(batch))
266 | sid = torch.LongTensor(len(batch))
267 |
268 | text_padded = torch.LongTensor(len(batch), max_text_len)
269 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
270 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
271 | text_padded.zero_()
272 | spec_padded.zero_()
273 | wav_padded.zero_()
274 | for i in range(len(ids_sorted_decreasing)):
275 | row = batch[ids_sorted_decreasing[i]]
276 |
277 | text = row[0]
278 | text_padded[i, :text.size(0)] = text
279 | text_lengths[i] = text.size(0)
280 |
281 | spec = row[1]
282 | spec_padded[i, :, :spec.size(1)] = spec
283 | spec_lengths[i] = spec.size(1)
284 |
285 | wav = row[2]
286 | wav_padded[i, :, :wav.size(1)] = wav
287 | wav_lengths[i] = wav.size(1)
288 |
289 | sid[i] = row[3]
290 |
291 | if self.return_ids:
292 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing
293 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid
294 |
295 |
296 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
297 | """
298 | Maintain similar input lengths in a batch.
299 | Length groups are specified by boundaries.
300 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
301 |
302 | It removes samples which are not included in the boundaries.
303 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
304 | """
305 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
306 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
307 | self.lengths = dataset.lengths
308 | self.batch_size = batch_size
309 | self.boundaries = boundaries
310 |
311 | self.buckets, self.num_samples_per_bucket = self._create_buckets()
312 | self.total_size = sum(self.num_samples_per_bucket)
313 | self.num_samples = self.total_size // self.num_replicas
314 |
315 | def _create_buckets(self):
316 | buckets = [[] for _ in range(len(self.boundaries) - 1)]
317 | for i in range(len(self.lengths)):
318 | length = self.lengths[i]
319 | idx_bucket = self._bisect(length)
320 | if idx_bucket != -1:
321 | buckets[idx_bucket].append(i)
322 |
323 | for i in range(len(buckets) - 1, 0, -1):
324 | if len(buckets[i]) == 0:
325 | buckets.pop(i)
326 | self.boundaries.pop(i+1)
327 |
328 | num_samples_per_bucket = []
329 | for i in range(len(buckets)):
330 | len_bucket = len(buckets[i])
331 | total_batch_size = self.num_replicas * self.batch_size
332 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
333 | num_samples_per_bucket.append(len_bucket + rem)
334 | return buckets, num_samples_per_bucket
335 |
336 | def __iter__(self):
337 | # deterministically shuffle based on epoch
338 | g = torch.Generator()
339 | g.manual_seed(self.epoch)
340 |
341 | indices = []
342 | if self.shuffle:
343 | for bucket in self.buckets:
344 | indices.append(torch.randperm(len(bucket), generator=g).tolist())
345 | else:
346 | for bucket in self.buckets:
347 | indices.append(list(range(len(bucket))))
348 |
349 | batches = []
350 | for i in range(len(self.buckets)):
351 | bucket = self.buckets[i]
352 | len_bucket = len(bucket)
353 | ids_bucket = indices[i]
354 | num_samples_bucket = self.num_samples_per_bucket[i]
355 |
356 | # add extra samples to make it evenly divisible
357 | rem = num_samples_bucket - len_bucket
358 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
359 |
360 | # subsample
361 | ids_bucket = ids_bucket[self.rank::self.num_replicas]
362 |
363 | # batching
364 | for j in range(len(ids_bucket) // self.batch_size):
365 | batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]]
366 | batches.append(batch)
367 |
368 | if self.shuffle:
369 | batch_ids = torch.randperm(len(batches), generator=g).tolist()
370 | batches = [batches[i] for i in batch_ids]
371 | self.batches = batches
372 |
373 | assert len(self.batches) * self.batch_size == self.num_samples
374 | return iter(self.batches)
375 |
376 | def _bisect(self, x, lo=0, hi=None):
377 | if hi is None:
378 | hi = len(self.boundaries) - 1
379 |
380 | if hi > lo:
381 | mid = (hi + lo) // 2
382 | if self.boundaries[mid] < x and x <= self.boundaries[mid+1]:
383 | return mid
384 | elif x <= self.boundaries[mid]:
385 | return self._bisect(x, lo, mid)
386 | else:
387 | return self._bisect(x, mid + 1, hi)
388 | else:
389 | return -1
390 |
391 | def __len__(self):
392 | return self.num_samples // self.batch_size
393 |
--------------------------------------------------------------------------------
/filelists/ljs_audio_text_val_filelist.txt:
--------------------------------------------------------------------------------
1 | DUMMY1/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read.
2 | DUMMY1/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too.
3 | DUMMY1/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five.
4 | DUMMY1/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect:
5 | DUMMY1/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others.
6 | DUMMY1/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated,
7 | DUMMY1/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others.
8 | DUMMY1/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies
9 | DUMMY1/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery.
10 | DUMMY1/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade.
11 | DUMMY1/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President.
12 | DUMMY1/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four.
13 | DUMMY1/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example."
14 | DUMMY1/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald:
15 | DUMMY1/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here.
16 | DUMMY1/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work.
17 | DUMMY1/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area
18 | DUMMY1/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon
19 | DUMMY1/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote.
20 | DUMMY1/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound,
21 | DUMMY1/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window.
22 | DUMMY1/LJ026-0068.wav|Energy enters the plant, to a small extent,
23 | DUMMY1/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary.
24 | DUMMY1/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized
25 | DUMMY1/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that
26 | DUMMY1/LJ012-0161.wav|he was reported to have fallen away to a shadow.
27 | DUMMY1/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to
28 | DUMMY1/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines.
29 | DUMMY1/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on.
30 | DUMMY1/LJ024-0083.wav|This plan of mine is no attack on the Court;
31 | DUMMY1/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough.
32 | DUMMY1/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup.
33 | DUMMY1/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles.
34 | DUMMY1/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive.
35 | DUMMY1/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen,
36 | DUMMY1/LJ009-0076.wav|We come to the sermon.
37 | DUMMY1/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution.
38 | DUMMY1/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes.
39 | DUMMY1/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art.
40 | DUMMY1/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount
41 | DUMMY1/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy.
42 | DUMMY1/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties,
43 | DUMMY1/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand.
44 | DUMMY1/LJ012-0235.wav|While they were in a state of insensibility the murder was committed.
45 | DUMMY1/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald.
46 | DUMMY1/LJ014-0030.wav|These were damnatory facts which well supported the prosecution.
47 | DUMMY1/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome?
48 | DUMMY1/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters.
49 | DUMMY1/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London,
50 | DUMMY1/LJ028-0275.wav|At last, in the twentieth month,
51 | DUMMY1/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed.
52 | DUMMY1/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm,
53 | DUMMY1/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County,
54 | DUMMY1/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view.
55 | DUMMY1/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning.
56 | DUMMY1/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words,
57 | DUMMY1/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands.
58 | DUMMY1/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy.
59 | DUMMY1/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace,
60 | DUMMY1/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells.
61 | DUMMY1/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true,
62 | DUMMY1/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him,
63 | DUMMY1/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits
64 | DUMMY1/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail.
65 | DUMMY1/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders.
66 | DUMMY1/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal.
67 | DUMMY1/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there.
68 | DUMMY1/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files
69 | DUMMY1/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator.
70 | DUMMY1/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash.
71 | DUMMY1/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood;
72 | DUMMY1/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely.
73 | DUMMY1/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present,
74 | DUMMY1/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him.
75 | DUMMY1/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense.
76 | DUMMY1/LJ008-0294.wav|nearly indefinitely deferred.
77 | DUMMY1/LJ047-0148.wav|On October twenty-five,
78 | DUMMY1/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner.
79 | DUMMY1/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old.
80 | DUMMY1/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male,
81 | DUMMY1/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one.
82 | DUMMY1/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved
83 | DUMMY1/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands.
84 | DUMMY1/LJ012-0250.wav|On the seventh July, eighteen thirty-seven,
85 | DUMMY1/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job.
86 | DUMMY1/LJ016-0138.wav|at a distance from the prison.
87 | DUMMY1/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology.
88 | DUMMY1/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally.
89 | DUMMY1/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline.
90 | DUMMY1/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects.
91 | DUMMY1/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle.
92 | DUMMY1/LJ038-0199.wav|eleven. If I am alive and taken prisoner,
93 | DUMMY1/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came.
94 | DUMMY1/LJ033-0047.wav|I noticed when I went out that the light was on, end quote,
95 | DUMMY1/LJ040-0027.wav|He was never satisfied with anything.
96 | DUMMY1/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly.
97 | DUMMY1/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity.
98 | DUMMY1/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days,
99 | DUMMY1/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston.
100 | DUMMY1/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce.
101 |
--------------------------------------------------------------------------------
/filelists/ljs_audio_text_val_filelist.txt.cleaned:
--------------------------------------------------------------------------------
1 | DUMMY1/LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹɪɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wˌʌt ðeɪ hˈɪɹ ænd wˌʌt ðeɪ ɹˈiːd.
2 | DUMMY1/LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː, ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt, tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ, ænd ˈɔːl ðə fˈɜːnɪtʃɚ, ˈaɪ wʊd biː mˈæd æz hˈɛl, tˈuː.
3 | DUMMY1/LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹɪpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪnkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn eɪtˈiːn θˈɜːɾifˈaɪv.
4 | DUMMY1/LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹɪspˈɛkt:
5 | DUMMY1/LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹɪspˈɛkt wʌz tə θɹˈoʊ ðə ɹɪspˌɑːnsəbˈɪlɪɾi ˌɑːn ˈʌðɚz.
6 | DUMMY1/LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛlɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌnkənvˈɪktᵻd pɹˈɪzənɚ, ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt, ænd stˈɪl ʌnkəntˈæmᵻnˌeɪɾᵻd,
7 | DUMMY1/LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔːstˈeɪʃənɚz. hɪz ɐɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz.
8 | DUMMY1/LJ047-0044.wav|ˈɑːswəld wʌz, haʊˈɛvɚ, wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz. hiː dɪnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz
9 | DUMMY1/LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ tʃˈɑːɹlz dʒˈeɪ. kˈæɹɪkˌoʊ, ɐ ɹˈɛzɪdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi.
10 | DUMMY1/LJ048-0194.wav|dˈʊɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛntitˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd.
11 | DUMMY1/LJ049-0026.wav|ˌɑːn əkˈeɪʒən ðə sˈiːkɹət sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt.
12 | DUMMY1/LJ004-0152.wav|ɑːlðˈoʊ æt mˈɪstɚ bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən, ðə fˈɜːst stˈɛp tʊwˈɔːɹdz ɹɪfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˌɛvəntˈiːn sˈɛvəntifˈoːɹ.
13 | DUMMY1/LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni, ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsəsɚɹi tuː "mˌeɪk ɐn ɛɡzˈæmpəl."
14 | DUMMY1/LJ043-0002.wav|ðə wˈɔːɹən kəmˈɪʃən ɹɪpˈoːɹt. baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɑːnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi. tʃˈæptɚ sˈɛvən. lˈiː hˈɑːɹvi ˈɑːswəld:
15 | DUMMY1/LJ009-0114.wav|mˈɪstɚ wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dɪskɹˈaɪbɪŋ ɐnˈʌðɚ ɹɪlˈɪdʒəs sˈɜːvɪs, wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪətli biː ɪnsˈɜːɾᵻd hˈɪɹ.
16 | DUMMY1/LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk.
17 | DUMMY1/LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd. ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzɪz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə
18 | DUMMY1/LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp, hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪlɪˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən
19 | DUMMY1/LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl, kwˈoʊt, wiː hɐd ɐ mˈoʊɾɚkˌeɪd wɛɹɹˈɛvɚ wiː wˈɛnt, ˈɛnd kwˈoʊt.
20 | DUMMY1/LJ031-0070.wav|dˈɑːktɚ klˈɑːɹk, hˌuː mˈoʊst klˈoʊsli ɑːbzˈɜːvd ðə hˈɛd wˈuːnd,
21 | DUMMY1/LJ034-0198.wav|jˈuːɪnz, hˌuː wʌz ɑːnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstɪfˌaɪd ðæt hiː kʊd nˌɑːt dɪskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ.
22 | DUMMY1/LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt, tʊ ɐ smˈɔːl ɛkstˈɛnt,
23 | DUMMY1/LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɑːnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsəsɚɹi.
24 | DUMMY1/LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd
25 | DUMMY1/LJ005-0014.wav|spˈiːkɪŋ ˌɑːn ɐ dɪbˈeɪt ˌɑːn pɹˈɪzən mˈæɾɚz, hiː dᵻklˈɛɹd ðˈæt
26 | DUMMY1/LJ012-0161.wav|hiː wʌz ɹɪpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ.
27 | DUMMY1/LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹɪpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹɪfˈɜːd tuː
28 | DUMMY1/LJ019-0257.wav|hˈɪɹ ðə tɹˈɛdwˈiːl wʌz ɪn jˈuːs, ðɛɹ sˈɛljʊlɚ kɹˈæŋks, ɔːɹ hˈɑːɹdlˈeɪbɚ məʃˈiːnz.
29 | DUMMY1/LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɑːn.
30 | DUMMY1/LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɑːnðə kˈoːɹt;
31 | DUMMY1/LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz, nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz. ˈaɪ hæv hɐd ɪnˈʌf.
32 | DUMMY1/LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp.
33 | DUMMY1/LJ046-0058.wav|dˈʊɹɪŋ hɪz pɹˈɛzɪdənsi, fɹˈæŋklɪn dˈiː. ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹəd dʒˈɜːnɪz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹəd fˈɪfti θˈaʊzənd mˈaɪlz.
34 | DUMMY1/LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ, ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv.
35 | DUMMY1/LJ002-0043.wav|lˈɑːŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾisˈɪks fˈiːt, sˈɪks twˈɛntiθɹˈiː fˈiːt, ænd ðɪ ˈeɪtθ eɪtˈiːn,
36 | DUMMY1/LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən.
37 | DUMMY1/LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hɐd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹɪpɹˈiːv, ænd wɪðˌɪn ɐ fjˈuː ˈaɪʊɹz ʌv ˌɛksɪkjˈuːʃən.
38 | DUMMY1/LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹət sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹɪlˈiːsd ɔːɹ ɛskˈeɪps.
39 | DUMMY1/LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ, ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt.
40 | DUMMY1/LJ042-0096.wav| ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt
41 | DUMMY1/LJ049-0050.wav|hˈɪl hɐd bˈoʊθ fˈiːt ɑːnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mɪsˈɛs kˈɛnədi.
42 | DUMMY1/LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt, nˈuːɡeɪt ɹɪsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntɪz,
43 | DUMMY1/LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs, ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsən ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd.
44 | DUMMY1/LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd.
45 | DUMMY1/LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kənklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɑːnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld.
46 | DUMMY1/LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən.
47 | DUMMY1/LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt, ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ɪlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm?
48 | DUMMY1/LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪt ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz.
49 | DUMMY1/LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪsd ænd ɹɪpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɪz ʌvðə sˈɪɾi ʌv lˈʌndən,
50 | DUMMY1/LJ028-0275.wav|æt lˈæst, ɪnðə twˈɛntiəθ mˈʌnθ,
51 | DUMMY1/LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋplˈeɪs wɪð ɐ tɹˈæpdˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd.
52 | DUMMY1/LJ011-0096.wav|hiː mˈæɹɪd ɐ lˈeɪdi ˈɑːlsoʊ bɪlˈɑːŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz, hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃən, wˈɪtʃ, ænd hɪz ˈoʊn mˈʌni, hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm,
53 | DUMMY1/LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː. kɹˈeɪɡ, ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti,
54 | DUMMY1/LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz, ɡɹˈeɪt lˈɔɪɚz, ɡˈʌvɚnɚz ʌv pɹˈɪzənz, ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː.
55 | DUMMY1/LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst, ɐ səspˈɪʃəs sˈɜːkəmstˌæns, æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ.
56 | DUMMY1/LJ027-0141.wav|ɪz klˈoʊsli ɹɪpɹədˈuːst ɪnðə lˈaɪfhˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ. ˈɔːɹ, ɪn ˈʌðɚ wˈɜːdz,
57 | DUMMY1/LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi, ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz.
58 | DUMMY1/LJ031-0202.wav|mɪsˈɛs kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hɐd sˈɜːvd ɪnðə nˈeɪvi.
59 | DUMMY1/LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈoʊptfɔːɹ pˈiəɹɪəd ʌv pˈiːs,
60 | DUMMY1/LJ016-0288.wav|"mˈʌlɚ, mˈʌlɚ, hiːz ðə mˈæn," tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz, wˌɪtʃ wʌz ɹɪsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz.
61 | DUMMY1/LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ, wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdɪli dɪstˈɪŋɡwɪʃ ðə fˈɑːls fɹʌmðə tɹˈuː,
62 | DUMMY1/LJ018-0081.wav|hɪz dɪfˈɛns bˌiːɪŋ ðæt hiː hɐd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd, bˌʌt ðˈæt, ɑːnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hɐd ɹˈɔŋd hˌɪm,
63 | DUMMY1/LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪnkɹiːs ɪnðə pˈeɪɹoʊlz, ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts
64 | DUMMY1/LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp, bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd, ænd ðə mˈæn wʌz kˈæɹɪd bˈæk tə dʒˈeɪl.
65 | DUMMY1/LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz, ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz.
66 | DUMMY1/LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən, ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl.
67 | DUMMY1/LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs, ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts, ðə hˈaʊskiːpɚ ðˈɛɹ.
68 | DUMMY1/LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛntitˈuː, naɪntˈiːn sˈɪkstiθɹˈiː, fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈiəɹɪəl fɚðə pˌiːˌɑːɹˈɛs dʒˈɛnɚɹəl fˈaɪlz
69 | DUMMY1/LJ017-0044.wav|ænd ðə dˈiːpəst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm, ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn, ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ.
70 | DUMMY1/LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ, ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn, ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ.
71 | DUMMY1/LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɐɹˈɛstᵻd ˌɑːn səspˈɪʃən, ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd;
72 | DUMMY1/LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn, bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd, ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sɪvˈɪɹli.
73 | DUMMY1/LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹɪhˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ. ɑːlðˈoʊ ɪt ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt,
74 | DUMMY1/LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm.
75 | DUMMY1/LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹɪkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɑːŋ ɪn səspˈɛns.
76 | DUMMY1/LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dɪfˈɜːd.
77 | DUMMY1/LJ047-0148.wav|ˌɑːn ɑːktˈoʊbɚ twˈɛntifˈaɪv,
78 | DUMMY1/LJ008-0111.wav|ðeɪ ˈɛntɚd ˈeɪ "stˈoʊn kˈoʊld ɹˈuːm," ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ.
79 | DUMMY1/LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstɪfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld.
80 | DUMMY1/LJ037-0234.wav|mɪsˈɛs mˈɛɹi bɹˈɑːk, ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən, wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl,
81 | DUMMY1/LJ040-0002.wav|tʃˈæptɚ sˈɛvən. lˈiː hˈɑːɹvi ˈɑːswəld: bˈækɡɹaʊnd ænd pˈɑːsəbəl mˈoʊɾɪvz, pˈɑːɹt wˌʌn.
82 | DUMMY1/LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstɪfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bɪkˈʌmɪŋ ɪnvˈɑːlvd
83 | DUMMY1/LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɑːn wˈɑːtʃᵻz, wɜː kˈɛɹfəli ɹɪmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz.
84 | DUMMY1/LJ012-0250.wav|ɑːnðə sˈɛvənθ dʒuːlˈaɪ, eɪtˈiːn θˈɜːɾisˈɛvən,
85 | DUMMY1/LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈɛnɚz tə wˈɜːk baɪ ðə dʒˈɑːb.
86 | DUMMY1/LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən.
87 | DUMMY1/LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ɪsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi.
88 | DUMMY1/LJ031-0134.wav|ˌɑːn wˈʌn əkˈeɪʒən mɪsˈɛs dʒˈɑːnsən, ɐkˈʌmpənɪd baɪ tˈuː sˈiːkɹət sˈɜːvɪs ˈeɪdʒənts, lˈɛft ðə ɹˈuːm tə sˈiː mɪsˈɛs kˈɛnədi ænd mɪsˈɛs kənˈæli.
89 | DUMMY1/LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn.
90 | DUMMY1/LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd, ˈoʊpənd, ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts.
91 | DUMMY1/LJ034-0160.wav|ˌɑːn bɹˈɛnənz sˈʌbsɪkwənt sˈɜːtən aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl.
92 | DUMMY1/LJ038-0199.wav|ɪlˈɛvən. ɪf ˈaɪ æm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ,
93 | DUMMY1/LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈɐd fɔːɹ hˌɪm, ænd ɹɪmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm.
94 | DUMMY1/LJ033-0047.wav|ˈaɪ nˈoʊɾɪsd wɛn ˈaɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɑːn, ˈɛnd kwˈoʊt,
95 | DUMMY1/LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ.
96 | DUMMY1/LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli.
97 | DUMMY1/LJ003-0111.wav|hiː wʌz ɪn kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː, ˈɛnd kwˈoʊt. ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɪˈɑːsɪɾi.
98 | DUMMY1/LJ008-0258.wav|lˈɛt mˌiː ɹɪtɹˈeɪs maɪ stˈɛps, ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz,
99 | DUMMY1/LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæn kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt, mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs, fˈɔːɹt wˈɜːθ, sˌæn æntˈoʊnɪˌoʊ, ænd hjˈuːstən.
100 | DUMMY1/LJ004-0045.wav|mˈɪstɚ stˈɜːdʒᵻz bˈoːɹn, sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ, sˌɜː dʒˈeɪmz skˈɑːɹlɪt, ænd wˈɪljəm wˈɪlbɚfˌoːɹs.
101 |
--------------------------------------------------------------------------------
/filelists/vctk_audio_sid_text_test_filelist.txt:
--------------------------------------------------------------------------------
1 | DUMMY2/p229/p229_128.wav|67|The whole process is a vicious circle at the moment.
2 | DUMMY2/p234/p234_112.wav|3|That would be a serious problem.
3 | DUMMY2/p298/p298_125.wav|68|I asked why he had come.
4 | DUMMY2/p283/p283_318.wav|95|If not, he should go home.
5 | DUMMY2/p260/p260_046.wav|81|It is marvellous.
6 | DUMMY2/p281/p281_306.wav|36|These figures are truly awful.
7 | DUMMY2/p285/p285_247.wav|2|Now, suddenly, we have this new landscape.
8 | DUMMY2/p237/p237_180.wav|61|A helpline number is published at the end of this article.
9 | DUMMY2/p259/p259_052.wav|7|Maybe full-time referees will provide the answer.)
10 | DUMMY2/p314/p314_053.wav|51|Rangers deserved to beat us.
11 | DUMMY2/p345/p345_070.wav|82|I haven't made any definite decisions.
12 | DUMMY2/p269/p269_132.wav|94|Who will attend?
13 | DUMMY2/p347/p347_295.wav|46|It is typical of me.
14 | DUMMY2/p251/p251_223.wav|9|For the refugees, the return will not come a moment too soon.
15 | DUMMY2/p300/p300_224.wav|102|There is nothing like this back home.
16 | DUMMY2/p276/p276_076.wav|106|He confirmed that the document was valid.
17 | DUMMY2/p294/p294_271.wav|104|I also thought, This is a feature film.
18 | DUMMY2/p259/p259_257.wav|7|The amount of alcohol as a whole was very high.)
19 | DUMMY2/p248/p248_131.wav|99|The whole thing of doing the movie was a risk.
20 | DUMMY2/p334/p334_023.wav|38|If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow.
21 | DUMMY2/p345/p345_386.wav|82|It is quite simple.
22 | DUMMY2/p330/p330_382.wav|1|Neither was involved in violence.
23 | DUMMY2/p246/p246_133.wav|5|My daughter is an adult.
24 | DUMMY2/p257/p257_140.wav|105|It's not true.
25 | DUMMY2/p340/p340_011.wav|74|When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow.
26 | DUMMY2/p284/p284_409.wav|16|Then , he laughs.
27 | DUMMY2/p317/p317_129.wav|97|You would be wrong.
28 | DUMMY2/p279/p279_183.wav|25|Government will intervene.
29 | DUMMY2/p376/p376_273.wav|71|"If not, he should go home."
30 | DUMMY2/p233/p233_109.wav|84|It is not affected by the sale.
31 | DUMMY2/p234/p234_118.wav|3|When we looked at the company.
32 | DUMMY2/p336/p336_207.wav|98|The train was on time.
33 | DUMMY2/p227/p227_213.wav|29|What are you not good at ?
34 | DUMMY2/p347/p347_113.wav|46|They are all Arabs.
35 | DUMMY2/p317/p317_125.wav|97|Alan Milburn, the health secretary, refused to comment.
36 | DUMMY2/p341/p341_031.wav|66|I was left-handed, but it was just a matter of practice.
37 | DUMMY2/p244/p244_338.wav|78|This isn't a betrayal of public services, it's their renewal.
38 | DUMMY2/p250/p250_288.wav|24|Is it in the right place ?
39 | DUMMY2/p233/p233_156.wav|84|It opens the door to the Champions League.
40 | DUMMY2/p334/p334_118.wav|38|The sanctions are about collective punishment.
41 | DUMMY2/p258/p258_027.wav|26|People come into the Borders for the beauty of the background.
42 | DUMMY2/p341/p341_187.wav|66|His signature is his handwriting.
43 | DUMMY2/p258/p258_347.wav|26|The composer will conduct.
44 | DUMMY2/p262/p262_005.wav|45|She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.
45 | DUMMY2/p231/p231_174.wav|50|One season, they might do well.
46 | DUMMY2/p363/p363_285.wav|6|But he was far from alone.
47 | DUMMY2/p303/p303_113.wav|44|Winning, meanwhile, is headed back to New York City.
48 | DUMMY2/p274/p274_181.wav|32|Is it in the right place ?)
49 | DUMMY2/p297/p297_023.wav|42|If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow.
50 | DUMMY2/p247/p247_065.wav|14|We will pay their bills.)
51 | DUMMY2/p273/p273_105.wav|56|The pressure is on them.
52 | DUMMY2/p245/p245_167.wav|59|It was an odd affair, in many respects.
53 | DUMMY2/p364/p364_239.wav|88|It was a long time coming.
54 | DUMMY2/p263/p263_047.wav|39|The Yugoslav president said he did not recognise the election outcome.
55 | DUMMY2/p283/p283_333.wav|95|No final decision has been taken.
56 | DUMMY2/p335/p335_313.wav|49|The issues are very intense.
57 | DUMMY2/p280/p280_172.wav|52|He said some things which were better left alone.
58 | DUMMY2/p266/p266_006.wav|20|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.
59 | DUMMY2/p260/p260_027.wav|81|Is this accurate?
60 | DUMMY2/p326/p326_214.wav|28|It's not long enough.
61 | DUMMY2/p259/p259_253.wav|7|You are like an animal.)
62 | DUMMY2/p228/p228_109.wav|57|However, the intensive care unit at the Southern General Hospital was full.
63 | DUMMY2/p376/p376_228.wav|71|"Half of young people had had contact with the police."
64 | DUMMY2/p361/p361_057.wav|79|That was something else.
65 | DUMMY2/p341/p341_058.wav|66|Labour accused the Tory leader of panicking.
66 | DUMMY2/p363/p363_247.wav|6|We are taking no chances this time.
67 | DUMMY2/p262/p262_054.wav|45|Already, he has been a tremendous influence in the dressing room.
68 | DUMMY2/p238/p238_090.wav|37|He thought she was amazing.
69 | DUMMY2/p306/p306_020.wav|12|Many complicated ideas about the rainbow have been formed.
70 | DUMMY2/p238/p238_339.wav|37|Harry Potter has lost his magic.
71 | DUMMY2/p302/p302_285.wav|30|Others said they had been beaten by police.
72 | DUMMY2/p275/p275_377.wav|40|Family liaison officers are now working to support the family.
73 | DUMMY2/p267/p267_286.wav|0|And they were being paid ?
74 | DUMMY2/p243/p243_090.wav|53|Among them was Gary Robertson from Dundee.
75 | DUMMY2/p274/p274_213.wav|32|It's easy to be negative about these things.)
76 | DUMMY2/p286/p286_310.wav|63|But it has been an amazing experience.
77 | DUMMY2/p294/p294_293.wav|104|That case has still not been settled.
78 | DUMMY2/p273/p273_174.wav|56|Two years later, she was dead.
79 | DUMMY2/p231/p231_408.wav|50|I should think so, too.
80 | DUMMY2/p323/p323_084.wav|34|And, within itself, it is visionary.
81 | DUMMY2/p248/p248_025.wav|99|She is given a new deputy minister for transport and planning.
82 | DUMMY2/p288/p288_197.wav|47|They are in the euro.
83 | DUMMY2/p300/p300_029.wav|102|Of course, this is nice to hear.
84 | DUMMY2/p299/p299_344.wav|58|He will never walk the streets again.
85 | DUMMY2/p376/p376_168.wav|71|"They will do their own thing."
86 | DUMMY2/p275/p275_277.wav|40|He looked very sharp.
87 | DUMMY2/p312/p312_022.wav|62|The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.
88 | DUMMY2/p278/p278_093.wav|10|It was some time before she found out he was safe.
89 | DUMMY2/p302/p302_312.wav|30|However, there was no hope, and glory too, for Scotland.
90 | DUMMY2/p236/p236_368.wav|75|It was like a weekly wage.
91 | DUMMY2/p237/p237_056.wav|61|No-one has appeared in court in relation to her death.
92 | DUMMY2/p305/p305_162.wav|54|For starters, many of the Scotland team didn't turn up.
93 | DUMMY2/p275/p275_018.wav|40|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.
94 | DUMMY2/p310/p310_039.wav|17|But one shouldn't go by that.
95 | DUMMY2/p299/p299_310.wav|58|Farmers have been an endangered species.
96 | DUMMY2/p259/p259_428.wav|7|In general terms, the proposals are very much in line with expectations.)
97 | DUMMY2/p339/p339_155.wav|18|It is just a matter of time.
98 | DUMMY2/p229/p229_347.wav|67|I've got no secret.
99 | DUMMY2/p256/p256_308.wav|90|Let's hope it's an investment in the future.
100 | DUMMY2/p360/p360_204.wav|60|It is dangerous and it is a lie.
101 | DUMMY2/p238/p238_208.wav|37|The refund is fully justified.
102 | DUMMY2/p341/p341_319.wav|66|This is very bad news.
103 | DUMMY2/p336/p336_399.wav|98|The pressure is enormous.
104 | DUMMY2/p229/p229_067.wav|67|That was the easy election.
105 | DUMMY2/p329/p329_159.wav|103|No-one, not even the Scottish Arts Council, was interested in her.
106 | DUMMY2/p258/p258_304.wav|26|The confidence is low, but it is a difficult thing to understand.
107 | DUMMY2/p312/p312_033.wav|62|Haven't been so lucky since.
108 | DUMMY2/p266/p266_093.wav|20|Two other men, including the taxi driver, were wounded in the attack.
109 | DUMMY2/p307/p307_396.wav|22|It was early morning.
110 | DUMMY2/p326/p326_039.wav|28|There is nothing like this back home.
111 | DUMMY2/p333/p333_009.wav|64|There is , according to legend, a boiling pot of gold at one end.
112 | DUMMY2/p295/p295_154.wav|92|I remember it clearly.
113 | DUMMY2/p297/p297_007.wav|42|The rainbow is a division of white light into many beautiful colors.
114 | DUMMY2/p233/p233_153.wav|84|It would be a last resort.
115 | DUMMY2/p244/p244_220.wav|78|This year, it will amount to a few hundred thousand pounds.
116 | DUMMY2/p267/p267_136.wav|0|It has become a way of life.
117 | DUMMY2/p311/p311_313.wav|4|The decision was left entirely to him.
118 | DUMMY2/p230/p230_113.wav|35|You can spend money on housing.
119 | DUMMY2/p318/p318_295.wav|19|We gave them the goal.
120 | DUMMY2/p236/p236_090.wav|75|After the match, do you ?
121 | DUMMY2/p364/p364_156.wav|88|Ferguson had done his homework.
122 | DUMMY2/p310/p310_260.wav|17|That is my role.
123 | DUMMY2/p323/p323_261.wav|34|They hoped to remain in the Edinburgh area.
124 | DUMMY2/p284/p284_393.wav|16|On the contrary, it was actually very funny.
125 | DUMMY2/p276/p276_460.wav|106|We will pay their bills.
126 | DUMMY2/p363/p363_273.wav|6|The plot is minimal.
127 | DUMMY2/p250/p250_039.wav|24|Costs have got to be controlled.
128 | DUMMY2/p317/p317_244.wav|97|This event allows us to emphasise the positive.
129 | DUMMY2/p280/p280_042.wav|52|He does not even trust his own members.
130 | DUMMY2/p227/p227_342.wav|29|Have a look at this lot.
131 | DUMMY2/p333/p333_255.wav|64|But the Foreign Secretary can cope.
132 | DUMMY2/p232/p232_103.wav|96|We recognise the important role of golf in attracting visitors.
133 | DUMMY2/p305/p305_138.wav|54|Another suggested the company should carry only pedestrians.
134 | DUMMY2/p248/p248_196.wav|99|It is not satisfied with the standard of fire safety provisions.
135 | DUMMY2/p230/p230_166.wav|35|We believe in the medium term.
136 | DUMMY2/p303/p303_275.wav|44|There are lots of these women in Finland.
137 | DUMMY2/p280/p280_208.wav|52|Gas production was also at record levels last year.
138 | DUMMY2/p330/p330_252.wav|1|I first met him last summer.
139 | DUMMY2/p330/p330_209.wav|1|I will not take you out of context.
140 | DUMMY2/p240/p240_214.wav|93|It had all been arranged.
141 | DUMMY2/p293/p293_185.wav|23|This is the stuff of live music.
142 | DUMMY2/p237/p237_230.wav|61|Clearly, the stakes are high.
143 | DUMMY2/p277/p277_014.wav|89|To the Hebrews it was a token that there would be no more universal floods.
144 | DUMMY2/p251/p251_107.wav|9|It had been played at festivals.
145 | DUMMY2/p302/p302_011.wav|30|When a man looks for something beyond his reach, his friends say he's looking for the pot of gold at the end of the rainbow.
146 | DUMMY2/p264/p264_147.wav|65|She's been shot.)
147 | DUMMY2/p236/p236_288.wav|75|Brown is an interesting man, but he is not desperate.
148 | DUMMY2/p323/p323_297.wav|34|I've got the shirt.
149 | DUMMY2/p297/p297_402.wav|42|I don't have a problem with getting older.
150 | DUMMY2/p267/p267_182.wav|0|A team is a team.
151 | DUMMY2/p226/p226_121.wav|43|Maybe this battle has been.
152 | DUMMY2/p311/p311_226.wav|4|Gone with them is any sense of narrative.
153 | DUMMY2/p335/p335_279.wav|49|Or rather he did and he didn't.
154 | DUMMY2/p270/p270_068.wav|8|Then followed a bout of flu.
155 | DUMMY2/p260/p260_072.wav|81|It was magic.
156 | DUMMY2/p362/p362_341.wav|15|The result could be all down to turnout.
157 | DUMMY2/p228/p228_180.wav|57|One season, they might do well.
158 | DUMMY2/p316/p316_152.wav|85|Failure is not an option.
159 | DUMMY2/p317/p317_423.wav|97|Manchester United are the classic example.
160 | DUMMY2/p243/p243_292.wav|53|Its work includes dealing with child abuse.
161 | DUMMY2/p362/p362_054.wav|15|We certainly hope we have been successful.
162 | DUMMY2/p243/p243_305.wav|53|What happened in that game ?
163 | DUMMY2/p364/p364_297.wav|88|It was just one man.
164 | DUMMY2/p255/p255_049.wav|31|We were surprised to see the photograph.
165 | DUMMY2/p297/p297_358.wav|42|He said he had no reports of casualties.
166 | DUMMY2/p283/p283_430.wav|95|My aim is a top six finish.
167 | DUMMY2/p310/p310_300.wav|17|Mike Tyson went to prison.
168 | DUMMY2/p363/p363_051.wav|6|The nation has his music.
169 | DUMMY2/p261/p261_112.wav|100|It became a national network.
170 | DUMMY2/p234/p234_036.wav|3|It was sold at a loss.
171 | DUMMY2/p247/p247_470.wav|14|They were good years for him.)
172 | DUMMY2/p303/p303_269.wav|44|After that nothing could save him.
173 | DUMMY2/p317/p317_256.wav|97|The man was pronounced dead on arrival.
174 | DUMMY2/p351/p351_161.wav|33|Paterson can afford to be generous.
175 | DUMMY2/p314/p314_295.wav|51|You take a risk.
176 | DUMMY2/p293/p293_268.wav|23|Our children are our future.
177 | DUMMY2/p306/p306_352.wav|12|Who has the second highest?
178 | DUMMY2/p273/p273_098.wav|56|The following are the principal provisions.
179 | DUMMY2/p285/p285_029.wav|2|Their courage, and their honesty, should be respected.
180 | DUMMY2/p266/p266_073.wav|20|It works for us.
181 | DUMMY2/p374/p374_288.wav|11|I had a good life at Rangers.
182 | DUMMY2/p280/p280_171.wav|52|You need a long-term strategy in football.
183 | DUMMY2/p239/p239_203.wav|48|It is all to do with the coaching.
184 | DUMMY2/p287/p287_292.wav|77|In essence, the teaching profession has a choice.
185 | DUMMY2/p330/p330_112.wav|1|Wallace was in at the deep end.
186 | DUMMY2/p247/p247_141.wav|14|They made such decisions in London.)
187 | DUMMY2/p277/p277_050.wav|89|This represents a tough game for us.
188 | DUMMY2/p233/p233_289.wav|84|He looked very sharp.
189 | DUMMY2/p284/p284_103.wav|16|Meanwhile, the Scottish Consumer Council yesterday offered support for the new Bill.
190 | DUMMY2/p334/p334_366.wav|38|We will miss him very much.
191 | DUMMY2/p238/p238_196.wav|37|Tiger is not the norm.
192 | DUMMY2/p304/p304_193.wav|72|Then they were awarded a penalty.
193 | DUMMY2/p229/p229_348.wav|67|Look at the witnesses.
194 | DUMMY2/p268/p268_147.wav|87|how do you get it back ?
195 | DUMMY2/p293/p293_348.wav|23|He quit in October.
196 | DUMMY2/p341/p341_082.wav|66|John Reid, the Northern Ireland secretary, yesterday appealed for restraint.
197 | DUMMY2/p258/p258_097.wav|26|It is a good lifestyle.
198 | DUMMY2/p340/p340_220.wav|74|Not so, it seems.
199 | DUMMY2/p269/p269_174.wav|94|Mark Fisher was a guest of the Northern Ireland Tourist Board.
200 | DUMMY2/p270/p270_078.wav|8|I've had it for the exams.
201 | DUMMY2/p334/p334_224.wav|38|I can't blame the fans.
202 | DUMMY2/p307/p307_306.wav|22|We're talking about creating an attractive neighbourhood.
203 | DUMMY2/p361/p361_205.wav|79|Translation - we got it wrong.
204 | DUMMY2/p229/p229_142.wav|67|What will happen then ?
205 | DUMMY2/p310/p310_221.wav|17|We will look into it.
206 | DUMMY2/p232/p232_357.wav|96|He had played well in that central role.
207 | DUMMY2/p263/p263_389.wav|39|This season has been a nightmare.
208 | DUMMY2/p283/p283_273.wav|95|Did he trip ?
209 | DUMMY2/p374/p374_277.wav|11|Where do you start?
210 | DUMMY2/p301/p301_289.wav|91|Children are using books in a terrible condition.
211 | DUMMY2/p345/p345_267.wav|82|Her presence was almost everywhere.
212 | DUMMY2/p264/p264_226.wav|65|No partners would lose their jobs.)
213 | DUMMY2/p253/p253_050.wav|70|A neighbour said.
214 | DUMMY2/p276/p276_118.wav|106|If they liked it then I'll be happy.
215 | DUMMY2/p295/p295_175.wav|92|Anything that can be done, the Government will do.
216 | DUMMY2/p247/p247_466.wav|14|I think it's a great system.)
217 | DUMMY2/p301/p301_182.wav|91|I see social work as a vocation, a commitment.
218 | DUMMY2/p294/p294_156.wav|104|We are well insured.
219 | DUMMY2/p287/p287_190.wav|77|We have to recognise that he is an elusive character.
220 | DUMMY2/p258/p258_333.wav|26|Robert is a special talent.
221 | DUMMY2/p275/p275_122.wav|40|Who would have?
222 | DUMMY2/p231/p231_259.wav|50|It was the climax of the thing.
223 | DUMMY2/p330/p330_073.wav|1|Over time, with patience and precision, the terrorists will be pursued.
224 | DUMMY2/p277/p277_239.wav|89|I should think so, too.
225 | DUMMY2/p374/p374_352.wav|11|As if they ever stopped.
226 | DUMMY2/p244/p244_258.wav|78|If it doesn't, it doesn't.
227 | DUMMY2/p277/p277_194.wav|89|I would think about the end of January, the beginning of February.
228 | DUMMY2/p241/p241_177.wav|86|The clarity is vital.
229 | DUMMY2/p247/p247_275.wav|14|What form did that take ?)
230 | DUMMY2/p230/p230_230.wav|35|That has been the easy part.
231 | DUMMY2/p323/p323_015.wav|34|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.
232 | DUMMY2/p269/p269_365.wav|94|But the real problem is the closure of the export market.
233 | DUMMY2/p310/p310_049.wav|17|They had four children together.
234 | DUMMY2/p281/p281_068.wav|36|I have proved that in the past.
235 | DUMMY2/p343/p343_162.wav|21|Dancing was her life.
236 | DUMMY2/p299/p299_208.wav|58|I'm a bit annoyed.
237 | DUMMY2/p329/p329_292.wav|103|The methadone programme is completely out of control.
238 | DUMMY2/p232/p232_376.wav|96|He could make it.
239 | DUMMY2/p305/p305_135.wav|54|I COULD hardly keep up with Professor McKean.
240 | DUMMY2/p351/p351_231.wav|33|We are pursuing legal action against the government.
241 | DUMMY2/p265/p265_153.wav|73|Military action is the only option we have on the table today.
242 | DUMMY2/p323/p323_137.wav|34|Everything was a dead end.
243 | DUMMY2/p305/p305_176.wav|54|That has given me great confidence.
244 | DUMMY2/p238/p238_053.wav|37|Does it matter ?
245 | DUMMY2/p230/p230_195.wav|35|It is not all good news and relief for Labour, however.
246 | DUMMY2/p238/p238_093.wav|37|He seems to have everything.
247 | DUMMY2/p259/p259_323.wav|7|We feel very comfortable in this international environment.)
248 | DUMMY2/p285/p285_032.wav|2|This is the window.
249 | DUMMY2/p302/p302_208.wav|30|Which means it matters.
250 | DUMMY2/p231/p231_176.wav|50|This much I can tell you.
251 | DUMMY2/p301/p301_054.wav|91|Here he is, in effect, appointing himself a judge.
252 | DUMMY2/p310/p310_102.wav|17|We shall rely on human beings.
253 | DUMMY2/p305/p305_121.wav|54|It is a vicious circle.
254 | DUMMY2/p231/p231_458.wav|50|She has reached the top of her profession.
255 | DUMMY2/p311/p311_024.wav|4|This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.
256 | DUMMY2/p245/p245_248.wav|59|Is there on his hands?
257 | DUMMY2/p333/p333_311.wav|64|Councillor Gordon has refused to stand down.
258 | DUMMY2/p299/p299_007.wav|58|The rainbow is a division of white light into many beautiful colors.
259 | DUMMY2/p229/p229_333.wav|67|She did not attend the courtroom.
260 | DUMMY2/p307/p307_286.wav|22|The board would report to the Scottish Parliament.
261 | DUMMY2/p305/p305_414.wav|54|It is the Holiday programme with a mortgage.
262 | DUMMY2/p264/p264_140.wav|65|He nearly killed my son.)
263 | DUMMY2/p374/p374_114.wav|11|I don't think it would make any difference.
264 | DUMMY2/p363/p363_369.wav|6|We have been overwhelmed by the response.
265 | DUMMY2/p293/p293_374.wav|23|I don't think the referees are against us.
266 | DUMMY2/p316/p316_329.wav|85|It's a production company.
267 | DUMMY2/p236/p236_018.wav|75|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.
268 | DUMMY2/p234/p234_332.wav|3|But it can be done.
269 | DUMMY2/p277/p277_132.wav|89|No production was achieved.
270 | DUMMY2/p326/p326_205.wav|28|As agreed, the prime minister was driven to Westminster Hall.
271 | DUMMY2/p272/p272_134.wav|69|This tour is critical for New Zealand rugby.
272 | DUMMY2/p316/p316_125.wav|85|And now the pressure is off.
273 | DUMMY2/p274/p274_149.wav|32|I prefer the clarity of the existing system.)
274 | DUMMY2/p227/p227_368.wav|29|A crucial moment has arrived.
275 | DUMMY2/p334/p334_206.wav|38|We'll have to work hard today.
276 | DUMMY2/p339/p339_087.wav|18|I am not completely insane.
277 | DUMMY2/p286/p286_453.wav|63|He was said to be emotionally disturbed.
278 | DUMMY2/p301/p301_110.wav|91|People want to see me on the screen.
279 | DUMMY2/p282/p282_188.wav|83|Suddenly, the rugby world had changed.
280 | DUMMY2/p263/p263_147.wav|39|Losing in that manner is very hard to take.
281 | DUMMY2/p256/p256_253.wav|90|I was never going to play against Scotland.
282 | DUMMY2/p374/p374_165.wav|11|Something has got to change.
283 | DUMMY2/p262/p262_232.wav|45|It's very safe.
284 | DUMMY2/p267/p267_417.wav|0|This is no reflection on Rangers.
285 | DUMMY2/p240/p240_078.wav|93|I've got the shirt.
286 | DUMMY2/p347/p347_143.wav|46|There is no sign of anyone being hurt.
287 | DUMMY2/p245/p245_069.wav|59|She died in hospital two hours later.
288 | DUMMY2/p233/p233_172.wav|84|They say that vital evidence was not heard in court.
289 | DUMMY2/p280/p280_282.wav|52|Overall, the last hole was good to the women.
290 | DUMMY2/p298/p298_364.wav|68|I'm looking at ways to do that now.
291 | DUMMY2/p339/p339_240.wav|18|The teacher would have approved.
292 | DUMMY2/p361/p361_387.wav|79|We have been going for three years.
293 | DUMMY2/p278/p278_221.wav|10|January is a bad time of year.
294 | DUMMY2/p334/p334_289.wav|38|They married in August last year.
295 | DUMMY2/p250/p250_187.wav|24|This championship is different from the other majors.
296 | DUMMY2/p248/p248_283.wav|99|Maloney is an engaging talent.
297 | DUMMY2/p275/p275_261.wav|40|It will be done in stages.
298 | DUMMY2/p288/p288_024.wav|47|This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.
299 | DUMMY2/p271/p271_454.wav|27|It's a miracle.
300 | DUMMY2/p252/p252_408.wav|55|They had to have hospital treatment.
301 | DUMMY2/p261/p261_192.wav|100|It was a pre-emptive strike.
302 | DUMMY2/p308/p308_099.wav|107|Glasgow deserved their win, but we made them look good.
303 | DUMMY2/p288/p288_070.wav|47|Neither it is.
304 | DUMMY2/p317/p317_356.wav|97|We're not an employment agency.
305 | DUMMY2/p351/p351_251.wav|33|That is a matter for the Scottish Parliament.
306 | DUMMY2/p329/p329_075.wav|103|This will be no easy option.
307 | DUMMY2/p261/p261_180.wav|100|I've been in two finals, and I've got a medal.
308 | DUMMY2/p301/p301_272.wav|91|Drink and petrol prices remain untouched.
309 | DUMMY2/p277/p277_404.wav|89|Whether the High Court will interfere with the sentence is another matter.
310 | DUMMY2/p301/p301_135.wav|91|How good is Lennox Lewis?
311 | DUMMY2/p246/p246_333.wav|5|Two people were interviewed.
312 | DUMMY2/p340/p340_250.wav|74|The film was great.
313 | DUMMY2/p268/p268_355.wav|87|They had declined in each of the two preceding quarters.
314 | DUMMY2/p236/p236_143.wav|75|The whole industry is a shambles.
315 | DUMMY2/p231/p231_398.wav|50|They have failed to deliver.
316 | DUMMY2/p340/p340_322.wav|74|I am extremely cautious.
317 | DUMMY2/p228/p228_048.wav|57|The Scottish Parliament is also looking at similar measures.
318 | DUMMY2/p334/p334_193.wav|38|It is in our own hands.
319 | DUMMY2/p226/p226_128.wav|43|I felt very strongly that England should have it.
320 | DUMMY2/p279/p279_064.wav|25|We have not given up hope.
321 | DUMMY2/p304/p304_416.wav|72|He took over our lives.
322 | DUMMY2/p313/p313_119.wav|76|O Neill is reputed to have replied.
323 | DUMMY2/p287/p287_195.wav|77|It comes from reflection or thinking.
324 | DUMMY2/p234/p234_008.wav|3|These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon.
325 | DUMMY2/p277/p277_119.wav|89|Naturally, it was not difficult to find support for these proposals.
326 | DUMMY2/p281/p281_394.wav|36|What are they for ?
327 | DUMMY2/p287/p287_272.wav|77|And they were being paid ?
328 | DUMMY2/p288/p288_071.wav|47|He seemed to lose his focus.
329 | DUMMY2/p335/p335_245.wav|49|A friendship that will endure.
330 | DUMMY2/p239/p239_061.wav|48|All manner of precaution and protection are taken.
331 | DUMMY2/p254/p254_003.wav|41|Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob.
332 | DUMMY2/p259/p259_282.wav|7|Washington is consumed by the crisis.)
333 | DUMMY2/p253/p253_202.wav|70|Sadly, it can't.
334 | DUMMY2/p318/p318_333.wav|19|But when we do it is great.
335 | DUMMY2/p351/p351_363.wav|33|You will never forget the clutching horror.
336 | DUMMY2/p241/p241_374.wav|86|There is no signature.
337 | DUMMY2/p272/p272_216.wav|69|The report is due out next month.
338 | DUMMY2/p330/p330_355.wav|1|I've got my own ideas.
339 | DUMMY2/p270/p270_179.wav|8|The outcome is now in our own hands.
340 | DUMMY2/p257/p257_079.wav|105|It is not long term, but I need time to recover.
341 | DUMMY2/p257/p257_027.wav|105|They should have a major rethink about the event for next year.
342 | DUMMY2/p279/p279_118.wav|25|Does this mean.
343 | DUMMY2/p334/p334_058.wav|38|Hopefully, it will be built by next year.
344 | DUMMY2/p363/p363_178.wav|6|That was a huge experience.
345 | DUMMY2/p376/p376_227.wav|71|"And thought we would get away with it."
346 | DUMMY2/p330/p330_411.wav|1|You know, he was struggling with his game all week.
347 | DUMMY2/p326/p326_316.wav|28|It certainly sounded it at times.
348 | DUMMY2/p323/p323_048.wav|34|Mackie was at home, unable to watch.
349 | DUMMY2/p313/p313_422.wav|76|Now, that is a good deal.
350 | DUMMY2/p364/p364_113.wav|88|It was just great.
351 | DUMMY2/p286/p286_414.wav|63|It's just a training thing.
352 | DUMMY2/p288/p288_229.wav|47|However, no further action was taken by police.
353 | DUMMY2/p259/p259_142.wav|7|What happened in that game ?)
354 | DUMMY2/p297/p297_118.wav|42|It's too big a risk to take.
355 | DUMMY2/p313/p313_209.wav|76|The night is young.
356 | DUMMY2/p303/p303_279.wav|44|I bought a car at auction.
357 | DUMMY2/p345/p345_166.wav|82|Miller was every bit as happy.
358 | DUMMY2/p333/p333_289.wav|64|It's going to be quite a challenge.
359 | DUMMY2/p336/p336_323.wav|98|One paper was not returned.
360 | DUMMY2/p271/p271_082.wav|27|He is in the queue.
361 | DUMMY2/p314/p314_175.wav|51|There is no substitute.
362 | DUMMY2/p248/p248_124.wav|99|I can't even get into the A team.
363 | DUMMY2/p297/p297_160.wav|42|Tax is a matter for national governments.
364 | DUMMY2/p236/p236_299.wav|75|how do you get it back ?
365 | DUMMY2/p248/p248_300.wav|99|It wasn't just the character and energy of the playing.
366 | DUMMY2/p231/p231_429.wav|50|He is on the wrong side.
367 | DUMMY2/p250/p250_368.wav|24|We put our bid in last night.
368 | DUMMY2/p376/p376_191.wav|71|"I am totally surprised."
369 | DUMMY2/p250/p250_419.wav|24|She started to put on weight.
370 | DUMMY2/p239/p239_037.wav|48|He works at the airport.
371 | DUMMY2/p340/p340_165.wav|74|He was very fit.
372 | DUMMY2/p339/p339_258.wav|18|There are not too many like him.
373 | DUMMY2/p326/p326_266.wav|28|It may also be her last.
374 | DUMMY2/p231/p231_472.wav|50|He felt it was the right time.
375 | DUMMY2/p261/p261_411.wav|100|I had a fortunate war.
376 | DUMMY2/p272/p272_359.wav|69|Now, though, he has an incentive.
377 | DUMMY2/p340/p340_015.wav|74|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.
378 | DUMMY2/p283/p283_022.wav|95|The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.
379 | DUMMY2/p281/p281_334.wav|36|However, the groups denied the claims.
380 | DUMMY2/p318/p318_223.wav|19|We remain committed to it, as does the government.
381 | DUMMY2/p281/p281_039.wav|36|This film will be totally awesome.
382 | DUMMY2/p270/p270_013.wav|8|Some have accepted it as a miracle without physical explanation.
383 | DUMMY2/p243/p243_047.wav|53|However, there is an issue, isn't there ?
384 | DUMMY2/p374/p374_122.wav|11|The course is in great condition.
385 | DUMMY2/p302/p302_040.wav|30|On fuel, the Chancellor has a number of options.
386 | DUMMY2/p254/p254_231.wav|41|And thought we would get away with it.
387 | DUMMY2/p246/p246_222.wav|5|It's not before time.
388 | DUMMY2/p262/p262_044.wav|45|It is difficult for Ali.
389 | DUMMY2/p270/p270_005.wav|8|She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.
390 | DUMMY2/p274/p274_340.wav|32|This is a historic occasion.)
391 | DUMMY2/p329/p329_045.wav|103|I hope you will leave it at that.
392 | DUMMY2/p285/p285_188.wav|2|Any change would be subject to the Scottish Parliament's approval.
393 | DUMMY2/p260/p260_193.wav|81|The Shadow Chancellor is away on holiday.
394 | DUMMY2/p259/p259_371.wav|7|He was unable to come.)
395 | DUMMY2/p275/p275_052.wav|40|Several other pupils and staff were seriously injured in the accident.
396 | DUMMY2/p233/p233_159.wav|84|But he stressed that the partnership is not a construction company.
397 | DUMMY2/p277/p277_312.wav|89|It will work.
398 | DUMMY2/p295/p295_211.wav|92|Leaving the Labour Party is one thing.
399 | DUMMY2/p297/p297_150.wav|42|It is the wealthiest in Europe.
400 | DUMMY2/p305/p305_026.wav|54|He added, however, that all options are under review.
401 | DUMMY2/p292/p292_121.wav|13|This would not be my first choice.
402 | DUMMY2/p253/p253_346.wav|70|It is the Holiday programme with a mortgage.
403 | DUMMY2/p363/p363_171.wav|6|He didn't know where to look.
404 | DUMMY2/p233/p233_128.wav|84|It is still too early for any likely contenders to have emerged.
405 | DUMMY2/p251/p251_137.wav|9|We are currently consulting with a wide range of interested parties.
406 | DUMMY2/p334/p334_034.wav|38|Appointed general secretary last September.
407 | DUMMY2/p286/p286_225.wav|63|This will take several weeks.
408 | DUMMY2/p363/p363_183.wav|6|Public safety is paramount.
409 | DUMMY2/p256/p256_207.wav|90|After that time, the market itself will set the prices.
410 | DUMMY2/p273/p273_311.wav|56|Job losses were also announced.
411 | DUMMY2/p274/p274_425.wav|32|The projections are very positive for South Africa.)
412 | DUMMY2/p254/p254_065.wav|41|That's the day job.
413 | DUMMY2/p335/p335_123.wav|49|Wagner was never like this.
414 | DUMMY2/p258/p258_105.wav|26|We do not expect any surplus.
415 | DUMMY2/p286/p286_294.wav|63|It was an easy decision to come here.
416 | DUMMY2/p361/p361_218.wav|79|But we were wrong.
417 | DUMMY2/p247/p247_426.wav|14|Being captain of this club is fantastic.)
418 | DUMMY2/p266/p266_391.wav|20|In time, may prove a worthy successor to Billy Dodds.
419 | DUMMY2/p253/p253_116.wav|70|It is so sad.
420 | DUMMY2/p261/p261_081.wav|100|Our mother is very worried.
421 | DUMMY2/p268/p268_131.wav|87|But then they scored their fourth.
422 | DUMMY2/p229/p229_192.wav|67|I have the first six months of next season to prove myself.
423 | DUMMY2/p275/p275_260.wav|40|They want to shut the Scottish Office.
424 | DUMMY2/p313/p313_109.wav|76|Nothing is being offered in exchange.
425 | DUMMY2/p347/p347_072.wav|46|Thankfully, Mr Campbell was able to help.
426 | DUMMY2/p298/p298_334.wav|68|Hopefully, the whole of Scottish rugby was paying attention.
427 | DUMMY2/p271/p271_232.wav|27|Jim Wallace, the justice minister, acknowledged that prisoner numbers were a concern.
428 | DUMMY2/p283/p283_056.wav|95|For the meantime, though, the signs are good.
429 | DUMMY2/p255/p255_239.wav|31|It's the same as Glasgow.
430 | DUMMY2/p267/p267_244.wav|0|We have come a long way in the last few sessions.
431 | DUMMY2/p340/p340_403.wav|74|I had a ball today.
432 | DUMMY2/p230/p230_083.wav|35|It might change your life.
433 | DUMMY2/p299/p299_403.wav|58|We will have to see, but it makes you think.
434 | DUMMY2/p343/p343_128.wav|21|Then came the crunch.
435 | DUMMY2/p297/p297_021.wav|42|The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases.
436 | DUMMY2/p298/p298_275.wav|68|There are lots of these women in Finland.
437 | DUMMY2/p347/p347_286.wav|46|It is just too long since the war.
438 | DUMMY2/p239/p239_445.wav|48|Either group is living in fantasy land.
439 | DUMMY2/p286/p286_003.wav|63|Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob.
440 | DUMMY2/p299/p299_082.wav|58|I still feel like a wee boy.
441 | DUMMY2/p306/p306_213.wav|12|It is like being a qualifier again.
442 | DUMMY2/p339/p339_305.wav|18|We know the goals will come.
443 | DUMMY2/p265/p265_274.wav|73|This time, for Rangers, it is certainly the latter.
444 | DUMMY2/p310/p310_382.wav|17|It has the Bank of Scotland behind it.
445 | DUMMY2/p335/p335_403.wav|49|Anyway, even if they didn't it wouldn't have mattered.
446 | DUMMY2/p246/p246_330.wav|5|I would be quite happy for the money to be given back.
447 | DUMMY2/p288/p288_386.wav|47|There is a solution, she believes.
448 | DUMMY2/p234/p234_019.wav|3|Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows.
449 | DUMMY2/p287/p287_408.wav|77|FIRST, we had the Battle of Britain.
450 | DUMMY2/p286/p286_249.wav|63|She will attend in July.
451 | DUMMY2/p251/p251_235.wav|9|I'd never seen a play about me.
452 | DUMMY2/p347/p347_291.wav|46|Insurance will be covered by the receiving galleries.
453 | DUMMY2/p257/p257_058.wav|105|It is not great art.
454 | DUMMY2/p231/p231_471.wav|50|Dennis was not so sure.
455 | DUMMY2/p341/p341_107.wav|66|There was great support all round the route.
456 | DUMMY2/p264/p264_160.wav|65|It was clearly not a battle.)
457 | DUMMY2/p252/p252_155.wav|55|I think, therefore I am ?
458 | DUMMY2/p336/p336_264.wav|98|Ferguson must take the blame.
459 | DUMMY2/p274/p274_142.wav|32|The referee faces a massive job.)
460 | DUMMY2/p303/p303_005.wav|44|She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.
461 | DUMMY2/p233/p233_240.wav|84|The singer is expected to be in hospital for several days.
462 | DUMMY2/p333/p333_220.wav|64|This process of attrition is expected to continue.
463 | DUMMY2/p285/p285_303.wav|2|Alex Smith has been a massive influence on my career as well.
464 | DUMMY2/p277/p277_348.wav|89|To do so he reckons that a good opening result is essential.
465 | DUMMY2/p311/p311_290.wav|4|I am not in denial.
466 | DUMMY2/p286/p286_316.wav|63|I am a retailer by nature.
467 | DUMMY2/p306/p306_119.wav|12|Completion is expected by October the following year.
468 | DUMMY2/p240/p240_028.wav|93|Is this accurate?
469 | DUMMY2/p238/p238_295.wav|37|It has been recorded twice.
470 | DUMMY2/p278/p278_049.wav|10|He will need that machine.
471 | DUMMY2/p351/p351_282.wav|33|We just wish they had done so before.
472 | DUMMY2/p267/p267_348.wav|0|Many of these properties are located in the south of England.
473 | DUMMY2/p312/p312_360.wav|62|And Scotland is no different.
474 | DUMMY2/p311/p311_324.wav|4|He pretended not to care.
475 | DUMMY2/p283/p283_389.wav|95|Scrutiny by the European Parliament is limited.
476 | DUMMY2/p266/p266_079.wav|20|He is in the queue.
477 | DUMMY2/p274/p274_424.wav|32|Not that Scotland can claim the moral high ground.)
478 | DUMMY2/p303/p303_169.wav|44|For athletes in our current climate, their sport is their livelihood.
479 | DUMMY2/p252/p252_237.wav|55|You know the type.
480 | DUMMY2/p323/p323_115.wav|34|Parts of the system are already overstretched.
481 | DUMMY2/p361/p361_013.wav|79|Some have accepted it as a miracle without physical explanation.
482 | DUMMY2/p333/p333_356.wav|64|Which he can do.
483 | DUMMY2/p241/p241_029.wav|86|However, the following year the cancer returned.
484 | DUMMY2/p248/p248_371.wav|99|Whether his stance is shared by the incoming manager is another matter.
485 | DUMMY2/p260/p260_007.wav|81|The rainbow is a division of white light into many beautiful colors.
486 | DUMMY2/p287/p287_257.wav|77|The concerns are the same.
487 | DUMMY2/p263/p263_125.wav|39|It isn't a happy memory.
488 | DUMMY2/p277/p277_258.wav|89|Immediate action must be taken.
489 | DUMMY2/p363/p363_219.wav|6|It was important in training terms.
490 | DUMMY2/p269/p269_191.wav|94|My main concern is that public health is not put at risk.
491 | DUMMY2/p262/p262_020.wav|45|Many complicated ideas about the rainbow have been formed.
492 | DUMMY2/p273/p273_023.wav|56|If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow.
493 | DUMMY2/p278/p278_029.wav|10|They have now been banned from Celtic Park for life.
494 | DUMMY2/p310/p310_065.wav|17|I have had no social life at all.
495 | DUMMY2/p255/p255_352.wav|31|He's very explosive.
496 | DUMMY2/p376/p376_019.wav|71|"Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. "
497 | DUMMY2/p263/p263_307.wav|39|Is there a waiting list ?
498 | DUMMY2/p249/p249_258.wav|80|They must play for each other.
499 | DUMMY2/p258/p258_111.wav|26|Maybe this battle has been.
500 | DUMMY2/p316/p316_129.wav|85|There can be no compromise on that demand.
501 |
--------------------------------------------------------------------------------
/filelists/vctk_audio_sid_text_val_filelist.txt:
--------------------------------------------------------------------------------
1 | DUMMY2/p364/p364_240.wav|88|It had happened to him.
2 | DUMMY2/p280/p280_148.wav|52|It is open season on the Old Firm.
3 | DUMMY2/p231/p231_320.wav|50|However, he is a coach, and he remains a coach at heart.
4 | DUMMY2/p282/p282_129.wav|83|It is not a U-turn.
5 | DUMMY2/p254/p254_015.wav|41|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.
6 | DUMMY2/p228/p228_285.wav|57|The songs are just so good.
7 | DUMMY2/p334/p334_307.wav|38|If they don't, they can expect their funding to be cut.
8 | DUMMY2/p287/p287_081.wav|77|I've never seen anything like it.
9 | DUMMY2/p247/p247_083.wav|14|It is a job creation scheme.)
10 | DUMMY2/p264/p264_051.wav|65|We were leading by two goals.)
11 | DUMMY2/p335/p335_058.wav|49|Let's see that increase over the years.
12 | DUMMY2/p236/p236_225.wav|75|There is no quick fix.
13 | DUMMY2/p374/p374_353.wav|11|And that brings us to the point.
14 | DUMMY2/p272/p272_076.wav|69|Sounds like The Sixth Sense?
15 | DUMMY2/p271/p271_152.wav|27|The petition was formally presented at Downing Street yesterday.
16 | DUMMY2/p228/p228_127.wav|57|They've got to account for it.
17 | DUMMY2/p276/p276_223.wav|106|It's been a humbling year.
18 | DUMMY2/p262/p262_248.wav|45|The project has already secured the support of Sir Sean Connery.
19 | DUMMY2/p314/p314_086.wav|51|The team this year is going places.
20 | DUMMY2/p225/p225_038.wav|101|Diving is no part of football.
21 | DUMMY2/p279/p279_088.wav|25|The shareholders will vote to wind up the company on Friday morning.
22 | DUMMY2/p272/p272_018.wav|69|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.
23 | DUMMY2/p256/p256_098.wav|90|She told The Herald.
24 | DUMMY2/p261/p261_218.wav|100|All will be revealed in due course.
25 | DUMMY2/p265/p265_063.wav|73|IT shouldn't come as a surprise, but it does.
26 | DUMMY2/p314/p314_042.wav|51|It is all about people being assaulted, abused.
27 | DUMMY2/p241/p241_188.wav|86|I wish I could say something.
28 | DUMMY2/p283/p283_111.wav|95|It's good to have a voice.
29 | DUMMY2/p275/p275_006.wav|40|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.
30 | DUMMY2/p228/p228_092.wav|57|Today I couldn't run on it.
31 | DUMMY2/p295/p295_343.wav|92|The atmosphere is businesslike.
32 | DUMMY2/p228/p228_187.wav|57|They will run a mile.
33 | DUMMY2/p294/p294_317.wav|104|It didn't put me off.
34 | DUMMY2/p231/p231_445.wav|50|It sounded like a bomb.
35 | DUMMY2/p272/p272_086.wav|69|Today she has been released.
36 | DUMMY2/p255/p255_210.wav|31|It was worth a photograph.
37 | DUMMY2/p229/p229_060.wav|67|And a film maker was born.
38 | DUMMY2/p260/p260_232.wav|81|The Home Office would not release any further details about the group.
39 | DUMMY2/p245/p245_025.wav|59|Johnson was pretty low.
40 | DUMMY2/p333/p333_185.wav|64|This area is perfect for children.
41 | DUMMY2/p244/p244_242.wav|78|He is a man of the people.
42 | DUMMY2/p376/p376_187.wav|71|"It is a terrible loss."
43 | DUMMY2/p239/p239_156.wav|48|It is a good lifestyle.
44 | DUMMY2/p307/p307_037.wav|22|He released a half-dozen solo albums.
45 | DUMMY2/p305/p305_185.wav|54|I am not even thinking about that.
46 | DUMMY2/p272/p272_081.wav|69|It was magic.
47 | DUMMY2/p302/p302_297.wav|30|I'm trying to stay open on that.
48 | DUMMY2/p275/p275_320.wav|40|We are in the end game.
49 | DUMMY2/p239/p239_231.wav|48|Then we will face the Danish champions.
50 | DUMMY2/p268/p268_301.wav|87|It was only later that the condition was diagnosed.
51 | DUMMY2/p336/p336_088.wav|98|They failed to reach agreement yesterday.
52 | DUMMY2/p278/p278_255.wav|10|They made such decisions in London.
53 | DUMMY2/p361/p361_132.wav|79|That got me out.
54 | DUMMY2/p307/p307_146.wav|22|You hope he prevails.
55 | DUMMY2/p244/p244_147.wav|78|They could not ignore the will of parliament, he claimed.
56 | DUMMY2/p294/p294_283.wav|104|This is our unfinished business.
57 | DUMMY2/p283/p283_300.wav|95|I would have the hammer in the crowd.
58 | DUMMY2/p239/p239_079.wav|48|I can understand the frustrations of our fans.
59 | DUMMY2/p264/p264_009.wav|65|There is , according to legend, a boiling pot of gold at one end. )
60 | DUMMY2/p307/p307_348.wav|22|He did not oppose the divorce.
61 | DUMMY2/p304/p304_308.wav|72|We are the gateway to justice.
62 | DUMMY2/p281/p281_056.wav|36|None has ever been found.
63 | DUMMY2/p267/p267_158.wav|0|We were given a warm and friendly reception.
64 | DUMMY2/p300/p300_169.wav|102|Who do these people think they are?
65 | DUMMY2/p276/p276_177.wav|106|They exist in name alone.
66 | DUMMY2/p228/p228_245.wav|57|It is a policy which has the full support of the minister.
67 | DUMMY2/p300/p300_303.wav|102|I'm wondering what you feel about the youngest.
68 | DUMMY2/p362/p362_247.wav|15|This would give Scotland around eight members.
69 | DUMMY2/p326/p326_031.wav|28|United were in control without always being dominant.
70 | DUMMY2/p361/p361_288.wav|79|I did not think it was very proper.
71 | DUMMY2/p286/p286_145.wav|63|Tiger is not the norm.
72 | DUMMY2/p234/p234_071.wav|3|She did that for the rest of her life.
73 | DUMMY2/p263/p263_296.wav|39|The decision was announced at its annual conference in Dunfermline.
74 | DUMMY2/p323/p323_228.wav|34|She became a heroine of my childhood.
75 | DUMMY2/p280/p280_346.wav|52|It was a bit like having children.
76 | DUMMY2/p333/p333_080.wav|64|But the tragedy did not stop there.
77 | DUMMY2/p226/p226_268.wav|43|That decision is for the British Parliament and people.
78 | DUMMY2/p362/p362_314.wav|15|Is that right?
79 | DUMMY2/p240/p240_047.wav|93|It is so sad.
80 | DUMMY2/p250/p250_207.wav|24|You could feel the heat.
81 | DUMMY2/p273/p273_176.wav|56|Neither side would reveal the details of the offer.
82 | DUMMY2/p316/p316_147.wav|85|And frankly, it's been a while.
83 | DUMMY2/p265/p265_047.wav|73|It is unique.
84 | DUMMY2/p336/p336_353.wav|98|Sometimes you get them, sometimes you don't.
85 | DUMMY2/p230/p230_376.wav|35|This hasn't happened in a vacuum.
86 | DUMMY2/p308/p308_209.wav|107|There is great potential on this river.
87 | DUMMY2/p250/p250_442.wav|24|We have not yet received a letter from the Irish.
88 | DUMMY2/p260/p260_037.wav|81|It's a fact.
89 | DUMMY2/p299/p299_345.wav|58|We're very excited and challenged by the project.
90 | DUMMY2/p269/p269_218.wav|94|A Grampian Police spokesman said.
91 | DUMMY2/p306/p306_014.wav|12|To the Hebrews it was a token that there would be no more universal floods.
92 | DUMMY2/p271/p271_292.wav|27|It's a record label, not a form of music.
93 | DUMMY2/p247/p247_225.wav|14|I am considered a teenager.)
94 | DUMMY2/p294/p294_094.wav|104|It should be a condition of employment.
95 | DUMMY2/p269/p269_031.wav|94|Is this accurate?
96 | DUMMY2/p275/p275_116.wav|40|It's not fair.
97 | DUMMY2/p265/p265_006.wav|73|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.
98 | DUMMY2/p285/p285_072.wav|2|Mr Irvine said Mr Rafferty was now in good spirits.
99 | DUMMY2/p270/p270_167.wav|8|We did what we had to do.
100 | DUMMY2/p360/p360_397.wav|60|It is a relief.
101 |
--------------------------------------------------------------------------------
/filelists/vctk_audio_sid_text_val_filelist.txt.cleaned:
--------------------------------------------------------------------------------
1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm.
2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm.
3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt.
4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn.
5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn.
6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd.
7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt.
8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt.
9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm.
10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz.
11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz.
12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks.
13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt.
14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns?
15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ.
16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt.
17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ.
18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi.
19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz.
20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl.
21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ.
22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn.
23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld.
24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs.
25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz.
26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd.
27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ.
28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs.
29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ.
30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt.
31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk.
32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl.
33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf.
34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm.
35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd.
36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf.
37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn.
38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp.
39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ.
40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən.
41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl.
42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs."
43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl.
44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz.
45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt.
46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk.
47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt.
48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm.
49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz.
50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd.
51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ.
52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən.
53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt.
54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz.
55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd.
56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs.
57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd.
58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz.
59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd.
60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs.
61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs.
62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd.
63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən.
64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ?
65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn.
66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ.
67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst.
68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz.
69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt.
70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ.
71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm.
72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf.
73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn.
74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd.
75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən.
76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ.
77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl.
78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt?
79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd.
80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt.
81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ.
82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl.
83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk.
84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt.
85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm.
86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ.
87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ.
88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt.
89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt.
90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd.
91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz.
92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk.
93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ.
94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt.
95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət?
96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ.
97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ.
98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts.
99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː.
100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf.
101 |
--------------------------------------------------------------------------------
/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "import IPython.display as ipd\n",
12 | "\n",
13 | "import os\n",
14 | "import json\n",
15 | "import math\n",
16 | "import torch\n",
17 | "from torch import nn\n",
18 | "from torch.nn import functional as F\n",
19 | "from torch.utils.data import DataLoader\n",
20 | "\n",
21 | "import commons\n",
22 | "import utils\n",
23 | "from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate\n",
24 | "from models import SynthesizerTrn\n",
25 | "from text.symbols import symbols\n",
26 | "from text import text_to_sequence\n",
27 | "\n",
28 | "from scipy.io.wavfile import write\n",
29 | "\n",
30 | "\n",
31 | "def get_text(text, hps):\n",
32 | " text_norm = text_to_sequence(text, hps.data.text_cleaners)\n",
33 | " if hps.data.add_blank:\n",
34 | " text_norm = commons.intersperse(text_norm, 0)\n",
35 | " text_norm = torch.LongTensor(text_norm)\n",
36 | " return text_norm"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "## LJ Speech"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "hps = utils.get_hparams_from_file(\"./configs/ljs_base.json\")"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "net_g = SynthesizerTrn(\n",
62 | " len(symbols),\n",
63 | " hps.data.filter_length // 2 + 1,\n",
64 | " hps.train.segment_size // hps.data.hop_length,\n",
65 | " **hps.model).cuda()\n",
66 | "_ = net_g.eval()\n",
67 | "\n",
68 | "_ = utils.load_checkpoint(\"/path/to/pretrained_ljs.pth\", net_g, None)"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
78 | "with torch.no_grad():\n",
79 | " x_tst = stn_tst.cuda().unsqueeze(0)\n",
80 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
81 | " audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
82 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | "## VCTK"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "hps = utils.get_hparams_from_file(\"./configs/vctk_base.json\")"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "net_g = SynthesizerTrn(\n",
108 | " len(symbols),\n",
109 | " hps.data.filter_length // 2 + 1,\n",
110 | " hps.train.segment_size // hps.data.hop_length,\n",
111 | " n_speakers=hps.data.n_speakers,\n",
112 | " **hps.model).cuda()\n",
113 | "_ = net_g.eval()\n",
114 | "\n",
115 | "_ = utils.load_checkpoint(\"/path/to/pretrained_vctk.pth\", net_g, None)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
125 | "with torch.no_grad():\n",
126 | " x_tst = stn_tst.cuda().unsqueeze(0)\n",
127 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
128 | " sid = torch.LongTensor([4]).cuda()\n",
129 | " audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
130 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "metadata": {},
136 | "source": [
137 | "### Voice Conversion"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)\n",
147 | "collate_fn = TextAudioSpeakerCollate()\n",
148 | "loader = DataLoader(dataset, num_workers=8, shuffle=False,\n",
149 | " batch_size=1, pin_memory=True,\n",
150 | " drop_last=True, collate_fn=collate_fn)\n",
151 | "data_list = list(loader)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "with torch.no_grad():\n",
161 | " x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cuda() for x in data_list[0]]\n",
162 | " sid_tgt1 = torch.LongTensor([1]).cuda()\n",
163 | " sid_tgt2 = torch.LongTensor([2]).cuda()\n",
164 | " sid_tgt3 = torch.LongTensor([4]).cuda()\n",
165 | " audio1 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data.cpu().float().numpy()\n",
166 | " audio2 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt2)[0][0,0].data.cpu().float().numpy()\n",
167 | " audio3 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt3)[0][0,0].data.cpu().float().numpy()\n",
168 | "print(\"Original SID: %d\" % sid_src.item())\n",
169 | "ipd.display(ipd.Audio(y[0].cpu().numpy(), rate=hps.data.sampling_rate, normalize=False))\n",
170 | "print(\"Converted SID: %d\" % sid_tgt1.item())\n",
171 | "ipd.display(ipd.Audio(audio1, rate=hps.data.sampling_rate, normalize=False))\n",
172 | "print(\"Converted SID: %d\" % sid_tgt2.item())\n",
173 | "ipd.display(ipd.Audio(audio2, rate=hps.data.sampling_rate, normalize=False))\n",
174 | "print(\"Converted SID: %d\" % sid_tgt3.item())\n",
175 | "ipd.display(ipd.Audio(audio3, rate=hps.data.sampling_rate, normalize=False))"
176 | ]
177 | }
178 | ],
179 | "metadata": {
180 | "kernelspec": {
181 | "display_name": "Python 3",
182 | "language": "python",
183 | "name": "python3"
184 | },
185 | "language_info": {
186 | "codemirror_mode": {
187 | "name": "ipython",
188 | "version": 3
189 | },
190 | "file_extension": ".py",
191 | "mimetype": "text/x-python",
192 | "name": "python",
193 | "nbconvert_exporter": "python",
194 | "pygments_lexer": "ipython3",
195 | "version": "3.7.7"
196 | }
197 | },
198 | "nbformat": 4,
199 | "nbformat_minor": 4
200 | }
201 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import commons
5 |
6 |
7 | def feature_loss(fmap_r, fmap_g):
8 | loss = 0
9 | for dr, dg in zip(fmap_r, fmap_g):
10 | for rl, gl in zip(dr, dg):
11 | rl = rl.float().detach()
12 | gl = gl.float()
13 | loss += torch.mean(torch.abs(rl - gl))
14 |
15 | return loss * 2
16 |
17 |
18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19 | loss = 0
20 | r_losses = []
21 | g_losses = []
22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23 | dr = dr.float()
24 | dg = dg.float()
25 | r_loss = torch.mean((1-dr)**2)
26 | g_loss = torch.mean(dg**2)
27 | loss += (r_loss + g_loss)
28 | r_losses.append(r_loss.item())
29 | g_losses.append(g_loss.item())
30 |
31 | return loss, r_losses, g_losses
32 |
33 |
34 | def generator_loss(disc_outputs):
35 | loss = 0
36 | gen_losses = []
37 | for dg in disc_outputs:
38 | dg = dg.float()
39 | l = torch.mean((1-dg)**2)
40 | gen_losses.append(l)
41 | loss += l
42 |
43 | return loss, gen_losses
44 |
45 |
46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47 | """
48 | z_p, logs_q: [b, h, t_t]
49 | m_p, logs_p: [b, h, t_t]
50 | """
51 | z_p = z_p.float()
52 | logs_q = logs_q.float()
53 | m_p = m_p.float()
54 | logs_p = logs_p.float()
55 | z_mask = z_mask.float()
56 |
57 | kl = logs_p - logs_q - 0.5
58 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
59 | kl = torch.sum(kl * z_mask)
60 | l = kl / torch.sum(z_mask)
61 | return l
62 |
--------------------------------------------------------------------------------
/mel_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import numpy as np
9 | import librosa
10 | import librosa.util as librosa_util
11 | from librosa.util import normalize, pad_center, tiny
12 | from scipy.signal import get_window
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | MAX_WAV_VALUE = 32768.0
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 | """
21 | PARAMS
22 | ------
23 | C: compression factor
24 | """
25 | return torch.log(torch.clamp(x, min=clip_val) * C)
26 |
27 |
28 | def dynamic_range_decompression_torch(x, C=1):
29 | """
30 | PARAMS
31 | ------
32 | C: compression factor used to compress
33 | """
34 | return torch.exp(x) / C
35 |
36 |
37 | def spectral_normalize_torch(magnitudes):
38 | output = dynamic_range_compression_torch(magnitudes)
39 | return output
40 |
41 |
42 | def spectral_de_normalize_torch(magnitudes):
43 | output = dynamic_range_decompression_torch(magnitudes)
44 | return output
45 |
46 |
47 | mel_basis = {}
48 | hann_window = {}
49 |
50 |
51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52 | if torch.min(y) < -1.:
53 | print('min value is ', torch.min(y))
54 | if torch.max(y) > 1.:
55 | print('max value is ', torch.max(y))
56 |
57 | global hann_window
58 | dtype_device = str(y.dtype) + '_' + str(y.device)
59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
60 | if wnsize_dtype_device not in hann_window:
61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62 |
63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64 | y = y.squeeze(1)
65 |
66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67 | center=center, pad_mode='reflect', normalized=False, onesided=True)
68 |
69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70 | return spec
71 |
72 |
73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74 | global mel_basis
75 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
76 | fmax_dtype_device = str(fmax) + '_' + dtype_device
77 | if fmax_dtype_device not in mel_basis:
78 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81 | spec = spectral_normalize_torch(spec)
82 | return spec
83 |
84 |
85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86 | if torch.min(y) < -1.:
87 | print('min value is ', torch.min(y))
88 | if torch.max(y) > 1.:
89 | print('max value is ', torch.max(y))
90 |
91 | global mel_basis, hann_window
92 | dtype_device = str(y.dtype) + '_' + str(y.device)
93 | fmax_dtype_device = str(fmax) + '_' + dtype_device
94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
95 | if fmax_dtype_device not in mel_basis:
96 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98 | if wnsize_dtype_device not in hann_window:
99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100 |
101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102 | y = y.squeeze(1)
103 |
104 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105 | center=center, pad_mode='reflect', normalized=False, onesided=True)
106 |
107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108 |
109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110 | spec = spectral_normalize_torch(spec)
111 |
112 | return spec
113 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | import commons
8 | import modules
9 | import attentions
10 | import monotonic_align
11 |
12 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
13 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14 | from commons import init_weights, get_padding
15 |
16 |
17 | class StochasticDurationPredictor(nn.Module):
18 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
19 | super().__init__()
20 | filter_channels = in_channels # it needs to be removed from future version.
21 | self.in_channels = in_channels
22 | self.filter_channels = filter_channels
23 | self.kernel_size = kernel_size
24 | self.p_dropout = p_dropout
25 | self.n_flows = n_flows
26 | self.gin_channels = gin_channels
27 |
28 | self.log_flow = modules.Log()
29 | self.flows = nn.ModuleList()
30 | self.flows.append(modules.ElementwiseAffine(2))
31 | for i in range(n_flows):
32 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
33 | self.flows.append(modules.Flip())
34 |
35 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
36 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
37 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
38 | self.post_flows = nn.ModuleList()
39 | self.post_flows.append(modules.ElementwiseAffine(2))
40 | for i in range(4):
41 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
42 | self.post_flows.append(modules.Flip())
43 |
44 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
45 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
46 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
47 | if gin_channels != 0:
48 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
49 |
50 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
51 | x = torch.detach(x)
52 | x = self.pre(x)
53 | if g is not None:
54 | g = torch.detach(g)
55 | x = x + self.cond(g)
56 | x = self.convs(x, x_mask)
57 | x = self.proj(x) * x_mask
58 |
59 | if not reverse:
60 | flows = self.flows
61 | assert w is not None
62 |
63 | logdet_tot_q = 0
64 | h_w = self.post_pre(w)
65 | h_w = self.post_convs(h_w, x_mask)
66 | h_w = self.post_proj(h_w) * x_mask
67 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
68 | z_q = e_q
69 | for flow in self.post_flows:
70 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
71 | logdet_tot_q += logdet_q
72 | z_u, z1 = torch.split(z_q, [1, 1], 1)
73 | u = torch.sigmoid(z_u) * x_mask
74 | z0 = (w - u) * x_mask
75 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
76 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
77 |
78 | logdet_tot = 0
79 | z0, logdet = self.log_flow(z0, x_mask)
80 | logdet_tot += logdet
81 | z = torch.cat([z0, z1], 1)
82 | for flow in flows:
83 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
84 | logdet_tot = logdet_tot + logdet
85 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
86 | return nll + logq # [b]
87 | else:
88 | flows = list(reversed(self.flows))
89 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
90 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
91 | for flow in flows:
92 | z = flow(z, x_mask, g=x, reverse=reverse)
93 | z0, z1 = torch.split(z, [1, 1], 1)
94 | logw = z0
95 | return logw
96 |
97 |
98 | class DurationPredictor(nn.Module):
99 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
100 | super().__init__()
101 |
102 | self.in_channels = in_channels
103 | self.filter_channels = filter_channels
104 | self.kernel_size = kernel_size
105 | self.p_dropout = p_dropout
106 | self.gin_channels = gin_channels
107 |
108 | self.drop = nn.Dropout(p_dropout)
109 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
110 | self.norm_1 = modules.LayerNorm(filter_channels)
111 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
112 | self.norm_2 = modules.LayerNorm(filter_channels)
113 | self.proj = nn.Conv1d(filter_channels, 1, 1)
114 |
115 | if gin_channels != 0:
116 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
117 |
118 | def forward(self, x, x_mask, g=None):
119 | x = torch.detach(x)
120 | if g is not None:
121 | g = torch.detach(g)
122 | x = x + self.cond(g)
123 | x = self.conv_1(x * x_mask)
124 | x = torch.relu(x)
125 | x = self.norm_1(x)
126 | x = self.drop(x)
127 | x = self.conv_2(x * x_mask)
128 | x = torch.relu(x)
129 | x = self.norm_2(x)
130 | x = self.drop(x)
131 | x = self.proj(x * x_mask)
132 | return x * x_mask
133 |
134 |
135 | class TextEncoder(nn.Module):
136 | def __init__(self,
137 | n_vocab,
138 | out_channels,
139 | hidden_channels,
140 | filter_channels,
141 | n_heads,
142 | n_layers,
143 | kernel_size,
144 | p_dropout):
145 | super().__init__()
146 | self.n_vocab = n_vocab
147 | self.out_channels = out_channels
148 | self.hidden_channels = hidden_channels
149 | self.filter_channels = filter_channels
150 | self.n_heads = n_heads
151 | self.n_layers = n_layers
152 | self.kernel_size = kernel_size
153 | self.p_dropout = p_dropout
154 |
155 | self.emb = nn.Embedding(n_vocab, hidden_channels)
156 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
157 |
158 | self.encoder = attentions.Encoder(
159 | hidden_channels,
160 | filter_channels,
161 | n_heads,
162 | n_layers,
163 | kernel_size,
164 | p_dropout)
165 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
166 |
167 | def forward(self, x, x_lengths):
168 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
169 | x = torch.transpose(x, 1, -1) # [b, h, t]
170 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
171 |
172 | x = self.encoder(x * x_mask, x_mask)
173 | stats = self.proj(x) * x_mask
174 |
175 | m, logs = torch.split(stats, self.out_channels, dim=1)
176 | return x, m, logs, x_mask
177 |
178 |
179 | class ResidualCouplingBlock(nn.Module):
180 | def __init__(self,
181 | channels,
182 | hidden_channels,
183 | kernel_size,
184 | dilation_rate,
185 | n_layers,
186 | n_flows=4,
187 | gin_channels=0):
188 | super().__init__()
189 | self.channels = channels
190 | self.hidden_channels = hidden_channels
191 | self.kernel_size = kernel_size
192 | self.dilation_rate = dilation_rate
193 | self.n_layers = n_layers
194 | self.n_flows = n_flows
195 | self.gin_channels = gin_channels
196 |
197 | self.flows = nn.ModuleList()
198 | for i in range(n_flows):
199 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
200 | self.flows.append(modules.Flip())
201 |
202 | def forward(self, x, x_mask, g=None, reverse=False):
203 | if not reverse:
204 | for flow in self.flows:
205 | x, _ = flow(x, x_mask, g=g, reverse=reverse)
206 | else:
207 | for flow in reversed(self.flows):
208 | x = flow(x, x_mask, g=g, reverse=reverse)
209 | return x
210 |
211 |
212 | class PosteriorEncoder(nn.Module):
213 | def __init__(self,
214 | in_channels,
215 | out_channels,
216 | hidden_channels,
217 | kernel_size,
218 | dilation_rate,
219 | n_layers,
220 | gin_channels=0):
221 | super().__init__()
222 | self.in_channels = in_channels
223 | self.out_channels = out_channels
224 | self.hidden_channels = hidden_channels
225 | self.kernel_size = kernel_size
226 | self.dilation_rate = dilation_rate
227 | self.n_layers = n_layers
228 | self.gin_channels = gin_channels
229 |
230 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
231 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
232 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
233 |
234 | def forward(self, x, x_lengths, g=None):
235 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
236 | x = self.pre(x) * x_mask
237 | x = self.enc(x, x_mask, g=g)
238 | stats = self.proj(x) * x_mask
239 | m, logs = torch.split(stats, self.out_channels, dim=1)
240 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
241 | return z, m, logs, x_mask
242 |
243 |
244 | class Generator(torch.nn.Module):
245 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
246 | super(Generator, self).__init__()
247 | self.num_kernels = len(resblock_kernel_sizes)
248 | self.num_upsamples = len(upsample_rates)
249 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
250 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
251 |
252 | self.ups = nn.ModuleList()
253 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
254 | self.ups.append(weight_norm(
255 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
256 | k, u, padding=(k-u)//2)))
257 |
258 | self.resblocks = nn.ModuleList()
259 | for i in range(len(self.ups)):
260 | ch = upsample_initial_channel//(2**(i+1))
261 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
262 | self.resblocks.append(resblock(ch, k, d))
263 |
264 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
265 | self.ups.apply(init_weights)
266 |
267 | if gin_channels != 0:
268 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
269 |
270 | def forward(self, x, g=None):
271 | x = self.conv_pre(x)
272 | if g is not None:
273 | x = x + self.cond(g)
274 |
275 | for i in range(self.num_upsamples):
276 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
277 | x = self.ups[i](x)
278 | xs = None
279 | for j in range(self.num_kernels):
280 | if xs is None:
281 | xs = self.resblocks[i*self.num_kernels+j](x)
282 | else:
283 | xs += self.resblocks[i*self.num_kernels+j](x)
284 | x = xs / self.num_kernels
285 | x = F.leaky_relu(x)
286 | x = self.conv_post(x)
287 | x = torch.tanh(x)
288 |
289 | return x
290 |
291 | def remove_weight_norm(self):
292 | print('Removing weight norm...')
293 | for l in self.ups:
294 | remove_weight_norm(l)
295 | for l in self.resblocks:
296 | l.remove_weight_norm()
297 |
298 |
299 | class DiscriminatorP(torch.nn.Module):
300 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
301 | super(DiscriminatorP, self).__init__()
302 | self.period = period
303 | self.use_spectral_norm = use_spectral_norm
304 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
305 | self.convs = nn.ModuleList([
306 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
307 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
308 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
309 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
310 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
311 | ])
312 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
313 |
314 | def forward(self, x):
315 | fmap = []
316 |
317 | # 1d to 2d
318 | b, c, t = x.shape
319 | if t % self.period != 0: # pad first
320 | n_pad = self.period - (t % self.period)
321 | x = F.pad(x, (0, n_pad), "reflect")
322 | t = t + n_pad
323 | x = x.view(b, c, t // self.period, self.period)
324 |
325 | for l in self.convs:
326 | x = l(x)
327 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
328 | fmap.append(x)
329 | x = self.conv_post(x)
330 | fmap.append(x)
331 | x = torch.flatten(x, 1, -1)
332 |
333 | return x, fmap
334 |
335 |
336 | class DiscriminatorS(torch.nn.Module):
337 | def __init__(self, use_spectral_norm=False):
338 | super(DiscriminatorS, self).__init__()
339 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
340 | self.convs = nn.ModuleList([
341 | norm_f(Conv1d(1, 16, 15, 1, padding=7)),
342 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
343 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
344 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
345 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
346 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
347 | ])
348 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
349 |
350 | def forward(self, x):
351 | fmap = []
352 |
353 | for l in self.convs:
354 | x = l(x)
355 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
356 | fmap.append(x)
357 | x = self.conv_post(x)
358 | fmap.append(x)
359 | x = torch.flatten(x, 1, -1)
360 |
361 | return x, fmap
362 |
363 |
364 | class MultiPeriodDiscriminator(torch.nn.Module):
365 | def __init__(self, use_spectral_norm=False):
366 | super(MultiPeriodDiscriminator, self).__init__()
367 | periods = [2,3,5,7,11]
368 |
369 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
370 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
371 | self.discriminators = nn.ModuleList(discs)
372 |
373 | def forward(self, y, y_hat):
374 | y_d_rs = []
375 | y_d_gs = []
376 | fmap_rs = []
377 | fmap_gs = []
378 | for i, d in enumerate(self.discriminators):
379 | y_d_r, fmap_r = d(y)
380 | y_d_g, fmap_g = d(y_hat)
381 | y_d_rs.append(y_d_r)
382 | y_d_gs.append(y_d_g)
383 | fmap_rs.append(fmap_r)
384 | fmap_gs.append(fmap_g)
385 |
386 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
387 |
388 |
389 |
390 | class SynthesizerTrn(nn.Module):
391 | """
392 | Synthesizer for Training
393 | """
394 |
395 | def __init__(self,
396 | n_vocab,
397 | spec_channels,
398 | segment_size,
399 | inter_channels,
400 | hidden_channels,
401 | filter_channels,
402 | n_heads,
403 | n_layers,
404 | kernel_size,
405 | p_dropout,
406 | resblock,
407 | resblock_kernel_sizes,
408 | resblock_dilation_sizes,
409 | upsample_rates,
410 | upsample_initial_channel,
411 | upsample_kernel_sizes,
412 | n_speakers=0,
413 | gin_channels=0,
414 | use_sdp=True,
415 | **kwargs):
416 |
417 | super().__init__()
418 | self.n_vocab = n_vocab
419 | self.spec_channels = spec_channels
420 | self.inter_channels = inter_channels
421 | self.hidden_channels = hidden_channels
422 | self.filter_channels = filter_channels
423 | self.n_heads = n_heads
424 | self.n_layers = n_layers
425 | self.kernel_size = kernel_size
426 | self.p_dropout = p_dropout
427 | self.resblock = resblock
428 | self.resblock_kernel_sizes = resblock_kernel_sizes
429 | self.resblock_dilation_sizes = resblock_dilation_sizes
430 | self.upsample_rates = upsample_rates
431 | self.upsample_initial_channel = upsample_initial_channel
432 | self.upsample_kernel_sizes = upsample_kernel_sizes
433 | self.segment_size = segment_size
434 | self.n_speakers = n_speakers
435 | self.gin_channels = gin_channels
436 |
437 | self.use_sdp = use_sdp
438 |
439 | self.enc_p = TextEncoder(n_vocab,
440 | inter_channels,
441 | hidden_channels,
442 | filter_channels,
443 | n_heads,
444 | n_layers,
445 | kernel_size,
446 | p_dropout)
447 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
448 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
449 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
450 |
451 | if use_sdp:
452 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
453 | else:
454 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
455 |
456 | if n_speakers > 1:
457 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
458 |
459 | def forward(self, x, x_lengths, y, y_lengths, sid=None):
460 |
461 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
462 | if self.n_speakers > 0:
463 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
464 | else:
465 | g = None
466 |
467 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
468 | z_p = self.flow(z, y_mask, g=g)
469 |
470 | with torch.no_grad():
471 | # negative cross-entropy
472 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
473 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
474 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
475 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
476 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
477 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
478 |
479 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
480 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
481 |
482 | w = attn.sum(2)
483 | if self.use_sdp:
484 | l_length = self.dp(x, x_mask, w, g=g)
485 | l_length = l_length / torch.sum(x_mask)
486 | else:
487 | logw_ = torch.log(w + 1e-6) * x_mask
488 | logw = self.dp(x, x_mask, g=g)
489 | l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
490 |
491 | # expand prior
492 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
493 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
494 |
495 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
496 | o = self.dec(z_slice, g=g)
497 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
498 |
499 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
500 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
501 | if self.n_speakers > 0:
502 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
503 | else:
504 | g = None
505 |
506 | if self.use_sdp:
507 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
508 | else:
509 | logw = self.dp(x, x_mask, g=g)
510 | w = torch.exp(logw) * x_mask * length_scale
511 | w_ceil = torch.ceil(w)
512 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
513 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
514 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
515 | attn = commons.generate_path(w_ceil, attn_mask)
516 |
517 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
518 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
519 |
520 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
521 | z = self.flow(z_p, y_mask, g=g, reverse=True)
522 | o = self.dec((z * y_mask)[:,:,:max_len], g=g)
523 | return o, attn, y_mask, (z, z_p, m_p, logs_p)
524 |
525 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
526 | assert self.n_speakers > 0, "n_speakers have to be larger than 0."
527 | g_src = self.emb_g(sid_src).unsqueeze(-1)
528 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
529 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
530 | z_p = self.flow(z, y_mask, g=g_src)
531 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
532 | o_hat = self.dec(z_hat * y_mask, g=g_tgt)
533 | return o_hat, y_mask, (z, z_p, z_hat)
534 |
535 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10 | from torch.nn.utils import weight_norm, remove_weight_norm
11 |
12 | import commons
13 | from commons import init_weights, get_padding
14 | from transforms import piecewise_rational_quadratic_transform
15 |
16 |
17 | LRELU_SLOPE = 0.1
18 |
19 |
20 | class LayerNorm(nn.Module):
21 | def __init__(self, channels, eps=1e-5):
22 | super().__init__()
23 | self.channels = channels
24 | self.eps = eps
25 |
26 | self.gamma = nn.Parameter(torch.ones(channels))
27 | self.beta = nn.Parameter(torch.zeros(channels))
28 |
29 | def forward(self, x):
30 | x = x.transpose(1, -1)
31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32 | return x.transpose(1, -1)
33 |
34 |
35 | class ConvReluNorm(nn.Module):
36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
37 | super().__init__()
38 | self.in_channels = in_channels
39 | self.hidden_channels = hidden_channels
40 | self.out_channels = out_channels
41 | self.kernel_size = kernel_size
42 | self.n_layers = n_layers
43 | self.p_dropout = p_dropout
44 | assert n_layers > 1, "Number of layers should be larger than 0."
45 |
46 | self.conv_layers = nn.ModuleList()
47 | self.norm_layers = nn.ModuleList()
48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
49 | self.norm_layers.append(LayerNorm(hidden_channels))
50 | self.relu_drop = nn.Sequential(
51 | nn.ReLU(),
52 | nn.Dropout(p_dropout))
53 | for _ in range(n_layers-1):
54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
55 | self.norm_layers.append(LayerNorm(hidden_channels))
56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
57 | self.proj.weight.data.zero_()
58 | self.proj.bias.data.zero_()
59 |
60 | def forward(self, x, x_mask):
61 | x_org = x
62 | for i in range(self.n_layers):
63 | x = self.conv_layers[i](x * x_mask)
64 | x = self.norm_layers[i](x)
65 | x = self.relu_drop(x)
66 | x = x_org + self.proj(x)
67 | return x * x_mask
68 |
69 |
70 | class DDSConv(nn.Module):
71 | """
72 | Dialted and Depth-Separable Convolution
73 | """
74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
75 | super().__init__()
76 | self.channels = channels
77 | self.kernel_size = kernel_size
78 | self.n_layers = n_layers
79 | self.p_dropout = p_dropout
80 |
81 | self.drop = nn.Dropout(p_dropout)
82 | self.convs_sep = nn.ModuleList()
83 | self.convs_1x1 = nn.ModuleList()
84 | self.norms_1 = nn.ModuleList()
85 | self.norms_2 = nn.ModuleList()
86 | for i in range(n_layers):
87 | dilation = kernel_size ** i
88 | padding = (kernel_size * dilation - dilation) // 2
89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
90 | groups=channels, dilation=dilation, padding=padding
91 | ))
92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
93 | self.norms_1.append(LayerNorm(channels))
94 | self.norms_2.append(LayerNorm(channels))
95 |
96 | def forward(self, x, x_mask, g=None):
97 | if g is not None:
98 | x = x + g
99 | for i in range(self.n_layers):
100 | y = self.convs_sep[i](x * x_mask)
101 | y = self.norms_1[i](y)
102 | y = F.gelu(y)
103 | y = self.convs_1x1[i](y)
104 | y = self.norms_2[i](y)
105 | y = F.gelu(y)
106 | y = self.drop(y)
107 | x = x + y
108 | return x * x_mask
109 |
110 |
111 | class WN(torch.nn.Module):
112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
113 | super(WN, self).__init__()
114 | assert(kernel_size % 2 == 1)
115 | self.hidden_channels =hidden_channels
116 | self.kernel_size = kernel_size,
117 | self.dilation_rate = dilation_rate
118 | self.n_layers = n_layers
119 | self.gin_channels = gin_channels
120 | self.p_dropout = p_dropout
121 |
122 | self.in_layers = torch.nn.ModuleList()
123 | self.res_skip_layers = torch.nn.ModuleList()
124 | self.drop = nn.Dropout(p_dropout)
125 |
126 | if gin_channels != 0:
127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
129 |
130 | for i in range(n_layers):
131 | dilation = dilation_rate ** i
132 | padding = int((kernel_size * dilation - dilation) / 2)
133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
134 | dilation=dilation, padding=padding)
135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
136 | self.in_layers.append(in_layer)
137 |
138 | # last one is not necessary
139 | if i < n_layers - 1:
140 | res_skip_channels = 2 * hidden_channels
141 | else:
142 | res_skip_channels = hidden_channels
143 |
144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
146 | self.res_skip_layers.append(res_skip_layer)
147 |
148 | def forward(self, x, x_mask, g=None, **kwargs):
149 | output = torch.zeros_like(x)
150 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
151 |
152 | if g is not None:
153 | g = self.cond_layer(g)
154 |
155 | for i in range(self.n_layers):
156 | x_in = self.in_layers[i](x)
157 | if g is not None:
158 | cond_offset = i * 2 * self.hidden_channels
159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
160 | else:
161 | g_l = torch.zeros_like(x_in)
162 |
163 | acts = commons.fused_add_tanh_sigmoid_multiply(
164 | x_in,
165 | g_l,
166 | n_channels_tensor)
167 | acts = self.drop(acts)
168 |
169 | res_skip_acts = self.res_skip_layers[i](acts)
170 | if i < self.n_layers - 1:
171 | res_acts = res_skip_acts[:,:self.hidden_channels,:]
172 | x = (x + res_acts) * x_mask
173 | output = output + res_skip_acts[:,self.hidden_channels:,:]
174 | else:
175 | output = output + res_skip_acts
176 | return output * x_mask
177 |
178 | def remove_weight_norm(self):
179 | if self.gin_channels != 0:
180 | torch.nn.utils.remove_weight_norm(self.cond_layer)
181 | for l in self.in_layers:
182 | torch.nn.utils.remove_weight_norm(l)
183 | for l in self.res_skip_layers:
184 | torch.nn.utils.remove_weight_norm(l)
185 |
186 |
187 | class ResBlock1(torch.nn.Module):
188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
189 | super(ResBlock1, self).__init__()
190 | self.convs1 = nn.ModuleList([
191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
192 | padding=get_padding(kernel_size, dilation[0]))),
193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
194 | padding=get_padding(kernel_size, dilation[1]))),
195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
196 | padding=get_padding(kernel_size, dilation[2])))
197 | ])
198 | self.convs1.apply(init_weights)
199 |
200 | self.convs2 = nn.ModuleList([
201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
202 | padding=get_padding(kernel_size, 1))),
203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
204 | padding=get_padding(kernel_size, 1))),
205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
206 | padding=get_padding(kernel_size, 1)))
207 | ])
208 | self.convs2.apply(init_weights)
209 |
210 | def forward(self, x, x_mask=None):
211 | for c1, c2 in zip(self.convs1, self.convs2):
212 | xt = F.leaky_relu(x, LRELU_SLOPE)
213 | if x_mask is not None:
214 | xt = xt * x_mask
215 | xt = c1(xt)
216 | xt = F.leaky_relu(xt, LRELU_SLOPE)
217 | if x_mask is not None:
218 | xt = xt * x_mask
219 | xt = c2(xt)
220 | x = xt + x
221 | if x_mask is not None:
222 | x = x * x_mask
223 | return x
224 |
225 | def remove_weight_norm(self):
226 | for l in self.convs1:
227 | remove_weight_norm(l)
228 | for l in self.convs2:
229 | remove_weight_norm(l)
230 |
231 |
232 | class ResBlock2(torch.nn.Module):
233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
234 | super(ResBlock2, self).__init__()
235 | self.convs = nn.ModuleList([
236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
237 | padding=get_padding(kernel_size, dilation[0]))),
238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
239 | padding=get_padding(kernel_size, dilation[1])))
240 | ])
241 | self.convs.apply(init_weights)
242 |
243 | def forward(self, x, x_mask=None):
244 | for c in self.convs:
245 | xt = F.leaky_relu(x, LRELU_SLOPE)
246 | if x_mask is not None:
247 | xt = xt * x_mask
248 | xt = c(xt)
249 | x = xt + x
250 | if x_mask is not None:
251 | x = x * x_mask
252 | return x
253 |
254 | def remove_weight_norm(self):
255 | for l in self.convs:
256 | remove_weight_norm(l)
257 |
258 |
259 | class Log(nn.Module):
260 | def forward(self, x, x_mask, reverse=False, **kwargs):
261 | if not reverse:
262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
263 | logdet = torch.sum(-y, [1, 2])
264 | return y, logdet
265 | else:
266 | x = torch.exp(x) * x_mask
267 | return x
268 |
269 |
270 | class Flip(nn.Module):
271 | def forward(self, x, *args, reverse=False, **kwargs):
272 | x = torch.flip(x, [1])
273 | if not reverse:
274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
275 | return x, logdet
276 | else:
277 | return x
278 |
279 |
280 | class ElementwiseAffine(nn.Module):
281 | def __init__(self, channels):
282 | super().__init__()
283 | self.channels = channels
284 | self.m = nn.Parameter(torch.zeros(channels,1))
285 | self.logs = nn.Parameter(torch.zeros(channels,1))
286 |
287 | def forward(self, x, x_mask, reverse=False, **kwargs):
288 | if not reverse:
289 | y = self.m + torch.exp(self.logs) * x
290 | y = y * x_mask
291 | logdet = torch.sum(self.logs * x_mask, [1,2])
292 | return y, logdet
293 | else:
294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
295 | return x
296 |
297 |
298 | class ResidualCouplingLayer(nn.Module):
299 | def __init__(self,
300 | channels,
301 | hidden_channels,
302 | kernel_size,
303 | dilation_rate,
304 | n_layers,
305 | p_dropout=0,
306 | gin_channels=0,
307 | mean_only=False):
308 | assert channels % 2 == 0, "channels should be divisible by 2"
309 | super().__init__()
310 | self.channels = channels
311 | self.hidden_channels = hidden_channels
312 | self.kernel_size = kernel_size
313 | self.dilation_rate = dilation_rate
314 | self.n_layers = n_layers
315 | self.half_channels = channels // 2
316 | self.mean_only = mean_only
317 |
318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
320 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
321 | self.post.weight.data.zero_()
322 | self.post.bias.data.zero_()
323 |
324 | def forward(self, x, x_mask, g=None, reverse=False):
325 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
326 | h = self.pre(x0) * x_mask
327 | h = self.enc(h, x_mask, g=g)
328 | stats = self.post(h) * x_mask
329 | if not self.mean_only:
330 | m, logs = torch.split(stats, [self.half_channels]*2, 1)
331 | else:
332 | m = stats
333 | logs = torch.zeros_like(m)
334 |
335 | if not reverse:
336 | x1 = m + x1 * torch.exp(logs) * x_mask
337 | x = torch.cat([x0, x1], 1)
338 | logdet = torch.sum(logs, [1,2])
339 | return x, logdet
340 | else:
341 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
342 | x = torch.cat([x0, x1], 1)
343 | return x
344 |
345 |
346 | class ConvFlow(nn.Module):
347 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
348 | super().__init__()
349 | self.in_channels = in_channels
350 | self.filter_channels = filter_channels
351 | self.kernel_size = kernel_size
352 | self.n_layers = n_layers
353 | self.num_bins = num_bins
354 | self.tail_bound = tail_bound
355 | self.half_channels = in_channels // 2
356 |
357 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
358 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
359 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
360 | self.proj.weight.data.zero_()
361 | self.proj.bias.data.zero_()
362 |
363 | def forward(self, x, x_mask, g=None, reverse=False):
364 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
365 | h = self.pre(x0)
366 | h = self.convs(h, x_mask, g=g)
367 | h = self.proj(h) * x_mask
368 |
369 | b, c, t = x0.shape
370 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
371 |
372 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
373 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
374 | unnormalized_derivatives = h[..., 2 * self.num_bins:]
375 |
376 | x1, logabsdet = piecewise_rational_quadratic_transform(x1,
377 | unnormalized_widths,
378 | unnormalized_heights,
379 | unnormalized_derivatives,
380 | inverse=reverse,
381 | tails='linear',
382 | tail_bound=self.tail_bound
383 | )
384 |
385 | x = torch.cat([x0, x1], 1) * x_mask
386 | logdet = torch.sum(logabsdet * x_mask, [1,2])
387 | if not reverse:
388 | return x, logdet
389 | else:
390 | return x
391 |
--------------------------------------------------------------------------------
/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .monotonic_align.core import maximum_path_c
4 |
5 |
6 | def maximum_path(neg_cent, mask):
7 | """ Cython optimized version.
8 | neg_cent: [b, t_t, t_s]
9 | mask: [b, t_t, t_s]
10 | """
11 | device = neg_cent.device
12 | dtype = neg_cent.dtype
13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14 | path = np.zeros(neg_cent.shape, dtype=np.int32)
15 |
16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max)
19 | return torch.from_numpy(path).to(device=device, dtype=dtype)
20 |
--------------------------------------------------------------------------------
/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | cimport cython
2 | from cython.parallel import prange
3 |
4 |
5 | @cython.boundscheck(False)
6 | @cython.wraparound(False)
7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
8 | cdef int x
9 | cdef int y
10 | cdef float v_prev
11 | cdef float v_cur
12 | cdef float tmp
13 | cdef int index = t_x - 1
14 |
15 | for y in range(t_y):
16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
17 | if x == y:
18 | v_cur = max_neg_val
19 | else:
20 | v_cur = value[y-1, x]
21 | if x == 0:
22 | if y == 0:
23 | v_prev = 0.
24 | else:
25 | v_prev = max_neg_val
26 | else:
27 | v_prev = value[y-1, x-1]
28 | value[y, x] += max(v_prev, v_cur)
29 |
30 | for y in range(t_y - 1, -1, -1):
31 | path[y, index] = 1
32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
33 | index = index - 1
34 |
35 |
36 | @cython.boundscheck(False)
37 | @cython.wraparound(False)
38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
39 | cdef int b = paths.shape[0]
40 | cdef int i
41 | for i in prange(b, nogil=True):
42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
43 |
--------------------------------------------------------------------------------
/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy
4 |
5 | setup(
6 | name = 'monotonic_align',
7 | ext_modules = cythonize("core.pyx"),
8 | include_dirs=[numpy.get_include()]
9 | )
10 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import text
3 | from utils import load_filepaths_and_text
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--out_extension", default="cleaned")
8 | parser.add_argument("--text_index", default=1, type=int)
9 | parser.add_argument("--filelists", nargs="+", default=["filelists/ljs_audio_text_val_filelist.txt", "filelists/ljs_audio_text_test_filelist.txt"])
10 | parser.add_argument("--text_cleaners", nargs="+", default=["english_cleaners2"])
11 |
12 | args = parser.parse_args()
13 |
14 |
15 | for filelist in args.filelists:
16 | print("START:", filelist)
17 | filepaths_and_text = load_filepaths_and_text(filelist)
18 | for i in range(len(filepaths_and_text)):
19 | original_text = filepaths_and_text[i][args.text_index]
20 | cleaned_text = text._clean_text(original_text, args.text_cleaners)
21 | filepaths_and_text[i][args.text_index] = cleaned_text
22 |
23 | new_filelist = filelist + "." + args.out_extension
24 | with open(new_filelist, "w", encoding="utf-8") as f:
25 | f.writelines(["|".join(x) + "\n" for x in filepaths_and_text])
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython==0.29.21
2 | librosa==0.8.0
3 | matplotlib==3.3.1
4 | numpy==1.18.5
5 | phonemizer==2.2.1
6 | scipy==1.5.2
7 | tensorboard==2.3.0
8 | torch==1.6.0
9 | torchvision==0.7.0
10 | Unidecode==1.1.1
11 |
--------------------------------------------------------------------------------
/resources/fig_1a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaywalnut310/vits/2e561ba58618d021b5b8323d3765880f7e0ecfdb/resources/fig_1a.png
--------------------------------------------------------------------------------
/resources/fig_1b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaywalnut310/vits/2e561ba58618d021b5b8323d3765880f7e0ecfdb/resources/fig_1b.png
--------------------------------------------------------------------------------
/resources/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaywalnut310/vits/2e561ba58618d021b5b8323d3765880f7e0ecfdb/resources/training.png
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from text import cleaners
3 | from text.symbols import symbols
4 |
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 |
11 | def text_to_sequence(text, cleaner_names):
12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13 | Args:
14 | text: string to convert to a sequence
15 | cleaner_names: names of the cleaner functions to run the text through
16 | Returns:
17 | List of integers corresponding to the symbols in the text
18 | '''
19 | sequence = []
20 |
21 | clean_text = _clean_text(text, cleaner_names)
22 | for symbol in clean_text:
23 | symbol_id = _symbol_to_id[symbol]
24 | sequence += [symbol_id]
25 | return sequence
26 |
27 |
28 | def cleaned_text_to_sequence(cleaned_text):
29 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
30 | Args:
31 | text: string to convert to a sequence
32 | Returns:
33 | List of integers corresponding to the symbols in the text
34 | '''
35 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
36 | return sequence
37 |
38 |
39 | def sequence_to_text(sequence):
40 | '''Converts a sequence of IDs back to a string'''
41 | result = ''
42 | for symbol_id in sequence:
43 | s = _id_to_symbol[symbol_id]
44 | result += s
45 | return result
46 |
47 |
48 | def _clean_text(text, cleaner_names):
49 | for name in cleaner_names:
50 | cleaner = getattr(cleaners, name)
51 | if not cleaner:
52 | raise Exception('Unknown cleaner: %s' % name)
53 | text = cleaner(text)
54 | return text
55 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 | import re
16 | from unidecode import unidecode
17 | from phonemizer import phonemize
18 |
19 |
20 | # Regular expression matching whitespace:
21 | _whitespace_re = re.compile(r'\s+')
22 |
23 | # List of (regular expression, replacement) pairs for abbreviations:
24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
25 | ('mrs', 'misess'),
26 | ('mr', 'mister'),
27 | ('dr', 'doctor'),
28 | ('st', 'saint'),
29 | ('co', 'company'),
30 | ('jr', 'junior'),
31 | ('maj', 'major'),
32 | ('gen', 'general'),
33 | ('drs', 'doctors'),
34 | ('rev', 'reverend'),
35 | ('lt', 'lieutenant'),
36 | ('hon', 'honorable'),
37 | ('sgt', 'sergeant'),
38 | ('capt', 'captain'),
39 | ('esq', 'esquire'),
40 | ('ltd', 'limited'),
41 | ('col', 'colonel'),
42 | ('ft', 'fort'),
43 | ]]
44 |
45 |
46 | def expand_abbreviations(text):
47 | for regex, replacement in _abbreviations:
48 | text = re.sub(regex, replacement, text)
49 | return text
50 |
51 |
52 | def expand_numbers(text):
53 | return normalize_numbers(text)
54 |
55 |
56 | def lowercase(text):
57 | return text.lower()
58 |
59 |
60 | def collapse_whitespace(text):
61 | return re.sub(_whitespace_re, ' ', text)
62 |
63 |
64 | def convert_to_ascii(text):
65 | return unidecode(text)
66 |
67 |
68 | def basic_cleaners(text):
69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
70 | text = lowercase(text)
71 | text = collapse_whitespace(text)
72 | return text
73 |
74 |
75 | def transliteration_cleaners(text):
76 | '''Pipeline for non-English text that transliterates to ASCII.'''
77 | text = convert_to_ascii(text)
78 | text = lowercase(text)
79 | text = collapse_whitespace(text)
80 | return text
81 |
82 |
83 | def english_cleaners(text):
84 | '''Pipeline for English text, including abbreviation expansion.'''
85 | text = convert_to_ascii(text)
86 | text = lowercase(text)
87 | text = expand_abbreviations(text)
88 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True)
89 | phonemes = collapse_whitespace(phonemes)
90 | return phonemes
91 |
92 |
93 | def english_cleaners2(text):
94 | '''Pipeline for English text, including abbreviation expansion. + punctuation + stress'''
95 | text = convert_to_ascii(text)
96 | text = lowercase(text)
97 | text = expand_abbreviations(text)
98 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True)
99 | phonemes = collapse_whitespace(phonemes)
100 | return phonemes
101 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 | '''
6 | _pad = '_'
7 | _punctuation = ';:,.!?¡¿—…"«»“” '
8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10 |
11 |
12 | # Export all symbols:
13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
14 |
15 | # Special symbol ids
16 | SPACE_ID = symbols.index(" ")
17 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import itertools
5 | import math
6 | import torch
7 | from torch import nn, optim
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torch.multiprocessing as mp
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.cuda.amp import autocast, GradScaler
15 |
16 | import commons
17 | import utils
18 | from data_utils import (
19 | TextAudioLoader,
20 | TextAudioCollate,
21 | DistributedBucketSampler
22 | )
23 | from models import (
24 | SynthesizerTrn,
25 | MultiPeriodDiscriminator,
26 | )
27 | from losses import (
28 | generator_loss,
29 | discriminator_loss,
30 | feature_loss,
31 | kl_loss
32 | )
33 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
34 | from text.symbols import symbols
35 |
36 |
37 | torch.backends.cudnn.benchmark = True
38 | global_step = 0
39 |
40 |
41 | def main():
42 | """Assume Single Node Multi GPUs Training Only"""
43 | assert torch.cuda.is_available(), "CPU training is not allowed."
44 |
45 | n_gpus = torch.cuda.device_count()
46 | os.environ['MASTER_ADDR'] = 'localhost'
47 | os.environ['MASTER_PORT'] = '80000'
48 |
49 | hps = utils.get_hparams()
50 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
51 |
52 |
53 | def run(rank, n_gpus, hps):
54 | global global_step
55 | if rank == 0:
56 | logger = utils.get_logger(hps.model_dir)
57 | logger.info(hps)
58 | utils.check_git_hash(hps.model_dir)
59 | writer = SummaryWriter(log_dir=hps.model_dir)
60 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
61 |
62 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
63 | torch.manual_seed(hps.train.seed)
64 | torch.cuda.set_device(rank)
65 |
66 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
67 | train_sampler = DistributedBucketSampler(
68 | train_dataset,
69 | hps.train.batch_size,
70 | [32,300,400,500,600,700,800,900,1000],
71 | num_replicas=n_gpus,
72 | rank=rank,
73 | shuffle=True)
74 | collate_fn = TextAudioCollate()
75 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
76 | collate_fn=collate_fn, batch_sampler=train_sampler)
77 | if rank == 0:
78 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
79 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False,
80 | batch_size=hps.train.batch_size, pin_memory=True,
81 | drop_last=False, collate_fn=collate_fn)
82 |
83 | net_g = SynthesizerTrn(
84 | len(symbols),
85 | hps.data.filter_length // 2 + 1,
86 | hps.train.segment_size // hps.data.hop_length,
87 | **hps.model).cuda(rank)
88 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
89 | optim_g = torch.optim.AdamW(
90 | net_g.parameters(),
91 | hps.train.learning_rate,
92 | betas=hps.train.betas,
93 | eps=hps.train.eps)
94 | optim_d = torch.optim.AdamW(
95 | net_d.parameters(),
96 | hps.train.learning_rate,
97 | betas=hps.train.betas,
98 | eps=hps.train.eps)
99 | net_g = DDP(net_g, device_ids=[rank])
100 | net_d = DDP(net_d, device_ids=[rank])
101 |
102 | try:
103 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
104 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
105 | global_step = (epoch_str - 1) * len(train_loader)
106 | except:
107 | epoch_str = 1
108 | global_step = 0
109 |
110 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
111 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
112 |
113 | scaler = GradScaler(enabled=hps.train.fp16_run)
114 |
115 | for epoch in range(epoch_str, hps.train.epochs + 1):
116 | if rank==0:
117 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
118 | else:
119 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
120 | scheduler_g.step()
121 | scheduler_d.step()
122 |
123 |
124 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
125 | net_g, net_d = nets
126 | optim_g, optim_d = optims
127 | scheduler_g, scheduler_d = schedulers
128 | train_loader, eval_loader = loaders
129 | if writers is not None:
130 | writer, writer_eval = writers
131 |
132 | train_loader.batch_sampler.set_epoch(epoch)
133 | global global_step
134 |
135 | net_g.train()
136 | net_d.train()
137 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(train_loader):
138 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
139 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
140 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
141 |
142 | with autocast(enabled=hps.train.fp16_run):
143 | y_hat, l_length, attn, ids_slice, x_mask, z_mask,\
144 | (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths)
145 |
146 | mel = spec_to_mel_torch(
147 | spec,
148 | hps.data.filter_length,
149 | hps.data.n_mel_channels,
150 | hps.data.sampling_rate,
151 | hps.data.mel_fmin,
152 | hps.data.mel_fmax)
153 | y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
154 | y_hat_mel = mel_spectrogram_torch(
155 | y_hat.squeeze(1),
156 | hps.data.filter_length,
157 | hps.data.n_mel_channels,
158 | hps.data.sampling_rate,
159 | hps.data.hop_length,
160 | hps.data.win_length,
161 | hps.data.mel_fmin,
162 | hps.data.mel_fmax
163 | )
164 |
165 | y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
166 |
167 | # Discriminator
168 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
169 | with autocast(enabled=False):
170 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
171 | loss_disc_all = loss_disc
172 | optim_d.zero_grad()
173 | scaler.scale(loss_disc_all).backward()
174 | scaler.unscale_(optim_d)
175 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
176 | scaler.step(optim_d)
177 |
178 | with autocast(enabled=hps.train.fp16_run):
179 | # Generator
180 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
181 | with autocast(enabled=False):
182 | loss_dur = torch.sum(l_length.float())
183 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
184 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
185 |
186 | loss_fm = feature_loss(fmap_r, fmap_g)
187 | loss_gen, losses_gen = generator_loss(y_d_hat_g)
188 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
189 | optim_g.zero_grad()
190 | scaler.scale(loss_gen_all).backward()
191 | scaler.unscale_(optim_g)
192 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
193 | scaler.step(optim_g)
194 | scaler.update()
195 |
196 | if rank==0:
197 | if global_step % hps.train.log_interval == 0:
198 | lr = optim_g.param_groups[0]['lr']
199 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
200 | logger.info('Train Epoch: {} [{:.0f}%]'.format(
201 | epoch,
202 | 100. * batch_idx / len(train_loader)))
203 | logger.info([x.item() for x in losses] + [global_step, lr])
204 |
205 | scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
206 | scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl})
207 |
208 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
209 | scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
210 | scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
211 | image_dict = {
212 | "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
213 | "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
214 | "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
215 | "all/attn": utils.plot_alignment_to_numpy(attn[0,0].data.cpu().numpy())
216 | }
217 | utils.summarize(
218 | writer=writer,
219 | global_step=global_step,
220 | images=image_dict,
221 | scalars=scalar_dict)
222 |
223 | if global_step % hps.train.eval_interval == 0:
224 | evaluate(hps, net_g, eval_loader, writer_eval)
225 | utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
226 | utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
227 | global_step += 1
228 |
229 | if rank == 0:
230 | logger.info('====> Epoch: {}'.format(epoch))
231 |
232 |
233 | def evaluate(hps, generator, eval_loader, writer_eval):
234 | generator.eval()
235 | with torch.no_grad():
236 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(eval_loader):
237 | x, x_lengths = x.cuda(0), x_lengths.cuda(0)
238 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
239 | y, y_lengths = y.cuda(0), y_lengths.cuda(0)
240 |
241 | # remove else
242 | x = x[:1]
243 | x_lengths = x_lengths[:1]
244 | spec = spec[:1]
245 | spec_lengths = spec_lengths[:1]
246 | y = y[:1]
247 | y_lengths = y_lengths[:1]
248 | break
249 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, max_len=1000)
250 | y_hat_lengths = mask.sum([1,2]).long() * hps.data.hop_length
251 |
252 | mel = spec_to_mel_torch(
253 | spec,
254 | hps.data.filter_length,
255 | hps.data.n_mel_channels,
256 | hps.data.sampling_rate,
257 | hps.data.mel_fmin,
258 | hps.data.mel_fmax)
259 | y_hat_mel = mel_spectrogram_torch(
260 | y_hat.squeeze(1).float(),
261 | hps.data.filter_length,
262 | hps.data.n_mel_channels,
263 | hps.data.sampling_rate,
264 | hps.data.hop_length,
265 | hps.data.win_length,
266 | hps.data.mel_fmin,
267 | hps.data.mel_fmax
268 | )
269 | image_dict = {
270 | "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
271 | }
272 | audio_dict = {
273 | "gen/audio": y_hat[0,:,:y_hat_lengths[0]]
274 | }
275 | if global_step == 0:
276 | image_dict.update({"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
277 | audio_dict.update({"gt/audio": y[0,:,:y_lengths[0]]})
278 |
279 | utils.summarize(
280 | writer=writer_eval,
281 | global_step=global_step,
282 | images=image_dict,
283 | audios=audio_dict,
284 | audio_sampling_rate=hps.data.sampling_rate
285 | )
286 | generator.train()
287 |
288 |
289 | if __name__ == "__main__":
290 | main()
291 |
--------------------------------------------------------------------------------
/train_ms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import itertools
5 | import math
6 | import torch
7 | from torch import nn, optim
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torch.multiprocessing as mp
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.cuda.amp import autocast, GradScaler
15 |
16 | import commons
17 | import utils
18 | from data_utils import (
19 | TextAudioSpeakerLoader,
20 | TextAudioSpeakerCollate,
21 | DistributedBucketSampler
22 | )
23 | from models import (
24 | SynthesizerTrn,
25 | MultiPeriodDiscriminator,
26 | )
27 | from losses import (
28 | generator_loss,
29 | discriminator_loss,
30 | feature_loss,
31 | kl_loss
32 | )
33 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
34 | from text.symbols import symbols
35 |
36 |
37 | torch.backends.cudnn.benchmark = True
38 | global_step = 0
39 |
40 |
41 | def main():
42 | """Assume Single Node Multi GPUs Training Only"""
43 | assert torch.cuda.is_available(), "CPU training is not allowed."
44 |
45 | n_gpus = torch.cuda.device_count()
46 | os.environ['MASTER_ADDR'] = 'localhost'
47 | os.environ['MASTER_PORT'] = '80000'
48 |
49 | hps = utils.get_hparams()
50 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
51 |
52 |
53 | def run(rank, n_gpus, hps):
54 | global global_step
55 | if rank == 0:
56 | logger = utils.get_logger(hps.model_dir)
57 | logger.info(hps)
58 | utils.check_git_hash(hps.model_dir)
59 | writer = SummaryWriter(log_dir=hps.model_dir)
60 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
61 |
62 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
63 | torch.manual_seed(hps.train.seed)
64 | torch.cuda.set_device(rank)
65 |
66 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
67 | train_sampler = DistributedBucketSampler(
68 | train_dataset,
69 | hps.train.batch_size,
70 | [32,300,400,500,600,700,800,900,1000],
71 | num_replicas=n_gpus,
72 | rank=rank,
73 | shuffle=True)
74 | collate_fn = TextAudioSpeakerCollate()
75 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
76 | collate_fn=collate_fn, batch_sampler=train_sampler)
77 | if rank == 0:
78 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
79 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False,
80 | batch_size=hps.train.batch_size, pin_memory=True,
81 | drop_last=False, collate_fn=collate_fn)
82 |
83 | net_g = SynthesizerTrn(
84 | len(symbols),
85 | hps.data.filter_length // 2 + 1,
86 | hps.train.segment_size // hps.data.hop_length,
87 | n_speakers=hps.data.n_speakers,
88 | **hps.model).cuda(rank)
89 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
90 | optim_g = torch.optim.AdamW(
91 | net_g.parameters(),
92 | hps.train.learning_rate,
93 | betas=hps.train.betas,
94 | eps=hps.train.eps)
95 | optim_d = torch.optim.AdamW(
96 | net_d.parameters(),
97 | hps.train.learning_rate,
98 | betas=hps.train.betas,
99 | eps=hps.train.eps)
100 | net_g = DDP(net_g, device_ids=[rank])
101 | net_d = DDP(net_d, device_ids=[rank])
102 |
103 | try:
104 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
105 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
106 | global_step = (epoch_str - 1) * len(train_loader)
107 | except:
108 | epoch_str = 1
109 | global_step = 0
110 |
111 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
112 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
113 |
114 | scaler = GradScaler(enabled=hps.train.fp16_run)
115 |
116 | for epoch in range(epoch_str, hps.train.epochs + 1):
117 | if rank==0:
118 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
119 | else:
120 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
121 | scheduler_g.step()
122 | scheduler_d.step()
123 |
124 |
125 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
126 | net_g, net_d = nets
127 | optim_g, optim_d = optims
128 | scheduler_g, scheduler_d = schedulers
129 | train_loader, eval_loader = loaders
130 | if writers is not None:
131 | writer, writer_eval = writers
132 |
133 | train_loader.batch_sampler.set_epoch(epoch)
134 | global global_step
135 |
136 | net_g.train()
137 | net_d.train()
138 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(train_loader):
139 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
140 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
141 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
142 | speakers = speakers.cuda(rank, non_blocking=True)
143 |
144 | with autocast(enabled=hps.train.fp16_run):
145 | y_hat, l_length, attn, ids_slice, x_mask, z_mask,\
146 | (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)
147 |
148 | mel = spec_to_mel_torch(
149 | spec,
150 | hps.data.filter_length,
151 | hps.data.n_mel_channels,
152 | hps.data.sampling_rate,
153 | hps.data.mel_fmin,
154 | hps.data.mel_fmax)
155 | y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
156 | y_hat_mel = mel_spectrogram_torch(
157 | y_hat.squeeze(1),
158 | hps.data.filter_length,
159 | hps.data.n_mel_channels,
160 | hps.data.sampling_rate,
161 | hps.data.hop_length,
162 | hps.data.win_length,
163 | hps.data.mel_fmin,
164 | hps.data.mel_fmax
165 | )
166 |
167 | y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
168 |
169 | # Discriminator
170 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
171 | with autocast(enabled=False):
172 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
173 | loss_disc_all = loss_disc
174 | optim_d.zero_grad()
175 | scaler.scale(loss_disc_all).backward()
176 | scaler.unscale_(optim_d)
177 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
178 | scaler.step(optim_d)
179 |
180 | with autocast(enabled=hps.train.fp16_run):
181 | # Generator
182 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
183 | with autocast(enabled=False):
184 | loss_dur = torch.sum(l_length.float())
185 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
186 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
187 |
188 | loss_fm = feature_loss(fmap_r, fmap_g)
189 | loss_gen, losses_gen = generator_loss(y_d_hat_g)
190 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
191 | optim_g.zero_grad()
192 | scaler.scale(loss_gen_all).backward()
193 | scaler.unscale_(optim_g)
194 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
195 | scaler.step(optim_g)
196 | scaler.update()
197 |
198 | if rank==0:
199 | if global_step % hps.train.log_interval == 0:
200 | lr = optim_g.param_groups[0]['lr']
201 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
202 | logger.info('Train Epoch: {} [{:.0f}%]'.format(
203 | epoch,
204 | 100. * batch_idx / len(train_loader)))
205 | logger.info([x.item() for x in losses] + [global_step, lr])
206 |
207 | scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
208 | scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl})
209 |
210 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
211 | scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
212 | scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
213 | image_dict = {
214 | "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
215 | "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
216 | "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
217 | "all/attn": utils.plot_alignment_to_numpy(attn[0,0].data.cpu().numpy())
218 | }
219 | utils.summarize(
220 | writer=writer,
221 | global_step=global_step,
222 | images=image_dict,
223 | scalars=scalar_dict)
224 |
225 | if global_step % hps.train.eval_interval == 0:
226 | evaluate(hps, net_g, eval_loader, writer_eval)
227 | utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
228 | utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
229 | global_step += 1
230 |
231 | if rank == 0:
232 | logger.info('====> Epoch: {}'.format(epoch))
233 |
234 |
235 | def evaluate(hps, generator, eval_loader, writer_eval):
236 | generator.eval()
237 | with torch.no_grad():
238 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader):
239 | x, x_lengths = x.cuda(0), x_lengths.cuda(0)
240 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
241 | y, y_lengths = y.cuda(0), y_lengths.cuda(0)
242 | speakers = speakers.cuda(0)
243 |
244 | # remove else
245 | x = x[:1]
246 | x_lengths = x_lengths[:1]
247 | spec = spec[:1]
248 | spec_lengths = spec_lengths[:1]
249 | y = y[:1]
250 | y_lengths = y_lengths[:1]
251 | speakers = speakers[:1]
252 | break
253 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000)
254 | y_hat_lengths = mask.sum([1,2]).long() * hps.data.hop_length
255 |
256 | mel = spec_to_mel_torch(
257 | spec,
258 | hps.data.filter_length,
259 | hps.data.n_mel_channels,
260 | hps.data.sampling_rate,
261 | hps.data.mel_fmin,
262 | hps.data.mel_fmax)
263 | y_hat_mel = mel_spectrogram_torch(
264 | y_hat.squeeze(1).float(),
265 | hps.data.filter_length,
266 | hps.data.n_mel_channels,
267 | hps.data.sampling_rate,
268 | hps.data.hop_length,
269 | hps.data.win_length,
270 | hps.data.mel_fmin,
271 | hps.data.mel_fmax
272 | )
273 | image_dict = {
274 | "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
275 | }
276 | audio_dict = {
277 | "gen/audio": y_hat[0,:,:y_hat_lengths[0]]
278 | }
279 | if global_step == 0:
280 | image_dict.update({"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
281 | audio_dict.update({"gt/audio": y[0,:,:y_lengths[0]]})
282 |
283 | utils.summarize(
284 | writer=writer_eval,
285 | global_step=global_step,
286 | images=image_dict,
287 | audios=audio_dict,
288 | audio_sampling_rate=hps.data.sampling_rate
289 | )
290 | generator.train()
291 |
292 |
293 | if __name__ == "__main__":
294 | main()
295 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(inputs,
13 | unnormalized_widths,
14 | unnormalized_heights,
15 | unnormalized_derivatives,
16 | inverse=False,
17 | tails=None,
18 | tail_bound=1.,
19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21 | min_derivative=DEFAULT_MIN_DERIVATIVE):
22 |
23 | if tails is None:
24 | spline_fn = rational_quadratic_spline
25 | spline_kwargs = {}
26 | else:
27 | spline_fn = unconstrained_rational_quadratic_spline
28 | spline_kwargs = {
29 | 'tails': tails,
30 | 'tail_bound': tail_bound
31 | }
32 |
33 | outputs, logabsdet = spline_fn(
34 | inputs=inputs,
35 | unnormalized_widths=unnormalized_widths,
36 | unnormalized_heights=unnormalized_heights,
37 | unnormalized_derivatives=unnormalized_derivatives,
38 | inverse=inverse,
39 | min_bin_width=min_bin_width,
40 | min_bin_height=min_bin_height,
41 | min_derivative=min_derivative,
42 | **spline_kwargs
43 | )
44 | return outputs, logabsdet
45 |
46 |
47 | def searchsorted(bin_locations, inputs, eps=1e-6):
48 | bin_locations[..., -1] += eps
49 | return torch.sum(
50 | inputs[..., None] >= bin_locations,
51 | dim=-1
52 | ) - 1
53 |
54 |
55 | def unconstrained_rational_quadratic_spline(inputs,
56 | unnormalized_widths,
57 | unnormalized_heights,
58 | unnormalized_derivatives,
59 | inverse=False,
60 | tails='linear',
61 | tail_bound=1.,
62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64 | min_derivative=DEFAULT_MIN_DERIVATIVE):
65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66 | outside_interval_mask = ~inside_interval_mask
67 |
68 | outputs = torch.zeros_like(inputs)
69 | logabsdet = torch.zeros_like(inputs)
70 |
71 | if tails == 'linear':
72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73 | constant = np.log(np.exp(1 - min_derivative) - 1)
74 | unnormalized_derivatives[..., 0] = constant
75 | unnormalized_derivatives[..., -1] = constant
76 |
77 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
78 | logabsdet[outside_interval_mask] = 0
79 | else:
80 | raise RuntimeError('{} tails are not implemented.'.format(tails))
81 |
82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89 | min_bin_width=min_bin_width,
90 | min_bin_height=min_bin_height,
91 | min_derivative=min_derivative
92 | )
93 |
94 | return outputs, logabsdet
95 |
96 | def rational_quadratic_spline(inputs,
97 | unnormalized_widths,
98 | unnormalized_heights,
99 | unnormalized_derivatives,
100 | inverse=False,
101 | left=0., right=1., bottom=0., top=1.,
102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104 | min_derivative=DEFAULT_MIN_DERIVATIVE):
105 | if torch.min(inputs) < left or torch.max(inputs) > right:
106 | raise ValueError('Input to a transform is not within its domain')
107 |
108 | num_bins = unnormalized_widths.shape[-1]
109 |
110 | if min_bin_width * num_bins > 1.0:
111 | raise ValueError('Minimal bin width too large for the number of bins')
112 | if min_bin_height * num_bins > 1.0:
113 | raise ValueError('Minimal bin height too large for the number of bins')
114 |
115 | widths = F.softmax(unnormalized_widths, dim=-1)
116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117 | cumwidths = torch.cumsum(widths, dim=-1)
118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119 | cumwidths = (right - left) * cumwidths + left
120 | cumwidths[..., 0] = left
121 | cumwidths[..., -1] = right
122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123 |
124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125 |
126 | heights = F.softmax(unnormalized_heights, dim=-1)
127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128 | cumheights = torch.cumsum(heights, dim=-1)
129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130 | cumheights = (top - bottom) * cumheights + bottom
131 | cumheights[..., 0] = bottom
132 | cumheights[..., -1] = top
133 | heights = cumheights[..., 1:] - cumheights[..., :-1]
134 |
135 | if inverse:
136 | bin_idx = searchsorted(cumheights, inputs)[..., None]
137 | else:
138 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
139 |
140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142 |
143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144 | delta = heights / widths
145 | input_delta = delta.gather(-1, bin_idx)[..., 0]
146 |
147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149 |
150 | input_heights = heights.gather(-1, bin_idx)[..., 0]
151 |
152 | if inverse:
153 | a = (((inputs - input_cumheights) * (input_derivatives
154 | + input_derivatives_plus_one
155 | - 2 * input_delta)
156 | + input_heights * (input_delta - input_derivatives)))
157 | b = (input_heights * input_derivatives
158 | - (inputs - input_cumheights) * (input_derivatives
159 | + input_derivatives_plus_one
160 | - 2 * input_delta))
161 | c = - input_delta * (inputs - input_cumheights)
162 |
163 | discriminant = b.pow(2) - 4 * a * c
164 | assert (discriminant >= 0).all()
165 |
166 | root = (2 * c) / (-b - torch.sqrt(discriminant))
167 | outputs = root * input_bin_widths + input_cumwidths
168 |
169 | theta_one_minus_theta = root * (1 - root)
170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171 | * theta_one_minus_theta)
172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173 | + 2 * input_delta * theta_one_minus_theta
174 | + input_derivatives * (1 - root).pow(2))
175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176 |
177 | return outputs, -logabsdet
178 | else:
179 | theta = (inputs - input_cumwidths) / input_bin_widths
180 | theta_one_minus_theta = theta * (1 - theta)
181 |
182 | numerator = input_heights * (input_delta * theta.pow(2)
183 | + input_derivatives * theta_one_minus_theta)
184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185 | * theta_one_minus_theta)
186 | outputs = input_cumheights + numerator / denominator
187 |
188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189 | + 2 * input_delta * theta_one_minus_theta
190 | + input_derivatives * (1 - theta).pow(2))
191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192 |
193 | return outputs, logabsdet
194 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 |
12 | MATPLOTLIB_FLAG = False
13 |
14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15 | logger = logging
16 |
17 |
18 | def load_checkpoint(checkpoint_path, model, optimizer=None):
19 | assert os.path.isfile(checkpoint_path)
20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21 | iteration = checkpoint_dict['iteration']
22 | learning_rate = checkpoint_dict['learning_rate']
23 | if optimizer is not None:
24 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
25 | saved_state_dict = checkpoint_dict['model']
26 | if hasattr(model, 'module'):
27 | state_dict = model.module.state_dict()
28 | else:
29 | state_dict = model.state_dict()
30 | new_state_dict= {}
31 | for k, v in state_dict.items():
32 | try:
33 | new_state_dict[k] = saved_state_dict[k]
34 | except:
35 | logger.info("%s is not in the checkpoint" % k)
36 | new_state_dict[k] = v
37 | if hasattr(model, 'module'):
38 | model.module.load_state_dict(new_state_dict)
39 | else:
40 | model.load_state_dict(new_state_dict)
41 | logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42 | checkpoint_path, iteration))
43 | return model, optimizer, learning_rate, iteration
44 |
45 |
46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
48 | iteration, checkpoint_path))
49 | if hasattr(model, 'module'):
50 | state_dict = model.module.state_dict()
51 | else:
52 | state_dict = model.state_dict()
53 | torch.save({'model': state_dict,
54 | 'iteration': iteration,
55 | 'optimizer': optimizer.state_dict(),
56 | 'learning_rate': learning_rate}, checkpoint_path)
57 |
58 |
59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
60 | for k, v in scalars.items():
61 | writer.add_scalar(k, v, global_step)
62 | for k, v in histograms.items():
63 | writer.add_histogram(k, v, global_step)
64 | for k, v in images.items():
65 | writer.add_image(k, v, global_step, dataformats='HWC')
66 | for k, v in audios.items():
67 | writer.add_audio(k, v, global_step, audio_sampling_rate)
68 |
69 |
70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
71 | f_list = glob.glob(os.path.join(dir_path, regex))
72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
73 | x = f_list[-1]
74 | print(x)
75 | return x
76 |
77 |
78 | def plot_spectrogram_to_numpy(spectrogram):
79 | global MATPLOTLIB_FLAG
80 | if not MATPLOTLIB_FLAG:
81 | import matplotlib
82 | matplotlib.use("Agg")
83 | MATPLOTLIB_FLAG = True
84 | mpl_logger = logging.getLogger('matplotlib')
85 | mpl_logger.setLevel(logging.WARNING)
86 | import matplotlib.pylab as plt
87 | import numpy as np
88 |
89 | fig, ax = plt.subplots(figsize=(10,2))
90 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
91 | interpolation='none')
92 | plt.colorbar(im, ax=ax)
93 | plt.xlabel("Frames")
94 | plt.ylabel("Channels")
95 | plt.tight_layout()
96 |
97 | fig.canvas.draw()
98 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
99 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
100 | plt.close()
101 | return data
102 |
103 |
104 | def plot_alignment_to_numpy(alignment, info=None):
105 | global MATPLOTLIB_FLAG
106 | if not MATPLOTLIB_FLAG:
107 | import matplotlib
108 | matplotlib.use("Agg")
109 | MATPLOTLIB_FLAG = True
110 | mpl_logger = logging.getLogger('matplotlib')
111 | mpl_logger.setLevel(logging.WARNING)
112 | import matplotlib.pylab as plt
113 | import numpy as np
114 |
115 | fig, ax = plt.subplots(figsize=(6, 4))
116 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
117 | interpolation='none')
118 | fig.colorbar(im, ax=ax)
119 | xlabel = 'Decoder timestep'
120 | if info is not None:
121 | xlabel += '\n\n' + info
122 | plt.xlabel(xlabel)
123 | plt.ylabel('Encoder timestep')
124 | plt.tight_layout()
125 |
126 | fig.canvas.draw()
127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
129 | plt.close()
130 | return data
131 |
132 |
133 | def load_wav_to_torch(full_path):
134 | sampling_rate, data = read(full_path)
135 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
136 |
137 |
138 | def load_filepaths_and_text(filename, split="|"):
139 | with open(filename, encoding='utf-8') as f:
140 | filepaths_and_text = [line.strip().split(split) for line in f]
141 | return filepaths_and_text
142 |
143 |
144 | def get_hparams(init=True):
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
147 | help='JSON file for configuration')
148 | parser.add_argument('-m', '--model', type=str, required=True,
149 | help='Model name')
150 |
151 | args = parser.parse_args()
152 | model_dir = os.path.join("./logs", args.model)
153 |
154 | if not os.path.exists(model_dir):
155 | os.makedirs(model_dir)
156 |
157 | config_path = args.config
158 | config_save_path = os.path.join(model_dir, "config.json")
159 | if init:
160 | with open(config_path, "r") as f:
161 | data = f.read()
162 | with open(config_save_path, "w") as f:
163 | f.write(data)
164 | else:
165 | with open(config_save_path, "r") as f:
166 | data = f.read()
167 | config = json.loads(data)
168 |
169 | hparams = HParams(**config)
170 | hparams.model_dir = model_dir
171 | return hparams
172 |
173 |
174 | def get_hparams_from_dir(model_dir):
175 | config_save_path = os.path.join(model_dir, "config.json")
176 | with open(config_save_path, "r") as f:
177 | data = f.read()
178 | config = json.loads(data)
179 |
180 | hparams =HParams(**config)
181 | hparams.model_dir = model_dir
182 | return hparams
183 |
184 |
185 | def get_hparams_from_file(config_path):
186 | with open(config_path, "r") as f:
187 | data = f.read()
188 | config = json.loads(data)
189 |
190 | hparams =HParams(**config)
191 | return hparams
192 |
193 |
194 | def check_git_hash(model_dir):
195 | source_dir = os.path.dirname(os.path.realpath(__file__))
196 | if not os.path.exists(os.path.join(source_dir, ".git")):
197 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
198 | source_dir
199 | ))
200 | return
201 |
202 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
203 |
204 | path = os.path.join(model_dir, "githash")
205 | if os.path.exists(path):
206 | saved_hash = open(path).read()
207 | if saved_hash != cur_hash:
208 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
209 | saved_hash[:8], cur_hash[:8]))
210 | else:
211 | open(path, "w").write(cur_hash)
212 |
213 |
214 | def get_logger(model_dir, filename="train.log"):
215 | global logger
216 | logger = logging.getLogger(os.path.basename(model_dir))
217 | logger.setLevel(logging.DEBUG)
218 |
219 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
220 | if not os.path.exists(model_dir):
221 | os.makedirs(model_dir)
222 | h = logging.FileHandler(os.path.join(model_dir, filename))
223 | h.setLevel(logging.DEBUG)
224 | h.setFormatter(formatter)
225 | logger.addHandler(h)
226 | return logger
227 |
228 |
229 | class HParams():
230 | def __init__(self, **kwargs):
231 | for k, v in kwargs.items():
232 | if type(v) == dict:
233 | v = HParams(**v)
234 | self[k] = v
235 |
236 | def keys(self):
237 | return self.__dict__.keys()
238 |
239 | def items(self):
240 | return self.__dict__.items()
241 |
242 | def values(self):
243 | return self.__dict__.values()
244 |
245 | def __len__(self):
246 | return len(self.__dict__)
247 |
248 | def __getitem__(self, key):
249 | return getattr(self, key)
250 |
251 | def __setitem__(self, key, value):
252 | return setattr(self, key, value)
253 |
254 | def __contains__(self, key):
255 | return key in self.__dict__
256 |
257 | def __repr__(self):
258 | return self.__dict__.__repr__()
259 |
--------------------------------------------------------------------------------