├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── attentions.py ├── colab_requirements.txt ├── commons.py ├── configs ├── vits2_ljs_base.json ├── vits2_ljs_nosdp.json ├── vits2_vctk_base.json ├── vits2_vctk_base_pr.json └── vits2_vctk_standard.json ├── data_utils.py ├── export_onnx.py ├── filelists ├── ljs_audio_text_test_filelist.txt ├── ljs_audio_text_test_filelist.txt.cleaned ├── ljs_audio_text_train_filelist.txt ├── ljs_audio_text_train_filelist.txt.cleaned ├── ljs_audio_text_val_filelist.txt ├── ljs_audio_text_val_filelist.txt.cleaned ├── vctk_audio_sid_text_test_filelist.txt ├── vctk_audio_sid_text_test_filelist.txt.cleaned ├── vctk_audio_sid_text_train_filelist.txt ├── vctk_audio_sid_text_train_filelist.txt.cleaned ├── vctk_audio_sid_text_train_filelist_new.txt ├── vctk_audio_sid_text_train_filelist_new.txt.cleaned ├── vctk_audio_sid_text_val_filelist.txt ├── vctk_audio_sid_text_val_filelist.txt.cleaned └── vctk_audio_sid_text_val_filelist_new.txt.cleaned ├── infer_onnx.py ├── inference.ipynb ├── inference.py ├── inference_ms.py ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── monotonic_align ├── __init__.py ├── core.pyx ├── monotonic_align │ └── .gitkeep └── setup.py ├── preprocess.py ├── preprocess_audio.py ├── requirements.txt ├── resources ├── image.png ├── sid_src_3.wav ├── sid_src_3_to_tgt_1.wav ├── sid_src_3_to_tgt_2.wav ├── sid_src_3_to_tgt_4.wav ├── test.wav ├── vctk_onnx_test.wav └── vctk_test.wav ├── text ├── LICENSE ├── __init__.py ├── cleaners.py └── symbols.py ├── train.py ├── train_ms.py ├── transforms.py ├── utils.py └── webui.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | DUMMY1 2 | DUMMY2 3 | DUMMY3 4 | logs 5 | __pycache__ 6 | .ipynb_checkpoints 7 | .*.swp 8 | 9 | build 10 | *.c 11 | monotonic_align/monotonic_align 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jaehyeon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design 2 | ### Jungil Kong, Jihoon Park, Beomjeong Kim, Jeongmin Kim, Dohee Kong, Sangjin Kim 3 | Unofficial implementation of the [VITS2 paper](https://arxiv.org/abs/2307.16430), sequel to [VITS paper](https://arxiv.org/abs/2106.06103). (thanks to the authors for their work!) 4 | 5 | ![Alt text](resources/image.png) 6 | 7 | Single-stage text-to-speech models have been actively studied recently, and their results have outperformed two-stage pipeline systems. Although the previous single-stage model has made great progress, there is room for improvement in terms of its intermittent unnaturalness, computational efficiency, and strong dependence on phoneme conversion. In this work, we introduce VITS2, a single-stage text-to-speech model that efficiently synthesizes a more natural speech by improving several aspects of the previous work. We propose improved structures and training mechanisms and present that the proposed methods are effective in improving naturalness, similarity of speech characteristics in a multi-speaker model, and efficiency of training and inference. Furthermore, we demonstrate that the strong dependence on phoneme conversion in previous works can be significantly reduced with our method, which allows a fully end-toend single-stage approach. 8 | 9 | ## Credits 10 | - We will build this repo based on the [VITS repo](https://github.com/jaywalnut310/vits). The goal is to make this model easier to transfer learning from VITS pretrained model! 11 | - (08-17-2023) - The authors were really kind to guide me through the paper and answer my questions. I am open to discuss any changes or answer questions regarding the implementation. Please feel free to open an issue or contact me directly. 12 | 13 | # Pretrained checkpoints 14 | - [LJSpeech-no-sdp](https://drive.google.com/drive/folders/1U-1EqBMXqmEqK0aUhbCJOquowbvKkLmc?usp=sharing) (refer to config.yaml in this checkppoint folder) | 64k steps | proof that training works! 15 | Would recommend experts to rename the ckpts to *_0.pth and starting the training using transfer learning. (I will add a notebook for this soon to help beginers). 16 | - [x] Check 'Discussion' page for training logs and tensorboard links and other community contributions. 17 | 18 | # Sample audio 19 | - Russian trained model samples [#32](https://github.com/p0p4k/vits2_pytorch/discussions/32). Thanks to [@shigabeev](https://github.com/shigabeev) for sharing the samples. 20 | - Some samples on non-native EN dataset [discussion page](https://github.com/p0p4k/vits2_pytorch/discussions/18). Thanks to [@athenasaurav](https://github.com/athenasaurav) for using his private GPU resources and dataset! 21 | - Added sample audio @104k steps. [ljspeech-nosdp](resources/test.wav) ; [tensorboard](https://github.com/p0p4k/vits2_pytorch/discussions/12) 22 | - [vietnamese samples](https://github.com/p0p4k/vits2_pytorch/pull/10#issuecomment-1682307529) Thanks to [@ductho9799](https://github.com/ductho9799) for sharing! 23 | 24 | ## Prerequisites 25 | 1. Python >= 3.10 26 | 2. Tested on Pytorch version 1.13.1 with Google Colab and LambdaLabs cloud. 27 | 3. Clone this repository 28 | 4. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 1. You may need to install espeak first: `apt-get install espeak` 30 | 5. Download datasets 31 | 1. Download and extract the LJ Speech dataset, then rename or create a link to the dataset folder: `ln -s /path/to/LJSpeech-1.1/wavs DUMMY1` 32 | 1. For mult-speaker setting, download and extract the VCTK dataset, and downsample wav files to 22050 Hz. Then rename or create a link to the dataset folder: `ln -s /path/to/VCTK-Corpus/downsampled_wavs DUMMY2` 33 | 6. Build Monotonic Alignment Search and run preprocessing if you use your own datasets. 34 | 35 | ```sh 36 | # Cython-version Monotonoic Alignment Search 37 | cd monotonic_align 38 | python setup.py build_ext --inplace 39 | 40 | # Preprocessing (g2p) for your own datasets. Preprocessed phonemes for LJ Speech and VCTK have been already provided. 41 | # python preprocess.py --text_index 1 --filelists filelists/ljs_audio_text_train_filelist.txt filelists/ljs_audio_text_val_filelist.txt filelists/ljs_audio_text_test_filelist.txt 42 | # python preprocess.py --text_index 2 --filelists filelists/vctk_audio_sid_text_train_filelist.txt filelists/vctk_audio_sid_text_val_filelist.txt filelists/vctk_audio_sid_text_test_filelist.txt 43 | ``` 44 | 45 | ## How to run (dry-run) 46 | - model forward pass (dry-run) 47 | ```python 48 | import torch 49 | from models import SynthesizerTrn 50 | 51 | net_g = SynthesizerTrn( 52 | n_vocab=256, 53 | spec_channels=80, # <--- vits2 parameter (changed from 513 to 80) 54 | segment_size=8192, 55 | inter_channels=192, 56 | hidden_channels=192, 57 | filter_channels=768, 58 | n_heads=2, 59 | n_layers=6, 60 | kernel_size=3, 61 | p_dropout=0.1, 62 | resblock="1", 63 | resblock_kernel_sizes=[3, 7, 11], 64 | resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 65 | upsample_rates=[8, 8, 2, 2], 66 | upsample_initial_channel=512, 67 | upsample_kernel_sizes=[16, 16, 4, 4], 68 | n_speakers=0, 69 | gin_channels=0, 70 | use_sdp=True, 71 | use_transformer_flows=True, # <--- vits2 parameter 72 | # (choose from "pre_conv", "fft", "mono_layer_inter_residual", "mono_layer_post_residual") 73 | transformer_flow_type="fft", # <--- vits2 parameter 74 | use_spk_conditioned_encoder=True, # <--- vits2 parameter 75 | use_noise_scaled_mas=True, # <--- vits2 parameter 76 | use_duration_discriminator=True, # <--- vits2 parameter 77 | ) 78 | 79 | x = torch.LongTensor([[1, 2, 3],[4, 5, 6]]) # token ids 80 | x_lengths = torch.LongTensor([3, 2]) # token lengths 81 | y = torch.randn(2, 80, 100) # mel spectrograms 82 | y_lengths = torch.Tensor([100, 80]) # mel spectrogram lengths 83 | 84 | net_g( 85 | x=x, 86 | x_lengths=x_lengths, 87 | y=y, 88 | y_lengths=y_lengths, 89 | ) 90 | 91 | # calculate loss and backpropagate 92 | ``` 93 | 94 | ## Training Example 95 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12iCPWMpKdekM6F1HujYCh5H8PyzrgJN2?usp=sharing) 96 | ```sh 97 | # LJ Speech 98 | python train.py -c configs/vits2_ljs_nosdp.json -m ljs_base # no-sdp; (recommended) 99 | python train.py -c configs/vits2_ljs_base.json -m ljs_base # with sdp; 100 | 101 | # VCTK 102 | python train_ms.py -c configs/vits2_vctk_base.json -m vctk_base 103 | 104 | # for onnx export of trained models 105 | python export_onnx.py --model-path="G_64000.pth" --config-path="config.json" --output="vits2.onnx" 106 | python infer_onnx.py --model="vits2.onnx" --config-path="config.json" --output-wav-path="output.wav" --text="hello world, how are you?" 107 | ``` 108 | 109 | ## TODOs, features and notes 110 | 111 | #### Duration predictor (fig 1a) 112 | - [x] Added LSTM discriminator to duration predictor. 113 | - [x] Added adversarial loss to duration predictor. ("use_duration_discriminator" flag in config file; default is "True") 114 | - [x] Monotonic Alignment Search with Gaussian Noise added; might need expert verification (Section 2.2) 115 | - [x] Added "use_noise_scaled_mas" flag in config file. Choose from True or False; updates noise while training based on number of steps and never goes below 0.0 116 | - [x] Update models.py/train.py/train_ms.py 117 | - [x] Update config files (vits2_vctk_base.json; vits2_ljs_base.json) 118 | - [x] Update losses in train.py and train_ms.py 119 | #### Transformer block in the normalizing flow (fig 1b) 120 | - [x] Added transformer block to the normalizing flow. There are three types of transformer blocks: pre-convolution (my implementation), FFT (from [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc/commit/fc8336fffd40c39bdb225c1b041ab4dd15fac4e9) repo) and mono-layer. 121 | - [x] Added "transformer_flow_type" flag in config file. Choose from "pre_conv", "fft", "mono_layer_inter_residual", "mono_layer_post_residual". 122 | - [x] Added layers and blocks in models.py 123 | (ResidualCouplingTransformersLayer, 124 | ResidualCouplingTransformersBlock, 125 | FFTransformerCouplingLayer, 126 | MonoTransformerFlowLayer) 127 | - [x] Add in config file (vits2_ljs_base.json; can be turned on using "use_transformer_flows" flag) 128 | #### Speaker-conditioned text encoder (fig 1c) 129 | - [x] Added speaker embedding to the text encoder in models.py (TextEncoder; backward compatible with VITS) 130 | - [x] Add in config file (vits2_ljs_base.json; can be turned on using "use_spk_conditioned_encoder" flag) 131 | #### Mel spectrogram posterior encoder (Section 3) 132 | - [x] Added mel spectrogram posterior encoder in train.py 133 | - [x] Addded new config file (vits2_ljs_base.json; can be turned on using "use_mel_posterior_encoder" flag) 134 | - [x] Updated 'data_utils.py' to use the "use_mel_posterior_encoder" flag for vits2 135 | #### Training scripts 136 | - [x] Added vits2 flags to train.py (single-speaer model) 137 | - [x] Added vits2 flags to train_ms.py (multi-speaker model) 138 | #### ONNX export 139 | - [x] Add ONNX export support. 140 | #### Gradio Demo 141 | - [x] Add Gradio demo support. 142 | 143 | ## Special mentions 144 | - [@erogol](https://github.com/erogol) for quick feedback and guidance. (Please check his awesome [CoquiTTS](https://github.com/coqui-ai/TTS) repo). 145 | - [@lexkoro](https://github.com/lexkoro) for discussions and help with the prototype training. 146 | - [@manmay-nakhashi](https://github.com/manmay-nakhashi) for discussions and help with the code. 147 | - [@athenasaurav](https://github.com/athenasaurav) for offering GPU support for training. 148 | - [@w11wo](https://github.com/w11wo) for ONNX support. 149 | - [@Subarasheese](https://github.com/Subarasheese) for Gradio UI. 150 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.utils import remove_weight_norm, weight_norm 8 | 9 | import commons 10 | import modules 11 | from modules import LayerNorm 12 | 13 | 14 | class Encoder(nn.Module): # backward compatible vits2 encoder 15 | def __init__( 16 | self, 17 | hidden_channels, 18 | filter_channels, 19 | n_heads, 20 | n_layers, 21 | kernel_size=1, 22 | p_dropout=0.0, 23 | window_size=4, 24 | **kwargs 25 | ): 26 | super().__init__() 27 | self.hidden_channels = hidden_channels 28 | self.filter_channels = filter_channels 29 | self.n_heads = n_heads 30 | self.n_layers = n_layers 31 | self.kernel_size = kernel_size 32 | self.p_dropout = p_dropout 33 | self.window_size = window_size 34 | 35 | self.drop = nn.Dropout(p_dropout) 36 | self.attn_layers = nn.ModuleList() 37 | self.norm_layers_1 = nn.ModuleList() 38 | self.ffn_layers = nn.ModuleList() 39 | self.norm_layers_2 = nn.ModuleList() 40 | # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels 41 | self.cond_layer_idx = self.n_layers 42 | if "gin_channels" in kwargs: 43 | self.gin_channels = kwargs["gin_channels"] 44 | if self.gin_channels != 0: 45 | self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) 46 | # vits2 says 3rd block, so idx is 2 by default 47 | self.cond_layer_idx = ( 48 | kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 49 | ) 50 | assert ( 51 | self.cond_layer_idx < self.n_layers 52 | ), "cond_layer_idx should be less than n_layers" 53 | 54 | for i in range(self.n_layers): 55 | self.attn_layers.append( 56 | MultiHeadAttention( 57 | hidden_channels, 58 | hidden_channels, 59 | n_heads, 60 | p_dropout=p_dropout, 61 | window_size=window_size, 62 | ) 63 | ) 64 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 65 | self.ffn_layers.append( 66 | FFN( 67 | hidden_channels, 68 | hidden_channels, 69 | filter_channels, 70 | kernel_size, 71 | p_dropout=p_dropout, 72 | ) 73 | ) 74 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 75 | 76 | def forward(self, x, x_mask, g=None): 77 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 78 | x = x * x_mask 79 | for i in range(self.n_layers): 80 | if i == self.cond_layer_idx and g is not None: 81 | g = self.spk_emb_linear(g.transpose(1, 2)) 82 | g = g.transpose(1, 2) 83 | x = x + g 84 | x = x * x_mask 85 | y = self.attn_layers[i](x, x, attn_mask) 86 | y = self.drop(y) 87 | x = self.norm_layers_1[i](x + y) 88 | 89 | y = self.ffn_layers[i](x, x_mask) 90 | y = self.drop(y) 91 | x = self.norm_layers_2[i](x + y) 92 | x = x * x_mask 93 | return x 94 | 95 | 96 | class Decoder(nn.Module): 97 | def __init__( 98 | self, 99 | hidden_channels, 100 | filter_channels, 101 | n_heads, 102 | n_layers, 103 | kernel_size=1, 104 | p_dropout=0.0, 105 | proximal_bias=False, 106 | proximal_init=True, 107 | **kwargs 108 | ): 109 | super().__init__() 110 | self.hidden_channels = hidden_channels 111 | self.filter_channels = filter_channels 112 | self.n_heads = n_heads 113 | self.n_layers = n_layers 114 | self.kernel_size = kernel_size 115 | self.p_dropout = p_dropout 116 | self.proximal_bias = proximal_bias 117 | self.proximal_init = proximal_init 118 | 119 | self.drop = nn.Dropout(p_dropout) 120 | self.self_attn_layers = nn.ModuleList() 121 | self.norm_layers_0 = nn.ModuleList() 122 | self.encdec_attn_layers = nn.ModuleList() 123 | self.norm_layers_1 = nn.ModuleList() 124 | self.ffn_layers = nn.ModuleList() 125 | self.norm_layers_2 = nn.ModuleList() 126 | for i in range(self.n_layers): 127 | self.self_attn_layers.append( 128 | MultiHeadAttention( 129 | hidden_channels, 130 | hidden_channels, 131 | n_heads, 132 | p_dropout=p_dropout, 133 | proximal_bias=proximal_bias, 134 | proximal_init=proximal_init, 135 | ) 136 | ) 137 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 138 | self.encdec_attn_layers.append( 139 | MultiHeadAttention( 140 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 141 | ) 142 | ) 143 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 144 | self.ffn_layers.append( 145 | FFN( 146 | hidden_channels, 147 | hidden_channels, 148 | filter_channels, 149 | kernel_size, 150 | p_dropout=p_dropout, 151 | causal=True, 152 | ) 153 | ) 154 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 155 | 156 | def forward(self, x, x_mask, h, h_mask): 157 | """ 158 | x: decoder input 159 | h: encoder output 160 | """ 161 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 162 | device=x.device, dtype=x.dtype 163 | ) 164 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 165 | x = x * x_mask 166 | for i in range(self.n_layers): 167 | y = self.self_attn_layers[i](x, x, self_attn_mask) 168 | y = self.drop(y) 169 | x = self.norm_layers_0[i](x + y) 170 | 171 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 172 | y = self.drop(y) 173 | x = self.norm_layers_1[i](x + y) 174 | 175 | y = self.ffn_layers[i](x, x_mask) 176 | y = self.drop(y) 177 | x = self.norm_layers_2[i](x + y) 178 | x = x * x_mask 179 | return x 180 | 181 | 182 | class MultiHeadAttention(nn.Module): 183 | def __init__( 184 | self, 185 | channels, 186 | out_channels, 187 | n_heads, 188 | p_dropout=0.0, 189 | window_size=None, 190 | heads_share=True, 191 | block_length=None, 192 | proximal_bias=False, 193 | proximal_init=False, 194 | ): 195 | super().__init__() 196 | assert channels % n_heads == 0 197 | 198 | self.channels = channels 199 | self.out_channels = out_channels 200 | self.n_heads = n_heads 201 | self.p_dropout = p_dropout 202 | self.window_size = window_size 203 | self.heads_share = heads_share 204 | self.block_length = block_length 205 | self.proximal_bias = proximal_bias 206 | self.proximal_init = proximal_init 207 | self.attn = None 208 | 209 | self.k_channels = channels // n_heads 210 | self.conv_q = nn.Conv1d(channels, channels, 1) 211 | self.conv_k = nn.Conv1d(channels, channels, 1) 212 | self.conv_v = nn.Conv1d(channels, channels, 1) 213 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 214 | self.drop = nn.Dropout(p_dropout) 215 | 216 | if window_size is not None: 217 | n_heads_rel = 1 if heads_share else n_heads 218 | rel_stddev = self.k_channels**-0.5 219 | self.emb_rel_k = nn.Parameter( 220 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 221 | * rel_stddev 222 | ) 223 | self.emb_rel_v = nn.Parameter( 224 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 225 | * rel_stddev 226 | ) 227 | 228 | nn.init.xavier_uniform_(self.conv_q.weight) 229 | nn.init.xavier_uniform_(self.conv_k.weight) 230 | nn.init.xavier_uniform_(self.conv_v.weight) 231 | if proximal_init: 232 | with torch.no_grad(): 233 | self.conv_k.weight.copy_(self.conv_q.weight) 234 | self.conv_k.bias.copy_(self.conv_q.bias) 235 | 236 | def forward(self, x, c, attn_mask=None): 237 | q = self.conv_q(x) 238 | k = self.conv_k(c) 239 | v = self.conv_v(c) 240 | 241 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 242 | 243 | x = self.conv_o(x) 244 | return x 245 | 246 | def attention(self, query, key, value, mask=None): 247 | # reshape [b, d, t] -> [b, n_h, t, d_k] 248 | b, d, t_s, t_t = (*key.size(), query.size(2)) 249 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 250 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 251 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 252 | 253 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 254 | if self.window_size is not None: 255 | assert ( 256 | t_s == t_t 257 | ), "Relative attention is only available for self-attention." 258 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 259 | rel_logits = self._matmul_with_relative_keys( 260 | query / math.sqrt(self.k_channels), key_relative_embeddings 261 | ) 262 | scores_local = self._relative_position_to_absolute_position(rel_logits) 263 | scores = scores + scores_local 264 | if self.proximal_bias: 265 | assert t_s == t_t, "Proximal bias is only available for self-attention." 266 | scores = scores + self._attention_bias_proximal(t_s).to( 267 | device=scores.device, dtype=scores.dtype 268 | ) 269 | if mask is not None: 270 | scores = scores.masked_fill(mask == 0, -1e4) 271 | if self.block_length is not None: 272 | assert ( 273 | t_s == t_t 274 | ), "Local attention is only available for self-attention." 275 | block_mask = ( 276 | torch.ones_like(scores) 277 | .triu(-self.block_length) 278 | .tril(self.block_length) 279 | ) 280 | scores = scores.masked_fill(block_mask == 0, -1e4) 281 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 282 | p_attn = self.drop(p_attn) 283 | output = torch.matmul(p_attn, value) 284 | if self.window_size is not None: 285 | relative_weights = self._absolute_position_to_relative_position(p_attn) 286 | value_relative_embeddings = self._get_relative_embeddings( 287 | self.emb_rel_v, t_s 288 | ) 289 | output = output + self._matmul_with_relative_values( 290 | relative_weights, value_relative_embeddings 291 | ) 292 | output = ( 293 | output.transpose(2, 3).contiguous().view(b, d, t_t) 294 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 295 | return output, p_attn 296 | 297 | def _matmul_with_relative_values(self, x, y): 298 | """ 299 | x: [b, h, l, m] 300 | y: [h or 1, m, d] 301 | ret: [b, h, l, d] 302 | """ 303 | ret = torch.matmul(x, y.unsqueeze(0)) 304 | return ret 305 | 306 | def _matmul_with_relative_keys(self, x, y): 307 | """ 308 | x: [b, h, l, d] 309 | y: [h or 1, m, d] 310 | ret: [b, h, l, m] 311 | """ 312 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 313 | return ret 314 | 315 | def _get_relative_embeddings(self, relative_embeddings, length): 316 | max_relative_position = 2 * self.window_size + 1 317 | # Pad first before slice to avoid using cond ops. 318 | pad_length = max(length - (self.window_size + 1), 0) 319 | slice_start_position = max((self.window_size + 1) - length, 0) 320 | slice_end_position = slice_start_position + 2 * length - 1 321 | if pad_length > 0: 322 | padded_relative_embeddings = F.pad( 323 | relative_embeddings, 324 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 325 | ) 326 | else: 327 | padded_relative_embeddings = relative_embeddings 328 | used_relative_embeddings = padded_relative_embeddings[ 329 | :, slice_start_position:slice_end_position 330 | ] 331 | return used_relative_embeddings 332 | 333 | def _relative_position_to_absolute_position(self, x): 334 | """ 335 | x: [b, h, l, 2*l-1] 336 | ret: [b, h, l, l] 337 | """ 338 | batch, heads, length, _ = x.size() 339 | # Concat columns of pad to shift from relative to absolute indexing. 340 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 341 | 342 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 343 | x_flat = x.view([batch, heads, length * 2 * length]) 344 | x_flat = F.pad( 345 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 346 | ) 347 | 348 | # Reshape and slice out the padded elements. 349 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 350 | :, :, :length, length - 1 : 351 | ] 352 | return x_final 353 | 354 | def _absolute_position_to_relative_position(self, x): 355 | """ 356 | x: [b, h, l, l] 357 | ret: [b, h, l, 2*l-1] 358 | """ 359 | batch, heads, length, _ = x.size() 360 | # padd along column 361 | x = F.pad( 362 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 363 | ) 364 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 365 | # add 0's in the beginning that will skew the elements after reshape 366 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 367 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 368 | return x_final 369 | 370 | def _attention_bias_proximal(self, length): 371 | """Bias for self-attention to encourage attention to close positions. 372 | Args: 373 | length: an integer scalar. 374 | Returns: 375 | a Tensor with shape [1, 1, length, length] 376 | """ 377 | r = torch.arange(length, dtype=torch.float32) 378 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 379 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 380 | 381 | 382 | class FFN(nn.Module): 383 | def __init__( 384 | self, 385 | in_channels, 386 | out_channels, 387 | filter_channels, 388 | kernel_size, 389 | p_dropout=0.0, 390 | activation=None, 391 | causal=False, 392 | ): 393 | super().__init__() 394 | self.in_channels = in_channels 395 | self.out_channels = out_channels 396 | self.filter_channels = filter_channels 397 | self.kernel_size = kernel_size 398 | self.p_dropout = p_dropout 399 | self.activation = activation 400 | self.causal = causal 401 | 402 | if causal: 403 | self.padding = self._causal_padding 404 | else: 405 | self.padding = self._same_padding 406 | 407 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 408 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 409 | self.drop = nn.Dropout(p_dropout) 410 | 411 | def forward(self, x, x_mask): 412 | x = self.conv_1(self.padding(x * x_mask)) 413 | if self.activation == "gelu": 414 | x = x * torch.sigmoid(1.702 * x) 415 | else: 416 | x = torch.relu(x) 417 | x = self.drop(x) 418 | x = self.conv_2(self.padding(x * x_mask)) 419 | return x * x_mask 420 | 421 | def _causal_padding(self, x): 422 | if self.kernel_size == 1: 423 | return x 424 | pad_l = self.kernel_size - 1 425 | pad_r = 0 426 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 427 | x = F.pad(x, commons.convert_pad_shape(padding)) 428 | return x 429 | 430 | def _same_padding(self, x): 431 | if self.kernel_size == 1: 432 | return x 433 | pad_l = (self.kernel_size - 1) // 2 434 | pad_r = self.kernel_size // 2 435 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 436 | x = F.pad(x, commons.convert_pad_shape(padding)) 437 | return x 438 | 439 | 440 | class Depthwise_Separable_Conv1D(nn.Module): 441 | def __init__( 442 | self, 443 | in_channels, 444 | out_channels, 445 | kernel_size, 446 | stride=1, 447 | padding=0, 448 | dilation=1, 449 | bias=True, 450 | padding_mode="zeros", # TODO: refine this type 451 | device=None, 452 | dtype=None, 453 | ): 454 | super().__init__() 455 | self.depth_conv = nn.Conv1d( 456 | in_channels=in_channels, 457 | out_channels=in_channels, 458 | kernel_size=kernel_size, 459 | groups=in_channels, 460 | stride=stride, 461 | padding=padding, 462 | dilation=dilation, 463 | bias=bias, 464 | padding_mode=padding_mode, 465 | device=device, 466 | dtype=dtype, 467 | ) 468 | self.point_conv = nn.Conv1d( 469 | in_channels=in_channels, 470 | out_channels=out_channels, 471 | kernel_size=1, 472 | bias=bias, 473 | device=device, 474 | dtype=dtype, 475 | ) 476 | 477 | def forward(self, input): 478 | return self.point_conv(self.depth_conv(input)) 479 | 480 | def weight_norm(self): 481 | self.depth_conv = weight_norm(self.depth_conv, name="weight") 482 | self.point_conv = weight_norm(self.point_conv, name="weight") 483 | 484 | def remove_weight_norm(self): 485 | self.depth_conv = remove_weight_norm(self.depth_conv, name="weight") 486 | self.point_conv = remove_weight_norm(self.point_conv, name="weight") 487 | 488 | 489 | class Depthwise_Separable_TransposeConv1D(nn.Module): 490 | def __init__( 491 | self, 492 | in_channels, 493 | out_channels, 494 | kernel_size, 495 | stride=1, 496 | padding=0, 497 | output_padding=0, 498 | bias=True, 499 | dilation=1, 500 | padding_mode="zeros", # TODO: refine this type 501 | device=None, 502 | dtype=None, 503 | ): 504 | super().__init__() 505 | self.depth_conv = nn.ConvTranspose1d( 506 | in_channels=in_channels, 507 | out_channels=in_channels, 508 | kernel_size=kernel_size, 509 | groups=in_channels, 510 | stride=stride, 511 | output_padding=output_padding, 512 | padding=padding, 513 | dilation=dilation, 514 | bias=bias, 515 | padding_mode=padding_mode, 516 | device=device, 517 | dtype=dtype, 518 | ) 519 | self.point_conv = nn.Conv1d( 520 | in_channels=in_channels, 521 | out_channels=out_channels, 522 | kernel_size=1, 523 | bias=bias, 524 | device=device, 525 | dtype=dtype, 526 | ) 527 | 528 | def forward(self, input): 529 | return self.point_conv(self.depth_conv(input)) 530 | 531 | def weight_norm(self): 532 | self.depth_conv = weight_norm(self.depth_conv, name="weight") 533 | self.point_conv = weight_norm(self.point_conv, name="weight") 534 | 535 | def remove_weight_norm(self): 536 | remove_weight_norm(self.depth_conv, name="weight") 537 | remove_weight_norm(self.point_conv, name="weight") 538 | 539 | 540 | def weight_norm_modules(module, name="weight", dim=0): 541 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( 542 | module, Depthwise_Separable_TransposeConv1D 543 | ): 544 | module.weight_norm() 545 | return module 546 | else: 547 | return weight_norm(module, name, dim) 548 | 549 | 550 | def remove_weight_norm_modules(module, name="weight"): 551 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( 552 | module, Depthwise_Separable_TransposeConv1D 553 | ): 554 | module.remove_weight_norm() 555 | else: 556 | remove_weight_norm(module, name) 557 | 558 | 559 | class FFT(nn.Module): 560 | def __init__( 561 | self, 562 | hidden_channels, 563 | filter_channels, 564 | n_heads, 565 | n_layers=1, 566 | kernel_size=1, 567 | p_dropout=0.0, 568 | proximal_bias=False, 569 | proximal_init=True, 570 | isflow=False, 571 | **kwargs 572 | ): 573 | super().__init__() 574 | self.hidden_channels = hidden_channels 575 | self.filter_channels = filter_channels 576 | self.n_heads = n_heads 577 | self.n_layers = n_layers 578 | self.kernel_size = kernel_size 579 | self.p_dropout = p_dropout 580 | self.proximal_bias = proximal_bias 581 | self.proximal_init = proximal_init 582 | if isflow and "gin_channels" in kwargs and kwargs["gin_channels"] > 0: 583 | cond_layer = torch.nn.Conv1d( 584 | kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1 585 | ) 586 | self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) 587 | self.cond_layer = weight_norm_modules(cond_layer, name="weight") 588 | self.gin_channels = kwargs["gin_channels"] 589 | self.drop = nn.Dropout(p_dropout) 590 | self.self_attn_layers = nn.ModuleList() 591 | self.norm_layers_0 = nn.ModuleList() 592 | self.ffn_layers = nn.ModuleList() 593 | self.norm_layers_1 = nn.ModuleList() 594 | for i in range(self.n_layers): 595 | self.self_attn_layers.append( 596 | MultiHeadAttention( 597 | hidden_channels, 598 | hidden_channels, 599 | n_heads, 600 | p_dropout=p_dropout, 601 | proximal_bias=proximal_bias, 602 | proximal_init=proximal_init, 603 | ) 604 | ) 605 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 606 | self.ffn_layers.append( 607 | FFN( 608 | hidden_channels, 609 | hidden_channels, 610 | filter_channels, 611 | kernel_size, 612 | p_dropout=p_dropout, 613 | causal=True, 614 | ) 615 | ) 616 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 617 | 618 | def forward(self, x, x_mask, g=None): 619 | """ 620 | x: decoder input 621 | h: encoder output 622 | """ 623 | if g is not None: 624 | g = self.cond_layer(g) 625 | 626 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 627 | device=x.device, dtype=x.dtype 628 | ) 629 | x = x * x_mask 630 | for i in range(self.n_layers): 631 | if g is not None: 632 | x = self.cond_pre(x) 633 | cond_offset = i * 2 * self.hidden_channels 634 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 635 | x = commons.fused_add_tanh_sigmoid_multiply( 636 | x, g_l, torch.IntTensor([self.hidden_channels]) 637 | ) 638 | y = self.self_attn_layers[i](x, x, self_attn_mask) 639 | y = self.drop(y) 640 | x = self.norm_layers_0[i](x + y) 641 | 642 | y = self.ffn_layers[i](x, x_mask) 643 | y = self.drop(y) 644 | x = self.norm_layers_1[i](x + y) 645 | x = x * x_mask 646 | return x 647 | -------------------------------------------------------------------------------- /colab_requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | librosa==0.8.0 3 | matplotlib==3.3.1 4 | phonemizer==2.2.1 5 | Unidecode==1.1.1 6 | tensorboardX 7 | scipy -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += ( 34 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 35 | ) 36 | return kl 37 | 38 | 39 | def rand_gumbel(shape): 40 | """Sample from the Gumbel distribution, protect from overflows.""" 41 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 42 | return -torch.log(-torch.log(uniform_samples)) 43 | 44 | 45 | def rand_gumbel_like(x): 46 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 47 | return g 48 | 49 | 50 | def slice_segments(x, ids_str, segment_size=4): 51 | ret = torch.zeros_like(x[:, :, :segment_size]) 52 | for i in range(x.size(0)): 53 | idx_str = ids_str[i] 54 | idx_end = idx_str + segment_size 55 | ret[i] = x[i, :, idx_str:idx_end] 56 | return ret 57 | 58 | 59 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 60 | b, d, t = x.size() 61 | if x_lengths is None: 62 | x_lengths = t 63 | ids_str_max = x_lengths - segment_size + 1 64 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 65 | ret = slice_segments(x, ids_str, segment_size) 66 | return ret, ids_str 67 | 68 | 69 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 70 | position = torch.arange(length, dtype=torch.float) 71 | num_timescales = channels // 2 72 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 73 | num_timescales - 1 74 | ) 75 | inv_timescales = min_timescale * torch.exp( 76 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 77 | ) 78 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 79 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 80 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 81 | signal = signal.view(1, channels, length) 82 | return signal 83 | 84 | 85 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 86 | b, channels, length = x.size() 87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 88 | return x + signal.to(dtype=x.dtype, device=x.device) 89 | 90 | 91 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 92 | b, channels, length = x.size() 93 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 94 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 95 | 96 | 97 | def subsequent_mask(length): 98 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 99 | return mask 100 | 101 | 102 | @torch.jit.script 103 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 104 | n_channels_int = n_channels[0] 105 | in_act = input_a + input_b 106 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 107 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 108 | acts = t_act * s_act 109 | return acts 110 | 111 | 112 | def convert_pad_shape(pad_shape): 113 | l = pad_shape[::-1] 114 | pad_shape = [item for sublist in l for item in sublist] 115 | return pad_shape 116 | 117 | 118 | def shift_1d(x): 119 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 120 | return x 121 | 122 | 123 | def sequence_mask(length, max_length=None): 124 | if max_length is None: 125 | max_length = length.max() 126 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 127 | return x.unsqueeze(0) < length.unsqueeze(1) 128 | 129 | 130 | def generate_path(duration, mask): 131 | """ 132 | duration: [b, 1, t_x] 133 | mask: [b, 1, t_y, t_x] 134 | """ 135 | device = duration.device 136 | 137 | b, _, t_y, t_x = mask.shape 138 | cum_duration = torch.cumsum(duration, -1) 139 | 140 | cum_duration_flat = cum_duration.view(b * t_x) 141 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 142 | path = path.view(b, t_x, t_y) 143 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 144 | path = path.unsqueeze(1).transpose(2, 3) * mask 145 | return path 146 | 147 | 148 | def clip_grad_value_(parameters, clip_value, norm_type=2): 149 | if isinstance(parameters, torch.Tensor): 150 | parameters = [parameters] 151 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 152 | norm_type = float(norm_type) 153 | if clip_value is not None: 154 | clip_value = float(clip_value) 155 | 156 | total_norm = 0 157 | for p in parameters: 158 | param_norm = p.grad.data.norm(norm_type) 159 | total_norm += param_norm.item() ** norm_type 160 | if clip_value is not None: 161 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 162 | total_norm = total_norm ** (1.0 / norm_type) 163 | return total_norm 164 | -------------------------------------------------------------------------------- /configs/vits2_ljs_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners2"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 0, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": false, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /configs/vits2_ljs_nosdp.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners2"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 0, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": false, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false, 58 | "use_sdp": false 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /configs/vits2_vctk_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners2"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 109, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false, 58 | "use_sdp": false, 59 | "gin_channels": 256 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /configs/vits2_vctk_base_pr.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners3"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 109, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [8,8,2,2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16,4,4], 56 | "n_layers_q": 3, 57 | "use_spectral_norm": false, 58 | "use_sdp": false, 59 | "gin_channels": 256 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /configs/vits2_vctk_standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": true, 21 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 22 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 23 | "text_cleaners":["english_cleaners3"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 109, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "pre_conv2", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "duration_discriminator_type": "dur_disc_2", 44 | "inter_channels": 192, 45 | "hidden_channels": 192, 46 | "filter_channels": 768, 47 | "n_heads": 2, 48 | "n_layers": 6, 49 | "kernel_size": 3, 50 | "p_dropout": 0.1, 51 | "resblock": "1", 52 | "resblock_kernel_sizes": [3,7,11], 53 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 54 | "upsample_rates": [8,8,2,2], 55 | "upsample_initial_channel": 512, 56 | "upsample_kernel_sizes": [16,16,4,4], 57 | "n_layers_q": 3, 58 | "use_spectral_norm": false, 59 | "use_sdp": true, 60 | "gin_channels": 256 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | 9 | import commons 10 | from mel_processing import (mel_spectrogram_torch, spec_to_mel_torch, 11 | spectrogram_torch) 12 | from text import cleaned_text_to_sequence, text_to_sequence 13 | from utils import load_filepaths_and_text, load_wav_to_torch 14 | 15 | 16 | class TextAudioLoader(torch.utils.data.Dataset): 17 | """ 18 | 1) loads audio, text pairs 19 | 2) normalizes text and converts them to sequences of integers 20 | 3) computes spectrograms from audio files. 21 | """ 22 | 23 | def __init__(self, audiopaths_and_text, hparams): 24 | self.hparams = hparams 25 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 26 | self.text_cleaners = hparams.text_cleaners 27 | self.max_wav_value = hparams.max_wav_value 28 | self.sampling_rate = hparams.sampling_rate 29 | self.filter_length = hparams.filter_length 30 | self.hop_length = hparams.hop_length 31 | self.win_length = hparams.win_length 32 | self.sampling_rate = hparams.sampling_rate 33 | 34 | self.use_mel_spec_posterior = getattr( 35 | hparams, "use_mel_posterior_encoder", False 36 | ) 37 | if self.use_mel_spec_posterior: 38 | self.n_mel_channels = getattr(hparams, "n_mel_channels", 80) 39 | self.cleaned_text = getattr(hparams, "cleaned_text", False) 40 | 41 | self.add_blank = hparams.add_blank 42 | self.min_text_len = getattr(hparams, "min_text_len", 1) 43 | self.max_text_len = getattr(hparams, "max_text_len", 190) 44 | 45 | random.seed(1234) 46 | random.shuffle(self.audiopaths_and_text) 47 | self._filter() 48 | 49 | def _filter(self): 50 | """ 51 | Filter text & store spec lengths 52 | """ 53 | # Store spectrogram lengths for Bucketing 54 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 55 | # spec_length = wav_length // hop_length 56 | 57 | audiopaths_and_text_new = [] 58 | lengths = [] 59 | for audiopath, text in self.audiopaths_and_text: 60 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 61 | audiopaths_and_text_new.append([audiopath, text]) 62 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 63 | self.audiopaths_and_text = audiopaths_and_text_new 64 | self.lengths = lengths 65 | 66 | def get_audio_text_pair(self, audiopath_and_text): 67 | # separate filename and text 68 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 69 | text = self.get_text(text) 70 | spec, wav = self.get_audio(audiopath) 71 | return (text, spec, wav) 72 | 73 | def get_audio(self, filename): 74 | # TODO : if linear spec exists convert to mel from existing linear spec 75 | audio, sampling_rate = load_wav_to_torch(filename) 76 | if sampling_rate != self.sampling_rate: 77 | raise ValueError( 78 | "{} {} SR doesn't match target {} SR".format( 79 | sampling_rate, self.sampling_rate 80 | ) 81 | ) 82 | audio_norm = audio / self.max_wav_value 83 | audio_norm = audio_norm.unsqueeze(0) 84 | spec_filename = filename.replace(".wav", ".spec.pt") 85 | if self.use_mel_spec_posterior: 86 | spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") 87 | if os.path.exists(spec_filename): 88 | spec = torch.load(spec_filename) 89 | else: 90 | if self.use_mel_spec_posterior: 91 | """TODO : (need verification) 92 | if linear spec exists convert to 93 | mel from existing linear spec (uncomment below lines)""" 94 | # if os.path.exists(filename.replace(".wav", ".spec.pt")): 95 | # # spec, n_fft, num_mels, sampling_rate, fmin, fmax 96 | # spec = spec_to_mel_torch( 97 | # torch.load(filename.replace(".wav", ".spec.pt")), 98 | # self.filter_length, self.n_mel_channels, self.sampling_rate, 99 | # self.hparams.mel_fmin, self.hparams.mel_fmax) 100 | spec = mel_spectrogram_torch( 101 | audio_norm, 102 | self.filter_length, 103 | self.n_mel_channels, 104 | self.sampling_rate, 105 | self.hop_length, 106 | self.win_length, 107 | self.hparams.mel_fmin, 108 | self.hparams.mel_fmax, 109 | center=False, 110 | ) 111 | else: 112 | spec = spectrogram_torch( 113 | audio_norm, 114 | self.filter_length, 115 | self.sampling_rate, 116 | self.hop_length, 117 | self.win_length, 118 | center=False, 119 | ) 120 | spec = torch.squeeze(spec, 0) 121 | torch.save(spec, spec_filename) 122 | return spec, audio_norm 123 | 124 | def get_text(self, text): 125 | if self.cleaned_text: 126 | text_norm = cleaned_text_to_sequence(text) 127 | else: 128 | text_norm = text_to_sequence(text, self.text_cleaners) 129 | if self.add_blank: 130 | text_norm = commons.intersperse(text_norm, 0) 131 | text_norm = torch.LongTensor(text_norm) 132 | return text_norm 133 | 134 | def __getitem__(self, index): 135 | return self.get_audio_text_pair(self.audiopaths_and_text[index]) 136 | 137 | def __len__(self): 138 | return len(self.audiopaths_and_text) 139 | 140 | 141 | class TextAudioCollate: 142 | """Zero-pads model inputs and targets""" 143 | 144 | def __init__(self, return_ids=False): 145 | self.return_ids = return_ids 146 | 147 | def __call__(self, batch): 148 | """Collate's training batch from normalized text and aduio 149 | PARAMS 150 | ------ 151 | batch: [text_normalized, spec_normalized, wav_normalized] 152 | """ 153 | # Right zero-pad all one-hot text sequences to max input length 154 | _, ids_sorted_decreasing = torch.sort( 155 | torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True 156 | ) 157 | 158 | max_text_len = max([len(x[0]) for x in batch]) 159 | max_spec_len = max([x[1].size(1) for x in batch]) 160 | max_wav_len = max([x[2].size(1) for x in batch]) 161 | 162 | text_lengths = torch.LongTensor(len(batch)) 163 | spec_lengths = torch.LongTensor(len(batch)) 164 | wav_lengths = torch.LongTensor(len(batch)) 165 | 166 | text_padded = torch.LongTensor(len(batch), max_text_len) 167 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 168 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 169 | text_padded.zero_() 170 | spec_padded.zero_() 171 | wav_padded.zero_() 172 | for i in range(len(ids_sorted_decreasing)): 173 | row = batch[ids_sorted_decreasing[i]] 174 | 175 | text = row[0] 176 | text_padded[i, : text.size(0)] = text 177 | text_lengths[i] = text.size(0) 178 | 179 | spec = row[1] 180 | spec_padded[i, :, : spec.size(1)] = spec 181 | spec_lengths[i] = spec.size(1) 182 | 183 | wav = row[2] 184 | wav_padded[i, :, : wav.size(1)] = wav 185 | wav_lengths[i] = wav.size(1) 186 | 187 | if self.return_ids: 188 | return ( 189 | text_padded, 190 | text_lengths, 191 | spec_padded, 192 | spec_lengths, 193 | wav_padded, 194 | wav_lengths, 195 | ids_sorted_decreasing, 196 | ) 197 | return ( 198 | text_padded, 199 | text_lengths, 200 | spec_padded, 201 | spec_lengths, 202 | wav_padded, 203 | wav_lengths, 204 | ) 205 | 206 | 207 | """Multi speaker version""" 208 | 209 | 210 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 211 | """ 212 | 1) loads audio, speaker_id, text pairs 213 | 2) normalizes text and converts them to sequences of integers 214 | 3) computes spectrograms from audio files. 215 | """ 216 | 217 | def __init__(self, audiopaths_sid_text, hparams): 218 | self.hparams = hparams 219 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 220 | self.text_cleaners = hparams.text_cleaners 221 | self.max_wav_value = hparams.max_wav_value 222 | self.sampling_rate = hparams.sampling_rate 223 | self.filter_length = hparams.filter_length 224 | self.hop_length = hparams.hop_length 225 | self.win_length = hparams.win_length 226 | self.sampling_rate = hparams.sampling_rate 227 | 228 | self.use_mel_spec_posterior = getattr( 229 | hparams, "use_mel_posterior_encoder", False 230 | ) 231 | if self.use_mel_spec_posterior: 232 | self.n_mel_channels = getattr(hparams, "n_mel_channels", 80) 233 | self.cleaned_text = getattr(hparams, "cleaned_text", False) 234 | 235 | self.add_blank = hparams.add_blank 236 | self.min_text_len = getattr(hparams, "min_text_len", 1) 237 | self.max_text_len = getattr(hparams, "max_text_len", 190) 238 | self.min_audio_len = getattr(hparams, "min_audio_len", 8192) 239 | 240 | random.seed(1234) 241 | random.shuffle(self.audiopaths_sid_text) 242 | self._filter() 243 | 244 | def _filter(self): 245 | """ 246 | Filter text & store spec lengths 247 | """ 248 | # Store spectrogram lengths for Bucketing 249 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 250 | # spec_length = wav_length // hop_length 251 | 252 | audiopaths_sid_text_new = [] 253 | lengths = [] 254 | for audiopath, sid, text in self.audiopaths_sid_text: 255 | if not os.path.isfile(audiopath): 256 | continue 257 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 258 | audiopaths_sid_text_new.append([audiopath, sid, text]) 259 | length = os.path.getsize(audiopath) // (2 * self.hop_length) 260 | if length < self.min_audio_len // self.hop_length: 261 | continue 262 | lengths.append(length) 263 | self.audiopaths_sid_text = audiopaths_sid_text_new 264 | self.lengths = lengths 265 | print( 266 | len(self.lengths) 267 | ) # if we use large corpus dataset, we can check how much time it takes. 268 | 269 | def get_audio_text_speaker_pair(self, audiopath_sid_text): 270 | # separate filename, speaker_id and text 271 | audiopath, sid, text = ( 272 | audiopath_sid_text[0], 273 | audiopath_sid_text[1], 274 | audiopath_sid_text[2], 275 | ) 276 | text = self.get_text(text) 277 | spec, wav = self.get_audio(audiopath) 278 | sid = self.get_sid(sid) 279 | return (text, spec, wav, sid) 280 | 281 | def get_audio(self, filename): 282 | # TODO : if linear spec exists convert to mel from existing linear spec 283 | audio, sampling_rate = load_wav_to_torch(filename) 284 | if sampling_rate != self.sampling_rate: 285 | raise ValueError( 286 | "{} {} SR doesn't match target {} SR".format( 287 | sampling_rate, self.sampling_rate 288 | ) 289 | ) 290 | audio_norm = audio / self.max_wav_value 291 | audio_norm = audio_norm.unsqueeze(0) 292 | spec_filename = filename.replace(".wav", ".spec.pt") 293 | if self.use_mel_spec_posterior: 294 | spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") 295 | if os.path.exists(spec_filename): 296 | spec = torch.load(spec_filename) 297 | else: 298 | if self.use_mel_spec_posterior: 299 | """TODO : (need verification) 300 | if linear spec exists convert to 301 | mel from existing linear spec (uncomment below lines)""" 302 | # if os.path.exists(filename.replace(".wav", ".spec.pt")): 303 | # # spec, n_fft, num_mels, sampling_rate, fmin, fmax 304 | # spec = spec_to_mel_torch( 305 | # torch.load(filename.replace(".wav", ".spec.pt")), 306 | # self.filter_length, self.n_mel_channels, self.sampling_rate, 307 | # self.hparams.mel_fmin, self.hparams.mel_fmax) 308 | spec = mel_spectrogram_torch( 309 | audio_norm, 310 | self.filter_length, 311 | self.n_mel_channels, 312 | self.sampling_rate, 313 | self.hop_length, 314 | self.win_length, 315 | self.hparams.mel_fmin, 316 | self.hparams.mel_fmax, 317 | center=False, 318 | ) 319 | else: 320 | spec = spectrogram_torch( 321 | audio_norm, 322 | self.filter_length, 323 | self.sampling_rate, 324 | self.hop_length, 325 | self.win_length, 326 | center=False, 327 | ) 328 | spec = torch.squeeze(spec, 0) 329 | torch.save(spec, spec_filename) 330 | return spec, audio_norm 331 | 332 | def get_text(self, text): 333 | if self.cleaned_text: 334 | text_norm = cleaned_text_to_sequence(text) 335 | else: 336 | text_norm = text_to_sequence(text, self.text_cleaners) 337 | if self.add_blank: 338 | text_norm = commons.intersperse(text_norm, 0) 339 | text_norm = torch.LongTensor(text_norm) 340 | return text_norm 341 | 342 | def get_sid(self, sid): 343 | sid = torch.LongTensor([int(sid)]) 344 | return sid 345 | 346 | def __getitem__(self, index): 347 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) 348 | 349 | def __len__(self): 350 | return len(self.audiopaths_sid_text) 351 | 352 | 353 | class TextAudioSpeakerCollate: 354 | """Zero-pads model inputs and targets""" 355 | 356 | def __init__(self, return_ids=False): 357 | self.return_ids = return_ids 358 | 359 | def __call__(self, batch): 360 | """Collate's training batch from normalized text, audio and speaker identities 361 | PARAMS 362 | ------ 363 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 364 | """ 365 | # Right zero-pad all one-hot text sequences to max input length 366 | _, ids_sorted_decreasing = torch.sort( 367 | torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True 368 | ) 369 | 370 | max_text_len = max([len(x[0]) for x in batch]) 371 | max_spec_len = max([x[1].size(1) for x in batch]) 372 | max_wav_len = max([x[2].size(1) for x in batch]) 373 | 374 | text_lengths = torch.LongTensor(len(batch)) 375 | spec_lengths = torch.LongTensor(len(batch)) 376 | wav_lengths = torch.LongTensor(len(batch)) 377 | sid = torch.LongTensor(len(batch)) 378 | 379 | text_padded = torch.LongTensor(len(batch), max_text_len) 380 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 381 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 382 | text_padded.zero_() 383 | spec_padded.zero_() 384 | wav_padded.zero_() 385 | for i in range(len(ids_sorted_decreasing)): 386 | row = batch[ids_sorted_decreasing[i]] 387 | 388 | text = row[0] 389 | text_padded[i, : text.size(0)] = text 390 | text_lengths[i] = text.size(0) 391 | 392 | spec = row[1] 393 | spec_padded[i, :, : spec.size(1)] = spec 394 | spec_lengths[i] = spec.size(1) 395 | 396 | wav = row[2] 397 | wav_padded[i, :, : wav.size(1)] = wav 398 | wav_lengths[i] = wav.size(1) 399 | 400 | sid[i] = row[3] 401 | 402 | if self.return_ids: 403 | return ( 404 | text_padded, 405 | text_lengths, 406 | spec_padded, 407 | spec_lengths, 408 | wav_padded, 409 | wav_lengths, 410 | sid, 411 | ids_sorted_decreasing, 412 | ) 413 | return ( 414 | text_padded, 415 | text_lengths, 416 | spec_padded, 417 | spec_lengths, 418 | wav_padded, 419 | wav_lengths, 420 | sid, 421 | ) 422 | 423 | 424 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 425 | """ 426 | Maintain similar input lengths in a batch. 427 | Length groups are specified by boundaries. 428 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 429 | 430 | It removes samples which are not included in the boundaries. 431 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 432 | """ 433 | 434 | def __init__( 435 | self, 436 | dataset, 437 | batch_size, 438 | boundaries, 439 | num_replicas=None, 440 | rank=None, 441 | shuffle=True, 442 | ): 443 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 444 | self.lengths = dataset.lengths 445 | self.batch_size = batch_size 446 | self.boundaries = boundaries 447 | 448 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 449 | self.total_size = sum(self.num_samples_per_bucket) 450 | self.num_samples = self.total_size // self.num_replicas 451 | 452 | def _create_buckets(self): 453 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 454 | for i in range(len(self.lengths)): 455 | length = self.lengths[i] 456 | idx_bucket = self._bisect(length) 457 | if idx_bucket != -1: 458 | buckets[idx_bucket].append(i) 459 | 460 | for i in range(len(buckets) - 1, 0, -1): 461 | if len(buckets[i]) == 0: 462 | buckets.pop(i) 463 | self.boundaries.pop(i + 1) 464 | 465 | num_samples_per_bucket = [] 466 | for i in range(len(buckets)): 467 | len_bucket = len(buckets[i]) 468 | total_batch_size = self.num_replicas * self.batch_size 469 | rem = ( 470 | total_batch_size - (len_bucket % total_batch_size) 471 | ) % total_batch_size 472 | num_samples_per_bucket.append(len_bucket + rem) 473 | return buckets, num_samples_per_bucket 474 | 475 | def __iter__(self): 476 | # deterministically shuffle based on epoch 477 | g = torch.Generator() 478 | g.manual_seed(self.epoch) 479 | 480 | indices = [] 481 | if self.shuffle: 482 | for bucket in self.buckets: 483 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 484 | else: 485 | for bucket in self.buckets: 486 | indices.append(list(range(len(bucket)))) 487 | 488 | batches = [] 489 | for i in range(len(self.buckets)): 490 | bucket = self.buckets[i] 491 | len_bucket = len(bucket) 492 | ids_bucket = indices[i] 493 | num_samples_bucket = self.num_samples_per_bucket[i] 494 | 495 | # add extra samples to make it evenly divisible 496 | rem = num_samples_bucket - len_bucket 497 | ids_bucket = ( 498 | ids_bucket 499 | + ids_bucket * (rem // len_bucket) 500 | + ids_bucket[: (rem % len_bucket)] 501 | ) 502 | 503 | # subsample 504 | ids_bucket = ids_bucket[self.rank :: self.num_replicas] 505 | 506 | # batching 507 | for j in range(len(ids_bucket) // self.batch_size): 508 | batch = [ 509 | bucket[idx] 510 | for idx in ids_bucket[ 511 | j * self.batch_size : (j + 1) * self.batch_size 512 | ] 513 | ] 514 | batches.append(batch) 515 | 516 | if self.shuffle: 517 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 518 | batches = [batches[i] for i in batch_ids] 519 | self.batches = batches 520 | 521 | assert len(self.batches) * self.batch_size == self.num_samples 522 | return iter(self.batches) 523 | 524 | def _bisect(self, x, lo=0, hi=None): 525 | if hi is None: 526 | hi = len(self.boundaries) - 1 527 | 528 | if hi > lo: 529 | mid = (hi + lo) // 2 530 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 531 | return mid 532 | elif x <= self.boundaries[mid]: 533 | return self._bisect(x, lo, mid) 534 | else: 535 | return self._bisect(x, mid + 1, hi) 536 | else: 537 | return -1 538 | 539 | def __len__(self): 540 | return self.num_samples // self.batch_size 541 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | import utils 8 | from models import SynthesizerTrn 9 | from text.symbols import symbols 10 | 11 | OPSET_VERSION = 15 12 | 13 | 14 | def main() -> None: 15 | torch.manual_seed(1234) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--model-path", required=True, help="Path to model weights (.pth)" 20 | ) 21 | parser.add_argument( 22 | "--config-path", required=True, help="Path to model config (.json)" 23 | ) 24 | parser.add_argument("--output", required=True, help="Path to output model (.onnx)") 25 | 26 | args = parser.parse_args() 27 | 28 | args.model_path = Path(args.model_path) 29 | args.config_path = Path(args.config_path) 30 | args.output = Path(args.output) 31 | args.output.parent.mkdir(parents=True, exist_ok=True) 32 | 33 | hps = utils.get_hparams_from_file(args.config_path) 34 | 35 | if ( 36 | "use_mel_posterior_encoder" in hps.model.keys() 37 | and hps.model.use_mel_posterior_encoder == True 38 | ): 39 | print("Using mel posterior encoder for VITS2") 40 | posterior_channels = 80 # vits2 41 | hps.data.use_mel_posterior_encoder = True 42 | else: 43 | print("Using lin posterior encoder for VITS1") 44 | posterior_channels = hps.data.filter_length // 2 + 1 45 | hps.data.use_mel_posterior_encoder = False 46 | 47 | model_g = SynthesizerTrn( 48 | len(symbols), 49 | posterior_channels, 50 | hps.train.segment_size // hps.data.hop_length, 51 | n_speakers=hps.data.n_speakers, 52 | **hps.model, 53 | ) 54 | 55 | _ = model_g.eval() 56 | 57 | _ = utils.load_checkpoint(args.model_path, model_g, None) 58 | 59 | def infer_forward(text, text_lengths, scales, sid=None): 60 | noise_scale = scales[0] 61 | length_scale = scales[1] 62 | noise_scale_w = scales[2] 63 | audio = model_g.infer( 64 | text, 65 | text_lengths, 66 | noise_scale=noise_scale, 67 | length_scale=length_scale, 68 | noise_scale_w=noise_scale_w, 69 | sid=sid, 70 | )[0] 71 | 72 | return audio 73 | 74 | model_g.forward = infer_forward 75 | 76 | dummy_input_length = 50 77 | sequences = torch.randint( 78 | low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long 79 | ) 80 | sequence_lengths = torch.LongTensor([sequences.size(1)]) 81 | 82 | sid: Optional[torch.LongTensor] = None 83 | if hps.data.n_speakers > 1: 84 | sid = torch.LongTensor([0]) 85 | 86 | # noise, length, noise_w 87 | scales = torch.FloatTensor([0.667, 1.0, 0.8]) 88 | dummy_input = (sequences, sequence_lengths, scales, sid) 89 | 90 | # Export 91 | torch.onnx.export( 92 | model=model_g, 93 | args=dummy_input, 94 | f=str(args.output), 95 | verbose=False, 96 | opset_version=OPSET_VERSION, 97 | input_names=["input", "input_lengths", "scales", "sid"], 98 | output_names=["output"], 99 | dynamic_axes={ 100 | "input": {0: "batch_size", 1: "phonemes"}, 101 | "input_lengths": {0: "batch_size"}, 102 | "output": {0: "batch_size", 1: "time1", 2: "time2"}, 103 | }, 104 | ) 105 | 106 | print(f"Exported model to {args.output}") 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /filelists/ljs_audio_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY1/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read. 2 | DUMMY1/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too. 3 | DUMMY1/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five. 4 | DUMMY1/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect: 5 | DUMMY1/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others. 6 | DUMMY1/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated, 7 | DUMMY1/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others. 8 | DUMMY1/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies 9 | DUMMY1/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery. 10 | DUMMY1/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade. 11 | DUMMY1/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President. 12 | DUMMY1/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four. 13 | DUMMY1/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example." 14 | DUMMY1/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald: 15 | DUMMY1/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here. 16 | DUMMY1/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work. 17 | DUMMY1/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area 18 | DUMMY1/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon 19 | DUMMY1/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote. 20 | DUMMY1/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound, 21 | DUMMY1/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window. 22 | DUMMY1/LJ026-0068.wav|Energy enters the plant, to a small extent, 23 | DUMMY1/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary. 24 | DUMMY1/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized 25 | DUMMY1/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that 26 | DUMMY1/LJ012-0161.wav|he was reported to have fallen away to a shadow. 27 | DUMMY1/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to 28 | DUMMY1/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines. 29 | DUMMY1/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on. 30 | DUMMY1/LJ024-0083.wav|This plan of mine is no attack on the Court; 31 | DUMMY1/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough. 32 | DUMMY1/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup. 33 | DUMMY1/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles. 34 | DUMMY1/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive. 35 | DUMMY1/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen, 36 | DUMMY1/LJ009-0076.wav|We come to the sermon. 37 | DUMMY1/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution. 38 | DUMMY1/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes. 39 | DUMMY1/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art. 40 | DUMMY1/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount 41 | DUMMY1/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy. 42 | DUMMY1/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties, 43 | DUMMY1/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand. 44 | DUMMY1/LJ012-0235.wav|While they were in a state of insensibility the murder was committed. 45 | DUMMY1/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald. 46 | DUMMY1/LJ014-0030.wav|These were damnatory facts which well supported the prosecution. 47 | DUMMY1/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome? 48 | DUMMY1/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters. 49 | DUMMY1/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London, 50 | DUMMY1/LJ028-0275.wav|At last, in the twentieth month, 51 | DUMMY1/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed. 52 | DUMMY1/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm, 53 | DUMMY1/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County, 54 | DUMMY1/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view. 55 | DUMMY1/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning. 56 | DUMMY1/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words, 57 | DUMMY1/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands. 58 | DUMMY1/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy. 59 | DUMMY1/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace, 60 | DUMMY1/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells. 61 | DUMMY1/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true, 62 | DUMMY1/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him, 63 | DUMMY1/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits 64 | DUMMY1/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail. 65 | DUMMY1/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders. 66 | DUMMY1/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal. 67 | DUMMY1/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there. 68 | DUMMY1/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files 69 | DUMMY1/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator. 70 | DUMMY1/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash. 71 | DUMMY1/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood; 72 | DUMMY1/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely. 73 | DUMMY1/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present, 74 | DUMMY1/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him. 75 | DUMMY1/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense. 76 | DUMMY1/LJ008-0294.wav|nearly indefinitely deferred. 77 | DUMMY1/LJ047-0148.wav|On October twenty-five, 78 | DUMMY1/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner. 79 | DUMMY1/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old. 80 | DUMMY1/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male, 81 | DUMMY1/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one. 82 | DUMMY1/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved 83 | DUMMY1/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands. 84 | DUMMY1/LJ012-0250.wav|On the seventh July, eighteen thirty-seven, 85 | DUMMY1/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job. 86 | DUMMY1/LJ016-0138.wav|at a distance from the prison. 87 | DUMMY1/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology. 88 | DUMMY1/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally. 89 | DUMMY1/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline. 90 | DUMMY1/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects. 91 | DUMMY1/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle. 92 | DUMMY1/LJ038-0199.wav|eleven. If I am alive and taken prisoner, 93 | DUMMY1/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came. 94 | DUMMY1/LJ033-0047.wav|I noticed when I went out that the light was on, end quote, 95 | DUMMY1/LJ040-0027.wav|He was never satisfied with anything. 96 | DUMMY1/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly. 97 | DUMMY1/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity. 98 | DUMMY1/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days, 99 | DUMMY1/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston. 100 | DUMMY1/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce. 101 | -------------------------------------------------------------------------------- /filelists/ljs_audio_text_val_filelist.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY1/LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹɪɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wˌʌt ðeɪ hˈɪɹ ænd wˌʌt ðeɪ ɹˈiːd. 2 | DUMMY1/LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː, ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt, tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ, ænd ˈɔːl ðə fˈɜːnɪtʃɚ, ˈaɪ wʊd biː mˈæd æz hˈɛl, tˈuː. 3 | DUMMY1/LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹɪpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪnkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn eɪtˈiːn θˈɜːɾifˈaɪv. 4 | DUMMY1/LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹɪspˈɛkt: 5 | DUMMY1/LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹɪspˈɛkt wʌz tə θɹˈoʊ ðə ɹɪspˌɑːnsəbˈɪlɪɾi ˌɑːn ˈʌðɚz. 6 | DUMMY1/LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛlɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌnkənvˈɪktᵻd pɹˈɪzənɚ, ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt, ænd stˈɪl ʌnkəntˈæmᵻnˌeɪɾᵻd, 7 | DUMMY1/LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔːstˈeɪʃənɚz. hɪz ɐɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz. 8 | DUMMY1/LJ047-0044.wav|ˈɑːswəld wʌz, haʊˈɛvɚ, wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz. hiː dɪnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz 9 | DUMMY1/LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ tʃˈɑːɹlz dʒˈeɪ. kˈæɹɪkˌoʊ, ɐ ɹˈɛzɪdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi. 10 | DUMMY1/LJ048-0194.wav|dˈʊɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛntitˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd. 11 | DUMMY1/LJ049-0026.wav|ˌɑːn əkˈeɪʒən ðə sˈiːkɹət sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt. 12 | DUMMY1/LJ004-0152.wav|ɑːlðˈoʊ æt mˈɪstɚ bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən, ðə fˈɜːst stˈɛp tʊwˈɔːɹdz ɹɪfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˌɛvəntˈiːn sˈɛvəntifˈoːɹ. 13 | DUMMY1/LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni, ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsəsɚɹi tuː "mˌeɪk ɐn ɛɡzˈæmpəl." 14 | DUMMY1/LJ043-0002.wav|ðə wˈɔːɹən kəmˈɪʃən ɹɪpˈoːɹt. baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɑːnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi. tʃˈæptɚ sˈɛvən. lˈiː hˈɑːɹvi ˈɑːswəld: 15 | DUMMY1/LJ009-0114.wav|mˈɪstɚ wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dɪskɹˈaɪbɪŋ ɐnˈʌðɚ ɹɪlˈɪdʒəs sˈɜːvɪs, wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪətli biː ɪnsˈɜːɾᵻd hˈɪɹ. 16 | DUMMY1/LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk. 17 | DUMMY1/LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd. ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzɪz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə 18 | DUMMY1/LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp, hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪlɪˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən 19 | DUMMY1/LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl, kwˈoʊt, wiː hɐd ɐ mˈoʊɾɚkˌeɪd wɛɹɹˈɛvɚ wiː wˈɛnt, ˈɛnd kwˈoʊt. 20 | DUMMY1/LJ031-0070.wav|dˈɑːktɚ klˈɑːɹk, hˌuː mˈoʊst klˈoʊsli ɑːbzˈɜːvd ðə hˈɛd wˈuːnd, 21 | DUMMY1/LJ034-0198.wav|jˈuːɪnz, hˌuː wʌz ɑːnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstɪfˌaɪd ðæt hiː kʊd nˌɑːt dɪskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ. 22 | DUMMY1/LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt, tʊ ɐ smˈɔːl ɛkstˈɛnt, 23 | DUMMY1/LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɑːnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsəsɚɹi. 24 | DUMMY1/LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd 25 | DUMMY1/LJ005-0014.wav|spˈiːkɪŋ ˌɑːn ɐ dɪbˈeɪt ˌɑːn pɹˈɪzən mˈæɾɚz, hiː dᵻklˈɛɹd ðˈæt 26 | DUMMY1/LJ012-0161.wav|hiː wʌz ɹɪpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ. 27 | DUMMY1/LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹɪpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹɪfˈɜːd tuː 28 | DUMMY1/LJ019-0257.wav|hˈɪɹ ðə tɹˈɛdwˈiːl wʌz ɪn jˈuːs, ðɛɹ sˈɛljʊlɚ kɹˈæŋks, ɔːɹ hˈɑːɹdlˈeɪbɚ məʃˈiːnz. 29 | DUMMY1/LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɑːn. 30 | DUMMY1/LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɑːnðə kˈoːɹt; 31 | DUMMY1/LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz, nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz. ˈaɪ hæv hɐd ɪnˈʌf. 32 | DUMMY1/LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp. 33 | DUMMY1/LJ046-0058.wav|dˈʊɹɪŋ hɪz pɹˈɛzɪdənsi, fɹˈæŋklɪn dˈiː. ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹəd dʒˈɜːnɪz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹəd fˈɪfti θˈaʊzənd mˈaɪlz. 34 | DUMMY1/LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ, ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv. 35 | DUMMY1/LJ002-0043.wav|lˈɑːŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾisˈɪks fˈiːt, sˈɪks twˈɛntiθɹˈiː fˈiːt, ænd ðɪ ˈeɪtθ eɪtˈiːn, 36 | DUMMY1/LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən. 37 | DUMMY1/LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hɐd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹɪpɹˈiːv, ænd wɪðˌɪn ɐ fjˈuː ˈaɪʊɹz ʌv ˌɛksɪkjˈuːʃən. 38 | DUMMY1/LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹət sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹɪlˈiːsd ɔːɹ ɛskˈeɪps. 39 | DUMMY1/LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ, ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt. 40 | DUMMY1/LJ042-0096.wav| ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt 41 | DUMMY1/LJ049-0050.wav|hˈɪl hɐd bˈoʊθ fˈiːt ɑːnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mɪsˈɛs kˈɛnədi. 42 | DUMMY1/LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt, nˈuːɡeɪt ɹɪsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntɪz, 43 | DUMMY1/LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs, ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsən ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd. 44 | DUMMY1/LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd. 45 | DUMMY1/LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kənklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɑːnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld. 46 | DUMMY1/LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən. 47 | DUMMY1/LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt, ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ɪlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm? 48 | DUMMY1/LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪt ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz. 49 | DUMMY1/LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪsd ænd ɹɪpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɪz ʌvðə sˈɪɾi ʌv lˈʌndən, 50 | DUMMY1/LJ028-0275.wav|æt lˈæst, ɪnðə twˈɛntiəθ mˈʌnθ, 51 | DUMMY1/LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋplˈeɪs wɪð ɐ tɹˈæpdˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd. 52 | DUMMY1/LJ011-0096.wav|hiː mˈæɹɪd ɐ lˈeɪdi ˈɑːlsoʊ bɪlˈɑːŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz, hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃən, wˈɪtʃ, ænd hɪz ˈoʊn mˈʌni, hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm, 53 | DUMMY1/LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː. kɹˈeɪɡ, ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti, 54 | DUMMY1/LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz, ɡɹˈeɪt lˈɔɪɚz, ɡˈʌvɚnɚz ʌv pɹˈɪzənz, ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː. 55 | DUMMY1/LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst, ɐ səspˈɪʃəs sˈɜːkəmstˌæns, æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ. 56 | DUMMY1/LJ027-0141.wav|ɪz klˈoʊsli ɹɪpɹədˈuːst ɪnðə lˈaɪfhˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ. ˈɔːɹ, ɪn ˈʌðɚ wˈɜːdz, 57 | DUMMY1/LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi, ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz. 58 | DUMMY1/LJ031-0202.wav|mɪsˈɛs kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hɐd sˈɜːvd ɪnðə nˈeɪvi. 59 | DUMMY1/LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈoʊptfɔːɹ pˈiəɹɪəd ʌv pˈiːs, 60 | DUMMY1/LJ016-0288.wav|"mˈʌlɚ, mˈʌlɚ, hiːz ðə mˈæn," tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz, wˌɪtʃ wʌz ɹɪsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz. 61 | DUMMY1/LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ, wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdɪli dɪstˈɪŋɡwɪʃ ðə fˈɑːls fɹʌmðə tɹˈuː, 62 | DUMMY1/LJ018-0081.wav|hɪz dɪfˈɛns bˌiːɪŋ ðæt hiː hɐd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd, bˌʌt ðˈæt, ɑːnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hɐd ɹˈɔŋd hˌɪm, 63 | DUMMY1/LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪnkɹiːs ɪnðə pˈeɪɹoʊlz, ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts 64 | DUMMY1/LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp, bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd, ænd ðə mˈæn wʌz kˈæɹɪd bˈæk tə dʒˈeɪl. 65 | DUMMY1/LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz, ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz. 66 | DUMMY1/LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən, ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl. 67 | DUMMY1/LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs, ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts, ðə hˈaʊskiːpɚ ðˈɛɹ. 68 | DUMMY1/LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛntitˈuː, naɪntˈiːn sˈɪkstiθɹˈiː, fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈiəɹɪəl fɚðə pˌiːˌɑːɹˈɛs dʒˈɛnɚɹəl fˈaɪlz 69 | DUMMY1/LJ017-0044.wav|ænd ðə dˈiːpəst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm, ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn, ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ. 70 | DUMMY1/LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ, ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn, ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ. 71 | DUMMY1/LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɐɹˈɛstᵻd ˌɑːn səspˈɪʃən, ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd; 72 | DUMMY1/LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn, bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd, ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sɪvˈɪɹli. 73 | DUMMY1/LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹɪhˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ. ɑːlðˈoʊ ɪt ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt, 74 | DUMMY1/LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm. 75 | DUMMY1/LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹɪkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɑːŋ ɪn səspˈɛns. 76 | DUMMY1/LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dɪfˈɜːd. 77 | DUMMY1/LJ047-0148.wav|ˌɑːn ɑːktˈoʊbɚ twˈɛntifˈaɪv, 78 | DUMMY1/LJ008-0111.wav|ðeɪ ˈɛntɚd ˈeɪ "stˈoʊn kˈoʊld ɹˈuːm," ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ. 79 | DUMMY1/LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstɪfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld. 80 | DUMMY1/LJ037-0234.wav|mɪsˈɛs mˈɛɹi bɹˈɑːk, ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən, wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl, 81 | DUMMY1/LJ040-0002.wav|tʃˈæptɚ sˈɛvən. lˈiː hˈɑːɹvi ˈɑːswəld: bˈækɡɹaʊnd ænd pˈɑːsəbəl mˈoʊɾɪvz, pˈɑːɹt wˌʌn. 82 | DUMMY1/LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstɪfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bɪkˈʌmɪŋ ɪnvˈɑːlvd 83 | DUMMY1/LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɑːn wˈɑːtʃᵻz, wɜː kˈɛɹfəli ɹɪmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz. 84 | DUMMY1/LJ012-0250.wav|ɑːnðə sˈɛvənθ dʒuːlˈaɪ, eɪtˈiːn θˈɜːɾisˈɛvən, 85 | DUMMY1/LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈɛnɚz tə wˈɜːk baɪ ðə dʒˈɑːb. 86 | DUMMY1/LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən. 87 | DUMMY1/LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ɪsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi. 88 | DUMMY1/LJ031-0134.wav|ˌɑːn wˈʌn əkˈeɪʒən mɪsˈɛs dʒˈɑːnsən, ɐkˈʌmpənɪd baɪ tˈuː sˈiːkɹət sˈɜːvɪs ˈeɪdʒənts, lˈɛft ðə ɹˈuːm tə sˈiː mɪsˈɛs kˈɛnədi ænd mɪsˈɛs kənˈæli. 89 | DUMMY1/LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn. 90 | DUMMY1/LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd, ˈoʊpənd, ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts. 91 | DUMMY1/LJ034-0160.wav|ˌɑːn bɹˈɛnənz sˈʌbsɪkwənt sˈɜːtən aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl. 92 | DUMMY1/LJ038-0199.wav|ɪlˈɛvən. ɪf ˈaɪ æm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ, 93 | DUMMY1/LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈɐd fɔːɹ hˌɪm, ænd ɹɪmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm. 94 | DUMMY1/LJ033-0047.wav|ˈaɪ nˈoʊɾɪsd wɛn ˈaɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɑːn, ˈɛnd kwˈoʊt, 95 | DUMMY1/LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ. 96 | DUMMY1/LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli. 97 | DUMMY1/LJ003-0111.wav|hiː wʌz ɪn kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː, ˈɛnd kwˈoʊt. ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɪˈɑːsɪɾi. 98 | DUMMY1/LJ008-0258.wav|lˈɛt mˌiː ɹɪtɹˈeɪs maɪ stˈɛps, ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz, 99 | DUMMY1/LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæn kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt, mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs, fˈɔːɹt wˈɜːθ, sˌæn æntˈoʊnɪˌoʊ, ænd hjˈuːstən. 100 | DUMMY1/LJ004-0045.wav|mˈɪstɚ stˈɜːdʒᵻz bˈoːɹn, sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ, sˌɜː dʒˈeɪmz skˈɑːɹlɪt, ænd wˈɪljəm wˈɪlbɚfˌoːɹs. 101 | -------------------------------------------------------------------------------- /filelists/vctk_audio_sid_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|It had happened to him. 2 | DUMMY2/p280/p280_148.wav|52|It is open season on the Old Firm. 3 | DUMMY2/p231/p231_320.wav|50|However, he is a coach, and he remains a coach at heart. 4 | DUMMY2/p282/p282_129.wav|83|It is not a U-turn. 5 | DUMMY2/p254/p254_015.wav|41|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain. 6 | DUMMY2/p228/p228_285.wav|57|The songs are just so good. 7 | DUMMY2/p334/p334_307.wav|38|If they don't, they can expect their funding to be cut. 8 | DUMMY2/p287/p287_081.wav|77|I've never seen anything like it. 9 | DUMMY2/p247/p247_083.wav|14|It is a job creation scheme.) 10 | DUMMY2/p264/p264_051.wav|65|We were leading by two goals.) 11 | DUMMY2/p335/p335_058.wav|49|Let's see that increase over the years. 12 | DUMMY2/p236/p236_225.wav|75|There is no quick fix. 13 | DUMMY2/p374/p374_353.wav|11|And that brings us to the point. 14 | DUMMY2/p272/p272_076.wav|69|Sounds like The Sixth Sense? 15 | DUMMY2/p271/p271_152.wav|27|The petition was formally presented at Downing Street yesterday. 16 | DUMMY2/p228/p228_127.wav|57|They've got to account for it. 17 | DUMMY2/p276/p276_223.wav|106|It's been a humbling year. 18 | DUMMY2/p262/p262_248.wav|45|The project has already secured the support of Sir Sean Connery. 19 | DUMMY2/p314/p314_086.wav|51|The team this year is going places. 20 | DUMMY2/p225/p225_038.wav|101|Diving is no part of football. 21 | DUMMY2/p279/p279_088.wav|25|The shareholders will vote to wind up the company on Friday morning. 22 | DUMMY2/p272/p272_018.wav|69|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain. 23 | DUMMY2/p256/p256_098.wav|90|She told The Herald. 24 | DUMMY2/p261/p261_218.wav|100|All will be revealed in due course. 25 | DUMMY2/p265/p265_063.wav|73|IT shouldn't come as a surprise, but it does. 26 | DUMMY2/p314/p314_042.wav|51|It is all about people being assaulted, abused. 27 | DUMMY2/p241/p241_188.wav|86|I wish I could say something. 28 | DUMMY2/p283/p283_111.wav|95|It's good to have a voice. 29 | DUMMY2/p275/p275_006.wav|40|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 30 | DUMMY2/p228/p228_092.wav|57|Today I couldn't run on it. 31 | DUMMY2/p295/p295_343.wav|92|The atmosphere is businesslike. 32 | DUMMY2/p228/p228_187.wav|57|They will run a mile. 33 | DUMMY2/p294/p294_317.wav|104|It didn't put me off. 34 | DUMMY2/p231/p231_445.wav|50|It sounded like a bomb. 35 | DUMMY2/p272/p272_086.wav|69|Today she has been released. 36 | DUMMY2/p255/p255_210.wav|31|It was worth a photograph. 37 | DUMMY2/p229/p229_060.wav|67|And a film maker was born. 38 | DUMMY2/p260/p260_232.wav|81|The Home Office would not release any further details about the group. 39 | DUMMY2/p245/p245_025.wav|59|Johnson was pretty low. 40 | DUMMY2/p333/p333_185.wav|64|This area is perfect for children. 41 | DUMMY2/p244/p244_242.wav|78|He is a man of the people. 42 | DUMMY2/p376/p376_187.wav|71|"It is a terrible loss." 43 | DUMMY2/p239/p239_156.wav|48|It is a good lifestyle. 44 | DUMMY2/p307/p307_037.wav|22|He released a half-dozen solo albums. 45 | DUMMY2/p305/p305_185.wav|54|I am not even thinking about that. 46 | DUMMY2/p272/p272_081.wav|69|It was magic. 47 | DUMMY2/p302/p302_297.wav|30|I'm trying to stay open on that. 48 | DUMMY2/p275/p275_320.wav|40|We are in the end game. 49 | DUMMY2/p239/p239_231.wav|48|Then we will face the Danish champions. 50 | DUMMY2/p268/p268_301.wav|87|It was only later that the condition was diagnosed. 51 | DUMMY2/p336/p336_088.wav|98|They failed to reach agreement yesterday. 52 | DUMMY2/p278/p278_255.wav|10|They made such decisions in London. 53 | DUMMY2/p361/p361_132.wav|79|That got me out. 54 | DUMMY2/p307/p307_146.wav|22|You hope he prevails. 55 | DUMMY2/p244/p244_147.wav|78|They could not ignore the will of parliament, he claimed. 56 | DUMMY2/p294/p294_283.wav|104|This is our unfinished business. 57 | DUMMY2/p283/p283_300.wav|95|I would have the hammer in the crowd. 58 | DUMMY2/p239/p239_079.wav|48|I can understand the frustrations of our fans. 59 | DUMMY2/p264/p264_009.wav|65|There is , according to legend, a boiling pot of gold at one end. ) 60 | DUMMY2/p307/p307_348.wav|22|He did not oppose the divorce. 61 | DUMMY2/p304/p304_308.wav|72|We are the gateway to justice. 62 | DUMMY2/p281/p281_056.wav|36|None has ever been found. 63 | DUMMY2/p267/p267_158.wav|0|We were given a warm and friendly reception. 64 | DUMMY2/p300/p300_169.wav|102|Who do these people think they are? 65 | DUMMY2/p276/p276_177.wav|106|They exist in name alone. 66 | DUMMY2/p228/p228_245.wav|57|It is a policy which has the full support of the minister. 67 | DUMMY2/p300/p300_303.wav|102|I'm wondering what you feel about the youngest. 68 | DUMMY2/p362/p362_247.wav|15|This would give Scotland around eight members. 69 | DUMMY2/p326/p326_031.wav|28|United were in control without always being dominant. 70 | DUMMY2/p361/p361_288.wav|79|I did not think it was very proper. 71 | DUMMY2/p286/p286_145.wav|63|Tiger is not the norm. 72 | DUMMY2/p234/p234_071.wav|3|She did that for the rest of her life. 73 | DUMMY2/p263/p263_296.wav|39|The decision was announced at its annual conference in Dunfermline. 74 | DUMMY2/p323/p323_228.wav|34|She became a heroine of my childhood. 75 | DUMMY2/p280/p280_346.wav|52|It was a bit like having children. 76 | DUMMY2/p333/p333_080.wav|64|But the tragedy did not stop there. 77 | DUMMY2/p226/p226_268.wav|43|That decision is for the British Parliament and people. 78 | DUMMY2/p362/p362_314.wav|15|Is that right? 79 | DUMMY2/p240/p240_047.wav|93|It is so sad. 80 | DUMMY2/p250/p250_207.wav|24|You could feel the heat. 81 | DUMMY2/p273/p273_176.wav|56|Neither side would reveal the details of the offer. 82 | DUMMY2/p316/p316_147.wav|85|And frankly, it's been a while. 83 | DUMMY2/p265/p265_047.wav|73|It is unique. 84 | DUMMY2/p336/p336_353.wav|98|Sometimes you get them, sometimes you don't. 85 | DUMMY2/p230/p230_376.wav|35|This hasn't happened in a vacuum. 86 | DUMMY2/p308/p308_209.wav|107|There is great potential on this river. 87 | DUMMY2/p250/p250_442.wav|24|We have not yet received a letter from the Irish. 88 | DUMMY2/p260/p260_037.wav|81|It's a fact. 89 | DUMMY2/p299/p299_345.wav|58|We're very excited and challenged by the project. 90 | DUMMY2/p269/p269_218.wav|94|A Grampian Police spokesman said. 91 | DUMMY2/p306/p306_014.wav|12|To the Hebrews it was a token that there would be no more universal floods. 92 | DUMMY2/p271/p271_292.wav|27|It's a record label, not a form of music. 93 | DUMMY2/p247/p247_225.wav|14|I am considered a teenager.) 94 | DUMMY2/p294/p294_094.wav|104|It should be a condition of employment. 95 | DUMMY2/p269/p269_031.wav|94|Is this accurate? 96 | DUMMY2/p275/p275_116.wav|40|It's not fair. 97 | DUMMY2/p265/p265_006.wav|73|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 98 | DUMMY2/p285/p285_072.wav|2|Mr Irvine said Mr Rafferty was now in good spirits. 99 | DUMMY2/p270/p270_167.wav|8|We did what we had to do. 100 | DUMMY2/p360/p360_397.wav|60|It is a relief. 101 | -------------------------------------------------------------------------------- /filelists/vctk_audio_sid_text_val_filelist.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm. 2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm. 3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt. 4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn. 5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn. 6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd. 7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt. 8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt. 9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm. 10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz. 11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz. 12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks. 13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt. 14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns? 15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ. 16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt. 17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ. 18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi. 19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz. 20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl. 21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ. 22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn. 23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld. 24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs. 25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz. 26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd. 27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ. 28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs. 29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt. 31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk. 32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl. 33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf. 34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm. 35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd. 36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf. 37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn. 38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp. 39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ. 40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən. 41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl. 42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs." 43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl. 44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz. 45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt. 46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk. 47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt. 48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm. 49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz. 50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd. 51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ. 52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən. 53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt. 54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz. 55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd. 56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs. 57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd. 58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz. 59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd. 60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs. 61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs. 62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd. 63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən. 64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ? 65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn. 66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ. 67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst. 68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz. 69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt. 70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ. 71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm. 72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf. 73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn. 74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd. 75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən. 76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ. 77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl. 78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt? 79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd. 80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt. 81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ. 82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl. 83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk. 84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt. 85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm. 86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ. 87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ. 88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt. 89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt. 90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd. 91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz. 92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk. 93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ. 94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt. 95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət? 96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ. 97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts. 99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː. 100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf. 101 | -------------------------------------------------------------------------------- /filelists/vctk_audio_sid_text_val_filelist_new.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm. 2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm. 3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt. 4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn. 5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn. 6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd. 7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt. 8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt. 9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm. 10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz. 11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz. 12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks. 13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt. 14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns? 15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ. 16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt. 17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ. 18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi. 19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz. 20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl. 21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ. 22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn. 23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld. 24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs. 25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz. 26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd. 27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ. 28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs. 29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt. 31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk. 32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl. 33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf. 34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm. 35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd. 36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf. 37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn. 38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp. 39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ. 40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən. 41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl. 42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs." 43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl. 44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz. 45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt. 46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk. 47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt. 48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm. 49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz. 50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd. 51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ. 52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən. 53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt. 54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz. 55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd. 56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs. 57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd. 58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz. 59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd. 60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs. 61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs. 62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd. 63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən. 64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ? 65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn. 66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ. 67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst. 68 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt. 69 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ. 70 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm. 71 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf. 72 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn. 73 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd. 74 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən. 75 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ. 76 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl. 77 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd. 78 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt. 79 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ. 80 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl. 81 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk. 82 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt. 83 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm. 84 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ. 85 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ. 86 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt. 87 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt. 88 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd. 89 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz. 90 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk. 91 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ. 92 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt. 93 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət? 94 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ. 95 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 96 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts. 97 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː. 98 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf. 99 | -------------------------------------------------------------------------------- /infer_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import onnxruntime 5 | import torch 6 | from scipy.io.wavfile import write 7 | 8 | import commons 9 | import utils 10 | from text import text_to_sequence 11 | 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | 21 | def main() -> None: 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model", required=True, help="Path to model (.onnx)") 24 | parser.add_argument( 25 | "--config-path", required=True, help="Path to model config (.json)" 26 | ) 27 | parser.add_argument( 28 | "--output-wav-path", required=True, help="Path to write WAV file" 29 | ) 30 | parser.add_argument("--text", required=True, type=str, help="Text to synthesize") 31 | parser.add_argument("--sid", required=False, type=int, help="Speaker ID to synthesize") 32 | args = parser.parse_args() 33 | 34 | sess_options = onnxruntime.SessionOptions() 35 | model = onnxruntime.InferenceSession(str(args.model), sess_options=sess_options, providers=["CPUExecutionProvider"]) 36 | 37 | hps = utils.get_hparams_from_file(args.config_path) 38 | 39 | phoneme_ids = get_text(args.text, hps) 40 | text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0) 41 | text_lengths = np.array([text.shape[1]], dtype=np.int64) 42 | scales = np.array([0.667, 1.0, 0.8], dtype=np.float32) 43 | sid = np.array([int(args.sid)]) if args.sid is not None else None 44 | 45 | audio = model.run( 46 | None, 47 | { 48 | "input": text, 49 | "input_lengths": text_lengths, 50 | "scales": scales, 51 | "sid": sid, 52 | }, 53 | )[0].squeeze((0, 1)) 54 | 55 | write(data=audio, rate=hps.data.sampling_rate, filename=args.output_wav_path) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d1097e33-73b2-4001-a76f-782c0cd17644", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "%matplotlib inline\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import IPython.display as ipd\n", 15 | "\n", 16 | "import os\n", 17 | "import json\n", 18 | "import math\n", 19 | "import torch\n", 20 | "from torch import nn\n", 21 | "from torch.nn import functional as F\n", 22 | "from torch.utils.data import DataLoader\n", 23 | "\n", 24 | "import commons\n", 25 | "import utils\n", 26 | "from data_utils import (\n", 27 | " TextAudioLoader,\n", 28 | " TextAudioCollate,\n", 29 | " TextAudioSpeakerLoader,\n", 30 | " TextAudioSpeakerCollate,\n", 31 | ")\n", 32 | "from models import SynthesizerTrn\n", 33 | "from text.symbols import symbols\n", 34 | "from text import text_to_sequence\n", 35 | "\n", 36 | "from scipy.io.wavfile import write\n", 37 | "\n", 38 | "\n", 39 | "def get_text(text, hps):\n", 40 | " text_norm = text_to_sequence(text, hps.data.text_cleaners)\n", 41 | " if hps.data.add_blank:\n", 42 | " text_norm = commons.intersperse(text_norm, 0)\n", 43 | " text_norm = torch.LongTensor(text_norm)\n", 44 | " return text_norm" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "9b653abf-4b03-47d6-80c3-8a4405ba9e56", 51 | "metadata": { 52 | "tags": [] 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "device = torch.device(\"cpu\")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "ec35f535-6d34-467c-bf33-aa7c1e673f34", 62 | "metadata": { 63 | "jp-MarkdownHeadingCollapsed": true, 64 | "tags": [] 65 | }, 66 | "source": [ 67 | "# LJSpeech" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "eb5302d7-7e4b-41c2-a39e-4fec7e46403c", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "hps = utils.get_hparams_from_file(\"./configs/vits2_ljs_base.json\")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "074d0c26-5a06-4878-b8da-361f68bd45c5", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "net_g = SynthesizerTrn(\n", 88 | " len(symbols),\n", 89 | " hps.data.filter_length // 2 + 1,\n", 90 | " hps.train.segment_size // hps.data.hop_length,\n", 91 | " **hps.model).cuda()\n", 92 | "_ = net_g.eval()\n", 93 | "\n", 94 | "_ = utils.load_checkpoint(\"/path/to/pretrained_ljs.pth\", net_g, None)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "196962fc-bd97-4b00-9194-e7c0d899e69a", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "stn_tst = get_text(\"VITS is Awesome!\", hps)\n", 105 | "with torch.no_grad():\n", 106 | " x_tst = stn_tst.cuda().unsqueeze(0)\n", 107 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n", 108 | " audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n", 109 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "267ffb0b-d2f0-4e1a-9450-5fc593075810", 115 | "metadata": {}, 116 | "source": [ 117 | "# VCTK" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "51a2ddc8-3cf2-459c-837f-a3fb08261b7d", 124 | "metadata": { 125 | "tags": [] 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "hps = utils.get_hparams_from_file(\"./configs/vits2_vctk_base2.json\")" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "8a5d8345-cae7-4137-a700-0474a846112f", 136 | "metadata": { 137 | "tags": [] 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "if hps.model.use_mel_posterior_encoder == True:\n", 142 | " print(\"Using mel posterior encoder for VITS2\")\n", 143 | " posterior_channels = 80 # vits2\n", 144 | " hps.data.use_mel_posterior_encoder = True\n", 145 | "else:\n", 146 | " print(\"Using lin posterior encoder for VITS1\")\n", 147 | " posterior_channels = hps.data.filter_length // 2 + 1\n", 148 | " hps.data.use_mel_posterior_encoder = False\n", 149 | "\n", 150 | "net_g = SynthesizerTrn(\n", 151 | " len(symbols),\n", 152 | " hps.data.n_mel_channels,\n", 153 | " None,\n", 154 | " n_speakers=hps.data.n_speakers,\n", 155 | " **hps.model,\n", 156 | ").to(device)\n", 157 | "_ = net_g.eval()\n", 158 | "\n", 159 | "_ = utils.load_checkpoint(\"/path/to/the/pretrained.pth\", net_g, None)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "256a6792-da17-4f86-8b67-f7469200a252", 166 | "metadata": { 167 | "tags": [] 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "text = \"\"\"VITS2 is Awesome!\"\"\"\n", 172 | "sid = 4\n", 173 | "\n", 174 | "stn_tst = get_text(text, hps)\n", 175 | "with torch.no_grad():\n", 176 | " x_tst = stn_tst.to(device).unsqueeze(0)\n", 177 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)\n", 178 | " sid = torch.LongTensor([int(sid)]).to(device)\n", 179 | " audio = (\n", 180 | " net_g.infer(\n", 181 | " x_tst,\n", 182 | " x_tst_lengths,\n", 183 | " sid=sid,\n", 184 | " noise_scale=0.667,\n", 185 | " noise_scale_w=0.8,\n", 186 | " length_scale=1,\n", 187 | " )[0][0, 0]\n", 188 | " .data.cpu()\n", 189 | " .float()\n", 190 | " .numpy()\n", 191 | " )\n", 192 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "7f0753ef-8103-4ef5-b431-3067a86e92b7", 198 | "metadata": {}, 199 | "source": [ 200 | "# Voice Conversion" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "92aaa17e-19fa-4cc9-935b-02a6a4f897de", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)\n", 211 | "collate_fn = TextAudioSpeakerCollate()\n", 212 | "loader = DataLoader(dataset, num_workers=0, shuffle=False,\n", 213 | " batch_size=1, pin_memory=False,\n", 214 | " drop_last=True, collate_fn=collate_fn)\n", 215 | "data_list = list(loader)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "0bb98f74-78ed-4ea1-894d-120af96335c3", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "with torch.no_grad():\n", 226 | " x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.to(device) for x in data_list[0]]\n", 227 | " sid_tgt1 = torch.LongTensor([1]).to(device)\n", 228 | " sid_tgt2 = torch.LongTensor([2]).to(device)\n", 229 | " sid_tgt3 = torch.LongTensor([4]).to(device)\n", 230 | " audio1 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data.cpu().float().numpy()\n", 231 | " audio2 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt2)[0][0,0].data.cpu().float().numpy()\n", 232 | " audio3 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt3)[0][0,0].data.cpu().float().numpy()\n", 233 | "print(\"Original SID: %d\" % sid_src.item())\n", 234 | "ipd.display(ipd.Audio(y[0].cpu().numpy(), rate=hps.data.sampling_rate, normalize=False))\n", 235 | "print(\"Converted SID: %d\" % sid_tgt1.item())\n", 236 | "ipd.display(ipd.Audio(audio1, rate=hps.data.sampling_rate, normalize=False))\n", 237 | "print(\"Converted SID: %d\" % sid_tgt2.item())\n", 238 | "ipd.display(ipd.Audio(audio2, rate=hps.data.sampling_rate, normalize=False))\n", 239 | "print(\"Converted SID: %d\" % sid_tgt3.item())\n", 240 | "ipd.display(ipd.Audio(audio3, rate=hps.data.sampling_rate, normalize=False))" 241 | ] 242 | } 243 | ], 244 | "metadata": { 245 | "kernelspec": { 246 | "display_name": "Python 3 (ipykernel)", 247 | "language": "python", 248 | "name": "python3" 249 | }, 250 | "language_info": { 251 | "codemirror_mode": { 252 | "name": "ipython", 253 | "version": 3 254 | }, 255 | "file_extension": ".py", 256 | "mimetype": "text/x-python", 257 | "name": "python", 258 | "nbconvert_exporter": "python", 259 | "pygments_lexer": "ipython3", 260 | "version": "3.10.8" 261 | } 262 | }, 263 | "nbformat": 4, 264 | "nbformat_minor": 5 265 | } 266 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ## LJSpeech 2 | import torch 3 | 4 | import commons 5 | import utils 6 | from models import SynthesizerTrn 7 | from text.symbols import symbols 8 | from text import text_to_sequence 9 | 10 | from scipy.io.wavfile import write 11 | 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | 21 | CONFIG_PATH = "./configs/vits2_ljs_nosdp.json" 22 | MODEL_PATH = "./logs/G_114000.pth" 23 | TEXT = "VITS-2 is Awesome!" 24 | OUTPUT_WAV_PATH = "sample_vits2.wav" 25 | 26 | hps = utils.get_hparams_from_file(CONFIG_PATH) 27 | 28 | if ( 29 | "use_mel_posterior_encoder" in hps.model.keys() 30 | and hps.model.use_mel_posterior_encoder == True 31 | ): 32 | print("Using mel posterior encoder for VITS2") 33 | posterior_channels = 80 # vits2 34 | hps.data.use_mel_posterior_encoder = True 35 | else: 36 | print("Using lin posterior encoder for VITS1") 37 | posterior_channels = hps.data.filter_length // 2 + 1 38 | hps.data.use_mel_posterior_encoder = False 39 | 40 | net_g = SynthesizerTrn( 41 | len(symbols), 42 | posterior_channels, 43 | hps.train.segment_size // hps.data.hop_length, 44 | **hps.model 45 | ).cuda() 46 | _ = net_g.eval() 47 | 48 | _ = utils.load_checkpoint(MODEL_PATH, net_g, None) 49 | 50 | stn_tst = get_text(TEXT, hps) 51 | with torch.no_grad(): 52 | x_tst = stn_tst.cuda().unsqueeze(0) 53 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 54 | audio = ( 55 | net_g.infer( 56 | x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 57 | )[0][0, 0] 58 | .data.cpu() 59 | .float() 60 | .numpy() 61 | ) 62 | 63 | write(data=audio, rate=hps.data.sampling_rate, filename=OUTPUT_WAV_PATH) 64 | -------------------------------------------------------------------------------- /inference_ms.py: -------------------------------------------------------------------------------- 1 | ## VCTK 2 | import torch 3 | 4 | import commons 5 | import utils 6 | from models import SynthesizerTrn 7 | from text.symbols import symbols 8 | from text import text_to_sequence 9 | 10 | from scipy.io.wavfile import write 11 | 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | 21 | CONFIG_PATH = "./configs/vits2_vctk_base.json" 22 | MODEL_PATH = "/path/to/pretrained_vctk.pth" 23 | TEXT = "VITS-2 is Awesome!" 24 | SPK_ID = 4 25 | OUTPUT_WAV_PATH = "sample_vits2_ms.wav" 26 | 27 | hps = utils.get_hparams_from_file(CONFIG_PATH) 28 | 29 | if ( 30 | "use_mel_posterior_encoder" in hps.model.keys() 31 | and hps.model.use_mel_posterior_encoder == True 32 | ): 33 | print("Using mel posterior encoder for VITS2") 34 | posterior_channels = 80 # vits2 35 | hps.data.use_mel_posterior_encoder = True 36 | else: 37 | print("Using lin posterior encoder for VITS1") 38 | posterior_channels = hps.data.filter_length // 2 + 1 39 | hps.data.use_mel_posterior_encoder = False 40 | 41 | net_g = SynthesizerTrn( 42 | len(symbols), 43 | posterior_channels, 44 | hps.train.segment_size // hps.data.hop_length, 45 | n_speakers=hps.data.n_speakers, 46 | **hps.model 47 | ).cuda() 48 | _ = net_g.eval() 49 | 50 | _ = utils.load_checkpoint(MODEL_PATH, net_g, None) 51 | 52 | stn_tst = get_text(TEXT, hps) 53 | with torch.no_grad(): 54 | x_tst = stn_tst.cuda().unsqueeze(0) 55 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 56 | sid = torch.LongTensor([SPK_ID]).cuda() 57 | audio = ( 58 | net_g.infer( 59 | x_tst, 60 | x_tst_lengths, 61 | sid=sid, 62 | noise_scale=0.667, 63 | noise_scale_w=0.8, 64 | length_scale=1, 65 | )[0][0, 0] 66 | .data.cpu() 67 | .float() 68 | .numpy() 69 | ) 70 | 71 | write(data=audio, rate=hps.data.sampling_rate, filename=OUTPUT_WAV_PATH) 72 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import commons 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1 - dr) ** 2) 26 | g_loss = torch.mean(dg**2) 27 | loss += r_loss + g_loss 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1 - dg) ** 2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | # warnings.simplefilter(action='ignore', category=FutureWarning) 4 | warnings.filterwarnings(action="ignore") 5 | 6 | import math 7 | import os 8 | import random 9 | 10 | import librosa 11 | import librosa.util as librosa_util 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.data 16 | from librosa.filters import mel as librosa_mel_fn 17 | from librosa.util import normalize, pad_center, tiny 18 | from packaging import version 19 | from scipy.io.wavfile import read 20 | from scipy.signal import get_window 21 | from torch import nn 22 | 23 | MAX_WAV_VALUE = 32768.0 24 | 25 | 26 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 27 | """ 28 | PARAMS 29 | ------ 30 | C: compression factor 31 | """ 32 | return torch.log(torch.clamp(x, min=clip_val) * C) 33 | 34 | 35 | def dynamic_range_decompression_torch(x, C=1): 36 | """ 37 | PARAMS 38 | ------ 39 | C: compression factor used to compress 40 | """ 41 | return torch.exp(x) / C 42 | 43 | 44 | def spectral_normalize_torch(magnitudes): 45 | output = dynamic_range_compression_torch(magnitudes) 46 | return output 47 | 48 | 49 | def spectral_de_normalize_torch(magnitudes): 50 | output = dynamic_range_decompression_torch(magnitudes) 51 | return output 52 | 53 | 54 | mel_basis = {} 55 | hann_window = {} 56 | 57 | 58 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 59 | if torch.min(y) < -1.0: 60 | print("min value is ", torch.min(y)) 61 | if torch.max(y) > 1.0: 62 | print("max value is ", torch.max(y)) 63 | 64 | global hann_window 65 | dtype_device = str(y.dtype) + "_" + str(y.device) 66 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 67 | if wnsize_dtype_device not in hann_window: 68 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 69 | dtype=y.dtype, device=y.device 70 | ) 71 | 72 | y = torch.nn.functional.pad( 73 | y.unsqueeze(1), 74 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 75 | mode="reflect", 76 | ) 77 | y = y.squeeze(1) 78 | 79 | if version.parse(torch.__version__) >= version.parse("2"): 80 | spec = torch.stft( 81 | y, 82 | n_fft, 83 | hop_length=hop_size, 84 | win_length=win_size, 85 | window=hann_window[wnsize_dtype_device], 86 | center=center, 87 | pad_mode="reflect", 88 | normalized=False, 89 | onesided=True, 90 | return_complex=False, 91 | ) 92 | else: 93 | spec = torch.stft( 94 | y, 95 | n_fft, 96 | hop_length=hop_size, 97 | win_length=win_size, 98 | window=hann_window[wnsize_dtype_device], 99 | center=center, 100 | pad_mode="reflect", 101 | normalized=False, 102 | onesided=True, 103 | ) 104 | 105 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 106 | return spec 107 | 108 | 109 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 110 | global mel_basis 111 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 112 | fmax_dtype_device = str(fmax) + "_" + dtype_device 113 | if fmax_dtype_device not in mel_basis: 114 | mel = librosa_mel_fn( 115 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 116 | ) 117 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 118 | dtype=spec.dtype, device=spec.device 119 | ) 120 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 121 | spec = spectral_normalize_torch(spec) 122 | return spec 123 | 124 | 125 | def mel_spectrogram_torch( 126 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 127 | ): 128 | if torch.min(y) < -1.0: 129 | print("min value is ", torch.min(y)) 130 | if torch.max(y) > 1.0: 131 | print("max value is ", torch.max(y)) 132 | 133 | global mel_basis, hann_window 134 | dtype_device = str(y.dtype) + "_" + str(y.device) 135 | fmax_dtype_device = str(fmax) + "_" + dtype_device 136 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 137 | if fmax_dtype_device not in mel_basis: 138 | mel = librosa_mel_fn( 139 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 140 | ) 141 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 142 | dtype=y.dtype, device=y.device 143 | ) 144 | if wnsize_dtype_device not in hann_window: 145 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 146 | dtype=y.dtype, device=y.device 147 | ) 148 | 149 | y = torch.nn.functional.pad( 150 | y.unsqueeze(1), 151 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 152 | mode="reflect", 153 | ) 154 | y = y.squeeze(1) 155 | 156 | if version.parse(torch.__version__) >= version.parse("2"): 157 | spec = torch.stft( 158 | y, 159 | n_fft, 160 | hop_length=hop_size, 161 | win_length=win_size, 162 | window=hann_window[wnsize_dtype_device], 163 | center=center, 164 | pad_mode="reflect", 165 | normalized=False, 166 | onesided=True, 167 | return_complex=False, 168 | ) 169 | else: 170 | spec = torch.stft( 171 | y, 172 | n_fft, 173 | hop_length=hop_size, 174 | win_length=win_size, 175 | window=hann_window[wnsize_dtype_device], 176 | center=center, 177 | pad_mode="reflect", 178 | normalized=False, 179 | onesided=True, 180 | ) 181 | 182 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 183 | 184 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 185 | spec = spectral_normalize_torch(spec) 186 | 187 | return spec 188 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import commons 13 | from commons import init_weights, get_padding 14 | from transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__( 37 | self, 38 | in_channels, 39 | hidden_channels, 40 | out_channels, 41 | kernel_size, 42 | n_layers, 43 | p_dropout, 44 | ): 45 | super().__init__() 46 | self.in_channels = in_channels 47 | self.hidden_channels = hidden_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = kernel_size 50 | self.n_layers = n_layers 51 | self.p_dropout = p_dropout 52 | assert n_layers > 1, "Number of layers should be larger than 0." 53 | 54 | self.conv_layers = nn.ModuleList() 55 | self.norm_layers = nn.ModuleList() 56 | self.conv_layers.append( 57 | nn.Conv1d( 58 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 59 | ) 60 | ) 61 | self.norm_layers.append(LayerNorm(hidden_channels)) 62 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 63 | for _ in range(n_layers - 1): 64 | self.conv_layers.append( 65 | nn.Conv1d( 66 | hidden_channels, 67 | hidden_channels, 68 | kernel_size, 69 | padding=kernel_size // 2, 70 | ) 71 | ) 72 | self.norm_layers.append(LayerNorm(hidden_channels)) 73 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 74 | self.proj.weight.data.zero_() 75 | self.proj.bias.data.zero_() 76 | 77 | def forward(self, x, x_mask): 78 | x_org = x 79 | for i in range(self.n_layers): 80 | x = self.conv_layers[i](x * x_mask) 81 | x = self.norm_layers[i](x) 82 | x = self.relu_drop(x) 83 | x = x_org + self.proj(x) 84 | return x * x_mask 85 | 86 | 87 | class DDSConv(nn.Module): 88 | """ 89 | Dialted and Depth-Separable Convolution 90 | """ 91 | 92 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): 93 | super().__init__() 94 | self.channels = channels 95 | self.kernel_size = kernel_size 96 | self.n_layers = n_layers 97 | self.p_dropout = p_dropout 98 | 99 | self.drop = nn.Dropout(p_dropout) 100 | self.convs_sep = nn.ModuleList() 101 | self.convs_1x1 = nn.ModuleList() 102 | self.norms_1 = nn.ModuleList() 103 | self.norms_2 = nn.ModuleList() 104 | for i in range(n_layers): 105 | dilation = kernel_size**i 106 | padding = (kernel_size * dilation - dilation) // 2 107 | self.convs_sep.append( 108 | nn.Conv1d( 109 | channels, 110 | channels, 111 | kernel_size, 112 | groups=channels, 113 | dilation=dilation, 114 | padding=padding, 115 | ) 116 | ) 117 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 118 | self.norms_1.append(LayerNorm(channels)) 119 | self.norms_2.append(LayerNorm(channels)) 120 | 121 | def forward(self, x, x_mask, g=None): 122 | if g is not None: 123 | x = x + g 124 | for i in range(self.n_layers): 125 | y = self.convs_sep[i](x * x_mask) 126 | y = self.norms_1[i](y) 127 | y = F.gelu(y) 128 | y = self.convs_1x1[i](y) 129 | y = self.norms_2[i](y) 130 | y = F.gelu(y) 131 | y = self.drop(y) 132 | x = x + y 133 | return x * x_mask 134 | 135 | 136 | class WN(torch.nn.Module): 137 | def __init__( 138 | self, 139 | hidden_channels, 140 | kernel_size, 141 | dilation_rate, 142 | n_layers, 143 | gin_channels=0, 144 | p_dropout=0, 145 | ): 146 | super(WN, self).__init__() 147 | assert kernel_size % 2 == 1 148 | self.hidden_channels = hidden_channels 149 | self.kernel_size = (kernel_size,) 150 | self.dilation_rate = dilation_rate 151 | self.n_layers = n_layers 152 | self.gin_channels = gin_channels 153 | self.p_dropout = p_dropout 154 | 155 | self.in_layers = torch.nn.ModuleList() 156 | self.res_skip_layers = torch.nn.ModuleList() 157 | self.drop = nn.Dropout(p_dropout) 158 | 159 | if gin_channels != 0: 160 | cond_layer = torch.nn.Conv1d( 161 | gin_channels, 2 * hidden_channels * n_layers, 1 162 | ) 163 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 164 | 165 | for i in range(n_layers): 166 | dilation = dilation_rate**i 167 | padding = int((kernel_size * dilation - dilation) / 2) 168 | in_layer = torch.nn.Conv1d( 169 | hidden_channels, 170 | 2 * hidden_channels, 171 | kernel_size, 172 | dilation=dilation, 173 | padding=padding, 174 | ) 175 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 176 | self.in_layers.append(in_layer) 177 | 178 | # last one is not necessary 179 | if i < n_layers - 1: 180 | res_skip_channels = 2 * hidden_channels 181 | else: 182 | res_skip_channels = hidden_channels 183 | 184 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 185 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 186 | self.res_skip_layers.append(res_skip_layer) 187 | 188 | def forward(self, x, x_mask, g=None, **kwargs): 189 | output = torch.zeros_like(x) 190 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 191 | 192 | if g is not None: 193 | g = self.cond_layer(g) 194 | 195 | for i in range(self.n_layers): 196 | x_in = self.in_layers[i](x) 197 | if g is not None: 198 | cond_offset = i * 2 * self.hidden_channels 199 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 200 | else: 201 | g_l = torch.zeros_like(x_in) 202 | 203 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 204 | acts = self.drop(acts) 205 | 206 | res_skip_acts = self.res_skip_layers[i](acts) 207 | if i < self.n_layers - 1: 208 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 209 | x = (x + res_acts) * x_mask 210 | output = output + res_skip_acts[:, self.hidden_channels :, :] 211 | else: 212 | output = output + res_skip_acts 213 | return output * x_mask 214 | 215 | def remove_weight_norm(self): 216 | if self.gin_channels != 0: 217 | torch.nn.utils.remove_weight_norm(self.cond_layer) 218 | for l in self.in_layers: 219 | torch.nn.utils.remove_weight_norm(l) 220 | for l in self.res_skip_layers: 221 | torch.nn.utils.remove_weight_norm(l) 222 | 223 | 224 | class ResBlock1(torch.nn.Module): 225 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 226 | super(ResBlock1, self).__init__() 227 | self.convs1 = nn.ModuleList( 228 | [ 229 | weight_norm( 230 | Conv1d( 231 | channels, 232 | channels, 233 | kernel_size, 234 | 1, 235 | dilation=dilation[0], 236 | padding=get_padding(kernel_size, dilation[0]), 237 | ) 238 | ), 239 | weight_norm( 240 | Conv1d( 241 | channels, 242 | channels, 243 | kernel_size, 244 | 1, 245 | dilation=dilation[1], 246 | padding=get_padding(kernel_size, dilation[1]), 247 | ) 248 | ), 249 | weight_norm( 250 | Conv1d( 251 | channels, 252 | channels, 253 | kernel_size, 254 | 1, 255 | dilation=dilation[2], 256 | padding=get_padding(kernel_size, dilation[2]), 257 | ) 258 | ), 259 | ] 260 | ) 261 | self.convs1.apply(init_weights) 262 | 263 | self.convs2 = nn.ModuleList( 264 | [ 265 | weight_norm( 266 | Conv1d( 267 | channels, 268 | channels, 269 | kernel_size, 270 | 1, 271 | dilation=1, 272 | padding=get_padding(kernel_size, 1), 273 | ) 274 | ), 275 | weight_norm( 276 | Conv1d( 277 | channels, 278 | channels, 279 | kernel_size, 280 | 1, 281 | dilation=1, 282 | padding=get_padding(kernel_size, 1), 283 | ) 284 | ), 285 | weight_norm( 286 | Conv1d( 287 | channels, 288 | channels, 289 | kernel_size, 290 | 1, 291 | dilation=1, 292 | padding=get_padding(kernel_size, 1), 293 | ) 294 | ), 295 | ] 296 | ) 297 | self.convs2.apply(init_weights) 298 | 299 | def forward(self, x, x_mask=None): 300 | for c1, c2 in zip(self.convs1, self.convs2): 301 | xt = F.leaky_relu(x, LRELU_SLOPE) 302 | if x_mask is not None: 303 | xt = xt * x_mask 304 | xt = c1(xt) 305 | xt = F.leaky_relu(xt, LRELU_SLOPE) 306 | if x_mask is not None: 307 | xt = xt * x_mask 308 | xt = c2(xt) 309 | x = xt + x 310 | if x_mask is not None: 311 | x = x * x_mask 312 | return x 313 | 314 | def remove_weight_norm(self): 315 | for l in self.convs1: 316 | remove_weight_norm(l) 317 | for l in self.convs2: 318 | remove_weight_norm(l) 319 | 320 | 321 | class ResBlock2(torch.nn.Module): 322 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 323 | super(ResBlock2, self).__init__() 324 | self.convs = nn.ModuleList( 325 | [ 326 | weight_norm( 327 | Conv1d( 328 | channels, 329 | channels, 330 | kernel_size, 331 | 1, 332 | dilation=dilation[0], 333 | padding=get_padding(kernel_size, dilation[0]), 334 | ) 335 | ), 336 | weight_norm( 337 | Conv1d( 338 | channels, 339 | channels, 340 | kernel_size, 341 | 1, 342 | dilation=dilation[1], 343 | padding=get_padding(kernel_size, dilation[1]), 344 | ) 345 | ), 346 | ] 347 | ) 348 | self.convs.apply(init_weights) 349 | 350 | def forward(self, x, x_mask=None): 351 | for c in self.convs: 352 | xt = F.leaky_relu(x, LRELU_SLOPE) 353 | if x_mask is not None: 354 | xt = xt * x_mask 355 | xt = c(xt) 356 | x = xt + x 357 | if x_mask is not None: 358 | x = x * x_mask 359 | return x 360 | 361 | def remove_weight_norm(self): 362 | for l in self.convs: 363 | remove_weight_norm(l) 364 | 365 | 366 | class Log(nn.Module): 367 | def forward(self, x, x_mask, reverse=False, **kwargs): 368 | if not reverse: 369 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 370 | logdet = torch.sum(-y, [1, 2]) 371 | return y, logdet 372 | else: 373 | x = torch.exp(x) * x_mask 374 | return x 375 | 376 | 377 | class Flip(nn.Module): 378 | def forward(self, x, *args, reverse=False, **kwargs): 379 | x = torch.flip(x, [1]) 380 | if not reverse: 381 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 382 | return x, logdet 383 | else: 384 | return x 385 | 386 | 387 | class ElementwiseAffine(nn.Module): 388 | def __init__(self, channels): 389 | super().__init__() 390 | self.channels = channels 391 | self.m = nn.Parameter(torch.zeros(channels, 1)) 392 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 393 | 394 | def forward(self, x, x_mask, reverse=False, **kwargs): 395 | if not reverse: 396 | y = self.m + torch.exp(self.logs) * x 397 | y = y * x_mask 398 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 399 | return y, logdet 400 | else: 401 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 402 | return x 403 | 404 | 405 | class ResidualCouplingLayer(nn.Module): 406 | def __init__( 407 | self, 408 | channels, 409 | hidden_channels, 410 | kernel_size, 411 | dilation_rate, 412 | n_layers, 413 | p_dropout=0, 414 | gin_channels=0, 415 | mean_only=False, 416 | ): 417 | assert channels % 2 == 0, "channels should be divisible by 2" 418 | super().__init__() 419 | self.channels = channels 420 | self.hidden_channels = hidden_channels 421 | self.kernel_size = kernel_size 422 | self.dilation_rate = dilation_rate 423 | self.n_layers = n_layers 424 | self.half_channels = channels // 2 425 | self.mean_only = mean_only 426 | 427 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 428 | self.enc = WN( 429 | hidden_channels, 430 | kernel_size, 431 | dilation_rate, 432 | n_layers, 433 | p_dropout=p_dropout, 434 | gin_channels=gin_channels, 435 | ) 436 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 437 | self.post.weight.data.zero_() 438 | self.post.bias.data.zero_() 439 | 440 | def forward(self, x, x_mask, g=None, reverse=False): 441 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 442 | h = self.pre(x0) * x_mask 443 | h = self.enc(h, x_mask, g=g) 444 | stats = self.post(h) * x_mask 445 | if not self.mean_only: 446 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 447 | else: 448 | m = stats 449 | logs = torch.zeros_like(m) 450 | 451 | if not reverse: 452 | x1 = m + x1 * torch.exp(logs) * x_mask 453 | x = torch.cat([x0, x1], 1) 454 | logdet = torch.sum(logs, [1, 2]) 455 | return x, logdet 456 | else: 457 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 458 | x = torch.cat([x0, x1], 1) 459 | return x 460 | 461 | 462 | class ConvFlow(nn.Module): 463 | def __init__( 464 | self, 465 | in_channels, 466 | filter_channels, 467 | kernel_size, 468 | n_layers, 469 | num_bins=10, 470 | tail_bound=5.0, 471 | ): 472 | super().__init__() 473 | self.in_channels = in_channels 474 | self.filter_channels = filter_channels 475 | self.kernel_size = kernel_size 476 | self.n_layers = n_layers 477 | self.num_bins = num_bins 478 | self.tail_bound = tail_bound 479 | self.half_channels = in_channels // 2 480 | 481 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 482 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) 483 | self.proj = nn.Conv1d( 484 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1 485 | ) 486 | self.proj.weight.data.zero_() 487 | self.proj.bias.data.zero_() 488 | 489 | def forward(self, x, x_mask, g=None, reverse=False): 490 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 491 | h = self.pre(x0) 492 | h = self.convs(h, x_mask, g=g) 493 | h = self.proj(h) * x_mask 494 | 495 | b, c, t = x0.shape 496 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 497 | 498 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) 499 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( 500 | self.filter_channels 501 | ) 502 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 503 | 504 | x1, logabsdet = piecewise_rational_quadratic_transform( 505 | x1, 506 | unnormalized_widths, 507 | unnormalized_heights, 508 | unnormalized_derivatives, 509 | inverse=reverse, 510 | tails="linear", 511 | tail_bound=self.tail_bound, 512 | ) 513 | 514 | x = torch.cat([x0, x1], 1) * x_mask 515 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 516 | if not reverse: 517 | return x, logdet 518 | else: 519 | return x 520 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /monotonic_align/monotonic_align/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/monotonic_align/monotonic_align/.gitkeep -------------------------------------------------------------------------------- /monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name="monotonic_align", 7 | ext_modules=cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()], 9 | ) 10 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import text 3 | from utils import load_filepaths_and_text 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--out_extension", default="cleaned") 8 | parser.add_argument("--text_index", default=1, type=int) 9 | parser.add_argument( 10 | "--filelists", 11 | nargs="+", 12 | default=[ 13 | "filelists/ljs_audio_text_val_filelist.txt", 14 | "filelists/ljs_audio_text_test_filelist.txt", 15 | ], 16 | ) 17 | parser.add_argument("--text_cleaners", nargs="+", default=["english_cleaners2"]) 18 | 19 | args = parser.parse_args() 20 | 21 | for filelist in args.filelists: 22 | print("START:", filelist) 23 | filepaths_and_text = load_filepaths_and_text(filelist) 24 | for i in range(len(filepaths_and_text)): 25 | original_text = filepaths_and_text[i][args.text_index] 26 | cleaned_text = text._clean_text(original_text, args.text_cleaners) 27 | filepaths_and_text[i][args.text_index] = cleaned_text 28 | 29 | new_filelist = filelist + "." + args.out_extension 30 | with open(new_filelist, "w", encoding="utf-8") as f: 31 | f.writelines(["|".join(x) + "\n" for x in filepaths_and_text]) 32 | -------------------------------------------------------------------------------- /preprocess_audio.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCTK 3 | https://datashare.ed.ac.uk/handle/10283/3443 4 | VCTK trim info 5 | https://github.com/nii-yamagishilab/vctk-silence-labels 6 | 7 | Warning! This code is not properly debugged. 8 | It is recommended to run it only once for the initial state of the audio file (flac or wav). 9 | If executed repeatedly, consecutive application of "trim" may potentially damage the audio file. 10 | 11 | >>> $ pip install librosa==0.9.2 numpy==1.23.5 scipy==1.9.1 tqdm # [option] 12 | >>> $ cd /path/to/the/your/vits2 13 | >>> $ ln -s /path/to/the/VCTK/* DUMMY2/ 14 | >>> $ git clone https://github.com/nii-yamagishilab/vctk-silence-labels filelists/vctk-silence-labels 15 | >>> $ python preprocess_audio.py --filelists <~/filelist.txt> --config <~/config.json> --trim <~/info.txt> 16 | """ 17 | 18 | import argparse 19 | import os 20 | 21 | import librosa 22 | import numpy as np 23 | from scipy.io import wavfile 24 | from tqdm.auto import tqdm 25 | 26 | import utils 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--filelists", 32 | nargs="+", 33 | default=[ 34 | "filelists/vctk_audio_sid_text_test_filelist.txt", 35 | "filelists/vctk_audio_sid_text_val_filelist.txt", 36 | "filelists/vctk_audio_sid_text_train_filelist.txt", 37 | ], 38 | ) 39 | parser.add_argument("--config", default="configs/vctk_base2.json", type=str) 40 | parser.add_argument( 41 | "--trim", 42 | default="filelists/vctk-silence-labels/vctk-silences.0.92.txt", 43 | type=str, 44 | ) 45 | args = parser.parse_args() 46 | 47 | with open(args.trim, "r", encoding="utf8") as f: 48 | lines = list(filter(lambda x: len(x) > 0, f.read().split("\n"))) 49 | trim_info = {} 50 | for line in lines: 51 | line = line.split(" ") 52 | trim_info[line[0]] = (float(line[1]), float(line[2])) 53 | 54 | hps = utils.get_hparams_from_file(args.config) 55 | for filelist in args.filelists: 56 | print("START:", filelist) 57 | with open(filelist, "r", encoding="utf8") as f: 58 | lines = list(filter(lambda x: len(x) > 0, f.read().split("\n"))) 59 | 60 | for line in tqdm(lines, total=len(lines), desc=filelist): 61 | src_filename = line.split("|")[0] 62 | if not os.path.isfile(src_filename): 63 | if os.path.isfile(src_filename.replace(".wav", "_mic1.flac")): 64 | src_filename = src_filename.replace(".wav", "_mic1.flac") 65 | else: 66 | continue 67 | 68 | if src_filename.endswith("_mic1.flac"): 69 | tgt_filename = src_filename.replace("_mic1.flac", ".wav") 70 | else: 71 | tgt_filename = src_filename 72 | 73 | basename = os.path.splitext(os.path.basename(src_filename))[0].replace( 74 | "_mic1", "" 75 | ) 76 | if trim_info.get(basename) is None: 77 | print( 78 | f"file info: '{src_filename}' doesn't exist in trim info '{args.trim}'" 79 | ) 80 | continue 81 | 82 | start, end = trim_info[basename][0], trim_info[basename][1] 83 | 84 | # warning: it could be make the file to unacceptable 85 | y, _ = librosa.core.load( 86 | src_filename, 87 | sr=hps.data.sampling_rate, 88 | mono=True, 89 | res_type="scipy", 90 | offset=start, 91 | duration=end - start, 92 | ) 93 | 94 | # y, _ = librosa.effects.trim( 95 | # y=y, 96 | # frame_length=4096, 97 | # hop_length=256, 98 | # top_db=35, 99 | # ) 100 | 101 | if y.shape[-1] < hps.train.segment_size: 102 | continue 103 | 104 | y = y * hps.data.max_wav_value 105 | wavfile.write( 106 | filename=tgt_filename, 107 | rate=hps.data.sampling_rate, 108 | data=y.astype(np.int16), 109 | ) 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==3.0.2 2 | librosa==0.10.1 3 | matplotlib==3.7.2 4 | numpy==1.24.4 5 | phonemizer==3.2.1 6 | scipy==1.11.2 7 | # torch==2.0.1 8 | # torchaudio==2.0.2 9 | # torchvision==0.15.2 10 | Unidecode==1.3.6 11 | tensorboard==2.14.0 12 | onnx==1.14.1 13 | onnxruntime==1.15.1 14 | gradio 15 | -------------------------------------------------------------------------------- /resources/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/image.png -------------------------------------------------------------------------------- /resources/sid_src_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/sid_src_3.wav -------------------------------------------------------------------------------- /resources/sid_src_3_to_tgt_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/sid_src_3_to_tgt_1.wav -------------------------------------------------------------------------------- /resources/sid_src_3_to_tgt_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/sid_src_3_to_tgt_2.wav -------------------------------------------------------------------------------- /resources/sid_src_3_to_tgt_4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/sid_src_3_to_tgt_4.wav -------------------------------------------------------------------------------- /resources/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/test.wav -------------------------------------------------------------------------------- /resources/vctk_onnx_test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/vctk_onnx_test.wav -------------------------------------------------------------------------------- /resources/vctk_test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/p0p4k/vits2_pytorch/1f4f3790568180f8dec4419d5cad5d0877b034bb/resources/vctk_test.wav -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | from text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, cleaner_names): 12 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | """ 19 | sequence = [] 20 | 21 | clean_text = _clean_text(text, cleaner_names) 22 | for symbol in clean_text: 23 | if symbol in _symbol_to_id.keys(): 24 | symbol_id = _symbol_to_id[symbol] 25 | sequence += [symbol_id] 26 | else: 27 | continue 28 | return sequence 29 | 30 | 31 | def cleaned_text_to_sequence(cleaned_text): 32 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 33 | Args: 34 | text: string to convert to a sequence 35 | Returns: 36 | List of integers corresponding to the symbols in the text 37 | """ 38 | sequence = [] 39 | 40 | for symbol in cleaned_text: 41 | if symbol in _symbol_to_id.keys(): 42 | symbol_id = _symbol_to_id[symbol] 43 | sequence += [symbol_id] 44 | else: 45 | continue 46 | return sequence 47 | 48 | 49 | def sequence_to_text(sequence): 50 | """Converts a sequence of IDs back to a string""" 51 | result = "" 52 | for symbol_id in sequence: 53 | s = _id_to_symbol[symbol_id] 54 | result += s 55 | return result 56 | 57 | 58 | def _clean_text(text, cleaner_names): 59 | for name in cleaner_names: 60 | cleaner = getattr(cleaners, name) 61 | if not cleaner: 62 | raise Exception("Unknown cleaner: %s" % name) 63 | text = cleaner(text) 64 | return text 65 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | """ 14 | 15 | import re 16 | from unidecode import unidecode 17 | from phonemizer import phonemize 18 | from phonemizer.backend import EspeakBackend 19 | backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True) 20 | 21 | 22 | # Regular expression matching whitespace: 23 | _whitespace_re = re.compile(r"\s+") 24 | 25 | # List of (regular expression, replacement) pairs for abbreviations: 26 | _abbreviations = [ 27 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 28 | for x in [ 29 | ("mrs", "misess"), 30 | ("mr", "mister"), 31 | ("dr", "doctor"), 32 | ("st", "saint"), 33 | ("co", "company"), 34 | ("jr", "junior"), 35 | ("maj", "major"), 36 | ("gen", "general"), 37 | ("drs", "doctors"), 38 | ("rev", "reverend"), 39 | ("lt", "lieutenant"), 40 | ("hon", "honorable"), 41 | ("sgt", "sergeant"), 42 | ("capt", "captain"), 43 | ("esq", "esquire"), 44 | ("ltd", "limited"), 45 | ("col", "colonel"), 46 | ("ft", "fort"), 47 | ] 48 | ] 49 | 50 | 51 | def expand_abbreviations(text): 52 | for regex, replacement in _abbreviations: 53 | text = re.sub(regex, replacement, text) 54 | return text 55 | 56 | 57 | def expand_numbers(text): 58 | return normalize_numbers(text) 59 | 60 | 61 | def lowercase(text): 62 | return text.lower() 63 | 64 | 65 | def collapse_whitespace(text): 66 | return re.sub(_whitespace_re, " ", text) 67 | 68 | 69 | def convert_to_ascii(text): 70 | return unidecode(text) 71 | 72 | 73 | def basic_cleaners(text): 74 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 75 | text = lowercase(text) 76 | text = collapse_whitespace(text) 77 | return text 78 | 79 | 80 | def transliteration_cleaners(text): 81 | """Pipeline for non-English text that transliterates to ASCII.""" 82 | text = convert_to_ascii(text) 83 | text = lowercase(text) 84 | text = collapse_whitespace(text) 85 | return text 86 | 87 | 88 | def english_cleaners(text): 89 | """Pipeline for English text, including abbreviation expansion.""" 90 | text = convert_to_ascii(text) 91 | text = lowercase(text) 92 | text = expand_abbreviations(text) 93 | phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) 94 | phonemes = collapse_whitespace(phonemes) 95 | return phonemes 96 | 97 | 98 | def english_cleaners2(text): 99 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 100 | text = convert_to_ascii(text) 101 | text = lowercase(text) 102 | text = expand_abbreviations(text) 103 | phonemes = phonemize( 104 | text, 105 | language="en-us", 106 | backend="espeak", 107 | strip=True, 108 | preserve_punctuation=True, 109 | with_stress=True, 110 | ) 111 | phonemes = collapse_whitespace(phonemes) 112 | return phonemes 113 | 114 | 115 | def english_cleaners3(text): 116 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 117 | text = convert_to_ascii(text) 118 | text = lowercase(text) 119 | text = expand_abbreviations(text) 120 | phonemes = backend.phonemize([text], strip=True)[0] 121 | phonemes = collapse_whitespace(phonemes) 122 | return phonemes 123 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | """ 6 | _pad = "_" 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | 11 | 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import itertools 5 | import math 6 | import torch 7 | from torch import nn, optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # from tensorboardX import SummaryWriter 13 | import torch.multiprocessing as mp 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.cuda.amp import autocast, GradScaler 17 | import tqdm 18 | 19 | import commons 20 | import utils 21 | from data_utils import TextAudioLoader, TextAudioCollate, DistributedBucketSampler 22 | from models import ( 23 | SynthesizerTrn, 24 | MultiPeriodDiscriminator, 25 | DurationDiscriminatorV1, 26 | DurationDiscriminatorV2, 27 | AVAILABLE_FLOW_TYPES, 28 | AVAILABLE_DURATION_DISCRIMINATOR_TYPES 29 | ) 30 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss 31 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch 32 | from text.symbols import symbols 33 | 34 | 35 | torch.backends.cudnn.benchmark = True 36 | global_step = 0 37 | 38 | 39 | def main(): 40 | """Assume Single Node Multi GPUs Training Only""" 41 | assert torch.cuda.is_available(), "CPU training is not allowed." 42 | 43 | n_gpus = torch.cuda.device_count() 44 | os.environ["MASTER_ADDR"] = "localhost" 45 | os.environ["MASTER_PORT"] = "6060" 46 | 47 | hps = utils.get_hparams() 48 | mp.spawn( 49 | run, 50 | nprocs=n_gpus, 51 | args=( 52 | n_gpus, 53 | hps, 54 | ), 55 | ) 56 | 57 | 58 | def run(rank, n_gpus, hps): 59 | global global_step 60 | if rank == 0: 61 | logger = utils.get_logger(hps.model_dir) 62 | logger.info(hps) 63 | utils.check_git_hash(hps.model_dir) 64 | writer = SummaryWriter(log_dir=hps.model_dir) 65 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 66 | 67 | dist.init_process_group( 68 | backend="nccl", init_method="env://", world_size=n_gpus, rank=rank 69 | ) 70 | torch.manual_seed(hps.train.seed) 71 | torch.cuda.set_device(rank) 72 | 73 | if ( 74 | "use_mel_posterior_encoder" in hps.model.keys() 75 | and hps.model.use_mel_posterior_encoder == True 76 | ): 77 | print("Using mel posterior encoder for VITS2") 78 | posterior_channels = 80 # vits2 79 | hps.data.use_mel_posterior_encoder = True 80 | else: 81 | print("Using lin posterior encoder for VITS1") 82 | posterior_channels = hps.data.filter_length // 2 + 1 83 | hps.data.use_mel_posterior_encoder = False 84 | 85 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data) 86 | train_sampler = DistributedBucketSampler( 87 | train_dataset, 88 | hps.train.batch_size, 89 | [32, 300, 400, 500, 600, 700, 800, 900, 1000], 90 | num_replicas=n_gpus, 91 | rank=rank, 92 | shuffle=True, 93 | ) 94 | 95 | collate_fn = TextAudioCollate() 96 | train_loader = DataLoader( 97 | train_dataset, 98 | num_workers=8, 99 | shuffle=False, 100 | pin_memory=True, 101 | collate_fn=collate_fn, 102 | batch_sampler=train_sampler, 103 | ) 104 | if rank == 0: 105 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data) 106 | eval_loader = DataLoader( 107 | eval_dataset, 108 | num_workers=8, 109 | shuffle=False, 110 | batch_size=hps.train.batch_size, 111 | pin_memory=True, 112 | drop_last=False, 113 | collate_fn=collate_fn, 114 | ) 115 | # some of these flags are not being used in the code and directly set in hps json file. 116 | # they are kept here for reference and prototyping. 117 | 118 | if ( 119 | "use_transformer_flows" in hps.model.keys() 120 | and hps.model.use_transformer_flows == True 121 | ): 122 | use_transformer_flows = True 123 | transformer_flow_type = hps.model.transformer_flow_type 124 | print(f"Using transformer flows {transformer_flow_type} for VITS2") 125 | assert ( 126 | transformer_flow_type in AVAILABLE_FLOW_TYPES 127 | ), f"transformer_flow_type must be one of {AVAILABLE_FLOW_TYPES}" 128 | else: 129 | print("Using normal flows for VITS1") 130 | use_transformer_flows = False 131 | 132 | if ( 133 | "use_spk_conditioned_encoder" in hps.model.keys() 134 | and hps.model.use_spk_conditioned_encoder == True 135 | ): 136 | if hps.data.n_speakers == 0: 137 | print("Warning: use_spk_conditioned_encoder is True but n_speakers is 0") 138 | print( 139 | "Setting use_spk_conditioned_encoder to False as model is a single speaker model" 140 | ) 141 | use_spk_conditioned_encoder = False 142 | else: 143 | print("Using normal encoder for VITS1") 144 | use_spk_conditioned_encoder = False 145 | 146 | if ( 147 | "use_noise_scaled_mas" in hps.model.keys() 148 | and hps.model.use_noise_scaled_mas == True 149 | ): 150 | print("Using noise scaled MAS for VITS2") 151 | use_noise_scaled_mas = True 152 | mas_noise_scale_initial = 0.01 153 | noise_scale_delta = 2e-6 154 | else: 155 | print("Using normal MAS for VITS1") 156 | use_noise_scaled_mas = False 157 | mas_noise_scale_initial = 0.0 158 | noise_scale_delta = 0.0 159 | 160 | if ( 161 | "use_duration_discriminator" in hps.model.keys() 162 | and hps.model.use_duration_discriminator == True 163 | ): 164 | # print("Using duration discriminator for VITS2") 165 | use_duration_discriminator = True 166 | duration_discriminator_type = hps.model.duration_discriminator_type 167 | print(f"Using duration_discriminator {duration_discriminator_type} for VITS2") 168 | assert duration_discriminator_type in AVAILABLE_DURATION_DISCRIMINATOR_TYPES, f"duration_discriminator_type must be one of {AVAILABLE_DURATION_DISCRIMINATOR_TYPES}" 169 | if duration_discriminator_type == "dur_disc_1": 170 | net_dur_disc = DurationDiscriminatorV1( 171 | hps.model.hidden_channels, 172 | hps.model.hidden_channels, 173 | 3, 174 | 0.1, 175 | gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, 176 | ).cuda(rank) 177 | elif duration_discriminator_type == "dur_disc_2": 178 | net_dur_disc = DurationDiscriminatorV2( 179 | hps.model.hidden_channels, 180 | hps.model.hidden_channels, 181 | 3, 182 | 0.1, 183 | gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, 184 | ).cuda(rank) 185 | else: 186 | print("NOT using any duration discriminator like VITS1") 187 | net_dur_disc = None 188 | use_duration_discriminator = False 189 | 190 | net_g = SynthesizerTrn( 191 | len(symbols), 192 | posterior_channels, 193 | hps.train.segment_size // hps.data.hop_length, 194 | mas_noise_scale_initial=mas_noise_scale_initial, 195 | noise_scale_delta=noise_scale_delta, 196 | **hps.model, 197 | ).cuda(rank) 198 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 199 | optim_g = torch.optim.AdamW( 200 | net_g.parameters(), 201 | hps.train.learning_rate, 202 | betas=hps.train.betas, 203 | eps=hps.train.eps, 204 | ) 205 | optim_d = torch.optim.AdamW( 206 | net_d.parameters(), 207 | hps.train.learning_rate, 208 | betas=hps.train.betas, 209 | eps=hps.train.eps, 210 | ) 211 | if net_dur_disc is not None: 212 | optim_dur_disc = torch.optim.AdamW( 213 | net_dur_disc.parameters(), 214 | hps.train.learning_rate, 215 | betas=hps.train.betas, 216 | eps=hps.train.eps, 217 | ) 218 | else: 219 | optim_dur_disc = None 220 | 221 | net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) 222 | net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) 223 | if net_dur_disc is not None: 224 | net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True) 225 | 226 | try: 227 | _, _, _, epoch_str = utils.load_checkpoint( 228 | utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g 229 | ) 230 | _, _, _, epoch_str = utils.load_checkpoint( 231 | utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d 232 | ) 233 | if net_dur_disc is not None: 234 | _, _, _, epoch_str = utils.load_checkpoint( 235 | utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), 236 | net_dur_disc, 237 | optim_dur_disc, 238 | ) 239 | global_step = (epoch_str - 1) * len(train_loader) 240 | except: 241 | epoch_str = 1 242 | global_step = 0 243 | 244 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 245 | optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 246 | ) 247 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 248 | optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 249 | ) 250 | if net_dur_disc is not None: 251 | scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR( 252 | optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 253 | ) 254 | else: 255 | scheduler_dur_disc = None 256 | 257 | scaler = GradScaler(enabled=hps.train.fp16_run) 258 | 259 | for epoch in range(epoch_str, hps.train.epochs + 1): 260 | if rank == 0: 261 | train_and_evaluate( 262 | rank, 263 | epoch, 264 | hps, 265 | [net_g, net_d, net_dur_disc], 266 | [optim_g, optim_d, optim_dur_disc], 267 | [scheduler_g, scheduler_d, scheduler_dur_disc], 268 | scaler, 269 | [train_loader, eval_loader], 270 | logger, 271 | [writer, writer_eval], 272 | ) 273 | else: 274 | train_and_evaluate( 275 | rank, 276 | epoch, 277 | hps, 278 | [net_g, net_d, net_dur_disc], 279 | [optim_g, optim_d, optim_dur_disc], 280 | [scheduler_g, scheduler_d, scheduler_dur_disc], 281 | scaler, 282 | [train_loader, None], 283 | None, 284 | None, 285 | ) 286 | scheduler_g.step() 287 | scheduler_d.step() 288 | if net_dur_disc is not None: 289 | scheduler_dur_disc.step() 290 | 291 | 292 | def train_and_evaluate( 293 | rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers 294 | ): 295 | net_g, net_d, net_dur_disc = nets 296 | optim_g, optim_d, optim_dur_disc = optims 297 | scheduler_g, scheduler_d, scheduler_dur_disc = schedulers 298 | train_loader, eval_loader = loaders 299 | if writers is not None: 300 | writer, writer_eval = writers 301 | 302 | train_loader.batch_sampler.set_epoch(epoch) 303 | global global_step 304 | 305 | net_g.train() 306 | net_d.train() 307 | if net_dur_disc is not None: 308 | net_dur_disc.train() 309 | 310 | if rank == 0: 311 | loader = tqdm.tqdm(train_loader, desc="Loading train data") 312 | else: 313 | loader = train_loader 314 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate( 315 | loader 316 | ): 317 | if net_g.module.use_noise_scaled_mas: 318 | current_mas_noise_scale = ( 319 | net_g.module.mas_noise_scale_initial 320 | - net_g.module.noise_scale_delta * global_step 321 | ) 322 | net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0) 323 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( 324 | rank, non_blocking=True 325 | ) 326 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( 327 | rank, non_blocking=True 328 | ) 329 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( 330 | rank, non_blocking=True 331 | ) 332 | 333 | with autocast(enabled=hps.train.fp16_run): 334 | ( 335 | y_hat, 336 | l_length, 337 | attn, 338 | ids_slice, 339 | x_mask, 340 | z_mask, 341 | (z, z_p, m_p, logs_p, m_q, logs_q), 342 | (hidden_x, logw, logw_), 343 | ) = net_g(x, x_lengths, spec, spec_lengths) 344 | 345 | if ( 346 | hps.model.use_mel_posterior_encoder 347 | or hps.data.use_mel_posterior_encoder 348 | ): 349 | mel = spec 350 | else: 351 | mel = spec_to_mel_torch( 352 | spec.float(), 353 | hps.data.filter_length, 354 | hps.data.n_mel_channels, 355 | hps.data.sampling_rate, 356 | hps.data.mel_fmin, 357 | hps.data.mel_fmax, 358 | ) 359 | y_mel = commons.slice_segments( 360 | mel, ids_slice, hps.train.segment_size // hps.data.hop_length 361 | ) 362 | y_hat_mel = mel_spectrogram_torch( 363 | y_hat.squeeze(1), 364 | hps.data.filter_length, 365 | hps.data.n_mel_channels, 366 | hps.data.sampling_rate, 367 | hps.data.hop_length, 368 | hps.data.win_length, 369 | hps.data.mel_fmin, 370 | hps.data.mel_fmax, 371 | ) 372 | 373 | y = commons.slice_segments( 374 | y, ids_slice * hps.data.hop_length, hps.train.segment_size 375 | ) # slice 376 | 377 | # Discriminator 378 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 379 | with autocast(enabled=False): 380 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( 381 | y_d_hat_r, y_d_hat_g 382 | ) 383 | loss_disc_all = loss_disc 384 | 385 | # Duration Discriminator 386 | if net_dur_disc is not None: 387 | y_dur_hat_r, y_dur_hat_g = net_dur_disc( 388 | hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach() 389 | ) 390 | with autocast(enabled=False): 391 | # TODO: I think need to mean using the mask, but for now, just mean all 392 | ( 393 | loss_dur_disc, 394 | losses_dur_disc_r, 395 | losses_dur_disc_g, 396 | ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g) 397 | loss_dur_disc_all = loss_dur_disc 398 | optim_dur_disc.zero_grad() 399 | scaler.scale(loss_dur_disc_all).backward() 400 | scaler.unscale_(optim_dur_disc) 401 | grad_norm_dur_disc = commons.clip_grad_value_( 402 | net_dur_disc.parameters(), None 403 | ) 404 | scaler.step(optim_dur_disc) 405 | 406 | optim_d.zero_grad() 407 | scaler.scale(loss_disc_all).backward() 408 | scaler.unscale_(optim_d) 409 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) 410 | scaler.step(optim_d) 411 | 412 | with autocast(enabled=hps.train.fp16_run): 413 | # Generator 414 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 415 | if net_dur_disc is not None: 416 | y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw) 417 | with autocast(enabled=False): 418 | loss_dur = torch.sum(l_length.float()) 419 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 420 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 421 | 422 | loss_fm = feature_loss(fmap_r, fmap_g) 423 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 424 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl 425 | if net_dur_disc is not None: 426 | loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g) 427 | loss_gen_all += loss_dur_gen 428 | 429 | optim_g.zero_grad() 430 | scaler.scale(loss_gen_all).backward() 431 | scaler.unscale_(optim_g) 432 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) 433 | scaler.step(optim_g) 434 | scaler.update() 435 | 436 | if rank == 0: 437 | if global_step % hps.train.log_interval == 0: 438 | lr = optim_g.param_groups[0]["lr"] 439 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl] 440 | logger.info( 441 | "Train Epoch: {} [{:.0f}%]".format( 442 | epoch, 100.0 * batch_idx / len(train_loader) 443 | ) 444 | ) 445 | logger.info([x.item() for x in losses] + [global_step, lr]) 446 | 447 | scalar_dict = { 448 | "loss/g/total": loss_gen_all, 449 | "loss/d/total": loss_disc_all, 450 | "learning_rate": lr, 451 | "grad_norm_d": grad_norm_d, 452 | "grad_norm_g": grad_norm_g, 453 | } 454 | if net_dur_disc is not None: 455 | scalar_dict.update( 456 | { 457 | "loss/dur_disc/total": loss_dur_disc_all, 458 | "grad_norm_dur_disc": grad_norm_dur_disc, 459 | } 460 | ) 461 | scalar_dict.update( 462 | { 463 | "loss/g/fm": loss_fm, 464 | "loss/g/mel": loss_mel, 465 | "loss/g/dur": loss_dur, 466 | "loss/g/kl": loss_kl, 467 | } 468 | ) 469 | 470 | scalar_dict.update( 471 | {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} 472 | ) 473 | scalar_dict.update( 474 | {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} 475 | ) 476 | scalar_dict.update( 477 | {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} 478 | ) 479 | 480 | # if net_dur_disc is not None: 481 | # scalar_dict.update({"loss/dur_disc_r" : f"{losses_dur_disc_r}"}) 482 | # scalar_dict.update({"loss/dur_disc_g" : f"{losses_dur_disc_g}"}) 483 | # scalar_dict.update({"loss/dur_gen" : f"{loss_dur_gen}"}) 484 | 485 | image_dict = { 486 | "slice/mel_org": utils.plot_spectrogram_to_numpy( 487 | y_mel[0].data.cpu().numpy() 488 | ), 489 | "slice/mel_gen": utils.plot_spectrogram_to_numpy( 490 | y_hat_mel[0].data.cpu().numpy() 491 | ), 492 | "all/mel": utils.plot_spectrogram_to_numpy( 493 | mel[0].data.cpu().numpy() 494 | ), 495 | "all/attn": utils.plot_alignment_to_numpy( 496 | attn[0, 0].data.cpu().numpy() 497 | ), 498 | } 499 | utils.summarize( 500 | writer=writer, 501 | global_step=global_step, 502 | images=image_dict, 503 | scalars=scalar_dict, 504 | ) 505 | 506 | if global_step % hps.train.eval_interval == 0: 507 | evaluate(hps, net_g, eval_loader, writer_eval) 508 | utils.save_checkpoint( 509 | net_g, 510 | optim_g, 511 | hps.train.learning_rate, 512 | epoch, 513 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), 514 | ) 515 | utils.save_checkpoint( 516 | net_d, 517 | optim_d, 518 | hps.train.learning_rate, 519 | epoch, 520 | os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), 521 | ) 522 | if net_dur_disc is not None: 523 | utils.save_checkpoint( 524 | net_dur_disc, 525 | optim_dur_disc, 526 | hps.train.learning_rate, 527 | epoch, 528 | os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), 529 | ) 530 | utils.remove_old_checkpoints(hps.model_dir, prefixes=["G_*.pth", "D_*.pth", "DUR_*.pth"]) 531 | global_step += 1 532 | 533 | if rank == 0: 534 | logger.info("====> Epoch: {}".format(epoch)) 535 | 536 | 537 | def evaluate(hps, generator, eval_loader, writer_eval): 538 | generator.eval() 539 | with torch.no_grad(): 540 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate( 541 | eval_loader 542 | ): 543 | x, x_lengths = x.cuda(0), x_lengths.cuda(0) 544 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) 545 | y, y_lengths = y.cuda(0), y_lengths.cuda(0) 546 | 547 | # remove else 548 | x = x[:1] 549 | x_lengths = x_lengths[:1] 550 | spec = spec[:1] 551 | spec_lengths = spec_lengths[:1] 552 | y = y[:1] 553 | y_lengths = y_lengths[:1] 554 | break 555 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, max_len=1000) 556 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length 557 | 558 | if hps.model.use_mel_posterior_encoder or hps.data.use_mel_posterior_encoder: 559 | mel = spec 560 | else: 561 | mel = spec_to_mel_torch( 562 | spec, 563 | hps.data.filter_length, 564 | hps.data.n_mel_channels, 565 | hps.data.sampling_rate, 566 | hps.data.mel_fmin, 567 | hps.data.mel_fmax, 568 | ) 569 | y_hat_mel = mel_spectrogram_torch( 570 | y_hat.squeeze(1).float(), 571 | hps.data.filter_length, 572 | hps.data.n_mel_channels, 573 | hps.data.sampling_rate, 574 | hps.data.hop_length, 575 | hps.data.win_length, 576 | hps.data.mel_fmin, 577 | hps.data.mel_fmax, 578 | ) 579 | image_dict = { 580 | "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) 581 | } 582 | audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]} 583 | if global_step == 0: 584 | image_dict.update( 585 | {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())} 586 | ) 587 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]}) 588 | 589 | utils.summarize( 590 | writer=writer_eval, 591 | global_step=global_step, 592 | images=image_dict, 593 | audios=audio_dict, 594 | audio_sampling_rate=hps.data.sampling_rate, 595 | ) 596 | generator.train() 597 | 598 | 599 | if __name__ == "__main__": 600 | main() 601 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 21 | iteration = checkpoint_dict["iteration"] 22 | learning_rate = checkpoint_dict["learning_rate"] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict["optimizer"]) 25 | saved_state_dict = checkpoint_dict["model"] 26 | if hasattr(model, "module"): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, "module"): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info( 42 | "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) 43 | ) 44 | return model, optimizer, learning_rate, iteration 45 | 46 | 47 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 48 | logger.info( 49 | "Saving model and optimizer state at iteration {} to {}".format( 50 | iteration, checkpoint_path 51 | ) 52 | ) 53 | if hasattr(model, "module"): 54 | state_dict = model.module.state_dict() 55 | else: 56 | state_dict = model.state_dict() 57 | torch.save( 58 | { 59 | "model": state_dict, 60 | "iteration": iteration, 61 | "optimizer": optimizer.state_dict(), 62 | "learning_rate": learning_rate, 63 | }, 64 | checkpoint_path, 65 | ) 66 | 67 | 68 | def summarize( 69 | writer, 70 | global_step, 71 | scalars={}, 72 | histograms={}, 73 | images={}, 74 | audios={}, 75 | audio_sampling_rate=22050, 76 | ): 77 | for k, v in scalars.items(): 78 | writer.add_scalar(k, v, global_step) 79 | for k, v in histograms.items(): 80 | writer.add_histogram(k, v, global_step) 81 | for k, v in images.items(): 82 | writer.add_image(k, v, global_step, dataformats="HWC") 83 | for k, v in audios.items(): 84 | writer.add_audio(k, v, global_step, audio_sampling_rate) 85 | 86 | 87 | def scan_checkpoint(dir_path, regex): 88 | f_list = glob.glob(os.path.join(dir_path, regex)) 89 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 90 | if len(f_list) == 0: 91 | return None 92 | return f_list 93 | 94 | 95 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 96 | f_list = scan_checkpoint(dir_path, regex) 97 | if not f_list: 98 | return None 99 | x = f_list[-1] 100 | print(x) 101 | return x 102 | 103 | 104 | def remove_old_checkpoints(cp_dir, prefixes=['G_*.pth', 'D_*.pth', 'DUR_*.pth']): 105 | for prefix in prefixes: 106 | sorted_ckpts = scan_checkpoint(cp_dir, prefix) 107 | if sorted_ckpts and len(sorted_ckpts) > 3: 108 | for ckpt_path in sorted_ckpts[:-3]: 109 | os.remove(ckpt_path) 110 | print("removed {}".format(ckpt_path)) 111 | 112 | 113 | def plot_spectrogram_to_numpy(spectrogram): 114 | global MATPLOTLIB_FLAG 115 | if not MATPLOTLIB_FLAG: 116 | import matplotlib 117 | 118 | matplotlib.use("Agg") 119 | MATPLOTLIB_FLAG = True 120 | mpl_logger = logging.getLogger("matplotlib") 121 | mpl_logger.setLevel(logging.WARNING) 122 | import matplotlib.pylab as plt 123 | import numpy as np 124 | 125 | fig, ax = plt.subplots(figsize=(10, 2)) 126 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 127 | plt.colorbar(im, ax=ax) 128 | plt.xlabel("Frames") 129 | plt.ylabel("Channels") 130 | plt.tight_layout() 131 | 132 | fig.canvas.draw() 133 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 134 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 135 | plt.close() 136 | return data 137 | 138 | 139 | def plot_alignment_to_numpy(alignment, info=None): 140 | global MATPLOTLIB_FLAG 141 | if not MATPLOTLIB_FLAG: 142 | import matplotlib 143 | 144 | matplotlib.use("Agg") 145 | MATPLOTLIB_FLAG = True 146 | mpl_logger = logging.getLogger("matplotlib") 147 | mpl_logger.setLevel(logging.WARNING) 148 | import matplotlib.pylab as plt 149 | import numpy as np 150 | 151 | fig, ax = plt.subplots(figsize=(6, 4)) 152 | im = ax.imshow( 153 | alignment.transpose(), aspect="auto", origin="lower", interpolation="none" 154 | ) 155 | fig.colorbar(im, ax=ax) 156 | xlabel = "Decoder timestep" 157 | if info is not None: 158 | xlabel += "\n\n" + info 159 | plt.xlabel(xlabel) 160 | plt.ylabel("Encoder timestep") 161 | plt.tight_layout() 162 | 163 | fig.canvas.draw() 164 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 165 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 166 | plt.close() 167 | return data 168 | 169 | 170 | def load_wav_to_torch(full_path): 171 | sampling_rate, data = read(full_path) 172 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 173 | 174 | 175 | def load_filepaths_and_text(filename, split="|"): 176 | with open(filename, encoding="utf-8") as f: 177 | filepaths_and_text = [line.strip().split(split) for line in f] 178 | return filepaths_and_text 179 | 180 | 181 | def get_hparams(init=True): 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument( 184 | "-c", 185 | "--config", 186 | type=str, 187 | default="./configs/base.json", 188 | help="JSON file for configuration", 189 | ) 190 | parser.add_argument("-m", "--model", type=str, required=True, help="Model name") 191 | 192 | args = parser.parse_args() 193 | model_dir = os.path.join("./logs", args.model) 194 | 195 | if not os.path.exists(model_dir): 196 | os.makedirs(model_dir) 197 | 198 | config_path = args.config 199 | config_save_path = os.path.join(model_dir, "config.json") 200 | if init: 201 | with open(config_path, "r") as f: 202 | data = f.read() 203 | with open(config_save_path, "w") as f: 204 | f.write(data) 205 | else: 206 | with open(config_save_path, "r") as f: 207 | data = f.read() 208 | config = json.loads(data) 209 | 210 | hparams = HParams(**config) 211 | hparams.model_dir = model_dir 212 | return hparams 213 | 214 | 215 | def get_hparams_from_dir(model_dir): 216 | config_save_path = os.path.join(model_dir, "config.json") 217 | with open(config_save_path, "r") as f: 218 | data = f.read() 219 | config = json.loads(data) 220 | 221 | hparams = HParams(**config) 222 | hparams.model_dir = model_dir 223 | return hparams 224 | 225 | 226 | def get_hparams_from_file(config_path): 227 | with open(config_path, "r") as f: 228 | data = f.read() 229 | config = json.loads(data) 230 | 231 | hparams = HParams(**config) 232 | return hparams 233 | 234 | 235 | def check_git_hash(model_dir): 236 | source_dir = os.path.dirname(os.path.realpath(__file__)) 237 | if not os.path.exists(os.path.join(source_dir, ".git")): 238 | logger.warn( 239 | "{} is not a git repository, therefore hash value comparison will be ignored.".format( 240 | source_dir 241 | ) 242 | ) 243 | return 244 | 245 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 246 | 247 | path = os.path.join(model_dir, "githash") 248 | if os.path.exists(path): 249 | saved_hash = open(path).read() 250 | if saved_hash != cur_hash: 251 | logger.warn( 252 | "git hash values are different. {}(saved) != {}(current)".format( 253 | saved_hash[:8], cur_hash[:8] 254 | ) 255 | ) 256 | else: 257 | open(path, "w").write(cur_hash) 258 | 259 | 260 | def get_logger(model_dir, filename="train.log"): 261 | global logger 262 | logger = logging.getLogger(os.path.basename(model_dir)) 263 | logger.setLevel(logging.DEBUG) 264 | 265 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 266 | if not os.path.exists(model_dir): 267 | os.makedirs(model_dir) 268 | h = logging.FileHandler(os.path.join(model_dir, filename)) 269 | h.setLevel(logging.DEBUG) 270 | h.setFormatter(formatter) 271 | logger.addHandler(h) 272 | return logger 273 | 274 | 275 | class HParams: 276 | def __init__(self, **kwargs): 277 | for k, v in kwargs.items(): 278 | if type(v) == dict: 279 | v = HParams(**v) 280 | self[k] = v 281 | 282 | def keys(self): 283 | return self.__dict__.keys() 284 | 285 | def items(self): 286 | return self.__dict__.items() 287 | 288 | def values(self): 289 | return self.__dict__.values() 290 | 291 | def __len__(self): 292 | return len(self.__dict__) 293 | 294 | def __getitem__(self, key): 295 | return getattr(self, key) 296 | 297 | def __setitem__(self, key, value): 298 | return setattr(self, key, value) 299 | 300 | def __contains__(self, key): 301 | return key in self.__dict__ 302 | 303 | def __repr__(self): 304 | return self.__dict__.__repr__() 305 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gradio as gr 3 | from gradio import components 4 | import os 5 | import torch 6 | import commons 7 | import utils 8 | from models import SynthesizerTrn 9 | from text.symbols import symbols 10 | from text import text_to_sequence 11 | from scipy.io.wavfile import write 12 | 13 | def get_text(text, hps): 14 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 15 | if hps.data.add_blank: 16 | text_norm = commons.intersperse(text_norm, 0) 17 | text_norm = torch.LongTensor(text_norm) 18 | return text_norm 19 | 20 | def tts(model_path, config_path, text): 21 | model_path = './logs/' + model_path 22 | config_path = './configs/' + config_path 23 | hps = utils.get_hparams_from_file(config_path) 24 | 25 | if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True: 26 | posterior_channels = 80 27 | hps.data.use_mel_posterior_encoder = True 28 | else: 29 | posterior_channels = hps.data.filter_length // 2 + 1 30 | hps.data.use_mel_posterior_encoder = False 31 | 32 | net_g = SynthesizerTrn( 33 | len(symbols), 34 | posterior_channels, 35 | hps.train.segment_size // hps.data.hop_length, 36 | **hps.model).cuda() 37 | _ = net_g.eval() 38 | _ = utils.load_checkpoint(model_path, net_g, None) 39 | 40 | stn_tst = get_text(text, hps) 41 | x_tst = stn_tst.cuda().unsqueeze(0) 42 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 43 | 44 | with torch.no_grad(): 45 | audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy() 46 | 47 | output_wav_path = "output.wav" 48 | write(output_wav_path, hps.data.sampling_rate, audio) 49 | 50 | return output_wav_path 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--model_path', type=str, default=None, help='Path to the model file.') 55 | parser.add_argument('--config_path', type=str, default=None, help='Path to the config file.') 56 | args = parser.parse_args() 57 | 58 | model_files = [f for f in os.listdir('./logs/') if f.endswith('.pth')] 59 | model_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True) 60 | config_files = [f for f in os.listdir('./configs/') if f.endswith('.json')] 61 | 62 | default_model_file = args.model_path if args.model_path else (model_files[0] if model_files else None) 63 | default_config_file = args.config_path if args.config_path else 'config.json' 64 | 65 | gr.Interface( 66 | fn=tts, 67 | inputs=[components.Dropdown(model_files,value=default_model_file, label="Model File"), components.Dropdown(config_files,value=default_config_file, label="Config File"), components.Textbox(label="Text Input")], 68 | outputs=components.Audio(type='filepath', label="Generated Speech"), 69 | live=False 70 | ).launch() 71 | --------------------------------------------------------------------------------