├── .gitignore ├── LICENSE ├── README.md ├── data_utils.py ├── datasets ├── ljs_base │ ├── config.yaml │ ├── filelists │ │ ├── test.txt │ │ ├── train.txt │ │ └── val.txt │ ├── prepare │ │ └── filelists.ipynb │ └── vocab.txt ├── ljs_nosdp │ └── config.yaml ├── madasr23_base │ ├── config.yaml │ └── prepare │ │ ├── filelists.ipynb │ │ └── metadata.ipynb └── vctk_base │ ├── config.yaml │ └── filelists │ ├── vctk_audio_sid_text_test_filelist.txt │ ├── vctk_audio_sid_text_test_filelist.txt.cleaned │ ├── vctk_audio_sid_text_train_filelist.txt │ ├── vctk_audio_sid_text_train_filelist.txt.cleaned │ ├── vctk_audio_sid_text_val_filelist.txt │ └── vctk_audio_sid_text_val_filelist.txt.cleaned ├── figures ├── figure01.png ├── figure02.png └── figure03.png ├── inference.ipynb ├── inference_batch.ipynb ├── losses.py ├── model ├── condition.py ├── decoder.py ├── discriminator.py ├── duration_predictors.py ├── encoders.py ├── models.py ├── modules.py ├── normalization.py ├── normalizing_flows.py └── transformer.py ├── preprocess ├── README.md ├── audio_find_corrupted.ipynb ├── audio_resample.ipynb ├── audio_resampling.py ├── mel_transform.py └── vocab_generation.ipynb ├── requirements.txt ├── text ├── LICENSE ├── __init__.py ├── cleaners.py ├── normalize_numbers.py └── symbols.py ├── train.py ├── train_ms.py └── utils ├── hparams.py ├── mel_processing.py ├── model.py ├── monotonic_align.py ├── task.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | DUMMY* 2 | 3 | __pycache__ 4 | .ipynb_checkpoints 5 | .*.swp 6 | 7 | build 8 | *.c 9 | monotonic_align/monotonic_align 10 | 11 | .vscode 12 | .DS_Store 13 | 14 | logs 15 | test 16 | datasets/madasr23_base/filelists -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jaehyeon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design 2 | 3 | ### Jungil Kong, Jihoon Park, Beomjeong Kim, Jeongmin Kim, Dohee Kong, Sangjin Kim 4 | 5 | ### SK Telecom, South Korea 6 | 7 | Single-stage text-to-speech models have been actively studied recently, and their results have outperformed two-stage pipeline systems. Although the previous single-stage model has made great progress, there is room for improvement in terms of its intermittent unnaturalness, computational efficiency, and strong dependence on phoneme conversion. In this work, we introduce VITS2, a single-stage text-to-speech model that efficiently synthesizes a more natural speech by improving several aspects of the previous work. We propose improved structures and training mechanisms and present that the proposed methods are effective in improving naturalness, similarity of speech characteristics in a multi-speaker model, and efficiency of training and inference. Furthermore, we demonstrate that the strong dependence on phoneme conversion in previous works can be significantly reduced with our method, which allows a fully end-to-end single-stage approach. 8 | 9 | Demo: https://vits-2.github.io/demo/ 10 | 11 | Paper: https://arxiv.org/abs/2307.16430 12 | 13 | Unofficial implementation of VITS2. This is a work in progress. Please refer to [TODO](#todo) for more details. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
Duration PredictorNormalizing FlowsText Encoder
Duration PredictorNormalizing FlowsText Encoder
27 | 28 | ## Audio Samples 29 | 30 | [In progress] 31 | 32 | Audio sample after 52,000 steps of training on 1 GPU for LJSpeech dataset: 33 | https://github.com/daniilrobnikov/vits2/assets/91742765/d769c77a-bd92-4732-96e7-ab53bf50d783 34 | 35 | ## Installation: 36 | 37 | 38 | 39 | **Clone the repo** 40 | 41 | ```shell 42 | git clone git@github.com:daniilrobnikov/vits2.git 43 | cd vits2 44 | ``` 45 | 46 | ## Setting up the conda env 47 | 48 | This is assuming you have navigated to the `vits2` root after cloning it. 49 | 50 | **NOTE:** This is tested under `python3.11` with conda env. For other python versions, you might encounter version conflicts. 51 | 52 | **PyTorch 2.0** 53 | Please refer [requirements.txt](requirements.txt) 54 | 55 | ```shell 56 | # install required packages (for pytorch 2.0) 57 | conda create -n vits2 python=3.11 58 | conda activate vits2 59 | pip install -r requirements.txt 60 | 61 | conda env config vars set PYTHONPATH="/path/to/vits2" 62 | ``` 63 | 64 | ## Download datasets 65 | 66 | There are three options you can choose from: LJ Speech, VCTK, or custom dataset. 67 | 68 | 1. LJ Speech: [LJ Speech dataset](#lj-speech-dataset). Used for single speaker TTS. 69 | 2. VCTK: [VCTK dataset](#vctk-dataset). Used for multi-speaker TTS. 70 | 3. Custom dataset: You can use your own dataset. Please refer [here](#custom-dataset). 71 | 72 | ### LJ Speech dataset 73 | 74 | 1. download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/) 75 | 76 | ```shell 77 | wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 78 | tar -xvf LJSpeech-1.1.tar.bz2 79 | cd LJSpeech-1.1/wavs 80 | rm -rf wavs 81 | ``` 82 | 83 | 3. preprocess mel-spectrograms. See [mel_transform.py](preprocess/mel_transform.py) 84 | 85 | ```shell 86 | python preprocess/mel_transform.py --data_dir /path/to/LJSpeech-1.1 -c datasets/ljs_base/config.yaml 87 | ``` 88 | 89 | 3. preprocess text. See [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb) 90 | 91 | 4. rename or create a link to the dataset folder. 92 | 93 | ```shell 94 | ln -s /path/to/LJSpeech-1.1 DUMMY1 95 | ``` 96 | 97 | ### VCTK dataset 98 | 99 | 1. download and extract the [VCTK dataset](https://www.kaggle.com/datasets/showmik50/vctk-dataset) 100 | 101 | ```shell 102 | wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip 103 | unzip VCTK-Corpus-0.92.zip 104 | ``` 105 | 106 | 2. (optional): downsample the audio files to 22050 Hz. See [audio_resample.ipynb](preprocess/audio_resample.ipynb) 107 | 108 | 3. preprocess mel-spectrograms. See [mel_transform.py](preprocess/mel_transform.py) 109 | 110 | ```shell 111 | python preprocess/mel_transform.py --data_dir /path/to/VCTK-Corpus-0.92 -c datasets/vctk_base/config.yaml 112 | ``` 113 | 114 | 4. preprocess text. See [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb) 115 | 116 | 5. rename or create a link to the dataset folder. 117 | 118 | ```shell 119 | ln -s /path/to/VCTK-Corpus-0.92 DUMMY2 120 | ``` 121 | 122 | ### Custom dataset 123 | 124 | 1. create a folder with wav files 125 | 2. duplicate the `ljs_base` in `datasets` directory and rename it to `custom_base` 126 | 3. open [custom_base](datasets/custom_base) and change the following fields in `config.yaml`: 127 | 128 | ```yaml 129 | data: 130 | training_files: datasets/custom_base/filelists/train.txt 131 | validation_files: datasets/custom_base/filelists/val.txt 132 | text_cleaners: # See text/cleaners.py 133 | - phonemize_text 134 | - tokenize_text 135 | - add_bos_eos 136 | cleaned_text: true # True if you ran step 6. 137 | language: en-us # language of your dataset. See espeak-ng 138 | sample_rate: 22050 # sample rate, based on your dataset 139 | ... 140 | n_speakers: 0 # 0 for single speaker, > 0 for multi-speaker 141 | ``` 142 | 143 | 4. preprocess mel-spectrograms. See [mel_transform.py](preprocess/mel_transform.py) 144 | 145 | ```shell 146 | python preprocess/mel_transform.py --data_dir /path/to/custom_dataset -c datasets/custom_base/config.yaml 147 | ``` 148 | 149 | 6. preprocess text. See [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb) 150 | 151 | **NOTE:** You may need to install `espeak-ng` if you want to use `phonemize_text` cleaner. Please refer [espeak-ng](https://github.com/espeak-ng/espeak-ng) 152 | 153 | 7. rename or create a link to the dataset folder. 154 | 155 | ```shell 156 | ln -s /path/to/custom_dataset DUMMY3 157 | ``` 158 | 159 | ## Training Examples 160 | 161 | ```shell 162 | # LJ Speech 163 | python train.py -c datasets/ljs_base/config.yaml -m ljs_base 164 | 165 | # VCTK 166 | python train_ms.py -c datasets/vctk_base/config.yaml -m vctk_base 167 | 168 | # Custom dataset (multi-speaker) 169 | python train_ms.py -c datasets/custom_base/config.yaml -m custom_base 170 | ``` 171 | 172 | ## Inference Examples 173 | 174 | See [inference.ipynb](inference.ipynb) and [inference_batch.ipynb](inference_batch.ipynb) 175 | 176 | ## Pretrained Models 177 | 178 | [In progress] 179 | 180 | ## Todo 181 | 182 | - [ ] model (vits2) 183 | - [x] update TextEncoder to support speaker conditioning 184 | - [x] support for high-resolution mel-spectrograms in training. See [mel_transform.py](preprocess/mel_transform.py) 185 | - [x] Monotonic Alignment Search with Gaussian noise 186 | - [x] Normalizing Flows using Transformer Block 187 | - [ ] Stochastic Duration Predictor with Time Step-wise Conditional Discriminator 188 | - [ ] model (YourTTS) 189 | - [ ] Language Conditioning 190 | - [ ] Speaker Encoder 191 | - [ ] model (NaturalSpeech) 192 | - [x] KL Divergence Loss after Prior Enhancing 193 | - [ ] GAN loss for e2e training 194 | - [ ] other 195 | - [x] support for batch inference 196 | - [x] special tokens in tokenizer 197 | - [x] test numba.jit and numba.cuda.jit implementations of MAS. See [monotonic_align.py](monotonic_align.py) 198 | - [ ] KL Divergence Loss between TextEncoder and Projection 199 | - [ ] support for streaming inference. Please refer [vits_chinese](https://github.com/PlayVoice/vits_chinese/blob/master/text/symbols.py) 200 | - [ ] use optuna for hyperparameter tuning 201 | - [ ] future work 202 | - [ ] update model to vits2. Please refer [VITS2](https://arxiv.org/abs/2307.16430) 203 | - [ ] update model to YourTTS with zero-shot learning. See [YourTTS](https://arxiv.org/abs/2112.02418) 204 | - [ ] update model to NaturalSpeech. Please refer [NaturalSpeech](https://arxiv.org/abs/2205.04421) 205 | 206 | ## Acknowledgements 207 | 208 | - This is unofficial repo based on [VITS2](https://arxiv.org/abs/2307.16430) 209 | - g2p for multiple languages is based on [phonemizer](https://github.com/bootphon/phonemizer) 210 | - We also thank GhatGPT for providing writing assistance. 211 | 212 | ## References 213 | 214 | - [VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design](https://arxiv.org/abs/2307.16430) 215 | - [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) 216 | - [YourTTS: Towards Zero-Shot Multi-Speaker TTS and Zero-Shot Voice Conversion for everyone](https://arxiv.org/abs/2112.02418) 217 | - [NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality](https://arxiv.org/abs/2205.04421) 218 | - [A TensorFlow implementation of Google's Tacotron speech synthesis with pre-trained model (unofficial)](https://github.com/keithito/tacotron) 219 | 220 | # VITS2 221 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.utils.data 5 | 6 | from utils.mel_processing import wav_to_spec, wav_to_mel 7 | from utils.task import load_vocab, load_wav_to_torch, load_filepaths_and_text 8 | from text import tokenizer 9 | 10 | 11 | class TextAudioLoader(torch.utils.data.Dataset): 12 | """ 13 | 1) loads audio, text pairs 14 | 2) normalizes text and converts them to sequences of integers 15 | 3) computes spectrograms from audio files. 16 | """ 17 | 18 | def __init__(self, audiopaths_and_text, hps_data): 19 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 20 | self.vocab = load_vocab(hps_data.vocab_file) 21 | self.text_cleaners = hps_data.text_cleaners 22 | self.sample_rate = hps_data.sample_rate 23 | self.n_fft = hps_data.n_fft 24 | self.hop_length = hps_data.hop_length 25 | self.win_length = hps_data.win_length 26 | self.n_mels = hps_data.n_mels 27 | self.f_min = hps_data.f_min 28 | self.f_max = hps_data.f_max 29 | self.use_mel = hps_data.use_mel 30 | 31 | self.language = getattr(hps_data, "language", "en-us") 32 | self.cleaned_text = getattr(hps_data, "cleaned_text", False) 33 | self.min_text_len = getattr(hps_data, "min_text_len", 1) 34 | self.max_text_len = getattr(hps_data, "max_text_len", 200) 35 | 36 | random.seed(1234) 37 | random.shuffle(self.audiopaths_and_text) 38 | self._filter() 39 | 40 | def _filter(self): 41 | """ 42 | Filter text & store spec lengths 43 | """ 44 | # Store spectrogram lengths for Bucketing 45 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 46 | # spec_length = wav_length // hop_length 47 | 48 | audiopaths_and_text_new = [] 49 | lengths = [] 50 | for audiopath, text in self.audiopaths_and_text: 51 | text_len = text.count("\t") + 1 52 | if self.min_text_len <= text_len and text_len <= self.max_text_len: 53 | audiopaths_and_text_new.append([audiopath, text]) 54 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 55 | self.audiopaths_and_text = audiopaths_and_text_new 56 | self.lengths = lengths 57 | 58 | def get_audio_text_pair(self, audiopath_and_text): 59 | # separate filename and text 60 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 61 | text = self.get_text(text) 62 | wav = self.get_audio(audiopath) 63 | spec = self.get_spec(audiopath, wav) 64 | return (text, spec, wav) 65 | 66 | def get_text(self, text): 67 | text_norm = tokenizer(text, self.vocab, self.text_cleaners, language=self.language, cleaned_text=self.cleaned_text) 68 | text_norm = torch.LongTensor(text_norm) 69 | return text_norm 70 | 71 | def get_audio(self, filename): 72 | audio, sample_rate = load_wav_to_torch(filename) 73 | assert sample_rate == self.sample_rate, f"{sample_rate} SR doesn't match target {self.sample_rate} SR" 74 | return audio 75 | 76 | def get_spec(self, filename: str, wav): 77 | spec_filename = filename.replace(".wav", ".spec.pt") 78 | 79 | if os.path.exists(spec_filename): 80 | spec = torch.load(spec_filename) 81 | else: 82 | if self.use_mel: 83 | spec = wav_to_mel(wav, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, self.win_length, self.f_min, self.f_max, center=False, norm=False) 84 | else: 85 | spec = wav_to_spec(wav, self.n_fft, self.sample_rate, self.hop_length, self.win_length, center=False) 86 | spec = torch.squeeze(spec, 0) 87 | torch.save(spec, spec_filename) 88 | 89 | return spec 90 | 91 | def __getitem__(self, index): 92 | return self.get_audio_text_pair(self.audiopaths_and_text[index]) 93 | 94 | def __len__(self): 95 | return len(self.audiopaths_and_text) 96 | 97 | 98 | class TextAudioCollate: 99 | """Zero-pads model inputs and targets""" 100 | 101 | def __init__(self, return_ids=False): 102 | self.return_ids = return_ids 103 | 104 | def __call__(self, batch): 105 | """Collate's training batch from normalized text and aduio 106 | PARAMS 107 | ------ 108 | batch: [text_normalized, spec_normalized, wav_normalized] 109 | """ 110 | # Right zero-pad all one-hot text sequences to max input length 111 | _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) 112 | 113 | max_text_len = max([len(x[0]) for x in batch]) 114 | max_spec_len = max([x[1].size(1) for x in batch]) 115 | max_wav_len = max([x[2].size(1) for x in batch]) 116 | 117 | text_lengths = torch.LongTensor(len(batch)) 118 | spec_lengths = torch.LongTensor(len(batch)) 119 | wav_lengths = torch.LongTensor(len(batch)) 120 | 121 | text_padded = torch.LongTensor(len(batch), max_text_len) 122 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 123 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 124 | text_padded.zero_() 125 | spec_padded.zero_() 126 | wav_padded.zero_() 127 | for i in range(len(ids_sorted_decreasing)): 128 | row = batch[ids_sorted_decreasing[i]] 129 | 130 | text = row[0] 131 | text_padded[i, : text.size(0)] = text 132 | text_lengths[i] = text.size(0) 133 | 134 | spec = row[1] 135 | spec_padded[i, :, : spec.size(1)] = spec 136 | spec_lengths[i] = spec.size(1) 137 | 138 | wav = row[2] 139 | wav_padded[i, :, : wav.size(1)] = wav 140 | wav_lengths[i] = wav.size(1) 141 | 142 | if self.return_ids: 143 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing 144 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths 145 | 146 | 147 | """Multi speaker version""" 148 | 149 | 150 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 151 | """ 152 | 1) loads audio, speaker_id, text pairs 153 | 2) normalizes text and converts them to sequences of integers 154 | 3) computes spectrograms from audio files. 155 | """ 156 | 157 | def __init__(self, audiopaths_sid_text, hps_data): 158 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 159 | self.vocab = load_vocab(hps_data.vocab_file) 160 | self.text_cleaners = hps_data.text_cleaners 161 | self.sample_rate = hps_data.sample_rate 162 | self.n_fft = hps_data.n_fft 163 | self.hop_length = hps_data.hop_length 164 | self.win_length = hps_data.win_length 165 | self.n_mels = hps_data.n_mels 166 | self.f_min = hps_data.f_min 167 | self.f_max = hps_data.f_max 168 | self.use_mel = hps_data.use_mel 169 | 170 | self.language = getattr(hps_data, "language", "en-us") 171 | self.cleaned_text = getattr(hps_data, "cleaned_text", False) 172 | self.min_text_len = getattr(hps_data, "min_text_len", 1) 173 | self.max_text_len = getattr(hps_data, "max_text_len", 200) 174 | 175 | random.seed(1234) 176 | random.shuffle(self.audiopaths_sid_text) 177 | self._filter() 178 | 179 | def _filter(self): 180 | """ 181 | Filter text & store spec lengths 182 | """ 183 | # Store spectrogram lengths for Bucketing 184 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 185 | # spec_length = wav_length // hop_length 186 | 187 | audiopaths_sid_text_new = [] 188 | lengths = [] 189 | for audiopath, sid, text in self.audiopaths_sid_text: 190 | text_len = text.count("\t") + 1 191 | if self.min_text_len <= text_len and text_len <= self.max_text_len: 192 | audiopaths_sid_text_new.append([audiopath, sid, text]) 193 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 194 | self.audiopaths_sid_text = audiopaths_sid_text_new 195 | self.lengths = lengths 196 | 197 | def get_audio_text_speaker_pair(self, audiopath_sid_text): 198 | # separate filename, speaker_id and text 199 | audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2] 200 | text = self.get_text(text) 201 | wav = self.get_audio(audiopath) 202 | spec = self.get_spec(audiopath, wav) 203 | sid = self.get_sid(sid) 204 | return (text, spec, wav, sid) 205 | 206 | def get_text(self, text): 207 | text_norm = tokenizer(text, self.vocab, self.text_cleaners, language=self.language, cleaned_text=self.cleaned_text) 208 | text_norm = torch.LongTensor(text_norm) 209 | return text_norm 210 | 211 | def get_audio(self, filename): 212 | audio, sample_rate = load_wav_to_torch(filename) 213 | assert sample_rate == self.sample_rate, f"{sample_rate} SR doesn't match target {self.sample_rate} SR" 214 | return audio 215 | 216 | def get_spec(self, filename: str, wav): 217 | spec_filename = filename.replace(".wav", ".spec.pt") 218 | 219 | if os.path.exists(spec_filename): 220 | spec = torch.load(spec_filename) 221 | else: 222 | if self.use_mel: 223 | spec = wav_to_mel(wav, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, self.win_length, self.f_min, self.f_max, center=False, norm=False) 224 | else: 225 | spec = wav_to_spec(wav, self.n_fft, self.sample_rate, self.hop_length, self.win_length, center=False) 226 | spec = torch.squeeze(spec, 0) 227 | torch.save(spec, spec_filename) 228 | 229 | return spec 230 | 231 | def get_sid(self, sid): 232 | sid = torch.LongTensor([int(sid)]) 233 | return sid 234 | 235 | def __getitem__(self, index): 236 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) 237 | 238 | def __len__(self): 239 | return len(self.audiopaths_sid_text) 240 | 241 | 242 | class TextAudioSpeakerCollate: 243 | """Zero-pads model inputs and targets""" 244 | 245 | def __init__(self, return_ids=False): 246 | self.return_ids = return_ids 247 | 248 | def __call__(self, batch): 249 | """Collate's training batch from normalized text, audio and speaker identities 250 | PARAMS 251 | ------ 252 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 253 | """ 254 | # Right zero-pad all one-hot text sequences to max input length 255 | _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True) 256 | 257 | max_text_len = max([len(x[0]) for x in batch]) 258 | max_spec_len = max([x[1].size(1) for x in batch]) 259 | max_wav_len = max([x[2].size(1) for x in batch]) 260 | 261 | text_lengths = torch.LongTensor(len(batch)) 262 | spec_lengths = torch.LongTensor(len(batch)) 263 | wav_lengths = torch.LongTensor(len(batch)) 264 | sid = torch.LongTensor(len(batch)) 265 | 266 | text_padded = torch.LongTensor(len(batch), max_text_len) 267 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 268 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 269 | text_padded.zero_() 270 | spec_padded.zero_() 271 | wav_padded.zero_() 272 | for i in range(len(ids_sorted_decreasing)): 273 | row = batch[ids_sorted_decreasing[i]] 274 | 275 | text = row[0] 276 | text_padded[i, : text.size(0)] = text 277 | text_lengths[i] = text.size(0) 278 | 279 | spec = row[1] 280 | spec_padded[i, :, : spec.size(1)] = spec 281 | spec_lengths[i] = spec.size(1) 282 | 283 | wav = row[2] 284 | wav_padded[i, :, : wav.size(1)] = wav 285 | wav_lengths[i] = wav.size(1) 286 | 287 | sid[i] = row[3] 288 | 289 | if self.return_ids: 290 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing 291 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid 292 | 293 | 294 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 295 | """ 296 | Maintain similar input lengths in a batch. 297 | Length groups are specified by boundaries. 298 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 299 | 300 | It removes samples which are not included in the boundaries. 301 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 302 | """ 303 | 304 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): 305 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 306 | self.lengths = dataset.lengths 307 | self.batch_size = batch_size 308 | self.boundaries = boundaries 309 | 310 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 311 | self.total_size = sum(self.num_samples_per_bucket) 312 | self.num_samples = self.total_size // self.num_replicas 313 | 314 | def _create_buckets(self): 315 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 316 | for i in range(len(self.lengths)): 317 | length = self.lengths[i] 318 | idx_bucket = self._bisect(length) 319 | if idx_bucket != -1: 320 | buckets[idx_bucket].append(i) 321 | 322 | for i in range(len(buckets) - 1, 0, -1): 323 | if len(buckets[i]) == 0: 324 | buckets.pop(i) 325 | self.boundaries.pop(i + 1) 326 | 327 | num_samples_per_bucket = [] 328 | for i in range(len(buckets)): 329 | len_bucket = len(buckets[i]) 330 | total_batch_size = self.num_replicas * self.batch_size 331 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 332 | num_samples_per_bucket.append(len_bucket + rem) 333 | return buckets, num_samples_per_bucket 334 | 335 | def __iter__(self): 336 | # deterministically shuffle based on epoch 337 | g = torch.Generator() 338 | g.manual_seed(self.epoch) 339 | 340 | indices = [] 341 | if self.shuffle: 342 | for bucket in self.buckets: 343 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 344 | else: 345 | for bucket in self.buckets: 346 | indices.append(list(range(len(bucket)))) 347 | 348 | batches = [] 349 | for i in range(len(self.buckets)): 350 | bucket = self.buckets[i] 351 | len_bucket = len(bucket) 352 | ids_bucket = indices[i] 353 | num_samples_bucket = self.num_samples_per_bucket[i] 354 | 355 | # add extra samples to make it evenly divisible 356 | rem = num_samples_bucket - len_bucket 357 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)] 358 | 359 | # subsample 360 | ids_bucket = ids_bucket[self.rank :: self.num_replicas] 361 | 362 | # batching 363 | for j in range(len(ids_bucket) // self.batch_size): 364 | batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]] 365 | batches.append(batch) 366 | 367 | if self.shuffle: 368 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 369 | batches = [batches[i] for i in batch_ids] 370 | self.batches = batches 371 | 372 | assert len(self.batches) * self.batch_size == self.num_samples 373 | return iter(self.batches) 374 | 375 | def _bisect(self, x, lo=0, hi=None): 376 | if hi is None: 377 | hi = len(self.boundaries) - 1 378 | 379 | if hi > lo: 380 | mid = (hi + lo) // 2 381 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 382 | return mid 383 | elif x <= self.boundaries[mid]: 384 | return self._bisect(x, lo, mid) 385 | else: 386 | return self._bisect(x, mid + 1, hi) 387 | else: 388 | return -1 389 | 390 | def __len__(self): 391 | return self.num_samples // self.batch_size 392 | -------------------------------------------------------------------------------- /datasets/ljs_base/config.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | log_interval: 100 3 | eval_interval: 1000 4 | seed: 1234 5 | epochs: 20000 6 | learning_rate: 2.0e-4 7 | betas: [0.8, 0.99] 8 | eps: 1.0e-09 9 | batch_size: 64 # TODO Try more 10 | fp16_run: true 11 | lr_decay: 0.999875 12 | segment_size: 8192 13 | init_lr_ratio: 1 14 | warmup_epochs: 0 15 | c_mel: 45 16 | c_kl_text: 0 # default: 0 17 | c_kl_dur: 2 # default: 2 18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1 19 | 20 | data: 21 | training_files: datasets/ljs_base/filelists/train.txt 22 | validation_files: datasets/ljs_base/filelists/val.txt 23 | vocab_file: datasets/ljs_base/vocab.txt 24 | text_cleaners: 25 | - phonemize_text 26 | - add_spaces 27 | - tokenize_text 28 | - add_bos_eos 29 | cleaned_text: true 30 | language: en-us 31 | bits_per_sample: 16 32 | sample_rate: 22050 33 | n_fft: 2048 34 | hop_length: 256 35 | win_length: 1024 36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80 37 | f_min: 0 38 | f_max: 39 | n_speakers: 0 40 | use_mel: true 41 | 42 | model: 43 | inter_channels: 192 44 | hidden_channels: 192 45 | filter_channels: 768 46 | n_heads: 2 47 | n_layers: 6 48 | n_layers_q: 12 # default: 16 49 | n_flows: 8 # default: 4 50 | kernel_size: 3 51 | p_dropout: 0.1 52 | speaker_cond_layer: 0 # 0 to disable speaker conditioning 53 | resblock: "1" 54 | resblock_kernel_sizes: [3, 7, 11] 55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 56 | upsample_rates: [8, 8, 2, 2] 57 | upsample_initial_channel: 512 58 | upsample_kernel_sizes: [16, 16, 4, 4] 59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS 60 | mas_noise_scale_decay: 2.0e-06 61 | use_spectral_norm: false 62 | use_transformer_flow: false 63 | -------------------------------------------------------------------------------- /datasets/ljs_base/vocab.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | n 6 8 | t 7 9 | ə 8 10 | s 9 11 | d 10 12 | ð 11 13 | ɹ 12 14 | k 13 15 | z 14 16 | ɪ 15 17 | l 16 18 | m 17 19 | ˈɪ 18 20 | p 19 21 | w 20 22 | v 21 23 | ˈɛ 22 24 | f 23 25 | ˈeɪ 24 26 | b 25 27 | ɚ 26 28 | , 27 29 | ʌ 28 30 | ˈæ 29 31 | h 30 32 | ᵻ 31 33 | i 32 34 | æ 33 35 | . 34 36 | ˈaɪ 35 37 | ˈiː 36 38 | ʃ 37 39 | uː 38 40 | ˈoʊ 39 41 | ˈɑː 40 42 | ˈʌ 41 43 | ŋ 42 44 | əl 43 45 | ˈuː 44 46 | ɾ 45 47 | ɡ 46 48 | ɐ 47 49 | ˈɜː 48 50 | dʒ 49 51 | tʃ 50 52 | iː 51 53 | j 52 54 | ˈaʊ 53 55 | θ 54 56 | ˌɪ 55 57 | ˈɔː 56 58 | ˈɔ 57 59 | ˈoːɹ 58 60 | ɔːɹ 59 61 | ɛ 60 62 | ˌɛ 61 63 | ˌʌ 62 64 | ˈɑːɹ 63 65 | ˌæ 64 66 | ˈɔːɹ 65 67 | ˈʊ 66 68 | ɜː 67 69 | oʊ 68 70 | eɪ 69 71 | ˈɛɹ 70 72 | ˈɪɹ 71 73 | " 72 74 | ˌeɪ 73 75 | iə 74 76 | ʊ 75 77 | ˌaɪ 76 78 | ˈɔɪ 77 79 | ˌɑː 78 80 | ; 79 81 | aɪ 80 82 | ɛɹ 81 83 | ˈʊɹ 82 84 | ɑːɹ 83 85 | ʒ 84 86 | ˈaɪɚ 85 87 | ˌiː 86 88 | ˌuː 87 89 | ˌoʊ 88 90 | aʊ 89 91 | ˈiə 90 92 | ɑː 91 93 | ɔː 92 94 | n̩ 93 95 | ʔ 94 96 | ˈaɪə 95 97 | : 96 98 | oːɹ 97 99 | ˌaʊ 98 100 | ˌɑːɹ 99 101 | ˌɜː 100 102 | ˌoː 101 103 | ˈoː 102 104 | ? 103 105 | ˌɔːɹ 104 106 | ˌɔː 105 107 | ɪɹ 106 108 | ʊɹ 107 109 | oː 108 110 | ! 109 111 | ɔɪ 110 112 | ˌʊɹ 111 113 | ˌʊ 112 114 | ˌiə 113 115 | ˌɔɪ 114 116 | r 115 117 | ɔ 116 118 | ˌoːɹ 117 119 | aɪə 118 120 | ˌɪɹ 119 121 | aɪɚ 120 122 | ˌɔ 121 123 | ˌɛɹ 122 124 | x 123 125 | “ 124 126 | ” 125 127 | ˈɚ 126 128 | ˌaɪɚ 127 129 | ˌn̩ 128 130 | -------------------------------------------------------------------------------- /datasets/ljs_nosdp/config.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | log_interval: 100 3 | eval_interval: 1000 4 | seed: 1234 5 | epochs: 20000 6 | learning_rate: 2.0e-4 7 | betas: [0.8, 0.99] 8 | eps: 1.0e-09 9 | batch_size: 64 10 | fp16_run: true 11 | lr_decay: 0.999875 12 | segment_size: 8192 13 | init_lr_ratio: 1 14 | warmup_epochs: 0 15 | c_mel: 45 16 | c_kl_text: 0 # default: 0 17 | c_kl_dur: 2 # default: 2 18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1 19 | 20 | data: 21 | training_files: datasets/ljs_base/filelists/train.txt 22 | validation_files: datasets/ljs_base/filelists/val.txt 23 | vocab_file: datasets/ljs_base/vocab.txt 24 | text_cleaners: 25 | - phonemize_text 26 | - add_spaces 27 | - tokenize_text 28 | - add_bos_eos 29 | cleaned_text: true 30 | language: en-us 31 | bits_per_sample: 16 32 | sample_rate: 22050 33 | n_fft: 2048 34 | hop_length: 256 35 | win_length: 1024 36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80 37 | f_min: 0 38 | f_max: 39 | n_speakers: 0 40 | use_mel: true 41 | 42 | model: 43 | inter_channels: 192 44 | hidden_channels: 192 45 | filter_channels: 768 46 | n_heads: 2 47 | n_layers: 6 48 | n_layers_q: 12 # default: 16 49 | n_flows: 8 # default: 4 50 | kernel_size: 3 51 | p_dropout: 0.1 52 | speaker_cond_layer: 0 # 0 to disable speaker conditioning 53 | resblock: "1" 54 | resblock_kernel_sizes: [3, 7, 11] 55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 56 | upsample_rates: [8, 8, 2, 2] 57 | upsample_initial_channel: 512 58 | upsample_kernel_sizes: [16, 16, 4, 4] 59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS 60 | mas_noise_scale_decay: 2.0e-06 61 | use_spectral_norm: false 62 | use_transformer_flow: false 63 | use_sdp: false 64 | -------------------------------------------------------------------------------- /datasets/madasr23_base/config.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | log_interval: 500 3 | eval_interval: 5000 4 | seed: 1234 5 | epochs: 10000 6 | learning_rate: 0.0002 7 | betas: [0.8, 0.99] 8 | eps: 1.0e-09 9 | batch_size: 64 10 | fp16_run: true 11 | lr_decay: 0.999875 12 | segment_size: 8192 13 | init_lr_ratio: 1 14 | warmup_epochs: 0 15 | c_mel: 45 16 | c_kl_text: 0 # default: 0 17 | c_kl_dur: 2 # default: 2 18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1 19 | 20 | data: 21 | training_files: datasets/madasr23_base/filelists/train.txt 22 | validation_files: datasets/madasr23_base/filelists/val.txt 23 | vocab_file: datasets/madasr23_base/vocab.txt 24 | text_cleaners: 25 | - phonemize_text 26 | - add_spaces 27 | - tokenize_text 28 | - add_bos_eos 29 | cleaned_text: true 30 | language: bn 31 | bits_per_sample: 16 32 | sample_rate: 16000 33 | n_fft: 2048 34 | hop_length: 256 35 | win_length: 1024 36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80 37 | f_min: 0 38 | f_max: 39 | n_speakers: 2011 40 | use_mel: true 41 | 42 | model: 43 | inter_channels: 192 44 | hidden_channels: 192 45 | filter_channels: 768 46 | n_heads: 2 47 | n_layers: 6 48 | n_layers_q: 12 # default: 16 49 | n_flows: 8 # default: 4 50 | kernel_size: 3 51 | p_dropout: 0.1 52 | speaker_cond_layer: 3 # 0 to disable speaker conditioning 53 | resblock: "1" 54 | resblock_kernel_sizes: [3, 7, 11] 55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 56 | upsample_rates: [8, 8, 2, 2] 57 | upsample_initial_channel: 512 58 | upsample_kernel_sizes: [16, 16, 4, 4] 59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS 60 | mas_noise_scale_decay: 2.0e-06 61 | use_spectral_norm: false 62 | use_transformer_flow: false 63 | gin_channels: 256 64 | -------------------------------------------------------------------------------- /datasets/vctk_base/config.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | log_interval: 200 3 | eval_interval: 1000 4 | seed: 1234 5 | epochs: 10000 6 | learning_rate: 0.0002 7 | betas: [0.8, 0.99] 8 | eps: 1.0e-09 9 | batch_size: 64 10 | fp16_run: true 11 | lr_decay: 0.999875 12 | segment_size: 8192 13 | init_lr_ratio: 1 14 | warmup_epochs: 0 15 | c_mel: 45 16 | c_kl_text: 0 # default: 0 17 | c_kl_dur: 2 # default: 2 18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1 19 | 20 | data: 21 | training_files: datasets/vctk_base/filelists/vctk_audio_sid_text_train_filelist.txt.cleaned 22 | validation_files: datasets/vctk_base/filelists/vctk_audio_sid_text_val_filelist.txt.cleaned 23 | vocab_file: datasets/vctk_base/vocab.txt 24 | text_cleaners: 25 | - phonemize_text 26 | - add_spaces 27 | - tokenize_text 28 | - add_bos_eos 29 | cleaned_text: true 30 | language: en-us 31 | bits_per_sample: 16 32 | sample_rate: 22050 33 | n_fft: 2048 34 | hop_length: 256 35 | win_length: 1024 36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80 37 | f_min: 0 38 | f_max: 39 | n_speakers: 109 40 | use_mel: true 41 | 42 | model: 43 | inter_channels: 192 44 | hidden_channels: 192 45 | filter_channels: 768 46 | n_heads: 2 47 | n_layers: 6 48 | n_layers_q: 12 # default: 16 49 | n_flows: 8 # default: 4 50 | kernel_size: 3 51 | p_dropout: 0.1 52 | speaker_cond_layer: 3 # 0 to disable speaker conditioning 53 | resblock: "1" 54 | resblock_kernel_sizes: [3, 7, 11] 55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 56 | upsample_rates: [8, 8, 2, 2] 57 | upsample_initial_channel: 512 58 | upsample_kernel_sizes: [16, 16, 4, 4] 59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS 60 | mas_noise_scale_decay: 2.0e-06 61 | use_spectral_norm: false 62 | use_transformer_flow: false 63 | gin_channels: 256 64 | -------------------------------------------------------------------------------- /datasets/vctk_base/filelists/vctk_audio_sid_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|It had happened to him. 2 | DUMMY2/p280/p280_148.wav|52|It is open season on the Old Firm. 3 | DUMMY2/p231/p231_320.wav|50|However, he is a coach, and he remains a coach at heart. 4 | DUMMY2/p282/p282_129.wav|83|It is not a U-turn. 5 | DUMMY2/p254/p254_015.wav|41|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain. 6 | DUMMY2/p228/p228_285.wav|57|The songs are just so good. 7 | DUMMY2/p334/p334_307.wav|38|If they don't, they can expect their funding to be cut. 8 | DUMMY2/p287/p287_081.wav|77|I've never seen anything like it. 9 | DUMMY2/p247/p247_083.wav|14|It is a job creation scheme.) 10 | DUMMY2/p264/p264_051.wav|65|We were leading by two goals.) 11 | DUMMY2/p335/p335_058.wav|49|Let's see that increase over the years. 12 | DUMMY2/p236/p236_225.wav|75|There is no quick fix. 13 | DUMMY2/p374/p374_353.wav|11|And that brings us to the point. 14 | DUMMY2/p272/p272_076.wav|69|Sounds like The Sixth Sense? 15 | DUMMY2/p271/p271_152.wav|27|The petition was formally presented at Downing Street yesterday. 16 | DUMMY2/p228/p228_127.wav|57|They've got to account for it. 17 | DUMMY2/p276/p276_223.wav|106|It's been a humbling year. 18 | DUMMY2/p262/p262_248.wav|45|The project has already secured the support of Sir Sean Connery. 19 | DUMMY2/p314/p314_086.wav|51|The team this year is going places. 20 | DUMMY2/p225/p225_038.wav|101|Diving is no part of football. 21 | DUMMY2/p279/p279_088.wav|25|The shareholders will vote to wind up the company on Friday morning. 22 | DUMMY2/p272/p272_018.wav|69|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain. 23 | DUMMY2/p256/p256_098.wav|90|She told The Herald. 24 | DUMMY2/p261/p261_218.wav|100|All will be revealed in due course. 25 | DUMMY2/p265/p265_063.wav|73|IT shouldn't come as a surprise, but it does. 26 | DUMMY2/p314/p314_042.wav|51|It is all about people being assaulted, abused. 27 | DUMMY2/p241/p241_188.wav|86|I wish I could say something. 28 | DUMMY2/p283/p283_111.wav|95|It's good to have a voice. 29 | DUMMY2/p275/p275_006.wav|40|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 30 | DUMMY2/p228/p228_092.wav|57|Today I couldn't run on it. 31 | DUMMY2/p295/p295_343.wav|92|The atmosphere is businesslike. 32 | DUMMY2/p228/p228_187.wav|57|They will run a mile. 33 | DUMMY2/p294/p294_317.wav|104|It didn't put me off. 34 | DUMMY2/p231/p231_445.wav|50|It sounded like a bomb. 35 | DUMMY2/p272/p272_086.wav|69|Today she has been released. 36 | DUMMY2/p255/p255_210.wav|31|It was worth a photograph. 37 | DUMMY2/p229/p229_060.wav|67|And a film maker was born. 38 | DUMMY2/p260/p260_232.wav|81|The Home Office would not release any further details about the group. 39 | DUMMY2/p245/p245_025.wav|59|Johnson was pretty low. 40 | DUMMY2/p333/p333_185.wav|64|This area is perfect for children. 41 | DUMMY2/p244/p244_242.wav|78|He is a man of the people. 42 | DUMMY2/p376/p376_187.wav|71|"It is a terrible loss." 43 | DUMMY2/p239/p239_156.wav|48|It is a good lifestyle. 44 | DUMMY2/p307/p307_037.wav|22|He released a half-dozen solo albums. 45 | DUMMY2/p305/p305_185.wav|54|I am not even thinking about that. 46 | DUMMY2/p272/p272_081.wav|69|It was magic. 47 | DUMMY2/p302/p302_297.wav|30|I'm trying to stay open on that. 48 | DUMMY2/p275/p275_320.wav|40|We are in the end game. 49 | DUMMY2/p239/p239_231.wav|48|Then we will face the Danish champions. 50 | DUMMY2/p268/p268_301.wav|87|It was only later that the condition was diagnosed. 51 | DUMMY2/p336/p336_088.wav|98|They failed to reach agreement yesterday. 52 | DUMMY2/p278/p278_255.wav|10|They made such decisions in London. 53 | DUMMY2/p361/p361_132.wav|79|That got me out. 54 | DUMMY2/p307/p307_146.wav|22|You hope he prevails. 55 | DUMMY2/p244/p244_147.wav|78|They could not ignore the will of parliament, he claimed. 56 | DUMMY2/p294/p294_283.wav|104|This is our unfinished business. 57 | DUMMY2/p283/p283_300.wav|95|I would have the hammer in the crowd. 58 | DUMMY2/p239/p239_079.wav|48|I can understand the frustrations of our fans. 59 | DUMMY2/p264/p264_009.wav|65|There is , according to legend, a boiling pot of gold at one end. ) 60 | DUMMY2/p307/p307_348.wav|22|He did not oppose the divorce. 61 | DUMMY2/p304/p304_308.wav|72|We are the gateway to justice. 62 | DUMMY2/p281/p281_056.wav|36|None has ever been found. 63 | DUMMY2/p267/p267_158.wav|0|We were given a warm and friendly reception. 64 | DUMMY2/p300/p300_169.wav|102|Who do these people think they are? 65 | DUMMY2/p276/p276_177.wav|106|They exist in name alone. 66 | DUMMY2/p228/p228_245.wav|57|It is a policy which has the full support of the minister. 67 | DUMMY2/p300/p300_303.wav|102|I'm wondering what you feel about the youngest. 68 | DUMMY2/p362/p362_247.wav|15|This would give Scotland around eight members. 69 | DUMMY2/p326/p326_031.wav|28|United were in control without always being dominant. 70 | DUMMY2/p361/p361_288.wav|79|I did not think it was very proper. 71 | DUMMY2/p286/p286_145.wav|63|Tiger is not the norm. 72 | DUMMY2/p234/p234_071.wav|3|She did that for the rest of her life. 73 | DUMMY2/p263/p263_296.wav|39|The decision was announced at its annual conference in Dunfermline. 74 | DUMMY2/p323/p323_228.wav|34|She became a heroine of my childhood. 75 | DUMMY2/p280/p280_346.wav|52|It was a bit like having children. 76 | DUMMY2/p333/p333_080.wav|64|But the tragedy did not stop there. 77 | DUMMY2/p226/p226_268.wav|43|That decision is for the British Parliament and people. 78 | DUMMY2/p362/p362_314.wav|15|Is that right? 79 | DUMMY2/p240/p240_047.wav|93|It is so sad. 80 | DUMMY2/p250/p250_207.wav|24|You could feel the heat. 81 | DUMMY2/p273/p273_176.wav|56|Neither side would reveal the details of the offer. 82 | DUMMY2/p316/p316_147.wav|85|And frankly, it's been a while. 83 | DUMMY2/p265/p265_047.wav|73|It is unique. 84 | DUMMY2/p336/p336_353.wav|98|Sometimes you get them, sometimes you don't. 85 | DUMMY2/p230/p230_376.wav|35|This hasn't happened in a vacuum. 86 | DUMMY2/p308/p308_209.wav|107|There is great potential on this river. 87 | DUMMY2/p250/p250_442.wav|24|We have not yet received a letter from the Irish. 88 | DUMMY2/p260/p260_037.wav|81|It's a fact. 89 | DUMMY2/p299/p299_345.wav|58|We're very excited and challenged by the project. 90 | DUMMY2/p269/p269_218.wav|94|A Grampian Police spokesman said. 91 | DUMMY2/p306/p306_014.wav|12|To the Hebrews it was a token that there would be no more universal floods. 92 | DUMMY2/p271/p271_292.wav|27|It's a record label, not a form of music. 93 | DUMMY2/p247/p247_225.wav|14|I am considered a teenager.) 94 | DUMMY2/p294/p294_094.wav|104|It should be a condition of employment. 95 | DUMMY2/p269/p269_031.wav|94|Is this accurate? 96 | DUMMY2/p275/p275_116.wav|40|It's not fair. 97 | DUMMY2/p265/p265_006.wav|73|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. 98 | DUMMY2/p285/p285_072.wav|2|Mr Irvine said Mr Rafferty was now in good spirits. 99 | DUMMY2/p270/p270_167.wav|8|We did what we had to do. 100 | DUMMY2/p360/p360_397.wav|60|It is a relief. 101 | -------------------------------------------------------------------------------- /datasets/vctk_base/filelists/vctk_audio_sid_text_val_filelist.txt.cleaned: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm. 2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm. 3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt. 4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn. 5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn. 6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd. 7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt. 8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt. 9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm. 10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz. 11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz. 12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks. 13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt. 14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns? 15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ. 16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt. 17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ. 18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi. 19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz. 20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl. 21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ. 22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn. 23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld. 24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs. 25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz. 26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd. 27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ. 28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs. 29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt. 31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk. 32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl. 33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf. 34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm. 35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd. 36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf. 37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn. 38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp. 39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ. 40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən. 41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl. 42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs." 43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl. 44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz. 45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt. 46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk. 47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt. 48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm. 49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz. 50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd. 51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ. 52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən. 53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt. 54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz. 55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd. 56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs. 57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd. 58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz. 59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd. 60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs. 61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs. 62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd. 63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən. 64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ? 65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn. 66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ. 67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst. 68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz. 69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt. 70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ. 71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm. 72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf. 73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn. 74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd. 75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən. 76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ. 77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl. 78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt? 79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd. 80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt. 81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ. 82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl. 83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk. 84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt. 85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm. 86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ. 87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ. 88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt. 89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt. 90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd. 91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz. 92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk. 93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ. 94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt. 95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət? 96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ. 97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts. 99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː. 100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf. 101 | -------------------------------------------------------------------------------- /figures/figure01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniilrobnikov/vits2/0525da4a558da999a725b9fddaa4584617df328b/figures/figure01.png -------------------------------------------------------------------------------- /figures/figure02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniilrobnikov/vits2/0525da4a558da999a725b9fddaa4584617df328b/figures/figure02.png -------------------------------------------------------------------------------- /figures/figure03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniilrobnikov/vits2/0525da4a558da999a725b9fddaa4584617df328b/figures/figure03.png -------------------------------------------------------------------------------- /inference_batch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from tqdm import tqdm\n", 11 | "import torch\n", 12 | "import torchaudio\n", 13 | "from torch.utils.data import Dataset, DataLoader\n", 14 | "\n", 15 | "from utils.task import load_checkpoint\n", 16 | "from utils.hparams import get_hparams_from_file\n", 17 | "from model.models import SynthesizerTrn\n", 18 | "from text.symbols import symbols\n", 19 | "from text import tokenizer\n", 20 | "\n", 21 | "\n", 22 | "def get_text(text: str, hps) -> torch.LongTensor:\n", 23 | " text_norm = tokenizer(text, hps.data.text_cleaners, language=hps.data.language)\n", 24 | " text_norm = torch.LongTensor(text_norm)\n", 25 | " return text_norm" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "dataset_path = \"filelists/madasr23_test.csv\"\n", 35 | "output_path = \"/path/to/output/directory\"\n", 36 | "data = pd.read_csv(dataset_path, sep=\"|\")\n", 37 | "print(data.head())" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## MADASR23 batch inference\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "model = \"custom_base\"\n", 54 | "hps = get_hparams_from_file(f\"./datasets/{model}/config.yaml\")" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "net_g = SynthesizerTrn(len(symbols), hps.data.n_mels if hps.data.use_mel else hps.data.n_fft // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model).cuda()\n", 64 | "_ = net_g.eval()\n", 65 | "\n", 66 | "_ = load_checkpoint(f\"./datasets/{model}/logs/G_15000.pth\", net_g, None)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "class MyDataset(Dataset):\n", 76 | " def __init__(self, dataframe, hps):\n", 77 | " self.data = dataframe\n", 78 | " self.hps = hps\n", 79 | "\n", 80 | " def __len__(self):\n", 81 | " return len(self.data)\n", 82 | "\n", 83 | " def __getitem__(self, idx):\n", 84 | " sid_idx = self.data[\"sid_idx\"][idx]\n", 85 | " sid = self.data[\"sid\"][idx]\n", 86 | " phonemes = self.data[\"phonemes\"][idx]\n", 87 | " stn_tst = get_text(phonemes, self.hps)\n", 88 | " return sid_idx, sid, stn_tst, idx\n", 89 | "\n", 90 | "\n", 91 | "# Initialize the dataset and data loader\n", 92 | "dataset = MyDataset(data, hps)\n", 93 | "data_loader = DataLoader(dataset, batch_size=1, num_workers=8)\n", 94 | "\n", 95 | "for sid_idx, spk_id, stn_tst, i in tqdm(data_loader):\n", 96 | " sid_idx = int(sid_idx)\n", 97 | " spk_id = int(spk_id)\n", 98 | " i = int(i)\n", 99 | " stn_tst = stn_tst[0]\n", 100 | " with torch.no_grad():\n", 101 | " x_tst = stn_tst.cuda().unsqueeze(0)\n", 102 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n", 103 | " sid = torch.LongTensor([sid_idx]).cuda()\n", 104 | " audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667,\n", 105 | " noise_scale_w=0.8, length_scale=1)[0][0].data.cpu()\n", 106 | " torchaudio.save(f\"{output_path}/{spk_id}_{i}.wav\", audio,\n", 107 | " hps.data.sample_rate, bits_per_sample=hps.data.bits_per_sample)\n", 108 | "\n", 109 | "print(\"Done!\")" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "### Voice Conversion\n", 117 | "\n", 118 | "TODO: Add batch inference for voice conversion\n" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.11.4" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 4 148 | } 149 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | 5 | def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]): 6 | loss = 0 7 | for dr, dg in zip(fmap_r, fmap_g): 8 | for rl, gl in zip(dr, dg): 9 | rl = rl.float().detach() 10 | gl = gl.float() 11 | loss += torch.mean(torch.abs(rl - gl)) 12 | 13 | return loss * 2 14 | 15 | 16 | def discriminator_loss(disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]): 17 | loss = 0 18 | r_losses = [] 19 | g_losses = [] 20 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 21 | dr = dr.float() 22 | dg = dg.float() 23 | r_loss = torch.mean((1 - dr) ** 2) 24 | g_loss = torch.mean(dg**2) 25 | loss += r_loss + g_loss 26 | r_losses.append(r_loss.item()) 27 | g_losses.append(g_loss.item()) 28 | 29 | return loss, r_losses, g_losses 30 | 31 | 32 | def generator_loss(disc_outputs: List[torch.Tensor]): 33 | loss = 0 34 | gen_losses = [] 35 | for dg in disc_outputs: 36 | dg = dg.float() 37 | l = torch.mean((1 - dg) ** 2) 38 | gen_losses.append(l) 39 | loss += l 40 | 41 | return loss, gen_losses 42 | 43 | 44 | def kl_loss(z_p: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, z_mask: torch.Tensor): 45 | """ 46 | z_p, logs_q: [b, h, t_t] 47 | m_p, logs_p: [b, h, t_t] 48 | """ 49 | z_p = z_p.float() 50 | logs_q = logs_q.float() 51 | m_p = m_p.float() 52 | logs_p = logs_p.float() 53 | z_mask = z_mask.float() 54 | 55 | kl = logs_p - logs_q - 0.5 56 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 57 | kl = torch.sum(kl * z_mask) 58 | l = kl / torch.sum(z_mask) 59 | return l 60 | 61 | 62 | def kl_loss_normal(m_q: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, z_mask: torch.Tensor): 63 | """ 64 | z_p, logs_q: [b, h, t_t] 65 | m_p, logs_p: [b, h, t_t] 66 | """ 67 | m_q = m_q.float() 68 | logs_q = logs_q.float() 69 | m_p = m_p.float() 70 | logs_p = logs_p.float() 71 | z_mask = z_mask.float() 72 | 73 | kl = logs_p - logs_q - 0.5 74 | kl += 0.5 * (torch.exp(2.0 * logs_q) + (m_q - m_p) ** 2) * torch.exp(-2.0 * logs_p) 75 | kl = torch.sum(kl * z_mask) 76 | l = kl / torch.sum(z_mask) 77 | return l 78 | -------------------------------------------------------------------------------- /model/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultiCondLayer(nn.Module): 6 | def __init__( 7 | self, 8 | gin_channels: int, 9 | out_channels: int, 10 | n_cond: int, 11 | ): 12 | """MultiCondLayer of VITS model. 13 | 14 | Args: 15 | gin_channels (int): Number of conditioning tensor channels. 16 | out_channels (int): Number of output tensor channels. 17 | n_cond (int): Number of conditions. 18 | """ 19 | super().__init__() 20 | self.n_cond = n_cond 21 | 22 | self.cond_layers = nn.ModuleList() 23 | for _ in range(n_cond): 24 | self.cond_layers.append(nn.Linear(gin_channels, out_channels)) 25 | 26 | def forward(self, cond: torch.Tensor, x_mask: torch.Tensor): 27 | """ 28 | Shapes: 29 | - cond: :math:`[B, C, N]` 30 | - x_mask: :math`[B, 1, T]` 31 | """ 32 | 33 | cond_out = torch.zeros_like(cond) 34 | for i in range(self.n_cond): 35 | cond_in = self.cond_layers[i](cond.mT).mT 36 | cond_out = cond_out + cond_in 37 | return cond_out * x_mask 38 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, remove_weight_norm 5 | 6 | from model.modules import LRELU_SLOPE 7 | from utils.model import init_weights, get_padding 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 12 | super(Generator, self).__init__() 13 | self.num_kernels = len(resblock_kernel_sizes) 14 | self.num_upsamples = len(upsample_rates) 15 | self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 16 | resblock = ResBlock1 if resblock == "1" else ResBlock2 17 | 18 | self.ups = nn.ModuleList() 19 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 20 | self.ups.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) 21 | 22 | self.resblocks = nn.ModuleList() 23 | for i in range(len(self.ups)): 24 | ch = upsample_initial_channel // (2 ** (i + 1)) 25 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 26 | self.resblocks.append(resblock(ch, k, d)) 27 | 28 | self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) 29 | self.ups.apply(init_weights) 30 | 31 | if gin_channels != 0: 32 | self.cond = nn.Linear(gin_channels, upsample_initial_channel) 33 | 34 | def forward(self, x, g=None): 35 | x = self.conv_pre(x) 36 | if g is not None: 37 | x = x + self.cond(g.mT).mT 38 | 39 | for i in range(self.num_upsamples): 40 | x = F.leaky_relu(x, LRELU_SLOPE) 41 | x = self.ups[i](x) 42 | xs = None 43 | for j in range(self.num_kernels): 44 | if xs is None: 45 | xs = self.resblocks[i * self.num_kernels + j](x) 46 | else: 47 | xs += self.resblocks[i * self.num_kernels + j](x) 48 | x = xs / self.num_kernels 49 | x = F.leaky_relu(x) 50 | x = self.conv_post(x) 51 | x = torch.tanh(x) 52 | 53 | return x 54 | 55 | def remove_weight_norm(self): 56 | print("Removing weight norm...") 57 | for l in self.ups: 58 | remove_weight_norm(l) 59 | for l in self.resblocks: 60 | l.remove_weight_norm() 61 | 62 | 63 | class ResBlock1(nn.Module): 64 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 65 | super(ResBlock1, self).__init__() 66 | self.convs1 = nn.ModuleList( 67 | [ 68 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), 69 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), 70 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))), 71 | ] 72 | ) 73 | self.convs1.apply(init_weights) 74 | 75 | self.convs2 = nn.ModuleList( 76 | [ 77 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), 78 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), 79 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), 80 | ] 81 | ) 82 | self.convs2.apply(init_weights) 83 | 84 | def forward(self, x, x_mask=None): 85 | for c1, c2 in zip(self.convs1, self.convs2): 86 | xt = F.leaky_relu(x, LRELU_SLOPE) 87 | if x_mask is not None: 88 | xt = xt * x_mask 89 | xt = c1(xt) 90 | xt = F.leaky_relu(xt, LRELU_SLOPE) 91 | if x_mask is not None: 92 | xt = xt * x_mask 93 | xt = c2(xt) 94 | x = xt + x 95 | if x_mask is not None: 96 | x = x * x_mask 97 | return x 98 | 99 | def remove_weight_norm(self): 100 | for l in self.convs1: 101 | remove_weight_norm(l) 102 | for l in self.convs2: 103 | remove_weight_norm(l) 104 | 105 | 106 | class ResBlock2(nn.Module): 107 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 108 | super(ResBlock2, self).__init__() 109 | self.convs = nn.ModuleList( 110 | [ 111 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), 112 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), 113 | ] 114 | ) 115 | self.convs.apply(init_weights) 116 | 117 | def forward(self, x, x_mask=None): 118 | for c in self.convs: 119 | xt = F.leaky_relu(x, LRELU_SLOPE) 120 | if x_mask is not None: 121 | xt = xt * x_mask 122 | xt = c(xt) 123 | x = xt + x 124 | if x_mask is not None: 125 | x = x * x_mask 126 | return x 127 | 128 | def remove_weight_norm(self): 129 | for l in self.convs: 130 | remove_weight_norm(l) 131 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, spectral_norm 5 | 6 | from model.modules import LRELU_SLOPE 7 | from utils.model import get_padding 8 | 9 | 10 | class DiscriminatorP(nn.Module): 11 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 12 | super(DiscriminatorP, self).__init__() 13 | self.period = period 14 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 15 | self.convs = nn.ModuleList( 16 | [ 17 | norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 18 | norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 19 | norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 20 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 21 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 22 | ] 23 | ) 24 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 25 | 26 | def forward(self, x): 27 | fmap = [] 28 | 29 | # 1d to 2d 30 | b, c, t = x.shape 31 | if t % self.period != 0: # pad first 32 | n_pad = self.period - (t % self.period) 33 | x = F.pad(x, (0, n_pad), "reflect") 34 | t = t + n_pad 35 | x = x.view(b, c, t // self.period, self.period) 36 | 37 | for l in self.convs: 38 | x = l(x) 39 | x = F.leaky_relu(x, LRELU_SLOPE) 40 | fmap.append(x) 41 | x = self.conv_post(x) 42 | fmap.append(x) 43 | x = torch.flatten(x, 1, -1) 44 | 45 | return x, fmap 46 | 47 | 48 | class DiscriminatorS(nn.Module): 49 | def __init__(self, use_spectral_norm=False): 50 | super(DiscriminatorS, self).__init__() 51 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 52 | self.convs = nn.ModuleList( 53 | [ 54 | norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), 55 | norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), 56 | norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), 57 | norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 58 | norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 59 | norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), 60 | ] 61 | ) 62 | self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) 63 | 64 | def forward(self, x): 65 | fmap = [] 66 | 67 | for l in self.convs: 68 | x = l(x) 69 | x = F.leaky_relu(x, LRELU_SLOPE) 70 | fmap.append(x) 71 | x = self.conv_post(x) 72 | fmap.append(x) 73 | x = torch.flatten(x, 1, -1) 74 | 75 | return x, fmap 76 | 77 | 78 | class MultiPeriodDiscriminator(nn.Module): 79 | def __init__(self, use_spectral_norm=False): 80 | super(MultiPeriodDiscriminator, self).__init__() 81 | periods = [2, 3, 5, 7, 11] # [1, 2, 3, 5, 7, 11] 82 | 83 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 84 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 85 | self.discriminators = nn.ModuleList(discs) 86 | 87 | def forward(self, y, y_hat): 88 | y_d_rs = [] 89 | y_d_gs = [] 90 | fmap_rs = [] 91 | fmap_gs = [] 92 | for i, d in enumerate(self.discriminators): 93 | y_d_r, fmap_r = d(y) 94 | y_d_g, fmap_g = d(y_hat) 95 | y_d_rs.append(y_d_r) 96 | y_d_gs.append(y_d_g) 97 | fmap_rs.append(fmap_r) 98 | fmap_gs.append(fmap_g) 99 | 100 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 101 | -------------------------------------------------------------------------------- /model/duration_predictors.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from model.modules import Flip 7 | from model.normalization import LayerNorm 8 | from utils.transforms import piecewise_rational_quadratic_transform 9 | 10 | 11 | class StochasticDurationPredictor(nn.Module): 12 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 13 | super().__init__() 14 | self.log_flow = Log() 15 | self.flows = nn.ModuleList() 16 | self.flows.append(ElementwiseAffine(2)) 17 | for i in range(n_flows): 18 | self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 19 | self.flows.append(Flip()) 20 | 21 | self.pre = nn.Linear(in_channels, filter_channels) 22 | self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 23 | self.proj = nn.Linear(filter_channels, filter_channels) 24 | 25 | self.post_pre = nn.Linear(1, filter_channels) 26 | self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 27 | self.post_proj = nn.Linear(filter_channels, filter_channels) 28 | 29 | self.post_flows = nn.ModuleList() 30 | self.post_flows.append(ElementwiseAffine(2)) 31 | for i in range(4): 32 | self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 33 | self.post_flows.append(Flip()) 34 | 35 | if gin_channels != 0: 36 | self.cond = nn.Linear(gin_channels, filter_channels) 37 | 38 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, w=None, g=None, reverse=False, noise_scale=1.0): 39 | x = torch.detach(x) 40 | x = self.pre(x.mT).mT 41 | if g is not None: 42 | g = torch.detach(g) 43 | x = x + self.cond(g.mT).mT 44 | x = self.convs(x, x_mask) 45 | x = self.proj(x.mT).mT * x_mask 46 | 47 | if not reverse: 48 | flows = self.flows 49 | assert w is not None 50 | 51 | logdet_tot_q = 0 52 | h_w = self.post_pre(w.mT).mT 53 | h_w = self.post_convs(h_w, x_mask) 54 | h_w = self.post_proj(h_w.mT).mT * x_mask 55 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 56 | z_q = e_q 57 | for flow in self.post_flows: 58 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 59 | logdet_tot_q += logdet_q 60 | z_u, z1 = torch.split(z_q, [1, 1], 1) 61 | u = torch.sigmoid(z_u) * x_mask 62 | z0 = (w - u) * x_mask 63 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) 64 | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q 65 | 66 | logdet_tot = 0 67 | z0, logdet = self.log_flow(z0, x_mask) 68 | logdet_tot += logdet 69 | z = torch.cat([z0, z1], 1) 70 | for flow in flows: 71 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 72 | logdet_tot = logdet_tot + logdet 73 | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot 74 | return nll + logq # [b] 75 | else: 76 | flows = list(reversed(self.flows)) 77 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 78 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 79 | for flow in flows: 80 | z = flow(z, x_mask, g=x, reverse=reverse) 81 | z0, z1 = torch.split(z, [1, 1], 1) 82 | logw = z0 83 | return logw 84 | 85 | 86 | class ConvFlow(nn.Module): 87 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 88 | super().__init__() 89 | self.filter_channels = filter_channels 90 | self.num_bins = num_bins 91 | self.tail_bound = tail_bound 92 | self.half_channels = in_channels // 2 93 | 94 | self.pre = nn.Linear(self.half_channels, filter_channels) 95 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) 96 | self.proj = nn.Linear(filter_channels, self.half_channels * (num_bins * 3 - 1)) 97 | self.proj.weight.data.zero_() 98 | self.proj.bias.data.zero_() 99 | 100 | def forward(self, x, x_mask, g=None, reverse=False): 101 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 102 | h = self.pre(x0.mT).mT 103 | h = self.convs(h, x_mask, g=g) 104 | h = self.proj(h.mT).mT * x_mask 105 | 106 | b, c, t = x0.shape 107 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 108 | 109 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) 110 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) 111 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 112 | 113 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, inverse=reverse, tails="linear", tail_bound=self.tail_bound) 114 | 115 | x = torch.cat([x0, x1], 1) * x_mask 116 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 117 | if not reverse: 118 | return x, logdet 119 | else: 120 | return x 121 | 122 | 123 | class DDSConv(nn.Module): 124 | """ 125 | Dialted and Depth-Separable Convolution 126 | """ 127 | 128 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): 129 | super().__init__() 130 | self.n_layers = n_layers 131 | 132 | self.drop = nn.Dropout(p_dropout) 133 | self.convs_sep = nn.ModuleList() 134 | self.linears = nn.ModuleList() 135 | self.norms_1 = nn.ModuleList() 136 | self.norms_2 = nn.ModuleList() 137 | for i in range(n_layers): 138 | dilation = kernel_size**i 139 | padding = (kernel_size * dilation - dilation) // 2 140 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)) 141 | self.linears.append(nn.Linear(channels, channels)) 142 | self.norms_1.append(LayerNorm(channels)) 143 | self.norms_2.append(LayerNorm(channels)) 144 | 145 | def forward(self, x, x_mask, g=None): 146 | if g is not None: 147 | x = x + g 148 | for i in range(self.n_layers): 149 | y = self.convs_sep[i](x * x_mask) 150 | y = self.norms_1[i](y) 151 | y = F.gelu(y) 152 | y = self.linears[i](y.mT).mT 153 | y = self.norms_2[i](y) 154 | y = F.gelu(y) 155 | y = self.drop(y) 156 | x = x + y 157 | return x * x_mask 158 | 159 | 160 | # TODO convert to class method 161 | class Log(nn.Module): 162 | def forward(self, x, x_mask, reverse=False, **kwargs): 163 | if not reverse: 164 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 165 | logdet = torch.sum(-y, [1, 2]) 166 | return y, logdet 167 | else: 168 | x = torch.exp(x) * x_mask 169 | return x 170 | 171 | 172 | class ElementwiseAffine(nn.Module): 173 | def __init__(self, channels): 174 | super().__init__() 175 | self.m = nn.Parameter(torch.zeros(channels, 1)) 176 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 177 | 178 | def forward(self, x, x_mask, reverse=False, **kwargs): 179 | if not reverse: 180 | y = self.m + torch.exp(self.logs) * x 181 | y = y * x_mask 182 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 183 | return y, logdet 184 | else: 185 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 186 | return x 187 | 188 | 189 | class DurationPredictor(nn.Module): 190 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 191 | super().__init__() 192 | self.drop = nn.Dropout(p_dropout) 193 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) 194 | self.norm_1 = LayerNorm(filter_channels) 195 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) 196 | self.norm_2 = LayerNorm(filter_channels) 197 | self.proj = nn.Linear(filter_channels, 1) 198 | 199 | if gin_channels != 0: 200 | self.cond = nn.Linear(gin_channels, in_channels) 201 | 202 | def forward(self, x, x_mask, g=None): 203 | x = torch.detach(x) 204 | if g is not None: 205 | g = torch.detach(g) 206 | x = x + self.cond(g.mT).mT 207 | x = self.conv_1(x * x_mask) 208 | x = torch.relu(x) 209 | x = self.norm_1(x) 210 | x = self.drop(x) 211 | x = self.conv_2(x * x_mask) 212 | x = torch.relu(x) 213 | x = self.norm_2(x) 214 | x = self.drop(x) 215 | x = self.proj((x * x_mask).mT).mT 216 | return x * x_mask 217 | -------------------------------------------------------------------------------- /model/encoders.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model.modules import WN 6 | from model.transformer import RelativePositionTransformer 7 | from utils.model import sequence_mask 8 | 9 | 10 | # * Ready and Tested 11 | class TextEncoder(nn.Module): 12 | def __init__( 13 | self, 14 | n_vocab: int, 15 | out_channels: int, 16 | hidden_channels: int, 17 | hidden_channels_ffn: int, 18 | n_heads: int, 19 | n_layers: int, 20 | kernel_size: int, 21 | dropout: float, 22 | gin_channels=0, 23 | lang_channels=0, 24 | speaker_cond_layer=0, 25 | ): 26 | """Text Encoder for VITS model. 27 | 28 | Args: 29 | n_vocab (int): Number of characters for the embedding layer. 30 | out_channels (int): Number of channels for the output. 31 | hidden_channels (int): Number of channels for the hidden layers. 32 | hidden_channels_ffn (int): Number of channels for the convolutional layers. 33 | n_heads (int): Number of attention heads for the Transformer layers. 34 | n_layers (int): Number of Transformer layers. 35 | kernel_size (int): Kernel size for the FFN layers in Transformer network. 36 | dropout (float): Dropout rate for the Transformer layers. 37 | gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0. 38 | lang_channels (int, optional): Number of channels for language embedding. Defaults to 0. 39 | """ 40 | super().__init__() 41 | self.out_channels = out_channels 42 | self.hidden_channels = hidden_channels 43 | 44 | self.emb = nn.Embedding(n_vocab, hidden_channels) 45 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 46 | 47 | self.encoder = RelativePositionTransformer( 48 | in_channels=hidden_channels, 49 | out_channels=hidden_channels, 50 | hidden_channels=hidden_channels, 51 | hidden_channels_ffn=hidden_channels_ffn, 52 | n_heads=n_heads, 53 | n_layers=n_layers, 54 | kernel_size=kernel_size, 55 | dropout=dropout, 56 | window_size=4, 57 | gin_channels=gin_channels, 58 | lang_channels=lang_channels, 59 | speaker_cond_layer=speaker_cond_layer, 60 | ) 61 | self.proj = nn.Linear(hidden_channels, out_channels * 2) 62 | 63 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: torch.Tensor = None, lang: torch.Tensor = None): 64 | """ 65 | Shapes: 66 | - x: :math:`[B, T]` 67 | - x_length: :math:`[B]` 68 | """ 69 | x = self.emb(x).mT * math.sqrt(self.hidden_channels) # [b, h, t] 70 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 71 | 72 | x = self.encoder(x, x_mask, g=g, lang=lang) 73 | stats = self.proj(x.mT).mT * x_mask 74 | 75 | m, logs = torch.split(stats, self.out_channels, dim=1) 76 | z = m + torch.randn_like(m) * torch.exp(logs) * x_mask 77 | return z, m, logs, x, x_mask 78 | 79 | 80 | # * Ready and Tested 81 | class PosteriorEncoder(nn.Module): 82 | def __init__( 83 | self, 84 | in_channels: int, 85 | out_channels: int, 86 | hidden_channels: int, 87 | kernel_size: int, 88 | dilation_rate: int, 89 | n_layers: int, 90 | gin_channels=0, 91 | ): 92 | """Posterior Encoder of VITS model. 93 | 94 | :: 95 | x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z 96 | 97 | Args: 98 | in_channels (int): Number of input tensor channels. 99 | out_channels (int): Number of output tensor channels. 100 | hidden_channels (int): Number of hidden channels. 101 | kernel_size (int): Kernel size of the WaveNet convolution layers. 102 | dilation_rate (int): Dilation rate of the WaveNet layers. 103 | num_layers (int): Number of the WaveNet layers. 104 | cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. 105 | """ 106 | super().__init__() 107 | self.out_channels = out_channels 108 | 109 | self.pre = nn.Linear(in_channels, hidden_channels) 110 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 111 | self.proj = nn.Linear(hidden_channels, out_channels * 2) 112 | 113 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g=None): 114 | """ 115 | Shapes: 116 | - x: :math:`[B, C, T]` 117 | - x_lengths: :math:`[B, 1]` 118 | - g: :math:`[B, C, 1]` 119 | """ 120 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 121 | x = self.pre(x.mT).mT * x_mask 122 | x = self.enc(x, x_mask, g=g) 123 | stats = self.proj(x.mT).mT * x_mask 124 | m, logs = torch.split(stats, self.out_channels, dim=1) 125 | z = m + torch.randn_like(m) * torch.exp(logs) * x_mask 126 | return z, m, logs, x_mask 127 | 128 | 129 | # TODO Ready for testing 130 | class AudioEncoder(nn.Module): 131 | def __init__( 132 | self, 133 | in_channels: int, 134 | out_channels: int, 135 | hidden_channels: int, 136 | hidden_channels_ffn: int, 137 | n_heads: int, 138 | n_layers: int, 139 | kernel_size: int, 140 | dropout: float, 141 | gin_channels=0, 142 | lang_channels=0, 143 | speaker_cond_layer=0, 144 | ): 145 | """Audio Encoder of VITS model. 146 | 147 | Args: 148 | in_channels (int): Number of input tensor channels. 149 | out_channels (int): Number of channels for the output. 150 | hidden_channels (int): Number of channels for the hidden layers. 151 | hidden_channels_ffn (int): Number of channels for the convolutional layers. 152 | n_heads (int): Number of attention heads for the Transformer layers. 153 | n_layers (int): Number of Transformer layers. 154 | kernel_size (int): Kernel size for the FFN layers in Transformer network. 155 | dropout (float): Dropout rate for the Transformer layers. 156 | gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0. 157 | lang_channels (int, optional): Number of channels for language embedding. Defaults to 0. 158 | """ 159 | super().__init__() 160 | self.out_channels = out_channels 161 | self.hidden_channels = hidden_channels 162 | 163 | self.pre = nn.Linear(in_channels, hidden_channels) 164 | self.encoder = RelativePositionTransformer( 165 | in_channels=hidden_channels, 166 | out_channels=hidden_channels, 167 | hidden_channels=hidden_channels, 168 | hidden_channels_ffn=hidden_channels_ffn, 169 | n_heads=n_heads, 170 | n_layers=n_layers, 171 | kernel_size=kernel_size, 172 | dropout=dropout, 173 | window_size=4, 174 | gin_channels=gin_channels, 175 | lang_channels=lang_channels, 176 | speaker_cond_layer=speaker_cond_layer, 177 | ) 178 | self.post = nn.Linear(hidden_channels, out_channels * 2) 179 | 180 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: torch.Tensor = None, lang: torch.Tensor = None): 181 | """ 182 | Shapes: 183 | - x: :math:`[B, C, T]` 184 | - x_lengths: :math:`[B, 1]` 185 | - g: :math:`[B, C, 1]` 186 | - lang: :math:`[B, C, 1]` 187 | """ 188 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 189 | 190 | x = self.pre(x.mT).mT * x_mask # [B, C, t'] 191 | x = self.encoder(x, x_mask, g=g, lang=lang) 192 | stats = self.post(x.mT).mT * x_mask 193 | 194 | m, logs = torch.split(stats, self.out_channels, dim=1) 195 | z = m + torch.randn_like(m) * torch.exp(logs) * x_mask 196 | return z, m, logs, x_mask 197 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from model.encoders import TextEncoder, PosteriorEncoder, AudioEncoder 5 | from model.normalizing_flows import ResidualCouplingBlock 6 | from model.duration_predictors import DurationPredictor, StochasticDurationPredictor 7 | from model.decoder import Generator 8 | from utils.monotonic_align import search_path, generate_path 9 | from utils.model import sequence_mask, rand_slice_segments 10 | 11 | 12 | class SynthesizerTrn(nn.Module): 13 | """ 14 | Synthesizer for Training 15 | """ 16 | 17 | def __init__( 18 | self, 19 | n_vocab, 20 | spec_channels, 21 | segment_size, 22 | inter_channels, 23 | hidden_channels, 24 | filter_channels, 25 | n_heads, 26 | n_layers, 27 | n_layers_q, 28 | n_flows, 29 | kernel_size, 30 | p_dropout, 31 | speaker_cond_layer, 32 | resblock, 33 | resblock_kernel_sizes, 34 | resblock_dilation_sizes, 35 | upsample_rates, 36 | upsample_initial_channel, 37 | upsample_kernel_sizes, 38 | mas_noise_scale, 39 | mas_noise_scale_decay, 40 | use_sdp=True, 41 | use_transformer_flow=True, 42 | n_speakers=0, 43 | gin_channels=0, 44 | **kwargs 45 | ): 46 | super().__init__() 47 | self.segment_size = segment_size 48 | self.n_speakers = n_speakers 49 | self.use_sdp = use_sdp 50 | self.mas_noise_scale = mas_noise_scale 51 | self.mas_noise_scale_decay = mas_noise_scale_decay 52 | 53 | self.enc_p = TextEncoder(n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels=gin_channels, speaker_cond_layer=speaker_cond_layer) 54 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, n_layers_q, gin_channels=gin_channels) 55 | # self.enc_q = AudioEncoder(spec_channels, inter_channels, 32, 768, n_heads, 2, kernel_size, p_dropout, gin_channels=gin_channels) 56 | # self.enc_q = AudioEncoder(spec_channels, inter_channels, 32, 32, n_heads, 3, kernel_size, p_dropout, gin_channels=gin_channels) 57 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) 58 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, n_flows=n_flows, gin_channels=gin_channels, mean_only=False, use_transformer_flow=use_transformer_flow) 59 | 60 | if use_sdp: 61 | self.dp = StochasticDurationPredictor(hidden_channels, hidden_channels, 3, 0.5, 4, gin_channels=gin_channels) 62 | else: 63 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 64 | 65 | if n_speakers > 1: 66 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 67 | 68 | def forward(self, x, x_lengths, y, y_lengths, sid=None): 69 | if self.n_speakers > 0: 70 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 71 | else: 72 | g = None 73 | 74 | z_p_text, m_p_text, logs_p_text, h_text, x_mask = self.enc_p(x, x_lengths, g=g) 75 | z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g) 76 | z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g) 77 | 78 | attn = search_path(z_q_dur, m_p_text, logs_p_text, x_mask, y_mask, mas_noise_scale=self.mas_noise_scale) 79 | self.mas_noise_scale = max(self.mas_noise_scale - self.mas_noise_scale_decay, 0.0) 80 | 81 | w = attn.sum(2) # [b, 1, t_s] 82 | 83 | # * reduce posterior 84 | # TODO Test gain constant 85 | if False: 86 | attn_inv = attn.squeeze(1) * (1 / (w + 1e-9)) 87 | m_q_text = torch.matmul(attn_inv.mT, m_q_dur.mT).mT 88 | logs_q_text = torch.matmul(attn_inv.mT, logs_q_dur.mT).mT 89 | 90 | # * expand prior 91 | if self.use_sdp: 92 | l_length = self.dp(h_text, x_mask, w, g=g) 93 | l_length = l_length / torch.sum(x_mask) 94 | else: 95 | logw_ = torch.log(w + 1e-6) * x_mask 96 | logw = self.dp(h_text, x_mask, g=g) 97 | l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging 98 | m_p_dur = torch.matmul(attn.squeeze(1), m_p_text.mT).mT 99 | logs_p_dur = torch.matmul(attn.squeeze(1), logs_p_text.mT).mT 100 | z_p_dur = m_p_dur + torch.randn_like(m_p_dur) * torch.exp(logs_p_dur) * y_mask 101 | 102 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_dur, m_p_dur, logs_p_dur, y_mask, g=g, reverse=True) 103 | 104 | z_slice, ids_slice = rand_slice_segments(z_q_audio, y_lengths, self.segment_size) 105 | o = self.dec(z_slice, g=g) 106 | return ( 107 | o, 108 | l_length, 109 | attn, 110 | ids_slice, 111 | x_mask, 112 | y_mask, 113 | (m_p_text, logs_p_text), 114 | (m_p_dur, logs_p_dur, z_q_dur, logs_q_dur), 115 | (m_p_audio, logs_p_audio, m_q_audio, logs_q_audio), 116 | ) 117 | 118 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=None): 119 | if self.n_speakers > 0: 120 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 121 | else: 122 | g = None 123 | 124 | z_p_text, m_p_text, logs_p_text, h_text, x_mask = self.enc_p(x, x_lengths, g=g) 125 | 126 | if self.use_sdp: 127 | logw = self.dp(h_text, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) 128 | else: 129 | logw = self.dp(h_text, x_mask, g=g) 130 | w = torch.exp(logw) * x_mask * length_scale 131 | w_ceil = torch.ceil(w) 132 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 133 | y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 134 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 135 | attn = generate_path(w_ceil, attn_mask) 136 | 137 | m_p_dur = torch.matmul(attn.squeeze(1), m_p_text.mT).mT # [b, t', t], [b, t, d] -> [b, d, t'] 138 | logs_p_dur = torch.matmul(attn.squeeze(1), logs_p_text.mT).mT # [b, t', t], [b, t, d] -> [b, d, t'] 139 | z_p_dur = m_p_dur + torch.randn_like(m_p_dur) * torch.exp(logs_p_dur) * noise_scale 140 | 141 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_dur, m_p_dur, logs_p_dur, y_mask, g=g, reverse=True) 142 | o = self.dec((z_p_audio * y_mask)[:, :, :max_len], g=g) 143 | return o, attn, y_mask, (z_p_dur, m_p_dur, logs_p_dur), (z_p_audio, m_p_audio, logs_p_audio) 144 | 145 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): 146 | assert self.n_speakers > 0, "n_speakers have to be larger than 0." 147 | g_src = self.emb_g(sid_src).unsqueeze(-1) 148 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) 149 | z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g_src) 150 | z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g_src) 151 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_q_dur, m_q_dur, logs_q_dur, y_mask, g=g_tgt, reverse=True) 152 | o_hat = self.dec(z_p_audio * y_mask, g=g_tgt) 153 | return o_hat, y_mask, (z_q_dur, m_q_dur, logs_q_dur), (z_p_audio, m_p_audio, logs_p_audio) 154 | 155 | def voice_restoration(self, y, y_lengths, sid=None): 156 | if self.n_speakers > 0: 157 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 158 | else: 159 | g = None 160 | z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g) 161 | z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g) 162 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_q_dur, m_q_dur, logs_q_dur, y_mask, g=g, reverse=True) 163 | o_hat = self.dec(z_p_audio * y_mask, g=g) 164 | return o_hat, y_mask, (z_q_dur, m_q_dur, logs_q_dur), (z_p_audio, m_p_audio, logs_p_audio) 165 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.model import fused_add_tanh_sigmoid_multiply 5 | 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | # ! PosteriorEncoder 11 | # ! ResidualCouplingLayer 12 | class WN(nn.Module): 13 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 14 | super(WN, self).__init__() 15 | assert kernel_size % 2 == 1 16 | self.hidden_channels = hidden_channels 17 | self.kernel_size = (kernel_size,) 18 | self.n_layers = n_layers 19 | self.gin_channels = gin_channels 20 | 21 | self.in_layers = nn.ModuleList() 22 | self.res_skip_layers = nn.ModuleList() 23 | self.drop = nn.Dropout(p_dropout) 24 | 25 | if gin_channels != 0: 26 | cond_layer = nn.Linear(gin_channels, 2 * hidden_channels * n_layers) 27 | self.cond_layer = nn.utils.weight_norm(cond_layer, name="weight") 28 | 29 | for i in range(n_layers): 30 | dilation = dilation_rate**i 31 | padding = int((kernel_size * dilation - dilation) / 2) 32 | in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding) 33 | in_layer = nn.utils.weight_norm(in_layer, name="weight") 34 | self.in_layers.append(in_layer) 35 | 36 | # last one is not necessary 37 | res_skip_channels = 2 * hidden_channels if i < n_layers - 1 else hidden_channels 38 | res_skip_layer = nn.Linear(hidden_channels, res_skip_channels) 39 | res_skip_layer = nn.utils.weight_norm(res_skip_layer, name="weight") 40 | self.res_skip_layers.append(res_skip_layer) 41 | 42 | def forward(self, x, x_mask, g=None, **kwargs): 43 | output = torch.zeros_like(x) 44 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 45 | 46 | if g is not None: 47 | g = self.cond_layer(g.mT).mT 48 | 49 | for i in range(self.n_layers): 50 | x_in = self.in_layers[i](x) 51 | if g is not None: 52 | cond_offset = i * 2 * self.hidden_channels 53 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 54 | else: 55 | g_l = torch.zeros_like(x_in) 56 | 57 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 58 | acts = self.drop(acts) 59 | 60 | res_skip_acts = self.res_skip_layers[i](acts.mT).mT 61 | if i < self.n_layers - 1: 62 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 63 | x = (x + res_acts) * x_mask 64 | output = output + res_skip_acts[:, self.hidden_channels :, :] 65 | else: 66 | output = output + res_skip_acts 67 | return output * x_mask 68 | 69 | def remove_weight_norm(self): 70 | if self.gin_channels != 0: 71 | nn.utils.remove_weight_norm(self.cond_layer) 72 | for l in self.in_layers: 73 | nn.utils.remove_weight_norm(l) 74 | for l in self.res_skip_layers: 75 | nn.utils.remove_weight_norm(l) 76 | 77 | 78 | # ! StochasticDurationPredictor 79 | # ! ResidualCouplingBlock 80 | # TODO convert to class method 81 | class Flip(nn.Module): 82 | def forward(self, x, *args, reverse=False, **kwargs): 83 | x = torch.flip(x, [1]) 84 | if not reverse: 85 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 86 | return x, logdet 87 | else: 88 | return x 89 | -------------------------------------------------------------------------------- /model/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | def __init__(self, channels, eps=1e-5): 8 | super().__init__() 9 | self.channels = channels 10 | self.eps = eps 11 | 12 | self.gamma = nn.Parameter(torch.ones(channels)) 13 | self.beta = nn.Parameter(torch.zeros(channels)) 14 | 15 | def forward(self, x: torch.Tensor): 16 | x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps) 17 | return x.mT 18 | 19 | 20 | class CondLayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5, cond_channels=0): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.linear_gamma = nn.Linear(cond_channels, channels) 27 | self.linear_beta = nn.Linear(cond_channels, channels) 28 | 29 | def forward(self, x: torch.Tensor, cond: torch.Tensor): 30 | gamma = self.linear_gamma(cond) 31 | beta = self.linear_beta(cond) 32 | 33 | x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps) 34 | return x.mT 35 | -------------------------------------------------------------------------------- /model/normalizing_flows.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.transformer import RelativePositionTransformer 5 | from model.modules import WN 6 | 7 | 8 | class ResidualCouplingBlock(nn.Module): 9 | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0, mean_only=False, use_transformer_flow=True): 10 | super().__init__() 11 | self.flows = nn.ModuleList() 12 | for i in range(n_flows): 13 | use_transformer = use_transformer_flow if (i == n_flows - 1) else False # TODO or (i == n_flows - 2) 14 | self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=mean_only, use_transformer_flow=use_transformer)) 15 | self.flows.append(Flip()) 16 | 17 | def forward(self, x, m, logs, x_mask, g=None, reverse=False): 18 | if reverse: 19 | for flow in reversed(self.flows): 20 | x, m, logs = flow(x, m, logs, x_mask, g=g, reverse=reverse) 21 | else: 22 | for flow in self.flows: 23 | x, m, logs = flow(x, m, logs, x_mask, g=g, reverse=reverse) 24 | return x, m, logs 25 | 26 | 27 | # TODO rewrite for 256x256 attention score map 28 | class ResidualCouplingLayer(nn.Module): 29 | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False, use_transformer_flow=True): 30 | assert channels % 2 == 0, "channels should be divisible by 2" 31 | super().__init__() 32 | self.half_channels = channels // 2 33 | self.mean_only = mean_only 34 | 35 | self.pre_transformer = ( 36 | RelativePositionTransformer( 37 | self.half_channels, 38 | self.half_channels, 39 | self.half_channels, 40 | self.half_channels, 41 | n_heads=2, 42 | n_layers=1, 43 | kernel_size=3, 44 | dropout=0.1, 45 | window_size=None, 46 | ) 47 | if use_transformer_flow 48 | else None 49 | ) 50 | 51 | self.pre = nn.Linear(self.half_channels, hidden_channels) 52 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 53 | self.post = nn.Linear(hidden_channels, self.half_channels * (2 - mean_only)) 54 | self.post.weight.data.zero_() 55 | self.post.bias.data.zero_() 56 | 57 | def forward(self, x, m, logs, x_mask, g=None, reverse=False): 58 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 59 | m0, m1 = torch.split(m, [self.half_channels] * 2, 1) 60 | logs0, logs1 = torch.split(logs, [self.half_channels] * 2, 1) 61 | x0_ = x0 62 | if self.pre_transformer is not None: 63 | x0_ = self.pre_transformer(x0 * x_mask, x_mask) 64 | x0_ = x0_ + x0 # residual connection 65 | h = self.pre(x0_.mT).mT * x_mask 66 | h = self.enc(h, x_mask, g=g) 67 | stats = self.post(h.mT).mT * x_mask 68 | if not self.mean_only: 69 | m_flow, logs_flow = torch.split(stats, [self.half_channels] * 2, 1) 70 | else: 71 | m_flow = stats 72 | logs_flow = torch.zeros_like(m) 73 | 74 | if reverse: 75 | x1 = (x1 - m_flow) * torch.exp(-logs_flow) * x_mask 76 | m1 = (m1 - m_flow) * torch.exp(-logs_flow) * x_mask 77 | logs1 = logs1 - logs_flow 78 | 79 | x = torch.cat([x0, x1], 1) 80 | m = torch.cat([m0, m1], 1) 81 | logs = torch.cat([logs0, logs1], 1) 82 | return x, m, logs 83 | else: 84 | x1 = m_flow + x1 * torch.exp(logs_flow) * x_mask 85 | m1 = m_flow + m1 * torch.exp(logs_flow) * x_mask 86 | logs1 = logs1 + logs_flow 87 | 88 | x = torch.cat([x0, x1], 1) 89 | m = torch.cat([m0, m1], 1) 90 | logs = torch.cat([logs0, logs1], 1) 91 | return x, m, logs 92 | 93 | 94 | class Flip(nn.Module): 95 | def forward(self, x, m, logs, *args, reverse=False, **kwargs): 96 | x = torch.flip(x, [1]) 97 | m = torch.flip(m, [1]) 98 | logs = torch.flip(logs, [1]) 99 | return x, m, logs 100 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from utils.model import convert_pad_shape 7 | from model.normalization import LayerNorm 8 | 9 | 10 | # TODO add conditioning on language 11 | # TODO check whether we need to stop gradient for speaker embedding 12 | class RelativePositionTransformer(nn.Module): 13 | def __init__( 14 | self, 15 | in_channels: int, 16 | hidden_channels: int, 17 | out_channels: int, 18 | hidden_channels_ffn: int, 19 | n_heads: int, 20 | n_layers: int, 21 | kernel_size=1, 22 | dropout=0.0, 23 | window_size=4, 24 | gin_channels=0, 25 | lang_channels=0, 26 | speaker_cond_layer=0, 27 | ): 28 | super().__init__() 29 | self.n_layers = n_layers 30 | self.speaker_cond_layer = speaker_cond_layer 31 | 32 | self.drop = nn.Dropout(dropout) 33 | self.attn_layers = nn.ModuleList() 34 | self.norm_layers_1 = nn.ModuleList() 35 | self.ffn_layers = nn.ModuleList() 36 | self.norm_layers_2 = nn.ModuleList() 37 | for i in range(self.n_layers): 38 | self.attn_layers.append(MultiHeadAttention(hidden_channels if i != 0 else in_channels, hidden_channels, n_heads, p_dropout=dropout, window_size=window_size)) 39 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 40 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, hidden_channels_ffn, kernel_size, p_dropout=dropout)) 41 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 42 | if gin_channels != 0: 43 | self.cond = nn.Linear(gin_channels, hidden_channels) 44 | 45 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: torch.Tensor = None, lang: torch.Tensor = None): 46 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 47 | x = x * x_mask 48 | for i in range(self.n_layers): 49 | # TODO consider using other conditioning 50 | # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12 51 | if i == self.speaker_cond_layer - 1 and g is not None: 52 | # ! g = torch.detach(g) 53 | x = x + self.cond(g.mT).mT 54 | x = x * x_mask 55 | y = self.attn_layers[i](x, x, attn_mask) 56 | y = self.drop(y) 57 | x = self.norm_layers_1[i](x + y) 58 | 59 | y = self.ffn_layers[i](x, x_mask) 60 | y = self.drop(y) 61 | x = self.norm_layers_2[i](x + y) 62 | x = x * x_mask 63 | return x 64 | 65 | 66 | class MultiHeadAttention(nn.Module): 67 | def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 68 | super().__init__() 69 | assert channels % n_heads == 0 70 | 71 | self.channels = channels 72 | self.out_channels = out_channels 73 | self.n_heads = n_heads 74 | self.p_dropout = p_dropout 75 | self.window_size = window_size 76 | self.heads_share = heads_share 77 | self.block_length = block_length 78 | self.proximal_bias = proximal_bias 79 | self.proximal_init = proximal_init 80 | self.attn = None 81 | 82 | self.k_channels = channels // n_heads 83 | self.conv_q = nn.Linear(channels, channels) 84 | self.conv_k = nn.Linear(channels, channels) 85 | self.conv_v = nn.Linear(channels, channels) 86 | self.conv_o = nn.Linear(channels, out_channels) 87 | self.drop = nn.Dropout(p_dropout) 88 | 89 | if window_size is not None: 90 | n_heads_rel = 1 if heads_share else n_heads 91 | rel_stddev = self.k_channels**-0.5 92 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 93 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 94 | 95 | nn.init.xavier_uniform_(self.conv_q.weight) 96 | nn.init.xavier_uniform_(self.conv_k.weight) 97 | nn.init.xavier_uniform_(self.conv_v.weight) 98 | if proximal_init: 99 | with torch.no_grad(): 100 | self.conv_k.weight.copy_(self.conv_q.weight) 101 | self.conv_k.bias.copy_(self.conv_q.bias) 102 | 103 | def forward(self, x, c, attn_mask=None): 104 | q = self.conv_q(x.mT).mT 105 | k = self.conv_k(c.mT).mT 106 | v = self.conv_v(c.mT).mT 107 | 108 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 109 | 110 | x = self.conv_o(x.mT).mT 111 | return x 112 | 113 | def attention(self, query, key, value, mask=None): 114 | # reshape [b, d, t] -> [b, n_h, t, d_k] 115 | b, d, t_s, t_t = (*key.size(), query.size(2)) 116 | query = query.view(b, self.n_heads, self.k_channels, t_t).mT 117 | key = key.view(b, self.n_heads, self.k_channels, t_s).mT 118 | value = value.view(b, self.n_heads, self.k_channels, t_s).mT 119 | 120 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT) 121 | if self.window_size is not None: 122 | assert t_s == t_t, "Relative attention is only available for self-attention." 123 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 124 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) 125 | scores_local = self._relative_position_to_absolute_position(rel_logits) 126 | scores = scores + scores_local 127 | if self.proximal_bias: 128 | assert t_s == t_t, "Proximal bias is only available for self-attention." 129 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 130 | if mask is not None: 131 | scores = scores.masked_fill(mask == 0, -1e4) 132 | if self.block_length is not None: 133 | assert t_s == t_t, "Local attention is only available for self-attention." 134 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 135 | scores = scores.masked_fill(block_mask == 0, -1e4) 136 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 137 | p_attn = self.drop(p_attn) 138 | output = torch.matmul(p_attn, value) 139 | if self.window_size is not None: 140 | relative_weights = self._absolute_position_to_relative_position(p_attn) 141 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 142 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 143 | output = output.mT.contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 144 | return output, p_attn 145 | 146 | def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor): 147 | """ 148 | x: [b, h, l, m] 149 | y: [h or 1, m, d] 150 | ret: [b, h, l, d] 151 | """ 152 | return torch.matmul(x, y.unsqueeze(0)) 153 | 154 | def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor): 155 | """ 156 | x: [b, h, l, d] 157 | y: [h or 1, m, d] 158 | ret: [b, h, l, m] 159 | """ 160 | return torch.matmul(x, y.unsqueeze(0).mT) 161 | 162 | def _get_relative_embeddings(self, relative_embeddings, length): 163 | max_relative_position = 2 * self.window_size + 1 164 | # Pad first before slice to avoid using cond ops. 165 | pad_length = max(length - (self.window_size + 1), 0) 166 | slice_start_position = max((self.window_size + 1) - length, 0) 167 | slice_end_position = slice_start_position + 2 * length - 1 168 | if pad_length > 0: 169 | padded_relative_embeddings = F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 170 | else: 171 | padded_relative_embeddings = relative_embeddings 172 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] 173 | return used_relative_embeddings 174 | 175 | def _relative_position_to_absolute_position(self, x): 176 | """ 177 | x: [b, h, l, 2*l-1] 178 | ret: [b, h, l, l] 179 | """ 180 | batch, heads, length, _ = x.size() 181 | # Concat columns of pad to shift from relative to absolute indexing. 182 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 183 | 184 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 185 | x_flat = x.view([batch, heads, length * 2 * length]) 186 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 187 | 188 | # Reshape and slice out the padded elements. 189 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] 190 | return x_final 191 | 192 | def _absolute_position_to_relative_position(self, x): 193 | """ 194 | x: [b, h, l, l] 195 | ret: [b, h, l, 2*l-1] 196 | """ 197 | batch, heads, length, _ = x.size() 198 | # padd along column 199 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 200 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 201 | # add 0's in the beginning that will skew the elements after reshape 202 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 203 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 204 | return x_final 205 | 206 | def _attention_bias_proximal(self, length): 207 | """Bias for self-attention to encourage attention to close positions. 208 | Args: 209 | length: an integer scalar. 210 | Returns: 211 | a Tensor with shape [1, 1, length, length] 212 | """ 213 | r = torch.arange(length, dtype=torch.float32) 214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 216 | 217 | 218 | class FFN(nn.Module): 219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, causal=False): 220 | super().__init__() 221 | self.kernel_size = kernel_size 222 | self.padding = self._causal_padding if causal else self._same_padding 223 | 224 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 225 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 226 | self.drop = nn.Dropout(p_dropout) 227 | 228 | def forward(self, x, x_mask): 229 | x = self.conv_1(self.padding(x * x_mask)) 230 | x = torch.relu(x) 231 | x = self.drop(x) 232 | x = self.conv_2(self.padding(x * x_mask)) 233 | return x * x_mask 234 | 235 | def _causal_padding(self, x): 236 | if self.kernel_size == 1: 237 | return x 238 | pad_l = self.kernel_size - 1 239 | pad_r = 0 240 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 241 | x = F.pad(x, convert_pad_shape(padding)) 242 | return x 243 | 244 | def _same_padding(self, x): 245 | if self.kernel_size == 1: 246 | return x 247 | pad_l = (self.kernel_size - 1) // 2 248 | pad_r = self.kernel_size // 2 249 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 250 | x = F.pad(x, convert_pad_shape(padding)) 251 | return x 252 | -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | # VITS2 | Preprocessing 2 | 3 | ## Todo 4 | 5 | - [x] text preprocessing 6 | - [x] update vocabulary to support all symbols and features from IPA. See [phonemes.md](https://github.com/espeak-ng/espeak-ng/blob/ed9a7bcf5778a188cdec202ac4316461badb28e1/docs/phonemes.md#L5) 7 | - [x] per dataset filelists preprocessing. Please refer [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb) 8 | - [x] handle unknown (out of vocabulary) symbols. Please refer [vocab - TorchText](https://pytorch.org/text/stable/vocab.html) 9 | - [x] handle special symbols in tokenizer. Please refer [text/symbols.py](text/symbols.py) 10 | - [ ] audio preprocessing 11 | - [x] replaced scipy and librosa dependencies with torchaudio. See docs [torchaudio.load](https://pytorch.org/audio/stable/backend.html#id2) and [torchaudio.transforms](https://pytorch.org/audio/stable/transforms.html) 12 | - [ ] remove necessity for speakers indexation. See [vits/issues/58](https://github.com/jaywalnut310/vits/issues/58) 13 | - [ ] update batch audio resampling. Please refer [audio_resample.ipynb](preprocess/audio_resample.ipynb) 14 | - [ ] test stereo audio (multi-channel) training 15 | 16 | # VITS2 | Preprocessing 17 | -------------------------------------------------------------------------------- /preprocess/audio_find_corrupted.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Check for corrupted audio files in dataset\n" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import torchaudio\n", 18 | "import concurrent.futures\n", 19 | "\n", 20 | "i_dir = \"path/to/your/dataset\"" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def check_wav(file_path):\n", 30 | " \"\"\"Load a .wav file and return if it's corrupted or not\"\"\"\n", 31 | " try:\n", 32 | " waveform, sample_rate = torchaudio.load(file_path)\n", 33 | " return (file_path, True)\n", 34 | " except Exception as e:\n", 35 | " return (file_path, False)\n", 36 | "\n", 37 | "\n", 38 | "def find_wavs(directory):\n", 39 | " \"\"\"Find all .wav files in a directory\"\"\"\n", 40 | " for foldername, subfolders, filenames in os.walk(directory):\n", 41 | " for filename in filenames:\n", 42 | " if filename.endswith(\".wav\"):\n", 43 | " yield os.path.join(foldername, filename)\n", 44 | "\n", 45 | "\n", 46 | "def main(directory):\n", 47 | " \"\"\"Check all .wav files in a directory and its subdirectories\"\"\"\n", 48 | " with concurrent.futures.ThreadPoolExecutor() as executor:\n", 49 | " wav_files = list(find_wavs(directory))\n", 50 | " future_to_file = {executor.submit(check_wav, wav): wav for wav in wav_files}\n", 51 | "\n", 52 | " done_count = 0\n", 53 | " for future in concurrent.futures.as_completed(future_to_file):\n", 54 | " file_path = future_to_file[future]\n", 55 | " try:\n", 56 | " is_valid = future.result()\n", 57 | " except Exception as exc:\n", 58 | " print(f\"{file_path} generated an exception: {exc}\")\n", 59 | " else:\n", 60 | " if not is_valid[1]:\n", 61 | " print(f\"Corrupted file: {file_path}\")\n", 62 | "\n", 63 | " done_count += 1\n", 64 | " if done_count % 5000 == 0:\n", 65 | " print(f\"Processed {done_count} files...\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "main(i_dir)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [] 83 | } 84 | ], 85 | "metadata": { 86 | "language_info": { 87 | "name": "python" 88 | }, 89 | "orig_nbformat": 4 90 | }, 91 | "nbformat": 4, 92 | "nbformat_minor": 2 93 | } 94 | -------------------------------------------------------------------------------- /preprocess/audio_resample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Resample audio wavs\n", 8 | "\n", 9 | "Refer to: [audio resampling tutorial](https://pytorch.org/audio/0.10.0/tutorials/audio_resampling_tutorial.html)\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import torchaudio\n", 20 | "import torchaudio.transforms as T\n", 21 | "import concurrent.futures\n", 22 | "from pathlib import Path\n", 23 | "import random" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# Example usage:\n", 33 | "input_directory = \"/path/to/dataset\"\n", 34 | "output_directory = f\"{input_directory}.cleaned\"\n", 35 | "orig_sr = 16000\n", 36 | "new_sr = 22050" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def resample_wav_files(input_dir, output_dir, sr, new_sr):\n", 46 | " # Create the output directory if it doesn't exist\n", 47 | " os.makedirs(output_dir, exist_ok=True)\n", 48 | "\n", 49 | " # Create a resampler object\n", 50 | " resampler = T.Resample(\n", 51 | " sr,\n", 52 | " new_sr,\n", 53 | " lowpass_filter_width=128,\n", 54 | " rolloff=0.99999,\n", 55 | " resampling_method=\"sinc_interp_hann\",\n", 56 | " )\n", 57 | "\n", 58 | " def resample_file(file_path):\n", 59 | " # Load the audio file\n", 60 | " waveform, sample_rate = torchaudio.load(file_path)\n", 61 | " assert sample_rate == sr\n", 62 | "\n", 63 | " # Resample the audio\n", 64 | " resampled_waveform = resampler(waveform)\n", 65 | "\n", 66 | " # Construct the output file path\n", 67 | " output_file = Path(output_dir) / Path(file_path).relative_to(input_dir)\n", 68 | "\n", 69 | " # Save the resampled audio\n", 70 | " torchaudio.save(output_file, resampled_waveform,\n", 71 | " new_sr, bits_per_sample=16)\n", 72 | "\n", 73 | " return output_file\n", 74 | "\n", 75 | " # Use generator to find .wav files and pre-create output directories\n", 76 | " def find_and_prep_wav_files(input_dir, output_dir):\n", 77 | " for root, _, files in os.walk(input_dir):\n", 78 | " for file in files:\n", 79 | " if file.endswith(\".wav\"):\n", 80 | " file_path = Path(root) / file\n", 81 | " output_file = Path(output_dir) / \\\n", 82 | " file_path.relative_to(input_dir)\n", 83 | " os.makedirs(output_file.parent, exist_ok=True)\n", 84 | " yield str(file_path)\n", 85 | "\n", 86 | " # Resample the .wav files using threads for parallel processing\n", 87 | " wav_files = find_and_prep_wav_files(input_dir, output_dir)\n", 88 | " with concurrent.futures.ThreadPoolExecutor() as executor:\n", 89 | " for i, output_file in enumerate(executor.map(resample_file, wav_files)):\n", 90 | " if i % 1000 == 0:\n", 91 | " print(f\"{i}: {output_file}\")\n", 92 | "\n", 93 | "\n", 94 | "resample_wav_files(input_directory, output_directory, orig_sr, new_sr)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Test random file to see if it worked\n", 104 | "out_path = os.path.join(output_directory, os.listdir(output_directory)[random.randint(0, len(os.listdir(output_directory)))])\n", 105 | "\n", 106 | "print(torchaudio.info(out_path))\n", 107 | "resampled_waveform, sample_rate = torchaudio.load(out_path)\n", 108 | "print(f\"max: {resampled_waveform.max()}, min: {resampled_waveform.min()}\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "g2p", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.11.4" 136 | }, 137 | "orig_nbformat": 4 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 2 141 | } 142 | -------------------------------------------------------------------------------- /preprocess/audio_resampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import os 4 | from concurrent.futures import ProcessPoolExecutor 5 | from multiprocessing import cpu_count 6 | 7 | import librosa 8 | import numpy as np 9 | from rich.progress import track 10 | from scipy.io import wavfile 11 | 12 | 13 | def load_wav(wav_path): 14 | return librosa.load(wav_path, sr=None) 15 | 16 | 17 | def trim_wav(wav, top_db=40): 18 | return librosa.effects.trim(wav, top_db=top_db) 19 | 20 | 21 | def normalize_peak(wav, threshold=1.0): 22 | peak = np.abs(wav).max() 23 | if peak > threshold: 24 | wav = 0.98 * wav / peak 25 | return wav 26 | 27 | 28 | def resample_wav(wav, sr, target_sr): 29 | return librosa.resample(wav, orig_sr=sr, target_sr=target_sr) 30 | 31 | 32 | def save_wav_to_path(wav, save_path, sr): 33 | wavfile.write(save_path, sr, (wav * np.iinfo(np.int16).max).astype(np.int16)) 34 | 35 | 36 | def process(item): 37 | spkdir, wav_name, args = item 38 | speaker = spkdir.replace("\\", "/").split("/")[-1] 39 | 40 | wav_path = os.path.join(args.in_dir, speaker, wav_name) 41 | if os.path.exists(wav_path) and ".wav" in wav_path: 42 | os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True) 43 | 44 | wav, sr = load_wav(wav_path) 45 | wav, _ = trim_wav(wav) 46 | wav = normalize_peak(wav) 47 | resampled_wav = resample_wav(wav, sr, args.sr2) 48 | 49 | if not args.skip_loudnorm: 50 | resampled_wav /= np.max(np.abs(resampled_wav)) 51 | 52 | save_path2 = os.path.join(args.out_dir2, speaker, wav_name) 53 | save_wav_to_path(resampled_wav, save_path2, args.sr2) 54 | 55 | 56 | """ 57 | def process_all_speakers(): 58 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) 59 | 60 | with ThreadPoolExecutor(max_workers=process_count) as executor: 61 | for speaker in speakers: 62 | spk_dir = os.path.join(args.in_dir, speaker) 63 | if os.path.isdir(spk_dir): 64 | print(spk_dir) 65 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] 66 | for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 67 | pass 68 | """ 69 | # multi process 70 | 71 | 72 | def process_all_speakers(): 73 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) 74 | with ProcessPoolExecutor(max_workers=process_count) as executor: 75 | for speaker in speakers: 76 | spk_dir = os.path.join(args.in_dir, speaker) 77 | if os.path.isdir(spk_dir): 78 | print(spk_dir) 79 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] 80 | for _ in track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"): 81 | pass 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--sr2", type=int, default=44100, help="sampling rate") 87 | parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir") 88 | parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir") 89 | parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it") 90 | args = parser.parse_args() 91 | 92 | print(f"CPU count: {cpu_count()}") 93 | speakers = os.listdir(args.in_dir) 94 | process_all_speakers() 95 | -------------------------------------------------------------------------------- /preprocess/mel_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import logging 5 | import argparse 6 | import traceback 7 | from tqdm import tqdm 8 | import torch 9 | import torch.multiprocessing as mp 10 | from concurrent.futures import ProcessPoolExecutor 11 | import torchaudio 12 | 13 | from utils.hparams import get_hparams_from_file, HParams 14 | from utils.mel_processing import wav_to_mel 15 | 16 | os.environ["OMP_NUM_THREADS"] = "1" 17 | log_format = "%(asctime)s %(message)s" 18 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt="%m/%d %I:%M:%S %p") 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--data_dir", type=str, required=True, help="Directory containing audio files") 24 | parser.add_argument("-c", "--config", type=str, required=True, help="YAML file for configuration") 25 | args = parser.parse_args() 26 | 27 | hparams = get_hparams_from_file(args.config) 28 | hparams.data_dir = args.data_dir 29 | return hparams 30 | 31 | 32 | def process_batch(batch, sr_hps, n_fft, hop_size, win_size, n_mels, fmin, fmax): 33 | wavs = [] 34 | for ifile in batch: 35 | try: 36 | wav, sr = torchaudio.load(ifile) 37 | assert sr == sr_hps, f"sample rate: {sr}, expected: {sr_hps}" 38 | wavs.append(wav) 39 | except: 40 | traceback.print_exc() 41 | print("Failed to process {}".format(ifile)) 42 | return None 43 | 44 | wav_lengths = torch.tensor([x.size(1) for x in wavs]) 45 | max_wav_len = wav_lengths.max() 46 | 47 | wav_padded = torch.zeros(len(batch), 1, max_wav_len) 48 | for i, wav in enumerate(wavs): 49 | wav_padded[i, :, : wav.size(1)] = wav 50 | 51 | spec = wav_to_mel(wav_padded, n_fft, n_mels, sr_hps, hop_size, win_size, fmin, fmax, center=False, norm=False) 52 | spec = torch.squeeze(spec, 1) 53 | 54 | for i, ifile in enumerate(batch): 55 | ofile = ifile.replace(".wav", ".spec.pt") 56 | spec_i = spec[i, :, : wav_lengths[i] // hop_size].clone() 57 | torch.save(spec_i, ofile) 58 | 59 | return batch 60 | 61 | 62 | def process_data(hps: HParams): 63 | wav_fns = sorted(glob.glob(f"{hps.data_dir}/**/*.wav", recursive=True)) 64 | # wav_fns = wav_fns[:100] # * Enable for testing 65 | logging.info(f"Max: {mp.cpu_count()}; using 32 CPU cores") 66 | logging.info(f"Preprocessing {len(wav_fns)} files...") 67 | 68 | sr = hps.data.sample_rate 69 | n_fft = hps.data.n_fft 70 | hop_size = hps.data.hop_length 71 | win_size = hps.data.win_length 72 | n_mels = hps.data.n_mels 73 | fmin = hps.data.f_min 74 | fmax = hps.data.f_max 75 | 76 | # Batch files to optimize disk I/O and computation 77 | batch_size = 128 # Change as needed 78 | audio_file_batches = [wav_fns[i : i + batch_size] for i in range(0, len(wav_fns), batch_size)] 79 | 80 | # Use multiprocessing to speed up the conversion 81 | with ProcessPoolExecutor(max_workers=32) as executor: 82 | futures = [executor.submit(process_batch, batch, sr, n_fft, hop_size, win_size, n_mels, fmin, fmax) for batch in audio_file_batches] 83 | for future in tqdm(futures): 84 | if future.result() is None: 85 | logging.warning(f"Failed to process a batch.") 86 | return 87 | 88 | 89 | def get_size_by_ext(directory, extension): 90 | total_size = 0 91 | for dirpath, dirnames, filenames in os.walk(directory): 92 | for f in filenames: 93 | if f.endswith(extension): 94 | fp = os.path.join(dirpath, f) 95 | total_size += os.path.getsize(fp) 96 | 97 | return total_size 98 | 99 | 100 | def human_readable_size(size): 101 | """Converts size in bytes to a human-readable format.""" 102 | for unit in ["B", "KB", "MB", "GB", "TB"]: 103 | if size < 1024: 104 | return f"{size:.2f}{unit}" 105 | size /= 1024 106 | return f"{size:.2f}PB" # PB is for petabyte, which will be used if the size is too large. 107 | 108 | 109 | if __name__ == "__main__": 110 | from time import time 111 | 112 | hps = parse_args() 113 | 114 | start = time() 115 | process_data(hps) 116 | logging.info(f"Processed data in {time() - start} seconds") 117 | 118 | extension = ".spec.pt" 119 | size_spec = get_size_by_ext(hps.data_dir, extension) 120 | logging.info(f"{extension}: \t{human_readable_size(size_spec)}") 121 | extension = ".wav" 122 | size_wav = get_size_by_ext(hps.data_dir, extension) 123 | logging.info(f"{extension}: \t{human_readable_size(size_wav)}") 124 | logging.info(f"Total: \t\t{human_readable_size(size_spec + size_wav)}") 125 | -------------------------------------------------------------------------------- /preprocess/vocab_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Create a list of all ipa symbols\n", 8 | "\n", 9 | "Please refer [phonemes.md](text/phonemes.md)\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Consonants\n", 19 | "\n", 20 | "consonants = \"\"\"\n", 21 | " m̥ | m | | ɱ | | | n̥ | n | | | ɳ̊ | ɳ | ɲ̟̊ | ɲ̟ | ɲ̊ | ɲ | ŋ̊ | ŋ | ɴ̥ | ɴ | | | | |\n", 22 | " p | b | p̪ | b̪ | t̪ | d̪ | t | d | | | ʈ | ɖ | | | c | ɟ | k | ɡ | q | ɢ | ʡ | | ʔ | |\n", 23 | " | | | | | | t͡s | d͡z | t͡ʃ | d͡ʒ | ʈ͡ʂ | ɖ͡ʐ | t͡ɕ | d͡ʑ | | | | | | | | | | |\n", 24 | " p͡ɸ | b͡β | p̪͡f | b̪͡v | t͡θ | d͡ð | | | | | | | | | c͡ç | ɟ͡ʝ | k͡x | ɡ͡ɣ | q͡χ | ɢ͡ʁ | ʡ͡ħ | ʡ͡ʕ | ʔ͡h | |\n", 25 | " | | | | | | t͡ɬ | d͡ɮ | | | ʈ͡ɭ̊˔ | | | | c͡ʎ̥˔ | | k͡ʟ̝̊ | ɡ͡ʟ̝ | | | | | | |\n", 26 | " | | | | | | s | z | ʃ | ʒ | ʂ | ʐ | ɕ | ʑ | | | | | | | | | | |\n", 27 | " ɸ | β | f | v | θ | ð | | | | | | | | | ç | ʝ | x | ɣ | χ | ʁ | ħ | ʕ | h | ɦ |\n", 28 | " | | | | | | ɬ | ɮ | | | ɭ̊˔ | | | | ʎ̥˔ | ʎ̝ | ʟ̝̊ | ʟ̝ | | | | | | |\n", 29 | " | | ʋ̥ | ʋ | | | ɹ̥ | ɹ | | | ɻ̊ | ɻ | | | j̊ | j | ɰ̊ | ɰ | | | | | | |\n", 30 | " | | | | | | l̥ | l | | | ɭ̊ | ɭ | | | ʎ̥ | ʎ | ʟ̥ | ʟ | | ʟ̠ | | | | |\n", 31 | " | ⱱ̟ | | ⱱ | | | ɾ̥ | ɾ | | | ɽ̊ | ɽ | | | | | | | | ɢ̆ | | ʡ̮ | | |\n", 32 | " | | | | | | | ɺ | | | | ɭ̆ | | | | ʎ̮ | | ʟ̆ | | | | | | |\n", 33 | " | ʙ | | | | | r̥ | r | | | ɽ͡r̥ | ɽ͡r | | | | | | | ʀ̥ | ʀ | ʜ | ʢ | | |\n", 34 | " ʘ | | | | ǀ | | ǃ | | | | | | ǂ | | | | | | | | | | | |\n", 35 | " | | | | | | ǁ | | | | | | | | | | | | | | | | | |\n", 36 | " | ɓ | | | | | | ɗ | | | | | | | | ʄ | | ɠ | | ʛ | | | | |\n", 37 | " pʼ | | | | | | tʼ | | | | ʈʼ | | | | cʼ | | kʼ | | qʼ | | ʡʼ | | | |\n", 38 | " | | fʼ | | θʼ | | sʼ | | ʃʼ | | ʂʼ | | | | | | xʼ | | χʼ | | | | | |\n", 39 | " | | | | | | ɬʼ | | | | | | | | | | | | | | | | | |\n", 40 | "\"\"\"\n", 41 | "\n", 42 | "consonants_other = \"\"\"\n", 43 | " | | | | | | | | | ŋ͡m | | |\n", 44 | " | | | | | | | | k͡p | ɡ͡b | | |\n", 45 | " p͡f | b͡v | | | | | | | | | | |\n", 46 | " | | | | ɧ | | | | | | | |\n", 47 | " | | | | | | | ɥ | | | ʍ | w |\n", 48 | " | | | ɫ | | | | | | | | |\n", 49 | "\"\"\"\n", 50 | "\n", 51 | "\n", 52 | "manner_of_articulation = \"\"\"\n", 53 | " ʼ |\n", 54 | "\"\"\"\n", 55 | "\n", 56 | "# Vowels\n", 57 | "\n", 58 | "vowels = \"\"\"\n", 59 | " i | y | ɨ | ʉ | ɯ | u |\n", 60 | " ɪ | ʏ | | | | ʊ |\n", 61 | " e | ø | ɘ | ɵ | ɤ | o |\n", 62 | " | | ə | | | |\n", 63 | " ɛ | œ | ɜ | ɞ | ʌ | ɔ |\n", 64 | " æ | | ɐ | | | |\n", 65 | " a | ɶ | | | ɑ | ɒ |\n", 66 | "\"\"\"\n", 67 | "\n", 68 | "\n", 69 | "vowels_other = \"\"\"\n", 70 | "| ɚ |\n", 71 | "| ɝ |\n", 72 | "\"\"\"\n", 73 | "\n", 74 | "# Diacritics\n", 75 | "\n", 76 | "articulation = \"\"\"\n", 77 | " ◌̼ |\n", 78 | " ◌̪͆ |\n", 79 | " ◌̪ |\n", 80 | " ◌̺ |\n", 81 | " ◌̻ |\n", 82 | " ◌̟ |\n", 83 | " ◌̠ |\n", 84 | " ◌̈ |\n", 85 | " ◌̽ |\n", 86 | " ◌̝ |\n", 87 | " ◌̞ |\n", 88 | "\"\"\"\n", 89 | "\n", 90 | "air_flow = \"\"\"\n", 91 | " ↑ |\n", 92 | " ↓ |\n", 93 | "\"\"\"\n", 94 | "\n", 95 | "phonation = \"\"\"\n", 96 | " ◌̤ |\n", 97 | " ◌̥ |\n", 98 | " ◌̬ |\n", 99 | " ◌̰ |\n", 100 | " ʔ͡◌ |\n", 101 | "\"\"\"\n", 102 | "\n", 103 | "rounding_and_labialization = \"\"\"\n", 104 | " ◌ʷ◌ᶣ |\n", 105 | " ◌ᵝ |\n", 106 | " ◌̹ |\n", 107 | " ◌̜ |\n", 108 | "\"\"\"\n", 109 | "\n", 110 | "\n", 111 | "syllabicity = \"\"\"\n", 112 | " ◌̩ |\n", 113 | " ◌̯ |\n", 114 | "\"\"\"\n", 115 | "\n", 116 | "consonant_release = \"\"\"\n", 117 | " ◌ʰ |\n", 118 | " ◌ⁿ |\n", 119 | " ◌ˡ |\n", 120 | " ◌̚ |\n", 121 | "\"\"\"\n", 122 | "\n", 123 | "co_articulation = \"\"\"\n", 124 | " ◌ʲ |\n", 125 | " ◌ˠ◌̴ |\n", 126 | " ◌ˤ◌̴ |\n", 127 | " ◌̃ |\n", 128 | " ◌˞ |\n", 129 | "\"\"\"\n", 130 | "\n", 131 | "tongue_root = \"\"\"\n", 132 | " ◌̘ |\n", 133 | " ◌̙ |\n", 134 | "\"\"\"\n", 135 | "\n", 136 | "fortis_and_lenis = \"\"\"\n", 137 | " ◌͈ |\n", 138 | " ◌͉ |\n", 139 | "\"\"\"\n", 140 | "\n", 141 | "# Suprasegmentals\n", 142 | "\n", 143 | "stress = \"\"\"\n", 144 | " ˈ◌ |\n", 145 | " ˌ◌ |\n", 146 | "\"\"\"\n", 147 | "\n", 148 | "length = \"\"\"\n", 149 | " ◌̆ |\n", 150 | " ◌ˑ |\n", 151 | " ◌ː |\n", 152 | " ◌ːː |\n", 153 | "\"\"\"\n", 154 | "\n", 155 | "rhythm = \"\"\"\n", 156 | " . |\n", 157 | " ◌‿◌ |\n", 158 | "\"\"\"\n", 159 | "\n", 160 | "tones = \"\"\"\n", 161 | " ◌˥ |\n", 162 | " ◌˦ |\n", 163 | " ◌˧ |\n", 164 | " ◌˨ |\n", 165 | " ◌˩ |\n", 166 | " ꜛ◌ |\n", 167 | " ꜜ◌ |\n", 168 | "\"\"\"\n", 169 | "\n", 170 | "intonation = \"\"\"\n", 171 | " | |\n", 172 | " ‖ |\n", 173 | " ↗︎ |\n", 174 | " ↘︎ |\n", 175 | "\"\"\"" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "def get_non_empty_content(md_table):\n", 185 | " non_empty_content = []\n", 186 | "\n", 187 | " # Split table by lines\n", 188 | " lines = md_table.split(\"\\n\")\n", 189 | "\n", 190 | " for line in lines:\n", 191 | " # Split each line by \"|\" to get the cells\n", 192 | " cells = line.split(\"|\")\n", 193 | "\n", 194 | " for cell in cells:\n", 195 | " cell_content = cell.strip()\n", 196 | "\n", 197 | " # If the cell content is not empty, add it to the list\n", 198 | " if cell_content != \"\":\n", 199 | " non_empty_content.append(cell_content)\n", 200 | "\n", 201 | " non_empty_content = \"\".join(non_empty_content)\n", 202 | "\n", 203 | " # unique non_empty_content\n", 204 | " non_empty_content = set(non_empty_content)\n", 205 | " non_empty_content = \"\".join(non_empty_content)\n", 206 | "\n", 207 | " # sort non_empty_content\n", 208 | " non_empty_content = sorted(non_empty_content)\n", 209 | "\n", 210 | " return non_empty_content\n", 211 | "\n", 212 | "\n", 213 | "# Consonants\n", 214 | "consonants = get_non_empty_content(consonants)\n", 215 | "consonants_other = get_non_empty_content(consonants_other)\n", 216 | "manner_of_articulation = get_non_empty_content(manner_of_articulation)\n", 217 | "# Vowels\n", 218 | "vowels = get_non_empty_content(vowels)\n", 219 | "vowels_other = get_non_empty_content(vowels_other)\n", 220 | "# Diacritics\n", 221 | "articulation = get_non_empty_content(articulation)\n", 222 | "air_flow = get_non_empty_content(air_flow)\n", 223 | "phonation = get_non_empty_content(phonation)\n", 224 | "rounding_and_labialization = get_non_empty_content(rounding_and_labialization)\n", 225 | "syllabicity = get_non_empty_content(syllabicity)\n", 226 | "consonant_release = get_non_empty_content(consonant_release)\n", 227 | "co_articulation = get_non_empty_content(co_articulation)\n", 228 | "tongue_root = get_non_empty_content(tongue_root)\n", 229 | "fortis_and_lenis = get_non_empty_content(fortis_and_lenis)\n", 230 | "# Suprasegmentals\n", 231 | "stress = get_non_empty_content(stress)\n", 232 | "length = get_non_empty_content(length)\n", 233 | "rhythm = get_non_empty_content(rhythm)\n", 234 | "tones = get_non_empty_content(tones)\n", 235 | "intonation = get_non_empty_content(intonation)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# All symbols\n", 245 | "_ipa = (\n", 246 | " consonants\n", 247 | " + consonants_other\n", 248 | " + manner_of_articulation\n", 249 | " + vowels\n", 250 | " + vowels_other\n", 251 | " + articulation\n", 252 | " + air_flow\n", 253 | " + phonation\n", 254 | " + rounding_and_labialization\n", 255 | " + syllabicity\n", 256 | " + consonant_release\n", 257 | " + co_articulation\n", 258 | " + tongue_root\n", 259 | " + fortis_and_lenis\n", 260 | " + stress\n", 261 | " + length\n", 262 | " + rhythm\n", 263 | " + tones\n", 264 | " + intonation\n", 265 | ")\n", 266 | "\n", 267 | "print(_ipa)\n", 268 | "print(len(_ipa))" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "_ipa = \"\".join(_ipa)\n", 278 | "\n", 279 | "# unique _ipa\n", 280 | "_ipa = set(_ipa)\n", 281 | "_ipa = \"\".join(_ipa)\n", 282 | "\n", 283 | "# sort symbols\n", 284 | "_ipa = sorted(_ipa)\n", 285 | "\n", 286 | "print(f'_ipa = \"{\"\".join(_ipa)}\"')\n", 287 | "print(len(_ipa))" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "_punctuation = ';:,.!?¡¿—…\"«»“” '\n", 297 | "\n", 298 | "symbols = list(_punctuation) + list(_ipa)\n", 299 | "\n", 300 | "# unique symbols\n", 301 | "symbols = set(symbols)\n", 302 | "symbols = \"\".join(symbols)\n", 303 | "\n", 304 | "# sort symbols\n", 305 | "symbols = sorted(symbols)\n", 306 | "\n", 307 | "symbols = \"\".join(symbols)\n", 308 | "\n", 309 | "print(f'symbols = \"{\"\".join(symbols)}\"')\n", 310 | "print(len(symbols))" 311 | ] 312 | } 313 | ], 314 | "metadata": { 315 | "kernelspec": { 316 | "display_name": "py11", 317 | "language": "python", 318 | "name": "python3" 319 | }, 320 | "language_info": { 321 | "codemirror_mode": { 322 | "name": "ipython", 323 | "version": 3 324 | }, 325 | "file_extension": ".py", 326 | "mimetype": "text/x-python", 327 | "name": "python", 328 | "nbconvert_exporter": "python", 329 | "pygments_lexer": "ipython3", 330 | "version": "3.11.4" 331 | }, 332 | "orig_nbformat": 4 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 2 336 | } 337 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | torchtext 5 | 6 | phonemizer 7 | inflect 8 | pandas 9 | 10 | numpy 11 | numba 12 | matplotlib 13 | 14 | tensorboard 15 | tensorboardX 16 | 17 | tqdm 18 | PyYAML 19 | ipykernel 20 | pytorch_lightning -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from text import cleaners 3 | from torchtext.vocab import Vocab 4 | 5 | 6 | def tokenizer(text: str, vocab: Vocab, cleaner_names: List[str], language="en-us", cleaned_text=False) -> List[int]: 7 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 8 | Args: 9 | text: string to convert to a sequence of IDs 10 | cleaner_names: names of the cleaner functions from text/cleaners.py 11 | language: language ID from https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md 12 | cleaned_text: whether the text has already been cleaned 13 | Returns: 14 | List of integers corresponding to the symbols in the text 15 | """ 16 | if not cleaned_text: 17 | return _clean_text(text, vocab, cleaner_names, language=language) 18 | else: 19 | return list(map(int, text.split("\t"))) 20 | 21 | 22 | def detokenizer(sequence: List[int], vocab: Vocab) -> str: 23 | """Converts a sequence of tokens back to a string""" 24 | return "".join(vocab.lookup_tokens(sequence)) 25 | 26 | 27 | def _clean_text(text: str, vocab: Vocab, cleaner_names: List[str], language="en-us") -> str: 28 | for name in cleaner_names: 29 | cleaner = getattr(cleaners, name) 30 | assert callable(cleaner), f"Unknown cleaner: {name}" 31 | text = cleaner(text, vocab=vocab, language=language) 32 | return text 33 | 34 | 35 | if __name__ == "__main__": 36 | from utils.task import load_vocab 37 | 38 | vocab = load_vocab("datasets/ljs_base/vocab.txt") 39 | cleaner_names = ["phonemize_text", "add_spaces", "tokenize_text", "delete_unks", "add_bos_eos", "detokenize_sequence"] 40 | text = "Well, I like pizza. You know … Who doesn't like pizza? " 41 | print(tokenizer(text, vocab, cleaner_names, language="en-us", cleaned_text=False)) 42 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cleaners are transformations that run over the input text at both training and eval time. 3 | 4 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 5 | hyperparameter. 6 | """ 7 | 8 | import re 9 | from typing import List 10 | from torchtext.vocab import Vocab 11 | from phonemizer import phonemize 12 | from phonemizer.separator import Separator 13 | 14 | 15 | from text.normalize_numbers import normalize_numbers 16 | 17 | from text.symbols import _punctuation, PAD_ID, UNK_ID, BOS_ID, EOS_ID 18 | 19 | 20 | _whitespace_re = re.compile(r"\s+") 21 | _preserved_symbols_re = re.compile(rf"[{_punctuation}]|<.*?>") 22 | separator = Separator(word="", phone=" ") 23 | 24 | 25 | # ---------------------------------------------------------------------------- # 26 | # | Text cleaners | # 27 | # ---------------------------------------------------------------------------- # 28 | def lowercase(text: str, *args, **kwargs): 29 | return text.lower() 30 | 31 | 32 | def collapse_whitespace(text: str, *args, **kwargs): 33 | return re.sub(_whitespace_re, " ", text) 34 | 35 | 36 | def expand_numbers(text: str, *args, **kwargs): 37 | return normalize_numbers(text) 38 | 39 | 40 | def phonemize_text(text: List[str] | str, *args, language="en-us", **kwargs): 41 | return phonemize(text, language=language, backend="espeak", separator=separator, strip=True, preserve_punctuation=True, punctuation_marks=_preserved_symbols_re, with_stress=True, njobs=8) 42 | 43 | 44 | def add_spaces(text: str, *args, **kwargs): 45 | spaced_text = re.sub(_preserved_symbols_re, r" \g<0> ", text) 46 | cleaned_text = re.sub(_whitespace_re, " ", spaced_text) 47 | return cleaned_text.strip() 48 | 49 | 50 | # ---------------------------------------------------------------------------- # 51 | # | Token cleaners | # 52 | # ---------------------------------------------------------------------------- # 53 | 54 | 55 | def tokenize_text(text: str, vocab: Vocab, *args, **kwargs): 56 | tokens = text.split() 57 | return vocab(tokens) 58 | 59 | 60 | def add_bos_eos(tokens: List[int], *args, **kwargs): 61 | return [BOS_ID] + tokens + [EOS_ID] 62 | 63 | 64 | def add_blank(tokens: List[int], *args, **kwargs): 65 | result = [PAD_ID] * (len(tokens) * 2 + 1) 66 | result[1::2] = tokens 67 | return result 68 | 69 | 70 | def delete_unks(tokens: List[int], *args, **kwargs): 71 | return [token for token in tokens if token != UNK_ID] 72 | 73 | 74 | def detokenize_sequence(sequence: List[int], vocab: Vocab, *args, **kwargs): 75 | return "".join(vocab.lookup_tokens(sequence)) 76 | -------------------------------------------------------------------------------- /text/normalize_numbers.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | import re 3 | 4 | 5 | _inflect = inflect.engine() 6 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 7 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 8 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 9 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 10 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 11 | _number_re = re.compile(r"[0-9]+") 12 | 13 | 14 | def _remove_commas(m): 15 | return m.group(1).replace(",", "") 16 | 17 | 18 | def _expand_decimal_point(m): 19 | return m.group(1).replace(".", " point ") 20 | 21 | 22 | def _expand_dollars(m): 23 | match = m.group(1) 24 | parts = match.split(".") 25 | if len(parts) > 2: 26 | return match + " dollars" # Unexpected format 27 | dollars = int(parts[0]) if parts[0] else 0 28 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 29 | if dollars and cents: 30 | dollar_unit = "dollar" if dollars == 1 else "dollars" 31 | cent_unit = "cent" if cents == 1 else "cents" 32 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 33 | elif dollars: 34 | dollar_unit = "dollar" if dollars == 1 else "dollars" 35 | return "%s %s" % (dollars, dollar_unit) 36 | elif cents: 37 | cent_unit = "cent" if cents == 1 else "cents" 38 | return "%s %s" % (cents, cent_unit) 39 | else: 40 | return "zero dollars" 41 | 42 | 43 | def _expand_ordinal(m): 44 | return _inflect.number_to_words(m.group(0)) 45 | 46 | 47 | def _expand_number(m): 48 | num = int(m.group(0)) 49 | if num > 1000 and num < 3000: 50 | if num == 2000: 51 | return "two thousand" 52 | elif num > 2000 and num < 2010: 53 | return "two thousand " + _inflect.number_to_words(num % 100) 54 | elif num % 100 == 0: 55 | return _inflect.number_to_words(num // 100) + " hundred" 56 | else: 57 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 58 | else: 59 | return _inflect.number_to_words(num, andword="") 60 | 61 | 62 | def normalize_numbers(text): 63 | text = re.sub(_comma_number_re, _remove_commas, text) 64 | text = re.sub(_pounds_re, r"\1 pounds", text) 65 | text = re.sub(_dollars_re, _expand_dollars, text) 66 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 67 | text = re.sub(_ordinal_re, _expand_ordinal, text) 68 | text = re.sub(_number_re, _expand_number, text) 69 | return text 70 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ 2 | Set of symbols 3 | """ 4 | _punctuation = ';:,.!?¡¿—…"«»“”' 5 | 6 | 7 | """ 8 | Special symbols 9 | """ 10 | # Define special symbols and indices 11 | special_symbols = ["", "", "", "", "", ""] 12 | PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import torch 4 | from torch import nn, optim 5 | from torch.nn import functional as F 6 | from torch.utils.data import DataLoader 7 | from torch.utils.tensorboard import SummaryWriter 8 | import torch.multiprocessing as mp 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.cuda.amp import autocast, GradScaler 12 | from typing import List 13 | 14 | import utils.task as task 15 | from utils.hparams import get_hparams 16 | from model.models import SynthesizerTrn 17 | from model.discriminator import MultiPeriodDiscriminator 18 | from data_utils import TextAudioLoader, TextAudioCollate, DistributedBucketSampler 19 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss, kl_loss_normal 20 | from utils.mel_processing import wav_to_mel, spec_to_mel, spectral_norm 21 | from utils.model import slice_segments, clip_grad_value_ 22 | 23 | 24 | torch.backends.cudnn.benchmark = True 25 | global_step = 0 26 | 27 | 28 | def main(): 29 | """Assume Single Node Multi GPUs Training Only""" 30 | assert torch.cuda.is_available(), "CPU training is not allowed." 31 | 32 | n_gpus = torch.cuda.device_count() 33 | os.environ["MASTER_ADDR"] = "localhost" 34 | os.environ["MASTER_PORT"] = "8000" 35 | 36 | hps = get_hparams() 37 | mp.spawn( 38 | run, 39 | nprocs=n_gpus, 40 | args=( 41 | n_gpus, 42 | hps, 43 | ), 44 | ) 45 | 46 | 47 | def run(rank, n_gpus, hps): 48 | global global_step 49 | if rank == 0: 50 | logger = task.get_logger(hps.model_dir) 51 | logger.info(hps) 52 | task.check_git_hash(hps.model_dir) 53 | writer = SummaryWriter(log_dir=hps.model_dir) 54 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 55 | 56 | dist.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank) 57 | torch.manual_seed(hps.train.seed) 58 | torch.cuda.set_device(rank) 59 | 60 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data) 61 | train_sampler = DistributedBucketSampler(train_dataset, hps.train.batch_size, [32, 300, 400, 500, 600, 700, 800, 900, 1000], num_replicas=n_gpus, rank=rank, shuffle=True) 62 | collate_fn = TextAudioCollate() 63 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler) 64 | if rank == 0: 65 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data) 66 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn) 67 | 68 | net_g = SynthesizerTrn(len(train_dataset.vocab), hps.data.n_mels if hps.data.use_mel else hps.data.n_fft // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model).cuda(rank) 69 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 70 | optim_g = torch.optim.AdamW(net_g.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps) 71 | optim_d = torch.optim.AdamW(net_d.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps) 72 | net_g = DDP(net_g, device_ids=[rank]) 73 | net_d = DDP(net_d, device_ids=[rank]) 74 | 75 | try: 76 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) 77 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) 78 | global_step = (epoch_str - 1) * len(train_loader) 79 | net_g.module.mas_noise_scale = max(hps.model.mas_noise_scale - global_step * hps.model.mas_noise_scale_decay, 0.0) 80 | except: 81 | epoch_str = 1 82 | global_step = 0 83 | 84 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # TODO: check 85 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 86 | 87 | scaler = GradScaler(enabled=hps.train.fp16_run) 88 | 89 | for epoch in range(epoch_str, hps.train.epochs + 1): 90 | if rank == 0: 91 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) 92 | else: 93 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) 94 | scheduler_g.step() 95 | scheduler_d.step() 96 | 97 | 98 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): 99 | net_g, net_d = nets 100 | optim_g, optim_d = optims 101 | scheduler_g, scheduler_d = schedulers 102 | train_loader, eval_loader = loaders 103 | if writers is not None: 104 | writer, writer_eval = writers 105 | 106 | train_loader.batch_sampler.set_epoch(epoch) 107 | global global_step 108 | 109 | net_g.train() 110 | net_d.train() 111 | if rank == 0: 112 | loader = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}") 113 | else: 114 | loader = train_loader 115 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(loader): 116 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True) 117 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) 118 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) 119 | 120 | with autocast(enabled=hps.train.fp16_run): 121 | ( 122 | y_hat, 123 | l_length, 124 | attn, 125 | ids_slice, 126 | x_mask, 127 | z_mask, 128 | (m_p_text, logs_p_text), 129 | (m_p_dur, logs_p_dur, z_q_dur, logs_q_dur), 130 | (m_p_audio, logs_p_audio, m_q_audio, logs_q_audio), 131 | ) = net_g(x, x_lengths, spec, spec_lengths) 132 | 133 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max) 134 | y_hat_mel = wav_to_mel(y_hat.squeeze(1), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max) 135 | 136 | y_mel = slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) 137 | y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 138 | 139 | # Discriminator 140 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 141 | with autocast(enabled=False): 142 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) 143 | loss_disc_all = loss_disc 144 | optim_d.zero_grad() 145 | scaler.scale(loss_disc_all).backward() 146 | scaler.unscale_(optim_d) 147 | grad_norm_d = clip_grad_value_(net_d.parameters(), None) 148 | scaler.step(optim_d) 149 | 150 | with autocast(enabled=hps.train.fp16_run): 151 | # Generator 152 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 153 | with autocast(enabled=False): 154 | loss_dur = torch.sum(l_length.float()) 155 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 156 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 157 | 158 | # TODO Test gain constant 159 | if False: 160 | loss_kl_text = kl_loss_normal(m_q_text, logs_q_text, m_p_text, logs_p_text, x_mask) * hps.train.c_kl_text 161 | loss_kl_dur = kl_loss(z_q_dur, logs_q_dur, m_p_dur, logs_p_dur, z_mask) * hps.train.c_kl_dur 162 | loss_kl_audio = kl_loss_normal(m_p_audio, logs_p_audio, m_q_audio, logs_q_audio, z_mask) * hps.train.c_kl_audio 163 | 164 | loss_fm = feature_loss(fmap_r, fmap_g) 165 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl_dur + loss_kl_audio # TODO + loss_kl_text 166 | optim_g.zero_grad() 167 | scaler.scale(loss_gen_all).backward() 168 | scaler.unscale_(optim_g) 169 | grad_norm_g = clip_grad_value_(net_g.parameters(), None) 170 | scaler.step(optim_g) 171 | scaler.update() 172 | 173 | if rank == 0: 174 | if global_step % hps.train.log_interval == 0: 175 | lr = optim_g.param_groups[0]["lr"] 176 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl_dur, loss_kl_audio] # TODO loss_kl_text 177 | losses_str = " ".join(f"{loss.item():.3f}" for loss in losses) 178 | loader.set_postfix_str(f"{losses_str}, {global_step}, {lr:.9f}") 179 | 180 | # scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} 181 | # scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl_dur}) 182 | 183 | # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) 184 | # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) 185 | # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) 186 | # image_dict = { 187 | # "slice/mel_org": task.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), 188 | # "slice/mel_gen": task.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), 189 | # "all/mel": task.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), 190 | # "all/attn": task.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy()), 191 | # } 192 | # task.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, sample_rate=hps.data.sample_rate) 193 | 194 | # Save checkpoint on CPU to prevent GPU OOM 195 | if global_step % hps.train.eval_interval == 0: 196 | # evaluate(hps, net_g, eval_loader, writer_eval) 197 | task.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 198 | task.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) 199 | global_step += 1 200 | 201 | 202 | def evaluate(hps, generator, eval_loader, writer_eval): 203 | generator.eval() 204 | with torch.no_grad(): 205 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(eval_loader): 206 | x, x_lengths = x.cuda(0), x_lengths.cuda(0) 207 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) 208 | y, y_lengths = y.cuda(0), y_lengths.cuda(0) 209 | 210 | # remove else 211 | x = x[:1] 212 | x_lengths = x_lengths[:1] 213 | spec = spec[:1] 214 | spec_lengths = spec_lengths[:1] 215 | y = y[:1] 216 | y_lengths = y_lengths[:1] 217 | break 218 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, max_len=1000) 219 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length 220 | 221 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max) 222 | y_hat_mel = wav_to_mel(y_hat.squeeze(1).float(), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max) 223 | image_dict = {"gen/mel": task.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())} 224 | audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]} 225 | if global_step == 0: 226 | image_dict.update({"gt/mel": task.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) 227 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]}) 228 | 229 | task.summarize(writer=writer_eval, global_step=global_step, images=image_dict, audios=audio_dict, sample_rate=hps.data.sample_rate) 230 | generator.train() 231 | 232 | 233 | if __name__ == "__main__": 234 | main() 235 | -------------------------------------------------------------------------------- /train_ms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import logging 4 | import torch 5 | from torch import nn, optim 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | import torch.multiprocessing as mp 10 | import torch.distributed as dist 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | from torch.cuda.amp import autocast, GradScaler 13 | from typing import List 14 | 15 | import utils.task as task 16 | from utils.hparams import get_hparams 17 | from model.models import SynthesizerTrn 18 | from model.discriminator import MultiPeriodDiscriminator 19 | from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate, DistributedBucketSampler 20 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss, kl_loss_normal 21 | from utils.mel_processing import wav_to_mel, spec_to_mel, spectral_norm 22 | from utils.model import slice_segments, clip_grad_value_ 23 | 24 | 25 | torch.backends.cudnn.benchmark = True 26 | global_step = 0 27 | 28 | 29 | def main(): 30 | """Assume Single Node Multi GPUs Training Only""" 31 | assert torch.cuda.is_available(), "CPU training is not allowed." 32 | 33 | n_gpus = torch.cuda.device_count() 34 | os.environ["MASTER_ADDR"] = "localhost" 35 | os.environ["MASTER_PORT"] = "8000" 36 | 37 | hps = get_hparams() 38 | mp.spawn( 39 | run, 40 | nprocs=n_gpus, 41 | args=( 42 | n_gpus, 43 | hps, 44 | ), 45 | ) 46 | 47 | 48 | def run(rank, n_gpus, hps): 49 | global global_step 50 | if rank == 0: 51 | logger = task.get_logger(hps.model_dir) 52 | logger.info(hps) 53 | task.check_git_hash(hps.model_dir) 54 | writer = SummaryWriter(log_dir=hps.model_dir) 55 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 56 | 57 | dist.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank) 58 | torch.manual_seed(hps.train.seed) 59 | torch.cuda.set_device(rank) 60 | 61 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) 62 | train_sampler = DistributedBucketSampler(train_dataset, hps.train.batch_size, [32, 300, 400, 500, 600, 700, 800, 900, 1000], num_replicas=n_gpus, rank=rank, shuffle=True) 63 | collate_fn = TextAudioSpeakerCollate() 64 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler) 65 | if rank == 0: 66 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data) 67 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn) 68 | 69 | net_g = SynthesizerTrn( 70 | len(train_dataset.vocab), hps.data.n_mels if hps.data.use_mel else hps.data.n_fft // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model 71 | ).cuda(rank) 72 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 73 | optim_g = torch.optim.AdamW(net_g.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps) 74 | optim_d = torch.optim.AdamW(net_d.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps) 75 | net_g = DDP(net_g, device_ids=[rank]) 76 | net_d = DDP(net_d, device_ids=[rank]) 77 | 78 | try: 79 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) 80 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) 81 | global_step = (epoch_str - 1) * len(train_loader) 82 | net_g.module.mas_noise_scale = max(hps.model.mas_noise_scale - global_step * hps.model.mas_noise_scale_decay, 0.0) 83 | except: 84 | epoch_str = 1 85 | global_step = 0 86 | 87 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 88 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 89 | 90 | scaler = GradScaler(enabled=hps.train.fp16_run) 91 | 92 | for epoch in range(epoch_str, hps.train.epochs + 1): 93 | if rank == 0: 94 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) 95 | else: 96 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) 97 | scheduler_g.step() 98 | scheduler_d.step() 99 | 100 | 101 | def train_and_evaluate(rank, epoch, hps, nets: List[torch.nn.parallel.DistributedDataParallel], optims: List[torch.optim.Optimizer], schedulers, scaler: GradScaler, loaders, logger: logging.Logger, writers): 102 | net_g, net_d = nets 103 | 104 | optim_g, optim_d = optims 105 | scheduler_g, scheduler_d = schedulers 106 | train_loader, eval_loader = loaders 107 | if writers is not None: 108 | writer, writer_eval = writers 109 | 110 | train_loader.batch_sampler.set_epoch(epoch) 111 | global global_step 112 | 113 | net_g.train() 114 | net_d.train() 115 | if rank == 0: 116 | loader = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}") 117 | else: 118 | loader = train_loader 119 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(loader): 120 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True) 121 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) 122 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) 123 | speakers = speakers.cuda(rank, non_blocking=True) 124 | 125 | with autocast(enabled=hps.train.fp16_run): 126 | ( 127 | y_hat, 128 | l_length, 129 | attn, 130 | ids_slice, 131 | x_mask, 132 | z_mask, 133 | (m_p_text, logs_p_text), 134 | (m_p_dur, logs_p_dur, z_q_dur, logs_q_dur), 135 | (m_p_audio, logs_p_audio, m_q_audio, logs_q_audio), 136 | ) = net_g(x, x_lengths, spec, spec_lengths, speakers) 137 | 138 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max) 139 | y_hat_mel = wav_to_mel(y_hat.squeeze(1), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max) 140 | 141 | y_mel = slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) 142 | y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 143 | 144 | # Discriminator 145 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 146 | with autocast(enabled=False): 147 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) 148 | loss_disc_all = loss_disc 149 | optim_d.zero_grad() 150 | scaler.scale(loss_disc_all).backward() 151 | scaler.unscale_(optim_d) 152 | grad_norm_d = clip_grad_value_(net_d.parameters(), None) 153 | scaler.step(optim_d) 154 | 155 | with autocast(enabled=hps.train.fp16_run): 156 | # Generator 157 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 158 | with autocast(enabled=False): 159 | loss_dur = torch.sum(l_length.float()) 160 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 161 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 162 | 163 | # TODO Test gain constant 164 | if False: 165 | loss_kl_text = kl_loss_normal(m_q_text, logs_q_text, m_p_text, logs_p_text, x_mask) * hps.train.c_kl_text 166 | loss_kl_dur = kl_loss(z_q_dur, logs_q_dur, m_p_dur, logs_p_dur, z_mask) * hps.train.c_kl_dur 167 | loss_kl_audio = kl_loss_normal(m_p_audio, logs_p_audio, m_q_audio, logs_q_audio, z_mask) * hps.train.c_kl_audio 168 | 169 | loss_fm = feature_loss(fmap_r, fmap_g) 170 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl_dur + loss_kl_audio # TODO + loss_kl_text 171 | optim_g.zero_grad() 172 | scaler.scale(loss_gen_all).backward() 173 | scaler.unscale_(optim_g) 174 | grad_norm_g = clip_grad_value_(net_g.parameters(), None) 175 | scaler.step(optim_g) 176 | scaler.update() 177 | 178 | if rank == 0: 179 | if global_step % hps.train.log_interval == 0: 180 | lr = optim_g.param_groups[0]["lr"] 181 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl_dur, loss_kl_audio] # TODO loss_kl_text 182 | losses_str = " ".join(f"{loss.item():.3f}" for loss in losses) 183 | loader.set_postfix_str(f"{losses_str}, {global_step}, {lr:.9f}") 184 | 185 | # scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} 186 | # scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl_dur}) 187 | 188 | # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) 189 | # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) 190 | # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) 191 | # image_dict = { 192 | # "slice/mel_org": task.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), 193 | # "slice/mel_gen": task.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), 194 | # "all/mel": task.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), 195 | # "all/attn": task.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy()), 196 | # } 197 | # task.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, sample_rate=hps.data.sample_rate) 198 | 199 | # Save checkpoint on CPU to prevent GPU OOM 200 | if global_step % hps.train.eval_interval == 0: 201 | # evaluate(hps, net_g, eval_loader, writer_eval) 202 | task.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 203 | task.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) 204 | global_step += 1 205 | 206 | 207 | def evaluate(hps, generator, eval_loader, writer_eval): 208 | generator.eval() 209 | with torch.no_grad(): 210 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader): 211 | x, x_lengths = x.cuda(0), x_lengths.cuda(0) 212 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) 213 | y, y_lengths = y.cuda(0), y_lengths.cuda(0) 214 | speakers = speakers.cuda(0) 215 | 216 | # remove else 217 | x = x[:1] 218 | x_lengths = x_lengths[:1] 219 | spec = spec[:1] 220 | spec_lengths = spec_lengths[:1] 221 | y = y[:1] 222 | y_lengths = y_lengths[:1] 223 | speakers = speakers[:1] 224 | break 225 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000) 226 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length 227 | 228 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max) 229 | y_hat_mel = wav_to_mel(y_hat.squeeze(1).float(), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max) 230 | image_dict = {"gen/mel": task.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())} 231 | audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]} 232 | if global_step == 0: 233 | image_dict.update({"gt/mel": task.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) 234 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]}) 235 | 236 | task.summarize(writer=writer_eval, global_step=global_step, images=image_dict, audios=audio_dict, sample_rate=hps.data.sample_rate) 237 | generator.train() 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import argparse 5 | import os 6 | import yaml 7 | 8 | 9 | class HParams: 10 | def __init__(self, **kwargs): 11 | for k, v in kwargs.items(): 12 | if type(v) == dict: 13 | v = HParams(**v) 14 | self[k] = v 15 | 16 | def keys(self): 17 | return self.__dict__.keys() 18 | 19 | def items(self): 20 | return self.__dict__.items() 21 | 22 | def values(self): 23 | return self.__dict__.values() 24 | 25 | def __len__(self): 26 | return len(self.__dict__) 27 | 28 | def __getitem__(self, key): 29 | return getattr(self, key) 30 | 31 | def __setitem__(self, key, value): 32 | return setattr(self, key, value) 33 | 34 | def __contains__(self, key): 35 | return key in self.__dict__ 36 | 37 | def __repr__(self): 38 | return self.__dict__.__repr__() 39 | 40 | 41 | def get_hparams() -> HParams: 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("-c", "--config", type=str, default="./datasets/base/config.yaml", help="YAML file for configuration") 44 | parser.add_argument("-m", "--model", type=str, required=True, help="Model name") 45 | args = parser.parse_args() 46 | 47 | # assert that path cnsists directory "datasets" and file "config.yaml 48 | assert os.path.exists("./datasets"), "`datasets` directory not found, navigate to the root of the project." 49 | assert os.path.exists(f"./datasets/{args.model}"), f"`{args.model}` not found in `./datasets/`" 50 | assert os.path.exists(f"./datasets/{args.model}/config.yaml"), f"`config.yaml` not found in `./datasets/{args.model}/`" 51 | 52 | model_dir = f"./datasets/{args.model}/logs" 53 | if not os.path.exists(model_dir): 54 | os.makedirs(model_dir) 55 | 56 | config_path = args.config 57 | hparams = get_hparams_from_file(config_path) 58 | hparams.model_dir = model_dir 59 | return hparams 60 | 61 | 62 | def get_hparams_from_file(config_path: str) -> HParams: 63 | with open(config_path, "r") as f: 64 | data = f.read() 65 | config = yaml.safe_load(data) 66 | 67 | hparams = HParams(**config) 68 | return hparams 69 | 70 | 71 | if __name__ == "__main__": 72 | hparams = get_hparams() 73 | print(hparams) 74 | -------------------------------------------------------------------------------- /utils/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio.transforms as T 3 | import torch.utils.data 4 | 5 | spectrogram_basis = {} 6 | mel_scale_basis = {} 7 | mel_spectrogram_basis = {} 8 | 9 | 10 | def spectral_norm(x: torch.Tensor, clip_val=1e-9): 11 | return torch.log(torch.clamp(x, min=clip_val)) 12 | 13 | 14 | def wav_to_spec(y: torch.Tensor, n_fft, sample_rate, hop_length, win_length, center=False) -> torch.Tensor: 15 | assert torch.min(y) >= -1.0, f"min value is {torch.min(y)}" 16 | assert torch.max(y) <= 1.0, f"max value is {torch.max(y)}" 17 | 18 | global spectrogram_basis 19 | dtype_device = str(y.dtype) + "_" + str(y.device) 20 | hparams = dtype_device + "_" + str(n_fft) + "_" + str(hop_length) 21 | if hparams not in spectrogram_basis: 22 | spectrogram_basis[hparams] = T.Spectrogram( 23 | n_fft=n_fft, 24 | win_length=win_length, 25 | hop_length=hop_length, 26 | pad=(n_fft - hop_length) // 2, 27 | power=1, 28 | center=center, 29 | ).to(device=y.device, dtype=y.dtype) 30 | 31 | spec = spectrogram_basis[hparams](y) 32 | spec = torch.sqrt(spec.pow(2) + 1e-6) 33 | return spec 34 | 35 | 36 | def spec_to_mel(spec: torch.Tensor, n_fft, n_mels, sample_rate, f_min, f_max, norm=True) -> torch.Tensor: 37 | global mel_scale_basis 38 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 39 | hparams = dtype_device + "_" + str(n_fft) + "_" + str(n_mels) + "_" + str(f_max) 40 | if hparams not in mel_scale_basis: 41 | mel_scale_basis[hparams] = T.MelScale(n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, n_stft=n_fft // 2 + 1, norm="slaney", mel_scale="slaney").to(device=spec.device, dtype=spec.dtype) 42 | 43 | mel = torch.matmul(mel_scale_basis[hparams].fb.T, spec) 44 | if norm: 45 | mel = spectral_norm(mel) 46 | return mel 47 | 48 | 49 | def wav_to_mel(y: torch.Tensor, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False, norm=True) -> torch.Tensor: 50 | assert torch.min(y) >= -1.0, f"min value is {torch.min(y)}" 51 | assert torch.max(y) <= 1.0, f"max value is {torch.max(y)}" 52 | 53 | global mel_spectrogram_basis 54 | dtype_device = str(y.dtype) + "_" + str(y.device) 55 | hparams = dtype_device + "_" + str(n_fft) + "_" + str(num_mels) + "_" + str(hop_size) + "_" + str(fmax) 56 | if hparams not in mel_spectrogram_basis: 57 | mel_spectrogram_basis[hparams] = T.MelSpectrogram( 58 | sample_rate=sampling_rate, 59 | n_fft=n_fft, 60 | win_length=win_size, 61 | hop_length=hop_size, 62 | n_mels=num_mels, 63 | f_min=fmin, 64 | f_max=fmax, 65 | pad=(n_fft - hop_size) // 2, 66 | power=1, 67 | center=center, 68 | norm="slaney", 69 | mel_scale="slaney", 70 | ).to(device=y.device, dtype=y.dtype) 71 | 72 | mel = mel_spectrogram_basis[hparams](y) 73 | if norm: 74 | mel = spectral_norm(mel) 75 | return mel 76 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def intersperse(lst, item): 19 | result = [item] * (len(lst) * 2 + 1) 20 | result[1::2] = lst 21 | return result 22 | 23 | 24 | # TODO remove this 25 | def kl_divergence(m_p, logs_p, m_q, logs_q): 26 | """KL(P||Q)""" 27 | kl = (logs_q - logs_p) - 0.5 28 | kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 29 | return kl 30 | 31 | 32 | # TODO remove this 33 | def rand_gumbel(shape): 34 | """Sample from the Gumbel distribution, protect from overflows.""" 35 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 36 | return -torch.log(-torch.log(uniform_samples)) 37 | 38 | 39 | # TODO remove this 40 | def rand_gumbel_like(x): 41 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 42 | return g 43 | 44 | 45 | def slice_segments(x, ids_str, segment_size=4): 46 | ret = torch.zeros_like(x[:, :, :segment_size]) 47 | for i in range(x.size(0)): 48 | idx_str = ids_str[i] 49 | idx_end = idx_str + segment_size 50 | ret[i] = x[i, :, idx_str:idx_end] 51 | return ret 52 | 53 | 54 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 55 | b, d, t = x.size() 56 | if x_lengths is None: 57 | x_lengths = t 58 | ids_str_max = x_lengths - segment_size + 1 59 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 60 | ret = slice_segments(x, ids_str, segment_size) 61 | return ret, ids_str 62 | 63 | 64 | # TODO remove this 65 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 66 | position = torch.arange(length, dtype=torch.float) 67 | num_timescales = channels // 2 68 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) 69 | inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 70 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 71 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 72 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 73 | signal = signal.view(1, channels, length) 74 | return signal 75 | 76 | 77 | # TODO remove this 78 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 79 | b, channels, length = x.size() 80 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 81 | return x + signal.to(dtype=x.dtype, device=x.device) 82 | 83 | 84 | # TODO remove this 85 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 86 | b, channels, length = x.size() 87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 88 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 89 | 90 | 91 | # TODO remove this 92 | def subsequent_mask(length): 93 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 94 | return mask 95 | 96 | 97 | @torch.jit.script 98 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 99 | n_channels_int = n_channels[0] 100 | in_act = input_a + input_b 101 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 102 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 103 | acts = t_act * s_act 104 | return acts 105 | 106 | 107 | def convert_pad_shape(pad_shape): 108 | l = pad_shape[::-1] 109 | pad_shape = [item for sublist in l for item in sublist] 110 | return pad_shape 111 | 112 | 113 | # TODO remove this 114 | def shift_1d(x): 115 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 116 | return x 117 | 118 | 119 | def sequence_mask(length: torch.Tensor, max_length=None) -> torch.Tensor: 120 | if max_length is None: 121 | max_length = length.max() 122 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 123 | return x.unsqueeze(0) < length.unsqueeze(1) 124 | 125 | 126 | def clip_grad_value_(parameters, clip_value, norm_type=2): 127 | if isinstance(parameters, torch.Tensor): 128 | parameters = [parameters] 129 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 130 | norm_type = float(norm_type) 131 | if clip_value is not None: 132 | clip_value = float(clip_value) 133 | 134 | total_norm = 0 135 | for p in parameters: 136 | param_norm = p.grad.data.norm(norm_type) 137 | total_norm += param_norm.item() ** norm_type 138 | if clip_value is not None: 139 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 140 | total_norm = total_norm ** (1.0 / norm_type) 141 | return total_norm 142 | -------------------------------------------------------------------------------- /utils/monotonic_align.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import numba 5 | import numpy as np 6 | from numba import cuda 7 | 8 | from utils.model import sequence_mask, convert_pad_shape 9 | 10 | 11 | # * Ready and Tested 12 | def search_path(z_p, m_p, logs_p, x_mask, y_mask, mas_noise_scale=0.01): 13 | with torch.no_grad(): 14 | o_scale = torch.exp(-2 * logs_p) # [b, d, t] 15 | logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t] 16 | logp2 = torch.matmul(-0.5 * (z_p**2).mT, o_scale) # [b, t', d] x [b, d, t] = [b, t', t] 17 | logp3 = torch.matmul(z_p.mT, (m_p * o_scale)) # [b, t', d] x [b, d, t] = [b, t', t] 18 | logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [b, 1, t] 19 | logp = logp1 + logp2 + logp3 + logp4 # [b, t', t] 20 | 21 | if mas_noise_scale > 0.0: 22 | epsilon = torch.std(logp) * torch.randn_like(logp) * mas_noise_scale 23 | logp = logp + epsilon 24 | 25 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) # [b, 1, t] * [b, t', 1] = [b, t', t] 26 | attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t', t] maximum_path_cuda 27 | return attn 28 | 29 | 30 | def generate_path(duration: torch.Tensor, mask: torch.Tensor): 31 | """ 32 | duration: [b, 1, t_x] 33 | mask: [b, 1, t_y, t_x] 34 | """ 35 | b, _, t_y, t_x = mask.shape 36 | cum_duration = torch.cumsum(duration, -1) 37 | 38 | cum_duration_flat = cum_duration.view(b * t_x) 39 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 40 | path = path.view(b, t_x, t_y) 41 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 42 | path = path.unsqueeze(1).mT * mask 43 | return path 44 | 45 | 46 | # ! ----------------------------- CUDA monotonic_align.py ----------------------------- 47 | 48 | 49 | # TODO test for the optimal blockspergrid and threadsperblock values 50 | def maximum_path_cuda(neg_cent: torch.Tensor, mask: torch.Tensor): 51 | """CUDA optimized version. 52 | neg_cent: [b, t_t, t_s] 53 | mask: [b, t_t, t_s] 54 | """ 55 | device = neg_cent.device 56 | dtype = neg_cent.dtype 57 | 58 | neg_cent_device = cuda.as_cuda_array(neg_cent) 59 | path_device = cuda.device_array(neg_cent.shape, dtype=np.int32) 60 | t_t_max_device = cuda.as_cuda_array(mask.sum(1, dtype=torch.int32)[:, 0]) 61 | t_s_max_device = cuda.as_cuda_array(mask.sum(2, dtype=torch.int32)[:, 0]) 62 | 63 | blockspergrid = neg_cent.shape[0] 64 | threadsperblock = max(neg_cent.shape[1], neg_cent.shape[2]) 65 | 66 | maximum_path_cuda_jit[blockspergrid, threadsperblock](path_device, neg_cent_device, t_t_max_device, t_s_max_device) 67 | 68 | # Convert device array back to tensor 69 | path = torch.as_tensor(path_device.copy_to_host(), device=device, dtype=dtype) 70 | return path 71 | 72 | 73 | @cuda.jit("void(int32[:,:,:], float32[:,:,:], int32[:], int32[:])") 74 | def maximum_path_cuda_jit(paths, values, t_ys, t_xs): 75 | max_neg_val = -1e9 76 | i = cuda.grid(1) 77 | if i >= paths.shape[0]: # exit if the thread is out of the index range 78 | return 79 | 80 | path = paths[i] 81 | value = values[i] 82 | t_y = t_ys[i] 83 | t_x = t_xs[i] 84 | 85 | v_prev = v_cur = 0.0 86 | index = t_x - 1 87 | 88 | for y in range(t_y): 89 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 90 | v_cur = value[y - 1, x] if x != y else max_neg_val 91 | v_prev = value[y - 1, x - 1] if x != 0 else (0.0 if y == 0 else max_neg_val) 92 | value[y, x] += max(v_prev, v_cur) 93 | 94 | for y in range(t_y - 1, -1, -1): 95 | path[y, index] = 1 96 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): 97 | index = index - 1 98 | cuda.syncthreads() 99 | 100 | 101 | # ! ------------------------------- CPU monotonic_align.py ------------------------------- 102 | 103 | 104 | def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor): 105 | """numba optimized version. 106 | neg_cent: [b, t_t, t_s] 107 | mask: [b, t_t, t_s] 108 | """ 109 | device = neg_cent.device 110 | dtype = neg_cent.dtype 111 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 112 | path = np.zeros(neg_cent.shape, dtype=np.int32) 113 | 114 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 115 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 116 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max) 117 | return torch.from_numpy(path).to(device=device, dtype=dtype) 118 | 119 | 120 | @numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]), nopython=True, nogil=True) 121 | def maximum_path_jit(paths, values, t_ys, t_xs): 122 | b = paths.shape[0] 123 | max_neg_val = -1e9 124 | for i in range(int(b)): 125 | path = paths[i] 126 | value = values[i] 127 | t_y = t_ys[i] 128 | t_x = t_xs[i] 129 | 130 | v_prev = v_cur = 0.0 131 | index = t_x - 1 132 | 133 | for y in range(t_y): 134 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 135 | if x == y: 136 | v_cur = max_neg_val 137 | else: 138 | v_cur = value[y - 1, x] 139 | if x == 0: 140 | if y == 0: 141 | v_prev = 0.0 142 | else: 143 | v_prev = max_neg_val 144 | else: 145 | v_prev = value[y - 1, x - 1] 146 | value[y, x] += max(v_prev, v_cur) 147 | 148 | for y in range(t_y - 1, -1, -1): 149 | path[y, index] = 1 150 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): 151 | index = index - 1 152 | -------------------------------------------------------------------------------- /utils/task.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import logging 5 | import subprocess 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | 10 | MATPLOTLIB_FLAG = False 11 | 12 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 13 | logging.getLogger("numba").setLevel(logging.WARNING) 14 | logger = logging 15 | 16 | 17 | def load_checkpoint(checkpoint_path, model, optimizer=None): 18 | assert os.path.isfile(checkpoint_path) 19 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 20 | iteration = checkpoint_dict["iteration"] 21 | learning_rate = checkpoint_dict["learning_rate"] 22 | if optimizer is not None: 23 | optimizer.load_state_dict(checkpoint_dict["optimizer"]) 24 | saved_state_dict = checkpoint_dict["model"] 25 | if hasattr(model, "module"): 26 | state_dict = model.module.state_dict() 27 | else: 28 | state_dict = model.state_dict() 29 | new_state_dict = {} 30 | for k, v in state_dict.items(): 31 | try: 32 | new_state_dict[k] = saved_state_dict[k] 33 | except: 34 | logger.info("%s is not in the checkpoint" % k) 35 | new_state_dict[k] = v 36 | if hasattr(model, "module"): 37 | model.module.load_state_dict(new_state_dict) 38 | else: 39 | model.load_state_dict(new_state_dict) 40 | logger.info("Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)) 41 | del checkpoint_dict 42 | torch.cuda.empty_cache() 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path)) 48 | if hasattr(model, "module"): 49 | state_dict = model.module.state_dict() 50 | else: 51 | state_dict = model.state_dict() 52 | torch.save({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, checkpoint_path) 53 | 54 | 55 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, sample_rate=22050): 56 | for k, v in scalars.items(): 57 | writer.add_scalar(k, v, global_step) 58 | for k, v in histograms.items(): 59 | writer.add_histogram(k, v, global_step) 60 | for k, v in images.items(): 61 | writer.add_image(k, v, global_step, dataformats="HWC") 62 | for k, v in audios.items(): 63 | writer.add_audio(k, v, global_step, sample_rate) 64 | 65 | 66 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 67 | f_list = glob.glob(os.path.join(dir_path, regex)) 68 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 69 | x = f_list[-1] 70 | print(x) 71 | return x 72 | 73 | 74 | def plot_spectrogram_to_numpy(spectrogram): 75 | global MATPLOTLIB_FLAG 76 | if not MATPLOTLIB_FLAG: 77 | import matplotlib 78 | 79 | matplotlib.use("Agg") 80 | MATPLOTLIB_FLAG = True 81 | mpl_logger = logging.getLogger("matplotlib") 82 | mpl_logger.setLevel(logging.WARNING) 83 | import matplotlib.pylab as plt 84 | import numpy as np 85 | 86 | fig, ax = plt.subplots(figsize=(10, 2)) 87 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 88 | plt.colorbar(im, ax=ax) 89 | plt.xlabel("Frames") 90 | plt.ylabel("Channels") 91 | plt.tight_layout() 92 | 93 | fig.canvas.draw() 94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 96 | plt.close() 97 | return data 98 | 99 | 100 | def plot_alignment_to_numpy(alignment, info=None): 101 | global MATPLOTLIB_FLAG 102 | if not MATPLOTLIB_FLAG: 103 | import matplotlib 104 | 105 | matplotlib.use("Agg") 106 | MATPLOTLIB_FLAG = True 107 | mpl_logger = logging.getLogger("matplotlib") 108 | mpl_logger.setLevel(logging.WARNING) 109 | import matplotlib.pylab as plt 110 | import numpy as np 111 | 112 | fig, ax = plt.subplots(figsize=(6, 4)) 113 | im = ax.imshow(alignment.transpose(), aspect="auto", origin="lower", interpolation="none") 114 | fig.colorbar(im, ax=ax) 115 | xlabel = "Decoder timestep" 116 | if info is not None: 117 | xlabel += "\n\n" + info 118 | plt.xlabel(xlabel) 119 | plt.ylabel("Encoder timestep") 120 | plt.tight_layout() 121 | 122 | fig.canvas.draw() 123 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 124 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 125 | plt.close() 126 | return data 127 | 128 | 129 | def load_vocab(vocab_file: str): 130 | """Load vocabulary from text file 131 | Args: 132 | vocab_file (str): Path to vocabulary file 133 | Returns: 134 | torchtext.vocab.Vocab: Vocabulary object 135 | """ 136 | from torchtext.vocab import vocab as transform_vocab 137 | from text.symbols import UNK_ID, special_symbols 138 | 139 | vocab = {} 140 | with open(vocab_file, "r") as f: 141 | for line in f: 142 | token, index = line.split() 143 | vocab[token] = int(index) 144 | vocab = transform_vocab(vocab, specials=special_symbols) 145 | vocab.set_default_index(UNK_ID) 146 | return vocab 147 | 148 | 149 | def save_vocab(vocab, vocab_file: str): 150 | """Save vocabulary as token index pairs in a text file, sorted by the indices 151 | Args: 152 | vocab (torchtext.vocab.Vocab): Vocabulary object 153 | vocab_file (str): Path to vocabulary file 154 | """ 155 | with open(vocab_file, "w") as f: 156 | for token, index in sorted(vocab.get_stoi().items(), key=lambda kv: kv[1]): 157 | f.write(f"{token}\t{index}\n") 158 | 159 | 160 | def load_wav_to_torch(full_path): 161 | """Load wav file 162 | Args: 163 | full_path (str): Full path of the wav file 164 | 165 | Returns: 166 | waveform (torch.FloatTensor): Stereo audio signal [channel, time] in range [-1, 1] 167 | sample_rate (int): Sampling rate of audio signal (Hz) 168 | """ 169 | waveform, sample_rate = torchaudio.load(full_path) 170 | return waveform, sample_rate 171 | 172 | 173 | def load_filepaths_and_text(filename, split="|"): 174 | with open(filename, encoding="utf-8") as f: 175 | filepaths_and_text = [line.strip().split(split) for line in f] 176 | return filepaths_and_text 177 | 178 | 179 | def check_git_hash(model_dir): 180 | source_dir = os.path.dirname(os.path.realpath(__file__)) 181 | if not os.path.exists(os.path.join(source_dir, ".git")): 182 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(source_dir)) 183 | return 184 | 185 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 186 | 187 | path = os.path.join(model_dir, "githash") 188 | if os.path.exists(path): 189 | saved_hash = open(path).read() 190 | if saved_hash != cur_hash: 191 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(saved_hash[:8], cur_hash[:8])) 192 | else: 193 | open(path, "w").write(cur_hash) 194 | 195 | 196 | def get_logger(model_dir, filename="train.log"): 197 | global logger 198 | logger = logging.getLogger(os.path.basename(model_dir)) 199 | logger.setLevel(logging.DEBUG) 200 | 201 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 202 | if not os.path.exists(model_dir): 203 | os.makedirs(model_dir) 204 | h = logging.FileHandler(os.path.join(model_dir, filename)) 205 | h.setLevel(logging.DEBUG) 206 | h.setFormatter(formatter) 207 | logger.addHandler(h) 208 | return logger 209 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 80 | inputs=inputs[inside_interval_mask], 81 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 82 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 83 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 84 | inverse=inverse, 85 | left=-tail_bound, 86 | right=tail_bound, 87 | bottom=-tail_bound, 88 | top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative, 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | 97 | def rational_quadratic_spline( 98 | inputs, 99 | unnormalized_widths, 100 | unnormalized_heights, 101 | unnormalized_derivatives, 102 | inverse=False, 103 | left=0.0, 104 | right=1.0, 105 | bottom=0.0, 106 | top=1.0, 107 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 108 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 109 | min_derivative=DEFAULT_MIN_DERIVATIVE, 110 | ): 111 | if torch.min(inputs) < left or torch.max(inputs) > right: 112 | raise ValueError("Input to a transform is not within its domain") 113 | 114 | num_bins = unnormalized_widths.shape[-1] 115 | 116 | if min_bin_width * num_bins > 1.0: 117 | raise ValueError("Minimal bin width too large for the number of bins") 118 | if min_bin_height * num_bins > 1.0: 119 | raise ValueError("Minimal bin height too large for the number of bins") 120 | 121 | widths = F.softmax(unnormalized_widths, dim=-1) 122 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 123 | cumwidths = torch.cumsum(widths, dim=-1) 124 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 125 | cumwidths = (right - left) * cumwidths + left 126 | cumwidths[..., 0] = left 127 | cumwidths[..., -1] = right 128 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 129 | 130 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 131 | 132 | heights = F.softmax(unnormalized_heights, dim=-1) 133 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 134 | cumheights = torch.cumsum(heights, dim=-1) 135 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 136 | cumheights = (top - bottom) * cumheights + bottom 137 | cumheights[..., 0] = bottom 138 | cumheights[..., -1] = top 139 | heights = cumheights[..., 1:] - cumheights[..., :-1] 140 | 141 | if inverse: 142 | bin_idx = searchsorted(cumheights, inputs)[..., None] 143 | else: 144 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 145 | 146 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 147 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 148 | 149 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 150 | delta = heights / widths 151 | input_delta = delta.gather(-1, bin_idx)[..., 0] 152 | 153 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 154 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 155 | 156 | input_heights = heights.gather(-1, bin_idx)[..., 0] 157 | 158 | if inverse: 159 | a = (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + input_heights * (input_delta - input_derivatives) 160 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 161 | c = -input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta) 171 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2)) 172 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 173 | 174 | return outputs, -logabsdet 175 | else: 176 | theta = (inputs - input_cumwidths) / input_bin_widths 177 | theta_one_minus_theta = theta * (1 - theta) 178 | 179 | numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) 180 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta) 181 | outputs = input_cumheights + numerator / denominator 182 | 183 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2)) 184 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 185 | 186 | return outputs, logabsdet 187 | --------------------------------------------------------------------------------