├── .gitignore
├── LICENSE
├── README.md
├── data_utils.py
├── datasets
├── ljs_base
│ ├── config.yaml
│ ├── filelists
│ │ ├── test.txt
│ │ ├── train.txt
│ │ └── val.txt
│ ├── prepare
│ │ └── filelists.ipynb
│ └── vocab.txt
├── ljs_nosdp
│ └── config.yaml
├── madasr23_base
│ ├── config.yaml
│ └── prepare
│ │ ├── filelists.ipynb
│ │ └── metadata.ipynb
└── vctk_base
│ ├── config.yaml
│ └── filelists
│ ├── vctk_audio_sid_text_test_filelist.txt
│ ├── vctk_audio_sid_text_test_filelist.txt.cleaned
│ ├── vctk_audio_sid_text_train_filelist.txt
│ ├── vctk_audio_sid_text_train_filelist.txt.cleaned
│ ├── vctk_audio_sid_text_val_filelist.txt
│ └── vctk_audio_sid_text_val_filelist.txt.cleaned
├── figures
├── figure01.png
├── figure02.png
└── figure03.png
├── inference.ipynb
├── inference_batch.ipynb
├── losses.py
├── model
├── condition.py
├── decoder.py
├── discriminator.py
├── duration_predictors.py
├── encoders.py
├── models.py
├── modules.py
├── normalization.py
├── normalizing_flows.py
└── transformer.py
├── preprocess
├── README.md
├── audio_find_corrupted.ipynb
├── audio_resample.ipynb
├── audio_resampling.py
├── mel_transform.py
└── vocab_generation.ipynb
├── requirements.txt
├── text
├── LICENSE
├── __init__.py
├── cleaners.py
├── normalize_numbers.py
└── symbols.py
├── train.py
├── train_ms.py
└── utils
├── hparams.py
├── mel_processing.py
├── model.py
├── monotonic_align.py
├── task.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | DUMMY*
2 |
3 | __pycache__
4 | .ipynb_checkpoints
5 | .*.swp
6 |
7 | build
8 | *.c
9 | monotonic_align/monotonic_align
10 |
11 | .vscode
12 | .DS_Store
13 |
14 | logs
15 | test
16 | datasets/madasr23_base/filelists
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jaehyeon Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design
2 |
3 | ### Jungil Kong, Jihoon Park, Beomjeong Kim, Jeongmin Kim, Dohee Kong, Sangjin Kim
4 |
5 | ### SK Telecom, South Korea
6 |
7 | Single-stage text-to-speech models have been actively studied recently, and their results have outperformed two-stage pipeline systems. Although the previous single-stage model has made great progress, there is room for improvement in terms of its intermittent unnaturalness, computational efficiency, and strong dependence on phoneme conversion. In this work, we introduce VITS2, a single-stage text-to-speech model that efficiently synthesizes a more natural speech by improving several aspects of the previous work. We propose improved structures and training mechanisms and present that the proposed methods are effective in improving naturalness, similarity of speech characteristics in a multi-speaker model, and efficiency of training and inference. Furthermore, we demonstrate that the strong dependence on phoneme conversion in previous works can be significantly reduced with our method, which allows a fully end-to-end single-stage approach.
8 |
9 | Demo: https://vits-2.github.io/demo/
10 |
11 | Paper: https://arxiv.org/abs/2307.16430
12 |
13 | Unofficial implementation of VITS2. This is a work in progress. Please refer to [TODO](#todo) for more details.
14 |
15 |
16 |
17 | Duration Predictor |
18 | Normalizing Flows |
19 | Text Encoder |
20 |
21 |
22 |  |
23 |  |
24 |  |
25 |
26 |
27 |
28 | ## Audio Samples
29 |
30 | [In progress]
31 |
32 | Audio sample after 52,000 steps of training on 1 GPU for LJSpeech dataset:
33 | https://github.com/daniilrobnikov/vits2/assets/91742765/d769c77a-bd92-4732-96e7-ab53bf50d783
34 |
35 | ## Installation:
36 |
37 |
38 |
39 | **Clone the repo**
40 |
41 | ```shell
42 | git clone git@github.com:daniilrobnikov/vits2.git
43 | cd vits2
44 | ```
45 |
46 | ## Setting up the conda env
47 |
48 | This is assuming you have navigated to the `vits2` root after cloning it.
49 |
50 | **NOTE:** This is tested under `python3.11` with conda env. For other python versions, you might encounter version conflicts.
51 |
52 | **PyTorch 2.0**
53 | Please refer [requirements.txt](requirements.txt)
54 |
55 | ```shell
56 | # install required packages (for pytorch 2.0)
57 | conda create -n vits2 python=3.11
58 | conda activate vits2
59 | pip install -r requirements.txt
60 |
61 | conda env config vars set PYTHONPATH="/path/to/vits2"
62 | ```
63 |
64 | ## Download datasets
65 |
66 | There are three options you can choose from: LJ Speech, VCTK, or custom dataset.
67 |
68 | 1. LJ Speech: [LJ Speech dataset](#lj-speech-dataset). Used for single speaker TTS.
69 | 2. VCTK: [VCTK dataset](#vctk-dataset). Used for multi-speaker TTS.
70 | 3. Custom dataset: You can use your own dataset. Please refer [here](#custom-dataset).
71 |
72 | ### LJ Speech dataset
73 |
74 | 1. download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/)
75 |
76 | ```shell
77 | wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
78 | tar -xvf LJSpeech-1.1.tar.bz2
79 | cd LJSpeech-1.1/wavs
80 | rm -rf wavs
81 | ```
82 |
83 | 3. preprocess mel-spectrograms. See [mel_transform.py](preprocess/mel_transform.py)
84 |
85 | ```shell
86 | python preprocess/mel_transform.py --data_dir /path/to/LJSpeech-1.1 -c datasets/ljs_base/config.yaml
87 | ```
88 |
89 | 3. preprocess text. See [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb)
90 |
91 | 4. rename or create a link to the dataset folder.
92 |
93 | ```shell
94 | ln -s /path/to/LJSpeech-1.1 DUMMY1
95 | ```
96 |
97 | ### VCTK dataset
98 |
99 | 1. download and extract the [VCTK dataset](https://www.kaggle.com/datasets/showmik50/vctk-dataset)
100 |
101 | ```shell
102 | wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip
103 | unzip VCTK-Corpus-0.92.zip
104 | ```
105 |
106 | 2. (optional): downsample the audio files to 22050 Hz. See [audio_resample.ipynb](preprocess/audio_resample.ipynb)
107 |
108 | 3. preprocess mel-spectrograms. See [mel_transform.py](preprocess/mel_transform.py)
109 |
110 | ```shell
111 | python preprocess/mel_transform.py --data_dir /path/to/VCTK-Corpus-0.92 -c datasets/vctk_base/config.yaml
112 | ```
113 |
114 | 4. preprocess text. See [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb)
115 |
116 | 5. rename or create a link to the dataset folder.
117 |
118 | ```shell
119 | ln -s /path/to/VCTK-Corpus-0.92 DUMMY2
120 | ```
121 |
122 | ### Custom dataset
123 |
124 | 1. create a folder with wav files
125 | 2. duplicate the `ljs_base` in `datasets` directory and rename it to `custom_base`
126 | 3. open [custom_base](datasets/custom_base) and change the following fields in `config.yaml`:
127 |
128 | ```yaml
129 | data:
130 | training_files: datasets/custom_base/filelists/train.txt
131 | validation_files: datasets/custom_base/filelists/val.txt
132 | text_cleaners: # See text/cleaners.py
133 | - phonemize_text
134 | - tokenize_text
135 | - add_bos_eos
136 | cleaned_text: true # True if you ran step 6.
137 | language: en-us # language of your dataset. See espeak-ng
138 | sample_rate: 22050 # sample rate, based on your dataset
139 | ...
140 | n_speakers: 0 # 0 for single speaker, > 0 for multi-speaker
141 | ```
142 |
143 | 4. preprocess mel-spectrograms. See [mel_transform.py](preprocess/mel_transform.py)
144 |
145 | ```shell
146 | python preprocess/mel_transform.py --data_dir /path/to/custom_dataset -c datasets/custom_base/config.yaml
147 | ```
148 |
149 | 6. preprocess text. See [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb)
150 |
151 | **NOTE:** You may need to install `espeak-ng` if you want to use `phonemize_text` cleaner. Please refer [espeak-ng](https://github.com/espeak-ng/espeak-ng)
152 |
153 | 7. rename or create a link to the dataset folder.
154 |
155 | ```shell
156 | ln -s /path/to/custom_dataset DUMMY3
157 | ```
158 |
159 | ## Training Examples
160 |
161 | ```shell
162 | # LJ Speech
163 | python train.py -c datasets/ljs_base/config.yaml -m ljs_base
164 |
165 | # VCTK
166 | python train_ms.py -c datasets/vctk_base/config.yaml -m vctk_base
167 |
168 | # Custom dataset (multi-speaker)
169 | python train_ms.py -c datasets/custom_base/config.yaml -m custom_base
170 | ```
171 |
172 | ## Inference Examples
173 |
174 | See [inference.ipynb](inference.ipynb) and [inference_batch.ipynb](inference_batch.ipynb)
175 |
176 | ## Pretrained Models
177 |
178 | [In progress]
179 |
180 | ## Todo
181 |
182 | - [ ] model (vits2)
183 | - [x] update TextEncoder to support speaker conditioning
184 | - [x] support for high-resolution mel-spectrograms in training. See [mel_transform.py](preprocess/mel_transform.py)
185 | - [x] Monotonic Alignment Search with Gaussian noise
186 | - [x] Normalizing Flows using Transformer Block
187 | - [ ] Stochastic Duration Predictor with Time Step-wise Conditional Discriminator
188 | - [ ] model (YourTTS)
189 | - [ ] Language Conditioning
190 | - [ ] Speaker Encoder
191 | - [ ] model (NaturalSpeech)
192 | - [x] KL Divergence Loss after Prior Enhancing
193 | - [ ] GAN loss for e2e training
194 | - [ ] other
195 | - [x] support for batch inference
196 | - [x] special tokens in tokenizer
197 | - [x] test numba.jit and numba.cuda.jit implementations of MAS. See [monotonic_align.py](monotonic_align.py)
198 | - [ ] KL Divergence Loss between TextEncoder and Projection
199 | - [ ] support for streaming inference. Please refer [vits_chinese](https://github.com/PlayVoice/vits_chinese/blob/master/text/symbols.py)
200 | - [ ] use optuna for hyperparameter tuning
201 | - [ ] future work
202 | - [ ] update model to vits2. Please refer [VITS2](https://arxiv.org/abs/2307.16430)
203 | - [ ] update model to YourTTS with zero-shot learning. See [YourTTS](https://arxiv.org/abs/2112.02418)
204 | - [ ] update model to NaturalSpeech. Please refer [NaturalSpeech](https://arxiv.org/abs/2205.04421)
205 |
206 | ## Acknowledgements
207 |
208 | - This is unofficial repo based on [VITS2](https://arxiv.org/abs/2307.16430)
209 | - g2p for multiple languages is based on [phonemizer](https://github.com/bootphon/phonemizer)
210 | - We also thank GhatGPT for providing writing assistance.
211 |
212 | ## References
213 |
214 | - [VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design](https://arxiv.org/abs/2307.16430)
215 | - [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103)
216 | - [YourTTS: Towards Zero-Shot Multi-Speaker TTS and Zero-Shot Voice Conversion for everyone](https://arxiv.org/abs/2112.02418)
217 | - [NaturalSpeech: End-to-End Text to Speech Synthesis with Human-Level Quality](https://arxiv.org/abs/2205.04421)
218 | - [A TensorFlow implementation of Google's Tacotron speech synthesis with pre-trained model (unofficial)](https://github.com/keithito/tacotron)
219 |
220 | # VITS2
221 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import torch.utils.data
5 |
6 | from utils.mel_processing import wav_to_spec, wav_to_mel
7 | from utils.task import load_vocab, load_wav_to_torch, load_filepaths_and_text
8 | from text import tokenizer
9 |
10 |
11 | class TextAudioLoader(torch.utils.data.Dataset):
12 | """
13 | 1) loads audio, text pairs
14 | 2) normalizes text and converts them to sequences of integers
15 | 3) computes spectrograms from audio files.
16 | """
17 |
18 | def __init__(self, audiopaths_and_text, hps_data):
19 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
20 | self.vocab = load_vocab(hps_data.vocab_file)
21 | self.text_cleaners = hps_data.text_cleaners
22 | self.sample_rate = hps_data.sample_rate
23 | self.n_fft = hps_data.n_fft
24 | self.hop_length = hps_data.hop_length
25 | self.win_length = hps_data.win_length
26 | self.n_mels = hps_data.n_mels
27 | self.f_min = hps_data.f_min
28 | self.f_max = hps_data.f_max
29 | self.use_mel = hps_data.use_mel
30 |
31 | self.language = getattr(hps_data, "language", "en-us")
32 | self.cleaned_text = getattr(hps_data, "cleaned_text", False)
33 | self.min_text_len = getattr(hps_data, "min_text_len", 1)
34 | self.max_text_len = getattr(hps_data, "max_text_len", 200)
35 |
36 | random.seed(1234)
37 | random.shuffle(self.audiopaths_and_text)
38 | self._filter()
39 |
40 | def _filter(self):
41 | """
42 | Filter text & store spec lengths
43 | """
44 | # Store spectrogram lengths for Bucketing
45 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
46 | # spec_length = wav_length // hop_length
47 |
48 | audiopaths_and_text_new = []
49 | lengths = []
50 | for audiopath, text in self.audiopaths_and_text:
51 | text_len = text.count("\t") + 1
52 | if self.min_text_len <= text_len and text_len <= self.max_text_len:
53 | audiopaths_and_text_new.append([audiopath, text])
54 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
55 | self.audiopaths_and_text = audiopaths_and_text_new
56 | self.lengths = lengths
57 |
58 | def get_audio_text_pair(self, audiopath_and_text):
59 | # separate filename and text
60 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
61 | text = self.get_text(text)
62 | wav = self.get_audio(audiopath)
63 | spec = self.get_spec(audiopath, wav)
64 | return (text, spec, wav)
65 |
66 | def get_text(self, text):
67 | text_norm = tokenizer(text, self.vocab, self.text_cleaners, language=self.language, cleaned_text=self.cleaned_text)
68 | text_norm = torch.LongTensor(text_norm)
69 | return text_norm
70 |
71 | def get_audio(self, filename):
72 | audio, sample_rate = load_wav_to_torch(filename)
73 | assert sample_rate == self.sample_rate, f"{sample_rate} SR doesn't match target {self.sample_rate} SR"
74 | return audio
75 |
76 | def get_spec(self, filename: str, wav):
77 | spec_filename = filename.replace(".wav", ".spec.pt")
78 |
79 | if os.path.exists(spec_filename):
80 | spec = torch.load(spec_filename)
81 | else:
82 | if self.use_mel:
83 | spec = wav_to_mel(wav, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, self.win_length, self.f_min, self.f_max, center=False, norm=False)
84 | else:
85 | spec = wav_to_spec(wav, self.n_fft, self.sample_rate, self.hop_length, self.win_length, center=False)
86 | spec = torch.squeeze(spec, 0)
87 | torch.save(spec, spec_filename)
88 |
89 | return spec
90 |
91 | def __getitem__(self, index):
92 | return self.get_audio_text_pair(self.audiopaths_and_text[index])
93 |
94 | def __len__(self):
95 | return len(self.audiopaths_and_text)
96 |
97 |
98 | class TextAudioCollate:
99 | """Zero-pads model inputs and targets"""
100 |
101 | def __init__(self, return_ids=False):
102 | self.return_ids = return_ids
103 |
104 | def __call__(self, batch):
105 | """Collate's training batch from normalized text and aduio
106 | PARAMS
107 | ------
108 | batch: [text_normalized, spec_normalized, wav_normalized]
109 | """
110 | # Right zero-pad all one-hot text sequences to max input length
111 | _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
112 |
113 | max_text_len = max([len(x[0]) for x in batch])
114 | max_spec_len = max([x[1].size(1) for x in batch])
115 | max_wav_len = max([x[2].size(1) for x in batch])
116 |
117 | text_lengths = torch.LongTensor(len(batch))
118 | spec_lengths = torch.LongTensor(len(batch))
119 | wav_lengths = torch.LongTensor(len(batch))
120 |
121 | text_padded = torch.LongTensor(len(batch), max_text_len)
122 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
123 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
124 | text_padded.zero_()
125 | spec_padded.zero_()
126 | wav_padded.zero_()
127 | for i in range(len(ids_sorted_decreasing)):
128 | row = batch[ids_sorted_decreasing[i]]
129 |
130 | text = row[0]
131 | text_padded[i, : text.size(0)] = text
132 | text_lengths[i] = text.size(0)
133 |
134 | spec = row[1]
135 | spec_padded[i, :, : spec.size(1)] = spec
136 | spec_lengths[i] = spec.size(1)
137 |
138 | wav = row[2]
139 | wav_padded[i, :, : wav.size(1)] = wav
140 | wav_lengths[i] = wav.size(1)
141 |
142 | if self.return_ids:
143 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing
144 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths
145 |
146 |
147 | """Multi speaker version"""
148 |
149 |
150 | class TextAudioSpeakerLoader(torch.utils.data.Dataset):
151 | """
152 | 1) loads audio, speaker_id, text pairs
153 | 2) normalizes text and converts them to sequences of integers
154 | 3) computes spectrograms from audio files.
155 | """
156 |
157 | def __init__(self, audiopaths_sid_text, hps_data):
158 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
159 | self.vocab = load_vocab(hps_data.vocab_file)
160 | self.text_cleaners = hps_data.text_cleaners
161 | self.sample_rate = hps_data.sample_rate
162 | self.n_fft = hps_data.n_fft
163 | self.hop_length = hps_data.hop_length
164 | self.win_length = hps_data.win_length
165 | self.n_mels = hps_data.n_mels
166 | self.f_min = hps_data.f_min
167 | self.f_max = hps_data.f_max
168 | self.use_mel = hps_data.use_mel
169 |
170 | self.language = getattr(hps_data, "language", "en-us")
171 | self.cleaned_text = getattr(hps_data, "cleaned_text", False)
172 | self.min_text_len = getattr(hps_data, "min_text_len", 1)
173 | self.max_text_len = getattr(hps_data, "max_text_len", 200)
174 |
175 | random.seed(1234)
176 | random.shuffle(self.audiopaths_sid_text)
177 | self._filter()
178 |
179 | def _filter(self):
180 | """
181 | Filter text & store spec lengths
182 | """
183 | # Store spectrogram lengths for Bucketing
184 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
185 | # spec_length = wav_length // hop_length
186 |
187 | audiopaths_sid_text_new = []
188 | lengths = []
189 | for audiopath, sid, text in self.audiopaths_sid_text:
190 | text_len = text.count("\t") + 1
191 | if self.min_text_len <= text_len and text_len <= self.max_text_len:
192 | audiopaths_sid_text_new.append([audiopath, sid, text])
193 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
194 | self.audiopaths_sid_text = audiopaths_sid_text_new
195 | self.lengths = lengths
196 |
197 | def get_audio_text_speaker_pair(self, audiopath_sid_text):
198 | # separate filename, speaker_id and text
199 | audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2]
200 | text = self.get_text(text)
201 | wav = self.get_audio(audiopath)
202 | spec = self.get_spec(audiopath, wav)
203 | sid = self.get_sid(sid)
204 | return (text, spec, wav, sid)
205 |
206 | def get_text(self, text):
207 | text_norm = tokenizer(text, self.vocab, self.text_cleaners, language=self.language, cleaned_text=self.cleaned_text)
208 | text_norm = torch.LongTensor(text_norm)
209 | return text_norm
210 |
211 | def get_audio(self, filename):
212 | audio, sample_rate = load_wav_to_torch(filename)
213 | assert sample_rate == self.sample_rate, f"{sample_rate} SR doesn't match target {self.sample_rate} SR"
214 | return audio
215 |
216 | def get_spec(self, filename: str, wav):
217 | spec_filename = filename.replace(".wav", ".spec.pt")
218 |
219 | if os.path.exists(spec_filename):
220 | spec = torch.load(spec_filename)
221 | else:
222 | if self.use_mel:
223 | spec = wav_to_mel(wav, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, self.win_length, self.f_min, self.f_max, center=False, norm=False)
224 | else:
225 | spec = wav_to_spec(wav, self.n_fft, self.sample_rate, self.hop_length, self.win_length, center=False)
226 | spec = torch.squeeze(spec, 0)
227 | torch.save(spec, spec_filename)
228 |
229 | return spec
230 |
231 | def get_sid(self, sid):
232 | sid = torch.LongTensor([int(sid)])
233 | return sid
234 |
235 | def __getitem__(self, index):
236 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
237 |
238 | def __len__(self):
239 | return len(self.audiopaths_sid_text)
240 |
241 |
242 | class TextAudioSpeakerCollate:
243 | """Zero-pads model inputs and targets"""
244 |
245 | def __init__(self, return_ids=False):
246 | self.return_ids = return_ids
247 |
248 | def __call__(self, batch):
249 | """Collate's training batch from normalized text, audio and speaker identities
250 | PARAMS
251 | ------
252 | batch: [text_normalized, spec_normalized, wav_normalized, sid]
253 | """
254 | # Right zero-pad all one-hot text sequences to max input length
255 | _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
256 |
257 | max_text_len = max([len(x[0]) for x in batch])
258 | max_spec_len = max([x[1].size(1) for x in batch])
259 | max_wav_len = max([x[2].size(1) for x in batch])
260 |
261 | text_lengths = torch.LongTensor(len(batch))
262 | spec_lengths = torch.LongTensor(len(batch))
263 | wav_lengths = torch.LongTensor(len(batch))
264 | sid = torch.LongTensor(len(batch))
265 |
266 | text_padded = torch.LongTensor(len(batch), max_text_len)
267 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
268 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
269 | text_padded.zero_()
270 | spec_padded.zero_()
271 | wav_padded.zero_()
272 | for i in range(len(ids_sorted_decreasing)):
273 | row = batch[ids_sorted_decreasing[i]]
274 |
275 | text = row[0]
276 | text_padded[i, : text.size(0)] = text
277 | text_lengths[i] = text.size(0)
278 |
279 | spec = row[1]
280 | spec_padded[i, :, : spec.size(1)] = spec
281 | spec_lengths[i] = spec.size(1)
282 |
283 | wav = row[2]
284 | wav_padded[i, :, : wav.size(1)] = wav
285 | wav_lengths[i] = wav.size(1)
286 |
287 | sid[i] = row[3]
288 |
289 | if self.return_ids:
290 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing
291 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid
292 |
293 |
294 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
295 | """
296 | Maintain similar input lengths in a batch.
297 | Length groups are specified by boundaries.
298 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
299 |
300 | It removes samples which are not included in the boundaries.
301 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
302 | """
303 |
304 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
305 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
306 | self.lengths = dataset.lengths
307 | self.batch_size = batch_size
308 | self.boundaries = boundaries
309 |
310 | self.buckets, self.num_samples_per_bucket = self._create_buckets()
311 | self.total_size = sum(self.num_samples_per_bucket)
312 | self.num_samples = self.total_size // self.num_replicas
313 |
314 | def _create_buckets(self):
315 | buckets = [[] for _ in range(len(self.boundaries) - 1)]
316 | for i in range(len(self.lengths)):
317 | length = self.lengths[i]
318 | idx_bucket = self._bisect(length)
319 | if idx_bucket != -1:
320 | buckets[idx_bucket].append(i)
321 |
322 | for i in range(len(buckets) - 1, 0, -1):
323 | if len(buckets[i]) == 0:
324 | buckets.pop(i)
325 | self.boundaries.pop(i + 1)
326 |
327 | num_samples_per_bucket = []
328 | for i in range(len(buckets)):
329 | len_bucket = len(buckets[i])
330 | total_batch_size = self.num_replicas * self.batch_size
331 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
332 | num_samples_per_bucket.append(len_bucket + rem)
333 | return buckets, num_samples_per_bucket
334 |
335 | def __iter__(self):
336 | # deterministically shuffle based on epoch
337 | g = torch.Generator()
338 | g.manual_seed(self.epoch)
339 |
340 | indices = []
341 | if self.shuffle:
342 | for bucket in self.buckets:
343 | indices.append(torch.randperm(len(bucket), generator=g).tolist())
344 | else:
345 | for bucket in self.buckets:
346 | indices.append(list(range(len(bucket))))
347 |
348 | batches = []
349 | for i in range(len(self.buckets)):
350 | bucket = self.buckets[i]
351 | len_bucket = len(bucket)
352 | ids_bucket = indices[i]
353 | num_samples_bucket = self.num_samples_per_bucket[i]
354 |
355 | # add extra samples to make it evenly divisible
356 | rem = num_samples_bucket - len_bucket
357 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
358 |
359 | # subsample
360 | ids_bucket = ids_bucket[self.rank :: self.num_replicas]
361 |
362 | # batching
363 | for j in range(len(ids_bucket) // self.batch_size):
364 | batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
365 | batches.append(batch)
366 |
367 | if self.shuffle:
368 | batch_ids = torch.randperm(len(batches), generator=g).tolist()
369 | batches = [batches[i] for i in batch_ids]
370 | self.batches = batches
371 |
372 | assert len(self.batches) * self.batch_size == self.num_samples
373 | return iter(self.batches)
374 |
375 | def _bisect(self, x, lo=0, hi=None):
376 | if hi is None:
377 | hi = len(self.boundaries) - 1
378 |
379 | if hi > lo:
380 | mid = (hi + lo) // 2
381 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
382 | return mid
383 | elif x <= self.boundaries[mid]:
384 | return self._bisect(x, lo, mid)
385 | else:
386 | return self._bisect(x, mid + 1, hi)
387 | else:
388 | return -1
389 |
390 | def __len__(self):
391 | return self.num_samples // self.batch_size
392 |
--------------------------------------------------------------------------------
/datasets/ljs_base/config.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | log_interval: 100
3 | eval_interval: 1000
4 | seed: 1234
5 | epochs: 20000
6 | learning_rate: 2.0e-4
7 | betas: [0.8, 0.99]
8 | eps: 1.0e-09
9 | batch_size: 64 # TODO Try more
10 | fp16_run: true
11 | lr_decay: 0.999875
12 | segment_size: 8192
13 | init_lr_ratio: 1
14 | warmup_epochs: 0
15 | c_mel: 45
16 | c_kl_text: 0 # default: 0
17 | c_kl_dur: 2 # default: 2
18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1
19 |
20 | data:
21 | training_files: datasets/ljs_base/filelists/train.txt
22 | validation_files: datasets/ljs_base/filelists/val.txt
23 | vocab_file: datasets/ljs_base/vocab.txt
24 | text_cleaners:
25 | - phonemize_text
26 | - add_spaces
27 | - tokenize_text
28 | - add_bos_eos
29 | cleaned_text: true
30 | language: en-us
31 | bits_per_sample: 16
32 | sample_rate: 22050
33 | n_fft: 2048
34 | hop_length: 256
35 | win_length: 1024
36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80
37 | f_min: 0
38 | f_max:
39 | n_speakers: 0
40 | use_mel: true
41 |
42 | model:
43 | inter_channels: 192
44 | hidden_channels: 192
45 | filter_channels: 768
46 | n_heads: 2
47 | n_layers: 6
48 | n_layers_q: 12 # default: 16
49 | n_flows: 8 # default: 4
50 | kernel_size: 3
51 | p_dropout: 0.1
52 | speaker_cond_layer: 0 # 0 to disable speaker conditioning
53 | resblock: "1"
54 | resblock_kernel_sizes: [3, 7, 11]
55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
56 | upsample_rates: [8, 8, 2, 2]
57 | upsample_initial_channel: 512
58 | upsample_kernel_sizes: [16, 16, 4, 4]
59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS
60 | mas_noise_scale_decay: 2.0e-06
61 | use_spectral_norm: false
62 | use_transformer_flow: false
63 |
--------------------------------------------------------------------------------
/datasets/ljs_base/vocab.txt:
--------------------------------------------------------------------------------
1 | 0
2 | 1
3 | 2
4 | 3
5 | 4
6 | 5
7 | n 6
8 | t 7
9 | ə 8
10 | s 9
11 | d 10
12 | ð 11
13 | ɹ 12
14 | k 13
15 | z 14
16 | ɪ 15
17 | l 16
18 | m 17
19 | ˈɪ 18
20 | p 19
21 | w 20
22 | v 21
23 | ˈɛ 22
24 | f 23
25 | ˈeɪ 24
26 | b 25
27 | ɚ 26
28 | , 27
29 | ʌ 28
30 | ˈæ 29
31 | h 30
32 | ᵻ 31
33 | i 32
34 | æ 33
35 | . 34
36 | ˈaɪ 35
37 | ˈiː 36
38 | ʃ 37
39 | uː 38
40 | ˈoʊ 39
41 | ˈɑː 40
42 | ˈʌ 41
43 | ŋ 42
44 | əl 43
45 | ˈuː 44
46 | ɾ 45
47 | ɡ 46
48 | ɐ 47
49 | ˈɜː 48
50 | dʒ 49
51 | tʃ 50
52 | iː 51
53 | j 52
54 | ˈaʊ 53
55 | θ 54
56 | ˌɪ 55
57 | ˈɔː 56
58 | ˈɔ 57
59 | ˈoːɹ 58
60 | ɔːɹ 59
61 | ɛ 60
62 | ˌɛ 61
63 | ˌʌ 62
64 | ˈɑːɹ 63
65 | ˌæ 64
66 | ˈɔːɹ 65
67 | ˈʊ 66
68 | ɜː 67
69 | oʊ 68
70 | eɪ 69
71 | ˈɛɹ 70
72 | ˈɪɹ 71
73 | " 72
74 | ˌeɪ 73
75 | iə 74
76 | ʊ 75
77 | ˌaɪ 76
78 | ˈɔɪ 77
79 | ˌɑː 78
80 | ; 79
81 | aɪ 80
82 | ɛɹ 81
83 | ˈʊɹ 82
84 | ɑːɹ 83
85 | ʒ 84
86 | ˈaɪɚ 85
87 | ˌiː 86
88 | ˌuː 87
89 | ˌoʊ 88
90 | aʊ 89
91 | ˈiə 90
92 | ɑː 91
93 | ɔː 92
94 | n̩ 93
95 | ʔ 94
96 | ˈaɪə 95
97 | : 96
98 | oːɹ 97
99 | ˌaʊ 98
100 | ˌɑːɹ 99
101 | ˌɜː 100
102 | ˌoː 101
103 | ˈoː 102
104 | ? 103
105 | ˌɔːɹ 104
106 | ˌɔː 105
107 | ɪɹ 106
108 | ʊɹ 107
109 | oː 108
110 | ! 109
111 | ɔɪ 110
112 | ˌʊɹ 111
113 | ˌʊ 112
114 | ˌiə 113
115 | ˌɔɪ 114
116 | r 115
117 | ɔ 116
118 | ˌoːɹ 117
119 | aɪə 118
120 | ˌɪɹ 119
121 | aɪɚ 120
122 | ˌɔ 121
123 | ˌɛɹ 122
124 | x 123
125 | “ 124
126 | ” 125
127 | ˈɚ 126
128 | ˌaɪɚ 127
129 | ˌn̩ 128
130 |
--------------------------------------------------------------------------------
/datasets/ljs_nosdp/config.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | log_interval: 100
3 | eval_interval: 1000
4 | seed: 1234
5 | epochs: 20000
6 | learning_rate: 2.0e-4
7 | betas: [0.8, 0.99]
8 | eps: 1.0e-09
9 | batch_size: 64
10 | fp16_run: true
11 | lr_decay: 0.999875
12 | segment_size: 8192
13 | init_lr_ratio: 1
14 | warmup_epochs: 0
15 | c_mel: 45
16 | c_kl_text: 0 # default: 0
17 | c_kl_dur: 2 # default: 2
18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1
19 |
20 | data:
21 | training_files: datasets/ljs_base/filelists/train.txt
22 | validation_files: datasets/ljs_base/filelists/val.txt
23 | vocab_file: datasets/ljs_base/vocab.txt
24 | text_cleaners:
25 | - phonemize_text
26 | - add_spaces
27 | - tokenize_text
28 | - add_bos_eos
29 | cleaned_text: true
30 | language: en-us
31 | bits_per_sample: 16
32 | sample_rate: 22050
33 | n_fft: 2048
34 | hop_length: 256
35 | win_length: 1024
36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80
37 | f_min: 0
38 | f_max:
39 | n_speakers: 0
40 | use_mel: true
41 |
42 | model:
43 | inter_channels: 192
44 | hidden_channels: 192
45 | filter_channels: 768
46 | n_heads: 2
47 | n_layers: 6
48 | n_layers_q: 12 # default: 16
49 | n_flows: 8 # default: 4
50 | kernel_size: 3
51 | p_dropout: 0.1
52 | speaker_cond_layer: 0 # 0 to disable speaker conditioning
53 | resblock: "1"
54 | resblock_kernel_sizes: [3, 7, 11]
55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
56 | upsample_rates: [8, 8, 2, 2]
57 | upsample_initial_channel: 512
58 | upsample_kernel_sizes: [16, 16, 4, 4]
59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS
60 | mas_noise_scale_decay: 2.0e-06
61 | use_spectral_norm: false
62 | use_transformer_flow: false
63 | use_sdp: false
64 |
--------------------------------------------------------------------------------
/datasets/madasr23_base/config.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | log_interval: 500
3 | eval_interval: 5000
4 | seed: 1234
5 | epochs: 10000
6 | learning_rate: 0.0002
7 | betas: [0.8, 0.99]
8 | eps: 1.0e-09
9 | batch_size: 64
10 | fp16_run: true
11 | lr_decay: 0.999875
12 | segment_size: 8192
13 | init_lr_ratio: 1
14 | warmup_epochs: 0
15 | c_mel: 45
16 | c_kl_text: 0 # default: 0
17 | c_kl_dur: 2 # default: 2
18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1
19 |
20 | data:
21 | training_files: datasets/madasr23_base/filelists/train.txt
22 | validation_files: datasets/madasr23_base/filelists/val.txt
23 | vocab_file: datasets/madasr23_base/vocab.txt
24 | text_cleaners:
25 | - phonemize_text
26 | - add_spaces
27 | - tokenize_text
28 | - add_bos_eos
29 | cleaned_text: true
30 | language: bn
31 | bits_per_sample: 16
32 | sample_rate: 16000
33 | n_fft: 2048
34 | hop_length: 256
35 | win_length: 1024
36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80
37 | f_min: 0
38 | f_max:
39 | n_speakers: 2011
40 | use_mel: true
41 |
42 | model:
43 | inter_channels: 192
44 | hidden_channels: 192
45 | filter_channels: 768
46 | n_heads: 2
47 | n_layers: 6
48 | n_layers_q: 12 # default: 16
49 | n_flows: 8 # default: 4
50 | kernel_size: 3
51 | p_dropout: 0.1
52 | speaker_cond_layer: 3 # 0 to disable speaker conditioning
53 | resblock: "1"
54 | resblock_kernel_sizes: [3, 7, 11]
55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
56 | upsample_rates: [8, 8, 2, 2]
57 | upsample_initial_channel: 512
58 | upsample_kernel_sizes: [16, 16, 4, 4]
59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS
60 | mas_noise_scale_decay: 2.0e-06
61 | use_spectral_norm: false
62 | use_transformer_flow: false
63 | gin_channels: 256
64 |
--------------------------------------------------------------------------------
/datasets/vctk_base/config.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | log_interval: 200
3 | eval_interval: 1000
4 | seed: 1234
5 | epochs: 10000
6 | learning_rate: 0.0002
7 | betas: [0.8, 0.99]
8 | eps: 1.0e-09
9 | batch_size: 64
10 | fp16_run: true
11 | lr_decay: 0.999875
12 | segment_size: 8192
13 | init_lr_ratio: 1
14 | warmup_epochs: 0
15 | c_mel: 45
16 | c_kl_text: 0 # default: 0
17 | c_kl_dur: 2 # default: 2
18 | c_kl_audio: 0.05 # 0.05 Allow more audio variation. default: 1
19 |
20 | data:
21 | training_files: datasets/vctk_base/filelists/vctk_audio_sid_text_train_filelist.txt.cleaned
22 | validation_files: datasets/vctk_base/filelists/vctk_audio_sid_text_val_filelist.txt.cleaned
23 | vocab_file: datasets/vctk_base/vocab.txt
24 | text_cleaners:
25 | - phonemize_text
26 | - add_spaces
27 | - tokenize_text
28 | - add_bos_eos
29 | cleaned_text: true
30 | language: en-us
31 | bits_per_sample: 16
32 | sample_rate: 22050
33 | n_fft: 2048
34 | hop_length: 256
35 | win_length: 1024
36 | n_mels: 80 # 100 works better with "slaney" mel-transform. default: 80
37 | f_min: 0
38 | f_max:
39 | n_speakers: 109
40 | use_mel: true
41 |
42 | model:
43 | inter_channels: 192
44 | hidden_channels: 192
45 | filter_channels: 768
46 | n_heads: 2
47 | n_layers: 6
48 | n_layers_q: 12 # default: 16
49 | n_flows: 8 # default: 4
50 | kernel_size: 3
51 | p_dropout: 0.1
52 | speaker_cond_layer: 3 # 0 to disable speaker conditioning
53 | resblock: "1"
54 | resblock_kernel_sizes: [3, 7, 11]
55 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
56 | upsample_rates: [8, 8, 2, 2]
57 | upsample_initial_channel: 512
58 | upsample_kernel_sizes: [16, 16, 4, 4]
59 | mas_noise_scale: 0.01 # 0.0 to disable Gaussian noise in MAS
60 | mas_noise_scale_decay: 2.0e-06
61 | use_spectral_norm: false
62 | use_transformer_flow: false
63 | gin_channels: 256
64 |
--------------------------------------------------------------------------------
/datasets/vctk_base/filelists/vctk_audio_sid_text_val_filelist.txt:
--------------------------------------------------------------------------------
1 | DUMMY2/p364/p364_240.wav|88|It had happened to him.
2 | DUMMY2/p280/p280_148.wav|52|It is open season on the Old Firm.
3 | DUMMY2/p231/p231_320.wav|50|However, he is a coach, and he remains a coach at heart.
4 | DUMMY2/p282/p282_129.wav|83|It is not a U-turn.
5 | DUMMY2/p254/p254_015.wav|41|The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.
6 | DUMMY2/p228/p228_285.wav|57|The songs are just so good.
7 | DUMMY2/p334/p334_307.wav|38|If they don't, they can expect their funding to be cut.
8 | DUMMY2/p287/p287_081.wav|77|I've never seen anything like it.
9 | DUMMY2/p247/p247_083.wav|14|It is a job creation scheme.)
10 | DUMMY2/p264/p264_051.wav|65|We were leading by two goals.)
11 | DUMMY2/p335/p335_058.wav|49|Let's see that increase over the years.
12 | DUMMY2/p236/p236_225.wav|75|There is no quick fix.
13 | DUMMY2/p374/p374_353.wav|11|And that brings us to the point.
14 | DUMMY2/p272/p272_076.wav|69|Sounds like The Sixth Sense?
15 | DUMMY2/p271/p271_152.wav|27|The petition was formally presented at Downing Street yesterday.
16 | DUMMY2/p228/p228_127.wav|57|They've got to account for it.
17 | DUMMY2/p276/p276_223.wav|106|It's been a humbling year.
18 | DUMMY2/p262/p262_248.wav|45|The project has already secured the support of Sir Sean Connery.
19 | DUMMY2/p314/p314_086.wav|51|The team this year is going places.
20 | DUMMY2/p225/p225_038.wav|101|Diving is no part of football.
21 | DUMMY2/p279/p279_088.wav|25|The shareholders will vote to wind up the company on Friday morning.
22 | DUMMY2/p272/p272_018.wav|69|Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.
23 | DUMMY2/p256/p256_098.wav|90|She told The Herald.
24 | DUMMY2/p261/p261_218.wav|100|All will be revealed in due course.
25 | DUMMY2/p265/p265_063.wav|73|IT shouldn't come as a surprise, but it does.
26 | DUMMY2/p314/p314_042.wav|51|It is all about people being assaulted, abused.
27 | DUMMY2/p241/p241_188.wav|86|I wish I could say something.
28 | DUMMY2/p283/p283_111.wav|95|It's good to have a voice.
29 | DUMMY2/p275/p275_006.wav|40|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.
30 | DUMMY2/p228/p228_092.wav|57|Today I couldn't run on it.
31 | DUMMY2/p295/p295_343.wav|92|The atmosphere is businesslike.
32 | DUMMY2/p228/p228_187.wav|57|They will run a mile.
33 | DUMMY2/p294/p294_317.wav|104|It didn't put me off.
34 | DUMMY2/p231/p231_445.wav|50|It sounded like a bomb.
35 | DUMMY2/p272/p272_086.wav|69|Today she has been released.
36 | DUMMY2/p255/p255_210.wav|31|It was worth a photograph.
37 | DUMMY2/p229/p229_060.wav|67|And a film maker was born.
38 | DUMMY2/p260/p260_232.wav|81|The Home Office would not release any further details about the group.
39 | DUMMY2/p245/p245_025.wav|59|Johnson was pretty low.
40 | DUMMY2/p333/p333_185.wav|64|This area is perfect for children.
41 | DUMMY2/p244/p244_242.wav|78|He is a man of the people.
42 | DUMMY2/p376/p376_187.wav|71|"It is a terrible loss."
43 | DUMMY2/p239/p239_156.wav|48|It is a good lifestyle.
44 | DUMMY2/p307/p307_037.wav|22|He released a half-dozen solo albums.
45 | DUMMY2/p305/p305_185.wav|54|I am not even thinking about that.
46 | DUMMY2/p272/p272_081.wav|69|It was magic.
47 | DUMMY2/p302/p302_297.wav|30|I'm trying to stay open on that.
48 | DUMMY2/p275/p275_320.wav|40|We are in the end game.
49 | DUMMY2/p239/p239_231.wav|48|Then we will face the Danish champions.
50 | DUMMY2/p268/p268_301.wav|87|It was only later that the condition was diagnosed.
51 | DUMMY2/p336/p336_088.wav|98|They failed to reach agreement yesterday.
52 | DUMMY2/p278/p278_255.wav|10|They made such decisions in London.
53 | DUMMY2/p361/p361_132.wav|79|That got me out.
54 | DUMMY2/p307/p307_146.wav|22|You hope he prevails.
55 | DUMMY2/p244/p244_147.wav|78|They could not ignore the will of parliament, he claimed.
56 | DUMMY2/p294/p294_283.wav|104|This is our unfinished business.
57 | DUMMY2/p283/p283_300.wav|95|I would have the hammer in the crowd.
58 | DUMMY2/p239/p239_079.wav|48|I can understand the frustrations of our fans.
59 | DUMMY2/p264/p264_009.wav|65|There is , according to legend, a boiling pot of gold at one end. )
60 | DUMMY2/p307/p307_348.wav|22|He did not oppose the divorce.
61 | DUMMY2/p304/p304_308.wav|72|We are the gateway to justice.
62 | DUMMY2/p281/p281_056.wav|36|None has ever been found.
63 | DUMMY2/p267/p267_158.wav|0|We were given a warm and friendly reception.
64 | DUMMY2/p300/p300_169.wav|102|Who do these people think they are?
65 | DUMMY2/p276/p276_177.wav|106|They exist in name alone.
66 | DUMMY2/p228/p228_245.wav|57|It is a policy which has the full support of the minister.
67 | DUMMY2/p300/p300_303.wav|102|I'm wondering what you feel about the youngest.
68 | DUMMY2/p362/p362_247.wav|15|This would give Scotland around eight members.
69 | DUMMY2/p326/p326_031.wav|28|United were in control without always being dominant.
70 | DUMMY2/p361/p361_288.wav|79|I did not think it was very proper.
71 | DUMMY2/p286/p286_145.wav|63|Tiger is not the norm.
72 | DUMMY2/p234/p234_071.wav|3|She did that for the rest of her life.
73 | DUMMY2/p263/p263_296.wav|39|The decision was announced at its annual conference in Dunfermline.
74 | DUMMY2/p323/p323_228.wav|34|She became a heroine of my childhood.
75 | DUMMY2/p280/p280_346.wav|52|It was a bit like having children.
76 | DUMMY2/p333/p333_080.wav|64|But the tragedy did not stop there.
77 | DUMMY2/p226/p226_268.wav|43|That decision is for the British Parliament and people.
78 | DUMMY2/p362/p362_314.wav|15|Is that right?
79 | DUMMY2/p240/p240_047.wav|93|It is so sad.
80 | DUMMY2/p250/p250_207.wav|24|You could feel the heat.
81 | DUMMY2/p273/p273_176.wav|56|Neither side would reveal the details of the offer.
82 | DUMMY2/p316/p316_147.wav|85|And frankly, it's been a while.
83 | DUMMY2/p265/p265_047.wav|73|It is unique.
84 | DUMMY2/p336/p336_353.wav|98|Sometimes you get them, sometimes you don't.
85 | DUMMY2/p230/p230_376.wav|35|This hasn't happened in a vacuum.
86 | DUMMY2/p308/p308_209.wav|107|There is great potential on this river.
87 | DUMMY2/p250/p250_442.wav|24|We have not yet received a letter from the Irish.
88 | DUMMY2/p260/p260_037.wav|81|It's a fact.
89 | DUMMY2/p299/p299_345.wav|58|We're very excited and challenged by the project.
90 | DUMMY2/p269/p269_218.wav|94|A Grampian Police spokesman said.
91 | DUMMY2/p306/p306_014.wav|12|To the Hebrews it was a token that there would be no more universal floods.
92 | DUMMY2/p271/p271_292.wav|27|It's a record label, not a form of music.
93 | DUMMY2/p247/p247_225.wav|14|I am considered a teenager.)
94 | DUMMY2/p294/p294_094.wav|104|It should be a condition of employment.
95 | DUMMY2/p269/p269_031.wav|94|Is this accurate?
96 | DUMMY2/p275/p275_116.wav|40|It's not fair.
97 | DUMMY2/p265/p265_006.wav|73|When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.
98 | DUMMY2/p285/p285_072.wav|2|Mr Irvine said Mr Rafferty was now in good spirits.
99 | DUMMY2/p270/p270_167.wav|8|We did what we had to do.
100 | DUMMY2/p360/p360_397.wav|60|It is a relief.
101 |
--------------------------------------------------------------------------------
/datasets/vctk_base/filelists/vctk_audio_sid_text_val_filelist.txt.cleaned:
--------------------------------------------------------------------------------
1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm.
2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm.
3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt.
4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn.
5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn.
6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd.
7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt.
8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt.
9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm.
10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz.
11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz.
12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks.
13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt.
14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns?
15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ.
16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt.
17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ.
18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi.
19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz.
20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl.
21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ.
22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn.
23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld.
24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs.
25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz.
26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd.
27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ.
28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs.
29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ.
30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt.
31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk.
32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl.
33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf.
34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm.
35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd.
36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf.
37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn.
38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp.
39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ.
40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən.
41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl.
42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs."
43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl.
44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz.
45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt.
46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk.
47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt.
48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm.
49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz.
50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd.
51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ.
52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən.
53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt.
54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz.
55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd.
56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs.
57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd.
58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz.
59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd.
60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs.
61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs.
62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd.
63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən.
64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ?
65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn.
66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ.
67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst.
68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz.
69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt.
70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ.
71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm.
72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf.
73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn.
74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd.
75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən.
76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ.
77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl.
78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt?
79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd.
80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt.
81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ.
82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl.
83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk.
84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt.
85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm.
86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ.
87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ.
88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt.
89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt.
90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd.
91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz.
92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk.
93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ.
94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt.
95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət?
96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ.
97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ.
98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts.
99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː.
100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf.
101 |
--------------------------------------------------------------------------------
/figures/figure01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daniilrobnikov/vits2/0525da4a558da999a725b9fddaa4584617df328b/figures/figure01.png
--------------------------------------------------------------------------------
/figures/figure02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daniilrobnikov/vits2/0525da4a558da999a725b9fddaa4584617df328b/figures/figure02.png
--------------------------------------------------------------------------------
/figures/figure03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daniilrobnikov/vits2/0525da4a558da999a725b9fddaa4584617df328b/figures/figure03.png
--------------------------------------------------------------------------------
/inference_batch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "from tqdm import tqdm\n",
11 | "import torch\n",
12 | "import torchaudio\n",
13 | "from torch.utils.data import Dataset, DataLoader\n",
14 | "\n",
15 | "from utils.task import load_checkpoint\n",
16 | "from utils.hparams import get_hparams_from_file\n",
17 | "from model.models import SynthesizerTrn\n",
18 | "from text.symbols import symbols\n",
19 | "from text import tokenizer\n",
20 | "\n",
21 | "\n",
22 | "def get_text(text: str, hps) -> torch.LongTensor:\n",
23 | " text_norm = tokenizer(text, hps.data.text_cleaners, language=hps.data.language)\n",
24 | " text_norm = torch.LongTensor(text_norm)\n",
25 | " return text_norm"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "dataset_path = \"filelists/madasr23_test.csv\"\n",
35 | "output_path = \"/path/to/output/directory\"\n",
36 | "data = pd.read_csv(dataset_path, sep=\"|\")\n",
37 | "print(data.head())"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "## MADASR23 batch inference\n"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "model = \"custom_base\"\n",
54 | "hps = get_hparams_from_file(f\"./datasets/{model}/config.yaml\")"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "net_g = SynthesizerTrn(len(symbols), hps.data.n_mels if hps.data.use_mel else hps.data.n_fft // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model).cuda()\n",
64 | "_ = net_g.eval()\n",
65 | "\n",
66 | "_ = load_checkpoint(f\"./datasets/{model}/logs/G_15000.pth\", net_g, None)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "class MyDataset(Dataset):\n",
76 | " def __init__(self, dataframe, hps):\n",
77 | " self.data = dataframe\n",
78 | " self.hps = hps\n",
79 | "\n",
80 | " def __len__(self):\n",
81 | " return len(self.data)\n",
82 | "\n",
83 | " def __getitem__(self, idx):\n",
84 | " sid_idx = self.data[\"sid_idx\"][idx]\n",
85 | " sid = self.data[\"sid\"][idx]\n",
86 | " phonemes = self.data[\"phonemes\"][idx]\n",
87 | " stn_tst = get_text(phonemes, self.hps)\n",
88 | " return sid_idx, sid, stn_tst, idx\n",
89 | "\n",
90 | "\n",
91 | "# Initialize the dataset and data loader\n",
92 | "dataset = MyDataset(data, hps)\n",
93 | "data_loader = DataLoader(dataset, batch_size=1, num_workers=8)\n",
94 | "\n",
95 | "for sid_idx, spk_id, stn_tst, i in tqdm(data_loader):\n",
96 | " sid_idx = int(sid_idx)\n",
97 | " spk_id = int(spk_id)\n",
98 | " i = int(i)\n",
99 | " stn_tst = stn_tst[0]\n",
100 | " with torch.no_grad():\n",
101 | " x_tst = stn_tst.cuda().unsqueeze(0)\n",
102 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
103 | " sid = torch.LongTensor([sid_idx]).cuda()\n",
104 | " audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667,\n",
105 | " noise_scale_w=0.8, length_scale=1)[0][0].data.cpu()\n",
106 | " torchaudio.save(f\"{output_path}/{spk_id}_{i}.wav\", audio,\n",
107 | " hps.data.sample_rate, bits_per_sample=hps.data.bits_per_sample)\n",
108 | "\n",
109 | "print(\"Done!\")"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "### Voice Conversion\n",
117 | "\n",
118 | "TODO: Add batch inference for voice conversion\n"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": []
125 | }
126 | ],
127 | "metadata": {
128 | "kernelspec": {
129 | "display_name": "Python 3",
130 | "language": "python",
131 | "name": "python3"
132 | },
133 | "language_info": {
134 | "codemirror_mode": {
135 | "name": "ipython",
136 | "version": 3
137 | },
138 | "file_extension": ".py",
139 | "mimetype": "text/x-python",
140 | "name": "python",
141 | "nbconvert_exporter": "python",
142 | "pygments_lexer": "ipython3",
143 | "version": "3.11.4"
144 | }
145 | },
146 | "nbformat": 4,
147 | "nbformat_minor": 4
148 | }
149 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 |
4 |
5 | def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
6 | loss = 0
7 | for dr, dg in zip(fmap_r, fmap_g):
8 | for rl, gl in zip(dr, dg):
9 | rl = rl.float().detach()
10 | gl = gl.float()
11 | loss += torch.mean(torch.abs(rl - gl))
12 |
13 | return loss * 2
14 |
15 |
16 | def discriminator_loss(disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]):
17 | loss = 0
18 | r_losses = []
19 | g_losses = []
20 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
21 | dr = dr.float()
22 | dg = dg.float()
23 | r_loss = torch.mean((1 - dr) ** 2)
24 | g_loss = torch.mean(dg**2)
25 | loss += r_loss + g_loss
26 | r_losses.append(r_loss.item())
27 | g_losses.append(g_loss.item())
28 |
29 | return loss, r_losses, g_losses
30 |
31 |
32 | def generator_loss(disc_outputs: List[torch.Tensor]):
33 | loss = 0
34 | gen_losses = []
35 | for dg in disc_outputs:
36 | dg = dg.float()
37 | l = torch.mean((1 - dg) ** 2)
38 | gen_losses.append(l)
39 | loss += l
40 |
41 | return loss, gen_losses
42 |
43 |
44 | def kl_loss(z_p: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, z_mask: torch.Tensor):
45 | """
46 | z_p, logs_q: [b, h, t_t]
47 | m_p, logs_p: [b, h, t_t]
48 | """
49 | z_p = z_p.float()
50 | logs_q = logs_q.float()
51 | m_p = m_p.float()
52 | logs_p = logs_p.float()
53 | z_mask = z_mask.float()
54 |
55 | kl = logs_p - logs_q - 0.5
56 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
57 | kl = torch.sum(kl * z_mask)
58 | l = kl / torch.sum(z_mask)
59 | return l
60 |
61 |
62 | def kl_loss_normal(m_q: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, z_mask: torch.Tensor):
63 | """
64 | z_p, logs_q: [b, h, t_t]
65 | m_p, logs_p: [b, h, t_t]
66 | """
67 | m_q = m_q.float()
68 | logs_q = logs_q.float()
69 | m_p = m_p.float()
70 | logs_p = logs_p.float()
71 | z_mask = z_mask.float()
72 |
73 | kl = logs_p - logs_q - 0.5
74 | kl += 0.5 * (torch.exp(2.0 * logs_q) + (m_q - m_p) ** 2) * torch.exp(-2.0 * logs_p)
75 | kl = torch.sum(kl * z_mask)
76 | l = kl / torch.sum(z_mask)
77 | return l
78 |
--------------------------------------------------------------------------------
/model/condition.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MultiCondLayer(nn.Module):
6 | def __init__(
7 | self,
8 | gin_channels: int,
9 | out_channels: int,
10 | n_cond: int,
11 | ):
12 | """MultiCondLayer of VITS model.
13 |
14 | Args:
15 | gin_channels (int): Number of conditioning tensor channels.
16 | out_channels (int): Number of output tensor channels.
17 | n_cond (int): Number of conditions.
18 | """
19 | super().__init__()
20 | self.n_cond = n_cond
21 |
22 | self.cond_layers = nn.ModuleList()
23 | for _ in range(n_cond):
24 | self.cond_layers.append(nn.Linear(gin_channels, out_channels))
25 |
26 | def forward(self, cond: torch.Tensor, x_mask: torch.Tensor):
27 | """
28 | Shapes:
29 | - cond: :math:`[B, C, N]`
30 | - x_mask: :math`[B, 1, T]`
31 | """
32 |
33 | cond_out = torch.zeros_like(cond)
34 | for i in range(self.n_cond):
35 | cond_in = self.cond_layers[i](cond.mT).mT
36 | cond_out = cond_out + cond_in
37 | return cond_out * x_mask
38 |
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm, remove_weight_norm
5 |
6 | from model.modules import LRELU_SLOPE
7 | from utils.model import init_weights, get_padding
8 |
9 |
10 | class Generator(nn.Module):
11 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
12 | super(Generator, self).__init__()
13 | self.num_kernels = len(resblock_kernel_sizes)
14 | self.num_upsamples = len(upsample_rates)
15 | self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
16 | resblock = ResBlock1 if resblock == "1" else ResBlock2
17 |
18 | self.ups = nn.ModuleList()
19 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
20 | self.ups.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
21 |
22 | self.resblocks = nn.ModuleList()
23 | for i in range(len(self.ups)):
24 | ch = upsample_initial_channel // (2 ** (i + 1))
25 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
26 | self.resblocks.append(resblock(ch, k, d))
27 |
28 | self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
29 | self.ups.apply(init_weights)
30 |
31 | if gin_channels != 0:
32 | self.cond = nn.Linear(gin_channels, upsample_initial_channel)
33 |
34 | def forward(self, x, g=None):
35 | x = self.conv_pre(x)
36 | if g is not None:
37 | x = x + self.cond(g.mT).mT
38 |
39 | for i in range(self.num_upsamples):
40 | x = F.leaky_relu(x, LRELU_SLOPE)
41 | x = self.ups[i](x)
42 | xs = None
43 | for j in range(self.num_kernels):
44 | if xs is None:
45 | xs = self.resblocks[i * self.num_kernels + j](x)
46 | else:
47 | xs += self.resblocks[i * self.num_kernels + j](x)
48 | x = xs / self.num_kernels
49 | x = F.leaky_relu(x)
50 | x = self.conv_post(x)
51 | x = torch.tanh(x)
52 |
53 | return x
54 |
55 | def remove_weight_norm(self):
56 | print("Removing weight norm...")
57 | for l in self.ups:
58 | remove_weight_norm(l)
59 | for l in self.resblocks:
60 | l.remove_weight_norm()
61 |
62 |
63 | class ResBlock1(nn.Module):
64 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
65 | super(ResBlock1, self).__init__()
66 | self.convs1 = nn.ModuleList(
67 | [
68 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
69 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
70 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))),
71 | ]
72 | )
73 | self.convs1.apply(init_weights)
74 |
75 | self.convs2 = nn.ModuleList(
76 | [
77 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
78 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
79 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
80 | ]
81 | )
82 | self.convs2.apply(init_weights)
83 |
84 | def forward(self, x, x_mask=None):
85 | for c1, c2 in zip(self.convs1, self.convs2):
86 | xt = F.leaky_relu(x, LRELU_SLOPE)
87 | if x_mask is not None:
88 | xt = xt * x_mask
89 | xt = c1(xt)
90 | xt = F.leaky_relu(xt, LRELU_SLOPE)
91 | if x_mask is not None:
92 | xt = xt * x_mask
93 | xt = c2(xt)
94 | x = xt + x
95 | if x_mask is not None:
96 | x = x * x_mask
97 | return x
98 |
99 | def remove_weight_norm(self):
100 | for l in self.convs1:
101 | remove_weight_norm(l)
102 | for l in self.convs2:
103 | remove_weight_norm(l)
104 |
105 |
106 | class ResBlock2(nn.Module):
107 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
108 | super(ResBlock2, self).__init__()
109 | self.convs = nn.ModuleList(
110 | [
111 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
112 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
113 | ]
114 | )
115 | self.convs.apply(init_weights)
116 |
117 | def forward(self, x, x_mask=None):
118 | for c in self.convs:
119 | xt = F.leaky_relu(x, LRELU_SLOPE)
120 | if x_mask is not None:
121 | xt = xt * x_mask
122 | xt = c(xt)
123 | x = xt + x
124 | if x_mask is not None:
125 | x = x * x_mask
126 | return x
127 |
128 | def remove_weight_norm(self):
129 | for l in self.convs:
130 | remove_weight_norm(l)
131 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm, spectral_norm
5 |
6 | from model.modules import LRELU_SLOPE
7 | from utils.model import get_padding
8 |
9 |
10 | class DiscriminatorP(nn.Module):
11 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
12 | super(DiscriminatorP, self).__init__()
13 | self.period = period
14 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
15 | self.convs = nn.ModuleList(
16 | [
17 | norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
18 | norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
19 | norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
20 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
21 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
22 | ]
23 | )
24 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
25 |
26 | def forward(self, x):
27 | fmap = []
28 |
29 | # 1d to 2d
30 | b, c, t = x.shape
31 | if t % self.period != 0: # pad first
32 | n_pad = self.period - (t % self.period)
33 | x = F.pad(x, (0, n_pad), "reflect")
34 | t = t + n_pad
35 | x = x.view(b, c, t // self.period, self.period)
36 |
37 | for l in self.convs:
38 | x = l(x)
39 | x = F.leaky_relu(x, LRELU_SLOPE)
40 | fmap.append(x)
41 | x = self.conv_post(x)
42 | fmap.append(x)
43 | x = torch.flatten(x, 1, -1)
44 |
45 | return x, fmap
46 |
47 |
48 | class DiscriminatorS(nn.Module):
49 | def __init__(self, use_spectral_norm=False):
50 | super(DiscriminatorS, self).__init__()
51 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
52 | self.convs = nn.ModuleList(
53 | [
54 | norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
55 | norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
56 | norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
57 | norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
58 | norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
59 | norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
60 | ]
61 | )
62 | self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
63 |
64 | def forward(self, x):
65 | fmap = []
66 |
67 | for l in self.convs:
68 | x = l(x)
69 | x = F.leaky_relu(x, LRELU_SLOPE)
70 | fmap.append(x)
71 | x = self.conv_post(x)
72 | fmap.append(x)
73 | x = torch.flatten(x, 1, -1)
74 |
75 | return x, fmap
76 |
77 |
78 | class MultiPeriodDiscriminator(nn.Module):
79 | def __init__(self, use_spectral_norm=False):
80 | super(MultiPeriodDiscriminator, self).__init__()
81 | periods = [2, 3, 5, 7, 11] # [1, 2, 3, 5, 7, 11]
82 |
83 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
84 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
85 | self.discriminators = nn.ModuleList(discs)
86 |
87 | def forward(self, y, y_hat):
88 | y_d_rs = []
89 | y_d_gs = []
90 | fmap_rs = []
91 | fmap_gs = []
92 | for i, d in enumerate(self.discriminators):
93 | y_d_r, fmap_r = d(y)
94 | y_d_g, fmap_g = d(y_hat)
95 | y_d_rs.append(y_d_r)
96 | y_d_gs.append(y_d_g)
97 | fmap_rs.append(fmap_r)
98 | fmap_gs.append(fmap_g)
99 |
100 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
101 |
--------------------------------------------------------------------------------
/model/duration_predictors.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from model.modules import Flip
7 | from model.normalization import LayerNorm
8 | from utils.transforms import piecewise_rational_quadratic_transform
9 |
10 |
11 | class StochasticDurationPredictor(nn.Module):
12 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
13 | super().__init__()
14 | self.log_flow = Log()
15 | self.flows = nn.ModuleList()
16 | self.flows.append(ElementwiseAffine(2))
17 | for i in range(n_flows):
18 | self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
19 | self.flows.append(Flip())
20 |
21 | self.pre = nn.Linear(in_channels, filter_channels)
22 | self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
23 | self.proj = nn.Linear(filter_channels, filter_channels)
24 |
25 | self.post_pre = nn.Linear(1, filter_channels)
26 | self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
27 | self.post_proj = nn.Linear(filter_channels, filter_channels)
28 |
29 | self.post_flows = nn.ModuleList()
30 | self.post_flows.append(ElementwiseAffine(2))
31 | for i in range(4):
32 | self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
33 | self.post_flows.append(Flip())
34 |
35 | if gin_channels != 0:
36 | self.cond = nn.Linear(gin_channels, filter_channels)
37 |
38 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, w=None, g=None, reverse=False, noise_scale=1.0):
39 | x = torch.detach(x)
40 | x = self.pre(x.mT).mT
41 | if g is not None:
42 | g = torch.detach(g)
43 | x = x + self.cond(g.mT).mT
44 | x = self.convs(x, x_mask)
45 | x = self.proj(x.mT).mT * x_mask
46 |
47 | if not reverse:
48 | flows = self.flows
49 | assert w is not None
50 |
51 | logdet_tot_q = 0
52 | h_w = self.post_pre(w.mT).mT
53 | h_w = self.post_convs(h_w, x_mask)
54 | h_w = self.post_proj(h_w.mT).mT * x_mask
55 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
56 | z_q = e_q
57 | for flow in self.post_flows:
58 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
59 | logdet_tot_q += logdet_q
60 | z_u, z1 = torch.split(z_q, [1, 1], 1)
61 | u = torch.sigmoid(z_u) * x_mask
62 | z0 = (w - u) * x_mask
63 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
64 | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
65 |
66 | logdet_tot = 0
67 | z0, logdet = self.log_flow(z0, x_mask)
68 | logdet_tot += logdet
69 | z = torch.cat([z0, z1], 1)
70 | for flow in flows:
71 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
72 | logdet_tot = logdet_tot + logdet
73 | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
74 | return nll + logq # [b]
75 | else:
76 | flows = list(reversed(self.flows))
77 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
78 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
79 | for flow in flows:
80 | z = flow(z, x_mask, g=x, reverse=reverse)
81 | z0, z1 = torch.split(z, [1, 1], 1)
82 | logw = z0
83 | return logw
84 |
85 |
86 | class ConvFlow(nn.Module):
87 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
88 | super().__init__()
89 | self.filter_channels = filter_channels
90 | self.num_bins = num_bins
91 | self.tail_bound = tail_bound
92 | self.half_channels = in_channels // 2
93 |
94 | self.pre = nn.Linear(self.half_channels, filter_channels)
95 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
96 | self.proj = nn.Linear(filter_channels, self.half_channels * (num_bins * 3 - 1))
97 | self.proj.weight.data.zero_()
98 | self.proj.bias.data.zero_()
99 |
100 | def forward(self, x, x_mask, g=None, reverse=False):
101 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
102 | h = self.pre(x0.mT).mT
103 | h = self.convs(h, x_mask, g=g)
104 | h = self.proj(h.mT).mT * x_mask
105 |
106 | b, c, t = x0.shape
107 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
108 |
109 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
110 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
111 | unnormalized_derivatives = h[..., 2 * self.num_bins :]
112 |
113 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, inverse=reverse, tails="linear", tail_bound=self.tail_bound)
114 |
115 | x = torch.cat([x0, x1], 1) * x_mask
116 | logdet = torch.sum(logabsdet * x_mask, [1, 2])
117 | if not reverse:
118 | return x, logdet
119 | else:
120 | return x
121 |
122 |
123 | class DDSConv(nn.Module):
124 | """
125 | Dialted and Depth-Separable Convolution
126 | """
127 |
128 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
129 | super().__init__()
130 | self.n_layers = n_layers
131 |
132 | self.drop = nn.Dropout(p_dropout)
133 | self.convs_sep = nn.ModuleList()
134 | self.linears = nn.ModuleList()
135 | self.norms_1 = nn.ModuleList()
136 | self.norms_2 = nn.ModuleList()
137 | for i in range(n_layers):
138 | dilation = kernel_size**i
139 | padding = (kernel_size * dilation - dilation) // 2
140 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding))
141 | self.linears.append(nn.Linear(channels, channels))
142 | self.norms_1.append(LayerNorm(channels))
143 | self.norms_2.append(LayerNorm(channels))
144 |
145 | def forward(self, x, x_mask, g=None):
146 | if g is not None:
147 | x = x + g
148 | for i in range(self.n_layers):
149 | y = self.convs_sep[i](x * x_mask)
150 | y = self.norms_1[i](y)
151 | y = F.gelu(y)
152 | y = self.linears[i](y.mT).mT
153 | y = self.norms_2[i](y)
154 | y = F.gelu(y)
155 | y = self.drop(y)
156 | x = x + y
157 | return x * x_mask
158 |
159 |
160 | # TODO convert to class method
161 | class Log(nn.Module):
162 | def forward(self, x, x_mask, reverse=False, **kwargs):
163 | if not reverse:
164 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
165 | logdet = torch.sum(-y, [1, 2])
166 | return y, logdet
167 | else:
168 | x = torch.exp(x) * x_mask
169 | return x
170 |
171 |
172 | class ElementwiseAffine(nn.Module):
173 | def __init__(self, channels):
174 | super().__init__()
175 | self.m = nn.Parameter(torch.zeros(channels, 1))
176 | self.logs = nn.Parameter(torch.zeros(channels, 1))
177 |
178 | def forward(self, x, x_mask, reverse=False, **kwargs):
179 | if not reverse:
180 | y = self.m + torch.exp(self.logs) * x
181 | y = y * x_mask
182 | logdet = torch.sum(self.logs * x_mask, [1, 2])
183 | return y, logdet
184 | else:
185 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
186 | return x
187 |
188 |
189 | class DurationPredictor(nn.Module):
190 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
191 | super().__init__()
192 | self.drop = nn.Dropout(p_dropout)
193 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
194 | self.norm_1 = LayerNorm(filter_channels)
195 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
196 | self.norm_2 = LayerNorm(filter_channels)
197 | self.proj = nn.Linear(filter_channels, 1)
198 |
199 | if gin_channels != 0:
200 | self.cond = nn.Linear(gin_channels, in_channels)
201 |
202 | def forward(self, x, x_mask, g=None):
203 | x = torch.detach(x)
204 | if g is not None:
205 | g = torch.detach(g)
206 | x = x + self.cond(g.mT).mT
207 | x = self.conv_1(x * x_mask)
208 | x = torch.relu(x)
209 | x = self.norm_1(x)
210 | x = self.drop(x)
211 | x = self.conv_2(x * x_mask)
212 | x = torch.relu(x)
213 | x = self.norm_2(x)
214 | x = self.drop(x)
215 | x = self.proj((x * x_mask).mT).mT
216 | return x * x_mask
217 |
--------------------------------------------------------------------------------
/model/encoders.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from model.modules import WN
6 | from model.transformer import RelativePositionTransformer
7 | from utils.model import sequence_mask
8 |
9 |
10 | # * Ready and Tested
11 | class TextEncoder(nn.Module):
12 | def __init__(
13 | self,
14 | n_vocab: int,
15 | out_channels: int,
16 | hidden_channels: int,
17 | hidden_channels_ffn: int,
18 | n_heads: int,
19 | n_layers: int,
20 | kernel_size: int,
21 | dropout: float,
22 | gin_channels=0,
23 | lang_channels=0,
24 | speaker_cond_layer=0,
25 | ):
26 | """Text Encoder for VITS model.
27 |
28 | Args:
29 | n_vocab (int): Number of characters for the embedding layer.
30 | out_channels (int): Number of channels for the output.
31 | hidden_channels (int): Number of channels for the hidden layers.
32 | hidden_channels_ffn (int): Number of channels for the convolutional layers.
33 | n_heads (int): Number of attention heads for the Transformer layers.
34 | n_layers (int): Number of Transformer layers.
35 | kernel_size (int): Kernel size for the FFN layers in Transformer network.
36 | dropout (float): Dropout rate for the Transformer layers.
37 | gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
38 | lang_channels (int, optional): Number of channels for language embedding. Defaults to 0.
39 | """
40 | super().__init__()
41 | self.out_channels = out_channels
42 | self.hidden_channels = hidden_channels
43 |
44 | self.emb = nn.Embedding(n_vocab, hidden_channels)
45 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
46 |
47 | self.encoder = RelativePositionTransformer(
48 | in_channels=hidden_channels,
49 | out_channels=hidden_channels,
50 | hidden_channels=hidden_channels,
51 | hidden_channels_ffn=hidden_channels_ffn,
52 | n_heads=n_heads,
53 | n_layers=n_layers,
54 | kernel_size=kernel_size,
55 | dropout=dropout,
56 | window_size=4,
57 | gin_channels=gin_channels,
58 | lang_channels=lang_channels,
59 | speaker_cond_layer=speaker_cond_layer,
60 | )
61 | self.proj = nn.Linear(hidden_channels, out_channels * 2)
62 |
63 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: torch.Tensor = None, lang: torch.Tensor = None):
64 | """
65 | Shapes:
66 | - x: :math:`[B, T]`
67 | - x_length: :math:`[B]`
68 | """
69 | x = self.emb(x).mT * math.sqrt(self.hidden_channels) # [b, h, t]
70 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
71 |
72 | x = self.encoder(x, x_mask, g=g, lang=lang)
73 | stats = self.proj(x.mT).mT * x_mask
74 |
75 | m, logs = torch.split(stats, self.out_channels, dim=1)
76 | z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
77 | return z, m, logs, x, x_mask
78 |
79 |
80 | # * Ready and Tested
81 | class PosteriorEncoder(nn.Module):
82 | def __init__(
83 | self,
84 | in_channels: int,
85 | out_channels: int,
86 | hidden_channels: int,
87 | kernel_size: int,
88 | dilation_rate: int,
89 | n_layers: int,
90 | gin_channels=0,
91 | ):
92 | """Posterior Encoder of VITS model.
93 |
94 | ::
95 | x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
96 |
97 | Args:
98 | in_channels (int): Number of input tensor channels.
99 | out_channels (int): Number of output tensor channels.
100 | hidden_channels (int): Number of hidden channels.
101 | kernel_size (int): Kernel size of the WaveNet convolution layers.
102 | dilation_rate (int): Dilation rate of the WaveNet layers.
103 | num_layers (int): Number of the WaveNet layers.
104 | cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
105 | """
106 | super().__init__()
107 | self.out_channels = out_channels
108 |
109 | self.pre = nn.Linear(in_channels, hidden_channels)
110 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
111 | self.proj = nn.Linear(hidden_channels, out_channels * 2)
112 |
113 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g=None):
114 | """
115 | Shapes:
116 | - x: :math:`[B, C, T]`
117 | - x_lengths: :math:`[B, 1]`
118 | - g: :math:`[B, C, 1]`
119 | """
120 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
121 | x = self.pre(x.mT).mT * x_mask
122 | x = self.enc(x, x_mask, g=g)
123 | stats = self.proj(x.mT).mT * x_mask
124 | m, logs = torch.split(stats, self.out_channels, dim=1)
125 | z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
126 | return z, m, logs, x_mask
127 |
128 |
129 | # TODO Ready for testing
130 | class AudioEncoder(nn.Module):
131 | def __init__(
132 | self,
133 | in_channels: int,
134 | out_channels: int,
135 | hidden_channels: int,
136 | hidden_channels_ffn: int,
137 | n_heads: int,
138 | n_layers: int,
139 | kernel_size: int,
140 | dropout: float,
141 | gin_channels=0,
142 | lang_channels=0,
143 | speaker_cond_layer=0,
144 | ):
145 | """Audio Encoder of VITS model.
146 |
147 | Args:
148 | in_channels (int): Number of input tensor channels.
149 | out_channels (int): Number of channels for the output.
150 | hidden_channels (int): Number of channels for the hidden layers.
151 | hidden_channels_ffn (int): Number of channels for the convolutional layers.
152 | n_heads (int): Number of attention heads for the Transformer layers.
153 | n_layers (int): Number of Transformer layers.
154 | kernel_size (int): Kernel size for the FFN layers in Transformer network.
155 | dropout (float): Dropout rate for the Transformer layers.
156 | gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
157 | lang_channels (int, optional): Number of channels for language embedding. Defaults to 0.
158 | """
159 | super().__init__()
160 | self.out_channels = out_channels
161 | self.hidden_channels = hidden_channels
162 |
163 | self.pre = nn.Linear(in_channels, hidden_channels)
164 | self.encoder = RelativePositionTransformer(
165 | in_channels=hidden_channels,
166 | out_channels=hidden_channels,
167 | hidden_channels=hidden_channels,
168 | hidden_channels_ffn=hidden_channels_ffn,
169 | n_heads=n_heads,
170 | n_layers=n_layers,
171 | kernel_size=kernel_size,
172 | dropout=dropout,
173 | window_size=4,
174 | gin_channels=gin_channels,
175 | lang_channels=lang_channels,
176 | speaker_cond_layer=speaker_cond_layer,
177 | )
178 | self.post = nn.Linear(hidden_channels, out_channels * 2)
179 |
180 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: torch.Tensor = None, lang: torch.Tensor = None):
181 | """
182 | Shapes:
183 | - x: :math:`[B, C, T]`
184 | - x_lengths: :math:`[B, 1]`
185 | - g: :math:`[B, C, 1]`
186 | - lang: :math:`[B, C, 1]`
187 | """
188 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
189 |
190 | x = self.pre(x.mT).mT * x_mask # [B, C, t']
191 | x = self.encoder(x, x_mask, g=g, lang=lang)
192 | stats = self.post(x.mT).mT * x_mask
193 |
194 | m, logs = torch.split(stats, self.out_channels, dim=1)
195 | z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
196 | return z, m, logs, x_mask
197 |
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from model.encoders import TextEncoder, PosteriorEncoder, AudioEncoder
5 | from model.normalizing_flows import ResidualCouplingBlock
6 | from model.duration_predictors import DurationPredictor, StochasticDurationPredictor
7 | from model.decoder import Generator
8 | from utils.monotonic_align import search_path, generate_path
9 | from utils.model import sequence_mask, rand_slice_segments
10 |
11 |
12 | class SynthesizerTrn(nn.Module):
13 | """
14 | Synthesizer for Training
15 | """
16 |
17 | def __init__(
18 | self,
19 | n_vocab,
20 | spec_channels,
21 | segment_size,
22 | inter_channels,
23 | hidden_channels,
24 | filter_channels,
25 | n_heads,
26 | n_layers,
27 | n_layers_q,
28 | n_flows,
29 | kernel_size,
30 | p_dropout,
31 | speaker_cond_layer,
32 | resblock,
33 | resblock_kernel_sizes,
34 | resblock_dilation_sizes,
35 | upsample_rates,
36 | upsample_initial_channel,
37 | upsample_kernel_sizes,
38 | mas_noise_scale,
39 | mas_noise_scale_decay,
40 | use_sdp=True,
41 | use_transformer_flow=True,
42 | n_speakers=0,
43 | gin_channels=0,
44 | **kwargs
45 | ):
46 | super().__init__()
47 | self.segment_size = segment_size
48 | self.n_speakers = n_speakers
49 | self.use_sdp = use_sdp
50 | self.mas_noise_scale = mas_noise_scale
51 | self.mas_noise_scale_decay = mas_noise_scale_decay
52 |
53 | self.enc_p = TextEncoder(n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels=gin_channels, speaker_cond_layer=speaker_cond_layer)
54 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, n_layers_q, gin_channels=gin_channels)
55 | # self.enc_q = AudioEncoder(spec_channels, inter_channels, 32, 768, n_heads, 2, kernel_size, p_dropout, gin_channels=gin_channels)
56 | # self.enc_q = AudioEncoder(spec_channels, inter_channels, 32, 32, n_heads, 3, kernel_size, p_dropout, gin_channels=gin_channels)
57 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
58 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, n_flows=n_flows, gin_channels=gin_channels, mean_only=False, use_transformer_flow=use_transformer_flow)
59 |
60 | if use_sdp:
61 | self.dp = StochasticDurationPredictor(hidden_channels, hidden_channels, 3, 0.5, 4, gin_channels=gin_channels)
62 | else:
63 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
64 |
65 | if n_speakers > 1:
66 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
67 |
68 | def forward(self, x, x_lengths, y, y_lengths, sid=None):
69 | if self.n_speakers > 0:
70 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
71 | else:
72 | g = None
73 |
74 | z_p_text, m_p_text, logs_p_text, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
75 | z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g)
76 | z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g)
77 |
78 | attn = search_path(z_q_dur, m_p_text, logs_p_text, x_mask, y_mask, mas_noise_scale=self.mas_noise_scale)
79 | self.mas_noise_scale = max(self.mas_noise_scale - self.mas_noise_scale_decay, 0.0)
80 |
81 | w = attn.sum(2) # [b, 1, t_s]
82 |
83 | # * reduce posterior
84 | # TODO Test gain constant
85 | if False:
86 | attn_inv = attn.squeeze(1) * (1 / (w + 1e-9))
87 | m_q_text = torch.matmul(attn_inv.mT, m_q_dur.mT).mT
88 | logs_q_text = torch.matmul(attn_inv.mT, logs_q_dur.mT).mT
89 |
90 | # * expand prior
91 | if self.use_sdp:
92 | l_length = self.dp(h_text, x_mask, w, g=g)
93 | l_length = l_length / torch.sum(x_mask)
94 | else:
95 | logw_ = torch.log(w + 1e-6) * x_mask
96 | logw = self.dp(h_text, x_mask, g=g)
97 | l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging
98 | m_p_dur = torch.matmul(attn.squeeze(1), m_p_text.mT).mT
99 | logs_p_dur = torch.matmul(attn.squeeze(1), logs_p_text.mT).mT
100 | z_p_dur = m_p_dur + torch.randn_like(m_p_dur) * torch.exp(logs_p_dur) * y_mask
101 |
102 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_dur, m_p_dur, logs_p_dur, y_mask, g=g, reverse=True)
103 |
104 | z_slice, ids_slice = rand_slice_segments(z_q_audio, y_lengths, self.segment_size)
105 | o = self.dec(z_slice, g=g)
106 | return (
107 | o,
108 | l_length,
109 | attn,
110 | ids_slice,
111 | x_mask,
112 | y_mask,
113 | (m_p_text, logs_p_text),
114 | (m_p_dur, logs_p_dur, z_q_dur, logs_q_dur),
115 | (m_p_audio, logs_p_audio, m_q_audio, logs_q_audio),
116 | )
117 |
118 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=None):
119 | if self.n_speakers > 0:
120 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
121 | else:
122 | g = None
123 |
124 | z_p_text, m_p_text, logs_p_text, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
125 |
126 | if self.use_sdp:
127 | logw = self.dp(h_text, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
128 | else:
129 | logw = self.dp(h_text, x_mask, g=g)
130 | w = torch.exp(logw) * x_mask * length_scale
131 | w_ceil = torch.ceil(w)
132 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
133 | y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
134 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
135 | attn = generate_path(w_ceil, attn_mask)
136 |
137 | m_p_dur = torch.matmul(attn.squeeze(1), m_p_text.mT).mT # [b, t', t], [b, t, d] -> [b, d, t']
138 | logs_p_dur = torch.matmul(attn.squeeze(1), logs_p_text.mT).mT # [b, t', t], [b, t, d] -> [b, d, t']
139 | z_p_dur = m_p_dur + torch.randn_like(m_p_dur) * torch.exp(logs_p_dur) * noise_scale
140 |
141 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_dur, m_p_dur, logs_p_dur, y_mask, g=g, reverse=True)
142 | o = self.dec((z_p_audio * y_mask)[:, :, :max_len], g=g)
143 | return o, attn, y_mask, (z_p_dur, m_p_dur, logs_p_dur), (z_p_audio, m_p_audio, logs_p_audio)
144 |
145 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
146 | assert self.n_speakers > 0, "n_speakers have to be larger than 0."
147 | g_src = self.emb_g(sid_src).unsqueeze(-1)
148 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
149 | z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g_src)
150 | z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g_src)
151 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_q_dur, m_q_dur, logs_q_dur, y_mask, g=g_tgt, reverse=True)
152 | o_hat = self.dec(z_p_audio * y_mask, g=g_tgt)
153 | return o_hat, y_mask, (z_q_dur, m_q_dur, logs_q_dur), (z_p_audio, m_p_audio, logs_p_audio)
154 |
155 | def voice_restoration(self, y, y_lengths, sid=None):
156 | if self.n_speakers > 0:
157 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
158 | else:
159 | g = None
160 | z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g)
161 | z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g)
162 | z_p_audio, m_p_audio, logs_p_audio = self.flow(z_q_dur, m_q_dur, logs_q_dur, y_mask, g=g, reverse=True)
163 | o_hat = self.dec(z_p_audio * y_mask, g=g)
164 | return o_hat, y_mask, (z_q_dur, m_q_dur, logs_q_dur), (z_p_audio, m_p_audio, logs_p_audio)
165 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.model import fused_add_tanh_sigmoid_multiply
5 |
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | # ! PosteriorEncoder
11 | # ! ResidualCouplingLayer
12 | class WN(nn.Module):
13 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
14 | super(WN, self).__init__()
15 | assert kernel_size % 2 == 1
16 | self.hidden_channels = hidden_channels
17 | self.kernel_size = (kernel_size,)
18 | self.n_layers = n_layers
19 | self.gin_channels = gin_channels
20 |
21 | self.in_layers = nn.ModuleList()
22 | self.res_skip_layers = nn.ModuleList()
23 | self.drop = nn.Dropout(p_dropout)
24 |
25 | if gin_channels != 0:
26 | cond_layer = nn.Linear(gin_channels, 2 * hidden_channels * n_layers)
27 | self.cond_layer = nn.utils.weight_norm(cond_layer, name="weight")
28 |
29 | for i in range(n_layers):
30 | dilation = dilation_rate**i
31 | padding = int((kernel_size * dilation - dilation) / 2)
32 | in_layer = nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding)
33 | in_layer = nn.utils.weight_norm(in_layer, name="weight")
34 | self.in_layers.append(in_layer)
35 |
36 | # last one is not necessary
37 | res_skip_channels = 2 * hidden_channels if i < n_layers - 1 else hidden_channels
38 | res_skip_layer = nn.Linear(hidden_channels, res_skip_channels)
39 | res_skip_layer = nn.utils.weight_norm(res_skip_layer, name="weight")
40 | self.res_skip_layers.append(res_skip_layer)
41 |
42 | def forward(self, x, x_mask, g=None, **kwargs):
43 | output = torch.zeros_like(x)
44 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
45 |
46 | if g is not None:
47 | g = self.cond_layer(g.mT).mT
48 |
49 | for i in range(self.n_layers):
50 | x_in = self.in_layers[i](x)
51 | if g is not None:
52 | cond_offset = i * 2 * self.hidden_channels
53 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
54 | else:
55 | g_l = torch.zeros_like(x_in)
56 |
57 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
58 | acts = self.drop(acts)
59 |
60 | res_skip_acts = self.res_skip_layers[i](acts.mT).mT
61 | if i < self.n_layers - 1:
62 | res_acts = res_skip_acts[:, : self.hidden_channels, :]
63 | x = (x + res_acts) * x_mask
64 | output = output + res_skip_acts[:, self.hidden_channels :, :]
65 | else:
66 | output = output + res_skip_acts
67 | return output * x_mask
68 |
69 | def remove_weight_norm(self):
70 | if self.gin_channels != 0:
71 | nn.utils.remove_weight_norm(self.cond_layer)
72 | for l in self.in_layers:
73 | nn.utils.remove_weight_norm(l)
74 | for l in self.res_skip_layers:
75 | nn.utils.remove_weight_norm(l)
76 |
77 |
78 | # ! StochasticDurationPredictor
79 | # ! ResidualCouplingBlock
80 | # TODO convert to class method
81 | class Flip(nn.Module):
82 | def forward(self, x, *args, reverse=False, **kwargs):
83 | x = torch.flip(x, [1])
84 | if not reverse:
85 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
86 | return x, logdet
87 | else:
88 | return x
89 |
--------------------------------------------------------------------------------
/model/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LayerNorm(nn.Module):
7 | def __init__(self, channels, eps=1e-5):
8 | super().__init__()
9 | self.channels = channels
10 | self.eps = eps
11 |
12 | self.gamma = nn.Parameter(torch.ones(channels))
13 | self.beta = nn.Parameter(torch.zeros(channels))
14 |
15 | def forward(self, x: torch.Tensor):
16 | x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
17 | return x.mT
18 |
19 |
20 | class CondLayerNorm(nn.Module):
21 | def __init__(self, channels, eps=1e-5, cond_channels=0):
22 | super().__init__()
23 | self.channels = channels
24 | self.eps = eps
25 |
26 | self.linear_gamma = nn.Linear(cond_channels, channels)
27 | self.linear_beta = nn.Linear(cond_channels, channels)
28 |
29 | def forward(self, x: torch.Tensor, cond: torch.Tensor):
30 | gamma = self.linear_gamma(cond)
31 | beta = self.linear_beta(cond)
32 |
33 | x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps)
34 | return x.mT
35 |
--------------------------------------------------------------------------------
/model/normalizing_flows.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from model.transformer import RelativePositionTransformer
5 | from model.modules import WN
6 |
7 |
8 | class ResidualCouplingBlock(nn.Module):
9 | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0, mean_only=False, use_transformer_flow=True):
10 | super().__init__()
11 | self.flows = nn.ModuleList()
12 | for i in range(n_flows):
13 | use_transformer = use_transformer_flow if (i == n_flows - 1) else False # TODO or (i == n_flows - 2)
14 | self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=mean_only, use_transformer_flow=use_transformer))
15 | self.flows.append(Flip())
16 |
17 | def forward(self, x, m, logs, x_mask, g=None, reverse=False):
18 | if reverse:
19 | for flow in reversed(self.flows):
20 | x, m, logs = flow(x, m, logs, x_mask, g=g, reverse=reverse)
21 | else:
22 | for flow in self.flows:
23 | x, m, logs = flow(x, m, logs, x_mask, g=g, reverse=reverse)
24 | return x, m, logs
25 |
26 |
27 | # TODO rewrite for 256x256 attention score map
28 | class ResidualCouplingLayer(nn.Module):
29 | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False, use_transformer_flow=True):
30 | assert channels % 2 == 0, "channels should be divisible by 2"
31 | super().__init__()
32 | self.half_channels = channels // 2
33 | self.mean_only = mean_only
34 |
35 | self.pre_transformer = (
36 | RelativePositionTransformer(
37 | self.half_channels,
38 | self.half_channels,
39 | self.half_channels,
40 | self.half_channels,
41 | n_heads=2,
42 | n_layers=1,
43 | kernel_size=3,
44 | dropout=0.1,
45 | window_size=None,
46 | )
47 | if use_transformer_flow
48 | else None
49 | )
50 |
51 | self.pre = nn.Linear(self.half_channels, hidden_channels)
52 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
53 | self.post = nn.Linear(hidden_channels, self.half_channels * (2 - mean_only))
54 | self.post.weight.data.zero_()
55 | self.post.bias.data.zero_()
56 |
57 | def forward(self, x, m, logs, x_mask, g=None, reverse=False):
58 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
59 | m0, m1 = torch.split(m, [self.half_channels] * 2, 1)
60 | logs0, logs1 = torch.split(logs, [self.half_channels] * 2, 1)
61 | x0_ = x0
62 | if self.pre_transformer is not None:
63 | x0_ = self.pre_transformer(x0 * x_mask, x_mask)
64 | x0_ = x0_ + x0 # residual connection
65 | h = self.pre(x0_.mT).mT * x_mask
66 | h = self.enc(h, x_mask, g=g)
67 | stats = self.post(h.mT).mT * x_mask
68 | if not self.mean_only:
69 | m_flow, logs_flow = torch.split(stats, [self.half_channels] * 2, 1)
70 | else:
71 | m_flow = stats
72 | logs_flow = torch.zeros_like(m)
73 |
74 | if reverse:
75 | x1 = (x1 - m_flow) * torch.exp(-logs_flow) * x_mask
76 | m1 = (m1 - m_flow) * torch.exp(-logs_flow) * x_mask
77 | logs1 = logs1 - logs_flow
78 |
79 | x = torch.cat([x0, x1], 1)
80 | m = torch.cat([m0, m1], 1)
81 | logs = torch.cat([logs0, logs1], 1)
82 | return x, m, logs
83 | else:
84 | x1 = m_flow + x1 * torch.exp(logs_flow) * x_mask
85 | m1 = m_flow + m1 * torch.exp(logs_flow) * x_mask
86 | logs1 = logs1 + logs_flow
87 |
88 | x = torch.cat([x0, x1], 1)
89 | m = torch.cat([m0, m1], 1)
90 | logs = torch.cat([logs0, logs1], 1)
91 | return x, m, logs
92 |
93 |
94 | class Flip(nn.Module):
95 | def forward(self, x, m, logs, *args, reverse=False, **kwargs):
96 | x = torch.flip(x, [1])
97 | m = torch.flip(m, [1])
98 | logs = torch.flip(logs, [1])
99 | return x, m, logs
100 |
--------------------------------------------------------------------------------
/model/transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from utils.model import convert_pad_shape
7 | from model.normalization import LayerNorm
8 |
9 |
10 | # TODO add conditioning on language
11 | # TODO check whether we need to stop gradient for speaker embedding
12 | class RelativePositionTransformer(nn.Module):
13 | def __init__(
14 | self,
15 | in_channels: int,
16 | hidden_channels: int,
17 | out_channels: int,
18 | hidden_channels_ffn: int,
19 | n_heads: int,
20 | n_layers: int,
21 | kernel_size=1,
22 | dropout=0.0,
23 | window_size=4,
24 | gin_channels=0,
25 | lang_channels=0,
26 | speaker_cond_layer=0,
27 | ):
28 | super().__init__()
29 | self.n_layers = n_layers
30 | self.speaker_cond_layer = speaker_cond_layer
31 |
32 | self.drop = nn.Dropout(dropout)
33 | self.attn_layers = nn.ModuleList()
34 | self.norm_layers_1 = nn.ModuleList()
35 | self.ffn_layers = nn.ModuleList()
36 | self.norm_layers_2 = nn.ModuleList()
37 | for i in range(self.n_layers):
38 | self.attn_layers.append(MultiHeadAttention(hidden_channels if i != 0 else in_channels, hidden_channels, n_heads, p_dropout=dropout, window_size=window_size))
39 | self.norm_layers_1.append(LayerNorm(hidden_channels))
40 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, hidden_channels_ffn, kernel_size, p_dropout=dropout))
41 | self.norm_layers_2.append(LayerNorm(hidden_channels))
42 | if gin_channels != 0:
43 | self.cond = nn.Linear(gin_channels, hidden_channels)
44 |
45 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: torch.Tensor = None, lang: torch.Tensor = None):
46 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
47 | x = x * x_mask
48 | for i in range(self.n_layers):
49 | # TODO consider using other conditioning
50 | # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
51 | if i == self.speaker_cond_layer - 1 and g is not None:
52 | # ! g = torch.detach(g)
53 | x = x + self.cond(g.mT).mT
54 | x = x * x_mask
55 | y = self.attn_layers[i](x, x, attn_mask)
56 | y = self.drop(y)
57 | x = self.norm_layers_1[i](x + y)
58 |
59 | y = self.ffn_layers[i](x, x_mask)
60 | y = self.drop(y)
61 | x = self.norm_layers_2[i](x + y)
62 | x = x * x_mask
63 | return x
64 |
65 |
66 | class MultiHeadAttention(nn.Module):
67 | def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
68 | super().__init__()
69 | assert channels % n_heads == 0
70 |
71 | self.channels = channels
72 | self.out_channels = out_channels
73 | self.n_heads = n_heads
74 | self.p_dropout = p_dropout
75 | self.window_size = window_size
76 | self.heads_share = heads_share
77 | self.block_length = block_length
78 | self.proximal_bias = proximal_bias
79 | self.proximal_init = proximal_init
80 | self.attn = None
81 |
82 | self.k_channels = channels // n_heads
83 | self.conv_q = nn.Linear(channels, channels)
84 | self.conv_k = nn.Linear(channels, channels)
85 | self.conv_v = nn.Linear(channels, channels)
86 | self.conv_o = nn.Linear(channels, out_channels)
87 | self.drop = nn.Dropout(p_dropout)
88 |
89 | if window_size is not None:
90 | n_heads_rel = 1 if heads_share else n_heads
91 | rel_stddev = self.k_channels**-0.5
92 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
93 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
94 |
95 | nn.init.xavier_uniform_(self.conv_q.weight)
96 | nn.init.xavier_uniform_(self.conv_k.weight)
97 | nn.init.xavier_uniform_(self.conv_v.weight)
98 | if proximal_init:
99 | with torch.no_grad():
100 | self.conv_k.weight.copy_(self.conv_q.weight)
101 | self.conv_k.bias.copy_(self.conv_q.bias)
102 |
103 | def forward(self, x, c, attn_mask=None):
104 | q = self.conv_q(x.mT).mT
105 | k = self.conv_k(c.mT).mT
106 | v = self.conv_v(c.mT).mT
107 |
108 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
109 |
110 | x = self.conv_o(x.mT).mT
111 | return x
112 |
113 | def attention(self, query, key, value, mask=None):
114 | # reshape [b, d, t] -> [b, n_h, t, d_k]
115 | b, d, t_s, t_t = (*key.size(), query.size(2))
116 | query = query.view(b, self.n_heads, self.k_channels, t_t).mT
117 | key = key.view(b, self.n_heads, self.k_channels, t_s).mT
118 | value = value.view(b, self.n_heads, self.k_channels, t_s).mT
119 |
120 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT)
121 | if self.window_size is not None:
122 | assert t_s == t_t, "Relative attention is only available for self-attention."
123 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
124 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
125 | scores_local = self._relative_position_to_absolute_position(rel_logits)
126 | scores = scores + scores_local
127 | if self.proximal_bias:
128 | assert t_s == t_t, "Proximal bias is only available for self-attention."
129 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
130 | if mask is not None:
131 | scores = scores.masked_fill(mask == 0, -1e4)
132 | if self.block_length is not None:
133 | assert t_s == t_t, "Local attention is only available for self-attention."
134 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
135 | scores = scores.masked_fill(block_mask == 0, -1e4)
136 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
137 | p_attn = self.drop(p_attn)
138 | output = torch.matmul(p_attn, value)
139 | if self.window_size is not None:
140 | relative_weights = self._absolute_position_to_relative_position(p_attn)
141 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
142 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
143 | output = output.mT.contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
144 | return output, p_attn
145 |
146 | def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
147 | """
148 | x: [b, h, l, m]
149 | y: [h or 1, m, d]
150 | ret: [b, h, l, d]
151 | """
152 | return torch.matmul(x, y.unsqueeze(0))
153 |
154 | def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
155 | """
156 | x: [b, h, l, d]
157 | y: [h or 1, m, d]
158 | ret: [b, h, l, m]
159 | """
160 | return torch.matmul(x, y.unsqueeze(0).mT)
161 |
162 | def _get_relative_embeddings(self, relative_embeddings, length):
163 | max_relative_position = 2 * self.window_size + 1
164 | # Pad first before slice to avoid using cond ops.
165 | pad_length = max(length - (self.window_size + 1), 0)
166 | slice_start_position = max((self.window_size + 1) - length, 0)
167 | slice_end_position = slice_start_position + 2 * length - 1
168 | if pad_length > 0:
169 | padded_relative_embeddings = F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
170 | else:
171 | padded_relative_embeddings = relative_embeddings
172 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
173 | return used_relative_embeddings
174 |
175 | def _relative_position_to_absolute_position(self, x):
176 | """
177 | x: [b, h, l, 2*l-1]
178 | ret: [b, h, l, l]
179 | """
180 | batch, heads, length, _ = x.size()
181 | # Concat columns of pad to shift from relative to absolute indexing.
182 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
183 |
184 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
185 | x_flat = x.view([batch, heads, length * 2 * length])
186 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
187 |
188 | # Reshape and slice out the padded elements.
189 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
190 | return x_final
191 |
192 | def _absolute_position_to_relative_position(self, x):
193 | """
194 | x: [b, h, l, l]
195 | ret: [b, h, l, 2*l-1]
196 | """
197 | batch, heads, length, _ = x.size()
198 | # padd along column
199 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
200 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
201 | # add 0's in the beginning that will skew the elements after reshape
202 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
203 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
204 | return x_final
205 |
206 | def _attention_bias_proximal(self, length):
207 | """Bias for self-attention to encourage attention to close positions.
208 | Args:
209 | length: an integer scalar.
210 | Returns:
211 | a Tensor with shape [1, 1, length, length]
212 | """
213 | r = torch.arange(length, dtype=torch.float32)
214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
216 |
217 |
218 | class FFN(nn.Module):
219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, causal=False):
220 | super().__init__()
221 | self.kernel_size = kernel_size
222 | self.padding = self._causal_padding if causal else self._same_padding
223 |
224 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
225 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
226 | self.drop = nn.Dropout(p_dropout)
227 |
228 | def forward(self, x, x_mask):
229 | x = self.conv_1(self.padding(x * x_mask))
230 | x = torch.relu(x)
231 | x = self.drop(x)
232 | x = self.conv_2(self.padding(x * x_mask))
233 | return x * x_mask
234 |
235 | def _causal_padding(self, x):
236 | if self.kernel_size == 1:
237 | return x
238 | pad_l = self.kernel_size - 1
239 | pad_r = 0
240 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
241 | x = F.pad(x, convert_pad_shape(padding))
242 | return x
243 |
244 | def _same_padding(self, x):
245 | if self.kernel_size == 1:
246 | return x
247 | pad_l = (self.kernel_size - 1) // 2
248 | pad_r = self.kernel_size // 2
249 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
250 | x = F.pad(x, convert_pad_shape(padding))
251 | return x
252 |
--------------------------------------------------------------------------------
/preprocess/README.md:
--------------------------------------------------------------------------------
1 | # VITS2 | Preprocessing
2 |
3 | ## Todo
4 |
5 | - [x] text preprocessing
6 | - [x] update vocabulary to support all symbols and features from IPA. See [phonemes.md](https://github.com/espeak-ng/espeak-ng/blob/ed9a7bcf5778a188cdec202ac4316461badb28e1/docs/phonemes.md#L5)
7 | - [x] per dataset filelists preprocessing. Please refer [prepare/filelists.ipynb](datasets/ljs_base/prepare/filelists.ipynb)
8 | - [x] handle unknown (out of vocabulary) symbols. Please refer [vocab - TorchText](https://pytorch.org/text/stable/vocab.html)
9 | - [x] handle special symbols in tokenizer. Please refer [text/symbols.py](text/symbols.py)
10 | - [ ] audio preprocessing
11 | - [x] replaced scipy and librosa dependencies with torchaudio. See docs [torchaudio.load](https://pytorch.org/audio/stable/backend.html#id2) and [torchaudio.transforms](https://pytorch.org/audio/stable/transforms.html)
12 | - [ ] remove necessity for speakers indexation. See [vits/issues/58](https://github.com/jaywalnut310/vits/issues/58)
13 | - [ ] update batch audio resampling. Please refer [audio_resample.ipynb](preprocess/audio_resample.ipynb)
14 | - [ ] test stereo audio (multi-channel) training
15 |
16 | # VITS2 | Preprocessing
17 |
--------------------------------------------------------------------------------
/preprocess/audio_find_corrupted.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Check for corrupted audio files in dataset\n"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os\n",
17 | "import torchaudio\n",
18 | "import concurrent.futures\n",
19 | "\n",
20 | "i_dir = \"path/to/your/dataset\""
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "def check_wav(file_path):\n",
30 | " \"\"\"Load a .wav file and return if it's corrupted or not\"\"\"\n",
31 | " try:\n",
32 | " waveform, sample_rate = torchaudio.load(file_path)\n",
33 | " return (file_path, True)\n",
34 | " except Exception as e:\n",
35 | " return (file_path, False)\n",
36 | "\n",
37 | "\n",
38 | "def find_wavs(directory):\n",
39 | " \"\"\"Find all .wav files in a directory\"\"\"\n",
40 | " for foldername, subfolders, filenames in os.walk(directory):\n",
41 | " for filename in filenames:\n",
42 | " if filename.endswith(\".wav\"):\n",
43 | " yield os.path.join(foldername, filename)\n",
44 | "\n",
45 | "\n",
46 | "def main(directory):\n",
47 | " \"\"\"Check all .wav files in a directory and its subdirectories\"\"\"\n",
48 | " with concurrent.futures.ThreadPoolExecutor() as executor:\n",
49 | " wav_files = list(find_wavs(directory))\n",
50 | " future_to_file = {executor.submit(check_wav, wav): wav for wav in wav_files}\n",
51 | "\n",
52 | " done_count = 0\n",
53 | " for future in concurrent.futures.as_completed(future_to_file):\n",
54 | " file_path = future_to_file[future]\n",
55 | " try:\n",
56 | " is_valid = future.result()\n",
57 | " except Exception as exc:\n",
58 | " print(f\"{file_path} generated an exception: {exc}\")\n",
59 | " else:\n",
60 | " if not is_valid[1]:\n",
61 | " print(f\"Corrupted file: {file_path}\")\n",
62 | "\n",
63 | " done_count += 1\n",
64 | " if done_count % 5000 == 0:\n",
65 | " print(f\"Processed {done_count} files...\")"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "main(i_dir)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": []
83 | }
84 | ],
85 | "metadata": {
86 | "language_info": {
87 | "name": "python"
88 | },
89 | "orig_nbformat": 4
90 | },
91 | "nbformat": 4,
92 | "nbformat_minor": 2
93 | }
94 |
--------------------------------------------------------------------------------
/preprocess/audio_resample.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Resample audio wavs\n",
8 | "\n",
9 | "Refer to: [audio resampling tutorial](https://pytorch.org/audio/0.10.0/tutorials/audio_resampling_tutorial.html)\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import torchaudio\n",
20 | "import torchaudio.transforms as T\n",
21 | "import concurrent.futures\n",
22 | "from pathlib import Path\n",
23 | "import random"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# Example usage:\n",
33 | "input_directory = \"/path/to/dataset\"\n",
34 | "output_directory = f\"{input_directory}.cleaned\"\n",
35 | "orig_sr = 16000\n",
36 | "new_sr = 22050"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "def resample_wav_files(input_dir, output_dir, sr, new_sr):\n",
46 | " # Create the output directory if it doesn't exist\n",
47 | " os.makedirs(output_dir, exist_ok=True)\n",
48 | "\n",
49 | " # Create a resampler object\n",
50 | " resampler = T.Resample(\n",
51 | " sr,\n",
52 | " new_sr,\n",
53 | " lowpass_filter_width=128,\n",
54 | " rolloff=0.99999,\n",
55 | " resampling_method=\"sinc_interp_hann\",\n",
56 | " )\n",
57 | "\n",
58 | " def resample_file(file_path):\n",
59 | " # Load the audio file\n",
60 | " waveform, sample_rate = torchaudio.load(file_path)\n",
61 | " assert sample_rate == sr\n",
62 | "\n",
63 | " # Resample the audio\n",
64 | " resampled_waveform = resampler(waveform)\n",
65 | "\n",
66 | " # Construct the output file path\n",
67 | " output_file = Path(output_dir) / Path(file_path).relative_to(input_dir)\n",
68 | "\n",
69 | " # Save the resampled audio\n",
70 | " torchaudio.save(output_file, resampled_waveform,\n",
71 | " new_sr, bits_per_sample=16)\n",
72 | "\n",
73 | " return output_file\n",
74 | "\n",
75 | " # Use generator to find .wav files and pre-create output directories\n",
76 | " def find_and_prep_wav_files(input_dir, output_dir):\n",
77 | " for root, _, files in os.walk(input_dir):\n",
78 | " for file in files:\n",
79 | " if file.endswith(\".wav\"):\n",
80 | " file_path = Path(root) / file\n",
81 | " output_file = Path(output_dir) / \\\n",
82 | " file_path.relative_to(input_dir)\n",
83 | " os.makedirs(output_file.parent, exist_ok=True)\n",
84 | " yield str(file_path)\n",
85 | "\n",
86 | " # Resample the .wav files using threads for parallel processing\n",
87 | " wav_files = find_and_prep_wav_files(input_dir, output_dir)\n",
88 | " with concurrent.futures.ThreadPoolExecutor() as executor:\n",
89 | " for i, output_file in enumerate(executor.map(resample_file, wav_files)):\n",
90 | " if i % 1000 == 0:\n",
91 | " print(f\"{i}: {output_file}\")\n",
92 | "\n",
93 | "\n",
94 | "resample_wav_files(input_directory, output_directory, orig_sr, new_sr)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "# Test random file to see if it worked\n",
104 | "out_path = os.path.join(output_directory, os.listdir(output_directory)[random.randint(0, len(os.listdir(output_directory)))])\n",
105 | "\n",
106 | "print(torchaudio.info(out_path))\n",
107 | "resampled_waveform, sample_rate = torchaudio.load(out_path)\n",
108 | "print(f\"max: {resampled_waveform.max()}, min: {resampled_waveform.min()}\")"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": []
117 | }
118 | ],
119 | "metadata": {
120 | "kernelspec": {
121 | "display_name": "g2p",
122 | "language": "python",
123 | "name": "python3"
124 | },
125 | "language_info": {
126 | "codemirror_mode": {
127 | "name": "ipython",
128 | "version": 3
129 | },
130 | "file_extension": ".py",
131 | "mimetype": "text/x-python",
132 | "name": "python",
133 | "nbconvert_exporter": "python",
134 | "pygments_lexer": "ipython3",
135 | "version": "3.11.4"
136 | },
137 | "orig_nbformat": 4
138 | },
139 | "nbformat": 4,
140 | "nbformat_minor": 2
141 | }
142 |
--------------------------------------------------------------------------------
/preprocess/audio_resampling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import concurrent.futures
3 | import os
4 | from concurrent.futures import ProcessPoolExecutor
5 | from multiprocessing import cpu_count
6 |
7 | import librosa
8 | import numpy as np
9 | from rich.progress import track
10 | from scipy.io import wavfile
11 |
12 |
13 | def load_wav(wav_path):
14 | return librosa.load(wav_path, sr=None)
15 |
16 |
17 | def trim_wav(wav, top_db=40):
18 | return librosa.effects.trim(wav, top_db=top_db)
19 |
20 |
21 | def normalize_peak(wav, threshold=1.0):
22 | peak = np.abs(wav).max()
23 | if peak > threshold:
24 | wav = 0.98 * wav / peak
25 | return wav
26 |
27 |
28 | def resample_wav(wav, sr, target_sr):
29 | return librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
30 |
31 |
32 | def save_wav_to_path(wav, save_path, sr):
33 | wavfile.write(save_path, sr, (wav * np.iinfo(np.int16).max).astype(np.int16))
34 |
35 |
36 | def process(item):
37 | spkdir, wav_name, args = item
38 | speaker = spkdir.replace("\\", "/").split("/")[-1]
39 |
40 | wav_path = os.path.join(args.in_dir, speaker, wav_name)
41 | if os.path.exists(wav_path) and ".wav" in wav_path:
42 | os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
43 |
44 | wav, sr = load_wav(wav_path)
45 | wav, _ = trim_wav(wav)
46 | wav = normalize_peak(wav)
47 | resampled_wav = resample_wav(wav, sr, args.sr2)
48 |
49 | if not args.skip_loudnorm:
50 | resampled_wav /= np.max(np.abs(resampled_wav))
51 |
52 | save_path2 = os.path.join(args.out_dir2, speaker, wav_name)
53 | save_wav_to_path(resampled_wav, save_path2, args.sr2)
54 |
55 |
56 | """
57 | def process_all_speakers():
58 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
59 |
60 | with ThreadPoolExecutor(max_workers=process_count) as executor:
61 | for speaker in speakers:
62 | spk_dir = os.path.join(args.in_dir, speaker)
63 | if os.path.isdir(spk_dir):
64 | print(spk_dir)
65 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
66 | for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
67 | pass
68 | """
69 | # multi process
70 |
71 |
72 | def process_all_speakers():
73 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
74 | with ProcessPoolExecutor(max_workers=process_count) as executor:
75 | for speaker in speakers:
76 | spk_dir = os.path.join(args.in_dir, speaker)
77 | if os.path.isdir(spk_dir):
78 | print(spk_dir)
79 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
80 | for _ in track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"):
81 | pass
82 |
83 |
84 | if __name__ == "__main__":
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument("--sr2", type=int, default=44100, help="sampling rate")
87 | parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir")
88 | parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
89 | parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it")
90 | args = parser.parse_args()
91 |
92 | print(f"CPU count: {cpu_count()}")
93 | speakers = os.listdir(args.in_dir)
94 | process_all_speakers()
95 |
--------------------------------------------------------------------------------
/preprocess/mel_transform.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import logging
5 | import argparse
6 | import traceback
7 | from tqdm import tqdm
8 | import torch
9 | import torch.multiprocessing as mp
10 | from concurrent.futures import ProcessPoolExecutor
11 | import torchaudio
12 |
13 | from utils.hparams import get_hparams_from_file, HParams
14 | from utils.mel_processing import wav_to_mel
15 |
16 | os.environ["OMP_NUM_THREADS"] = "1"
17 | log_format = "%(asctime)s %(message)s"
18 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt="%m/%d %I:%M:%S %p")
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--data_dir", type=str, required=True, help="Directory containing audio files")
24 | parser.add_argument("-c", "--config", type=str, required=True, help="YAML file for configuration")
25 | args = parser.parse_args()
26 |
27 | hparams = get_hparams_from_file(args.config)
28 | hparams.data_dir = args.data_dir
29 | return hparams
30 |
31 |
32 | def process_batch(batch, sr_hps, n_fft, hop_size, win_size, n_mels, fmin, fmax):
33 | wavs = []
34 | for ifile in batch:
35 | try:
36 | wav, sr = torchaudio.load(ifile)
37 | assert sr == sr_hps, f"sample rate: {sr}, expected: {sr_hps}"
38 | wavs.append(wav)
39 | except:
40 | traceback.print_exc()
41 | print("Failed to process {}".format(ifile))
42 | return None
43 |
44 | wav_lengths = torch.tensor([x.size(1) for x in wavs])
45 | max_wav_len = wav_lengths.max()
46 |
47 | wav_padded = torch.zeros(len(batch), 1, max_wav_len)
48 | for i, wav in enumerate(wavs):
49 | wav_padded[i, :, : wav.size(1)] = wav
50 |
51 | spec = wav_to_mel(wav_padded, n_fft, n_mels, sr_hps, hop_size, win_size, fmin, fmax, center=False, norm=False)
52 | spec = torch.squeeze(spec, 1)
53 |
54 | for i, ifile in enumerate(batch):
55 | ofile = ifile.replace(".wav", ".spec.pt")
56 | spec_i = spec[i, :, : wav_lengths[i] // hop_size].clone()
57 | torch.save(spec_i, ofile)
58 |
59 | return batch
60 |
61 |
62 | def process_data(hps: HParams):
63 | wav_fns = sorted(glob.glob(f"{hps.data_dir}/**/*.wav", recursive=True))
64 | # wav_fns = wav_fns[:100] # * Enable for testing
65 | logging.info(f"Max: {mp.cpu_count()}; using 32 CPU cores")
66 | logging.info(f"Preprocessing {len(wav_fns)} files...")
67 |
68 | sr = hps.data.sample_rate
69 | n_fft = hps.data.n_fft
70 | hop_size = hps.data.hop_length
71 | win_size = hps.data.win_length
72 | n_mels = hps.data.n_mels
73 | fmin = hps.data.f_min
74 | fmax = hps.data.f_max
75 |
76 | # Batch files to optimize disk I/O and computation
77 | batch_size = 128 # Change as needed
78 | audio_file_batches = [wav_fns[i : i + batch_size] for i in range(0, len(wav_fns), batch_size)]
79 |
80 | # Use multiprocessing to speed up the conversion
81 | with ProcessPoolExecutor(max_workers=32) as executor:
82 | futures = [executor.submit(process_batch, batch, sr, n_fft, hop_size, win_size, n_mels, fmin, fmax) for batch in audio_file_batches]
83 | for future in tqdm(futures):
84 | if future.result() is None:
85 | logging.warning(f"Failed to process a batch.")
86 | return
87 |
88 |
89 | def get_size_by_ext(directory, extension):
90 | total_size = 0
91 | for dirpath, dirnames, filenames in os.walk(directory):
92 | for f in filenames:
93 | if f.endswith(extension):
94 | fp = os.path.join(dirpath, f)
95 | total_size += os.path.getsize(fp)
96 |
97 | return total_size
98 |
99 |
100 | def human_readable_size(size):
101 | """Converts size in bytes to a human-readable format."""
102 | for unit in ["B", "KB", "MB", "GB", "TB"]:
103 | if size < 1024:
104 | return f"{size:.2f}{unit}"
105 | size /= 1024
106 | return f"{size:.2f}PB" # PB is for petabyte, which will be used if the size is too large.
107 |
108 |
109 | if __name__ == "__main__":
110 | from time import time
111 |
112 | hps = parse_args()
113 |
114 | start = time()
115 | process_data(hps)
116 | logging.info(f"Processed data in {time() - start} seconds")
117 |
118 | extension = ".spec.pt"
119 | size_spec = get_size_by_ext(hps.data_dir, extension)
120 | logging.info(f"{extension}: \t{human_readable_size(size_spec)}")
121 | extension = ".wav"
122 | size_wav = get_size_by_ext(hps.data_dir, extension)
123 | logging.info(f"{extension}: \t{human_readable_size(size_wav)}")
124 | logging.info(f"Total: \t\t{human_readable_size(size_spec + size_wav)}")
125 |
--------------------------------------------------------------------------------
/preprocess/vocab_generation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Create a list of all ipa symbols\n",
8 | "\n",
9 | "Please refer [phonemes.md](text/phonemes.md)\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Consonants\n",
19 | "\n",
20 | "consonants = \"\"\"\n",
21 | " m̥ | m | | ɱ | | | n̥ | n | | | ɳ̊ | ɳ | ɲ̟̊ | ɲ̟ | ɲ̊ | ɲ | ŋ̊ | ŋ | ɴ̥ | ɴ | | | | |\n",
22 | " p | b | p̪ | b̪ | t̪ | d̪ | t | d | | | ʈ | ɖ | | | c | ɟ | k | ɡ | q | ɢ | ʡ | | ʔ | |\n",
23 | " | | | | | | t͡s | d͡z | t͡ʃ | d͡ʒ | ʈ͡ʂ | ɖ͡ʐ | t͡ɕ | d͡ʑ | | | | | | | | | | |\n",
24 | " p͡ɸ | b͡β | p̪͡f | b̪͡v | t͡θ | d͡ð | | | | | | | | | c͡ç | ɟ͡ʝ | k͡x | ɡ͡ɣ | q͡χ | ɢ͡ʁ | ʡ͡ħ | ʡ͡ʕ | ʔ͡h | |\n",
25 | " | | | | | | t͡ɬ | d͡ɮ | | | ʈ͡ɭ̊˔ | | | | c͡ʎ̥˔ | | k͡ʟ̝̊ | ɡ͡ʟ̝ | | | | | | |\n",
26 | " | | | | | | s | z | ʃ | ʒ | ʂ | ʐ | ɕ | ʑ | | | | | | | | | | |\n",
27 | " ɸ | β | f | v | θ | ð | | | | | | | | | ç | ʝ | x | ɣ | χ | ʁ | ħ | ʕ | h | ɦ |\n",
28 | " | | | | | | ɬ | ɮ | | | ɭ̊˔ | | | | ʎ̥˔ | ʎ̝ | ʟ̝̊ | ʟ̝ | | | | | | |\n",
29 | " | | ʋ̥ | ʋ | | | ɹ̥ | ɹ | | | ɻ̊ | ɻ | | | j̊ | j | ɰ̊ | ɰ | | | | | | |\n",
30 | " | | | | | | l̥ | l | | | ɭ̊ | ɭ | | | ʎ̥ | ʎ | ʟ̥ | ʟ | | ʟ̠ | | | | |\n",
31 | " | ⱱ̟ | | ⱱ | | | ɾ̥ | ɾ | | | ɽ̊ | ɽ | | | | | | | | ɢ̆ | | ʡ̮ | | |\n",
32 | " | | | | | | | ɺ | | | | ɭ̆ | | | | ʎ̮ | | ʟ̆ | | | | | | |\n",
33 | " | ʙ | | | | | r̥ | r | | | ɽ͡r̥ | ɽ͡r | | | | | | | ʀ̥ | ʀ | ʜ | ʢ | | |\n",
34 | " ʘ | | | | ǀ | | ǃ | | | | | | ǂ | | | | | | | | | | | |\n",
35 | " | | | | | | ǁ | | | | | | | | | | | | | | | | | |\n",
36 | " | ɓ | | | | | | ɗ | | | | | | | | ʄ | | ɠ | | ʛ | | | | |\n",
37 | " pʼ | | | | | | tʼ | | | | ʈʼ | | | | cʼ | | kʼ | | qʼ | | ʡʼ | | | |\n",
38 | " | | fʼ | | θʼ | | sʼ | | ʃʼ | | ʂʼ | | | | | | xʼ | | χʼ | | | | | |\n",
39 | " | | | | | | ɬʼ | | | | | | | | | | | | | | | | | |\n",
40 | "\"\"\"\n",
41 | "\n",
42 | "consonants_other = \"\"\"\n",
43 | " | | | | | | | | | ŋ͡m | | |\n",
44 | " | | | | | | | | k͡p | ɡ͡b | | |\n",
45 | " p͡f | b͡v | | | | | | | | | | |\n",
46 | " | | | | ɧ | | | | | | | |\n",
47 | " | | | | | | | ɥ | | | ʍ | w |\n",
48 | " | | | ɫ | | | | | | | | |\n",
49 | "\"\"\"\n",
50 | "\n",
51 | "\n",
52 | "manner_of_articulation = \"\"\"\n",
53 | " ʼ |\n",
54 | "\"\"\"\n",
55 | "\n",
56 | "# Vowels\n",
57 | "\n",
58 | "vowels = \"\"\"\n",
59 | " i | y | ɨ | ʉ | ɯ | u |\n",
60 | " ɪ | ʏ | | | | ʊ |\n",
61 | " e | ø | ɘ | ɵ | ɤ | o |\n",
62 | " | | ə | | | |\n",
63 | " ɛ | œ | ɜ | ɞ | ʌ | ɔ |\n",
64 | " æ | | ɐ | | | |\n",
65 | " a | ɶ | | | ɑ | ɒ |\n",
66 | "\"\"\"\n",
67 | "\n",
68 | "\n",
69 | "vowels_other = \"\"\"\n",
70 | "| ɚ |\n",
71 | "| ɝ |\n",
72 | "\"\"\"\n",
73 | "\n",
74 | "# Diacritics\n",
75 | "\n",
76 | "articulation = \"\"\"\n",
77 | " ◌̼ |\n",
78 | " ◌̪͆ |\n",
79 | " ◌̪ |\n",
80 | " ◌̺ |\n",
81 | " ◌̻ |\n",
82 | " ◌̟ |\n",
83 | " ◌̠ |\n",
84 | " ◌̈ |\n",
85 | " ◌̽ |\n",
86 | " ◌̝ |\n",
87 | " ◌̞ |\n",
88 | "\"\"\"\n",
89 | "\n",
90 | "air_flow = \"\"\"\n",
91 | " ↑ |\n",
92 | " ↓ |\n",
93 | "\"\"\"\n",
94 | "\n",
95 | "phonation = \"\"\"\n",
96 | " ◌̤ |\n",
97 | " ◌̥ |\n",
98 | " ◌̬ |\n",
99 | " ◌̰ |\n",
100 | " ʔ͡◌ |\n",
101 | "\"\"\"\n",
102 | "\n",
103 | "rounding_and_labialization = \"\"\"\n",
104 | " ◌ʷ◌ᶣ |\n",
105 | " ◌ᵝ |\n",
106 | " ◌̹ |\n",
107 | " ◌̜ |\n",
108 | "\"\"\"\n",
109 | "\n",
110 | "\n",
111 | "syllabicity = \"\"\"\n",
112 | " ◌̩ |\n",
113 | " ◌̯ |\n",
114 | "\"\"\"\n",
115 | "\n",
116 | "consonant_release = \"\"\"\n",
117 | " ◌ʰ |\n",
118 | " ◌ⁿ |\n",
119 | " ◌ˡ |\n",
120 | " ◌̚ |\n",
121 | "\"\"\"\n",
122 | "\n",
123 | "co_articulation = \"\"\"\n",
124 | " ◌ʲ |\n",
125 | " ◌ˠ◌̴ |\n",
126 | " ◌ˤ◌̴ |\n",
127 | " ◌̃ |\n",
128 | " ◌˞ |\n",
129 | "\"\"\"\n",
130 | "\n",
131 | "tongue_root = \"\"\"\n",
132 | " ◌̘ |\n",
133 | " ◌̙ |\n",
134 | "\"\"\"\n",
135 | "\n",
136 | "fortis_and_lenis = \"\"\"\n",
137 | " ◌͈ |\n",
138 | " ◌͉ |\n",
139 | "\"\"\"\n",
140 | "\n",
141 | "# Suprasegmentals\n",
142 | "\n",
143 | "stress = \"\"\"\n",
144 | " ˈ◌ |\n",
145 | " ˌ◌ |\n",
146 | "\"\"\"\n",
147 | "\n",
148 | "length = \"\"\"\n",
149 | " ◌̆ |\n",
150 | " ◌ˑ |\n",
151 | " ◌ː |\n",
152 | " ◌ːː |\n",
153 | "\"\"\"\n",
154 | "\n",
155 | "rhythm = \"\"\"\n",
156 | " . |\n",
157 | " ◌‿◌ |\n",
158 | "\"\"\"\n",
159 | "\n",
160 | "tones = \"\"\"\n",
161 | " ◌˥ |\n",
162 | " ◌˦ |\n",
163 | " ◌˧ |\n",
164 | " ◌˨ |\n",
165 | " ◌˩ |\n",
166 | " ꜛ◌ |\n",
167 | " ꜜ◌ |\n",
168 | "\"\"\"\n",
169 | "\n",
170 | "intonation = \"\"\"\n",
171 | " | |\n",
172 | " ‖ |\n",
173 | " ↗︎ |\n",
174 | " ↘︎ |\n",
175 | "\"\"\""
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "def get_non_empty_content(md_table):\n",
185 | " non_empty_content = []\n",
186 | "\n",
187 | " # Split table by lines\n",
188 | " lines = md_table.split(\"\\n\")\n",
189 | "\n",
190 | " for line in lines:\n",
191 | " # Split each line by \"|\" to get the cells\n",
192 | " cells = line.split(\"|\")\n",
193 | "\n",
194 | " for cell in cells:\n",
195 | " cell_content = cell.strip()\n",
196 | "\n",
197 | " # If the cell content is not empty, add it to the list\n",
198 | " if cell_content != \"\":\n",
199 | " non_empty_content.append(cell_content)\n",
200 | "\n",
201 | " non_empty_content = \"\".join(non_empty_content)\n",
202 | "\n",
203 | " # unique non_empty_content\n",
204 | " non_empty_content = set(non_empty_content)\n",
205 | " non_empty_content = \"\".join(non_empty_content)\n",
206 | "\n",
207 | " # sort non_empty_content\n",
208 | " non_empty_content = sorted(non_empty_content)\n",
209 | "\n",
210 | " return non_empty_content\n",
211 | "\n",
212 | "\n",
213 | "# Consonants\n",
214 | "consonants = get_non_empty_content(consonants)\n",
215 | "consonants_other = get_non_empty_content(consonants_other)\n",
216 | "manner_of_articulation = get_non_empty_content(manner_of_articulation)\n",
217 | "# Vowels\n",
218 | "vowels = get_non_empty_content(vowels)\n",
219 | "vowels_other = get_non_empty_content(vowels_other)\n",
220 | "# Diacritics\n",
221 | "articulation = get_non_empty_content(articulation)\n",
222 | "air_flow = get_non_empty_content(air_flow)\n",
223 | "phonation = get_non_empty_content(phonation)\n",
224 | "rounding_and_labialization = get_non_empty_content(rounding_and_labialization)\n",
225 | "syllabicity = get_non_empty_content(syllabicity)\n",
226 | "consonant_release = get_non_empty_content(consonant_release)\n",
227 | "co_articulation = get_non_empty_content(co_articulation)\n",
228 | "tongue_root = get_non_empty_content(tongue_root)\n",
229 | "fortis_and_lenis = get_non_empty_content(fortis_and_lenis)\n",
230 | "# Suprasegmentals\n",
231 | "stress = get_non_empty_content(stress)\n",
232 | "length = get_non_empty_content(length)\n",
233 | "rhythm = get_non_empty_content(rhythm)\n",
234 | "tones = get_non_empty_content(tones)\n",
235 | "intonation = get_non_empty_content(intonation)"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": null,
241 | "metadata": {},
242 | "outputs": [],
243 | "source": [
244 | "# All symbols\n",
245 | "_ipa = (\n",
246 | " consonants\n",
247 | " + consonants_other\n",
248 | " + manner_of_articulation\n",
249 | " + vowels\n",
250 | " + vowels_other\n",
251 | " + articulation\n",
252 | " + air_flow\n",
253 | " + phonation\n",
254 | " + rounding_and_labialization\n",
255 | " + syllabicity\n",
256 | " + consonant_release\n",
257 | " + co_articulation\n",
258 | " + tongue_root\n",
259 | " + fortis_and_lenis\n",
260 | " + stress\n",
261 | " + length\n",
262 | " + rhythm\n",
263 | " + tones\n",
264 | " + intonation\n",
265 | ")\n",
266 | "\n",
267 | "print(_ipa)\n",
268 | "print(len(_ipa))"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "_ipa = \"\".join(_ipa)\n",
278 | "\n",
279 | "# unique _ipa\n",
280 | "_ipa = set(_ipa)\n",
281 | "_ipa = \"\".join(_ipa)\n",
282 | "\n",
283 | "# sort symbols\n",
284 | "_ipa = sorted(_ipa)\n",
285 | "\n",
286 | "print(f'_ipa = \"{\"\".join(_ipa)}\"')\n",
287 | "print(len(_ipa))"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "_punctuation = ';:,.!?¡¿—…\"«»“” '\n",
297 | "\n",
298 | "symbols = list(_punctuation) + list(_ipa)\n",
299 | "\n",
300 | "# unique symbols\n",
301 | "symbols = set(symbols)\n",
302 | "symbols = \"\".join(symbols)\n",
303 | "\n",
304 | "# sort symbols\n",
305 | "symbols = sorted(symbols)\n",
306 | "\n",
307 | "symbols = \"\".join(symbols)\n",
308 | "\n",
309 | "print(f'symbols = \"{\"\".join(symbols)}\"')\n",
310 | "print(len(symbols))"
311 | ]
312 | }
313 | ],
314 | "metadata": {
315 | "kernelspec": {
316 | "display_name": "py11",
317 | "language": "python",
318 | "name": "python3"
319 | },
320 | "language_info": {
321 | "codemirror_mode": {
322 | "name": "ipython",
323 | "version": 3
324 | },
325 | "file_extension": ".py",
326 | "mimetype": "text/x-python",
327 | "name": "python",
328 | "nbconvert_exporter": "python",
329 | "pygments_lexer": "ipython3",
330 | "version": "3.11.4"
331 | },
332 | "orig_nbformat": 4
333 | },
334 | "nbformat": 4,
335 | "nbformat_minor": 2
336 | }
337 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchaudio
4 | torchtext
5 |
6 | phonemizer
7 | inflect
8 | pandas
9 |
10 | numpy
11 | numba
12 | matplotlib
13 |
14 | tensorboard
15 | tensorboardX
16 |
17 | tqdm
18 | PyYAML
19 | ipykernel
20 | pytorch_lightning
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from text import cleaners
3 | from torchtext.vocab import Vocab
4 |
5 |
6 | def tokenizer(text: str, vocab: Vocab, cleaner_names: List[str], language="en-us", cleaned_text=False) -> List[int]:
7 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
8 | Args:
9 | text: string to convert to a sequence of IDs
10 | cleaner_names: names of the cleaner functions from text/cleaners.py
11 | language: language ID from https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md
12 | cleaned_text: whether the text has already been cleaned
13 | Returns:
14 | List of integers corresponding to the symbols in the text
15 | """
16 | if not cleaned_text:
17 | return _clean_text(text, vocab, cleaner_names, language=language)
18 | else:
19 | return list(map(int, text.split("\t")))
20 |
21 |
22 | def detokenizer(sequence: List[int], vocab: Vocab) -> str:
23 | """Converts a sequence of tokens back to a string"""
24 | return "".join(vocab.lookup_tokens(sequence))
25 |
26 |
27 | def _clean_text(text: str, vocab: Vocab, cleaner_names: List[str], language="en-us") -> str:
28 | for name in cleaner_names:
29 | cleaner = getattr(cleaners, name)
30 | assert callable(cleaner), f"Unknown cleaner: {name}"
31 | text = cleaner(text, vocab=vocab, language=language)
32 | return text
33 |
34 |
35 | if __name__ == "__main__":
36 | from utils.task import load_vocab
37 |
38 | vocab = load_vocab("datasets/ljs_base/vocab.txt")
39 | cleaner_names = ["phonemize_text", "add_spaces", "tokenize_text", "delete_unks", "add_bos_eos", "detokenize_sequence"]
40 | text = "Well, I like pizza. You know … Who doesn't like pizza? "
41 | print(tokenizer(text, vocab, cleaner_names, language="en-us", cleaned_text=False))
42 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """
2 | Cleaners are transformations that run over the input text at both training and eval time.
3 |
4 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5 | hyperparameter.
6 | """
7 |
8 | import re
9 | from typing import List
10 | from torchtext.vocab import Vocab
11 | from phonemizer import phonemize
12 | from phonemizer.separator import Separator
13 |
14 |
15 | from text.normalize_numbers import normalize_numbers
16 |
17 | from text.symbols import _punctuation, PAD_ID, UNK_ID, BOS_ID, EOS_ID
18 |
19 |
20 | _whitespace_re = re.compile(r"\s+")
21 | _preserved_symbols_re = re.compile(rf"[{_punctuation}]|<.*?>")
22 | separator = Separator(word="", phone=" ")
23 |
24 |
25 | # ---------------------------------------------------------------------------- #
26 | # | Text cleaners | #
27 | # ---------------------------------------------------------------------------- #
28 | def lowercase(text: str, *args, **kwargs):
29 | return text.lower()
30 |
31 |
32 | def collapse_whitespace(text: str, *args, **kwargs):
33 | return re.sub(_whitespace_re, " ", text)
34 |
35 |
36 | def expand_numbers(text: str, *args, **kwargs):
37 | return normalize_numbers(text)
38 |
39 |
40 | def phonemize_text(text: List[str] | str, *args, language="en-us", **kwargs):
41 | return phonemize(text, language=language, backend="espeak", separator=separator, strip=True, preserve_punctuation=True, punctuation_marks=_preserved_symbols_re, with_stress=True, njobs=8)
42 |
43 |
44 | def add_spaces(text: str, *args, **kwargs):
45 | spaced_text = re.sub(_preserved_symbols_re, r" \g<0> ", text)
46 | cleaned_text = re.sub(_whitespace_re, " ", spaced_text)
47 | return cleaned_text.strip()
48 |
49 |
50 | # ---------------------------------------------------------------------------- #
51 | # | Token cleaners | #
52 | # ---------------------------------------------------------------------------- #
53 |
54 |
55 | def tokenize_text(text: str, vocab: Vocab, *args, **kwargs):
56 | tokens = text.split()
57 | return vocab(tokens)
58 |
59 |
60 | def add_bos_eos(tokens: List[int], *args, **kwargs):
61 | return [BOS_ID] + tokens + [EOS_ID]
62 |
63 |
64 | def add_blank(tokens: List[int], *args, **kwargs):
65 | result = [PAD_ID] * (len(tokens) * 2 + 1)
66 | result[1::2] = tokens
67 | return result
68 |
69 |
70 | def delete_unks(tokens: List[int], *args, **kwargs):
71 | return [token for token in tokens if token != UNK_ID]
72 |
73 |
74 | def detokenize_sequence(sequence: List[int], vocab: Vocab, *args, **kwargs):
75 | return "".join(vocab.lookup_tokens(sequence))
76 |
--------------------------------------------------------------------------------
/text/normalize_numbers.py:
--------------------------------------------------------------------------------
1 | import inflect
2 | import re
3 |
4 |
5 | _inflect = inflect.engine()
6 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
7 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
8 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
9 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
10 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
11 | _number_re = re.compile(r"[0-9]+")
12 |
13 |
14 | def _remove_commas(m):
15 | return m.group(1).replace(",", "")
16 |
17 |
18 | def _expand_decimal_point(m):
19 | return m.group(1).replace(".", " point ")
20 |
21 |
22 | def _expand_dollars(m):
23 | match = m.group(1)
24 | parts = match.split(".")
25 | if len(parts) > 2:
26 | return match + " dollars" # Unexpected format
27 | dollars = int(parts[0]) if parts[0] else 0
28 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29 | if dollars and cents:
30 | dollar_unit = "dollar" if dollars == 1 else "dollars"
31 | cent_unit = "cent" if cents == 1 else "cents"
32 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
33 | elif dollars:
34 | dollar_unit = "dollar" if dollars == 1 else "dollars"
35 | return "%s %s" % (dollars, dollar_unit)
36 | elif cents:
37 | cent_unit = "cent" if cents == 1 else "cents"
38 | return "%s %s" % (cents, cent_unit)
39 | else:
40 | return "zero dollars"
41 |
42 |
43 | def _expand_ordinal(m):
44 | return _inflect.number_to_words(m.group(0))
45 |
46 |
47 | def _expand_number(m):
48 | num = int(m.group(0))
49 | if num > 1000 and num < 3000:
50 | if num == 2000:
51 | return "two thousand"
52 | elif num > 2000 and num < 2010:
53 | return "two thousand " + _inflect.number_to_words(num % 100)
54 | elif num % 100 == 0:
55 | return _inflect.number_to_words(num // 100) + " hundred"
56 | else:
57 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
58 | else:
59 | return _inflect.number_to_words(num, andword="")
60 |
61 |
62 | def normalize_numbers(text):
63 | text = re.sub(_comma_number_re, _remove_commas, text)
64 | text = re.sub(_pounds_re, r"\1 pounds", text)
65 | text = re.sub(_dollars_re, _expand_dollars, text)
66 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67 | text = re.sub(_ordinal_re, _expand_ordinal, text)
68 | text = re.sub(_number_re, _expand_number, text)
69 | return text
70 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """
2 | Set of symbols
3 | """
4 | _punctuation = ';:,.!?¡¿—…"«»“”'
5 |
6 |
7 | """
8 | Special symbols
9 | """
10 | # Define special symbols and indices
11 | special_symbols = ["", "", "", "", "", ""]
12 | PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
13 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import torch
4 | from torch import nn, optim
5 | from torch.nn import functional as F
6 | from torch.utils.data import DataLoader
7 | from torch.utils.tensorboard import SummaryWriter
8 | import torch.multiprocessing as mp
9 | import torch.distributed as dist
10 | from torch.nn.parallel import DistributedDataParallel as DDP
11 | from torch.cuda.amp import autocast, GradScaler
12 | from typing import List
13 |
14 | import utils.task as task
15 | from utils.hparams import get_hparams
16 | from model.models import SynthesizerTrn
17 | from model.discriminator import MultiPeriodDiscriminator
18 | from data_utils import TextAudioLoader, TextAudioCollate, DistributedBucketSampler
19 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss, kl_loss_normal
20 | from utils.mel_processing import wav_to_mel, spec_to_mel, spectral_norm
21 | from utils.model import slice_segments, clip_grad_value_
22 |
23 |
24 | torch.backends.cudnn.benchmark = True
25 | global_step = 0
26 |
27 |
28 | def main():
29 | """Assume Single Node Multi GPUs Training Only"""
30 | assert torch.cuda.is_available(), "CPU training is not allowed."
31 |
32 | n_gpus = torch.cuda.device_count()
33 | os.environ["MASTER_ADDR"] = "localhost"
34 | os.environ["MASTER_PORT"] = "8000"
35 |
36 | hps = get_hparams()
37 | mp.spawn(
38 | run,
39 | nprocs=n_gpus,
40 | args=(
41 | n_gpus,
42 | hps,
43 | ),
44 | )
45 |
46 |
47 | def run(rank, n_gpus, hps):
48 | global global_step
49 | if rank == 0:
50 | logger = task.get_logger(hps.model_dir)
51 | logger.info(hps)
52 | task.check_git_hash(hps.model_dir)
53 | writer = SummaryWriter(log_dir=hps.model_dir)
54 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
55 |
56 | dist.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank)
57 | torch.manual_seed(hps.train.seed)
58 | torch.cuda.set_device(rank)
59 |
60 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
61 | train_sampler = DistributedBucketSampler(train_dataset, hps.train.batch_size, [32, 300, 400, 500, 600, 700, 800, 900, 1000], num_replicas=n_gpus, rank=rank, shuffle=True)
62 | collate_fn = TextAudioCollate()
63 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler)
64 | if rank == 0:
65 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
66 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn)
67 |
68 | net_g = SynthesizerTrn(len(train_dataset.vocab), hps.data.n_mels if hps.data.use_mel else hps.data.n_fft // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model).cuda(rank)
69 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
70 | optim_g = torch.optim.AdamW(net_g.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)
71 | optim_d = torch.optim.AdamW(net_d.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)
72 | net_g = DDP(net_g, device_ids=[rank])
73 | net_d = DDP(net_d, device_ids=[rank])
74 |
75 | try:
76 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
77 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
78 | global_step = (epoch_str - 1) * len(train_loader)
79 | net_g.module.mas_noise_scale = max(hps.model.mas_noise_scale - global_step * hps.model.mas_noise_scale_decay, 0.0)
80 | except:
81 | epoch_str = 1
82 | global_step = 0
83 |
84 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # TODO: check
85 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
86 |
87 | scaler = GradScaler(enabled=hps.train.fp16_run)
88 |
89 | for epoch in range(epoch_str, hps.train.epochs + 1):
90 | if rank == 0:
91 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
92 | else:
93 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
94 | scheduler_g.step()
95 | scheduler_d.step()
96 |
97 |
98 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
99 | net_g, net_d = nets
100 | optim_g, optim_d = optims
101 | scheduler_g, scheduler_d = schedulers
102 | train_loader, eval_loader = loaders
103 | if writers is not None:
104 | writer, writer_eval = writers
105 |
106 | train_loader.batch_sampler.set_epoch(epoch)
107 | global global_step
108 |
109 | net_g.train()
110 | net_d.train()
111 | if rank == 0:
112 | loader = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
113 | else:
114 | loader = train_loader
115 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(loader):
116 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
117 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
118 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
119 |
120 | with autocast(enabled=hps.train.fp16_run):
121 | (
122 | y_hat,
123 | l_length,
124 | attn,
125 | ids_slice,
126 | x_mask,
127 | z_mask,
128 | (m_p_text, logs_p_text),
129 | (m_p_dur, logs_p_dur, z_q_dur, logs_q_dur),
130 | (m_p_audio, logs_p_audio, m_q_audio, logs_q_audio),
131 | ) = net_g(x, x_lengths, spec, spec_lengths)
132 |
133 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max)
134 | y_hat_mel = wav_to_mel(y_hat.squeeze(1), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max)
135 |
136 | y_mel = slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
137 | y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
138 |
139 | # Discriminator
140 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
141 | with autocast(enabled=False):
142 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
143 | loss_disc_all = loss_disc
144 | optim_d.zero_grad()
145 | scaler.scale(loss_disc_all).backward()
146 | scaler.unscale_(optim_d)
147 | grad_norm_d = clip_grad_value_(net_d.parameters(), None)
148 | scaler.step(optim_d)
149 |
150 | with autocast(enabled=hps.train.fp16_run):
151 | # Generator
152 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
153 | with autocast(enabled=False):
154 | loss_dur = torch.sum(l_length.float())
155 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
156 | loss_gen, losses_gen = generator_loss(y_d_hat_g)
157 |
158 | # TODO Test gain constant
159 | if False:
160 | loss_kl_text = kl_loss_normal(m_q_text, logs_q_text, m_p_text, logs_p_text, x_mask) * hps.train.c_kl_text
161 | loss_kl_dur = kl_loss(z_q_dur, logs_q_dur, m_p_dur, logs_p_dur, z_mask) * hps.train.c_kl_dur
162 | loss_kl_audio = kl_loss_normal(m_p_audio, logs_p_audio, m_q_audio, logs_q_audio, z_mask) * hps.train.c_kl_audio
163 |
164 | loss_fm = feature_loss(fmap_r, fmap_g)
165 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl_dur + loss_kl_audio # TODO + loss_kl_text
166 | optim_g.zero_grad()
167 | scaler.scale(loss_gen_all).backward()
168 | scaler.unscale_(optim_g)
169 | grad_norm_g = clip_grad_value_(net_g.parameters(), None)
170 | scaler.step(optim_g)
171 | scaler.update()
172 |
173 | if rank == 0:
174 | if global_step % hps.train.log_interval == 0:
175 | lr = optim_g.param_groups[0]["lr"]
176 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl_dur, loss_kl_audio] # TODO loss_kl_text
177 | losses_str = " ".join(f"{loss.item():.3f}" for loss in losses)
178 | loader.set_postfix_str(f"{losses_str}, {global_step}, {lr:.9f}")
179 |
180 | # scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
181 | # scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl_dur})
182 |
183 | # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
184 | # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
185 | # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
186 | # image_dict = {
187 | # "slice/mel_org": task.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
188 | # "slice/mel_gen": task.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
189 | # "all/mel": task.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
190 | # "all/attn": task.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy()),
191 | # }
192 | # task.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, sample_rate=hps.data.sample_rate)
193 |
194 | # Save checkpoint on CPU to prevent GPU OOM
195 | if global_step % hps.train.eval_interval == 0:
196 | # evaluate(hps, net_g, eval_loader, writer_eval)
197 | task.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
198 | task.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
199 | global_step += 1
200 |
201 |
202 | def evaluate(hps, generator, eval_loader, writer_eval):
203 | generator.eval()
204 | with torch.no_grad():
205 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(eval_loader):
206 | x, x_lengths = x.cuda(0), x_lengths.cuda(0)
207 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
208 | y, y_lengths = y.cuda(0), y_lengths.cuda(0)
209 |
210 | # remove else
211 | x = x[:1]
212 | x_lengths = x_lengths[:1]
213 | spec = spec[:1]
214 | spec_lengths = spec_lengths[:1]
215 | y = y[:1]
216 | y_lengths = y_lengths[:1]
217 | break
218 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, max_len=1000)
219 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
220 |
221 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max)
222 | y_hat_mel = wav_to_mel(y_hat.squeeze(1).float(), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max)
223 | image_dict = {"gen/mel": task.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())}
224 | audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]}
225 | if global_step == 0:
226 | image_dict.update({"gt/mel": task.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
227 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]})
228 |
229 | task.summarize(writer=writer_eval, global_step=global_step, images=image_dict, audios=audio_dict, sample_rate=hps.data.sample_rate)
230 | generator.train()
231 |
232 |
233 | if __name__ == "__main__":
234 | main()
235 |
--------------------------------------------------------------------------------
/train_ms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import logging
4 | import torch
5 | from torch import nn, optim
6 | from torch.nn import functional as F
7 | from torch.utils.data import DataLoader
8 | from torch.utils.tensorboard import SummaryWriter
9 | import torch.multiprocessing as mp
10 | import torch.distributed as dist
11 | from torch.nn.parallel import DistributedDataParallel as DDP
12 | from torch.cuda.amp import autocast, GradScaler
13 | from typing import List
14 |
15 | import utils.task as task
16 | from utils.hparams import get_hparams
17 | from model.models import SynthesizerTrn
18 | from model.discriminator import MultiPeriodDiscriminator
19 | from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate, DistributedBucketSampler
20 | from losses import generator_loss, discriminator_loss, feature_loss, kl_loss, kl_loss_normal
21 | from utils.mel_processing import wav_to_mel, spec_to_mel, spectral_norm
22 | from utils.model import slice_segments, clip_grad_value_
23 |
24 |
25 | torch.backends.cudnn.benchmark = True
26 | global_step = 0
27 |
28 |
29 | def main():
30 | """Assume Single Node Multi GPUs Training Only"""
31 | assert torch.cuda.is_available(), "CPU training is not allowed."
32 |
33 | n_gpus = torch.cuda.device_count()
34 | os.environ["MASTER_ADDR"] = "localhost"
35 | os.environ["MASTER_PORT"] = "8000"
36 |
37 | hps = get_hparams()
38 | mp.spawn(
39 | run,
40 | nprocs=n_gpus,
41 | args=(
42 | n_gpus,
43 | hps,
44 | ),
45 | )
46 |
47 |
48 | def run(rank, n_gpus, hps):
49 | global global_step
50 | if rank == 0:
51 | logger = task.get_logger(hps.model_dir)
52 | logger.info(hps)
53 | task.check_git_hash(hps.model_dir)
54 | writer = SummaryWriter(log_dir=hps.model_dir)
55 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
56 |
57 | dist.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank)
58 | torch.manual_seed(hps.train.seed)
59 | torch.cuda.set_device(rank)
60 |
61 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
62 | train_sampler = DistributedBucketSampler(train_dataset, hps.train.batch_size, [32, 300, 400, 500, 600, 700, 800, 900, 1000], num_replicas=n_gpus, rank=rank, shuffle=True)
63 | collate_fn = TextAudioSpeakerCollate()
64 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler)
65 | if rank == 0:
66 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
67 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn)
68 |
69 | net_g = SynthesizerTrn(
70 | len(train_dataset.vocab), hps.data.n_mels if hps.data.use_mel else hps.data.n_fft // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model
71 | ).cuda(rank)
72 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
73 | optim_g = torch.optim.AdamW(net_g.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)
74 | optim_d = torch.optim.AdamW(net_d.parameters(), hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)
75 | net_g = DDP(net_g, device_ids=[rank])
76 | net_d = DDP(net_d, device_ids=[rank])
77 |
78 | try:
79 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
80 | _, _, _, epoch_str = task.load_checkpoint(task.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
81 | global_step = (epoch_str - 1) * len(train_loader)
82 | net_g.module.mas_noise_scale = max(hps.model.mas_noise_scale - global_step * hps.model.mas_noise_scale_decay, 0.0)
83 | except:
84 | epoch_str = 1
85 | global_step = 0
86 |
87 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
88 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
89 |
90 | scaler = GradScaler(enabled=hps.train.fp16_run)
91 |
92 | for epoch in range(epoch_str, hps.train.epochs + 1):
93 | if rank == 0:
94 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
95 | else:
96 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
97 | scheduler_g.step()
98 | scheduler_d.step()
99 |
100 |
101 | def train_and_evaluate(rank, epoch, hps, nets: List[torch.nn.parallel.DistributedDataParallel], optims: List[torch.optim.Optimizer], schedulers, scaler: GradScaler, loaders, logger: logging.Logger, writers):
102 | net_g, net_d = nets
103 |
104 | optim_g, optim_d = optims
105 | scheduler_g, scheduler_d = schedulers
106 | train_loader, eval_loader = loaders
107 | if writers is not None:
108 | writer, writer_eval = writers
109 |
110 | train_loader.batch_sampler.set_epoch(epoch)
111 | global global_step
112 |
113 | net_g.train()
114 | net_d.train()
115 | if rank == 0:
116 | loader = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
117 | else:
118 | loader = train_loader
119 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(loader):
120 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
121 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
122 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
123 | speakers = speakers.cuda(rank, non_blocking=True)
124 |
125 | with autocast(enabled=hps.train.fp16_run):
126 | (
127 | y_hat,
128 | l_length,
129 | attn,
130 | ids_slice,
131 | x_mask,
132 | z_mask,
133 | (m_p_text, logs_p_text),
134 | (m_p_dur, logs_p_dur, z_q_dur, logs_q_dur),
135 | (m_p_audio, logs_p_audio, m_q_audio, logs_q_audio),
136 | ) = net_g(x, x_lengths, spec, spec_lengths, speakers)
137 |
138 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max)
139 | y_hat_mel = wav_to_mel(y_hat.squeeze(1), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max)
140 |
141 | y_mel = slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
142 | y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
143 |
144 | # Discriminator
145 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
146 | with autocast(enabled=False):
147 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
148 | loss_disc_all = loss_disc
149 | optim_d.zero_grad()
150 | scaler.scale(loss_disc_all).backward()
151 | scaler.unscale_(optim_d)
152 | grad_norm_d = clip_grad_value_(net_d.parameters(), None)
153 | scaler.step(optim_d)
154 |
155 | with autocast(enabled=hps.train.fp16_run):
156 | # Generator
157 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
158 | with autocast(enabled=False):
159 | loss_dur = torch.sum(l_length.float())
160 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
161 | loss_gen, losses_gen = generator_loss(y_d_hat_g)
162 |
163 | # TODO Test gain constant
164 | if False:
165 | loss_kl_text = kl_loss_normal(m_q_text, logs_q_text, m_p_text, logs_p_text, x_mask) * hps.train.c_kl_text
166 | loss_kl_dur = kl_loss(z_q_dur, logs_q_dur, m_p_dur, logs_p_dur, z_mask) * hps.train.c_kl_dur
167 | loss_kl_audio = kl_loss_normal(m_p_audio, logs_p_audio, m_q_audio, logs_q_audio, z_mask) * hps.train.c_kl_audio
168 |
169 | loss_fm = feature_loss(fmap_r, fmap_g)
170 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl_dur + loss_kl_audio # TODO + loss_kl_text
171 | optim_g.zero_grad()
172 | scaler.scale(loss_gen_all).backward()
173 | scaler.unscale_(optim_g)
174 | grad_norm_g = clip_grad_value_(net_g.parameters(), None)
175 | scaler.step(optim_g)
176 | scaler.update()
177 |
178 | if rank == 0:
179 | if global_step % hps.train.log_interval == 0:
180 | lr = optim_g.param_groups[0]["lr"]
181 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl_dur, loss_kl_audio] # TODO loss_kl_text
182 | losses_str = " ".join(f"{loss.item():.3f}" for loss in losses)
183 | loader.set_postfix_str(f"{losses_str}, {global_step}, {lr:.9f}")
184 |
185 | # scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
186 | # scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl_dur})
187 |
188 | # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
189 | # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
190 | # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
191 | # image_dict = {
192 | # "slice/mel_org": task.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
193 | # "slice/mel_gen": task.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
194 | # "all/mel": task.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
195 | # "all/attn": task.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy()),
196 | # }
197 | # task.summarize(writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict, sample_rate=hps.data.sample_rate)
198 |
199 | # Save checkpoint on CPU to prevent GPU OOM
200 | if global_step % hps.train.eval_interval == 0:
201 | # evaluate(hps, net_g, eval_loader, writer_eval)
202 | task.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
203 | task.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
204 | global_step += 1
205 |
206 |
207 | def evaluate(hps, generator, eval_loader, writer_eval):
208 | generator.eval()
209 | with torch.no_grad():
210 | for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader):
211 | x, x_lengths = x.cuda(0), x_lengths.cuda(0)
212 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
213 | y, y_lengths = y.cuda(0), y_lengths.cuda(0)
214 | speakers = speakers.cuda(0)
215 |
216 | # remove else
217 | x = x[:1]
218 | x_lengths = x_lengths[:1]
219 | spec = spec[:1]
220 | spec_lengths = spec_lengths[:1]
221 | y = y[:1]
222 | y_lengths = y_lengths[:1]
223 | speakers = speakers[:1]
224 | break
225 | y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000)
226 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
227 |
228 | mel = spectral_norm(spec) if hps.data.use_mel else spec_to_mel(spec, hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.f_min, hps.data.f_max)
229 | y_hat_mel = wav_to_mel(y_hat.squeeze(1).float(), hps.data.n_fft, hps.data.n_mels, hps.data.sample_rate, hps.data.hop_length, hps.data.win_length, hps.data.f_min, hps.data.f_max)
230 | image_dict = {"gen/mel": task.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())}
231 | audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]}
232 | if global_step == 0:
233 | image_dict.update({"gt/mel": task.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
234 | audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]})
235 |
236 | task.summarize(writer=writer_eval, global_step=global_step, images=image_dict, audios=audio_dict, sample_rate=hps.data.sample_rate)
237 | generator.train()
238 |
239 |
240 | if __name__ == "__main__":
241 | main()
242 |
--------------------------------------------------------------------------------
/utils/hparams.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import argparse
5 | import os
6 | import yaml
7 |
8 |
9 | class HParams:
10 | def __init__(self, **kwargs):
11 | for k, v in kwargs.items():
12 | if type(v) == dict:
13 | v = HParams(**v)
14 | self[k] = v
15 |
16 | def keys(self):
17 | return self.__dict__.keys()
18 |
19 | def items(self):
20 | return self.__dict__.items()
21 |
22 | def values(self):
23 | return self.__dict__.values()
24 |
25 | def __len__(self):
26 | return len(self.__dict__)
27 |
28 | def __getitem__(self, key):
29 | return getattr(self, key)
30 |
31 | def __setitem__(self, key, value):
32 | return setattr(self, key, value)
33 |
34 | def __contains__(self, key):
35 | return key in self.__dict__
36 |
37 | def __repr__(self):
38 | return self.__dict__.__repr__()
39 |
40 |
41 | def get_hparams() -> HParams:
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("-c", "--config", type=str, default="./datasets/base/config.yaml", help="YAML file for configuration")
44 | parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
45 | args = parser.parse_args()
46 |
47 | # assert that path cnsists directory "datasets" and file "config.yaml
48 | assert os.path.exists("./datasets"), "`datasets` directory not found, navigate to the root of the project."
49 | assert os.path.exists(f"./datasets/{args.model}"), f"`{args.model}` not found in `./datasets/`"
50 | assert os.path.exists(f"./datasets/{args.model}/config.yaml"), f"`config.yaml` not found in `./datasets/{args.model}/`"
51 |
52 | model_dir = f"./datasets/{args.model}/logs"
53 | if not os.path.exists(model_dir):
54 | os.makedirs(model_dir)
55 |
56 | config_path = args.config
57 | hparams = get_hparams_from_file(config_path)
58 | hparams.model_dir = model_dir
59 | return hparams
60 |
61 |
62 | def get_hparams_from_file(config_path: str) -> HParams:
63 | with open(config_path, "r") as f:
64 | data = f.read()
65 | config = yaml.safe_load(data)
66 |
67 | hparams = HParams(**config)
68 | return hparams
69 |
70 |
71 | if __name__ == "__main__":
72 | hparams = get_hparams()
73 | print(hparams)
74 |
--------------------------------------------------------------------------------
/utils/mel_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio.transforms as T
3 | import torch.utils.data
4 |
5 | spectrogram_basis = {}
6 | mel_scale_basis = {}
7 | mel_spectrogram_basis = {}
8 |
9 |
10 | def spectral_norm(x: torch.Tensor, clip_val=1e-9):
11 | return torch.log(torch.clamp(x, min=clip_val))
12 |
13 |
14 | def wav_to_spec(y: torch.Tensor, n_fft, sample_rate, hop_length, win_length, center=False) -> torch.Tensor:
15 | assert torch.min(y) >= -1.0, f"min value is {torch.min(y)}"
16 | assert torch.max(y) <= 1.0, f"max value is {torch.max(y)}"
17 |
18 | global spectrogram_basis
19 | dtype_device = str(y.dtype) + "_" + str(y.device)
20 | hparams = dtype_device + "_" + str(n_fft) + "_" + str(hop_length)
21 | if hparams not in spectrogram_basis:
22 | spectrogram_basis[hparams] = T.Spectrogram(
23 | n_fft=n_fft,
24 | win_length=win_length,
25 | hop_length=hop_length,
26 | pad=(n_fft - hop_length) // 2,
27 | power=1,
28 | center=center,
29 | ).to(device=y.device, dtype=y.dtype)
30 |
31 | spec = spectrogram_basis[hparams](y)
32 | spec = torch.sqrt(spec.pow(2) + 1e-6)
33 | return spec
34 |
35 |
36 | def spec_to_mel(spec: torch.Tensor, n_fft, n_mels, sample_rate, f_min, f_max, norm=True) -> torch.Tensor:
37 | global mel_scale_basis
38 | dtype_device = str(spec.dtype) + "_" + str(spec.device)
39 | hparams = dtype_device + "_" + str(n_fft) + "_" + str(n_mels) + "_" + str(f_max)
40 | if hparams not in mel_scale_basis:
41 | mel_scale_basis[hparams] = T.MelScale(n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, n_stft=n_fft // 2 + 1, norm="slaney", mel_scale="slaney").to(device=spec.device, dtype=spec.dtype)
42 |
43 | mel = torch.matmul(mel_scale_basis[hparams].fb.T, spec)
44 | if norm:
45 | mel = spectral_norm(mel)
46 | return mel
47 |
48 |
49 | def wav_to_mel(y: torch.Tensor, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False, norm=True) -> torch.Tensor:
50 | assert torch.min(y) >= -1.0, f"min value is {torch.min(y)}"
51 | assert torch.max(y) <= 1.0, f"max value is {torch.max(y)}"
52 |
53 | global mel_spectrogram_basis
54 | dtype_device = str(y.dtype) + "_" + str(y.device)
55 | hparams = dtype_device + "_" + str(n_fft) + "_" + str(num_mels) + "_" + str(hop_size) + "_" + str(fmax)
56 | if hparams not in mel_spectrogram_basis:
57 | mel_spectrogram_basis[hparams] = T.MelSpectrogram(
58 | sample_rate=sampling_rate,
59 | n_fft=n_fft,
60 | win_length=win_size,
61 | hop_length=hop_size,
62 | n_mels=num_mels,
63 | f_min=fmin,
64 | f_max=fmax,
65 | pad=(n_fft - hop_size) // 2,
66 | power=1,
67 | center=center,
68 | norm="slaney",
69 | mel_scale="slaney",
70 | ).to(device=y.device, dtype=y.dtype)
71 |
72 | mel = mel_spectrogram_basis[hparams](y)
73 | if norm:
74 | mel = spectral_norm(mel)
75 | return mel
76 |
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | def init_weights(m, mean=0.0, std=0.01):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | m.weight.data.normal_(mean, std)
12 |
13 |
14 | def get_padding(kernel_size, dilation=1):
15 | return int((kernel_size * dilation - dilation) / 2)
16 |
17 |
18 | def intersperse(lst, item):
19 | result = [item] * (len(lst) * 2 + 1)
20 | result[1::2] = lst
21 | return result
22 |
23 |
24 | # TODO remove this
25 | def kl_divergence(m_p, logs_p, m_q, logs_q):
26 | """KL(P||Q)"""
27 | kl = (logs_q - logs_p) - 0.5
28 | kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29 | return kl
30 |
31 |
32 | # TODO remove this
33 | def rand_gumbel(shape):
34 | """Sample from the Gumbel distribution, protect from overflows."""
35 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36 | return -torch.log(-torch.log(uniform_samples))
37 |
38 |
39 | # TODO remove this
40 | def rand_gumbel_like(x):
41 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
42 | return g
43 |
44 |
45 | def slice_segments(x, ids_str, segment_size=4):
46 | ret = torch.zeros_like(x[:, :, :segment_size])
47 | for i in range(x.size(0)):
48 | idx_str = ids_str[i]
49 | idx_end = idx_str + segment_size
50 | ret[i] = x[i, :, idx_str:idx_end]
51 | return ret
52 |
53 |
54 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
55 | b, d, t = x.size()
56 | if x_lengths is None:
57 | x_lengths = t
58 | ids_str_max = x_lengths - segment_size + 1
59 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
60 | ret = slice_segments(x, ids_str, segment_size)
61 | return ret, ids_str
62 |
63 |
64 | # TODO remove this
65 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
66 | position = torch.arange(length, dtype=torch.float)
67 | num_timescales = channels // 2
68 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
69 | inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
70 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
71 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
72 | signal = F.pad(signal, [0, 0, 0, channels % 2])
73 | signal = signal.view(1, channels, length)
74 | return signal
75 |
76 |
77 | # TODO remove this
78 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
79 | b, channels, length = x.size()
80 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
81 | return x + signal.to(dtype=x.dtype, device=x.device)
82 |
83 |
84 | # TODO remove this
85 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
86 | b, channels, length = x.size()
87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
89 |
90 |
91 | # TODO remove this
92 | def subsequent_mask(length):
93 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
94 | return mask
95 |
96 |
97 | @torch.jit.script
98 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
99 | n_channels_int = n_channels[0]
100 | in_act = input_a + input_b
101 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
102 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
103 | acts = t_act * s_act
104 | return acts
105 |
106 |
107 | def convert_pad_shape(pad_shape):
108 | l = pad_shape[::-1]
109 | pad_shape = [item for sublist in l for item in sublist]
110 | return pad_shape
111 |
112 |
113 | # TODO remove this
114 | def shift_1d(x):
115 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
116 | return x
117 |
118 |
119 | def sequence_mask(length: torch.Tensor, max_length=None) -> torch.Tensor:
120 | if max_length is None:
121 | max_length = length.max()
122 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
123 | return x.unsqueeze(0) < length.unsqueeze(1)
124 |
125 |
126 | def clip_grad_value_(parameters, clip_value, norm_type=2):
127 | if isinstance(parameters, torch.Tensor):
128 | parameters = [parameters]
129 | parameters = list(filter(lambda p: p.grad is not None, parameters))
130 | norm_type = float(norm_type)
131 | if clip_value is not None:
132 | clip_value = float(clip_value)
133 |
134 | total_norm = 0
135 | for p in parameters:
136 | param_norm = p.grad.data.norm(norm_type)
137 | total_norm += param_norm.item() ** norm_type
138 | if clip_value is not None:
139 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
140 | total_norm = total_norm ** (1.0 / norm_type)
141 | return total_norm
142 |
--------------------------------------------------------------------------------
/utils/monotonic_align.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | import numba
5 | import numpy as np
6 | from numba import cuda
7 |
8 | from utils.model import sequence_mask, convert_pad_shape
9 |
10 |
11 | # * Ready and Tested
12 | def search_path(z_p, m_p, logs_p, x_mask, y_mask, mas_noise_scale=0.01):
13 | with torch.no_grad():
14 | o_scale = torch.exp(-2 * logs_p) # [b, d, t]
15 | logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t]
16 | logp2 = torch.matmul(-0.5 * (z_p**2).mT, o_scale) # [b, t', d] x [b, d, t] = [b, t', t]
17 | logp3 = torch.matmul(z_p.mT, (m_p * o_scale)) # [b, t', d] x [b, d, t] = [b, t', t]
18 | logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [b, 1, t]
19 | logp = logp1 + logp2 + logp3 + logp4 # [b, t', t]
20 |
21 | if mas_noise_scale > 0.0:
22 | epsilon = torch.std(logp) * torch.randn_like(logp) * mas_noise_scale
23 | logp = logp + epsilon
24 |
25 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) # [b, 1, t] * [b, t', 1] = [b, t', t]
26 | attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t', t] maximum_path_cuda
27 | return attn
28 |
29 |
30 | def generate_path(duration: torch.Tensor, mask: torch.Tensor):
31 | """
32 | duration: [b, 1, t_x]
33 | mask: [b, 1, t_y, t_x]
34 | """
35 | b, _, t_y, t_x = mask.shape
36 | cum_duration = torch.cumsum(duration, -1)
37 |
38 | cum_duration_flat = cum_duration.view(b * t_x)
39 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
40 | path = path.view(b, t_x, t_y)
41 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
42 | path = path.unsqueeze(1).mT * mask
43 | return path
44 |
45 |
46 | # ! ----------------------------- CUDA monotonic_align.py -----------------------------
47 |
48 |
49 | # TODO test for the optimal blockspergrid and threadsperblock values
50 | def maximum_path_cuda(neg_cent: torch.Tensor, mask: torch.Tensor):
51 | """CUDA optimized version.
52 | neg_cent: [b, t_t, t_s]
53 | mask: [b, t_t, t_s]
54 | """
55 | device = neg_cent.device
56 | dtype = neg_cent.dtype
57 |
58 | neg_cent_device = cuda.as_cuda_array(neg_cent)
59 | path_device = cuda.device_array(neg_cent.shape, dtype=np.int32)
60 | t_t_max_device = cuda.as_cuda_array(mask.sum(1, dtype=torch.int32)[:, 0])
61 | t_s_max_device = cuda.as_cuda_array(mask.sum(2, dtype=torch.int32)[:, 0])
62 |
63 | blockspergrid = neg_cent.shape[0]
64 | threadsperblock = max(neg_cent.shape[1], neg_cent.shape[2])
65 |
66 | maximum_path_cuda_jit[blockspergrid, threadsperblock](path_device, neg_cent_device, t_t_max_device, t_s_max_device)
67 |
68 | # Convert device array back to tensor
69 | path = torch.as_tensor(path_device.copy_to_host(), device=device, dtype=dtype)
70 | return path
71 |
72 |
73 | @cuda.jit("void(int32[:,:,:], float32[:,:,:], int32[:], int32[:])")
74 | def maximum_path_cuda_jit(paths, values, t_ys, t_xs):
75 | max_neg_val = -1e9
76 | i = cuda.grid(1)
77 | if i >= paths.shape[0]: # exit if the thread is out of the index range
78 | return
79 |
80 | path = paths[i]
81 | value = values[i]
82 | t_y = t_ys[i]
83 | t_x = t_xs[i]
84 |
85 | v_prev = v_cur = 0.0
86 | index = t_x - 1
87 |
88 | for y in range(t_y):
89 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
90 | v_cur = value[y - 1, x] if x != y else max_neg_val
91 | v_prev = value[y - 1, x - 1] if x != 0 else (0.0 if y == 0 else max_neg_val)
92 | value[y, x] += max(v_prev, v_cur)
93 |
94 | for y in range(t_y - 1, -1, -1):
95 | path[y, index] = 1
96 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
97 | index = index - 1
98 | cuda.syncthreads()
99 |
100 |
101 | # ! ------------------------------- CPU monotonic_align.py -------------------------------
102 |
103 |
104 | def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor):
105 | """numba optimized version.
106 | neg_cent: [b, t_t, t_s]
107 | mask: [b, t_t, t_s]
108 | """
109 | device = neg_cent.device
110 | dtype = neg_cent.dtype
111 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
112 | path = np.zeros(neg_cent.shape, dtype=np.int32)
113 |
114 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
115 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
116 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
117 | return torch.from_numpy(path).to(device=device, dtype=dtype)
118 |
119 |
120 | @numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]), nopython=True, nogil=True)
121 | def maximum_path_jit(paths, values, t_ys, t_xs):
122 | b = paths.shape[0]
123 | max_neg_val = -1e9
124 | for i in range(int(b)):
125 | path = paths[i]
126 | value = values[i]
127 | t_y = t_ys[i]
128 | t_x = t_xs[i]
129 |
130 | v_prev = v_cur = 0.0
131 | index = t_x - 1
132 |
133 | for y in range(t_y):
134 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
135 | if x == y:
136 | v_cur = max_neg_val
137 | else:
138 | v_cur = value[y - 1, x]
139 | if x == 0:
140 | if y == 0:
141 | v_prev = 0.0
142 | else:
143 | v_prev = max_neg_val
144 | else:
145 | v_prev = value[y - 1, x - 1]
146 | value[y, x] += max(v_prev, v_cur)
147 |
148 | for y in range(t_y - 1, -1, -1):
149 | path[y, index] = 1
150 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
151 | index = index - 1
152 |
--------------------------------------------------------------------------------
/utils/task.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import logging
5 | import subprocess
6 | import numpy as np
7 | import torch
8 | import torchaudio
9 |
10 | MATPLOTLIB_FLAG = False
11 |
12 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
13 | logging.getLogger("numba").setLevel(logging.WARNING)
14 | logger = logging
15 |
16 |
17 | def load_checkpoint(checkpoint_path, model, optimizer=None):
18 | assert os.path.isfile(checkpoint_path)
19 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
20 | iteration = checkpoint_dict["iteration"]
21 | learning_rate = checkpoint_dict["learning_rate"]
22 | if optimizer is not None:
23 | optimizer.load_state_dict(checkpoint_dict["optimizer"])
24 | saved_state_dict = checkpoint_dict["model"]
25 | if hasattr(model, "module"):
26 | state_dict = model.module.state_dict()
27 | else:
28 | state_dict = model.state_dict()
29 | new_state_dict = {}
30 | for k, v in state_dict.items():
31 | try:
32 | new_state_dict[k] = saved_state_dict[k]
33 | except:
34 | logger.info("%s is not in the checkpoint" % k)
35 | new_state_dict[k] = v
36 | if hasattr(model, "module"):
37 | model.module.load_state_dict(new_state_dict)
38 | else:
39 | model.load_state_dict(new_state_dict)
40 | logger.info("Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration))
41 | del checkpoint_dict
42 | torch.cuda.empty_cache()
43 | return model, optimizer, learning_rate, iteration
44 |
45 |
46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47 | logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path))
48 | if hasattr(model, "module"):
49 | state_dict = model.module.state_dict()
50 | else:
51 | state_dict = model.state_dict()
52 | torch.save({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, checkpoint_path)
53 |
54 |
55 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, sample_rate=22050):
56 | for k, v in scalars.items():
57 | writer.add_scalar(k, v, global_step)
58 | for k, v in histograms.items():
59 | writer.add_histogram(k, v, global_step)
60 | for k, v in images.items():
61 | writer.add_image(k, v, global_step, dataformats="HWC")
62 | for k, v in audios.items():
63 | writer.add_audio(k, v, global_step, sample_rate)
64 |
65 |
66 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
67 | f_list = glob.glob(os.path.join(dir_path, regex))
68 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
69 | x = f_list[-1]
70 | print(x)
71 | return x
72 |
73 |
74 | def plot_spectrogram_to_numpy(spectrogram):
75 | global MATPLOTLIB_FLAG
76 | if not MATPLOTLIB_FLAG:
77 | import matplotlib
78 |
79 | matplotlib.use("Agg")
80 | MATPLOTLIB_FLAG = True
81 | mpl_logger = logging.getLogger("matplotlib")
82 | mpl_logger.setLevel(logging.WARNING)
83 | import matplotlib.pylab as plt
84 | import numpy as np
85 |
86 | fig, ax = plt.subplots(figsize=(10, 2))
87 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
88 | plt.colorbar(im, ax=ax)
89 | plt.xlabel("Frames")
90 | plt.ylabel("Channels")
91 | plt.tight_layout()
92 |
93 | fig.canvas.draw()
94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
96 | plt.close()
97 | return data
98 |
99 |
100 | def plot_alignment_to_numpy(alignment, info=None):
101 | global MATPLOTLIB_FLAG
102 | if not MATPLOTLIB_FLAG:
103 | import matplotlib
104 |
105 | matplotlib.use("Agg")
106 | MATPLOTLIB_FLAG = True
107 | mpl_logger = logging.getLogger("matplotlib")
108 | mpl_logger.setLevel(logging.WARNING)
109 | import matplotlib.pylab as plt
110 | import numpy as np
111 |
112 | fig, ax = plt.subplots(figsize=(6, 4))
113 | im = ax.imshow(alignment.transpose(), aspect="auto", origin="lower", interpolation="none")
114 | fig.colorbar(im, ax=ax)
115 | xlabel = "Decoder timestep"
116 | if info is not None:
117 | xlabel += "\n\n" + info
118 | plt.xlabel(xlabel)
119 | plt.ylabel("Encoder timestep")
120 | plt.tight_layout()
121 |
122 | fig.canvas.draw()
123 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
124 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
125 | plt.close()
126 | return data
127 |
128 |
129 | def load_vocab(vocab_file: str):
130 | """Load vocabulary from text file
131 | Args:
132 | vocab_file (str): Path to vocabulary file
133 | Returns:
134 | torchtext.vocab.Vocab: Vocabulary object
135 | """
136 | from torchtext.vocab import vocab as transform_vocab
137 | from text.symbols import UNK_ID, special_symbols
138 |
139 | vocab = {}
140 | with open(vocab_file, "r") as f:
141 | for line in f:
142 | token, index = line.split()
143 | vocab[token] = int(index)
144 | vocab = transform_vocab(vocab, specials=special_symbols)
145 | vocab.set_default_index(UNK_ID)
146 | return vocab
147 |
148 |
149 | def save_vocab(vocab, vocab_file: str):
150 | """Save vocabulary as token index pairs in a text file, sorted by the indices
151 | Args:
152 | vocab (torchtext.vocab.Vocab): Vocabulary object
153 | vocab_file (str): Path to vocabulary file
154 | """
155 | with open(vocab_file, "w") as f:
156 | for token, index in sorted(vocab.get_stoi().items(), key=lambda kv: kv[1]):
157 | f.write(f"{token}\t{index}\n")
158 |
159 |
160 | def load_wav_to_torch(full_path):
161 | """Load wav file
162 | Args:
163 | full_path (str): Full path of the wav file
164 |
165 | Returns:
166 | waveform (torch.FloatTensor): Stereo audio signal [channel, time] in range [-1, 1]
167 | sample_rate (int): Sampling rate of audio signal (Hz)
168 | """
169 | waveform, sample_rate = torchaudio.load(full_path)
170 | return waveform, sample_rate
171 |
172 |
173 | def load_filepaths_and_text(filename, split="|"):
174 | with open(filename, encoding="utf-8") as f:
175 | filepaths_and_text = [line.strip().split(split) for line in f]
176 | return filepaths_and_text
177 |
178 |
179 | def check_git_hash(model_dir):
180 | source_dir = os.path.dirname(os.path.realpath(__file__))
181 | if not os.path.exists(os.path.join(source_dir, ".git")):
182 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(source_dir))
183 | return
184 |
185 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
186 |
187 | path = os.path.join(model_dir, "githash")
188 | if os.path.exists(path):
189 | saved_hash = open(path).read()
190 | if saved_hash != cur_hash:
191 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(saved_hash[:8], cur_hash[:8]))
192 | else:
193 | open(path, "w").write(cur_hash)
194 |
195 |
196 | def get_logger(model_dir, filename="train.log"):
197 | global logger
198 | logger = logging.getLogger(os.path.basename(model_dir))
199 | logger.setLevel(logging.DEBUG)
200 |
201 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
202 | if not os.path.exists(model_dir):
203 | os.makedirs(model_dir)
204 | h = logging.FileHandler(os.path.join(model_dir, filename))
205 | h.setLevel(logging.DEBUG)
206 | h.setFormatter(formatter)
207 | logger.addHandler(h)
208 | return logger
209 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(
13 | inputs,
14 | unnormalized_widths,
15 | unnormalized_heights,
16 | unnormalized_derivatives,
17 | inverse=False,
18 | tails=None,
19 | tail_bound=1.0,
20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22 | min_derivative=DEFAULT_MIN_DERIVATIVE,
23 | ):
24 | if tails is None:
25 | spline_fn = rational_quadratic_spline
26 | spline_kwargs = {}
27 | else:
28 | spline_fn = unconstrained_rational_quadratic_spline
29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30 |
31 | outputs, logabsdet = spline_fn(
32 | inputs=inputs,
33 | unnormalized_widths=unnormalized_widths,
34 | unnormalized_heights=unnormalized_heights,
35 | unnormalized_derivatives=unnormalized_derivatives,
36 | inverse=inverse,
37 | min_bin_width=min_bin_width,
38 | min_bin_height=min_bin_height,
39 | min_derivative=min_derivative,
40 | **spline_kwargs
41 | )
42 | return outputs, logabsdet
43 |
44 |
45 | def searchsorted(bin_locations, inputs, eps=1e-6):
46 | bin_locations[..., -1] += eps
47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48 |
49 |
50 | def unconstrained_rational_quadratic_spline(
51 | inputs,
52 | unnormalized_widths,
53 | unnormalized_heights,
54 | unnormalized_derivatives,
55 | inverse=False,
56 | tails="linear",
57 | tail_bound=1.0,
58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60 | min_derivative=DEFAULT_MIN_DERIVATIVE,
61 | ):
62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63 | outside_interval_mask = ~inside_interval_mask
64 |
65 | outputs = torch.zeros_like(inputs)
66 | logabsdet = torch.zeros_like(inputs)
67 |
68 | if tails == "linear":
69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70 | constant = np.log(np.exp(1 - min_derivative) - 1)
71 | unnormalized_derivatives[..., 0] = constant
72 | unnormalized_derivatives[..., -1] = constant
73 |
74 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
75 | logabsdet[outside_interval_mask] = 0
76 | else:
77 | raise RuntimeError("{} tails are not implemented.".format(tails))
78 |
79 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
80 | inputs=inputs[inside_interval_mask],
81 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
82 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
83 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
84 | inverse=inverse,
85 | left=-tail_bound,
86 | right=tail_bound,
87 | bottom=-tail_bound,
88 | top=tail_bound,
89 | min_bin_width=min_bin_width,
90 | min_bin_height=min_bin_height,
91 | min_derivative=min_derivative,
92 | )
93 |
94 | return outputs, logabsdet
95 |
96 |
97 | def rational_quadratic_spline(
98 | inputs,
99 | unnormalized_widths,
100 | unnormalized_heights,
101 | unnormalized_derivatives,
102 | inverse=False,
103 | left=0.0,
104 | right=1.0,
105 | bottom=0.0,
106 | top=1.0,
107 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
108 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
109 | min_derivative=DEFAULT_MIN_DERIVATIVE,
110 | ):
111 | if torch.min(inputs) < left or torch.max(inputs) > right:
112 | raise ValueError("Input to a transform is not within its domain")
113 |
114 | num_bins = unnormalized_widths.shape[-1]
115 |
116 | if min_bin_width * num_bins > 1.0:
117 | raise ValueError("Minimal bin width too large for the number of bins")
118 | if min_bin_height * num_bins > 1.0:
119 | raise ValueError("Minimal bin height too large for the number of bins")
120 |
121 | widths = F.softmax(unnormalized_widths, dim=-1)
122 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
123 | cumwidths = torch.cumsum(widths, dim=-1)
124 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
125 | cumwidths = (right - left) * cumwidths + left
126 | cumwidths[..., 0] = left
127 | cumwidths[..., -1] = right
128 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
129 |
130 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
131 |
132 | heights = F.softmax(unnormalized_heights, dim=-1)
133 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
134 | cumheights = torch.cumsum(heights, dim=-1)
135 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
136 | cumheights = (top - bottom) * cumheights + bottom
137 | cumheights[..., 0] = bottom
138 | cumheights[..., -1] = top
139 | heights = cumheights[..., 1:] - cumheights[..., :-1]
140 |
141 | if inverse:
142 | bin_idx = searchsorted(cumheights, inputs)[..., None]
143 | else:
144 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
145 |
146 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
147 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
148 |
149 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
150 | delta = heights / widths
151 | input_delta = delta.gather(-1, bin_idx)[..., 0]
152 |
153 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
154 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
155 |
156 | input_heights = heights.gather(-1, bin_idx)[..., 0]
157 |
158 | if inverse:
159 | a = (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + input_heights * (input_delta - input_derivatives)
160 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
161 | c = -input_delta * (inputs - input_cumheights)
162 |
163 | discriminant = b.pow(2) - 4 * a * c
164 | assert (discriminant >= 0).all()
165 |
166 | root = (2 * c) / (-b - torch.sqrt(discriminant))
167 | outputs = root * input_bin_widths + input_cumwidths
168 |
169 | theta_one_minus_theta = root * (1 - root)
170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta)
171 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
172 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
173 |
174 | return outputs, -logabsdet
175 | else:
176 | theta = (inputs - input_cumwidths) / input_bin_widths
177 | theta_one_minus_theta = theta * (1 - theta)
178 |
179 | numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
180 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta)
181 | outputs = input_cumheights + numerator / denominator
182 |
183 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
184 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
185 |
186 | return outputs, logabsdet
187 |
--------------------------------------------------------------------------------