├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── attentions.py ├── colab_requirements.txt ├── commons.py ├── configs ├── vits3_ljs_base.json ├── vits3_ljs_nosdp.json ├── vits3_vctk_base.json ├── vits3_vctk_base_pr.json └── vits3_vctk_standard.json ├── data_utils.py ├── export_onnx.py ├── filelists ├── ljs_audio_text_test_filelist.txt ├── ljs_audio_text_test_filelist.txt.cleaned ├── ljs_audio_text_train_filelist.txt ├── ljs_audio_text_train_filelist.txt.cleaned ├── ljs_audio_text_val_filelist.txt ├── ljs_audio_text_val_filelist.txt.cleaned ├── vctk_audio_sid_text_test_filelist.txt ├── vctk_audio_sid_text_test_filelist.txt.cleaned ├── vctk_audio_sid_text_train_filelist.txt ├── vctk_audio_sid_text_train_filelist.txt.cleaned ├── vctk_audio_sid_text_train_filelist_new.txt ├── vctk_audio_sid_text_train_filelist_new.txt.cleaned ├── vctk_audio_sid_text_val_filelist.txt ├── vctk_audio_sid_text_val_filelist.txt.cleaned └── vctk_audio_sid_text_val_filelist_new.txt.cleaned ├── infer_onnx.py ├── inference.ipynb ├── inference.py ├── inference_ms.py ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── monotonic_align ├── __init__.py ├── core.pyx └── setup.py ├── preprocess.py ├── preprocess_audio.py ├── requirements.txt ├── text ├── LICENSE ├── __init__.py ├── cleaners.py └── symbols.py ├── train.py ├── train_ms.py ├── transforms.py ├── utils.py └── webui.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | DUMMY1 2 | DUMMY2 3 | DUMMY3 4 | logs 5 | __pycache__ 6 | .ipynb_checkpoints 7 | .*.swp 8 | 9 | build 10 | *.c 11 | monotonic_align/monotonic_align 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 p0p 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Homebrewed VITS-3 with extra flow to improve text encoder's projected normalizing flow distribution and prior loss (WIP 🚧) 2 | ### inspired by VITS-2 and GradTTS 3 | ## Prerequisites 4 | 1. Python >= 3.10 5 | 2. Clone this repository 6 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 7 | 1. You may need to install espeak first: `apt-get install espeak` 8 | 4. Download datasets 9 | 1. Download and extract the LJ Speech dataset, then rename or create a link to the dataset folder: `ln -s /path/to/LJSpeech-1.1/wavs DUMMY1` 10 | 1. For mult-speaker setting, download and extract the VCTK dataset, and downsample wav files to 22050 Hz. Then rename or create a link to the dataset folder: `ln -s /path/to/VCTK-Corpus/downsampled_wavs DUMMY2` 11 | 5. Build Monotonic Alignment Search and run preprocessing if you use your own datasets. 12 | 13 | ```sh 14 | # Cython-version Monotonoic Alignment Search 15 | cd monotonic_align 16 | python setup.py build_ext --inplace 17 | 18 | # Preprocessing (g2p) for your own datasets. Preprocessed phonemes for LJ Speech and VCTK have been already provided. 19 | # python preprocess.py --text_index 1 --filelists filelists/ljs_audio_text_train_filelist.txt filelists/ljs_audio_text_val_filelist.txt filelists/ljs_audio_text_test_filelist.txt 20 | # python preprocess.py --text_index 2 --filelists filelists/vctk_audio_sid_text_train_filelist.txt filelists/vctk_audio_sid_text_val_filelist.txt filelists/vctk_audio_sid_text_test_filelist.txt 21 | ``` 22 | 23 | ## How to run (dry-run) 24 | - model forward pass (dry-run) 25 | ```python 26 | import torch 27 | from models import SynthesizerTrn 28 | 29 | net_g = SynthesizerTrn( 30 | n_vocab=256, 31 | spec_channels=80, 32 | segment_size=8192, 33 | inter_channels=192, 34 | hidden_channels=192, 35 | filter_channels=768, 36 | n_heads=2, 37 | n_layers=6, 38 | kernel_size=3, 39 | p_dropout=0.1, 40 | resblock="1", 41 | resblock_kernel_sizes=[3, 7, 11], 42 | resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 43 | upsample_rates=[8, 8, 2, 2], 44 | upsample_initial_channel=512, 45 | upsample_kernel_sizes=[16, 16, 4, 4], 46 | n_speakers=0, 47 | gin_channels=0, 48 | use_sdp=True, 49 | use_transformer_flows=True, 50 | # (choose from "pre_conv", "fft", "mono_layer_inter_residual", "mono_layer_post_residual") 51 | transformer_flow_type="fft", 52 | use_spk_conditioned_encoder=True, 53 | use_noise_scaled_mas=True, 54 | use_duration_discriminator=True, 55 | ) 56 | 57 | x = torch.LongTensor([[1, 2, 3],[4, 5, 6]]) # token ids 58 | x_lengths = torch.LongTensor([3, 2]) # token lengths 59 | y = torch.randn(2, 80, 100) # mel spectrograms 60 | y_lengths = torch.Tensor([100, 80]) # mel spectrogram lengths 61 | 62 | net_g( 63 | x=x, 64 | x_lengths=x_lengths, 65 | y=y, 66 | y_lengths=y_lengths, 67 | ) 68 | 69 | # calculate loss and backpropagate 70 | ``` 71 | 72 | ## Training Example 73 | 74 | ```sh 75 | # LJ Speech 76 | python train.py -c configs/vits3_ljs_nosdp.json -m ljs_base # no-sdp; (recommended) 77 | python train.py -c configs/vits3_ljs_base.json -m ljs_base # with sdp; 78 | 79 | # VCTK 80 | python train_ms.py -c configs/vits3_vctk_base.json -m vctk_base 81 | ``` 82 | 83 | ## TODOs, features and notes 84 | 85 | - [] Train for LJ Speech and get sample audio 86 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.utils import remove_weight_norm, weight_norm 8 | 9 | import commons 10 | import modules 11 | from modules import LayerNorm 12 | 13 | 14 | class Encoder(nn.Module): # backward compatible vits2 encoder 15 | def __init__( 16 | self, 17 | hidden_channels, 18 | filter_channels, 19 | n_heads, 20 | n_layers, 21 | kernel_size=1, 22 | p_dropout=0.0, 23 | window_size=4, 24 | **kwargs 25 | ): 26 | super().__init__() 27 | self.hidden_channels = hidden_channels 28 | self.filter_channels = filter_channels 29 | self.n_heads = n_heads 30 | self.n_layers = n_layers 31 | self.kernel_size = kernel_size 32 | self.p_dropout = p_dropout 33 | self.window_size = window_size 34 | 35 | self.drop = nn.Dropout(p_dropout) 36 | self.attn_layers = nn.ModuleList() 37 | self.norm_layers_1 = nn.ModuleList() 38 | self.ffn_layers = nn.ModuleList() 39 | self.norm_layers_2 = nn.ModuleList() 40 | # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels 41 | self.cond_layer_idx = self.n_layers 42 | if "gin_channels" in kwargs: 43 | self.gin_channels = kwargs["gin_channels"] 44 | if self.gin_channels != 0: 45 | self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) 46 | # vits2 says 3rd block, so idx is 2 by default 47 | self.cond_layer_idx = ( 48 | kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 49 | ) 50 | assert ( 51 | self.cond_layer_idx < self.n_layers 52 | ), "cond_layer_idx should be less than n_layers" 53 | 54 | for i in range(self.n_layers): 55 | self.attn_layers.append( 56 | MultiHeadAttention( 57 | hidden_channels, 58 | hidden_channels, 59 | n_heads, 60 | p_dropout=p_dropout, 61 | window_size=window_size, 62 | ) 63 | ) 64 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 65 | self.ffn_layers.append( 66 | FFN( 67 | hidden_channels, 68 | hidden_channels, 69 | filter_channels, 70 | kernel_size, 71 | p_dropout=p_dropout, 72 | ) 73 | ) 74 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 75 | 76 | def forward(self, x, x_mask, g=None): 77 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 78 | x = x * x_mask 79 | for i in range(self.n_layers): 80 | if i == self.cond_layer_idx and g is not None: 81 | g = self.spk_emb_linear(g.transpose(1, 2)) 82 | g = g.transpose(1, 2) 83 | x = x + g 84 | x = x * x_mask 85 | y = self.attn_layers[i](x, x, attn_mask) 86 | y = self.drop(y) 87 | x = self.norm_layers_1[i](x + y) 88 | 89 | y = self.ffn_layers[i](x, x_mask) 90 | y = self.drop(y) 91 | x = self.norm_layers_2[i](x + y) 92 | x = x * x_mask 93 | return x 94 | 95 | 96 | class Decoder(nn.Module): 97 | def __init__( 98 | self, 99 | hidden_channels, 100 | filter_channels, 101 | n_heads, 102 | n_layers, 103 | kernel_size=1, 104 | p_dropout=0.0, 105 | proximal_bias=False, 106 | proximal_init=True, 107 | **kwargs 108 | ): 109 | super().__init__() 110 | self.hidden_channels = hidden_channels 111 | self.filter_channels = filter_channels 112 | self.n_heads = n_heads 113 | self.n_layers = n_layers 114 | self.kernel_size = kernel_size 115 | self.p_dropout = p_dropout 116 | self.proximal_bias = proximal_bias 117 | self.proximal_init = proximal_init 118 | 119 | self.drop = nn.Dropout(p_dropout) 120 | self.self_attn_layers = nn.ModuleList() 121 | self.norm_layers_0 = nn.ModuleList() 122 | self.encdec_attn_layers = nn.ModuleList() 123 | self.norm_layers_1 = nn.ModuleList() 124 | self.ffn_layers = nn.ModuleList() 125 | self.norm_layers_2 = nn.ModuleList() 126 | for i in range(self.n_layers): 127 | self.self_attn_layers.append( 128 | MultiHeadAttention( 129 | hidden_channels, 130 | hidden_channels, 131 | n_heads, 132 | p_dropout=p_dropout, 133 | proximal_bias=proximal_bias, 134 | proximal_init=proximal_init, 135 | ) 136 | ) 137 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 138 | self.encdec_attn_layers.append( 139 | MultiHeadAttention( 140 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 141 | ) 142 | ) 143 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 144 | self.ffn_layers.append( 145 | FFN( 146 | hidden_channels, 147 | hidden_channels, 148 | filter_channels, 149 | kernel_size, 150 | p_dropout=p_dropout, 151 | causal=True, 152 | ) 153 | ) 154 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 155 | 156 | def forward(self, x, x_mask, h, h_mask): 157 | """ 158 | x: decoder input 159 | h: encoder output 160 | """ 161 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 162 | device=x.device, dtype=x.dtype 163 | ) 164 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 165 | x = x * x_mask 166 | for i in range(self.n_layers): 167 | y = self.self_attn_layers[i](x, x, self_attn_mask) 168 | y = self.drop(y) 169 | x = self.norm_layers_0[i](x + y) 170 | 171 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 172 | y = self.drop(y) 173 | x = self.norm_layers_1[i](x + y) 174 | 175 | y = self.ffn_layers[i](x, x_mask) 176 | y = self.drop(y) 177 | x = self.norm_layers_2[i](x + y) 178 | x = x * x_mask 179 | return x 180 | 181 | 182 | class MultiHeadAttention(nn.Module): 183 | def __init__( 184 | self, 185 | channels, 186 | out_channels, 187 | n_heads, 188 | p_dropout=0.0, 189 | window_size=None, 190 | heads_share=True, 191 | block_length=None, 192 | proximal_bias=False, 193 | proximal_init=False, 194 | ): 195 | super().__init__() 196 | assert channels % n_heads == 0 197 | 198 | self.channels = channels 199 | self.out_channels = out_channels 200 | self.n_heads = n_heads 201 | self.p_dropout = p_dropout 202 | self.window_size = window_size 203 | self.heads_share = heads_share 204 | self.block_length = block_length 205 | self.proximal_bias = proximal_bias 206 | self.proximal_init = proximal_init 207 | self.attn = None 208 | 209 | self.k_channels = channels // n_heads 210 | self.conv_q = nn.Conv1d(channels, channels, 1) 211 | self.conv_k = nn.Conv1d(channels, channels, 1) 212 | self.conv_v = nn.Conv1d(channels, channels, 1) 213 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 214 | self.drop = nn.Dropout(p_dropout) 215 | 216 | if window_size is not None: 217 | n_heads_rel = 1 if heads_share else n_heads 218 | rel_stddev = self.k_channels**-0.5 219 | self.emb_rel_k = nn.Parameter( 220 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 221 | * rel_stddev 222 | ) 223 | self.emb_rel_v = nn.Parameter( 224 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 225 | * rel_stddev 226 | ) 227 | 228 | nn.init.xavier_uniform_(self.conv_q.weight) 229 | nn.init.xavier_uniform_(self.conv_k.weight) 230 | nn.init.xavier_uniform_(self.conv_v.weight) 231 | if proximal_init: 232 | with torch.no_grad(): 233 | self.conv_k.weight.copy_(self.conv_q.weight) 234 | self.conv_k.bias.copy_(self.conv_q.bias) 235 | 236 | def forward(self, x, c, attn_mask=None): 237 | q = self.conv_q(x) 238 | k = self.conv_k(c) 239 | v = self.conv_v(c) 240 | 241 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 242 | 243 | x = self.conv_o(x) 244 | return x 245 | 246 | def attention(self, query, key, value, mask=None): 247 | # reshape [b, d, t] -> [b, n_h, t, d_k] 248 | b, d, t_s, t_t = (*key.size(), query.size(2)) 249 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 250 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 251 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 252 | 253 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 254 | if self.window_size is not None: 255 | assert ( 256 | t_s == t_t 257 | ), "Relative attention is only available for self-attention." 258 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 259 | rel_logits = self._matmul_with_relative_keys( 260 | query / math.sqrt(self.k_channels), key_relative_embeddings 261 | ) 262 | scores_local = self._relative_position_to_absolute_position(rel_logits) 263 | scores = scores + scores_local 264 | if self.proximal_bias: 265 | assert t_s == t_t, "Proximal bias is only available for self-attention." 266 | scores = scores + self._attention_bias_proximal(t_s).to( 267 | device=scores.device, dtype=scores.dtype 268 | ) 269 | if mask is not None: 270 | scores = scores.masked_fill(mask == 0, -1e4) 271 | if self.block_length is not None: 272 | assert ( 273 | t_s == t_t 274 | ), "Local attention is only available for self-attention." 275 | block_mask = ( 276 | torch.ones_like(scores) 277 | .triu(-self.block_length) 278 | .tril(self.block_length) 279 | ) 280 | scores = scores.masked_fill(block_mask == 0, -1e4) 281 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 282 | p_attn = self.drop(p_attn) 283 | output = torch.matmul(p_attn, value) 284 | if self.window_size is not None: 285 | relative_weights = self._absolute_position_to_relative_position(p_attn) 286 | value_relative_embeddings = self._get_relative_embeddings( 287 | self.emb_rel_v, t_s 288 | ) 289 | output = output + self._matmul_with_relative_values( 290 | relative_weights, value_relative_embeddings 291 | ) 292 | output = ( 293 | output.transpose(2, 3).contiguous().view(b, d, t_t) 294 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 295 | return output, p_attn 296 | 297 | def _matmul_with_relative_values(self, x, y): 298 | """ 299 | x: [b, h, l, m] 300 | y: [h or 1, m, d] 301 | ret: [b, h, l, d] 302 | """ 303 | ret = torch.matmul(x, y.unsqueeze(0)) 304 | return ret 305 | 306 | def _matmul_with_relative_keys(self, x, y): 307 | """ 308 | x: [b, h, l, d] 309 | y: [h or 1, m, d] 310 | ret: [b, h, l, m] 311 | """ 312 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 313 | return ret 314 | 315 | def _get_relative_embeddings(self, relative_embeddings, length): 316 | max_relative_position = 2 * self.window_size + 1 317 | # Pad first before slice to avoid using cond ops. 318 | pad_length = max(length - (self.window_size + 1), 0) 319 | slice_start_position = max((self.window_size + 1) - length, 0) 320 | slice_end_position = slice_start_position + 2 * length - 1 321 | if pad_length > 0: 322 | padded_relative_embeddings = F.pad( 323 | relative_embeddings, 324 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 325 | ) 326 | else: 327 | padded_relative_embeddings = relative_embeddings 328 | used_relative_embeddings = padded_relative_embeddings[ 329 | :, slice_start_position:slice_end_position 330 | ] 331 | return used_relative_embeddings 332 | 333 | def _relative_position_to_absolute_position(self, x): 334 | """ 335 | x: [b, h, l, 2*l-1] 336 | ret: [b, h, l, l] 337 | """ 338 | batch, heads, length, _ = x.size() 339 | # Concat columns of pad to shift from relative to absolute indexing. 340 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 341 | 342 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 343 | x_flat = x.view([batch, heads, length * 2 * length]) 344 | x_flat = F.pad( 345 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 346 | ) 347 | 348 | # Reshape and slice out the padded elements. 349 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 350 | :, :, :length, length - 1 : 351 | ] 352 | return x_final 353 | 354 | def _absolute_position_to_relative_position(self, x): 355 | """ 356 | x: [b, h, l, l] 357 | ret: [b, h, l, 2*l-1] 358 | """ 359 | batch, heads, length, _ = x.size() 360 | # padd along column 361 | x = F.pad( 362 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 363 | ) 364 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 365 | # add 0's in the beginning that will skew the elements after reshape 366 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 367 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 368 | return x_final 369 | 370 | def _attention_bias_proximal(self, length): 371 | """Bias for self-attention to encourage attention to close positions. 372 | Args: 373 | length: an integer scalar. 374 | Returns: 375 | a Tensor with shape [1, 1, length, length] 376 | """ 377 | r = torch.arange(length, dtype=torch.float32) 378 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 379 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 380 | 381 | 382 | class FFN(nn.Module): 383 | def __init__( 384 | self, 385 | in_channels, 386 | out_channels, 387 | filter_channels, 388 | kernel_size, 389 | p_dropout=0.0, 390 | activation=None, 391 | causal=False, 392 | ): 393 | super().__init__() 394 | self.in_channels = in_channels 395 | self.out_channels = out_channels 396 | self.filter_channels = filter_channels 397 | self.kernel_size = kernel_size 398 | self.p_dropout = p_dropout 399 | self.activation = activation 400 | self.causal = causal 401 | 402 | if causal: 403 | self.padding = self._causal_padding 404 | else: 405 | self.padding = self._same_padding 406 | 407 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 408 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 409 | self.drop = nn.Dropout(p_dropout) 410 | 411 | def forward(self, x, x_mask): 412 | x = self.conv_1(self.padding(x * x_mask)) 413 | if self.activation == "gelu": 414 | x = x * torch.sigmoid(1.702 * x) 415 | else: 416 | x = torch.relu(x) 417 | x = self.drop(x) 418 | x = self.conv_2(self.padding(x * x_mask)) 419 | return x * x_mask 420 | 421 | def _causal_padding(self, x): 422 | if self.kernel_size == 1: 423 | return x 424 | pad_l = self.kernel_size - 1 425 | pad_r = 0 426 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 427 | x = F.pad(x, commons.convert_pad_shape(padding)) 428 | return x 429 | 430 | def _same_padding(self, x): 431 | if self.kernel_size == 1: 432 | return x 433 | pad_l = (self.kernel_size - 1) // 2 434 | pad_r = self.kernel_size // 2 435 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 436 | x = F.pad(x, commons.convert_pad_shape(padding)) 437 | return x 438 | 439 | 440 | class Depthwise_Separable_Conv1D(nn.Module): 441 | def __init__( 442 | self, 443 | in_channels, 444 | out_channels, 445 | kernel_size, 446 | stride=1, 447 | padding=0, 448 | dilation=1, 449 | bias=True, 450 | padding_mode="zeros", # TODO: refine this type 451 | device=None, 452 | dtype=None, 453 | ): 454 | super().__init__() 455 | self.depth_conv = nn.Conv1d( 456 | in_channels=in_channels, 457 | out_channels=in_channels, 458 | kernel_size=kernel_size, 459 | groups=in_channels, 460 | stride=stride, 461 | padding=padding, 462 | dilation=dilation, 463 | bias=bias, 464 | padding_mode=padding_mode, 465 | device=device, 466 | dtype=dtype, 467 | ) 468 | self.point_conv = nn.Conv1d( 469 | in_channels=in_channels, 470 | out_channels=out_channels, 471 | kernel_size=1, 472 | bias=bias, 473 | device=device, 474 | dtype=dtype, 475 | ) 476 | 477 | def forward(self, input): 478 | return self.point_conv(self.depth_conv(input)) 479 | 480 | def weight_norm(self): 481 | self.depth_conv = weight_norm(self.depth_conv, name="weight") 482 | self.point_conv = weight_norm(self.point_conv, name="weight") 483 | 484 | def remove_weight_norm(self): 485 | self.depth_conv = remove_weight_norm(self.depth_conv, name="weight") 486 | self.point_conv = remove_weight_norm(self.point_conv, name="weight") 487 | 488 | 489 | class Depthwise_Separable_TransposeConv1D(nn.Module): 490 | def __init__( 491 | self, 492 | in_channels, 493 | out_channels, 494 | kernel_size, 495 | stride=1, 496 | padding=0, 497 | output_padding=0, 498 | bias=True, 499 | dilation=1, 500 | padding_mode="zeros", # TODO: refine this type 501 | device=None, 502 | dtype=None, 503 | ): 504 | super().__init__() 505 | self.depth_conv = nn.ConvTranspose1d( 506 | in_channels=in_channels, 507 | out_channels=in_channels, 508 | kernel_size=kernel_size, 509 | groups=in_channels, 510 | stride=stride, 511 | output_padding=output_padding, 512 | padding=padding, 513 | dilation=dilation, 514 | bias=bias, 515 | padding_mode=padding_mode, 516 | device=device, 517 | dtype=dtype, 518 | ) 519 | self.point_conv = nn.Conv1d( 520 | in_channels=in_channels, 521 | out_channels=out_channels, 522 | kernel_size=1, 523 | bias=bias, 524 | device=device, 525 | dtype=dtype, 526 | ) 527 | 528 | def forward(self, input): 529 | return self.point_conv(self.depth_conv(input)) 530 | 531 | def weight_norm(self): 532 | self.depth_conv = weight_norm(self.depth_conv, name="weight") 533 | self.point_conv = weight_norm(self.point_conv, name="weight") 534 | 535 | def remove_weight_norm(self): 536 | remove_weight_norm(self.depth_conv, name="weight") 537 | remove_weight_norm(self.point_conv, name="weight") 538 | 539 | 540 | def weight_norm_modules(module, name="weight", dim=0): 541 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( 542 | module, Depthwise_Separable_TransposeConv1D 543 | ): 544 | module.weight_norm() 545 | return module 546 | else: 547 | return weight_norm(module, name, dim) 548 | 549 | 550 | def remove_weight_norm_modules(module, name="weight"): 551 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( 552 | module, Depthwise_Separable_TransposeConv1D 553 | ): 554 | module.remove_weight_norm() 555 | else: 556 | remove_weight_norm(module, name) 557 | 558 | 559 | class FFT(nn.Module): 560 | def __init__( 561 | self, 562 | hidden_channels, 563 | filter_channels, 564 | n_heads, 565 | n_layers=1, 566 | kernel_size=1, 567 | p_dropout=0.0, 568 | proximal_bias=False, 569 | proximal_init=True, 570 | isflow=False, 571 | **kwargs 572 | ): 573 | super().__init__() 574 | self.hidden_channels = hidden_channels 575 | self.filter_channels = filter_channels 576 | self.n_heads = n_heads 577 | self.n_layers = n_layers 578 | self.kernel_size = kernel_size 579 | self.p_dropout = p_dropout 580 | self.proximal_bias = proximal_bias 581 | self.proximal_init = proximal_init 582 | if isflow and "gin_channels" in kwargs and kwargs["gin_channels"] > 0: 583 | cond_layer = torch.nn.Conv1d( 584 | kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1 585 | ) 586 | self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) 587 | self.cond_layer = weight_norm_modules(cond_layer, name="weight") 588 | self.gin_channels = kwargs["gin_channels"] 589 | self.drop = nn.Dropout(p_dropout) 590 | self.self_attn_layers = nn.ModuleList() 591 | self.norm_layers_0 = nn.ModuleList() 592 | self.ffn_layers = nn.ModuleList() 593 | self.norm_layers_1 = nn.ModuleList() 594 | for i in range(self.n_layers): 595 | self.self_attn_layers.append( 596 | MultiHeadAttention( 597 | hidden_channels, 598 | hidden_channels, 599 | n_heads, 600 | p_dropout=p_dropout, 601 | proximal_bias=proximal_bias, 602 | proximal_init=proximal_init, 603 | ) 604 | ) 605 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 606 | self.ffn_layers.append( 607 | FFN( 608 | hidden_channels, 609 | hidden_channels, 610 | filter_channels, 611 | kernel_size, 612 | p_dropout=p_dropout, 613 | causal=True, 614 | ) 615 | ) 616 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 617 | 618 | def forward(self, x, x_mask, g=None): 619 | """ 620 | x: decoder input 621 | h: encoder output 622 | """ 623 | if g is not None: 624 | g = self.cond_layer(g) 625 | 626 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 627 | device=x.device, dtype=x.dtype 628 | ) 629 | x = x * x_mask 630 | for i in range(self.n_layers): 631 | if g is not None: 632 | x = self.cond_pre(x) 633 | cond_offset = i * 2 * self.hidden_channels 634 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 635 | x = commons.fused_add_tanh_sigmoid_multiply( 636 | x, g_l, torch.IntTensor([self.hidden_channels]) 637 | ) 638 | y = self.self_attn_layers[i](x, x, self_attn_mask) 639 | y = self.drop(y) 640 | x = self.norm_layers_0[i](x + y) 641 | 642 | y = self.ffn_layers[i](x, x_mask) 643 | y = self.drop(y) 644 | x = self.norm_layers_1[i](x + y) 645 | x = x * x_mask 646 | return x 647 | -------------------------------------------------------------------------------- /colab_requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | librosa==0.8.0 3 | matplotlib==3.3.1 4 | phonemizer==2.2.1 5 | Unidecode==1.1.1 6 | tensorboardX 7 | scipy -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += ( 34 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 35 | ) 36 | return kl 37 | 38 | 39 | def rand_gumbel(shape): 40 | """Sample from the Gumbel distribution, protect from overflows.""" 41 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 42 | return -torch.log(-torch.log(uniform_samples)) 43 | 44 | 45 | def rand_gumbel_like(x): 46 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 47 | return g 48 | 49 | 50 | def slice_segments(x, ids_str, segment_size=4): 51 | ret = torch.zeros_like(x[:, :, :segment_size]) 52 | for i in range(x.size(0)): 53 | idx_str = ids_str[i] 54 | idx_end = idx_str + segment_size 55 | ret[i] = x[i, :, idx_str:idx_end] 56 | return ret 57 | 58 | 59 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 60 | b, d, t = x.size() 61 | if x_lengths is None: 62 | x_lengths = t 63 | ids_str_max = x_lengths - segment_size + 1 64 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 65 | ret = slice_segments(x, ids_str, segment_size) 66 | return ret, ids_str 67 | 68 | 69 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 70 | position = torch.arange(length, dtype=torch.float) 71 | num_timescales = channels // 2 72 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 73 | num_timescales - 1 74 | ) 75 | inv_timescales = min_timescale * torch.exp( 76 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 77 | ) 78 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 79 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 80 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 81 | signal = signal.view(1, channels, length) 82 | return signal 83 | 84 | 85 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 86 | b, channels, length = x.size() 87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 88 | return x + signal.to(dtype=x.dtype, device=x.device) 89 | 90 | 91 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 92 | b, channels, length = x.size() 93 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 94 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 95 | 96 | 97 | def subsequent_mask(length): 98 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 99 | return mask 100 | 101 | 102 | @torch.jit.script 103 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 104 | n_channels_int = n_channels[0] 105 | in_act = input_a + input_b 106 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 107 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 108 | acts = t_act * s_act 109 | return acts 110 | 111 | 112 | def convert_pad_shape(pad_shape): 113 | l = pad_shape[::-1] 114 | pad_shape = [item for sublist in l for item in sublist] 115 | return pad_shape 116 | 117 | 118 | def shift_1d(x): 119 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 120 | return x 121 | 122 | 123 | def sequence_mask(length, max_length=None): 124 | if max_length is None: 125 | max_length = length.max() 126 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 127 | return x.unsqueeze(0) < length.unsqueeze(1) 128 | 129 | 130 | def generate_path(duration, mask): 131 | """ 132 | duration: [b, 1, t_x] 133 | mask: [b, 1, t_y, t_x] 134 | """ 135 | device = duration.device 136 | 137 | b, _, t_y, t_x = mask.shape 138 | cum_duration = torch.cumsum(duration, -1) 139 | 140 | cum_duration_flat = cum_duration.view(b * t_x) 141 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 142 | path = path.view(b, t_x, t_y) 143 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 144 | path = path.unsqueeze(1).transpose(2, 3) * mask 145 | return path 146 | 147 | 148 | def clip_grad_value_(parameters, clip_value, norm_type=2): 149 | if isinstance(parameters, torch.Tensor): 150 | parameters = [parameters] 151 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 152 | norm_type = float(norm_type) 153 | if clip_value is not None: 154 | clip_value = float(clip_value) 155 | 156 | total_norm = 0 157 | for p in parameters: 158 | param_norm = p.grad.data.norm(norm_type) 159 | total_norm += param_norm.item() ** norm_type 160 | if clip_value is not None: 161 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 162 | total_norm = total_norm ** (1.0 / norm_type) 163 | return total_norm 164 | -------------------------------------------------------------------------------- /configs/vits3_ljs_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners2"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 0, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": false, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /configs/vits3_ljs_nosdp.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners2"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 0, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": false, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false, 58 | "use_sdp": false 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /configs/vits3_vctk_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners2"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 109, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false, 58 | "use_sdp": false, 59 | "gin_channels": 256 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /configs/vits3_vctk_base_pr.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners3"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 109, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false, 58 | "use_sdp": false, 59 | "gin_channels": 256 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /configs/vits3_vctk_standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners3"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 109, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv2", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "duration_discriminator_type": "dur_disc_2", 44 | "inter_channels": 192, 45 | "hidden_channels": 192, 46 | "filter_channels": 768, 47 | "n_heads": 2, 48 | "n_layers": 6, 49 | "kernel_size": 3, 50 | "p_dropout": 0.1, 51 | "resblock": "1", 52 | "resblock_kernel_sizes": [3,7,11], 53 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 54 | "upsample_rates": [8,8,2,2], 55 | "upsample_initial_channel": 512, 56 | "upsample_kernel_sizes": [16,16,4,4], 57 | "n_layers_q": 3, 58 | "use_spectral_norm": false, 59 | "use_sdp": true, 60 | "gin_channels": 256 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | 9 | import commons 10 | from mel_processing import (mel_spectrogram_torch, spec_to_mel_torch, 11 | spectrogram_torch) 12 | from text import cleaned_text_to_sequence, text_to_sequence 13 | from utils import load_filepaths_and_text, load_wav_to_torch 14 | 15 | 16 | class TextAudioLoader(torch.utils.data.Dataset): 17 | """ 18 | 1) loads audio, text pairs 19 | 2) normalizes text and converts them to sequences of integers 20 | 3) computes spectrograms from audio files. 21 | """ 22 | 23 | def __init__(self, audiopaths_and_text, hparams): 24 | self.hparams = hparams 25 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 26 | self.text_cleaners = hparams.text_cleaners 27 | self.max_wav_value = hparams.max_wav_value 28 | self.sampling_rate = hparams.sampling_rate 29 | self.filter_length = hparams.filter_length 30 | self.hop_length = hparams.hop_length 31 | self.win_length = hparams.win_length 32 | self.sampling_rate = hparams.sampling_rate 33 | 34 | self.use_mel_spec_posterior = getattr( 35 | hparams, "use_mel_posterior_encoder", False 36 | ) 37 | if self.use_mel_spec_posterior: 38 | self.n_mel_channels = getattr(hparams, "n_mel_channels", 80) 39 | self.cleaned_text = getattr(hparams, "cleaned_text", False) 40 | 41 | self.add_blank = hparams.add_blank 42 | self.min_text_len = getattr(hparams, "min_text_len", 1) 43 | self.max_text_len = getattr(hparams, "max_text_len", 190) 44 | 45 | random.seed(1234) 46 | random.shuffle(self.audiopaths_and_text) 47 | self._filter() 48 | 49 | def _filter(self): 50 | """ 51 | Filter text & store spec lengths 52 | """ 53 | # Store spectrogram lengths for Bucketing 54 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 55 | # spec_length = wav_length // hop_length 56 | 57 | audiopaths_and_text_new = [] 58 | lengths = [] 59 | for audiopath, text in self.audiopaths_and_text: 60 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 61 | audiopaths_and_text_new.append([audiopath, text]) 62 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 63 | self.audiopaths_and_text = audiopaths_and_text_new 64 | self.lengths = lengths 65 | 66 | def get_audio_text_pair(self, audiopath_and_text): 67 | # separate filename and text 68 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 69 | text = self.get_text(text) 70 | spec, wav = self.get_audio(audiopath) 71 | return (text, spec, wav) 72 | 73 | def get_audio(self, filename): 74 | # TODO : if linear spec exists convert to mel from existing linear spec 75 | audio, sampling_rate = load_wav_to_torch(filename) 76 | if sampling_rate != self.sampling_rate: 77 | raise ValueError( 78 | "{} {} SR doesn't match target {} SR".format( 79 | sampling_rate, self.sampling_rate 80 | ) 81 | ) 82 | audio_norm = audio / self.max_wav_value 83 | audio_norm = audio_norm.unsqueeze(0) 84 | spec_filename = filename.replace(".wav", ".spec.pt") 85 | if self.use_mel_spec_posterior: 86 | spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") 87 | if os.path.exists(spec_filename): 88 | spec = torch.load(spec_filename) 89 | else: 90 | if self.use_mel_spec_posterior: 91 | """TODO : (need verification) 92 | if linear spec exists convert to 93 | mel from existing linear spec (uncomment below lines)""" 94 | # if os.path.exists(filename.replace(".wav", ".spec.pt")): 95 | # # spec, n_fft, num_mels, sampling_rate, fmin, fmax 96 | # spec = spec_to_mel_torch( 97 | # torch.load(filename.replace(".wav", ".spec.pt")), 98 | # self.filter_length, self.n_mel_channels, self.sampling_rate, 99 | # self.hparams.mel_fmin, self.hparams.mel_fmax) 100 | spec = mel_spectrogram_torch( 101 | audio_norm, 102 | self.filter_length, 103 | self.n_mel_channels, 104 | self.sampling_rate, 105 | self.hop_length, 106 | self.win_length, 107 | self.hparams.mel_fmin, 108 | self.hparams.mel_fmax, 109 | center=False, 110 | ) 111 | else: 112 | spec = spectrogram_torch( 113 | audio_norm, 114 | self.filter_length, 115 | self.sampling_rate, 116 | self.hop_length, 117 | self.win_length, 118 | center=False, 119 | ) 120 | spec = torch.squeeze(spec, 0) 121 | torch.save(spec, spec_filename) 122 | return spec, audio_norm 123 | 124 | def get_text(self, text): 125 | if self.cleaned_text: 126 | text_norm = cleaned_text_to_sequence(text) 127 | else: 128 | text_norm = text_to_sequence(text, self.text_cleaners) 129 | if self.add_blank: 130 | text_norm = commons.intersperse(text_norm, 0) 131 | text_norm = torch.LongTensor(text_norm) 132 | return text_norm 133 | 134 | def __getitem__(self, index): 135 | return self.get_audio_text_pair(self.audiopaths_and_text[index]) 136 | 137 | def __len__(self): 138 | return len(self.audiopaths_and_text) 139 | 140 | 141 | class TextAudioCollate: 142 | """Zero-pads model inputs and targets""" 143 | 144 | def __init__(self, return_ids=False): 145 | self.return_ids = return_ids 146 | 147 | def __call__(self, batch): 148 | """Collate's training batch from normalized text and aduio 149 | PARAMS 150 | ------ 151 | batch: [text_normalized, spec_normalized, wav_normalized] 152 | """ 153 | # Right zero-pad all one-hot text sequences to max input length 154 | _, ids_sorted_decreasing = torch.sort( 155 | torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True 156 | ) 157 | 158 | max_text_len = max([len(x[0]) for x in batch]) 159 | max_spec_len = max([x[1].size(1) for x in batch]) 160 | max_wav_len = max([x[2].size(1) for x in batch]) 161 | 162 | text_lengths = torch.LongTensor(len(batch)) 163 | spec_lengths = torch.LongTensor(len(batch)) 164 | wav_lengths = torch.LongTensor(len(batch)) 165 | 166 | text_padded = torch.LongTensor(len(batch), max_text_len) 167 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 168 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 169 | text_padded.zero_() 170 | spec_padded.zero_() 171 | wav_padded.zero_() 172 | for i in range(len(ids_sorted_decreasing)): 173 | row = batch[ids_sorted_decreasing[i]] 174 | 175 | text = row[0] 176 | text_padded[i, : text.size(0)] = text 177 | text_lengths[i] = text.size(0) 178 | 179 | spec = row[1] 180 | spec_padded[i, :, : spec.size(1)] = spec 181 | spec_lengths[i] = spec.size(1) 182 | 183 | wav = row[2] 184 | wav_padded[i, :, : wav.size(1)] = wav 185 | wav_lengths[i] = wav.size(1) 186 | 187 | if self.return_ids: 188 | return ( 189 | text_padded, 190 | text_lengths, 191 | spec_padded, 192 | spec_lengths, 193 | wav_padded, 194 | wav_lengths, 195 | ids_sorted_decreasing, 196 | ) 197 | return ( 198 | text_padded, 199 | text_lengths, 200 | spec_padded, 201 | spec_lengths, 202 | wav_padded, 203 | wav_lengths, 204 | ) 205 | 206 | 207 | """Multi speaker version""" 208 | 209 | 210 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 211 | """ 212 | 1) loads audio, speaker_id, text pairs 213 | 2) normalizes text and converts them to sequences of integers 214 | 3) computes spectrograms from audio files. 215 | """ 216 | 217 | def __init__(self, audiopaths_sid_text, hparams): 218 | self.hparams = hparams 219 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 220 | self.text_cleaners = hparams.text_cleaners 221 | self.max_wav_value = hparams.max_wav_value 222 | self.sampling_rate = hparams.sampling_rate 223 | self.filter_length = hparams.filter_length 224 | self.hop_length = hparams.hop_length 225 | self.win_length = hparams.win_length 226 | self.sampling_rate = hparams.sampling_rate 227 | 228 | self.use_mel_spec_posterior = getattr( 229 | hparams, "use_mel_posterior_encoder", False 230 | ) 231 | if self.use_mel_spec_posterior: 232 | self.n_mel_channels = getattr(hparams, "n_mel_channels", 80) 233 | self.cleaned_text = getattr(hparams, "cleaned_text", False) 234 | 235 | self.add_blank = hparams.add_blank 236 | self.min_text_len = getattr(hparams, "min_text_len", 1) 237 | self.max_text_len = getattr(hparams, "max_text_len", 190) 238 | self.min_audio_len = getattr(hparams, "min_audio_len", 8192) 239 | 240 | random.seed(1234) 241 | random.shuffle(self.audiopaths_sid_text) 242 | self._filter() 243 | 244 | def _filter(self): 245 | """ 246 | Filter text & store spec lengths 247 | """ 248 | # Store spectrogram lengths for Bucketing 249 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 250 | # spec_length = wav_length // hop_length 251 | 252 | audiopaths_sid_text_new = [] 253 | lengths = [] 254 | for audiopath, sid, text in self.audiopaths_sid_text: 255 | if not os.path.isfile(audiopath): 256 | continue 257 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 258 | audiopaths_sid_text_new.append([audiopath, sid, text]) 259 | length = os.path.getsize(audiopath) // (2 * self.hop_length) 260 | if length < self.min_audio_len // self.hop_length: 261 | continue 262 | lengths.append(length) 263 | self.audiopaths_sid_text = audiopaths_sid_text_new 264 | self.lengths = lengths 265 | print( 266 | len(self.lengths) 267 | ) # if we use large corpus dataset, we can check how much time it takes. 268 | 269 | def get_audio_text_speaker_pair(self, audiopath_sid_text): 270 | # separate filename, speaker_id and text 271 | audiopath, sid, text = ( 272 | audiopath_sid_text[0], 273 | audiopath_sid_text[1], 274 | audiopath_sid_text[2], 275 | ) 276 | text = self.get_text(text) 277 | spec, wav = self.get_audio(audiopath) 278 | sid = self.get_sid(sid) 279 | return (text, spec, wav, sid) 280 | 281 | def get_audio(self, filename): 282 | # TODO : if linear spec exists convert to mel from existing linear spec 283 | audio, sampling_rate = load_wav_to_torch(filename) 284 | if sampling_rate != self.sampling_rate: 285 | raise ValueError( 286 | "{} {} SR doesn't match target {} SR".format( 287 | sampling_rate, self.sampling_rate 288 | ) 289 | ) 290 | audio_norm = audio / self.max_wav_value 291 | audio_norm = audio_norm.unsqueeze(0) 292 | spec_filename = filename.replace(".wav", ".spec.pt") 293 | if self.use_mel_spec_posterior: 294 | spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") 295 | if os.path.exists(spec_filename): 296 | spec = torch.load(spec_filename) 297 | else: 298 | if self.use_mel_spec_posterior: 299 | """TODO : (need verification) 300 | if linear spec exists convert to 301 | mel from existing linear spec (uncomment below lines)""" 302 | # if os.path.exists(filename.replace(".wav", ".spec.pt")): 303 | # # spec, n_fft, num_mels, sampling_rate, fmin, fmax 304 | # spec = spec_to_mel_torch( 305 | # torch.load(filename.replace(".wav", ".spec.pt")), 306 | # self.filter_length, self.n_mel_channels, self.sampling_rate, 307 | # self.hparams.mel_fmin, self.hparams.mel_fmax) 308 | spec = mel_spectrogram_torch( 309 | audio_norm, 310 | self.filter_length, 311 | self.n_mel_channels, 312 | self.sampling_rate, 313 | self.hop_length, 314 | self.win_length, 315 | self.hparams.mel_fmin, 316 | self.hparams.mel_fmax, 317 | center=False, 318 | ) 319 | else: 320 | spec = spectrogram_torch( 321 | audio_norm, 322 | self.filter_length, 323 | self.sampling_rate, 324 | self.hop_length, 325 | self.win_length, 326 | center=False, 327 | ) 328 | spec = torch.squeeze(spec, 0) 329 | torch.save(spec, spec_filename) 330 | return spec, audio_norm 331 | 332 | def get_text(self, text): 333 | if self.cleaned_text: 334 | text_norm = cleaned_text_to_sequence(text) 335 | else: 336 | text_norm = text_to_sequence(text, self.text_cleaners) 337 | if self.add_blank: 338 | text_norm = commons.intersperse(text_norm, 0) 339 | text_norm = torch.LongTensor(text_norm) 340 | return text_norm 341 | 342 | def get_sid(self, sid): 343 | sid = torch.LongTensor([int(sid)]) 344 | return sid 345 | 346 | def __getitem__(self, index): 347 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) 348 | 349 | def __len__(self): 350 | return len(self.audiopaths_sid_text) 351 | 352 | 353 | class TextAudioSpeakerCollate: 354 | """Zero-pads model inputs and targets""" 355 | 356 | def __init__(self, return_ids=False): 357 | self.return_ids = return_ids 358 | 359 | def __call__(self, batch): 360 | """Collate's training batch from normalized text, audio and speaker identities 361 | PARAMS 362 | ------ 363 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 364 | """ 365 | # Right zero-pad all one-hot text sequences to max input length 366 | _, ids_sorted_decreasing = torch.sort( 367 | torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True 368 | ) 369 | 370 | max_text_len = max([len(x[0]) for x in batch]) 371 | max_spec_len = max([x[1].size(1) for x in batch]) 372 | max_wav_len = max([x[2].size(1) for x in batch]) 373 | 374 | text_lengths = torch.LongTensor(len(batch)) 375 | spec_lengths = torch.LongTensor(len(batch)) 376 | wav_lengths = torch.LongTensor(len(batch)) 377 | sid = torch.LongTensor(len(batch)) 378 | 379 | text_padded = torch.LongTensor(len(batch), max_text_len) 380 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 381 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 382 | text_padded.zero_() 383 | spec_padded.zero_() 384 | wav_padded.zero_() 385 | for i in range(len(ids_sorted_decreasing)): 386 | row = batch[ids_sorted_decreasing[i]] 387 | 388 | text = row[0] 389 | text_padded[i, : text.size(0)] = text 390 | text_lengths[i] = text.size(0) 391 | 392 | spec = row[1] 393 | spec_padded[i, :, : spec.size(1)] = spec 394 | spec_lengths[i] = spec.size(1) 395 | 396 | wav = row[2] 397 | wav_padded[i, :, : wav.size(1)] = wav 398 | wav_lengths[i] = wav.size(1) 399 | 400 | sid[i] = row[3] 401 | 402 | if self.return_ids: 403 | return ( 404 | text_padded, 405 | text_lengths, 406 | spec_padded, 407 | spec_lengths, 408 | wav_padded, 409 | wav_lengths, 410 | sid, 411 | ids_sorted_decreasing, 412 | ) 413 | return ( 414 | text_padded, 415 | text_lengths, 416 | spec_padded, 417 | spec_lengths, 418 | wav_padded, 419 | wav_lengths, 420 | sid, 421 | ) 422 | 423 | 424 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 425 | """ 426 | Maintain similar input lengths in a batch. 427 | Length groups are specified by boundaries. 428 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 429 | 430 | It removes samples which are not included in the boundaries. 431 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 432 | """ 433 | 434 | def __init__( 435 | self, 436 | dataset, 437 | batch_size, 438 | boundaries, 439 | num_replicas=None, 440 | rank=None, 441 | shuffle=True, 442 | ): 443 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 444 | self.lengths = dataset.lengths 445 | self.batch_size = batch_size 446 | self.boundaries = boundaries 447 | 448 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 449 | self.total_size = sum(self.num_samples_per_bucket) 450 | self.num_samples = self.total_size // self.num_replicas 451 | 452 | def _create_buckets(self): 453 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 454 | for i in range(len(self.lengths)): 455 | length = self.lengths[i] 456 | idx_bucket = self._bisect(length) 457 | if idx_bucket != -1: 458 | buckets[idx_bucket].append(i) 459 | 460 | for i in range(len(buckets) - 1, 0, -1): 461 | if len(buckets[i]) == 0: 462 | buckets.pop(i) 463 | self.boundaries.pop(i + 1) 464 | 465 | num_samples_per_bucket = [] 466 | for i in range(len(buckets)): 467 | len_bucket = len(buckets[i]) 468 | total_batch_size = self.num_replicas * self.batch_size 469 | rem = ( 470 | total_batch_size - (len_bucket % total_batch_size) 471 | ) % total_batch_size 472 | num_samples_per_bucket.append(len_bucket + rem) 473 | return buckets, num_samples_per_bucket 474 | 475 | def __iter__(self): 476 | # deterministically shuffle based on epoch 477 | g = torch.Generator() 478 | g.manual_seed(self.epoch) 479 | 480 | indices = [] 481 | if self.shuffle: 482 | for bucket in self.buckets: 483 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 484 | else: 485 | for bucket in self.buckets: 486 | indices.append(list(range(len(bucket)))) 487 | 488 | batches = [] 489 | for i in range(len(self.buckets)): 490 | bucket = self.buckets[i] 491 | len_bucket = len(bucket) 492 | ids_bucket = indices[i] 493 | num_samples_bucket = self.num_samples_per_bucket[i] 494 | 495 | # add extra samples to make it evenly divisible 496 | rem = num_samples_bucket - len_bucket 497 | ids_bucket = ( 498 | ids_bucket 499 | + ids_bucket * (rem // len_bucket) 500 | + ids_bucket[: (rem % len_bucket)] 501 | ) 502 | 503 | # subsample 504 | ids_bucket = ids_bucket[self.rank :: self.num_replicas] 505 | 506 | # batching 507 | for j in range(len(ids_bucket) // self.batch_size): 508 | batch = [ 509 | bucket[idx] 510 | for idx in ids_bucket[ 511 | j * self.batch_size : (j + 1) * self.batch_size 512 | ] 513 | ] 514 | batches.append(batch) 515 | 516 | if self.shuffle: 517 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 518 | batches = [batches[i] for i in batch_ids] 519 | self.batches = batches 520 | 521 | assert len(self.batches) * self.batch_size == self.num_samples 522 | return iter(self.batches) 523 | 524 | def _bisect(self, x, lo=0, hi=None): 525 | if hi is None: 526 | hi = len(self.boundaries) - 1 527 | 528 | if hi > lo: 529 | mid = (hi + lo) // 2 530 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 531 | return mid 532 | elif x <= self.boundaries[mid]: 533 | return self._bisect(x, lo, mid) 534 | else: 535 | return self._bisect(x, mid + 1, hi) 536 | else: 537 | return -1 538 | 539 | def __len__(self): 540 | return self.num_samples // self.batch_size 541 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | # import argparse 2 | # from pathlib import Path 3 | # from typing import Optional 4 | 5 | # import torch 6 | 7 | # import utils 8 | # from models import SynthesizerTrn 9 | # from text.symbols import symbols 10 | 11 | # OPSET_VERSION = 15 12 | 13 | 14 | # def main() -> None: 15 | # torch.manual_seed(1234) 16 | 17 | # parser = argparse.ArgumentParser() 18 | # parser.add_argument( 19 | # "--model-path", required=True, help="Path to model weights (.pth)" 20 | # ) 21 | # parser.add_argument( 22 | # "--config-path", required=True, help="Path to model config (.json)" 23 | # ) 24 | # parser.add_argument("--output", required=True, help="Path to output model (.onnx)") 25 | 26 | # args = parser.parse_args() 27 | 28 | # args.model_path = Path(args.model_path) 29 | # args.config_path = Path(args.config_path) 30 | # args.output = Path(args.output) 31 | # args.output.parent.mkdir(parents=True, exist_ok=True) 32 | 33 | # hps = utils.get_hparams_from_file(args.config_path) 34 | 35 | # if ( 36 | # "use_mel_posterior_encoder" in hps.model.keys() 37 | # and hps.model.use_mel_posterior_encoder == True 38 | # ): 39 | # print("Using mel posterior encoder for VITS2") 40 | # posterior_channels = 80 # vits2 41 | # hps.data.use_mel_posterior_encoder = True 42 | # else: 43 | # print("Using lin posterior encoder for VITS1") 44 | # posterior_channels = hps.data.filter_length // 2 + 1 45 | # hps.data.use_mel_posterior_encoder = False 46 | 47 | # model_g = SynthesizerTrn( 48 | # len(symbols), 49 | # posterior_channels, 50 | # hps.train.segment_size // hps.data.hop_length, 51 | # n_speakers=hps.data.n_speakers, 52 | # **hps.model, 53 | # ) 54 | 55 | # _ = model_g.eval() 56 | 57 | # _ = utils.load_checkpoint(args.model_path, model_g, None) 58 | 59 | # def infer_forward(text, text_lengths, scales, sid=None): 60 | # noise_scale = scales[0] 61 | # length_scale = scales[1] 62 | # noise_scale_w = scales[2] 63 | # audio = model_g.infer( 64 | # text, 65 | # text_lengths, 66 | # noise_scale=noise_scale, 67 | # length_scale=length_scale, 68 | # noise_scale_w=noise_scale_w, 69 | # sid=sid, 70 | # )[0] 71 | 72 | # return audio 73 | 74 | # model_g.forward = infer_forward 75 | 76 | # dummy_input_length = 50 77 | # sequences = torch.randint( 78 | # low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long 79 | # ) 80 | # sequence_lengths = torch.LongTensor([sequences.size(1)]) 81 | 82 | # sid: Optional[torch.LongTensor] = None 83 | # if hps.data.n_speakers > 1: 84 | # sid = torch.LongTensor([0]) 85 | 86 | # # noise, length, noise_w 87 | # scales = torch.FloatTensor([0.667, 1.0, 0.8]) 88 | # dummy_input = (sequences, sequence_lengths, scales, sid) 89 | 90 | # # Export 91 | # torch.onnx.export( 92 | # model=model_g, 93 | # args=dummy_input, 94 | # f=str(args.output), 95 | # verbose=False, 96 | # opset_version=OPSET_VERSION, 97 | # input_names=["input", "input_lengths", "scales", "sid"], 98 | # output_names=["output"], 99 | # dynamic_axes={ 100 | # "input": {0: "batch_size", 1: "phonemes"}, 101 | # "input_lengths": {0: "batch_size"}, 102 | # "output": {0: "batch_size", 1: "time1", 2: "time2"}, 103 | # }, 104 | # ) 105 | 106 | # print(f"Exported model to {args.output}") 107 | 108 | 109 | # if __name__ == "__main__": 110 | # main() 111 | -------------------------------------------------------------------------------- /filelists/ljs_audio_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY1/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read. 2 | DUMMY1/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too. 3 | DUMMY1/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five. 4 | DUMMY1/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect: 5 | DUMMY1/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others. 6 | DUMMY1/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated, 7 | DUMMY1/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others. 8 | DUMMY1/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies 9 | DUMMY1/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery. 10 | DUMMY1/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade. 11 | DUMMY1/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President. 12 | DUMMY1/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four. 13 | DUMMY1/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example." 14 | DUMMY1/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald: 15 | DUMMY1/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here. 16 | DUMMY1/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work. 17 | DUMMY1/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area 18 | DUMMY1/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon 19 | DUMMY1/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote. 20 | DUMMY1/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound, 21 | DUMMY1/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window. 22 | DUMMY1/LJ026-0068.wav|Energy enters the plant, to a small extent, 23 | DUMMY1/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary. 24 | DUMMY1/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized 25 | DUMMY1/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that 26 | DUMMY1/LJ012-0161.wav|he was reported to have fallen away to a shadow. 27 | DUMMY1/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to 28 | DUMMY1/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines. 29 | DUMMY1/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on. 30 | DUMMY1/LJ024-0083.wav|This plan of mine is no attack on the Court; 31 | DUMMY1/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough. 32 | DUMMY1/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup. 33 | DUMMY1/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles. 34 | DUMMY1/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive. 35 | DUMMY1/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen, 36 | DUMMY1/LJ009-0076.wav|We come to the sermon. 37 | DUMMY1/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution. 38 | DUMMY1/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes. 39 | DUMMY1/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art. 40 | DUMMY1/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount 41 | DUMMY1/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy. 42 | DUMMY1/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties, 43 | DUMMY1/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand. 44 | DUMMY1/LJ012-0235.wav|While they were in a state of insensibility the murder was committed. 45 | DUMMY1/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald. 46 | DUMMY1/LJ014-0030.wav|These were damnatory facts which well supported the prosecution. 47 | DUMMY1/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome? 48 | DUMMY1/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters. 49 | DUMMY1/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London, 50 | DUMMY1/LJ028-0275.wav|At last, in the twentieth month, 51 | DUMMY1/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed. 52 | DUMMY1/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm, 53 | DUMMY1/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County, 54 | DUMMY1/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view. 55 | DUMMY1/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning. 56 | DUMMY1/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words, 57 | DUMMY1/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands. 58 | DUMMY1/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy. 59 | DUMMY1/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace, 60 | DUMMY1/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells. 61 | DUMMY1/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true, 62 | DUMMY1/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him, 63 | DUMMY1/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits 64 | DUMMY1/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail. 65 | DUMMY1/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders. 66 | DUMMY1/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal. 67 | DUMMY1/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there. 68 | DUMMY1/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files 69 | DUMMY1/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator. 70 | DUMMY1/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash. 71 | DUMMY1/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood; 72 | DUMMY1/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely. 73 | DUMMY1/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present, 74 | DUMMY1/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him. 75 | DUMMY1/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense. 76 | DUMMY1/LJ008-0294.wav|nearly indefinitely deferred. 77 | DUMMY1/LJ047-0148.wav|On October twenty-five, 78 | DUMMY1/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner. 79 | DUMMY1/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old. 80 | DUMMY1/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male, 81 | DUMMY1/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one. 82 | DUMMY1/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved 83 | DUMMY1/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands. 84 | DUMMY1/LJ012-0250.wav|On the seventh July, eighteen thirty-seven, 85 | DUMMY1/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job. 86 | DUMMY1/LJ016-0138.wav|at a distance from the prison. 87 | DUMMY1/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology. 88 | DUMMY1/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally. 89 | DUMMY1/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline. 90 | DUMMY1/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects. 91 | DUMMY1/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle. 92 | DUMMY1/LJ038-0199.wav|eleven. If I am alive and taken prisoner, 93 | DUMMY1/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came. 94 | DUMMY1/LJ033-0047.wav|I noticed when I went out that the light was on, end quote, 95 | DUMMY1/LJ040-0027.wav|He was never satisfied with anything. 96 | DUMMY1/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly. 97 | DUMMY1/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity. 98 | DUMMY1/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days, 99 | DUMMY1/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston. 100 | DUMMY1/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce. 101 | -------------------------------------------------------------------------------- /filelists/ljs_audio_text_val_filelist.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY1/LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹɪɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wˌʌt ðeɪ hˈɪɹ ænd wˌʌt ðeɪ ɹˈiːd. 2 | DUMMY1/LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː, ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt, tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ, ænd ˈɔːl ðə fˈɜːnɪtʃɚ, ˈaɪ wʊd biː mˈæd æz hˈɛl, tˈuː. 3 | DUMMY1/LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹɪpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪnkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn eɪtˈiːn θˈɜːɾifˈaɪv. 4 | DUMMY1/LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹɪspˈɛkt: 5 | DUMMY1/LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹɪspˈɛkt wʌz tə θɹˈoʊ ðə ɹɪspˌɑːnsəbˈɪlɪɾi ˌɑːn ˈʌðɚz. 6 | DUMMY1/LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛlɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌnkənvˈɪktᵻd pɹˈɪzənɚ, ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt, ænd stˈɪl ʌnkəntˈæmᵻnˌeɪɾᵻd, 7 | DUMMY1/LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔːstˈeɪʃənɚz. hɪz ɐɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz. 8 | DUMMY1/LJ047-0044.wav|ˈɑːswəld wʌz, haʊˈɛvɚ, wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz. hiː dɪnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz 9 | DUMMY1/LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ tʃˈɑːɹlz dʒˈeɪ. kˈæɹɪkˌoʊ, ɐ ɹˈɛzɪdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi. 10 | DUMMY1/LJ048-0194.wav|dˈʊɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛntitˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd. 11 | DUMMY1/LJ049-0026.wav|ˌɑːn əkˈeɪʒən ðə sˈiːkɹət sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt. 12 | DUMMY1/LJ004-0152.wav|ɑːlðˈoʊ æt mˈɪstɚ bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən, ðə fˈɜːst stˈɛp tʊwˈɔːɹdz ɹɪfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˌɛvəntˈiːn sˈɛvəntifˈoːɹ. 13 | DUMMY1/LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni, ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsəsɚɹi tuː "mˌeɪk ɐn ɛɡzˈæmpəl." 14 | DUMMY1/LJ043-0002.wav|ðə wˈɔːɹən kəmˈɪʃən ɹɪpˈoːɹt. baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɑːnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi. tʃˈæptɚ sˈɛvən. lˈiː hˈɑːɹvi ˈɑːswəld: 15 | DUMMY1/LJ009-0114.wav|mˈɪstɚ wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dɪskɹˈaɪbɪŋ ɐnˈʌðɚ ɹɪlˈɪdʒəs sˈɜːvɪs, wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪətli biː ɪnsˈɜːɾᵻd hˈɪɹ. 16 | DUMMY1/LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk. 17 | DUMMY1/LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd. ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzɪz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə 18 | DUMMY1/LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp, hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪlɪˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən 19 | DUMMY1/LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl, kwˈoʊt, wiː hɐd ɐ mˈoʊɾɚkˌeɪd wɛɹɹˈɛvɚ wiː wˈɛnt, ˈɛnd kwˈoʊt. 20 | DUMMY1/LJ031-0070.wav|dˈɑːktɚ klˈɑːɹk, hˌuː mˈoʊst klˈoʊsli ɑːbzˈɜːvd ðə hˈɛd wˈuːnd, 21 | DUMMY1/LJ034-0198.wav|jˈuːɪnz, hˌuː wʌz ɑːnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstɪfˌaɪd ðæt hiː kʊd nˌɑːt dɪskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ. 22 | DUMMY1/LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt, tʊ ɐ smˈɔːl ɛkstˈɛnt, 23 | DUMMY1/LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɑːnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsəsɚɹi. 24 | DUMMY1/LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd 25 | DUMMY1/LJ005-0014.wav|spˈiːkɪŋ ˌɑːn ɐ dɪbˈeɪt ˌɑːn pɹˈɪzən mˈæɾɚz, hiː dᵻklˈɛɹd ðˈæt 26 | DUMMY1/LJ012-0161.wav|hiː wʌz ɹɪpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ. 27 | DUMMY1/LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹɪpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹɪfˈɜːd tuː 28 | DUMMY1/LJ019-0257.wav|hˈɪɹ ðə tɹˈɛdwˈiːl wʌz ɪn jˈuːs, ðɛɹ sˈɛljʊlɚ kɹˈæŋks, ɔːɹ hˈɑːɹdlˈeɪbɚ məʃˈiːnz. 29 | DUMMY1/LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɑːn. 30 | DUMMY1/LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɑːnðə kˈoːɹt; 31 | DUMMY1/LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz, nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz. ˈaɪ hæv hɐd ɪnˈʌf. 32 | DUMMY1/LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp. 33 | DUMMY1/LJ046-0058.wav|dˈʊɹɪŋ hɪz pɹˈɛzɪdənsi, fɹˈæŋklɪn dˈiː. ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹəd dʒˈɜːnɪz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹəd fˈɪfti θˈaʊzənd mˈaɪlz. 34 | DUMMY1/LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ, ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv. 35 | DUMMY1/LJ002-0043.wav|lˈɑːŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾisˈɪks fˈiːt, sˈɪks twˈɛntiθɹˈiː fˈiːt, ænd ðɪ ˈeɪtθ eɪtˈiːn, 36 | DUMMY1/LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən. 37 | DUMMY1/LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hɐd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹɪpɹˈiːv, ænd wɪðˌɪn ɐ fjˈuː ˈaɪʊɹz ʌv ˌɛksɪkjˈuːʃən. 38 | DUMMY1/LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹət sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹɪlˈiːsd ɔːɹ ɛskˈeɪps. 39 | DUMMY1/LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ, ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt. 40 | DUMMY1/LJ042-0096.wav| ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt 41 | DUMMY1/LJ049-0050.wav|hˈɪl hɐd bˈoʊθ fˈiːt ɑːnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mɪsˈɛs kˈɛnədi. 42 | DUMMY1/LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt, nˈuːɡeɪt ɹɪsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntɪz, 43 | DUMMY1/LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs, ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsən ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd. 44 | DUMMY1/LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd. 45 | DUMMY1/LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kənklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɑːnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld. 46 | DUMMY1/LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən. 47 | DUMMY1/LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt, ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ɪlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm? 48 | DUMMY1/LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪt ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz. 49 | DUMMY1/LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪsd ænd ɹɪpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɪz ʌvðə sˈɪɾi ʌv lˈʌndən, 50 | DUMMY1/LJ028-0275.wav|æt lˈæst, ɪnðə twˈɛntiəθ mˈʌnθ, 51 | DUMMY1/LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋplˈeɪs wɪð ɐ tɹˈæpdˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd. 52 | DUMMY1/LJ011-0096.wav|hiː mˈæɹɪd ɐ lˈeɪdi ˈɑːlsoʊ bɪlˈɑːŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz, hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃən, wˈɪtʃ, ænd hɪz ˈoʊn mˈʌni, hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm, 53 | DUMMY1/LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː. kɹˈeɪɡ, ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti, 54 | DUMMY1/LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz, ɡɹˈeɪt lˈɔɪɚz, ɡˈʌvɚnɚz ʌv pɹˈɪzənz, ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː. 55 | DUMMY1/LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst, ɐ səspˈɪʃəs sˈɜːkəmstˌæns, æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ. 56 | DUMMY1/LJ027-0141.wav|ɪz klˈoʊsli ɹɪpɹədˈuːst ɪnðə lˈaɪfhˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ. ˈɔːɹ, ɪn ˈʌðɚ wˈɜːdz, 57 | DUMMY1/LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi, ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz. 58 | DUMMY1/LJ031-0202.wav|mɪsˈɛs kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hɐd sˈɜːvd ɪnðə nˈeɪvi. 59 | DUMMY1/LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈoʊptfɔːɹ pˈiəɹɪəd ʌv pˈiːs, 60 | DUMMY1/LJ016-0288.wav|"mˈʌlɚ, mˈʌlɚ, hiːz ðə mˈæn," tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz, wˌɪtʃ wʌz ɹɪsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz. 61 | DUMMY1/LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ, wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdɪli dɪstˈɪŋɡwɪʃ ðə fˈɑːls fɹʌmðə tɹˈuː, 62 | DUMMY1/LJ018-0081.wav|hɪz dɪfˈɛns bˌiːɪŋ ðæt hiː hɐd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd, bˌʌt ðˈæt, ɑːnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hɐd ɹˈɔŋd hˌɪm, 63 | DUMMY1/LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪnkɹiːs ɪnðə pˈeɪɹoʊlz, ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts 64 | DUMMY1/LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp, bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd, ænd ðə mˈæn wʌz kˈæɹɪd bˈæk tə dʒˈeɪl. 65 | DUMMY1/LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz, ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz. 66 | DUMMY1/LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən, ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl. 67 | DUMMY1/LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs, ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts, ðə hˈaʊskiːpɚ ðˈɛɹ. 68 | DUMMY1/LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛntitˈuː, naɪntˈiːn sˈɪkstiθɹˈiː, fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈiəɹɪəl fɚðə pˌiːˌɑːɹˈɛs dʒˈɛnɚɹəl fˈaɪlz 69 | DUMMY1/LJ017-0044.wav|ænd ðə dˈiːpəst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm, ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn, ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ. 70 | DUMMY1/LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ, ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn, ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ. 71 | DUMMY1/LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɐɹˈɛstᵻd ˌɑːn səspˈɪʃən, ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd; 72 | DUMMY1/LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn, bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd, ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sɪvˈɪɹli. 73 | DUMMY1/LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹɪhˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ. ɑːlðˈoʊ ɪt ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt, 74 | DUMMY1/LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm. 75 | DUMMY1/LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹɪkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɑːŋ ɪn səspˈɛns. 76 | DUMMY1/LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dɪfˈɜːd. 77 | DUMMY1/LJ047-0148.wav|ˌɑːn ɑːktˈoʊbɚ twˈɛntifˈaɪv, 78 | DUMMY1/LJ008-0111.wav|ðeɪ ˈɛntɚd ˈeɪ "stˈoʊn kˈoʊld ɹˈuːm," ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ. 79 | DUMMY1/LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstɪfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld. 80 | DUMMY1/LJ037-0234.wav|mɪsˈɛs mˈɛɹi bɹˈɑːk, ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən, wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl, 81 | DUMMY1/LJ040-0002.wav|tʃˈæptɚ sˈɛvən. lˈiː hˈɑːɹvi ˈɑːswəld: bˈækɡɹaʊnd ænd pˈɑːsəbəl mˈoʊɾɪvz, pˈɑːɹt wˌʌn. 82 | DUMMY1/LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstɪfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bɪkˈʌmɪŋ ɪnvˈɑːlvd 83 | DUMMY1/LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɑːn wˈɑːtʃᵻz, wɜː kˈɛɹfəli ɹɪmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz. 84 | DUMMY1/LJ012-0250.wav|ɑːnðə sˈɛvənθ dʒuːlˈaɪ, eɪtˈiːn θˈɜːɾisˈɛvən, 85 | DUMMY1/LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈɛnɚz tə wˈɜːk baɪ ðə dʒˈɑːb. 86 | DUMMY1/LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən. 87 | DUMMY1/LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ɪsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi. 88 | DUMMY1/LJ031-0134.wav|ˌɑːn wˈʌn əkˈeɪʒən mɪsˈɛs dʒˈɑːnsən, ɐkˈʌmpənɪd baɪ tˈuː sˈiːkɹət sˈɜːvɪs ˈeɪdʒənts, lˈɛft ðə ɹˈuːm tə sˈiː mɪsˈɛs kˈɛnədi ænd mɪsˈɛs kənˈæli. 89 | DUMMY1/LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn. 90 | DUMMY1/LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd, ˈoʊpənd, ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts. 91 | DUMMY1/LJ034-0160.wav|ˌɑːn bɹˈɛnənz sˈʌbsɪkwənt sˈɜːtən aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl. 92 | DUMMY1/LJ038-0199.wav|ɪlˈɛvən. ɪf ˈaɪ æm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ, 93 | DUMMY1/LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈɐd fɔːɹ hˌɪm, ænd ɹɪmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm. 94 | DUMMY1/LJ033-0047.wav|ˈaɪ nˈoʊɾɪsd wɛn ˈaɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɑːn, ˈɛnd kwˈoʊt, 95 | DUMMY1/LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ. 96 | DUMMY1/LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli. 97 | DUMMY1/LJ003-0111.wav|hiː wʌz ɪn kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː, ˈɛnd kwˈoʊt. ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɪˈɑːsɪɾi. 98 | DUMMY1/LJ008-0258.wav|lˈɛt mˌiː ɹɪtɹˈeɪs maɪ stˈɛps, ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz, 99 | DUMMY1/LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæn kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt, mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs, fˈɔːɹt wˈɜːθ, sˌæn æntˈoʊnɪˌoʊ, ænd hjˈuːstən. 100 | DUMMY1/LJ004-0045.wav|mˈɪstɚ stˈɜːdʒᵻz bˈoːɹn, sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ, sˌɜː dʒˈeɪmz skˈɑːɹlɪt, ænd wˈɪljəm wˈɪlbɚfˌoːɹs. 101 | -------------------------------------------------------------------------------- /filelists/vctk_audio_sid_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|It had happened to him. 2 | DUMMY2/p280/p280_148.wav|52|It is open season on the Old Firm. 3 | DUMMY2/p231/p231_320.wav|50|However, he is a coach, and he remains a coach at heart. 4 | DUMMY2/p282/p282_129.wav|83|It is not a U-turn. 5 | DUMMY2/p254/p254_015.wav|41|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain. 6 | DUMMY2/p228/p228_285.wav|57|The songs are just so good. 7 | DUMMY2/p334/p334_307.wav|38|If they don't, they can expect their funding to be cut. 8 | DUMMY2/p287/p287_081.wav|77|I've never seen anything like it. 9 | DUMMY2/p247/p247_083.wav|14|It is a job creation scheme.) 10 | DUMMY2/p264/p264_051.wav|65|We were leading by two goals.) 11 | DUMMY2/p335/p335_058.wav|49|Let's see that increase over the years. 12 | DUMMY2/p236/p236_225.wav|75|There is no quick fix. 13 | DUMMY2/p374/p374_353.wav|11|And that brings us to the point. 14 | DUMMY2/p272/p272_076.wav|69|Sounds like The Sixth Sense? 15 | DUMMY2/p271/p271_152.wav|27|The petition was formally presented at Downing Street yesterday. 16 | DUMMY2/p228/p228_127.wav|57|They've got to account for it. 17 | DUMMY2/p276/p276_223.wav|106|It's been a humbling year. 18 | DUMMY2/p262/p262_248.wav|45|The project has already secured the support of Sir Sean Connery. 19 | DUMMY2/p314/p314_086.wav|51|The team this year is going places. 20 | DUMMY2/p225/p225_038.wav|101|Diving is no part of football. 21 | DUMMY2/p279/p279_088.wav|25|The shareholders will vote to wind up the company on Friday morning. 22 | DUMMY2/p272/p272_018.wav|69|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain. 23 | DUMMY2/p256/p256_098.wav|90|She told The Herald. 24 | DUMMY2/p261/p261_218.wav|100|All will be revealed in due course. 25 | DUMMY2/p265/p265_063.wav|73|IT shouldn't come as a surprise, but it does. 26 | DUMMY2/p314/p314_042.wav|51|It is all about people being assaulted, abused. 27 | DUMMY2/p241/p241_188.wav|86|I wish I could say something. 28 | DUMMY2/p283/p283_111.wav|95|It's good to have a voice. 29 | DUMMY2/p275/p275_006.wav|40|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 30 | DUMMY2/p228/p228_092.wav|57|Today I couldn't run on it. 31 | DUMMY2/p295/p295_343.wav|92|The atmosphere is businesslike. 32 | DUMMY2/p228/p228_187.wav|57|They will run a mile. 33 | DUMMY2/p294/p294_317.wav|104|It didn't put me off. 34 | DUMMY2/p231/p231_445.wav|50|It sounded like a bomb. 35 | DUMMY2/p272/p272_086.wav|69|Today she has been released. 36 | DUMMY2/p255/p255_210.wav|31|It was worth a photograph. 37 | DUMMY2/p229/p229_060.wav|67|And a film maker was born. 38 | DUMMY2/p260/p260_232.wav|81|The Home Office would not release any further details about the group. 39 | DUMMY2/p245/p245_025.wav|59|Johnson was pretty low. 40 | DUMMY2/p333/p333_185.wav|64|This area is perfect for children. 41 | DUMMY2/p244/p244_242.wav|78|He is a man of the people. 42 | DUMMY2/p376/p376_187.wav|71|"It is a terrible loss." 43 | DUMMY2/p239/p239_156.wav|48|It is a good lifestyle. 44 | DUMMY2/p307/p307_037.wav|22|He released a half-dozen solo albums. 45 | DUMMY2/p305/p305_185.wav|54|I am not even thinking about that. 46 | DUMMY2/p272/p272_081.wav|69|It was magic. 47 | DUMMY2/p302/p302_297.wav|30|I'm trying to stay open on that. 48 | DUMMY2/p275/p275_320.wav|40|We are in the end game. 49 | DUMMY2/p239/p239_231.wav|48|Then we will face the Danish champions. 50 | DUMMY2/p268/p268_301.wav|87|It was only later that the condition was diagnosed. 51 | DUMMY2/p336/p336_088.wav|98|They failed to reach agreement yesterday. 52 | DUMMY2/p278/p278_255.wav|10|They made such decisions in London. 53 | DUMMY2/p361/p361_132.wav|79|That got me out. 54 | DUMMY2/p307/p307_146.wav|22|You hope he prevails. 55 | DUMMY2/p244/p244_147.wav|78|They could not ignore the will of parliament, he claimed. 56 | DUMMY2/p294/p294_283.wav|104|This is our unfinished business. 57 | DUMMY2/p283/p283_300.wav|95|I would have the hammer in the crowd. 58 | DUMMY2/p239/p239_079.wav|48|I can understand the frustrations of our fans. 59 | DUMMY2/p264/p264_009.wav|65|There is , according to legend, a boiling pot of gold at one end. ) 60 | DUMMY2/p307/p307_348.wav|22|He did not oppose the divorce. 61 | DUMMY2/p304/p304_308.wav|72|We are the gateway to justice. 62 | DUMMY2/p281/p281_056.wav|36|None has ever been found. 63 | DUMMY2/p267/p267_158.wav|0|We were given a warm and friendly reception. 64 | DUMMY2/p300/p300_169.wav|102|Who do these people think they are? 65 | DUMMY2/p276/p276_177.wav|106|They exist in name alone. 66 | DUMMY2/p228/p228_245.wav|57|It is a policy which has the full support of the minister. 67 | DUMMY2/p300/p300_303.wav|102|I'm wondering what you feel about the youngest. 68 | DUMMY2/p362/p362_247.wav|15|This would give Scotland around eight members. 69 | DUMMY2/p326/p326_031.wav|28|United were in control without always being dominant. 70 | DUMMY2/p361/p361_288.wav|79|I did not think it was very proper. 71 | DUMMY2/p286/p286_145.wav|63|Tiger is not the norm. 72 | DUMMY2/p234/p234_071.wav|3|She did that for the rest of her life. 73 | DUMMY2/p263/p263_296.wav|39|The decision was announced at its annual conference in Dunfermline. 74 | DUMMY2/p323/p323_228.wav|34|She became a heroine of my childhood. 75 | DUMMY2/p280/p280_346.wav|52|It was a bit like having children. 76 | DUMMY2/p333/p333_080.wav|64|But the tragedy did not stop there. 77 | DUMMY2/p226/p226_268.wav|43|That decision is for the British Parliament and people. 78 | DUMMY2/p362/p362_314.wav|15|Is that right? 79 | DUMMY2/p240/p240_047.wav|93|It is so sad. 80 | DUMMY2/p250/p250_207.wav|24|You could feel the heat. 81 | DUMMY2/p273/p273_176.wav|56|Neither side would reveal the details of the offer. 82 | DUMMY2/p316/p316_147.wav|85|And frankly, it's been a while. 83 | DUMMY2/p265/p265_047.wav|73|It is unique. 84 | DUMMY2/p336/p336_353.wav|98|Sometimes you get them, sometimes you don't. 85 | DUMMY2/p230/p230_376.wav|35|This hasn't happened in a vacuum. 86 | DUMMY2/p308/p308_209.wav|107|There is great potential on this river. 87 | DUMMY2/p250/p250_442.wav|24|We have not yet received a letter from the Irish. 88 | DUMMY2/p260/p260_037.wav|81|It's a fact. 89 | DUMMY2/p299/p299_345.wav|58|We're very excited and challenged by the project. 90 | DUMMY2/p269/p269_218.wav|94|A Grampian Police spokesman said. 91 | DUMMY2/p306/p306_014.wav|12|To the Hebrews it was a token that there would be no more universal floods. 92 | DUMMY2/p271/p271_292.wav|27|It's a record label, not a form of music. 93 | DUMMY2/p247/p247_225.wav|14|I am considered a teenager.) 94 | DUMMY2/p294/p294_094.wav|104|It should be a condition of employment. 95 | DUMMY2/p269/p269_031.wav|94|Is this accurate? 96 | DUMMY2/p275/p275_116.wav|40|It's not fair. 97 | DUMMY2/p265/p265_006.wav|73|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 98 | DUMMY2/p285/p285_072.wav|2|Mr Irvine said Mr Rafferty was now in good spirits. 99 | DUMMY2/p270/p270_167.wav|8|We did what we had to do. 100 | DUMMY2/p360/p360_397.wav|60|It is a relief. 101 | -------------------------------------------------------------------------------- /filelists/vctk_audio_sid_text_val_filelist.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm. 2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm. 3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt. 4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn. 5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn. 6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd. 7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt. 8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt. 9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm. 10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz. 11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz. 12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks. 13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt. 14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns? 15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ. 16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt. 17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ. 18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi. 19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz. 20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl. 21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ. 22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn. 23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld. 24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs. 25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz. 26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd. 27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ. 28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs. 29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt. 31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk. 32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl. 33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf. 34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm. 35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd. 36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf. 37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn. 38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp. 39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ. 40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən. 41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl. 42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs." 43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl. 44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz. 45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt. 46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk. 47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt. 48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm. 49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz. 50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd. 51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ. 52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən. 53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt. 54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz. 55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd. 56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs. 57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd. 58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz. 59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd. 60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs. 61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs. 62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd. 63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən. 64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ? 65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn. 66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ. 67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst. 68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz. 69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt. 70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ. 71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm. 72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf. 73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn. 74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd. 75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən. 76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ. 77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl. 78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt? 79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd. 80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt. 81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ. 82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl. 83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk. 84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt. 85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm. 86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ. 87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ. 88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt. 89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt. 90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd. 91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz. 92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk. 93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ. 94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt. 95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət? 96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ. 97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts. 99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː. 100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf. 101 | -------------------------------------------------------------------------------- /filelists/vctk_audio_sid_text_val_filelist_new.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm. 2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm. 3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt. 4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn. 5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn. 6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd. 7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt. 8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt. 9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm. 10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz. 11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz. 12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks. 13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt. 14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns? 15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ. 16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt. 17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ. 18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi. 19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz. 20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl. 21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ. 22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn. 23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld. 24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs. 25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz. 26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd. 27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ. 28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs. 29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt. 31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk. 32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl. 33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf. 34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm. 35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd. 36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf. 37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn. 38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp. 39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ. 40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən. 41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl. 42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs." 43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl. 44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz. 45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt. 46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk. 47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt. 48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm. 49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz. 50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd. 51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ. 52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən. 53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt. 54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz. 55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd. 56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs. 57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd. 58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz. 59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd. 60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs. 61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs. 62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd. 63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən. 64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ? 65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn. 66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ. 67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst. 68 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt. 69 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ. 70 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm. 71 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf. 72 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn. 73 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd. 74 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən. 75 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ. 76 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl. 77 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd. 78 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt. 79 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ. 80 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl. 81 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk. 82 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt. 83 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm. 84 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ. 85 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ. 86 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt. 87 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt. 88 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd. 89 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz. 90 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk. 91 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ. 92 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt. 93 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət? 94 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ. 95 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 96 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts. 97 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː. 98 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf. 99 | -------------------------------------------------------------------------------- /infer_onnx.py: -------------------------------------------------------------------------------- 1 | # import argparse 2 | 3 | # import numpy as np 4 | # import onnxruntime 5 | # import torch 6 | # from scipy.io.wavfile import write 7 | 8 | # import commons 9 | # import utils 10 | # from text import text_to_sequence 11 | 12 | 13 | # def get_text(text, hps): 14 | # text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | # if hps.data.add_blank: 16 | # text_norm = commons.intersperse(text_norm, 0) 17 | # text_norm = torch.LongTensor(text_norm) 18 | # return text_norm 19 | 20 | 21 | # def main() -> None: 22 | # parser = argparse.ArgumentParser() 23 | # parser.add_argument("--model", required=True, help="Path to model (.onnx)") 24 | # parser.add_argument( 25 | # "--config-path", required=True, help="Path to model config (.json)" 26 | # ) 27 | # parser.add_argument( 28 | # "--output-wav-path", required=True, help="Path to write WAV file" 29 | # ) 30 | # parser.add_argument("--text", required=True, type=str, help="Text to synthesize") 31 | # parser.add_argument("--sid", required=False, type=int, help="Speaker ID to synthesize") 32 | # args = parser.parse_args() 33 | 34 | # sess_options = onnxruntime.SessionOptions() 35 | # model = onnxruntime.InferenceSession(str(args.model), sess_options=sess_options, providers=["CPUExecutionProvider"]) 36 | 37 | # hps = utils.get_hparams_from_file(args.config_path) 38 | 39 | # phoneme_ids = get_text(args.text, hps) 40 | # text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0) 41 | # text_lengths = np.array([text.shape[1]], dtype=np.int64) 42 | # scales = np.array([0.667, 1.0, 0.8], dtype=np.float32) 43 | # sid = np.array([int(args.sid)]) if args.sid is not None else None 44 | 45 | # audio = model.run( 46 | # None, 47 | # { 48 | # "input": text, 49 | # "input_lengths": text_lengths, 50 | # "scales": scales, 51 | # "sid": sid, 52 | # }, 53 | # )[0].squeeze((0, 1)) 54 | 55 | # write(data=audio, rate=hps.data.sampling_rate, filename=args.output_wav_path) 56 | 57 | 58 | # if __name__ == "__main__": 59 | # main() 60 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d1097e33-73b2-4001-a76f-782c0cd17644", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "%matplotlib inline\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import IPython.display as ipd\n", 15 | "\n", 16 | "import os\n", 17 | "import json\n", 18 | "import math\n", 19 | "import torch\n", 20 | "from torch import nn\n", 21 | "from torch.nn import functional as F\n", 22 | "from torch.utils.data import DataLoader\n", 23 | "\n", 24 | "import commons\n", 25 | "import utils\n", 26 | "from data_utils import (\n", 27 | " TextAudioLoader,\n", 28 | " TextAudioCollate,\n", 29 | " TextAudioSpeakerLoader,\n", 30 | " TextAudioSpeakerCollate,\n", 31 | ")\n", 32 | "from models import SynthesizerTrn\n", 33 | "from text.symbols import symbols\n", 34 | "from text import text_to_sequence\n", 35 | "\n", 36 | "from scipy.io.wavfile import write\n", 37 | "\n", 38 | "\n", 39 | "def get_text(text, hps):\n", 40 | " text_norm = text_to_sequence(text, hps.data.text_cleaners)\n", 41 | " if hps.data.add_blank:\n", 42 | " text_norm = commons.intersperse(text_norm, 0)\n", 43 | " text_norm = torch.LongTensor(text_norm)\n", 44 | " return text_norm" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "9b653abf-4b03-47d6-80c3-8a4405ba9e56", 51 | "metadata": { 52 | "tags": [] 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "device = torch.device(\"cpu\")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "ec35f535-6d34-467c-bf33-aa7c1e673f34", 62 | "metadata": { 63 | "jp-MarkdownHeadingCollapsed": true, 64 | "tags": [] 65 | }, 66 | "source": [ 67 | "# LJSpeech" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "eb5302d7-7e4b-41c2-a39e-4fec7e46403c", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "hps = utils.get_hparams_from_file(\"./configs/vits3_ljs_base.json\")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "074d0c26-5a06-4878-b8da-361f68bd45c5", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "net_g = SynthesizerTrn(\n", 88 | " len(symbols),\n", 89 | " hps.data.filter_length // 2 + 1,\n", 90 | " hps.train.segment_size // hps.data.hop_length,\n", 91 | " **hps.model).cuda()\n", 92 | "_ = net_g.eval()\n", 93 | "\n", 94 | "_ = utils.load_checkpoint(\"/path/to/pretrained_ljs.pth\", net_g, None)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "196962fc-bd97-4b00-9194-e7c0d899e69a", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "stn_tst = get_text(\"VITS is Awesome!\", hps)\n", 105 | "with torch.no_grad():\n", 106 | " x_tst = stn_tst.cuda().unsqueeze(0)\n", 107 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n", 108 | " audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n", 109 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "267ffb0b-d2f0-4e1a-9450-5fc593075810", 115 | "metadata": {}, 116 | "source": [ 117 | "# VCTK" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "51a2ddc8-3cf2-459c-837f-a3fb08261b7d", 124 | "metadata": { 125 | "tags": [] 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "hps = utils.get_hparams_from_file(\"./configs/vits3_vctk_base2.json\")" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "8a5d8345-cae7-4137-a700-0474a846112f", 136 | "metadata": { 137 | "tags": [] 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "if hps.model.use_mel_posterior_encoder == True:\n", 142 | " print(\"Using mel posterior encoder for VITS3\")\n", 143 | " posterior_channels = 80\n", 144 | " hps.data.use_mel_posterior_encoder = True\n", 145 | "else:\n", 146 | " print(\"Using lin posterior encoder for VITS1\")\n", 147 | " posterior_channels = hps.data.filter_length // 2 + 1\n", 148 | " hps.data.use_mel_posterior_encoder = False\n", 149 | "\n", 150 | "net_g = SynthesizerTrn(\n", 151 | " len(symbols),\n", 152 | " hps.data.n_mel_channels,\n", 153 | " None,\n", 154 | " n_speakers=hps.data.n_speakers,\n", 155 | " **hps.model,\n", 156 | ").to(device)\n", 157 | "_ = net_g.eval()\n", 158 | "\n", 159 | "_ = utils.load_checkpoint(\"/path/to/the/pretrained.pth\", net_g, None)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "256a6792-da17-4f86-8b67-f7469200a252", 166 | "metadata": { 167 | "tags": [] 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "text = \"\"\"VITS3 is Awesome!\"\"\"\n", 172 | "sid = 4\n", 173 | "\n", 174 | "stn_tst = get_text(text, hps)\n", 175 | "with torch.no_grad():\n", 176 | " x_tst = stn_tst.to(device).unsqueeze(0)\n", 177 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)\n", 178 | " sid = torch.LongTensor([int(sid)]).to(device)\n", 179 | " audio = (\n", 180 | " net_g.infer(\n", 181 | " x_tst,\n", 182 | " x_tst_lengths,\n", 183 | " sid=sid,\n", 184 | " noise_scale=0.667,\n", 185 | " noise_scale_w=0.8,\n", 186 | " length_scale=1,\n", 187 | " )[0][0, 0]\n", 188 | " .data.cpu()\n", 189 | " .float()\n", 190 | " .numpy()\n", 191 | " )\n", 192 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "7f0753ef-8103-4ef5-b431-3067a86e92b7", 198 | "metadata": {}, 199 | "source": [ 200 | "# Voice Conversion" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "92aaa17e-19fa-4cc9-935b-02a6a4f897de", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)\n", 211 | "collate_fn = TextAudioSpeakerCollate()\n", 212 | "loader = DataLoader(dataset, num_workers=0, shuffle=False,\n", 213 | " batch_size=1, pin_memory=False,\n", 214 | " drop_last=True, collate_fn=collate_fn)\n", 215 | "data_list = list(loader)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "0bb98f74-78ed-4ea1-894d-120af96335c3", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "with torch.no_grad():\n", 226 | " x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.to(device) for x in data_list[0]]\n", 227 | " sid_tgt1 = torch.LongTensor([1]).to(device)\n", 228 | " sid_tgt2 = torch.LongTensor([2]).to(device)\n", 229 | " sid_tgt3 = torch.LongTensor([4]).to(device)\n", 230 | " audio1 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data.cpu().float().numpy()\n", 231 | " audio2 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt2)[0][0,0].data.cpu().float().numpy()\n", 232 | " audio3 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt3)[0][0,0].data.cpu().float().numpy()\n", 233 | "print(\"Original SID: %d\" % sid_src.item())\n", 234 | "ipd.display(ipd.Audio(y[0].cpu().numpy(), rate=hps.data.sampling_rate, normalize=False))\n", 235 | "print(\"Converted SID: %d\" % sid_tgt1.item())\n", 236 | "ipd.display(ipd.Audio(audio1, rate=hps.data.sampling_rate, normalize=False))\n", 237 | "print(\"Converted SID: %d\" % sid_tgt2.item())\n", 238 | "ipd.display(ipd.Audio(audio2, rate=hps.data.sampling_rate, normalize=False))\n", 239 | "print(\"Converted SID: %d\" % sid_tgt3.item())\n", 240 | "ipd.display(ipd.Audio(audio3, rate=hps.data.sampling_rate, normalize=False))" 241 | ] 242 | } 243 | ], 244 | "metadata": { 245 | "kernelspec": { 246 | "display_name": "Python 3 (ipykernel)", 247 | "language": "python", 248 | "name": "python3" 249 | }, 250 | "language_info": { 251 | "codemirror_mode": { 252 | "name": "ipython", 253 | "version": 3 254 | }, 255 | "file_extension": ".py", 256 | "mimetype": "text/x-python", 257 | "name": "python", 258 | "nbconvert_exporter": "python", 259 | "pygments_lexer": "ipython3", 260 | "version": "3.10.8" 261 | } 262 | }, 263 | "nbformat": 4, 264 | "nbformat_minor": 5 265 | } 266 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ## LJSpeech 2 | import torch 3 | 4 | import commons 5 | import utils 6 | from models import SynthesizerTrn 7 | from text.symbols import symbols 8 | from text import text_to_sequence 9 | 10 | from scipy.io.wavfile import write 11 | 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | 21 | CONFIG_PATH = "./configs/vits3_ljs_nosdp.json" 22 | MODEL_PATH = "./logs/G_114000.pth" 23 | TEXT = "VITS-3 is Awesome!" 24 | OUTPUT_WAV_PATH = "sample_vits3.wav" 25 | 26 | hps = utils.get_hparams_from_file(CONFIG_PATH) 27 | 28 | if ( 29 | "use_mel_posterior_encoder" in hps.model.keys() 30 | and hps.model.use_mel_posterior_encoder == True 31 | ): 32 | print("Using mel posterior encoder for VITS3") 33 | posterior_channels = 80 # vits2 34 | hps.data.use_mel_posterior_encoder = True 35 | else: 36 | print("Using lin posterior encoder for VITS1") 37 | posterior_channels = hps.data.filter_length // 2 + 1 38 | hps.data.use_mel_posterior_encoder = False 39 | 40 | net_g = SynthesizerTrn( 41 | len(symbols), 42 | posterior_channels, 43 | hps.train.segment_size // hps.data.hop_length, 44 | **hps.model 45 | ).cuda() 46 | _ = net_g.eval() 47 | 48 | _ = utils.load_checkpoint(MODEL_PATH, net_g, None) 49 | 50 | stn_tst = get_text(TEXT, hps) 51 | with torch.no_grad(): 52 | x_tst = stn_tst.cuda().unsqueeze(0) 53 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 54 | audio = ( 55 | net_g.infer( 56 | x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 57 | )[0][0, 0] 58 | .data.cpu() 59 | .float() 60 | .numpy() 61 | ) 62 | 63 | write(data=audio, rate=hps.data.sampling_rate, filename=OUTPUT_WAV_PATH) 64 | -------------------------------------------------------------------------------- /inference_ms.py: -------------------------------------------------------------------------------- 1 | ## VCTK 2 | import torch 3 | 4 | import commons 5 | import utils 6 | from models import SynthesizerTrn 7 | from text.symbols import symbols 8 | from text import text_to_sequence 9 | 10 | from scipy.io.wavfile import write 11 | 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | 21 | CONFIG_PATH = "./configs/vits3_vctk_base.json" 22 | MODEL_PATH = "/path/to/pretrained_vctk.pth" 23 | TEXT = "VITS-2 is Awesome!" 24 | SPK_ID = 4 25 | OUTPUT_WAV_PATH = "sample_vits3_ms.wav" 26 | 27 | hps = utils.get_hparams_from_file(CONFIG_PATH) 28 | 29 | if ( 30 | "use_mel_posterior_encoder" in hps.model.keys() 31 | and hps.model.use_mel_posterior_encoder == True 32 | ): 33 | print("Using mel posterior encoder for VITS2") 34 | posterior_channels = 80 # vits2 35 | hps.data.use_mel_posterior_encoder = True 36 | else: 37 | print("Using lin posterior encoder for VITS1") 38 | posterior_channels = hps.data.filter_length // 2 + 1 39 | hps.data.use_mel_posterior_encoder = False 40 | 41 | net_g = SynthesizerTrn( 42 | len(symbols), 43 | posterior_channels, 44 | hps.train.segment_size // hps.data.hop_length, 45 | n_speakers=hps.data.n_speakers, 46 | **hps.model 47 | ).cuda() 48 | _ = net_g.eval() 49 | 50 | _ = utils.load_checkpoint(MODEL_PATH, net_g, None) 51 | 52 | stn_tst = get_text(TEXT, hps) 53 | with torch.no_grad(): 54 | x_tst = stn_tst.cuda().unsqueeze(0) 55 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 56 | sid = torch.LongTensor([SPK_ID]).cuda() 57 | audio = ( 58 | net_g.infer( 59 | x_tst, 60 | x_tst_lengths, 61 | sid=sid, 62 | noise_scale=0.667, 63 | noise_scale_w=0.8, 64 | length_scale=1, 65 | )[0][0, 0] 66 | .data.cpu() 67 | .float() 68 | .numpy() 69 | ) 70 | 71 | write(data=audio, rate=hps.data.sampling_rate, filename=OUTPUT_WAV_PATH) 72 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import commons 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1 - dr) ** 2) 26 | g_loss = torch.mean(dg**2) 27 | loss += r_loss + g_loss 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1 - dg) ** 2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | # warnings.simplefilter(action='ignore', category=FutureWarning) 4 | warnings.filterwarnings(action="ignore") 5 | 6 | import math 7 | import os 8 | import random 9 | 10 | import librosa 11 | import librosa.util as librosa_util 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.data 16 | from librosa.filters import mel as librosa_mel_fn 17 | from librosa.util import normalize, pad_center, tiny 18 | from packaging import version 19 | from scipy.io.wavfile import read 20 | from scipy.signal import get_window 21 | from torch import nn 22 | 23 | MAX_WAV_VALUE = 32768.0 24 | 25 | 26 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 27 | """ 28 | PARAMS 29 | ------ 30 | C: compression factor 31 | """ 32 | return torch.log(torch.clamp(x, min=clip_val) * C) 33 | 34 | 35 | def dynamic_range_decompression_torch(x, C=1): 36 | """ 37 | PARAMS 38 | ------ 39 | C: compression factor used to compress 40 | """ 41 | return torch.exp(x) / C 42 | 43 | 44 | def spectral_normalize_torch(magnitudes): 45 | output = dynamic_range_compression_torch(magnitudes) 46 | return output 47 | 48 | 49 | def spectral_de_normalize_torch(magnitudes): 50 | output = dynamic_range_decompression_torch(magnitudes) 51 | return output 52 | 53 | 54 | mel_basis = {} 55 | hann_window = {} 56 | 57 | 58 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 59 | if torch.min(y) < -1.0: 60 | print("min value is ", torch.min(y)) 61 | if torch.max(y) > 1.0: 62 | print("max value is ", torch.max(y)) 63 | 64 | global hann_window 65 | dtype_device = str(y.dtype) + "_" + str(y.device) 66 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 67 | if wnsize_dtype_device not in hann_window: 68 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 69 | dtype=y.dtype, device=y.device 70 | ) 71 | 72 | y = torch.nn.functional.pad( 73 | y.unsqueeze(1), 74 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 75 | mode="reflect", 76 | ) 77 | y = y.squeeze(1) 78 | 79 | if version.parse(torch.__version__) >= version.parse("2"): 80 | spec = torch.stft( 81 | y, 82 | n_fft, 83 | hop_length=hop_size, 84 | win_length=win_size, 85 | window=hann_window[wnsize_dtype_device], 86 | center=center, 87 | pad_mode="reflect", 88 | normalized=False, 89 | onesided=True, 90 | return_complex=False, 91 | ) 92 | else: 93 | spec = torch.stft( 94 | y, 95 | n_fft, 96 | hop_length=hop_size, 97 | win_length=win_size, 98 | window=hann_window[wnsize_dtype_device], 99 | center=center, 100 | pad_mode="reflect", 101 | normalized=False, 102 | onesided=True, 103 | ) 104 | 105 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 106 | return spec 107 | 108 | 109 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 110 | global mel_basis 111 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 112 | fmax_dtype_device = str(fmax) + "_" + dtype_device 113 | if fmax_dtype_device not in mel_basis: 114 | mel = librosa_mel_fn( 115 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 116 | ) 117 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 118 | dtype=spec.dtype, device=spec.device 119 | ) 120 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 121 | spec = spectral_normalize_torch(spec) 122 | return spec 123 | 124 | 125 | def mel_spectrogram_torch( 126 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 127 | ): 128 | if torch.min(y) < -1.0: 129 | print("min value is ", torch.min(y)) 130 | if torch.max(y) > 1.0: 131 | print("max value is ", torch.max(y)) 132 | 133 | global mel_basis, hann_window 134 | dtype_device = str(y.dtype) + "_" + str(y.device) 135 | fmax_dtype_device = str(fmax) + "_" + dtype_device 136 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 137 | if fmax_dtype_device not in mel_basis: 138 | mel = librosa_mel_fn( 139 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 140 | ) 141 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 142 | dtype=y.dtype, device=y.device 143 | ) 144 | if wnsize_dtype_device not in hann_window: 145 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 146 | dtype=y.dtype, device=y.device 147 | ) 148 | 149 | y = torch.nn.functional.pad( 150 | y.unsqueeze(1), 151 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 152 | mode="reflect", 153 | ) 154 | y = y.squeeze(1) 155 | 156 | if version.parse(torch.__version__) >= version.parse("2"): 157 | spec = torch.stft( 158 | y, 159 | n_fft, 160 | hop_length=hop_size, 161 | win_length=win_size, 162 | window=hann_window[wnsize_dtype_device], 163 | center=center, 164 | pad_mode="reflect", 165 | normalized=False, 166 | onesided=True, 167 | return_complex=False, 168 | ) 169 | else: 170 | spec = torch.stft( 171 | y, 172 | n_fft, 173 | hop_length=hop_size, 174 | win_length=win_size, 175 | window=hann_window[wnsize_dtype_device], 176 | center=center, 177 | pad_mode="reflect", 178 | normalized=False, 179 | onesided=True, 180 | ) 181 | 182 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 183 | 184 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 185 | spec = spectral_normalize_torch(spec) 186 | 187 | return spec 188 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import commons 13 | from commons import init_weights, get_padding 14 | from transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__( 37 | self, 38 | in_channels, 39 | hidden_channels, 40 | out_channels, 41 | kernel_size, 42 | n_layers, 43 | p_dropout, 44 | ): 45 | super().__init__() 46 | self.in_channels = in_channels 47 | self.hidden_channels = hidden_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = kernel_size 50 | self.n_layers = n_layers 51 | self.p_dropout = p_dropout 52 | assert n_layers > 1, "Number of layers should be larger than 0." 53 | 54 | self.conv_layers = nn.ModuleList() 55 | self.norm_layers = nn.ModuleList() 56 | self.conv_layers.append( 57 | nn.Conv1d( 58 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 59 | ) 60 | ) 61 | self.norm_layers.append(LayerNorm(hidden_channels)) 62 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 63 | for _ in range(n_layers - 1): 64 | self.conv_layers.append( 65 | nn.Conv1d( 66 | hidden_channels, 67 | hidden_channels, 68 | kernel_size, 69 | padding=kernel_size // 2, 70 | ) 71 | ) 72 | self.norm_layers.append(LayerNorm(hidden_channels)) 73 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 74 | self.proj.weight.data.zero_() 75 | self.proj.bias.data.zero_() 76 | 77 | def forward(self, x, x_mask): 78 | x_org = x 79 | for i in range(self.n_layers): 80 | x = self.conv_layers[i](x * x_mask) 81 | x = self.norm_layers[i](x) 82 | x = self.relu_drop(x) 83 | x = x_org + self.proj(x) 84 | return x * x_mask 85 | 86 | 87 | class DDSConv(nn.Module): 88 | """ 89 | Dialted and Depth-Separable Convolution 90 | """ 91 | 92 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): 93 | super().__init__() 94 | self.channels = channels 95 | self.kernel_size = kernel_size 96 | self.n_layers = n_layers 97 | self.p_dropout = p_dropout 98 | 99 | self.drop = nn.Dropout(p_dropout) 100 | self.convs_sep = nn.ModuleList() 101 | self.convs_1x1 = nn.ModuleList() 102 | self.norms_1 = nn.ModuleList() 103 | self.norms_2 = nn.ModuleList() 104 | for i in range(n_layers): 105 | dilation = kernel_size**i 106 | padding = (kernel_size * dilation - dilation) // 2 107 | self.convs_sep.append( 108 | nn.Conv1d( 109 | channels, 110 | channels, 111 | kernel_size, 112 | groups=channels, 113 | dilation=dilation, 114 | padding=padding, 115 | ) 116 | ) 117 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 118 | self.norms_1.append(LayerNorm(channels)) 119 | self.norms_2.append(LayerNorm(channels)) 120 | 121 | def forward(self, x, x_mask, g=None): 122 | if g is not None: 123 | x = x + g 124 | for i in range(self.n_layers): 125 | y = self.convs_sep[i](x * x_mask) 126 | y = self.norms_1[i](y) 127 | y = F.gelu(y) 128 | y = self.convs_1x1[i](y) 129 | y = self.norms_2[i](y) 130 | y = F.gelu(y) 131 | y = self.drop(y) 132 | x = x + y 133 | return x * x_mask 134 | 135 | 136 | class WN(torch.nn.Module): 137 | def __init__( 138 | self, 139 | hidden_channels, 140 | kernel_size, 141 | dilation_rate, 142 | n_layers, 143 | gin_channels=0, 144 | p_dropout=0, 145 | ): 146 | super(WN, self).__init__() 147 | assert kernel_size % 2 == 1 148 | self.hidden_channels = hidden_channels 149 | self.kernel_size = (kernel_size,) 150 | self.dilation_rate = dilation_rate 151 | self.n_layers = n_layers 152 | self.gin_channels = gin_channels 153 | self.p_dropout = p_dropout 154 | 155 | self.in_layers = torch.nn.ModuleList() 156 | self.res_skip_layers = torch.nn.ModuleList() 157 | self.drop = nn.Dropout(p_dropout) 158 | 159 | if gin_channels != 0: 160 | cond_layer = torch.nn.Conv1d( 161 | gin_channels, 2 * hidden_channels * n_layers, 1 162 | ) 163 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 164 | 165 | for i in range(n_layers): 166 | dilation = dilation_rate**i 167 | padding = int((kernel_size * dilation - dilation) / 2) 168 | in_layer = torch.nn.Conv1d( 169 | hidden_channels, 170 | 2 * hidden_channels, 171 | kernel_size, 172 | dilation=dilation, 173 | padding=padding, 174 | ) 175 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 176 | self.in_layers.append(in_layer) 177 | 178 | # last one is not necessary 179 | if i < n_layers - 1: 180 | res_skip_channels = 2 * hidden_channels 181 | else: 182 | res_skip_channels = hidden_channels 183 | 184 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 185 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 186 | self.res_skip_layers.append(res_skip_layer) 187 | 188 | def forward(self, x, x_mask, g=None, **kwargs): 189 | output = torch.zeros_like(x) 190 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 191 | 192 | if g is not None: 193 | g = self.cond_layer(g) 194 | 195 | for i in range(self.n_layers): 196 | x_in = self.in_layers[i](x) 197 | if g is not None: 198 | cond_offset = i * 2 * self.hidden_channels 199 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 200 | else: 201 | g_l = torch.zeros_like(x_in) 202 | 203 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 204 | acts = self.drop(acts) 205 | 206 | res_skip_acts = self.res_skip_layers[i](acts) 207 | if i < self.n_layers - 1: 208 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 209 | x = (x + res_acts) * x_mask 210 | output = output + res_skip_acts[:, self.hidden_channels :, :] 211 | else: 212 | output = output + res_skip_acts 213 | return output * x_mask 214 | 215 | def remove_weight_norm(self): 216 | if self.gin_channels != 0: 217 | torch.nn.utils.remove_weight_norm(self.cond_layer) 218 | for l in self.in_layers: 219 | torch.nn.utils.remove_weight_norm(l) 220 | for l in self.res_skip_layers: 221 | torch.nn.utils.remove_weight_norm(l) 222 | 223 | 224 | class ResBlock1(torch.nn.Module): 225 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 226 | super(ResBlock1, self).__init__() 227 | self.convs1 = nn.ModuleList( 228 | [ 229 | weight_norm( 230 | Conv1d( 231 | channels, 232 | channels, 233 | kernel_size, 234 | 1, 235 | dilation=dilation[0], 236 | padding=get_padding(kernel_size, dilation[0]), 237 | ) 238 | ), 239 | weight_norm( 240 | Conv1d( 241 | channels, 242 | channels, 243 | kernel_size, 244 | 1, 245 | dilation=dilation[1], 246 | padding=get_padding(kernel_size, dilation[1]), 247 | ) 248 | ), 249 | weight_norm( 250 | Conv1d( 251 | channels, 252 | channels, 253 | kernel_size, 254 | 1, 255 | dilation=dilation[2], 256 | padding=get_padding(kernel_size, dilation[2]), 257 | ) 258 | ), 259 | ] 260 | ) 261 | self.convs1.apply(init_weights) 262 | 263 | self.convs2 = nn.ModuleList( 264 | [ 265 | weight_norm( 266 | Conv1d( 267 | channels, 268 | channels, 269 | kernel_size, 270 | 1, 271 | dilation=1, 272 | padding=get_padding(kernel_size, 1), 273 | ) 274 | ), 275 | weight_norm( 276 | Conv1d( 277 | channels, 278 | channels, 279 | kernel_size, 280 | 1, 281 | dilation=1, 282 | padding=get_padding(kernel_size, 1), 283 | ) 284 | ), 285 | weight_norm( 286 | Conv1d( 287 | channels, 288 | channels, 289 | kernel_size, 290 | 1, 291 | dilation=1, 292 | padding=get_padding(kernel_size, 1), 293 | ) 294 | ), 295 | ] 296 | ) 297 | self.convs2.apply(init_weights) 298 | 299 | def forward(self, x, x_mask=None): 300 | for c1, c2 in zip(self.convs1, self.convs2): 301 | xt = F.leaky_relu(x, LRELU_SLOPE) 302 | if x_mask is not None: 303 | xt = xt * x_mask 304 | xt = c1(xt) 305 | xt = F.leaky_relu(xt, LRELU_SLOPE) 306 | if x_mask is not None: 307 | xt = xt * x_mask 308 | xt = c2(xt) 309 | x = xt + x 310 | if x_mask is not None: 311 | x = x * x_mask 312 | return x 313 | 314 | def remove_weight_norm(self): 315 | for l in self.convs1: 316 | remove_weight_norm(l) 317 | for l in self.convs2: 318 | remove_weight_norm(l) 319 | 320 | 321 | class ResBlock2(torch.nn.Module): 322 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 323 | super(ResBlock2, self).__init__() 324 | self.convs = nn.ModuleList( 325 | [ 326 | weight_norm( 327 | Conv1d( 328 | channels, 329 | channels, 330 | kernel_size, 331 | 1, 332 | dilation=dilation[0], 333 | padding=get_padding(kernel_size, dilation[0]), 334 | ) 335 | ), 336 | weight_norm( 337 | Conv1d( 338 | channels, 339 | channels, 340 | kernel_size, 341 | 1, 342 | dilation=dilation[1], 343 | padding=get_padding(kernel_size, dilation[1]), 344 | ) 345 | ), 346 | ] 347 | ) 348 | self.convs.apply(init_weights) 349 | 350 | def forward(self, x, x_mask=None): 351 | for c in self.convs: 352 | xt = F.leaky_relu(x, LRELU_SLOPE) 353 | if x_mask is not None: 354 | xt = xt * x_mask 355 | xt = c(xt) 356 | x = xt + x 357 | if x_mask is not None: 358 | x = x * x_mask 359 | return x 360 | 361 | def remove_weight_norm(self): 362 | for l in self.convs: 363 | remove_weight_norm(l) 364 | 365 | 366 | class Log(nn.Module): 367 | def forward(self, x, x_mask, reverse=False, **kwargs): 368 | if not reverse: 369 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 370 | logdet = torch.sum(-y, [1, 2]) 371 | return y, logdet 372 | else: 373 | x = torch.exp(x) * x_mask 374 | return x 375 | 376 | 377 | class Flip(nn.Module): 378 | def forward(self, x, *args, reverse=False, **kwargs): 379 | x = torch.flip(x, [1]) 380 | if not reverse: 381 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 382 | return x, logdet 383 | else: 384 | return x 385 | 386 | 387 | class ElementwiseAffine(nn.Module): 388 | def __init__(self, channels): 389 | super().__init__() 390 | self.channels = channels 391 | self.m = nn.Parameter(torch.zeros(channels, 1)) 392 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 393 | 394 | def forward(self, x, x_mask, reverse=False, **kwargs): 395 | if not reverse: 396 | y = self.m + torch.exp(self.logs) * x 397 | y = y * x_mask 398 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 399 | return y, logdet 400 | else: 401 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 402 | return x 403 | 404 | 405 | class ResidualCouplingLayer(nn.Module): 406 | def __init__( 407 | self, 408 | channels, 409 | hidden_channels, 410 | kernel_size, 411 | dilation_rate, 412 | n_layers, 413 | p_dropout=0, 414 | gin_channels=0, 415 | mean_only=False, 416 | ): 417 | assert channels % 2 == 0, "channels should be divisible by 2" 418 | super().__init__() 419 | self.channels = channels 420 | self.hidden_channels = hidden_channels 421 | self.kernel_size = kernel_size 422 | self.dilation_rate = dilation_rate 423 | self.n_layers = n_layers 424 | self.half_channels = channels // 2 425 | self.mean_only = mean_only 426 | 427 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 428 | self.enc = WN( 429 | hidden_channels, 430 | kernel_size, 431 | dilation_rate, 432 | n_layers, 433 | p_dropout=p_dropout, 434 | gin_channels=gin_channels, 435 | ) 436 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 437 | self.post.weight.data.zero_() 438 | self.post.bias.data.zero_() 439 | 440 | def forward(self, x, x_mask, g=None, reverse=False): 441 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 442 | h = self.pre(x0) * x_mask 443 | h = self.enc(h, x_mask, g=g) 444 | stats = self.post(h) * x_mask 445 | if not self.mean_only: 446 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 447 | else: 448 | m = stats 449 | logs = torch.zeros_like(m) 450 | 451 | if not reverse: 452 | x1 = m + x1 * torch.exp(logs) * x_mask 453 | x = torch.cat([x0, x1], 1) 454 | logdet = torch.sum(logs, [1, 2]) 455 | return x, logdet 456 | else: 457 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 458 | x = torch.cat([x0, x1], 1) 459 | return x 460 | 461 | 462 | class ConvFlow(nn.Module): 463 | def __init__( 464 | self, 465 | in_channels, 466 | filter_channels, 467 | kernel_size, 468 | n_layers, 469 | num_bins=10, 470 | tail_bound=5.0, 471 | ): 472 | super().__init__() 473 | self.in_channels = in_channels 474 | self.filter_channels = filter_channels 475 | self.kernel_size = kernel_size 476 | self.n_layers = n_layers 477 | self.num_bins = num_bins 478 | self.tail_bound = tail_bound 479 | self.half_channels = in_channels // 2 480 | 481 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 482 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) 483 | self.proj = nn.Conv1d( 484 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1 485 | ) 486 | self.proj.weight.data.zero_() 487 | self.proj.bias.data.zero_() 488 | 489 | def forward(self, x, x_mask, g=None, reverse=False): 490 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 491 | h = self.pre(x0) 492 | h = self.convs(h, x_mask, g=g) 493 | h = self.proj(h) * x_mask 494 | 495 | b, c, t = x0.shape 496 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 497 | 498 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) 499 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( 500 | self.filter_channels 501 | ) 502 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 503 | 504 | x1, logabsdet = piecewise_rational_quadratic_transform( 505 | x1, 506 | unnormalized_widths, 507 | unnormalized_heights, 508 | unnormalized_derivatives, 509 | inverse=reverse, 510 | tails="linear", 511 | tail_bound=self.tail_bound, 512 | ) 513 | 514 | x = torch.cat([x0, x1], 1) * x_mask 515 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 516 | if not reverse: 517 | return x, logdet 518 | else: 519 | return x 520 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name="monotonic_align", 7 | ext_modules=cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()], 9 | ) 10 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import text 3 | from utils import load_filepaths_and_text 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--out_extension", default="cleaned") 8 | parser.add_argument("--text_index", default=1, type=int) 9 | parser.add_argument( 10 | "--filelists", 11 | nargs="+", 12 | default=[ 13 | "filelists/ljs_audio_text_val_filelist.txt", 14 | "filelists/ljs_audio_text_test_filelist.txt", 15 | ], 16 | ) 17 | parser.add_argument("--text_cleaners", nargs="+", default=["english_cleaners2"]) 18 | 19 | args = parser.parse_args() 20 | 21 | for filelist in args.filelists: 22 | print("START:", filelist) 23 | filepaths_and_text = load_filepaths_and_text(filelist) 24 | for i in range(len(filepaths_and_text)): 25 | original_text = filepaths_and_text[i][args.text_index] 26 | cleaned_text = text._clean_text(original_text, args.text_cleaners) 27 | filepaths_and_text[i][args.text_index] = cleaned_text 28 | 29 | new_filelist = filelist + "." + args.out_extension 30 | with open(new_filelist, "w", encoding="utf-8") as f: 31 | f.writelines(["|".join(x) + "\n" for x in filepaths_and_text]) 32 | -------------------------------------------------------------------------------- /preprocess_audio.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCTK 3 | https://datashare.ed.ac.uk/handle/10283/3443 4 | VCTK trim info 5 | https://github.com/nii-yamagishilab/vctk-silence-labels 6 | 7 | Warning! This code is not properly debugged. 8 | It is recommended to run it only once for the initial state of the audio file (flac or wav). 9 | If executed repeatedly, consecutive application of "trim" may potentially damage the audio file. 10 | 11 | >>> $ pip install librosa==0.9.2 numpy==1.23.5 scipy==1.9.1 tqdm # [option] 12 | >>> $ cd /path/to/the/your/vits3 13 | >>> $ ln -s /path/to/the/VCTK/* DUMMY2/ 14 | >>> $ git clone https://github.com/nii-yamagishilab/vctk-silence-labels filelists/vctk-silence-labels 15 | >>> $ python preprocess_audio.py --filelists <~/filelist.txt> --config <~/config.json> --trim <~/info.txt> 16 | """ 17 | 18 | import argparse 19 | import os 20 | 21 | import librosa 22 | import numpy as np 23 | from scipy.io import wavfile 24 | from tqdm.auto import tqdm 25 | 26 | import utils 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--filelists", 32 | nargs="+", 33 | default=[ 34 | "filelists/vctk_audio_sid_text_test_filelist.txt", 35 | "filelists/vctk_audio_sid_text_val_filelist.txt", 36 | "filelists/vctk_audio_sid_text_train_filelist.txt", 37 | ], 38 | ) 39 | parser.add_argument("--config", default="configs/vctk_base2.json", type=str) 40 | parser.add_argument( 41 | "--trim", 42 | default="filelists/vctk-silence-labels/vctk-silences.0.92.txt", 43 | type=str, 44 | ) 45 | args = parser.parse_args() 46 | 47 | with open(args.trim, "r", encoding="utf8") as f: 48 | lines = list(filter(lambda x: len(x) > 0, f.read().split("\n"))) 49 | trim_info = {} 50 | for line in lines: 51 | line = line.split(" ") 52 | trim_info[line[0]] = (float(line[1]), float(line[2])) 53 | 54 | hps = utils.get_hparams_from_file(args.config) 55 | for filelist in args.filelists: 56 | print("START:", filelist) 57 | with open(filelist, "r", encoding="utf8") as f: 58 | lines = list(filter(lambda x: len(x) > 0, f.read().split("\n"))) 59 | 60 | for line in tqdm(lines, total=len(lines), desc=filelist): 61 | src_filename = line.split("|")[0] 62 | if not os.path.isfile(src_filename): 63 | if os.path.isfile(src_filename.replace(".wav", "_mic1.flac")): 64 | src_filename = src_filename.replace(".wav", "_mic1.flac") 65 | else: 66 | continue 67 | 68 | if src_filename.endswith("_mic1.flac"): 69 | tgt_filename = src_filename.replace("_mic1.flac", ".wav") 70 | else: 71 | tgt_filename = src_filename 72 | 73 | basename = os.path.splitext(os.path.basename(src_filename))[0].replace( 74 | "_mic1", "" 75 | ) 76 | if trim_info.get(basename) is None: 77 | print( 78 | f"file info: '{src_filename}' doesn't exist in trim info '{args.trim}'" 79 | ) 80 | continue 81 | 82 | start, end = trim_info[basename][0], trim_info[basename][1] 83 | 84 | # warning: it could be make the file to unacceptable 85 | y, _ = librosa.core.load( 86 | src_filename, 87 | sr=hps.data.sampling_rate, 88 | mono=True, 89 | res_type="scipy", 90 | offset=start, 91 | duration=end - start, 92 | ) 93 | 94 | # y, _ = librosa.effects.trim( 95 | # y=y, 96 | # frame_length=4096, 97 | # hop_length=256, 98 | # top_db=35, 99 | # ) 100 | 101 | if y.shape[-1] < hps.train.segment_size: 102 | continue 103 | 104 | y = y * hps.data.max_wav_value 105 | wavfile.write( 106 | filename=tgt_filename, 107 | rate=hps.data.sampling_rate, 108 | data=y.astype(np.int16), 109 | ) 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==3.0.2 2 | librosa==0.10.1 3 | matplotlib==3.7.2 4 | numpy==1.24.4 5 | phonemizer==3.2.1 6 | scipy==1.11.2 7 | # torch==2.0.1 8 | # torchaudio==2.0.2 9 | # torchvision==0.15.2 10 | Unidecode==1.3.6 11 | tensorboard==2.14.0 12 | onnx==1.14.1 13 | onnxruntime==1.15.1 14 | gradio 15 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | from text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, cleaner_names): 12 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | """ 19 | sequence = [] 20 | 21 | clean_text = _clean_text(text, cleaner_names) 22 | for symbol in clean_text: 23 | if symbol in _symbol_to_id.keys(): 24 | symbol_id = _symbol_to_id[symbol] 25 | sequence += [symbol_id] 26 | else: 27 | continue 28 | return sequence 29 | 30 | 31 | def cleaned_text_to_sequence(cleaned_text): 32 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 33 | Args: 34 | text: string to convert to a sequence 35 | Returns: 36 | List of integers corresponding to the symbols in the text 37 | """ 38 | sequence = [] 39 | 40 | for symbol in cleaned_text: 41 | if symbol in _symbol_to_id.keys(): 42 | symbol_id = _symbol_to_id[symbol] 43 | sequence += [symbol_id] 44 | else: 45 | continue 46 | return sequence 47 | 48 | 49 | def sequence_to_text(sequence): 50 | """Converts a sequence of IDs back to a string""" 51 | result = "" 52 | for symbol_id in sequence: 53 | s = _id_to_symbol[symbol_id] 54 | result += s 55 | return result 56 | 57 | 58 | def _clean_text(text, cleaner_names): 59 | for name in cleaner_names: 60 | cleaner = getattr(cleaners, name) 61 | if not cleaner: 62 | raise Exception("Unknown cleaner: %s" % name) 63 | text = cleaner(text) 64 | return text 65 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | """ 14 | 15 | import re 16 | from unidecode import unidecode 17 | from phonemizer import phonemize 18 | from phonemizer.backend import EspeakBackend 19 | backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True) 20 | 21 | 22 | # Regular expression matching whitespace: 23 | _whitespace_re = re.compile(r"\s+") 24 | 25 | # List of (regular expression, replacement) pairs for abbreviations: 26 | _abbreviations = [ 27 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 28 | for x in [ 29 | ("mrs", "misess"), 30 | ("mr", "mister"), 31 | ("dr", "doctor"), 32 | ("st", "saint"), 33 | ("co", "company"), 34 | ("jr", "junior"), 35 | ("maj", "major"), 36 | ("gen", "general"), 37 | ("drs", "doctors"), 38 | ("rev", "reverend"), 39 | ("lt", "lieutenant"), 40 | ("hon", "honorable"), 41 | ("sgt", "sergeant"), 42 | ("capt", "captain"), 43 | ("esq", "esquire"), 44 | ("ltd", "limited"), 45 | ("col", "colonel"), 46 | ("ft", "fort"), 47 | ] 48 | ] 49 | 50 | 51 | def expand_abbreviations(text): 52 | for regex, replacement in _abbreviations: 53 | text = re.sub(regex, replacement, text) 54 | return text 55 | 56 | 57 | def expand_numbers(text): 58 | return normalize_numbers(text) 59 | 60 | 61 | def lowercase(text): 62 | return text.lower() 63 | 64 | 65 | def collapse_whitespace(text): 66 | return re.sub(_whitespace_re, " ", text) 67 | 68 | 69 | def convert_to_ascii(text): 70 | return unidecode(text) 71 | 72 | 73 | def basic_cleaners(text): 74 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 75 | text = lowercase(text) 76 | text = collapse_whitespace(text) 77 | return text 78 | 79 | 80 | def transliteration_cleaners(text): 81 | """Pipeline for non-English text that transliterates to ASCII.""" 82 | text = convert_to_ascii(text) 83 | text = lowercase(text) 84 | text = collapse_whitespace(text) 85 | return text 86 | 87 | 88 | def english_cleaners(text): 89 | """Pipeline for English text, including abbreviation expansion.""" 90 | text = convert_to_ascii(text) 91 | text = lowercase(text) 92 | text = expand_abbreviations(text) 93 | phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) 94 | phonemes = collapse_whitespace(phonemes) 95 | return phonemes 96 | 97 | 98 | def english_cleaners2(text): 99 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 100 | text = convert_to_ascii(text) 101 | text = lowercase(text) 102 | text = expand_abbreviations(text) 103 | phonemes = phonemize( 104 | text, 105 | language="en-us", 106 | backend="espeak", 107 | strip=True, 108 | preserve_punctuation=True, 109 | with_stress=True, 110 | ) 111 | phonemes = collapse_whitespace(phonemes) 112 | return phonemes 113 | 114 | 115 | def english_cleaners3(text): 116 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 117 | text = convert_to_ascii(text) 118 | text = lowercase(text) 119 | text = expand_abbreviations(text) 120 | phonemes = backend.phonemize([text], strip=True)[0] 121 | phonemes = collapse_whitespace(phonemes) 122 | return phonemes 123 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | """ 6 | _pad = "_" 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | 11 | 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import itertools 5 | import math 6 | import torch 7 | from torch import nn, optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # from tensorboardX import SummaryWriter 13 | import torch.multiprocessing as mp 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.cuda.amp import autocast, GradScaler 17 | import tqdm 18 | 19 | import commons 20 | import utils 21 | from data_utils import TextAudioLoader, TextAudioCollate, DistributedBucketSampler 22 | from models import ( 23 | SynthesizerTrn, 24 | MultiPeriodDiscriminator, 25 | DurationDiscriminatorV1, 26 | DurationDiscriminatorV2, 27 | AVAILABLE_FLOW_TYPES, 28 | AVAILABLE_DURATION_DISCRIMINATOR_TYPES 29 | ) 30 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss 31 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch 32 | from text.symbols import symbols 33 | 34 | 35 | torch.backends.cudnn.benchmark = True 36 | global_step = 0 37 | 38 | 39 | def main(): 40 | """Assume Single Node Multi GPUs Training Only""" 41 | assert torch.cuda.is_available(), "CPU training is not allowed." 42 | 43 | n_gpus = torch.cuda.device_count() 44 | os.environ["MASTER_ADDR"] = "localhost" 45 | os.environ["MASTER_PORT"] = "6060" 46 | 47 | hps = utils.get_hparams() 48 | mp.spawn( 49 | run, 50 | nprocs=n_gpus, 51 | args=( 52 | n_gpus, 53 | hps, 54 | ), 55 | ) 56 | 57 | 58 | def run(rank, n_gpus, hps): 59 | global global_step 60 | if rank == 0: 61 | logger = utils.get_logger(hps.model_dir) 62 | logger.info(hps) 63 | utils.check_git_hash(hps.model_dir) 64 | writer = SummaryWriter(log_dir=hps.model_dir) 65 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 66 | 67 | dist.init_process_group( 68 | backend="nccl", init_method="env://", world_size=n_gpus, rank=rank 69 | ) 70 | torch.manual_seed(hps.train.seed) 71 | torch.cuda.set_device(rank) 72 | 73 | if ( 74 | "use_mel_posterior_encoder" in hps.model.keys() 75 | and hps.model.use_mel_posterior_encoder == True 76 | ): 77 | print("Using mel posterior encoder for VITS3") 78 | posterior_channels = 80 # vits3 79 | hps.data.use_mel_posterior_encoder = True 80 | else: 81 | print("Using lin posterior encoder for VITS1") 82 | posterior_channels = hps.data.filter_length // 2 + 1 83 | hps.data.use_mel_posterior_encoder = False 84 | 85 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data) 86 | train_sampler = DistributedBucketSampler( 87 | train_dataset, 88 | hps.train.batch_size, 89 | [32, 300, 400, 500, 600, 700, 800, 900, 1000], 90 | num_replicas=n_gpus, 91 | rank=rank, 92 | shuffle=True, 93 | ) 94 | 95 | collate_fn = TextAudioCollate() 96 | train_loader = DataLoader( 97 | train_dataset, 98 | num_workers=8, 99 | shuffle=False, 100 | pin_memory=True, 101 | collate_fn=collate_fn, 102 | batch_sampler=train_sampler, 103 | ) 104 | if rank == 0: 105 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data) 106 | eval_loader = DataLoader( 107 | eval_dataset, 108 | num_workers=8, 109 | shuffle=False, 110 | batch_size=hps.train.batch_size, 111 | pin_memory=True, 112 | drop_last=False, 113 | collate_fn=collate_fn, 114 | ) 115 | # some of these flags are not being used in the code and directly set in hps json file. 116 | # they are kept here for reference and prototyping. 117 | 118 | if ( 119 | "use_transformer_flows" in hps.model.keys() 120 | and hps.model.use_transformer_flows == True 121 | ): 122 | use_transformer_flows = True 123 | transformer_flow_type = hps.model.transformer_flow_type 124 | print(f"Using transformer flows {transformer_flow_type} for VITS3") 125 | assert ( 126 | transformer_flow_type in AVAILABLE_FLOW_TYPES 127 | ), f"transformer_flow_type must be one of {AVAILABLE_FLOW_TYPES}" 128 | else: 129 | print("Using normal flows for VITS1") 130 | use_transformer_flows = False 131 | 132 | if ( 133 | "use_spk_conditioned_encoder" in hps.model.keys() 134 | and hps.model.use_spk_conditioned_encoder == True 135 | ): 136 | if hps.data.n_speakers == 0: 137 | print("Warning: use_spk_conditioned_encoder is True but n_speakers is 0") 138 | print( 139 | "Setting use_spk_conditioned_encoder to False as model is a single speaker model" 140 | ) 141 | use_spk_conditioned_encoder = False 142 | else: 143 | print("Using normal encoder for VITS1") 144 | use_spk_conditioned_encoder = False 145 | 146 | if ( 147 | "use_noise_scaled_mas" in hps.model.keys() 148 | and hps.model.use_noise_scaled_mas == True 149 | ): 150 | print("Using noise scaled MAS for VITS3") 151 | use_noise_scaled_mas = True 152 | mas_noise_scale_initial = 0.01 153 | noise_scale_delta = 2e-6 154 | else: 155 | print("Using normal MAS for VITS1") 156 | use_noise_scaled_mas = False 157 | mas_noise_scale_initial = 0.0 158 | noise_scale_delta = 0.0 159 | 160 | if ( 161 | "use_duration_discriminator" in hps.model.keys() 162 | and hps.model.use_duration_discriminator == True 163 | ): 164 | # print("Using duration discriminator for VITS3") 165 | use_duration_discriminator = True 166 | duration_discriminator_type = hps.model.duration_discriminator_type 167 | print(f"Using duration_discriminator {duration_discriminator_type} for VITS3") 168 | assert duration_discriminator_type in AVAILABLE_DURATION_DISCRIMINATOR_TYPES, f"duration_discriminator_type must be one of {AVAILABLE_DURATION_DISCRIMINATOR_TYPES}" 169 | if duration_discriminator_type == "dur_disc_1": 170 | net_dur_disc = DurationDiscriminatorV1( 171 | hps.model.hidden_channels, 172 | hps.model.hidden_channels, 173 | 3, 174 | 0.1, 175 | gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, 176 | ).cuda(rank) 177 | elif duration_discriminator_type == "dur_disc_2": 178 | net_dur_disc = DurationDiscriminatorV2( 179 | hps.model.hidden_channels, 180 | hps.model.hidden_channels, 181 | 3, 182 | 0.1, 183 | gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, 184 | ).cuda(rank) 185 | else: 186 | print("NOT using any duration discriminator like VITS1") 187 | net_dur_disc = None 188 | use_duration_discriminator = False 189 | 190 | net_g = SynthesizerTrn( 191 | len(symbols), 192 | posterior_channels, 193 | hps.train.segment_size // hps.data.hop_length, 194 | mas_noise_scale_initial=mas_noise_scale_initial, 195 | noise_scale_delta=noise_scale_delta, 196 | **hps.model, 197 | ).cuda(rank) 198 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 199 | optim_g = torch.optim.AdamW( 200 | net_g.parameters(), 201 | hps.train.learning_rate, 202 | betas=hps.train.betas, 203 | eps=hps.train.eps, 204 | ) 205 | optim_d = torch.optim.AdamW( 206 | net_d.parameters(), 207 | hps.train.learning_rate, 208 | betas=hps.train.betas, 209 | eps=hps.train.eps, 210 | ) 211 | if net_dur_disc is not None: 212 | optim_dur_disc = torch.optim.AdamW( 213 | net_dur_disc.parameters(), 214 | hps.train.learning_rate, 215 | betas=hps.train.betas, 216 | eps=hps.train.eps, 217 | ) 218 | else: 219 | optim_dur_disc = None 220 | 221 | net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) 222 | net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) 223 | if net_dur_disc is not None: 224 | net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True) 225 | 226 | try: 227 | _, _, _, epoch_str = utils.load_checkpoint( 228 | utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g 229 | ) 230 | _, _, _, epoch_str = utils.load_checkpoint( 231 | utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d 232 | ) 233 | if net_dur_disc is not None: 234 | _, _, _, epoch_str = utils.load_checkpoint( 235 | utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), 236 | net_dur_disc, 237 | optim_dur_disc, 238 | ) 239 | global_step = (epoch_str - 1) * len(train_loader) 240 | except: 241 | epoch_str = 1 242 | global_step = 0 243 | 244 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 245 | optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 246 | ) 247 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 248 | optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 249 | ) 250 | if net_dur_disc is not None: 251 | scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR( 252 | optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 253 | ) 254 | else: 255 | scheduler_dur_disc = None 256 | 257 | scaler = GradScaler(enabled=hps.train.fp16_run) 258 | 259 | for epoch in range(epoch_str, hps.train.epochs + 1): 260 | if rank == 0: 261 | train_and_evaluate( 262 | rank, 263 | epoch, 264 | hps, 265 | [net_g, net_d, net_dur_disc], 266 | [optim_g, optim_d, optim_dur_disc], 267 | [scheduler_g, scheduler_d, scheduler_dur_disc], 268 | scaler, 269 | [train_loader, eval_loader], 270 | logger, 271 | [writer, writer_eval], 272 | ) 273 | else: 274 | train_and_evaluate( 275 | rank, 276 | epoch, 277 | hps, 278 | [net_g, net_d, net_dur_disc], 279 | [optim_g, optim_d, optim_dur_disc], 280 | [scheduler_g, scheduler_d, scheduler_dur_disc], 281 | scaler, 282 | [train_loader, None], 283 | None, 284 | None, 285 | ) 286 | scheduler_g.step() 287 | scheduler_d.step() 288 | if net_dur_disc is not None: 289 | scheduler_dur_disc.step() 290 | 291 | 292 | def train_and_evaluate( 293 | rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers 294 | ): 295 | net_g, net_d, net_dur_disc = nets 296 | optim_g, optim_d, optim_dur_disc = optims 297 | scheduler_g, scheduler_d, scheduler_dur_disc = schedulers 298 | train_loader, eval_loader = loaders 299 | if writers is not None: 300 | writer, writer_eval = writers 301 | 302 | train_loader.batch_sampler.set_epoch(epoch) 303 | global global_step 304 | 305 | net_g.train() 306 | net_d.train() 307 | if net_dur_disc is not None: 308 | net_dur_disc.train() 309 | 310 | if rank == 0: 311 | loader = tqdm.tqdm(train_loader, desc="Loading train data") 312 | else: 313 | loader = train_loader 314 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate( 315 | loader 316 | ): 317 | if net_g.module.use_noise_scaled_mas: 318 | current_mas_noise_scale = ( 319 | net_g.module.mas_noise_scale_initial 320 | - net_g.module.noise_scale_delta * global_step 321 | ) 322 | net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0) 323 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( 324 | rank, non_blocking=True 325 | ) 326 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( 327 | rank, non_blocking=True 328 | ) 329 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( 330 | rank, non_blocking=True 331 | ) 332 | 333 | with autocast(enabled=hps.train.fp16_run): 334 | ( 335 | y_hat, 336 | l_length, 337 | attn, 338 | ids_slice, 339 | x_mask, 340 | z_mask, 341 | (z, z_p, m_p, logs_p, m_q, logs_q), 342 | (hidden_x, logw, logw_), 343 | (x_flown, x_expanded), 344 | ) = net_g(x, x_lengths, spec, spec_lengths) 345 | 346 | if ( 347 | hps.model.use_mel_posterior_encoder 348 | or hps.data.use_mel_posterior_encoder 349 | ): 350 | mel = spec 351 | else: 352 | mel = spec_to_mel_torch( 353 | spec.float(), 354 | hps.data.filter_length, 355 | hps.data.n_mel_channels, 356 | hps.data.sampling_rate, 357 | hps.data.mel_fmin, 358 | hps.data.mel_fmax, 359 | ) 360 | y_mel = commons.slice_segments( 361 | mel, ids_slice, hps.train.segment_size // hps.data.hop_length 362 | ) 363 | y_hat_mel = mel_spectrogram_torch( 364 | y_hat.squeeze(1), 365 | hps.data.filter_length, 366 | hps.data.n_mel_channels, 367 | hps.data.sampling_rate, 368 | hps.data.hop_length, 369 | hps.data.win_length, 370 | hps.data.mel_fmin, 371 | hps.data.mel_fmax, 372 | ) 373 | 374 | y = commons.slice_segments( 375 | y, ids_slice * hps.data.hop_length, hps.train.segment_size 376 | ) # slice 377 | 378 | # Discriminator 379 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 380 | with autocast(enabled=False): 381 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( 382 | y_d_hat_r, y_d_hat_g 383 | ) 384 | loss_disc_all = loss_disc 385 | 386 | # Duration Discriminator 387 | if net_dur_disc is not None: 388 | y_dur_hat_r, y_dur_hat_g = net_dur_disc( 389 | hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach() 390 | ) 391 | with autocast(enabled=False): 392 | # TODO: I think need to mean using the mask, but for now, just mean all 393 | ( 394 | loss_dur_disc, 395 | losses_dur_disc_r, 396 | losses_dur_disc_g, 397 | ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g) 398 | loss_dur_disc_all = loss_dur_disc 399 | optim_dur_disc.zero_grad() 400 | scaler.scale(loss_dur_disc_all).backward() 401 | scaler.unscale_(optim_dur_disc) 402 | grad_norm_dur_disc = commons.clip_grad_value_( 403 | net_dur_disc.parameters(), None 404 | ) 405 | scaler.step(optim_dur_disc) 406 | 407 | optim_d.zero_grad() 408 | scaler.scale(loss_disc_all).backward() 409 | scaler.unscale_(optim_d) 410 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) 411 | scaler.step(optim_d) 412 | 413 | with autocast(enabled=hps.train.fp16_run): 414 | # Generator 415 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 416 | if net_dur_disc is not None: 417 | y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw) 418 | with autocast(enabled=False): 419 | loss_dur = torch.sum(l_length.float()) 420 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 421 | prior_loss = torch.sum(0.5 * ((z.detach() - x_expanded) ** 2 + math.log(2 * math.pi)) * y_mask) 422 | prior_loss = prior_loss / (torch.sum(z_mask) * net_g.spec_channels) 423 | # loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 424 | loss_kl = kl_loss(x_flown, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 425 | loss_fm = feature_loss(fmap_r, fmap_g) 426 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 427 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl 428 | if net_dur_disc is not None: 429 | loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g) 430 | loss_gen_all += loss_dur_gen 431 | 432 | optim_g.zero_grad() 433 | scaler.scale(loss_gen_all).backward() 434 | scaler.unscale_(optim_g) 435 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) 436 | scaler.step(optim_g) 437 | scaler.update() 438 | 439 | if rank == 0: 440 | if global_step % hps.train.log_interval == 0: 441 | lr = optim_g.param_groups[0]["lr"] 442 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl] 443 | logger.info( 444 | "Train Epoch: {} [{:.0f}%]".format( 445 | epoch, 100.0 * batch_idx / len(train_loader) 446 | ) 447 | ) 448 | logger.info([x.item() for x in losses] + [global_step, lr]) 449 | 450 | scalar_dict = { 451 | "loss/g/total": loss_gen_all, 452 | "loss/d/total": loss_disc_all, 453 | "learning_rate": lr, 454 | "grad_norm_d": grad_norm_d, 455 | "grad_norm_g": grad_norm_g, 456 | } 457 | if net_dur_disc is not None: 458 | scalar_dict.update( 459 | { 460 | "loss/dur_disc/total": loss_dur_disc_all, 461 | "grad_norm_dur_disc": grad_norm_dur_disc, 462 | } 463 | ) 464 | scalar_dict.update( 465 | { 466 | "loss/g/fm": loss_fm, 467 | "loss/g/mel": loss_mel, 468 | "loss/g/dur": loss_dur, 469 | "loss/g/kl": loss_kl, 470 | } 471 | ) 472 | 473 | scalar_dict.update( 474 | {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} 475 | ) 476 | scalar_dict.update( 477 | {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} 478 | ) 479 | scalar_dict.update( 480 | {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} 481 | ) 482 | 483 | # if net_dur_disc is not None: 484 | # scalar_dict.update({"loss/dur_disc_r" : f"{losses_dur_disc_r}"}) 485 | # scalar_dict.update({"loss/dur_disc_g" : f"{losses_dur_disc_g}"}) 486 | # scalar_dict.update({"loss/dur_gen" : f"{loss_dur_gen}"}) 487 | 488 | image_dict = { 489 | "slice/mel_org": utils.plot_spectrogram_to_numpy( 490 | y_mel[0].data.cpu().numpy() 491 | ), 492 | "slice/mel_gen": utils.plot_spectrogram_to_numpy( 493 | y_hat_mel[0].data.cpu().numpy() 494 | ), 495 | "all/mel": utils.plot_spectrogram_to_numpy( 496 | mel[0].data.cpu().numpy() 497 | ), 498 | "all/attn": utils.plot_alignment_to_numpy( 499 | attn[0, 0].data.cpu().numpy() 500 | ), 501 | } 502 | utils.summarize( 503 | writer=writer, 504 | global_step=global_step, 505 | images=image_dict, 506 | scalars=scalar_dict, 507 | ) 508 | 509 | if global_step % hps.train.eval_interval == 0: 510 | evaluate(hps, net_g, eval_loader, writer_eval) 511 | utils.save_checkpoint( 512 | net_g, 513 | optim_g, 514 | hps.train.learning_rate, 515 | epoch, 516 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), 517 | ) 518 | utils.save_checkpoint( 519 | net_d, 520 | optim_d, 521 | hps.train.learning_rate, 522 | epoch, 523 | os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), 524 | ) 525 | if net_dur_disc is not None: 526 | utils.save_checkpoint( 527 | net_dur_disc, 528 | optim_dur_disc, 529 | hps.train.learning_rate, 530 | epoch, 531 | os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), 532 | ) 533 | utils.remove_old_checkpoints(hps.model_dir, prefixes=["G_*.pth", "D_*.pth", "DUR_*.pth"]) 534 | global_step += 1 535 | 536 | if rank == 0: 537 | logger.info("====> Epoch: {}".format(epoch)) 538 | 539 | 540 | def evaluate(hps, generator, eval_loader, writer_eval): 541 | generator.eval() 542 | with torch.no_grad(): 543 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate( 544 | eval_loader 545 | ): 546 | x, x_lengths = x.cuda(0), x_lengths.cuda(0) 547 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) 548 | y, y_lengths = y.cuda(0), y_lengths.cuda(0) 549 | 550 | # remove else 551 | x = x[:1] 552 | x_lengths = x_lengths[:1] 553 | spec = spec[:1] 554 | spec_lengths = spec_lengths[:1] 555 | y = y[:1] 556 | y_lengths = y_lengths[:1] 557 | break 558 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, max_len=1000) 559 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length 560 | 561 | if hps.model.use_mel_posterior_encoder or hps.data.use_mel_posterior_encoder: 562 | mel = spec 563 | else: 564 | mel = spec_to_mel_torch( 565 | spec, 566 | hps.data.filter_length, 567 | hps.data.n_mel_channels, 568 | hps.data.sampling_rate, 569 | hps.data.mel_fmin, 570 | hps.data.mel_fmax, 571 | ) 572 | y_hat_mel = mel_spectrogram_torch( 573 | y_hat.squeeze(1).float(), 574 | hps.data.filter_length, 575 | hps.data.n_mel_channels, 576 | hps.data.sampling_rate, 577 | hps.data.hop_length, 578 | hps.data.win_length, 579 | hps.data.mel_fmin, 580 | hps.data.mel_fmax, 581 | ) 582 | image_dict = { 583 | "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) 584 | } 585 | audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]} 586 | if global_step == 0: 587 | image_dict.update( 588 | {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())} 589 | ) 590 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]}) 591 | 592 | utils.summarize( 593 | writer=writer_eval, 594 | global_step=global_step, 595 | images=image_dict, 596 | audios=audio_dict, 597 | audio_sampling_rate=hps.data.sampling_rate, 598 | ) 599 | generator.train() 600 | 601 | 602 | if __name__ == "__main__": 603 | main() 604 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 21 | iteration = checkpoint_dict["iteration"] 22 | learning_rate = checkpoint_dict["learning_rate"] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict["optimizer"]) 25 | saved_state_dict = checkpoint_dict["model"] 26 | if hasattr(model, "module"): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, "module"): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info( 42 | "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) 43 | ) 44 | return model, optimizer, learning_rate, iteration 45 | 46 | 47 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 48 | logger.info( 49 | "Saving model and optimizer state at iteration {} to {}".format( 50 | iteration, checkpoint_path 51 | ) 52 | ) 53 | if hasattr(model, "module"): 54 | state_dict = model.module.state_dict() 55 | else: 56 | state_dict = model.state_dict() 57 | torch.save( 58 | { 59 | "model": state_dict, 60 | "iteration": iteration, 61 | "optimizer": optimizer.state_dict(), 62 | "learning_rate": learning_rate, 63 | }, 64 | checkpoint_path, 65 | ) 66 | 67 | 68 | def summarize( 69 | writer, 70 | global_step, 71 | scalars={}, 72 | histograms={}, 73 | images={}, 74 | audios={}, 75 | audio_sampling_rate=22050, 76 | ): 77 | for k, v in scalars.items(): 78 | writer.add_scalar(k, v, global_step) 79 | for k, v in histograms.items(): 80 | writer.add_histogram(k, v, global_step) 81 | for k, v in images.items(): 82 | writer.add_image(k, v, global_step, dataformats="HWC") 83 | for k, v in audios.items(): 84 | writer.add_audio(k, v, global_step, audio_sampling_rate) 85 | 86 | 87 | def scan_checkpoint(dir_path, regex): 88 | f_list = glob.glob(os.path.join(dir_path, regex)) 89 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 90 | if len(f_list) == 0: 91 | return None 92 | return f_list 93 | 94 | 95 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 96 | f_list = scan_checkpoint(dir_path, regex) 97 | if not f_list: 98 | return None 99 | x = f_list[-1] 100 | print(x) 101 | return x 102 | 103 | 104 | def remove_old_checkpoints(cp_dir, prefixes=['G_*.pth', 'D_*.pth', 'DUR_*.pth']): 105 | for prefix in prefixes: 106 | sorted_ckpts = scan_checkpoint(cp_dir, prefix) 107 | if sorted_ckpts and len(sorted_ckpts) > 3: 108 | for ckpt_path in sorted_ckpts[:-3]: 109 | os.remove(ckpt_path) 110 | print("removed {}".format(ckpt_path)) 111 | 112 | 113 | def plot_spectrogram_to_numpy(spectrogram): 114 | global MATPLOTLIB_FLAG 115 | if not MATPLOTLIB_FLAG: 116 | import matplotlib 117 | 118 | matplotlib.use("Agg") 119 | MATPLOTLIB_FLAG = True 120 | mpl_logger = logging.getLogger("matplotlib") 121 | mpl_logger.setLevel(logging.WARNING) 122 | import matplotlib.pylab as plt 123 | import numpy as np 124 | 125 | fig, ax = plt.subplots(figsize=(10, 2)) 126 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 127 | plt.colorbar(im, ax=ax) 128 | plt.xlabel("Frames") 129 | plt.ylabel("Channels") 130 | plt.tight_layout() 131 | 132 | fig.canvas.draw() 133 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 134 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 135 | plt.close() 136 | return data 137 | 138 | 139 | def plot_alignment_to_numpy(alignment, info=None): 140 | global MATPLOTLIB_FLAG 141 | if not MATPLOTLIB_FLAG: 142 | import matplotlib 143 | 144 | matplotlib.use("Agg") 145 | MATPLOTLIB_FLAG = True 146 | mpl_logger = logging.getLogger("matplotlib") 147 | mpl_logger.setLevel(logging.WARNING) 148 | import matplotlib.pylab as plt 149 | import numpy as np 150 | 151 | fig, ax = plt.subplots(figsize=(6, 4)) 152 | im = ax.imshow( 153 | alignment.transpose(), aspect="auto", origin="lower", interpolation="none" 154 | ) 155 | fig.colorbar(im, ax=ax) 156 | xlabel = "Decoder timestep" 157 | if info is not None: 158 | xlabel += "\n\n" + info 159 | plt.xlabel(xlabel) 160 | plt.ylabel("Encoder timestep") 161 | plt.tight_layout() 162 | 163 | fig.canvas.draw() 164 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 165 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 166 | plt.close() 167 | return data 168 | 169 | 170 | def load_wav_to_torch(full_path): 171 | sampling_rate, data = read(full_path) 172 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 173 | 174 | 175 | def load_filepaths_and_text(filename, split="|"): 176 | with open(filename, encoding="utf-8") as f: 177 | filepaths_and_text = [line.strip().split(split) for line in f] 178 | return filepaths_and_text 179 | 180 | 181 | def get_hparams(init=True): 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument( 184 | "-c", 185 | "--config", 186 | type=str, 187 | default="./configs/base.json", 188 | help="JSON file for configuration", 189 | ) 190 | parser.add_argument("-m", "--model", type=str, required=True, help="Model name") 191 | 192 | args = parser.parse_args() 193 | model_dir = os.path.join("./logs", args.model) 194 | 195 | if not os.path.exists(model_dir): 196 | os.makedirs(model_dir) 197 | 198 | config_path = args.config 199 | config_save_path = os.path.join(model_dir, "config.json") 200 | if init: 201 | with open(config_path, "r") as f: 202 | data = f.read() 203 | with open(config_save_path, "w") as f: 204 | f.write(data) 205 | else: 206 | with open(config_save_path, "r") as f: 207 | data = f.read() 208 | config = json.loads(data) 209 | 210 | hparams = HParams(**config) 211 | hparams.model_dir = model_dir 212 | return hparams 213 | 214 | 215 | def get_hparams_from_dir(model_dir): 216 | config_save_path = os.path.join(model_dir, "config.json") 217 | with open(config_save_path, "r") as f: 218 | data = f.read() 219 | config = json.loads(data) 220 | 221 | hparams = HParams(**config) 222 | hparams.model_dir = model_dir 223 | return hparams 224 | 225 | 226 | def get_hparams_from_file(config_path): 227 | with open(config_path, "r") as f: 228 | data = f.read() 229 | config = json.loads(data) 230 | 231 | hparams = HParams(**config) 232 | return hparams 233 | 234 | 235 | def check_git_hash(model_dir): 236 | source_dir = os.path.dirname(os.path.realpath(__file__)) 237 | if not os.path.exists(os.path.join(source_dir, ".git")): 238 | logger.warn( 239 | "{} is not a git repository, therefore hash value comparison will be ignored.".format( 240 | source_dir 241 | ) 242 | ) 243 | return 244 | 245 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 246 | 247 | path = os.path.join(model_dir, "githash") 248 | if os.path.exists(path): 249 | saved_hash = open(path).read() 250 | if saved_hash != cur_hash: 251 | logger.warn( 252 | "git hash values are different. {}(saved) != {}(current)".format( 253 | saved_hash[:8], cur_hash[:8] 254 | ) 255 | ) 256 | else: 257 | open(path, "w").write(cur_hash) 258 | 259 | 260 | def get_logger(model_dir, filename="train.log"): 261 | global logger 262 | logger = logging.getLogger(os.path.basename(model_dir)) 263 | logger.setLevel(logging.DEBUG) 264 | 265 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 266 | if not os.path.exists(model_dir): 267 | os.makedirs(model_dir) 268 | h = logging.FileHandler(os.path.join(model_dir, filename)) 269 | h.setLevel(logging.DEBUG) 270 | h.setFormatter(formatter) 271 | logger.addHandler(h) 272 | return logger 273 | 274 | 275 | class HParams: 276 | def __init__(self, **kwargs): 277 | for k, v in kwargs.items(): 278 | if type(v) == dict: 279 | v = HParams(**v) 280 | self[k] = v 281 | 282 | def keys(self): 283 | return self.__dict__.keys() 284 | 285 | def items(self): 286 | return self.__dict__.items() 287 | 288 | def values(self): 289 | return self.__dict__.values() 290 | 291 | def __len__(self): 292 | return len(self.__dict__) 293 | 294 | def __getitem__(self, key): 295 | return getattr(self, key) 296 | 297 | def __setitem__(self, key, value): 298 | return setattr(self, key, value) 299 | 300 | def __contains__(self, key): 301 | return key in self.__dict__ 302 | 303 | def __repr__(self): 304 | return self.__dict__.__repr__() 305 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gradio as gr 3 | from gradio import components 4 | import os 5 | import torch 6 | import commons 7 | import utils 8 | from models import SynthesizerTrn 9 | from text.symbols import symbols 10 | from text import text_to_sequence 11 | from scipy.io.wavfile import write 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | def tts(model_path, config_path, text): 21 | model_path = './logs/' + model_path 22 | config_path = './configs/' + config_path 23 | hps = utils.get_hparams_from_file(config_path) 24 | 25 | if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True: 26 | posterior_channels = 80 27 | hps.data.use_mel_posterior_encoder = True 28 | else: 29 | posterior_channels = hps.data.filter_length // 2 + 1 30 | hps.data.use_mel_posterior_encoder = False 31 | 32 | net_g = SynthesizerTrn( 33 | len(symbols), 34 | posterior_channels, 35 | hps.train.segment_size // hps.data.hop_length, 36 | **hps.model).cuda() 37 | _ = net_g.eval() 38 | _ = utils.load_checkpoint(model_path, net_g, None) 39 | 40 | stn_tst = get_text(text, hps) 41 | x_tst = stn_tst.cuda().unsqueeze(0) 42 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 43 | 44 | with torch.no_grad(): 45 | audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy() 46 | 47 | output_wav_path = "output.wav" 48 | write(output_wav_path, hps.data.sampling_rate, audio) 49 | 50 | return output_wav_path 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--model_path', type=str, default=None, help='Path to the model file.') 55 | parser.add_argument('--config_path', type=str, default=None, help='Path to the config file.') 56 | args = parser.parse_args() 57 | 58 | model_files = [f for f in os.listdir('./logs/') if f.endswith('.pth')] 59 | model_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True) 60 | config_files = [f for f in os.listdir('./configs/') if f.endswith('.json')] 61 | 62 | default_model_file = args.model_path if args.model_path else (model_files[0] if model_files else None) 63 | default_config_file = args.config_path if args.config_path else 'config.json' 64 | 65 | gr.Interface( 66 | fn=tts, 67 | inputs=[components.Dropdown(model_files,value=default_model_file, label="Model File"), components.Dropdown(config_files,value=default_config_file, label="Config File"), components.Textbox(label="Text Input")], 68 | outputs=components.Audio(type='filepath', label="Generated Speech"), 69 | live=False 70 | ).launch() 71 | --------------------------------------------------------------------------------