├── .gitignore ├── .gitmodules ├── README.md ├── audio ├── __init__.py ├── audio_processing.py ├── stft.py └── tools.py ├── config ├── README.md ├── __init__.py ├── algorithm │ ├── base_emb_vad.1-shot.yaml │ ├── base_emb_vad.avg_train_spk_emb.yaml │ ├── base_emb_vad.train_all.1-shot.yaml │ ├── base_emb_vad.train_all.avg_train_spk_emb.yaml │ ├── base_emb_vad.train_all.yaml │ ├── base_emb_vad.train_clean.1-shot.yaml │ ├── base_emb_vad.train_clean.avg_train_spk_emb.yaml │ ├── base_emb_vad.train_clean.yaml │ ├── base_emb_vad.yaml │ ├── base_share_emb_va_d.yaml │ ├── base_table_emb_va_d.yaml │ ├── dev.yaml │ ├── dvec.1-shot.yaml │ ├── dvec.yaml │ ├── encoder.1-shot.yaml │ ├── encoder.yaml │ ├── meta_emb_vad.1-shot.yaml │ ├── meta_emb_vad.avg_train_spk_emb.yaml │ ├── meta_emb_vad.train_all.1-shot.yaml │ ├── meta_emb_vad.train_all.avg_train_spk_emb.yaml │ ├── meta_emb_vad.train_all.yaml │ ├── meta_emb_vad.train_clean.1-shot.yaml │ ├── meta_emb_vad.train_clean.avg_train_spk_emb.yaml │ ├── meta_emb_vad.train_clean.yaml │ ├── meta_emb_vad.yaml │ ├── meta_lingual.yaml │ ├── meta_share_emb_va_d.yaml │ ├── meta_table_emb_va_d.yaml │ ├── scratch_encoder.1-shot.yaml │ └── scratch_encoder.yaml ├── comet.py ├── model │ ├── base.yaml │ ├── dev.yaml │ └── new_dev.yaml ├── preprocess │ ├── LibriTTS.yaml │ ├── VCTK.yaml │ └── miniLibriTTS.yaml └── train │ ├── LibriTTS.yaml │ ├── VCTK.yaml │ ├── base.yaml │ ├── dev.yaml │ └── miniLibriTTS.yaml ├── dataset.py ├── evaluation ├── README.md ├── centroid_similarity.py ├── compute_mos.py ├── config.py ├── images │ ├── LibriTTS │ │ ├── auc_encoder.png │ │ ├── det_encoder.png │ │ ├── eer_encoder.png │ │ ├── errorbar_plot_encoder.png │ │ └── roc_encoder.png │ ├── VCTK │ │ ├── auc_encoder.png │ │ ├── det_encoder.png │ │ ├── eer_encoder.png │ │ ├── errorbar_plot_encoder.png │ │ └── roc_encoder.png │ ├── evaluate_flowchart.jpg │ ├── meta-FastSpeech2.png │ ├── meta-TTS-meta-task.png │ └── meta-TTS-multi-task.png ├── json │ ├── LibriTTS │ │ └── pair.json │ └── VCTK │ │ └── pair.json ├── main.py ├── merge_image.py ├── pair_similarity.py ├── similarity_plot.py ├── speaker_verification.py ├── txt │ ├── LibriTTS │ │ ├── eer.txt │ │ ├── mbnet.txt │ │ ├── mosnet.txt │ │ └── wav2vec2.txt │ └── VCTK │ │ ├── eer.txt │ │ ├── mbnet.txt │ │ └── mosnet.txt ├── visualize.py └── wavs_to_dvector.py ├── lexicon └── librispeech-lexicon.txt ├── lightning ├── callbacks │ ├── __init__.py │ ├── progressbar.py │ ├── saver.py │ └── utils.py ├── collate.py ├── datamodules │ ├── __init__.py │ ├── base_datamodule.py │ ├── baseline_datamodule.py │ ├── define.py │ ├── meta_datamodule.py │ └── utils.py ├── model │ ├── __init__.py │ ├── fastspeech2.py │ ├── loss.py │ ├── modules.py │ ├── optimizer.py │ ├── phoneme_embedding.py │ └── speaker_encoder.py ├── optimizer.py ├── sampler.py ├── scheduler.py ├── systems │ ├── __init__.py │ ├── base_adaptor.py │ ├── baseline.py │ ├── imaml.py │ ├── meta.py │ ├── system.py │ └── utils.py └── utils.py ├── main.py ├── prepare_align.py ├── preprocess.py ├── preprocessed_data └── example_corpus │ └── TextGrid │ └── speaker1 │ └── speaker1_utterance1.TextGrid ├── preprocessor ├── libritts.py ├── preprocessor.py └── vctk.py ├── requirements.txt ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py ├── pinyin.py └── symbols.py ├── transformer ├── Constants.py ├── Layers.py ├── Models.py ├── Modules.py ├── SubLayers.py └── __init__.py └── utils ├── model.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | __pycache__ 107 | .vscode 108 | .DS_Store 109 | 110 | # MFA 111 | montreal-forced-aligner/ 112 | 113 | # learn2learn 114 | learn2learn/ 115 | 116 | # data, checkpoint, and models 117 | .comet.config 118 | lightning_logs/ 119 | raw_data/ 120 | output 121 | preprocessed_data/miniAISHELL-3/ 122 | preprocessed_data/miniLibriTTS/ 123 | preprocessed_data/LibriTTS 124 | preprocessed_data/VCTK 125 | evaluation/speechmetrics/ 126 | evaluation/Pytorch_MBNet/ 127 | evaluation/images/**/._* 128 | evaluation/images/**/*.png 129 | *.npy 130 | *.csv 131 | *.swp 132 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hypertorch"] 2 | path = hypertorch 3 | url = git@github.com:prolearner/hypertorch.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech 2 | 3 | This repository is the official implementation of ["Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech"](https://arxiv.org/abs/2111.04040v1). 4 | 5 | 6 | 7 | | multi-task learning | meta learning | 8 | | --- | --- | 9 | | ![](evaluation/images/meta-TTS-multi-task.png) | ![](evaluation/images/meta-TTS-meta-task.png) | 10 | 11 | ### Meta-TTS 12 | 13 | ![image](evaluation/images/meta-FastSpeech2.png) 14 | 15 | ## Requirements 16 | 17 | This is how I build my environment, which is not exactly needed to be the same: 18 | - Sign up for [Comet.ml](https://www.comet.ml/), find out your workspace and API key via [www.comet.ml/api/my/settings](www.comet.ml/api/my/settings) and fill them in `config/comet.py`. Comet logger is used throughout train/val/test stages. 19 | - Check my training logs [here](https://www.comet.ml/b02901071/meta-tts/view/Zvh3Lz3Wvy2AiWcinD06TaS0G). 20 | - [Optional] Install [pyenv](https://github.com/pyenv/pyenv.git) for Python version 21 | control, change to Python 3.8.6. 22 | ```bash 23 | # After download and install pyenv: 24 | pyenv install 3.8.6 25 | pyenv local 3.8.6 26 | ``` 27 | - [Optional] Install [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv.git) as a plugin of pyenv for clean virtual environment. 28 | ```bash 29 | # After install pyenv-virtualenv 30 | pyenv virtualenv meta-tts 31 | pyenv activate meta-tts 32 | ``` 33 | - Install requirements: 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## Proprocessing 39 | First, download [LibriTTS](https://www.openslr.org/60/) and [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), then change the paths in `config/LibriTTS/preprocess.yaml` and `config/VCTK/preprocess.yaml`, then run 40 | ```bash 41 | python3 prepare_align.py config/LibriTTS/preprocess.yaml 42 | python3 prepare_align.py config/VCTK/preprocess.yaml 43 | ``` 44 | for some preparations. 45 | 46 | Alignments of LibriTTS is provided [here](https://github.com/kan-bayashi/LibriTTSLabel.git), and 47 | the alignments of VCTK is provided [here](https://drive.google.com/file/d/1ScLIiyIgLRIZ03DqCmrZ8F75miC77o8g/view?usp=sharing). 48 | You have to unzip the files into `preprocessed_data/LibriTTS/TextGrid/` and 49 | `preprocessed_data/VCTK/TextGrid/`. 50 | 51 | Then run the preprocessing script: 52 | ```bash 53 | python3 preprocess.py config/LibriTTS/preprocess.yaml 54 | 55 | # Copy stats from LibriTTS to VCTK to keep pitch/energy normalization the same shift and bias. 56 | cp preprocessed_data/LibriTTS/stats.json preprocessed_data/VCTK/ 57 | 58 | python3 preprocess.py config/VCTK/preprocess.yaml 59 | ``` 60 | 61 | ## Training 62 | 63 | To train the models in the paper, run this command: 64 | 65 | ```bash 66 | python3 main.py -s train \ 67 | -p config/preprocess/.yaml \ 68 | -m config/model/base.yaml \ 69 | -t config/train/base.yaml config/train/.yaml \ 70 | -a config/algorithm/.yaml 71 | ``` 72 | 73 | To reproduce, please use 8 V100 GPUs for meta models, and 1 V100 GPU for baseline 74 | models, or else you might need to tune gradient accumulation step (grad_acc_step) 75 | setting in `config/train/base.yaml` to get the correct meta batch size. 76 | Note that each GPU has its own random seed, so even the meta batch size is the 77 | same, different number of GPUs is equivalent to different random seed. 78 | 79 | After training, you can find your checkpoints under 80 | `output/ckpt////checkpoints/`, where the 81 | project name is set in `config/comet.py`. 82 | 83 | To inference the models, run: 84 | ```bash 85 | python3 main.py -s test \ 86 | -p config/preprocess/.yaml \ 87 | -m config/model/base.yaml \ 88 | -t config/train/base.yaml config/train/.yaml \ 89 | -a config/algorithm/.yaml \ 90 | -e -c 91 | ``` 92 | and the results would be under 93 | `output/result////`. 94 | 95 | ## Evaluation 96 | 97 | > **Note:** The evaluation code is not well-refactored yet. 98 | 99 | `cd evaluation/` and check [README.md](evaluation/README.md) 100 | 101 | ## Pre-trained Models 102 | 103 | > **Note:** The checkpoints are with older version, might not capatiable with 104 | > the current code. We would fix the problem in the future. 105 | 106 | Since our codes are using Comet logger, you might need to create a dummy 107 | experiment by running: 108 | ```Python 109 | from comet_ml import Experiment 110 | experiment = Experiment() 111 | ``` 112 | then put the checkpoint files under 113 | `output/ckpt/LibriTTS///checkpoints/`. 114 | 115 | You can download pretrained models [here](https://drive.google.com/drive/folders/1Av7afSMcHX6pp2_ZmpHqfJNx6ONM7N8d?usp=sharing). 116 | 117 | ## Results 118 | 119 | | Corpus | LibriTTS | VCTK | 120 | | --- | --- | --- | 121 | | Speaker Similarity | ![](evaluation/images/LibriTTS/errorbar_plot_encoder.png) | ![](evaluation/images/VCTK/errorbar_plot_encoder.png) | 122 | | Speaker Verification | ![](evaluation/images/LibriTTS/eer_encoder.png)
![](evaluation/images/LibriTTS/det_encoder.png) | ![](evaluation/images/VCTK/eer_encoder.png)
![](evaluation/images/VCTK/det_encoder.png) | 123 | | Synthesized Speech Detection | ![](evaluation/images/LibriTTS/auc_encoder.png)
![](evaluation/images/LibriTTS/roc_encoder.png) | ![](evaluation/images/VCTK/auc_encoder.png)
![](evaluation/images/VCTK/roc_encoder.png) | 124 | 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data.cuda(), 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes): 152 | output = dynamic_range_compression(magnitudes) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1 170 | assert torch.max(y.data) <= 1 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, energy 179 | -------------------------------------------------------------------------------- /audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | 5 | from audio.audio_processing import griffin_lim 6 | 7 | 8 | def get_mel_from_wav(audio, _stft): 9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 10 | audio = torch.autograd.Variable(audio, requires_grad=False) 11 | melspec, energy = _stft.mel_spectrogram(audio) 12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 14 | 15 | return melspec, energy 16 | 17 | 18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 19 | mel = torch.stack([mel]) 20 | mel_decompress = _stft.spectral_de_normalize(mel) 21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 22 | spec_from_mel_scaling = 1000 23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 26 | 27 | audio = griffin_lim( 28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 29 | ) 30 | 31 | audio = audio.squeeze() 32 | audio = audio.cpu().numpy() 33 | audio_path = out_filename 34 | write(audio_path, _stft.sampling_rate, audio) 35 | -------------------------------------------------------------------------------- /config/README.md: -------------------------------------------------------------------------------- 1 | # Config 2 | Here are the config files used to train the multi-speaker TTS models. 3 | 2 different configurations are given: 4 | - LibriTTS: suggested configuration for LibriTTS dataset. 5 | - VCTK: suggested configuration for VCTK dataset. 6 | 7 | Some important hyper-parameters are explained here. 8 | 9 | ## preprocess.yaml 10 | - **path.lexicon_path**: the lexicon (which maps words to phonemes) used by Montreal Forced Aligner. 11 | We provide an English lexicon in `lexicon/`. 12 | - **mel.stft.mel_fmax**: set it to null as MelGAN is used. 13 | - **pitch.feature & energy.feature**: the original FastSpeech 2 paper proposed to predict and apply frame-level pitch and energy features to the inputs of the TTS decoder to control the pitch and energy of the synthesized utterances. 14 | However, in our experiments, we find that using phoneme-level features makes the prosody of the synthesized utterances more natural. 15 | - **pitch.normalization & energy.normalization**: to normalize the pitch and energy values or not. 16 | The original FastSpeech 2 paper did not normalize these values. 17 | 18 | ## train.yaml 19 | - **optimizer.grad_acc_step**: the number of batches of gradient accumulation before updating the model parameters and call optimizer.zero_grad(), which is useful if you wish to train the model with a large batch size but you do not have sufficient GPU memory. 20 | - **optimizer.anneal_steps & optimizer.anneal_rate**: the learning rate is reduced at the **anneal_steps** by the ratio specified with **anneal_rate**. 21 | 22 | ## model.yaml 23 | - **transformer.decoder_layer**: the original FastSpeech 2 paper used a 4-layer decoder, but we find it better to use a 6-layer decoder, especially for multi-speaker TTS, which is the only architecture difference from the original FastSpeech 2. 24 | - **variance_embedding.pitch_quantization**: when the pitch values are normalized as specified in ``preprocess.yaml``, it is not valid to use log-scale quantization bins as proposed in the original paper, so we use linear-scaled bins instead. 25 | - **multi_speaker**: to apply a speaker embedding table to enable multi-speaker TTS or not. 26 | - **vocoder.speaker**: should be set to 'universal'. 27 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/config/__init__.py -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad-1_shot 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 1000 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: True 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.avg_train_spk_emb.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_va_d-avg_train_spk_emb 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: True 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.train_all.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad-train_all-1_shot 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 1000 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: True 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.train_all.avg_train_spk_emb.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad-train_all-avg_train_spk_emb 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: True 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.train_all.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad-train_all 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.train_clean.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad-train_clean-1_shot 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 1000 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: True 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.train_clean.avg_train_spk_emb.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_va_d-train_clean-avg_train_spk_emb 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: True 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.train_clean.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad-train_clean 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/base_emb_vad.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/base_share_emb_va_d.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_va_d shared 2 | type: baseline # meta/baseline 3 | 4 | adapt: 5 | speaker_emb: shared # shared/table/encoder 6 | modules: 7 | - speaker_emb 8 | - variance_adaptor 9 | - decoder 10 | - mel_linear 11 | - postnet 12 | 13 | ways: 1 14 | shots: 5 15 | queries: 5 16 | steps: 5 17 | lr: 0.001 18 | meta_batch_size: 8 19 | 20 | test: 21 | queries: 1 22 | steps: 100 # max adaptation steps for testing 23 | -------------------------------------------------------------------------------- /config/algorithm/base_table_emb_va_d.yaml: -------------------------------------------------------------------------------- 1 | name: base_emb_vad 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/dev.yaml: -------------------------------------------------------------------------------- 1 | name: dev meta_emb_va_d table 2 | type: imaml # meta/baseline/imaml 3 | #meta_type: spk 4 | # 5 | _phn_emb_config: 6 | embedding: &embedding 7 | type: embedding 8 | refresh: False 9 | codebook: &codebook 10 | type: codebook 11 | size: 30 12 | representation_dim: 1024 13 | attention: 14 | type: hard 15 | share: False 16 | 17 | adapt: 18 | type: lang # spk/lang 19 | class: iMAML # MAML/iMAML 20 | speaker_emb: table # shared/table/encoder 21 | phoneme_emb: *codebook # *embedding/*codebook 22 | imaml: 23 | K: 5 # CG steps # TODO: need tuning 24 | reg_param: 1 # TODO: need tuning 25 | batch_size: 5 26 | stochastic: True 27 | 28 | modules: 29 | - encoder 30 | - variance_adaptor 31 | - decoder 32 | - mel_linear 33 | - postnet 34 | 35 | task: &task 36 | ways: 1 37 | shots: 10 38 | queries: 5 39 | lr: 0.001 40 | 41 | train: 42 | << : *task 43 | steps: 2 44 | meta_batch_size: 1 45 | 46 | test: 47 | << : *task 48 | queries: 1 49 | steps: 100 # max adaptation steps for testing 50 | -------------------------------------------------------------------------------- /config/algorithm/dvec.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: dvec-1_shot 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: dvec # shared/table/encoder/dvec/scratch_encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: [] 15 | 16 | task: &task 17 | ways: 1 18 | shots: 5 19 | queries: 5 20 | lr: 0.001 21 | 22 | train: 23 | << : *task 24 | steps: 5 25 | meta_batch_size: 8 26 | 27 | test: 28 | << : *task 29 | queries: 1 30 | steps: 0 # max adaptation steps for testing 31 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 32 | avg_train_spk_emb: False 33 | 1-shot: True 34 | -------------------------------------------------------------------------------- /config/algorithm/dvec.yaml: -------------------------------------------------------------------------------- 1 | name: dvec 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: dvec # shared/table/encoder/dvec/scratch_encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: [] 15 | 16 | task: &task 17 | ways: 1 18 | shots: 5 19 | queries: 5 20 | lr: 0.001 21 | 22 | train: 23 | << : *task 24 | steps: 5 25 | meta_batch_size: 8 26 | 27 | test: 28 | << : *task 29 | queries: 1 30 | steps: 0 # max adaptation steps for testing 31 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 32 | avg_train_spk_emb: False 33 | 1-shot: False 34 | -------------------------------------------------------------------------------- /config/algorithm/encoder.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: encoder-1_shot 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: encoder # shared/table/encoder/dvec/scratch_encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: [] 15 | 16 | task: &task 17 | ways: 1 18 | shots: 5 19 | queries: 5 20 | lr: 0.001 21 | 22 | train: 23 | << : *task 24 | steps: 5 25 | meta_batch_size: 8 26 | 27 | test: 28 | << : *task 29 | queries: 1 30 | steps: 0 # max adaptation steps for testing 31 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 32 | avg_train_spk_emb: False 33 | 1-shot: True 34 | -------------------------------------------------------------------------------- /config/algorithm/encoder.yaml: -------------------------------------------------------------------------------- 1 | name: encoder 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: encoder # shared/table/encoder/dvec/scratch_encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: [] 15 | 16 | task: &task 17 | ways: 1 18 | shots: 5 19 | queries: 5 20 | lr: 0.001 21 | 22 | train: 23 | << : *task 24 | steps: 5 25 | meta_batch_size: 8 26 | 27 | test: 28 | << : *task 29 | queries: 1 30 | steps: 0 # max adaptation steps for testing 31 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 32 | avg_train_spk_emb: False 33 | 1-shot: False 34 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-1_shot 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 1000 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: True 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.avg_train_spk_emb.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-avg_train_spk_emb 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: True 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.train_all.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-train_all-1_shot 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 1000 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: True 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.train_all.avg_train_spk_emb.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-train_all-avg_train_spk_emb 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: True 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.train_all.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-train_all 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.train_clean.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-train_clean-1_shot 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 1000 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: True 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.train_clean.avg_train_spk_emb.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-train_clean-avg_train_spk_emb 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: True 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.train_clean.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad-train_clean 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_emb_vad.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_vad 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/meta_lingual.yaml: -------------------------------------------------------------------------------- 1 | name: meta-lingual-debug 2 | type: meta # meta/baseline 3 | meta_type: lang 4 | 5 | adapt: 6 | speaker_emb: table # shared/table/encoder 7 | modules: 8 | - encoder 9 | - variance_adaptor 10 | - decoder 11 | - mel_linear 12 | - postnet 13 | 14 | ways: 1 15 | shots: 25 16 | queries: 25 17 | steps: 5 18 | lr: 0.0003 19 | meta_batch_size: 1 20 | 21 | test: 22 | queries: 25 23 | steps: 100 # max adaptation steps for testing 24 | -------------------------------------------------------------------------------- /config/algorithm/meta_share_emb_va_d.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_va_d shared 2 | type: meta # meta/baseline 3 | 4 | adapt: 5 | speaker_emb: shared # shared/table/encoder 6 | modules: 7 | - speaker_emb 8 | - variance_adaptor 9 | - decoder 10 | - mel_linear 11 | - postnet 12 | 13 | ways: 1 14 | shots: 5 15 | queries: 5 16 | steps: 5 17 | lr: 0.001 18 | meta_batch_size: 8 19 | 20 | test: 21 | queries: 1 22 | steps: 100 # max adaptation steps for testing 23 | -------------------------------------------------------------------------------- /config/algorithm/meta_table_emb_va_d.yaml: -------------------------------------------------------------------------------- 1 | name: meta_emb_va_d table 2 | type: meta # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: MAML # MAML/iMAML, not used 12 | speaker_emb: table # shared/table/encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: 15 | - speaker_emb 16 | - variance_adaptor 17 | - decoder 18 | - mel_linear 19 | - postnet 20 | 21 | task: &task 22 | ways: 1 23 | shots: 5 24 | queries: 5 25 | lr: 0.001 26 | 27 | train: 28 | << : *task 29 | steps: 5 30 | meta_batch_size: 8 31 | 32 | test: 33 | << : *task 34 | queries: 1 35 | steps: 100 # max adaptation steps for testing 36 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 37 | avg_train_spk_emb: False 38 | 1-shot: False 39 | -------------------------------------------------------------------------------- /config/algorithm/scratch_encoder.1-shot.yaml: -------------------------------------------------------------------------------- 1 | name: scratch_encoder-1_shot 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: scratch_encoder # shared/table/encoder/dvec/scratch_encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: [] 15 | 16 | task: &task 17 | ways: 1 18 | shots: 5 19 | queries: 5 20 | lr: 0.001 21 | 22 | train: 23 | << : *task 24 | steps: 5 25 | meta_batch_size: 8 26 | 27 | test: 28 | << : *task 29 | queries: 1 30 | steps: 0 # max adaptation steps for testing 31 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 32 | avg_train_spk_emb: False 33 | 1-shot: True 34 | -------------------------------------------------------------------------------- /config/algorithm/scratch_encoder.yaml: -------------------------------------------------------------------------------- 1 | name: scratch_encoder 2 | type: baseline # meta/baseline/imaml, get_system 3 | 4 | _phn_emb_config: 5 | embedding: &embedding 6 | type: embedding 7 | refresh: False 8 | 9 | adapt: 10 | type: spk # spk/lang 11 | #class: iMAML # MAML/iMAML, not used 12 | speaker_emb: scratch_encoder # shared/table/encoder/dvec/scratch_encoder 13 | phoneme_emb: *embedding # *embedding/*codebook 14 | modules: [] 15 | 16 | task: &task 17 | ways: 1 18 | shots: 5 19 | queries: 5 20 | lr: 0.001 21 | 22 | train: 23 | << : *task 24 | steps: 5 25 | meta_batch_size: 8 26 | 27 | test: 28 | << : *task 29 | queries: 1 30 | steps: 0 # max adaptation steps for testing 31 | saving_steps: [5, 10, 20, 50, 100, 200, 400, 600, 800, 1000] 32 | avg_train_spk_emb: False 33 | 1-shot: False 34 | -------------------------------------------------------------------------------- /config/comet.py: -------------------------------------------------------------------------------- 1 | 2 | COMET_CONFIG = { 3 | "api_key": "API_KEY", 4 | "workspace": "WORKSPACE", 5 | "project_name": "PROJECT_NAME", 6 | "log_code": True, 7 | "log_graph": True, 8 | "parse_args": True, 9 | "log_env_details": True, 10 | "log_git_metadata": True, 11 | "log_git_patch": True, 12 | "log_env_gpu": True, 13 | "log_env_cpu": True, 14 | "log_env_host": True, 15 | } 16 | -------------------------------------------------------------------------------- /config/model/base.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 4 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 6 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | multi_speaker: True 24 | multi_lingual: True 25 | 26 | max_seq_len: 1000 27 | 28 | vocoder: 29 | model: "MelGAN" 30 | speaker: "universal" # support 'LJSpeech', 'universal' 31 | -------------------------------------------------------------------------------- /config/model/dev.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 1 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 1 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | multi_speaker: True 24 | multi_lingual: True 25 | 26 | max_seq_len: 1000 27 | 28 | vocoder: 29 | model: "MelGAN" 30 | speaker: "universal" # support 'LJSpeech', 'universal' 31 | -------------------------------------------------------------------------------- /config/model/new_dev.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 4 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 6 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | multi_speaker: True 24 | multi_lingual: True 25 | 26 | max_seq_len: 1000 27 | codebook_size: 30 28 | representation_dim: 1024 29 | 30 | vocoder: 31 | model: "MelGAN" 32 | speaker: "universal" # support 'LJSpeech', 'universal' 33 | -------------------------------------------------------------------------------- /config/preprocess/LibriTTS.yaml: -------------------------------------------------------------------------------- 1 | dataset: "LibriTTS" 2 | lang_id: 0 3 | 4 | path: 5 | corpus_path: "/home/r06942045/myData/LibriTTS" 6 | lexicon_path: "lexicon/librispeech-lexicon.txt" 7 | raw_path: "./raw_data/LibriTTS" 8 | preprocessed_path: "./preprocessed_data/LibriTTS" 9 | 10 | subsets: 11 | #train: "train-all" 12 | train: "train-clean" 13 | #train: 14 | #- train-clean-100 15 | #- train-clean-360 16 | #- train-other-500 17 | val: "dev-clean" 18 | test: "test-clean" 19 | 20 | preprocessing: 21 | val_size: 512 22 | text: 23 | text_cleaners: ["english_cleaners"] 24 | language: "en" 25 | audio: 26 | sampling_rate: 22050 27 | max_wav_value: 32768.0 28 | stft: 29 | filter_length: 1024 30 | hop_length: 256 31 | win_length: 1024 32 | mel: 33 | n_mel_channels: 80 34 | mel_fmin: 0 35 | mel_fmax: Null # set to null for MelGAN vocoder 36 | pitch: 37 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 38 | normalization: True 39 | energy: 40 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 41 | normalization: True 42 | -------------------------------------------------------------------------------- /config/preprocess/VCTK.yaml: -------------------------------------------------------------------------------- 1 | dataset: "VCTK" 2 | lang_id: 0 3 | 4 | path: 5 | corpus_path: "/home/r06942045/myData/VCTK-Corpus" 6 | lexicon_path: "lexicon/librispeech-lexicon.txt" 7 | raw_path: "./raw_data/VCTK" 8 | preprocessed_path: "./preprocessed_data/VCTK" 9 | 10 | subsets: 11 | test: "all" 12 | 13 | preprocessing: 14 | val_size: 0 15 | text: 16 | text_cleaners: ["english_cleaners"] 17 | language: "en" 18 | audio: 19 | sampling_rate: 22050 20 | max_wav_value: 32768.0 21 | stft: 22 | filter_length: 1024 23 | hop_length: 256 24 | win_length: 1024 25 | mel: 26 | n_mel_channels: 80 27 | mel_fmin: 0 28 | mel_fmax: Null # set to null for MelGAN vocoder 29 | pitch: 30 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 31 | normalization: True 32 | energy: 33 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 34 | normalization: True 35 | -------------------------------------------------------------------------------- /config/preprocess/miniLibriTTS.yaml: -------------------------------------------------------------------------------- 1 | dataset: "miniLibriTTS" 2 | lang_id: 0 3 | 4 | path: 5 | corpus_path: "/home/r06942045/myData/LibriTTS" 6 | lexicon_path: "lexicon/librispeech-lexicon.txt" 7 | raw_path: "./raw_data/LibriTTS" 8 | preprocessed_path: "./preprocessed_data/miniLibriTTS" 9 | 10 | subsets: 11 | train: "test-clean" 12 | val: "test-clean" 13 | test: "test-clean" 14 | 15 | preprocessing: 16 | val_size: 512 17 | text: 18 | text_cleaners: ["english_cleaners"] 19 | language: "en" 20 | audio: 21 | sampling_rate: 22050 22 | max_wav_value: 32768.0 23 | stft: 24 | filter_length: 1024 25 | hop_length: 256 26 | win_length: 1024 27 | mel: 28 | n_mel_channels: 80 29 | mel_fmin: 0 30 | mel_fmax: Null # set to null for MelGAN vocoder 31 | pitch: 32 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 33 | normalization: True 34 | energy: 35 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 36 | normalization: True 37 | -------------------------------------------------------------------------------- /config/train/LibriTTS.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/LibriTTS" 3 | log_path: "./output/log/LibriTTS" 4 | result_path: "./output/result/LibriTTS" 5 | -------------------------------------------------------------------------------- /config/train/VCTK.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/VCTK" 3 | log_path: "./output/log/VCTK" 4 | result_path: "./output/result/VCTK" 5 | -------------------------------------------------------------------------------- /config/train/base.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | batch_size: 80 # meta_batch_size * (shots + query) # Only used by baseline model 3 | betas: [0.9, 0.98] 4 | eps: 0.000000001 5 | weight_decay: 0.0 6 | grad_clip_thresh: 1.0 7 | grad_acc_step: 1 8 | warm_up_step: 4000 9 | anneal_steps: [300000, 400000, 500000] 10 | anneal_rate: 0.3 11 | step: 12 | total_step: 100000 13 | log_step: 100 14 | synth_step: 1000 15 | val_step: 1000 16 | save_step: 1000 17 | -------------------------------------------------------------------------------- /config/train/dev.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | batch_size: 16 3 | betas: [0.9, 0.98] 4 | eps: 0.000000001 5 | weight_decay: 0.0 6 | grad_clip_thresh: 1.0 7 | grad_acc_step: 1 8 | warm_up_step: 4000 9 | anneal_steps: [300000, 400000, 500000] 10 | anneal_rate: 0.3 11 | step: 12 | total_step: 600 13 | log_step: 100 14 | synth_step: 100 15 | val_step: 100 16 | save_step: 100 17 | -------------------------------------------------------------------------------- /config/train/miniLibriTTS.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/miniLibriTTS" 3 | log_path: "./output/log/miniLibriTTS" 4 | result_path: "./output/result/miniLibriTTS" 5 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation_for_TTS 2 | 3 | For MOS prediction of [Utilizing Self-supervised Representations for MOS Prediction](https://arxiv.org/abs/2104.03017), 4 | we asked the authors for the code, but since the code is not yet publicly released, we could not provide it here. 5 | To get the code, please contact its authors. 6 | 7 | ## Prepare 8 | - Install [speechmetrics](https://github.com/aliutkus/speechmetrics.git) for 9 | MOSNet: 10 | ```bash 11 | git clone https://github.com/aliutkus/speechmetrics.git 12 | cd speechmetrics 13 | # bash 14 | pip install -e .[gpu] 15 | # zsh 16 | pip install -e .\[gpu\] 17 | ``` 18 | - Download [Pytorch_MBNet](https://github.com/sky1456723/Pytorch-MBNet.git) for MBNet: 19 | ```bash 20 | git clone https://github.com/sky1456723/Pytorch-MBNet.git 21 | ``` 22 | - Fix paths and configurations in `config.py` 23 | - Prepare output directories 24 | ```bash 25 | mkdir -p npy/LibriTTS 26 | mkdir -p npy/VCTK 27 | mkdir -p csv/LibriTTS 28 | mkdir -p csv/VCTK 29 | ``` 30 | 31 | ## MOS prediction 32 | ```bash 33 | python compare_mos.py --net 34 | ``` 35 | 36 | ## Speaker adaptation metrics 37 | ![image](images/evaluate_flowchart.jpg) 38 | -------------------------------------------------------------------------------- /evaluation/images/LibriTTS/auc_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/LibriTTS/auc_encoder.png -------------------------------------------------------------------------------- /evaluation/images/LibriTTS/det_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/LibriTTS/det_encoder.png -------------------------------------------------------------------------------- /evaluation/images/LibriTTS/eer_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/LibriTTS/eer_encoder.png -------------------------------------------------------------------------------- /evaluation/images/LibriTTS/errorbar_plot_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/LibriTTS/errorbar_plot_encoder.png -------------------------------------------------------------------------------- /evaluation/images/LibriTTS/roc_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/LibriTTS/roc_encoder.png -------------------------------------------------------------------------------- /evaluation/images/VCTK/auc_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/VCTK/auc_encoder.png -------------------------------------------------------------------------------- /evaluation/images/VCTK/det_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/VCTK/det_encoder.png -------------------------------------------------------------------------------- /evaluation/images/VCTK/eer_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/VCTK/eer_encoder.png -------------------------------------------------------------------------------- /evaluation/images/VCTK/errorbar_plot_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/VCTK/errorbar_plot_encoder.png -------------------------------------------------------------------------------- /evaluation/images/VCTK/roc_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/VCTK/roc_encoder.png -------------------------------------------------------------------------------- /evaluation/images/evaluate_flowchart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/evaluate_flowchart.jpg -------------------------------------------------------------------------------- /evaluation/images/meta-FastSpeech2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/meta-FastSpeech2.png -------------------------------------------------------------------------------- /evaluation/images/meta-TTS-meta-task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/meta-TTS-meta-task.png -------------------------------------------------------------------------------- /evaluation/images/meta-TTS-multi-task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/Meta-TTS/85c228c0a277e066e525fab959088aad5c59d293/evaluation/images/meta-TTS-multi-task.png -------------------------------------------------------------------------------- /evaluation/main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from wavs_to_dvector import WavsToDvector 3 | from centroid_similarity import CentroidSimilarity 4 | from pair_similarity import PairSimilarity 5 | from speaker_verification import SpeakerVerification 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = ArgumentParser() 10 | parser.add_argument('--new_pair', type=bool, default=False) 11 | parser.add_argument('--output_path', type=str, default='eer.txt') 12 | args = parser.parse_args() 13 | main = WavsToDvector(args) 14 | 15 | main = CentroidSimilarity() 16 | main.load_dvector() 17 | main.get_centroid_similarity() 18 | main.save_centroid_similarity() 19 | 20 | main = PairSimilarity() 21 | main.load_dvector() 22 | main.get_pair_similarity() 23 | main.save_pair_similarity() 24 | 25 | main = SpeakerVerification(args) 26 | main.load_pair_similarity() 27 | main.get_eer() 28 | # for suffix in [ '_encoder']: 29 | # for suffix in ['_base_emb', '_base_emb1', '_meta_emb', '_meta_emb1']: 30 | # main.set_suffix(suffix) 31 | # main.plot_eer(suffix) 32 | # main.plot_auc(suffix) 33 | 34 | -------------------------------------------------------------------------------- /evaluation/merge_image.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | im1 = Image.open('images/LibriTTS/det.png') 4 | im2 = Image.open('images/VCTK/det.png') 5 | w, h = im1.size 6 | pad = 0 7 | crop_w = int(w*0.68) 8 | im = Image.new('RGB', (crop_w+pad+w, h), (255,255,255)) 9 | im.paste(im1, (0, 0)) 10 | im.paste(im2.crop((0, 0, crop_w, h)), (w+pad,0)) 11 | im.save('images/det.png') 12 | im.show() 13 | 14 | # im1 = Image.open('images/LibriTTS/eer.png') 15 | # im2 = Image.open('images/VCTK/eer.png') 16 | # w, h = im1.size 17 | # pad = 0 18 | # crop_w = int(w*0.66) 19 | # im = Image.new('RGB', (crop_w+pad+w, h), (255,255,255)) 20 | # im.paste(im1, (0,0)) 21 | # im.paste(im2.crop((0, 0, crop_w, h)), (w+pad, 0)) 22 | # im.save('images/eer.png') 23 | # im.show() 24 | 25 | # im1 = Image.open('images/LibriTTS/errorbar_plot.png') 26 | # im2 = Image.open('images/VCTK/errorbar_plot.png') 27 | # w, h = im1.size 28 | # pad = 0 29 | # crop_w = int(w*0.66) 30 | # im = Image.new('RGB', (crop_w+pad+w, h), (255,255,255)) 31 | # im.paste(im1, (0,0)) 32 | # im.paste(im2.crop((0, 0, crop_w, h)), (w+pad, 0)) 33 | # im.save('images/errorbar_plot.png') 34 | # im.show() 35 | 36 | # im1 = Image.open('images/LibriTTS/roc.png') 37 | # im2 = Image.open('images/VCTK/roc.png') 38 | # w, h = im1.size 39 | # pad = 0 40 | # crop_w = int(w*0.68) 41 | # im = Image.new('RGB', (crop_w+pad+w, h), (255,255,255)) 42 | # im.paste(im1, (0, 0)) 43 | # im.paste(im2.crop((0, 0, crop_w, h)), (w+pad,0)) 44 | # im.save('images/roc.png') 45 | # im.show() 46 | -------------------------------------------------------------------------------- /evaluation/pair_similarity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torchaudio 6 | import numpy as np 7 | import random 8 | import json 9 | 10 | from resemblyzer import VoiceEncoder, preprocess_wav 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | 14 | import config 15 | 16 | class PairSimilarity: 17 | def __init__(self): 18 | self.corpus = config.corpus 19 | # self.pair_sim_mode_list = config.pair_sim_mode_list 20 | self.mode_list = config.mode_list 21 | self.step_list = config.step_list 22 | self.mode_step_list = config.mode_step_list 23 | 24 | def load_dvector(self): 25 | self.dvector_list_dict = dict() 26 | for mode in ['recon', 'real', 'pair']: 27 | self.dvector_list_dict[mode] = np.load(f'npy/{self.corpus}/{mode}_dvector.npy', allow_pickle=True) 28 | for mode, steps in tqdm(self.mode_step_list, desc='mode'): 29 | for step in tqdm(steps, leave=False): 30 | # if mode in ['scratch_encoder', 'encoder', 'dvec'] and step != 0: 31 | # continue 32 | self.dvector_list_dict[f'{mode}_step{step}'] = np.load( 33 | f'npy/{self.corpus}/{mode}_step{step}_dvector.npy', allow_pickle=True 34 | ) 35 | 36 | def get_pair_similarity(self): 37 | self.pair_similarity_dict = dict() 38 | for mode in ['recon', 'real']: 39 | print(f'Getting pair similarity of mode: {mode}') 40 | if os.path.exists(f'npy/{self.corpus}/{mode}_pair_sim.npy'): 41 | print(f'\tLoading from: \n\t\tnpy/{self.corpus}/{mode}_pair_sim.npy') 42 | self.pair_similarity_dict[mode] = np.load( 43 | f'npy/{self.corpus}/{mode}_pair_sim.npy', allow_pickle=True 44 | ) 45 | else: 46 | self.pair_similarity_dict[mode] = self.compute_pair_similarity(self.dvector_list_dict[mode]) 47 | if not os.path.exists(f'npy/{self.corpus}/{mode}_pair_sim.npy'): 48 | print(f'\tSaving to: \n\t\tnpy/{self.corpus}/{mode}_pair_sim.npy') 49 | np.save(f'npy/{self.corpus}/{mode}_pair_sim.npy', 50 | self.pair_similarity_dict[mode], allow_pickle=True) 51 | 52 | for mode, steps in tqdm(self.mode_step_list, desc='mode'): 53 | for step in tqdm(steps, leave=False): 54 | print(f'Getting pair similarity of mode: {mode}, step: {step}') 55 | if os.path.exists(f'npy/{self.corpus}/{mode}_step{step}_pair_sim.npy'): 56 | print(f'\tLoading from: \n\t\tnpy/{self.corpus}/{mode}_step{step}_pair_sim.npy') 57 | self.pair_similarity_dict[f'{mode}_step{step}'] = np.load( 58 | f'npy/{self.corpus}/{mode}_step{step}_pair_sim.npy', allow_pickle=True 59 | ) 60 | else: 61 | self.pair_similarity_dict[f'{mode}_step{step}'] = self.compute_pair_similarity(self.dvector_list_dict[f'{mode}_step{step}']) 62 | if not os.path.exists(f'npy/{self.corpus}/{mode}_step{step}_pair_sim.npy'): 63 | print(f'\tSaving to: \n\t\tnpy/{self.corpus}/{mode}_step{step}_pair_sim.npy') 64 | np.save(f'npy/{self.corpus}/{mode}_step{step}_pair_sim.npy', 65 | self.pair_similarity_dict[f'{mode}_step{step}'], allow_pickle=True) 66 | 67 | 68 | def compute_pair_similarity(self, check_list): 69 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 70 | dvector_test_repeat_tensor = torch.from_numpy(np.repeat(check_list, 4, axis=0)) 71 | 72 | if (dvector_test_repeat_tensor.shape[0] == self.dvector_list_dict['pair'].shape[1]): 73 | pair_dvector_list_positive = torch.from_numpy(self.dvector_list_dict['pair'][0,:]) 74 | pair_dvector_list_negative = torch.from_numpy(self.dvector_list_dict['pair'][1,:]) 75 | else: 76 | assert (dvector_test_repeat_tensor.shape[0] == self.dvector_list_dict['pair'].shape[1]*5) 77 | dvector_list_dict = np.repeat(self.dvector_list_dict['pair'], 5, axis=1) 78 | pair_dvector_list_positive = torch.from_numpy(dvector_list_dict[0,:]) 79 | pair_dvector_list_negative = torch.from_numpy(dvector_list_dict[1,:]) 80 | 81 | with torch.no_grad(): 82 | pair_similarity_list_positive = cos(dvector_test_repeat_tensor, pair_dvector_list_positive).detach().cpu().numpy() 83 | pair_similarity_list_negative = cos(dvector_test_repeat_tensor, pair_dvector_list_negative).detach().cpu().numpy() 84 | pos_exp = np.expand_dims(pair_similarity_list_positive, axis=0) 85 | neg_exp = np.expand_dims(pair_similarity_list_negative, axis=0) 86 | pair_similarity_list = np.concatenate((pos_exp, neg_exp), axis=0) 87 | 88 | return pair_similarity_list # [2, num_test_samples] 89 | 90 | def save_pair_similarity(self): 91 | np.save(f'npy/{self.corpus}/pair_similarity.npy', self.pair_similarity_dict, allow_pickle=True) 92 | 93 | def load_pair_similarity(self): 94 | self.pair_similarity_dict = np.load(f'npy/{self.corpus}/pair_similarity.npy', allow_pickle=True)[()] 95 | 96 | if __name__ == '__main__': 97 | main = PairSimilarity() 98 | main.load_dvector() 99 | main.get_pair_similarity() 100 | main.save_pair_similarity() 101 | #main.load_pair_similarity() 102 | -------------------------------------------------------------------------------- /evaluation/txt/LibriTTS/mbnet.txt: -------------------------------------------------------------------------------- 1 | real, 3.0998066712759043, 0.06761463442595016 2 | recon, 2.9494027452249276, 0.058265167806764044 3 | base_emb_vad_step0, 3.0932134197730767, 0.048714665169813556 4 | base_emb_vad_step5, 2.7906651332190164, 0.04839244804587661 5 | base_emb_vad_step10, 2.7646724122919535, 0.04964458128604268 6 | base_emb_vad_step20, 2.738788509251256, 0.04931496835138024 7 | base_emb_vad_step50, 2.722774910103334, 0.05102192893134009 8 | base_emb_vad_step100, 2.697758700894682, 0.05152738839691235 9 | base_emb_va_step0, 3.092493726234687, 0.04856032778932958 10 | base_emb_va_step5, 3.0481630271594775, 0.04855360269371888 11 | base_emb_va_step10, 3.050547972713646, 0.050197317936808546 12 | base_emb_va_step20, 3.0882336822779557, 0.04915642744849016 13 | base_emb_va_step50, 3.1423092629564437, 0.04941750661989038 14 | base_emb_va_step100, 3.136937092009344, 0.04928097536965588 15 | base_emb_d_step0, 3.093197797866244, 0.04869906045021205 16 | base_emb_d_step5, 2.8628291842576705, 0.04711960903510713 17 | base_emb_d_step10, 2.8245130675403693, 0.04708691637171462 18 | base_emb_d_step20, 2.801979055138011, 0.04801655181374102 19 | base_emb_d_step50, 2.7418729849159718, 0.0484094365797929 20 | base_emb_d_step100, 2.7268280473194624, 0.049144253716850266 21 | base_emb_step0, 3.092493726234687, 0.04856032778932958 22 | base_emb_step5, 3.089926293609958, 0.04873540432708355 23 | base_emb_step10, 3.0827179343292586, 0.048761583249807686 24 | base_emb_step20, 3.087727585121205, 0.04887121459610161 25 | base_emb_step50, 3.089080808782264, 0.0491646467468258 26 | base_emb_step100, 3.079163065670352, 0.05035499783570206 27 | meta_emb_vad_step0, 3.5074708153934857, 0.0478152966121964 28 | meta_emb_vad_step5, 2.5488798718311285, 0.04885451233081697 29 | meta_emb_vad_step10, 2.441913863742038, 0.049371695409596295 30 | meta_emb_vad_step20, 2.3993183333230648, 0.05099444407794531 31 | meta_emb_vad_step50, 2.4028517411727655, 0.050821362557678676 32 | meta_emb_vad_step100, 2.4383454401242104, 0.050914249851754245 33 | meta_emb_va_step0, 3.4577871606146036, 0.046928229181820665 34 | meta_emb_va_step5, 3.0339753568956724, 0.05528877643073055 35 | meta_emb_va_step10, 2.9993242009689935, 0.056476118595492744 36 | meta_emb_va_step20, 3.004195406444763, 0.05680643691732692 37 | meta_emb_va_step50, 3.0252432772203495, 0.05463495931765407 38 | meta_emb_va_step100, 3.0236595386737273, 0.05579881225284358 39 | meta_emb_d_step0, 3.302149846365577, 0.04927718970292861 40 | meta_emb_d_step5, 2.6492644060207042, 0.05031595512793882 41 | meta_emb_d_step10, 2.5705940172468362, 0.05073750998550067 42 | meta_emb_d_step20, 2.548281764905704, 0.051272076783215534 43 | meta_emb_d_step50, 2.547390048245066, 0.051696118643244734 44 | meta_emb_d_step100, 2.5450702072366287, 0.051746405880803155 45 | meta_emb_step0, 3.1525100926427463, 0.04917916940711618 46 | meta_emb_step5, 3.1492875183099196, 0.04852238333364393 47 | meta_emb_step10, 3.1430198942360126, 0.04900129221511211 48 | meta_emb_step20, 3.1490705099545027, 0.04838677734626939 49 | meta_emb_step50, 3.1606833911255787, 0.04769419767611777 50 | meta_emb_step100, 3.1575357504189014, 0.04791131424682647 51 | base_emb1_vad_step0, 3.3178421572635046, 0.046513501691966226 52 | base_emb1_vad_step5, 3.1666759234902107, 0.0449544568953031 53 | base_emb1_vad_step10, 3.1500962331498923, 0.04437951018500616 54 | base_emb1_vad_step20, 3.125835326549254, 0.04584222309614975 55 | base_emb1_vad_step50, 3.091314305796435, 0.04404654560985202 56 | base_emb1_vad_step100, 3.098910886010057, 0.04661970339388942 57 | base_emb1_va_step0, 3.31887457206061, 0.046539539856274126 58 | base_emb1_va_step5, 3.31290452084259, 0.04743488055318983 59 | base_emb1_va_step10, 3.3258937930590227, 0.04641912133445338 60 | base_emb1_va_step20, 3.3196748679405763, 0.04717819738988407 61 | base_emb1_va_step50, 3.3507169167462147, 0.044881614899367164 62 | base_emb1_va_step100, 3.32251760814535, 0.0439895072259104 63 | base_emb1_d_step0, 3.318235134803935, 0.046530921329043934 64 | base_emb1_d_step5, 3.1502003877570757, 0.04358965249896795 65 | base_emb1_d_step10, 3.1254205346891752, 0.043893442091995354 66 | base_emb1_d_step20, 3.132610975323539, 0.04366638493670724 67 | base_emb1_d_step50, 3.100661910678211, 0.04452637875308769 68 | base_emb1_d_step100, 3.078260793105552, 0.045583532776910005 69 | base_emb1_step0, 3.31887457206061, 0.046539539856274126 70 | base_emb1_step5, 3.317996844649315, 0.04601485342610973 71 | base_emb1_step10, 3.310502472676729, 0.0468134500203139 72 | base_emb1_step20, 3.3212314795114493, 0.046901836567603876 73 | base_emb1_step50, 3.3152759441811788, 0.046770500787381826 74 | base_emb1_step100, 3.318976945194759, 0.04699973144513864 75 | meta_emb1_vad_step0, 2.435106070222039, 0.06636215892106394 76 | meta_emb1_vad_step5, 1.8389447158888768, 0.04439320592783737 77 | meta_emb1_vad_step10, 1.8258926356701475, 0.044190779280264214 78 | meta_emb1_vad_step20, 1.8243700491362496, 0.04251511907043484 79 | meta_emb1_vad_step50, 1.843956349711669, 0.04461960390059327 80 | meta_emb1_vad_step100, 1.8356421887874603, 0.04257216972126996 81 | meta_emb1_va_step0, 3.818016982588329, 0.03572363337873478 82 | meta_emb1_va_step5, 3.1604798651839556, 0.05291811585893242 83 | meta_emb1_va_step10, 3.1332908378619897, 0.05336828565173607 84 | meta_emb1_va_step20, 3.133008793976746, 0.054857853012403465 85 | meta_emb1_va_step50, 3.121938541531563, 0.05378801599931929 86 | meta_emb1_va_step100, 3.123715956352259, 0.052894441387369796 87 | meta_emb1_d_step0, 2.891439030437093, 0.061231351009742314 88 | meta_emb1_d_step5, 2.569915217592528, 0.04543138611712987 89 | meta_emb1_d_step10, 2.5332564044939843, 0.04702989130142927 90 | meta_emb1_d_step20, 2.505083627802761, 0.04823837577553091 91 | meta_emb1_d_step50, 2.4972414017507902, 0.04886929481302274 92 | meta_emb1_d_step100, 2.4600070532607403, 0.0469536441245108 93 | meta_emb1_step0, 3.314951139649278, 0.046708537418881646 94 | meta_emb1_step5, 3.3260014710065566, 0.0465437738558877 95 | meta_emb1_step10, 3.322444368545946, 0.046788461951275 96 | meta_emb1_step20, 3.3248588405549526, 0.04577748568758765 97 | meta_emb1_step50, 3.3148307745393955, 0.046748117883944576 98 | meta_emb1_step100, 3.313822230814319, 0.047122075087280606 99 | -------------------------------------------------------------------------------- /evaluation/txt/LibriTTS/mosnet.txt: -------------------------------------------------------------------------------- 1 | real, 3.201337968832568, 0.03986935105834017 2 | recon, 3.412636025171531, 0.046701488171906586 3 | base_emb_vad_step0, 3.3676418650307154, 0.03746364184698117 4 | base_emb_vad_step5, 3.3039062174135134, 0.038448626851326634 5 | base_emb_vad_step10, 3.265101950223509, 0.0397729065798709 6 | base_emb_vad_step20, 3.2479452859414253, 0.03897903074288306 7 | base_emb_vad_step50, 3.2540491405678424, 0.04625916918458891 8 | base_emb_vad_step100, 3.2304635077322783, 0.04620690648528665 9 | base_emb_va_step0, 3.3686475020490194, 0.03743740530766208 10 | base_emb_va_step5, 3.3686112924233864, 0.03803199971154426 11 | base_emb_va_step10, 3.3492137199561847, 0.03811669278946754 12 | base_emb_va_step20, 3.375590327342874, 0.03946152046852547 13 | base_emb_va_step50, 3.389424511084431, 0.04319472292808167 14 | base_emb_va_step100, 3.3853114964930633, 0.041294608346386666 15 | base_emb_d_step0, 3.3677600155535496, 0.037466699240184254 16 | base_emb_d_step5, 3.311088865524844, 0.037942129254370086 17 | base_emb_d_step10, 3.2946582452247015, 0.03803098135602754 18 | base_emb_d_step20, 3.2701840290897772, 0.037476209797372005 19 | base_emb_d_step50, 3.2292421038605665, 0.03936099565870525 20 | base_emb_d_step100, 3.2132847852220663, 0.040094795392910186 21 | base_emb_step0, 3.3686475020490194, 0.03743740530766208 22 | base_emb_step5, 3.3678191861039712, 0.037167794163616874 23 | base_emb_step10, 3.3618024289608, 0.03719943721310199 24 | base_emb_step20, 3.3706686065385214, 0.03727269291284123 25 | base_emb_step50, 3.3669375718424193, 0.037189019174485936 26 | base_emb_step100, 3.3659172548275245, 0.03725904072517303 27 | meta_emb_vad_step0, 3.155492122432119, 0.036629131490509406 28 | meta_emb_vad_step5, 3.034593613720254, 0.04445795930129351 29 | meta_emb_vad_step10, 3.027069683725897, 0.04562205980206068 30 | meta_emb_vad_step20, 3.028296868659948, 0.04724232535556341 31 | meta_emb_vad_step50, 3.0485325418412685, 0.04749964616829884 32 | meta_emb_vad_step100, 3.059189623319789, 0.04765164103114746 33 | meta_emb_va_step0, 3.3922596197379264, 0.040861292819441876 34 | meta_emb_va_step5, 3.268508694869907, 0.04194995295280227 35 | meta_emb_va_step10, 3.236906358285954, 0.04071600622044631 36 | meta_emb_va_step20, 3.2602794799757633, 0.04181124831969164 37 | meta_emb_va_step50, 3.280323849304726, 0.04370169603823064 38 | meta_emb_va_step100, 3.292748646712617, 0.04401586715593092 39 | meta_emb_d_step0, 3.138027667607132, 0.037726834376788844 40 | meta_emb_d_step5, 3.064130145664278, 0.042075798004692896 41 | meta_emb_d_step10, 3.0561642919323946, 0.043702583355429044 42 | meta_emb_d_step20, 3.0463048043219665, 0.04264791217187038 43 | meta_emb_d_step50, 3.062404863732426, 0.043165708264194406 44 | meta_emb_d_step100, 3.063221792445371, 0.04286769725378485 45 | meta_emb_step0, 3.3324303709362684, 0.037614818876335804 46 | meta_emb_step5, 3.3328311213929402, 0.0378248262385322 47 | meta_emb_step10, 3.3328188993036747, 0.0376571217183766 48 | meta_emb_step20, 3.331526673938099, 0.03734534984177455 49 | meta_emb_step50, 3.3420766752801443, 0.03709001267074178 50 | meta_emb_step100, 3.345018107640116, 0.03716280817721549 51 | base_emb1_vad_step0, 3.2201322881798995, 0.042113937675226996 52 | base_emb1_vad_step5, 3.2596841044723988, 0.04468057825094692 53 | base_emb1_vad_step10, 3.2523149389185404, 0.047184580868469723 54 | base_emb1_vad_step20, 3.2847280765050337, 0.04826966729976369 55 | base_emb1_vad_step50, 3.4006105353565594, 0.05622455179688294 56 | base_emb1_vad_step100, 3.5299757217106067, 0.05912666870716859 57 | base_emb1_va_step0, 3.223578382087381, 0.042106068212486865 58 | base_emb1_va_step5, 3.248281316537606, 0.042818369492010566 59 | base_emb1_va_step10, 3.2576987263011303, 0.04489693910621258 60 | base_emb1_va_step20, 3.322443688386365, 0.049283377515292374 61 | base_emb1_va_step50, 3.4722456743842676, 0.054633863494289024 62 | base_emb1_va_step100, 3.635908254471264, 0.0566658484258312 63 | base_emb1_d_step0, 3.2202925980091095, 0.042128236860930576 64 | base_emb1_d_step5, 3.2121427649337995, 0.04159585545177458 65 | base_emb1_d_step10, 3.192517094706234, 0.04246277870092773 66 | base_emb1_d_step20, 3.193853928854591, 0.04237430266614909 67 | base_emb1_d_step50, 3.1750146268229735, 0.042151526328500145 68 | base_emb1_d_step100, 3.186405345405403, 0.04358198679995892 69 | base_emb1_step0, 3.223578382087381, 0.042106068212486865 70 | base_emb1_step5, 3.2272761162174377, 0.04226331841637534 71 | base_emb1_step10, 3.2352519564722715, 0.042420047720244616 72 | base_emb1_step20, 3.243023547294893, 0.04265648546308094 73 | base_emb1_step50, 3.2334551517116394, 0.042155603100098984 74 | base_emb1_step100, 3.231393805263858, 0.042828524282683834 75 | meta_emb1_vad_step0, 2.760109054022714, 0.02506048216632578 76 | meta_emb1_vad_step5, 2.739273412055091, 0.0508273521500014 77 | meta_emb1_vad_step10, 2.744325948388953, 0.051938936695801345 78 | meta_emb1_vad_step20, 2.768014217873937, 0.052014932127693744 79 | meta_emb1_vad_step50, 2.769045090008723, 0.05279744424460109 80 | meta_emb1_vad_step100, 2.7712345166425956, 0.051659305949102244 81 | meta_emb1_va_step0, 3.6504660758532976, 0.04856879840274154 82 | meta_emb1_va_step5, 3.408578964440446, 0.05351444921506671 83 | meta_emb1_va_step10, 3.4180875440177165, 0.05360646768046595 84 | meta_emb1_va_step20, 3.368900363578608, 0.05066626203256213 85 | meta_emb1_va_step50, 3.335397705043617, 0.05021283514633093 86 | meta_emb1_va_step100, 3.335811276380953, 0.05022845461278615 87 | meta_emb1_d_step0, 2.108970970111458, 0.021493574910985993 88 | meta_emb1_d_step5, 2.9641674206052957, 0.041156813209318355 89 | meta_emb1_d_step10, 2.9631937246181463, 0.04235695044000117 90 | meta_emb1_d_step20, 2.9732736213818978, 0.04371786406899768 91 | meta_emb1_d_step50, 2.9561530173216997, 0.04318085333281836 92 | meta_emb1_d_step100, 2.958218639422404, 0.043677672668892685 93 | meta_emb1_step0, 3.200500512397603, 0.042467654224700925 94 | meta_emb1_step5, 3.1993633218501745, 0.04238316509433875 95 | meta_emb1_step10, 3.198031627229954, 0.04249910201361394 96 | meta_emb1_step20, 3.2026829119575653, 0.04276099150105015 97 | meta_emb1_step50, 3.188899510589085, 0.041979769746260706 98 | meta_emb1_step100, 3.197268416614909, 0.04286970682441824 99 | -------------------------------------------------------------------------------- /evaluation/txt/LibriTTS/wav2vec2.txt: -------------------------------------------------------------------------------- 1 | real, 4.42456654184743, 0.026250306950686683 2 | recon, 3.9786542587374387, 0.03758378426074139 3 | base_emb_vad_step0, 4.30329809847631, 0.02827131807833441 4 | base_emb_vad_step5, 4.2018882461676474, 0.030254691048952783 5 | base_emb_vad_step10, 4.219324217030876, 0.031101746585594704 6 | base_emb_vad_step20, 4.231188449420427, 0.02814475074881777 7 | base_emb_vad_step50, 4.233088667455473, 0.028864731928429848 8 | base_emb_vad_step100, 4.199176441289876, 0.03158014908272819 9 | base_emb_va_step0, 4.306127946627767, 0.028373907092473474 10 | base_emb_va_step5, 4.2819249884862645, 0.02772959716410172 11 | base_emb_va_step10, 4.310626111140377, 0.02561730822003704 12 | base_emb_va_step20, 4.329034181017625, 0.02613418500911687 13 | base_emb_va_step50, 4.373736571128431, 0.024838187070443676 14 | base_emb_va_step100, 4.380138202325294, 0.025325164550489526 15 | base_emb_d_step0, 4.303466816089656, 0.028264226625620018 16 | base_emb_d_step5, 4.242390267943081, 0.02800746562412018 17 | base_emb_d_step10, 4.2334286252919, 0.029467367813603977 18 | base_emb_d_step20, 4.22548213444258, 0.030295680810658445 19 | base_emb_d_step50, 4.186092880211379, 0.031208800067914162 20 | base_emb_d_step100, 4.151742135223589, 0.03295160059975424 21 | base_emb_step0, 4.306127946627767, 0.028373907092473474 22 | base_emb_step5, 4.30078866371983, 0.029609570341761225 23 | base_emb_step10, 4.305139288698372, 0.028916519517767723 24 | base_emb_step20, 4.308063836474168, 0.028171494076008847 25 | base_emb_step50, 4.307667328338874, 0.027648483932367126 26 | base_emb_step100, 4.310512022752511, 0.026443908748655407 27 | meta_emb_vad_step0, 4.265659945183678, 0.027164501027608406 28 | meta_emb_vad_step5, 4.04895039138041, 0.028866622738939595 29 | meta_emb_vad_step10, 3.9636664806227935, 0.03257466975713684 30 | meta_emb_vad_step20, 3.893797249974389, 0.03549160239606095 31 | meta_emb_vad_step50, 3.8755226119568476, 0.03496805109042087 32 | meta_emb_vad_step100, 3.878892378195336, 0.03410210650567211 33 | meta_emb_va_step0, 4.3212904812474, 0.028893397768236562 34 | meta_emb_va_step5, 4.214754113437314, 0.030531255625899464 35 | meta_emb_va_step10, 4.196747286931465, 0.030807907487113352 36 | meta_emb_va_step20, 4.203729180521087, 0.03113370327197175 37 | meta_emb_va_step50, 4.219166655681636, 0.028681017932307155 38 | meta_emb_va_step100, 4.229745752325184, 0.02704967156087518 39 | meta_emb_d_step0, 4.265645944758465, 0.026820668241836796 40 | meta_emb_d_step5, 4.121271540067698, 0.031075045683074877 41 | meta_emb_d_step10, 4.032727787369176, 0.03441445641605655 42 | meta_emb_d_step20, 3.96393337688948, 0.03619262217178412 43 | meta_emb_d_step50, 3.931758446128745, 0.03541838505557114 44 | meta_emb_d_step100, 3.9145606275843945, 0.03472047684589429 45 | meta_emb_step0, 4.366588279212776, 0.02248959264237703 46 | meta_emb_step5, 4.366672911534184, 0.022223812996105985 47 | meta_emb_step10, 4.368018280126546, 0.02230722626240597 48 | meta_emb_step20, 4.368979346595313, 0.021696831524857834 49 | meta_emb_step50, 4.369441663356204, 0.022433004312863096 50 | meta_emb_step100, 4.365810421344481, 0.022962560985711894 51 | base_emb1_vad_step0, 4.463462595092623, 0.02073665622131917 52 | base_emb1_vad_step5, 4.439665093233711, 0.0211713595650555 53 | base_emb1_vad_step10, 4.438386037357543, 0.022661617165118792 54 | base_emb1_vad_step20, 4.435737983176582, 0.020618926770053135 55 | base_emb1_vad_step50, 4.392466587847785, 0.021985646965091518 56 | base_emb1_vad_step100, 4.335330097298873, 0.026205830804997655 57 | base_emb1_va_step0, 4.466208813221831, 0.020621955447720875 58 | base_emb1_va_step5, 4.459405666119174, 0.02136389436716317 59 | base_emb1_va_step10, 4.465054333602128, 0.020980586044788037 60 | base_emb1_va_step20, 4.471216410790619, 0.0190257750126452 61 | base_emb1_va_step50, 4.4511884278372715, 0.02049275614736155 62 | base_emb1_va_step100, 4.411165007635167, 0.02302761947335878 63 | base_emb1_d_step0, 4.463777856607186, 0.020663435688678743 64 | base_emb1_d_step5, 4.454637635303171, 0.021667653017134324 65 | base_emb1_d_step10, 4.446284733516605, 0.02263208938976037 66 | base_emb1_d_step20, 4.440510096518617, 0.021666220252990084 67 | base_emb1_d_step50, 4.4051582342699955, 0.024005763122250806 68 | base_emb1_d_step100, 4.367845735267589, 0.025343012262229198 69 | base_emb1_step0, 4.466208813221831, 0.020621955447720875 70 | base_emb1_step5, 4.463140482573133, 0.02092770156643521 71 | base_emb1_step10, 4.463800158155592, 0.0208021885120554 72 | base_emb1_step20, 4.459585409023259, 0.020983706204784756 73 | base_emb1_step50, 4.463072405833947, 0.020898537139660583 74 | base_emb1_step100, 4.457653117415152, 0.020341320810945106 75 | meta_emb1_vad_step0, 3.2074698531313945, 0.06595511222802213 76 | meta_emb1_vad_step5, 2.9720823443249653, 0.06006993633609541 77 | meta_emb1_vad_step10, 2.8975056254942166, 0.05492799519876597 78 | meta_emb1_vad_step20, 2.9411468331358934, 0.055155745364220965 79 | meta_emb1_vad_step50, 2.9358481698130308, 0.054904934214018986 80 | meta_emb1_vad_step100, 2.950684112545691, 0.05160767903534765 81 | meta_emb1_va_step0, 4.5206782900189095, 0.022192777457958952 82 | meta_emb1_va_step5, 4.413133616902326, 0.0231245430309985 83 | meta_emb1_va_step10, 4.414536211835711, 0.02131353894637311 84 | meta_emb1_va_step20, 4.4050117660509915, 0.02092451337122775 85 | meta_emb1_va_step50, 4.382645149372126, 0.02467413447894565 86 | meta_emb1_va_step100, 4.392414617303171, 0.022839132186852504 87 | meta_emb1_d_step0, 2.9519021487549733, 0.030661451627903482 88 | meta_emb1_d_step5, 3.9998618432957875, 0.03368436356819319 89 | meta_emb1_d_step10, 3.945032059754196, 0.035002074557426294 90 | meta_emb1_d_step20, 3.9088215439727434, 0.03535456958159461 91 | meta_emb1_d_step50, 3.867711244052962, 0.03602897104818789 92 | meta_emb1_d_step100, 3.854876233362838, 0.03600596866172569 93 | meta_emb1_step0, 4.479024097323418, 0.02232855279584162 94 | meta_emb1_step5, 4.473775724047108, 0.023265405120504103 95 | meta_emb1_step10, 4.472217404528668, 0.02325513187693406 96 | meta_emb1_step20, 4.47112299776391, 0.023745366838673863 97 | meta_emb1_step50, 4.47211504257039, 0.023760659951777185 98 | meta_emb1_step100, 4.470362861297633, 0.02398611883529576 99 | -------------------------------------------------------------------------------- /evaluation/txt/VCTK/eer.txt: -------------------------------------------------------------------------------- 1 | recon: 2 | threshold:0.6944 EER:0.0951 3 | real: 4 | threshold:0.7091 EER:0.0641 5 | base_emb_vad_step0: 6 | threshold:0.5356 EER:0.4926 7 | base_emb_vad_step5: 8 | threshold:0.5386 EER:0.4796 9 | base_emb_vad_step10: 10 | threshold:0.5444 EER:0.4589 11 | base_emb_vad_step20: 12 | threshold:0.5545 EER:0.4333 13 | base_emb_vad_step50: 14 | threshold:0.5793 EER:0.3624 15 | base_emb_vad_step100: 16 | threshold:0.6032 EER:0.2983 17 | base_emb_va_step0: 18 | threshold:0.5357 EER:0.4916 19 | base_emb_va_step5: 20 | threshold:0.5364 EER:0.4910 21 | base_emb_va_step10: 22 | threshold:0.5378 EER:0.4877 23 | base_emb_va_step20: 24 | threshold:0.5386 EER:0.4796 25 | base_emb_va_step50: 26 | threshold:0.5418 EER:0.4692 27 | base_emb_va_step100: 28 | threshold:0.5466 EER:0.4381 29 | base_emb_d_step0: 30 | threshold:0.5356 EER:0.4925 31 | base_emb_d_step5: 32 | threshold:0.5385 EER:0.4800 33 | base_emb_d_step10: 34 | threshold:0.5434 EER:0.4663 35 | base_emb_d_step20: 36 | threshold:0.5540 EER:0.4429 37 | base_emb_d_step50: 38 | threshold:0.5728 EER:0.3958 39 | base_emb_d_step100: 40 | threshold:0.5872 EER:0.3587 41 | base_emb_step0: 42 | threshold:0.5357 EER:0.4916 43 | base_emb_step5: 44 | threshold:0.5362 EER:0.4907 45 | base_emb_step10: 46 | threshold:0.5368 EER:0.4899 47 | base_emb_step20: 48 | threshold:0.5370 EER:0.4918 49 | base_emb_step50: 50 | threshold:0.5382 EER:0.4886 51 | base_emb_step100: 52 | threshold:0.5382 EER:0.4877 53 | meta_emb_vad_step0: 54 | threshold:0.5297 EER:0.4968 55 | meta_emb_vad_step5: 56 | threshold:0.5926 EER:0.3069 57 | meta_emb_vad_step10: 58 | threshold:0.6050 EER:0.2626 59 | meta_emb_vad_step20: 60 | threshold:0.6156 EER:0.2341 61 | meta_emb_vad_step50: 62 | threshold:0.6234 EER:0.2134 63 | meta_emb_vad_step100: 64 | threshold:0.6303 EER:0.2021 65 | meta_emb_va_step0: 66 | threshold:0.5318 EER:0.4891 67 | meta_emb_va_step5: 68 | threshold:0.5490 EER:0.4062 69 | meta_emb_va_step10: 70 | threshold:0.5524 EER:0.3944 71 | meta_emb_va_step20: 72 | threshold:0.5527 EER:0.3997 73 | meta_emb_va_step50: 74 | threshold:0.5551 EER:0.3932 75 | meta_emb_va_step100: 76 | threshold:0.5556 EER:0.3986 77 | meta_emb_d_step0: 78 | threshold:0.5346 EER:0.4968 79 | meta_emb_d_step5: 80 | threshold:0.5675 EER:0.3801 81 | meta_emb_d_step10: 82 | threshold:0.5774 EER:0.3469 83 | meta_emb_d_step20: 84 | threshold:0.5864 EER:0.3193 85 | meta_emb_d_step50: 86 | threshold:0.5954 EER:0.2983 87 | meta_emb_d_step100: 88 | threshold:0.6009 EER:0.2870 89 | meta_emb_step0: 90 | threshold:0.5280 EER:0.5123 91 | meta_emb_step5: 92 | threshold:0.5284 EER:0.5120 93 | meta_emb_step10: 94 | threshold:0.5290 EER:0.5098 95 | meta_emb_step20: 96 | threshold:0.5289 EER:0.5107 97 | meta_emb_step50: 98 | threshold:0.5295 EER:0.5081 99 | meta_emb_step100: 100 | threshold:0.5287 EER:0.5052 101 | base_emb1_vad_step0: 102 | threshold:0.5259 EER:0.4957 103 | base_emb1_vad_step5: 104 | threshold:0.5266 EER:0.4808 105 | base_emb1_vad_step10: 106 | threshold:0.5301 EER:0.4754 107 | base_emb1_vad_step20: 108 | threshold:0.5387 EER:0.4580 109 | base_emb1_vad_step50: 110 | threshold:0.5567 EER:0.4126 111 | base_emb1_vad_step100: 112 | threshold:0.5746 EER:0.3712 113 | base_emb1_va_step0: 114 | threshold:0.5262 EER:0.4962 115 | base_emb1_va_step5: 116 | threshold:0.5259 EER:0.4902 117 | base_emb1_va_step10: 118 | threshold:0.5302 EER:0.4865 119 | base_emb1_va_step20: 120 | threshold:0.5310 EER:0.4839 121 | base_emb1_va_step50: 122 | threshold:0.5380 EER:0.4672 123 | base_emb1_va_step100: 124 | threshold:0.5452 EER:0.4484 125 | base_emb1_d_step0: 126 | threshold:0.5261 EER:0.4959 127 | base_emb1_d_step5: 128 | threshold:0.5249 EER:0.4861 129 | base_emb1_d_step10: 130 | threshold:0.5282 EER:0.4819 131 | base_emb1_d_step20: 132 | threshold:0.5331 EER:0.4677 133 | base_emb1_d_step50: 134 | threshold:0.5447 EER:0.4413 135 | base_emb1_d_step100: 136 | threshold:0.5563 EER:0.4159 137 | base_emb1_step0: 138 | threshold:0.5262 EER:0.4962 139 | base_emb1_step5: 140 | threshold:0.5264 EER:0.4944 141 | base_emb1_step10: 142 | threshold:0.5262 EER:0.4951 143 | base_emb1_step20: 144 | threshold:0.5253 EER:0.4981 145 | base_emb1_step50: 146 | threshold:0.5262 EER:0.4906 147 | base_emb1_step100: 148 | threshold:0.5273 EER:0.4900 149 | meta_emb1_vad_step0: 150 | threshold:0.5382 EER:0.4951 151 | meta_emb1_vad_step5: 152 | threshold:0.5929 EER:0.2778 153 | meta_emb1_vad_step10: 154 | threshold:0.5951 EER:0.2617 155 | meta_emb1_vad_step20: 156 | threshold:0.5998 EER:0.2522 157 | meta_emb1_vad_step50: 158 | threshold:0.6001 EER:0.2468 159 | meta_emb1_vad_step100: 160 | threshold:0.6042 EER:0.2391 161 | meta_emb1_va_step0: 162 | threshold:0.5360 EER:0.4968 163 | meta_emb1_va_step5: 164 | threshold:0.5761 EER:0.3490 165 | meta_emb1_va_step10: 166 | threshold:0.5761 EER:0.3435 167 | meta_emb1_va_step20: 168 | threshold:0.5756 EER:0.3422 169 | meta_emb1_va_step50: 170 | threshold:0.5730 EER:0.3477 171 | meta_emb1_va_step100: 172 | threshold:0.5753 EER:0.3453 173 | meta_emb1_d_step0: 174 | threshold:0.5009 EER:0.4880 175 | meta_emb1_d_step5: 176 | threshold:0.6037 EER:0.3155 177 | meta_emb1_d_step10: 178 | threshold:0.6066 EER:0.2979 179 | meta_emb1_d_step20: 180 | threshold:0.6055 EER:0.2905 181 | meta_emb1_d_step50: 182 | threshold:0.6063 EER:0.2894 183 | meta_emb1_d_step100: 184 | threshold:0.6075 EER:0.2815 185 | meta_emb1_step0: 186 | threshold:0.5272 EER:0.4902 187 | meta_emb1_step5: 188 | threshold:0.5273 EER:0.4922 189 | meta_emb1_step10: 190 | threshold:0.5268 EER:0.4891 191 | meta_emb1_step20: 192 | threshold:0.5275 EER:0.4880 193 | meta_emb1_step50: 194 | threshold:0.5278 EER:0.4905 195 | meta_emb1_step100: 196 | threshold:0.5277 EER:0.4891 197 | scratch_encoder_step0: 198 | threshold:0.5771 EER:0.3394 199 | encoder_step0: 200 | threshold:0.6245 EER:0.2737 201 | dvec_step0: 202 | threshold:0.6408 EER:0.2159 203 | -------------------------------------------------------------------------------- /evaluation/txt/VCTK/mbnet.txt: -------------------------------------------------------------------------------- 1 | real, 2.603792264436682, 0.02929591085527332 2 | recon, 3.4300844905277095, 0.024522813717607672 3 | base_emb_vad_step0, 3.4205478598122245, 0.02800934817366612 4 | base_emb_vad_step5, 3.1568520387151726, 0.028913171323214307 5 | base_emb_vad_step10, 3.1399767087703503, 0.02825469538795088 6 | base_emb_vad_step20, 3.1452079111779176, 0.028059757466759677 7 | base_emb_vad_step50, 3.1467463188563234, 0.027790707003117605 8 | base_emb_vad_step100, 3.108099532348138, 0.026648564139731865 9 | base_emb_va_step0, 3.416997024927426, 0.028053743373246225 10 | base_emb_va_step5, 3.4102204187462726, 0.028766035413678122 11 | base_emb_va_step10, 3.429574969742033, 0.027295461764901986 12 | base_emb_va_step20, 3.421257048017449, 0.027937227555057176 13 | base_emb_va_step50, 3.4540376626644975, 0.027019311558638823 14 | base_emb_va_step100, 3.419004302540863, 0.026998013728795314 15 | base_emb_d_step0, 3.4205358716210834, 0.028013288171463504 16 | base_emb_d_step5, 3.180630559929543, 0.02908387848823438 17 | base_emb_d_step10, 3.1606844807112657, 0.0282615302661083 18 | base_emb_d_step20, 3.149966186809319, 0.028530315591141872 19 | base_emb_d_step50, 3.1216270206151187, 0.02882104051886327 20 | base_emb_d_step100, 3.113108624945636, 0.028607570326214055 21 | base_emb_step0, 3.416997024927426, 0.028053743373246225 22 | base_emb_step5, 3.426375813506268, 0.027809415450059223 23 | base_emb_step10, 3.4261919211044356, 0.027668876779906912 24 | base_emb_step20, 3.4222459190835557, 0.027674214013142927 25 | base_emb_step50, 3.4171393357769206, 0.02829515036472088 26 | base_emb_step100, 3.4210810434349157, 0.02774494718886563 27 | meta_emb_vad_step0, 3.5774335308621326, 0.028013972819050922 28 | meta_emb_vad_step5, 2.918553749996203, 0.029215120767142506 29 | meta_emb_vad_step10, 2.8573853921973043, 0.029695280749454756 30 | meta_emb_vad_step20, 2.8506717881256782, 0.028931210971310493 31 | meta_emb_vad_step50, 2.8566617050242646, 0.02875380454307618 32 | meta_emb_vad_step100, 2.8916157499231674, 0.028480728757941404 33 | meta_emb_va_step0, 3.664785830648961, 0.023916356968493355 34 | meta_emb_va_step5, 3.4457255877140494, 0.028681490032293852 35 | meta_emb_va_step10, 3.443760226239209, 0.02800346066223213 36 | meta_emb_va_step20, 3.447991964579732, 0.028626778346187095 37 | meta_emb_va_step50, 3.4425465147252434, 0.027394804199692032 38 | meta_emb_va_step100, 3.4382404467711845, 0.027690749716341407 39 | meta_emb_d_step0, 3.5366774903679334, 0.027594872747729288 40 | meta_emb_d_step5, 2.997376237133587, 0.027139958347979703 41 | meta_emb_d_step10, 2.9183262223722757, 0.027778159778449184 42 | meta_emb_d_step20, 2.8640159472685167, 0.027865961429496624 43 | meta_emb_d_step50, 2.860272313433665, 0.028118348975391085 44 | meta_emb_d_step100, 2.840425332663236, 0.027879488119746776 45 | meta_emb_step0, 3.4267570375016443, 0.027332972231362183 46 | meta_emb_step5, 3.4255768412517176, 0.027272712987986784 47 | meta_emb_step10, 3.429027485764689, 0.02708875134841815 48 | meta_emb_step20, 3.425404195777244, 0.027406623070524837 49 | meta_emb_step50, 3.423860019210864, 0.02708018432112285 50 | meta_emb_step100, 3.4206761769536467, 0.026744213844383774 51 | base_emb1_vad_step0, 3.4702669905705585, 0.02659931847923218 52 | base_emb1_vad_step5, 3.2893670589007713, 0.02643735088894779 53 | base_emb1_vad_step10, 3.262023709507452, 0.026029307998880837 54 | base_emb1_vad_step20, 3.2617441578595727, 0.025851319485003445 55 | base_emb1_vad_step50, 3.2690580681104353, 0.02475533168457977 56 | base_emb1_vad_step100, 3.2459616414788695, 0.025221372705622736 57 | base_emb1_va_step0, 3.4691621227258884, 0.026731225330090178 58 | base_emb1_va_step5, 3.466075749722896, 0.02668181610935017 59 | base_emb1_va_step10, 3.470354644857623, 0.02640465439713356 60 | base_emb1_va_step20, 3.482779949972475, 0.026199674444743103 61 | base_emb1_va_step50, 3.5138768651695163, 0.02475455047788744 62 | base_emb1_va_step100, 3.4993907413272947, 0.025110848636036395 63 | base_emb1_d_step0, 3.469883299605162, 0.026600284050598787 64 | base_emb1_d_step5, 3.2887588808382, 0.025962891581505568 65 | base_emb1_d_step10, 3.2612943672747523, 0.026283216453906175 66 | base_emb1_d_step20, 3.2514576177906105, 0.026560020444501585 67 | base_emb1_d_step50, 3.2334895527197256, 0.026253231108981184 68 | base_emb1_d_step100, 3.235178620175079, 0.025795290451668514 69 | base_emb1_step0, 3.4691621227258884, 0.026731225330090178 70 | base_emb1_step5, 3.4707646116062447, 0.02642455633969774 71 | base_emb1_step10, 3.465612700071048, 0.026818220584705656 72 | base_emb1_step20, 3.4647744598074093, 0.026673884099920997 73 | base_emb1_step50, 3.4699253872450857, 0.026464475995280887 74 | base_emb1_step100, 3.4769048825320272, 0.025920973719480545 75 | meta_emb1_vad_step0, 3.0323347462410175, 0.037808794137207204 76 | meta_emb1_vad_step5, 2.4880514896854207, 0.033596332074662685 77 | meta_emb1_vad_step10, 2.511500399214802, 0.0329117789976576 78 | meta_emb1_vad_step20, 2.5333056445750923, 0.032481784862227824 79 | meta_emb1_vad_step50, 2.5089744688184172, 0.03207417578732985 80 | meta_emb1_vad_step100, 2.5388326544866517, 0.03234876446448038 81 | meta_emb1_va_step0, 3.8190324032610214, 0.020976954674332672 82 | meta_emb1_va_step5, 3.4483640460918346, 0.025895668194668035 83 | meta_emb1_va_step10, 3.4450953692473747, 0.025624466833493453 84 | meta_emb1_va_step20, 3.4088451063467398, 0.02615492705749868 85 | meta_emb1_va_step50, 3.4209371298827507, 0.025751843828596587 86 | meta_emb1_va_step100, 3.4238058360362493, 0.025484148943272406 87 | meta_emb1_d_step0, 3.050584825238696, 0.03596101562326014 88 | meta_emb1_d_step5, 2.865401001026233, 0.027579211815109243 89 | meta_emb1_d_step10, 2.8140108813014297, 0.027620378724771102 90 | meta_emb1_d_step20, 2.761705452852227, 0.02755707656684019 91 | meta_emb1_d_step50, 2.7242669525245824, 0.02788408900848567 92 | meta_emb1_d_step100, 2.7201049349236266, 0.028446886906332472 93 | meta_emb1_step0, 3.4838303742171437, 0.025282553545030367 94 | meta_emb1_step5, 3.4821389610706657, 0.025222486672536938 95 | meta_emb1_step10, 3.4811191473984056, 0.025680858858837512 96 | meta_emb1_step20, 3.4842451472801192, 0.02531852352565581 97 | meta_emb1_step50, 3.4846121521873608, 0.025595428477594644 98 | meta_emb1_step100, 3.48271810159915, 0.02548362830930797 99 | -------------------------------------------------------------------------------- /evaluation/txt/VCTK/mosnet.txt: -------------------------------------------------------------------------------- 1 | real, 2.9914067336530596, 0.01724796889154286 2 | recon, 3.8635682521594896, 0.02867461832157345 3 | base_emb_vad_step0, 3.3818420625671193, 0.025721201157375705 4 | base_emb_vad_step5, 3.291588210259323, 0.024989455318100467 5 | base_emb_vad_step10, 3.2834533025959023, 0.024332763829917798 6 | base_emb_vad_step20, 3.286765437159273, 0.025002307592003975 7 | base_emb_vad_step50, 3.3294351819074817, 0.027211353263875786 8 | base_emb_vad_step100, 3.3400339198609195, 0.02822864919384735 9 | base_emb_va_step0, 3.3837837056705244, 0.025756148396111532 10 | base_emb_va_step5, 3.3867477497982756, 0.025510106321355364 11 | base_emb_va_step10, 3.404021498988624, 0.026115340752530525 12 | base_emb_va_step20, 3.4022307305562274, 0.02600037194484692 13 | base_emb_va_step50, 3.4203954502526255, 0.02723834033171941 14 | base_emb_va_step100, 3.425382275203312, 0.02845974930994204 15 | base_emb_d_step0, 3.3818757907935866, 0.025720634193677698 16 | base_emb_d_step5, 3.300832493851582, 0.025196789695751415 17 | base_emb_d_step10, 3.294140972473003, 0.02451680733991757 18 | base_emb_d_step20, 3.29816129486318, 0.024783413118946118 19 | base_emb_d_step50, 3.311159746169492, 0.02527834344237692 20 | base_emb_d_step100, 3.343000315681652, 0.02580843829706604 21 | base_emb_step0, 3.3837837056705244, 0.025756148396111532 22 | base_emb_step5, 3.3901545354337603, 0.0255213842160454 23 | base_emb_step10, 3.3922790843579502, 0.025593752933839177 24 | base_emb_step20, 3.3928047297177493, 0.02555516132589962 25 | base_emb_step50, 3.39462937966541, 0.025499150073422985 26 | base_emb_step100, 3.399532319356998, 0.025617245999271645 27 | meta_emb_vad_step0, 3.2885612144514367, 0.02457238029901364 28 | meta_emb_vad_step5, 3.117598002569543, 0.027552292007502026 29 | meta_emb_vad_step10, 3.1066604367008916, 0.027050395434378337 30 | meta_emb_vad_step20, 3.1164710130542517, 0.027503541033347907 31 | meta_emb_vad_step50, 3.1456200281089104, 0.028076841394706762 32 | meta_emb_vad_step100, 3.1787155948717287, 0.028500123598582325 33 | meta_emb_va_step0, 3.5064448038185083, 0.026979499893524127 34 | meta_emb_va_step5, 3.4084164298105017, 0.028454838895207016 35 | meta_emb_va_step10, 3.409978693972031, 0.029022496718878136 36 | meta_emb_va_step20, 3.4192714829135826, 0.029036097521198023 37 | meta_emb_va_step50, 3.4472676315517337, 0.02989000881422263 38 | meta_emb_va_step100, 3.457302544955854, 0.029704549099748092 39 | meta_emb_d_step0, 3.2671030550091356, 0.024018833196679846 40 | meta_emb_d_step5, 3.1728670229376466, 0.025542221097970375 41 | meta_emb_d_step10, 3.178490358301335, 0.026670105205189983 42 | meta_emb_d_step20, 3.170115514072003, 0.026882432688683445 43 | meta_emb_d_step50, 3.1917620671706066, 0.027535784811300354 44 | meta_emb_d_step100, 3.1996339205652475, 0.027621850324839808 45 | meta_emb_step0, 3.4010300621804266, 0.026080912907923835 46 | meta_emb_step5, 3.4024173166878797, 0.02595192018602425 47 | meta_emb_step10, 3.405437099574893, 0.025928107493238855 48 | meta_emb_step20, 3.4060055592821703, 0.025978488534966464 49 | meta_emb_step50, 3.3962330816796533, 0.02572363205214968 50 | meta_emb_step100, 3.3988231166645333, 0.025581218589343393 51 | base_emb1_vad_step0, 3.307169395464438, 0.02735597083562226 52 | base_emb1_vad_step5, 3.314538378889362, 0.028035095525919162 53 | base_emb1_vad_step10, 3.302908473889585, 0.02784411260705156 54 | base_emb1_vad_step20, 3.320704742093329, 0.02899953702599309 55 | base_emb1_vad_step50, 3.3613764882363655, 0.030810512631431374 56 | base_emb1_vad_step100, 3.4605840901257814, 0.03255763061874665 57 | base_emb1_va_step0, 3.312428825017479, 0.027393061806808207 58 | base_emb1_va_step5, 3.318314754507608, 0.028095932795092444 59 | base_emb1_va_step10, 3.318092281964642, 0.028347003826369847 60 | base_emb1_va_step20, 3.331678691768536, 0.02933970849073485 61 | base_emb1_va_step50, 3.3942113570455046, 0.030913576595935934 62 | base_emb1_va_step100, 3.429963778980352, 0.032112660091454304 63 | base_emb1_d_step0, 3.308364159875998, 0.027402240644681535 64 | base_emb1_d_step5, 3.312498135285245, 0.027510193740361125 65 | base_emb1_d_step10, 3.304473970223356, 0.02711066048555232 66 | base_emb1_d_step20, 3.2990986564783036, 0.027757950105959484 67 | base_emb1_d_step50, 3.312400787792824, 0.027984315531737764 68 | base_emb1_d_step100, 3.3650946877896786, 0.029172371804348868 69 | base_emb1_step0, 3.312428825017479, 0.027393061806808207 70 | base_emb1_step5, 3.310911410660655, 0.027301918992913832 71 | base_emb1_step10, 3.3141138136248895, 0.027556232232169672 72 | base_emb1_step20, 3.312599429515777, 0.027473789021807835 73 | base_emb1_step50, 3.308620948275482, 0.02729968208939933 74 | base_emb1_step100, 3.3177566619382963, 0.02783927460861786 75 | meta_emb1_vad_step0, 2.8113651915832802, 0.01690823814423584 76 | meta_emb1_vad_step5, 2.901348269785996, 0.028376914197852352 77 | meta_emb1_vad_step10, 2.9285528419056424, 0.029073026514891957 78 | meta_emb1_vad_step20, 2.95374071839507, 0.028959971582366326 79 | meta_emb1_vad_step50, 2.968389001640457, 0.029510798349690545 80 | meta_emb1_vad_step100, 2.970130405492253, 0.029438052875342822 81 | meta_emb1_va_step0, 3.334032991832053, 0.028162743526945935 82 | meta_emb1_va_step5, 3.3459207350733102, 0.030108101004989975 83 | meta_emb1_va_step10, 3.3468229518858372, 0.030312563098072363 84 | meta_emb1_va_step20, 3.3226510185610363, 0.029891030184199287 85 | meta_emb1_va_step50, 3.342045980885073, 0.029955446346356045 86 | meta_emb1_va_step100, 3.3398830351178295, 0.029690205449984648 87 | meta_emb1_d_step0, 2.062801736717423, 0.013854937287058002 88 | meta_emb1_d_step5, 3.027584376917393, 0.02547195106571681 89 | meta_emb1_d_step10, 3.020958310545043, 0.026169694170265136 90 | meta_emb1_d_step20, 3.006842028901533, 0.0260638625869649 91 | meta_emb1_d_step50, 2.99890773964149, 0.02552307509447711 92 | meta_emb1_d_step100, 3.0030143545181662, 0.025936676449024552 93 | meta_emb1_step0, 3.3062533606533653, 0.02810278032531253 94 | meta_emb1_step5, 3.303396538451866, 0.028066807735458315 95 | meta_emb1_step10, 3.3023921924608723, 0.028260776348371566 96 | meta_emb1_step20, 3.2959860599151365, 0.028440538674798697 97 | meta_emb1_step50, 3.297571825346461, 0.02815199990720343 98 | meta_emb1_step100, 3.2957564244667688, 0.028338068232996637 99 | -------------------------------------------------------------------------------- /lightning/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .progressbar import GlobalProgressBar 2 | from .saver import Saver 3 | -------------------------------------------------------------------------------- /lightning/callbacks/progressbar.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | 3 | import sys 4 | import torch 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.callbacks import Callback 7 | from pytorch_lightning.callbacks.progress import ProgressBarBase 8 | 9 | 10 | class GlobalProgressBar(Callback): 11 | """Global progress bar. 12 | TODO: add progress bar for training, validation and testing loop. 13 | """ 14 | 15 | def __init__(self, global_progress: bool = True, leave_global_progress: bool = True, process_position=0): 16 | super().__init__() 17 | 18 | self.global_progress = global_progress 19 | self.global_desc = "Steps: {steps}/{max_steps}" 20 | self.leave_global_progress = leave_global_progress 21 | self.global_pb = None 22 | self.process_position = process_position 23 | 24 | def on_train_start(self, trainer, pl_module): 25 | if pl_module.local_rank == 0: 26 | desc = self.global_desc.format(steps=pl_module.global_step + 1, max_steps=trainer.max_steps) 27 | 28 | self.global_pb = tqdm( 29 | desc=desc, 30 | dynamic_ncols=True, 31 | total=trainer.max_steps, 32 | initial=pl_module.global_step, 33 | leave=self.leave_global_progress, 34 | disable=not self.global_progress, 35 | position=self.process_position, 36 | file=sys.stdout, 37 | ) 38 | 39 | def on_train_end(self, trainer, pl_module): 40 | if pl_module.local_rank == 0: 41 | self.global_pb.close() 42 | self.global_pb = None 43 | 44 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 45 | if pl_module.local_rank == 0: 46 | 47 | # Set description 48 | desc = self.global_desc.format(steps=pl_module.global_step + 1, max_steps=trainer.max_steps) 49 | self.global_pb.set_description(desc) 50 | 51 | # Update progress 52 | if (pl_module.global_step+1) % trainer.accumulate_grad_batches == 0: 53 | self.global_pb.update(1) 54 | 55 | -------------------------------------------------------------------------------- /lightning/callbacks/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import matplotlib 4 | matplotlib.use("Agg") 5 | from matplotlib import pyplot as plt 6 | from scipy.io import wavfile 7 | 8 | from utils.tools import expand, plot_mel 9 | 10 | 11 | def synth_one_sample_with_target(targets, predictions, vocoder, preprocess_config): 12 | """Synthesize the first sample of the batch given target pitch/duration/energy.""" 13 | basename = targets[0][0] 14 | src_len = predictions[8][0].item() 15 | mel_len = predictions[9][0].item() 16 | mel_target = targets[6][0, :mel_len].detach().transpose(0, 1) 17 | duration = targets[11][0, :src_len].detach().cpu().numpy() 18 | pitch = targets[9][0, :src_len].detach().cpu().numpy() 19 | energy = targets[10][0, :src_len].detach().cpu().numpy() 20 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) 21 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 22 | pitch = expand(pitch, duration) 23 | else: 24 | pitch = targets[9][0, :mel_len].detach().cpu().numpy() 25 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 26 | energy = expand(energy, duration) 27 | else: 28 | energy = targets[10][0, :mel_len].detach().cpu().numpy() 29 | 30 | with open( 31 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 32 | ) as f: 33 | stats = json.load(f) 34 | stats = stats["pitch"] + stats["energy"][:2] 35 | 36 | fig = plot_mel( 37 | [ 38 | (mel_prediction.cpu().numpy(), pitch, energy), 39 | (mel_target.cpu().numpy(), pitch, energy), 40 | ], 41 | stats, 42 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 43 | ) 44 | 45 | if vocoder.mel2wav is not None: 46 | max_wav_value = preprocess_config["preprocessing"]["audio"]["max_wav_value"] 47 | 48 | wav_reconstruction = vocoder.infer(mel_target.unsqueeze(0), max_wav_value)[0] 49 | wav_prediction = vocoder.infer(mel_prediction.unsqueeze(0), max_wav_value)[0] 50 | else: 51 | wav_reconstruction = wav_prediction = None 52 | 53 | return fig, wav_reconstruction, wav_prediction, basename 54 | 55 | def recon_samples(targets, predictions, vocoder, preprocess_config, figure_dir, audio_dir): 56 | """Reconstruct all samples of the batch.""" 57 | for i in range(len(predictions[0])): 58 | basename = targets[0][i] 59 | src_len = predictions[8][i].item() 60 | mel_len = predictions[9][i].item() 61 | mel_target = targets[6][i, :mel_len].detach().transpose(0, 1) 62 | duration = targets[11][i, :src_len].detach().cpu().numpy() 63 | pitch = targets[9][i, :src_len].detach().cpu().numpy() 64 | energy = targets[10][i, :src_len].detach().cpu().numpy() 65 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 66 | pitch = expand(pitch, duration) 67 | else: 68 | pitch = targets[9][i, :mel_len].detach().cpu().numpy() 69 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 70 | energy = expand(energy, duration) 71 | else: 72 | energy = targets[10][i, :mel_len].detach().cpu().numpy() 73 | 74 | with open( 75 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 76 | ) as f: 77 | stats = json.load(f) 78 | stats = stats["pitch"] + stats["energy"][:2] 79 | 80 | fig = plot_mel( 81 | [ 82 | (mel_target.cpu().numpy(), pitch, energy), 83 | ], 84 | stats, 85 | ["Ground-Truth Spectrogram"], 86 | ) 87 | plt.savefig(os.path.join(figure_dir, f"{basename}.target.png")) 88 | plt.close() 89 | 90 | mel_targets = targets[6].transpose(1, 2) 91 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] 92 | max_wav_value = preprocess_config["preprocessing"]["audio"]["max_wav_value"] 93 | wav_targets = vocoder.infer(mel_targets, max_wav_value, lengths=lengths) 94 | 95 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 96 | for wav, basename in zip(wav_targets, targets[0]): 97 | wavfile.write(os.path.join(audio_dir, f"{basename}.recon.wav"), sampling_rate, wav) 98 | 99 | def synth_samples(targets, predictions, vocoder, preprocess_config, figure_dir, audio_dir, name): 100 | """Synthesize the first sample of the batch.""" 101 | for i in range(len(predictions[0])): 102 | basename = targets[0][i] 103 | src_len = predictions[8][i].item() 104 | mel_len = predictions[9][i].item() 105 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) 106 | duration = predictions[5][i, :src_len].detach().cpu().numpy() 107 | pitch = predictions[2][i, :src_len].detach().cpu().numpy() 108 | energy = predictions[3][i, :src_len].detach().cpu().numpy() 109 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 110 | pitch = expand(pitch, duration) 111 | else: 112 | pitch = targets[9][i, :mel_len].detach().cpu().numpy() 113 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 114 | energy = expand(energy, duration) 115 | else: 116 | energy = targets[10][i, :mel_len].detach().cpu().numpy() 117 | 118 | with open( 119 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 120 | ) as f: 121 | stats = json.load(f) 122 | stats = stats["pitch"] + stats["energy"][:2] 123 | 124 | fig = plot_mel( 125 | [ 126 | (mel_prediction.cpu().numpy(), pitch, energy), 127 | ], 128 | stats, 129 | ["Synthetized Spectrogram"], 130 | ) 131 | plt.savefig(os.path.join(figure_dir, f"{basename}.{name}.synth.png")) 132 | plt.close() 133 | 134 | mel_predictions = predictions[1].transpose(1, 2) 135 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] 136 | max_wav_value = preprocess_config["preprocessing"]["audio"]["max_wav_value"] 137 | wav_predictions = vocoder.infer(mel_predictions, max_wav_value, lengths=lengths) 138 | 139 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 140 | for wav, basename in zip(wav_predictions, targets[0]): 141 | wavfile.write(os.path.join(audio_dir, f"{basename}.{name}.synth.wav"), sampling_rate, wav) 142 | 143 | -------------------------------------------------------------------------------- /lightning/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_datamodule import BaseDataModule 2 | from .baseline_datamodule import BaselineDataModule 3 | from .meta_datamodule import MetaDataModule 4 | 5 | 6 | DATA_MODULE = { 7 | "base": BaseDataModule, 8 | "meta": MetaDataModule, 9 | "imaml": MetaDataModule, 10 | "baseline": BaselineDataModule, 11 | } 12 | 13 | def get_datamodule(algorithm): 14 | return DATA_MODULE[algorithm] 15 | -------------------------------------------------------------------------------- /lightning/datamodules/base_datamodule.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import ConcatDataset 2 | import pytorch_lightning as pl 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from dataset import MonolingualTTSDataset as Dataset 7 | from lightning.collate import get_single_collate 8 | 9 | 10 | class BaseDataModule(pl.LightningDataModule): 11 | def __init__(self, preprocess_configs, train_config, algorithm_config, log_dir, result_dir): 12 | super().__init__() 13 | self.preprocess_configs = preprocess_configs 14 | self.train_config = train_config 15 | self.algorithm_config = algorithm_config 16 | 17 | self.log_dir = log_dir 18 | self.result_dir = result_dir 19 | 20 | 21 | def setup(self, stage=None): 22 | spk_refer_wav = (self.algorithm_config["adapt"]["speaker_emb"] 23 | in ["dvec", "encoder", "scratch_encoder"]) 24 | 25 | if stage in (None, 'fit', 'validate'): 26 | self.train_datasets = [ 27 | Dataset( 28 | f"{preprocess_config['subsets']['train']}.txt", 29 | preprocess_config, self.train_config, sort=True, drop_last=True, spk_refer_wav=spk_refer_wav 30 | ) for preprocess_config in self.preprocess_configs 31 | ] 32 | self.val_datasets = [ 33 | Dataset( 34 | f"{preprocess_config['subsets']['val']}.txt", 35 | preprocess_config, self.train_config, sort=False, drop_last=False, spk_refer_wav=spk_refer_wav 36 | ) for preprocess_config in self.preprocess_configs 37 | ] 38 | 39 | if stage in (None, 'test', 'predict'): 40 | self.test_datasets = [ 41 | Dataset( 42 | f"{preprocess_config['subsets']['test']}.txt", 43 | preprocess_config, self.train_config, sort=False, drop_last=False, spk_refer_wav=spk_refer_wav 44 | ) for preprocess_config in self.preprocess_configs 45 | ] 46 | 47 | 48 | def train_dataloader(self): 49 | """Training dataloader, not modified for multiple dataloaders.""" 50 | batch_size = self.train_config["optimizer"]["batch_size"] 51 | self.train_dataset = ConcatDataset(self.train_datasets) 52 | self.train_loader = DataLoader( 53 | self.train_dataset, 54 | batch_size=batch_size//torch.cuda.device_count(), 55 | shuffle=True, 56 | drop_last=True, 57 | num_workers=4, 58 | collate_fn=get_single_collate(False), 59 | ) 60 | return self.train_loader 61 | 62 | def val_dataloader(self): 63 | """Validation dataloader, not modified for multiple dataloaders.""" 64 | batch_size = self.train_config["optimizer"]["batch_size"] 65 | self.val_dataset = ConcatDataset(self.val_datasets) 66 | self.val_loader = DataLoader( 67 | self.val_dataset, 68 | batch_size=batch_size//torch.cuda.device_count(), 69 | shuffle=False, 70 | drop_last=False, 71 | num_workers=4, 72 | collate_fn=get_single_collate(False), 73 | ) 74 | return self.val_loader 75 | -------------------------------------------------------------------------------- /lightning/datamodules/baseline_datamodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from torch.utils.data import DataLoader, ConcatDataset 5 | 6 | from lightning.collate import get_single_collate 7 | from lightning.utils import seed_all, EpisodicInfiniteWrapper 8 | 9 | from .base_datamodule import BaseDataModule 10 | from .utils import few_shot_task_dataset, prefetch_tasks 11 | 12 | 13 | class BaselineDataModule(BaseDataModule): 14 | def __init__(self, preprocess_config, train_config, algorithm_config, log_dir, result_dir): 15 | super().__init__(preprocess_config, train_config, algorithm_config, log_dir, result_dir) 16 | self.meta_type = self.algorithm_config["adapt"]["type"] 17 | 18 | self.train_ways = self.algorithm_config["adapt"]["train"]["ways"] 19 | self.train_shots = self.algorithm_config["adapt"]["train"]["shots"] 20 | self.train_queries = self.algorithm_config["adapt"]["train"]["queries"] 21 | 22 | self.test_ways = self.algorithm_config["adapt"]["test"]["ways"] 23 | self.test_shots = self.algorithm_config["adapt"]["test"]["shots"] 24 | self.test_queries = self.algorithm_config["adapt"]["test"]["queries"] 25 | 26 | self.meta_batch_size = self.algorithm_config["adapt"]["train"]["meta_batch_size"] 27 | self.val_step = self.train_config["step"]["val_step"] 28 | 29 | 30 | def setup(self, stage=None): 31 | super().setup(stage) 32 | # pl.seed_everything(43, True) 33 | 34 | if stage in (None, 'fit', 'validate'): 35 | self._train_setup() 36 | self._validation_setup() 37 | 38 | if stage in (None, 'test', 'predict'): 39 | self._test_setup() 40 | 41 | 42 | def _train_setup(self): 43 | self.train_dataset = ConcatDataset(self.train_datasets) 44 | if not isinstance(self.train_dataset, EpisodicInfiniteWrapper): 45 | self.batch_size = self.train_ways * (self.train_shots + self.train_queries) * self.meta_batch_size 46 | self.train_dataset = EpisodicInfiniteWrapper(self.train_dataset, self.val_step*self.batch_size) 47 | 48 | 49 | def _validation_setup(self): 50 | self.val_dataset = ConcatDataset(self.val_datasets) 51 | self.val_task_dataset = few_shot_task_dataset( 52 | self.val_dataset, self.test_ways, self.test_shots, self.test_queries, 53 | n_tasks_per_label=8, type=self.meta_type 54 | ) 55 | with seed_all(43): 56 | self.val_SQids2Tid = prefetch_tasks(self.val_task_dataset, 'val', self.log_dir) 57 | 58 | 59 | def _test_setup(self): 60 | self.test_dataset = ConcatDataset(self.test_datasets) 61 | self.test_task_dataset = few_shot_task_dataset( 62 | self.test_dataset, self.test_ways, self.test_shots, self.test_queries, 63 | n_tasks_per_label=16, type=self.meta_type 64 | ) 65 | with seed_all(43): 66 | self.test_SQids2Tid = prefetch_tasks(self.test_task_dataset, 'test', self.result_dir) 67 | 68 | 69 | def train_dataloader(self): 70 | """Training dataloader""" 71 | self.train_loader = DataLoader( 72 | self.train_dataset, 73 | batch_size=self.batch_size//torch.cuda.device_count(), 74 | shuffle=True, 75 | drop_last=True, 76 | num_workers=4, 77 | collate_fn=get_single_collate(False), 78 | ) 79 | return self.train_loader 80 | 81 | 82 | def val_dataloader(self): 83 | """Validation dataloader""" 84 | self.val_loader = DataLoader( 85 | self.val_task_dataset, 86 | batch_size=1, 87 | shuffle=False, 88 | num_workers=0, 89 | collate_fn=lambda batch: batch, 90 | ) 91 | return self.val_loader 92 | 93 | 94 | def test_dataloader(self): 95 | """Test dataloader""" 96 | self.test_loader = DataLoader( 97 | self.test_task_dataset, 98 | batch_size=1, 99 | shuffle=False, 100 | num_workers=0, 101 | collate_fn=lambda batch: batch, 102 | ) 103 | return self.test_loader 104 | -------------------------------------------------------------------------------- /lightning/datamodules/define.py: -------------------------------------------------------------------------------- 1 | # NOTE: this should be move to other place 2 | 3 | from text.symbols import symbols 4 | 5 | LANG_ID2SYMBOLS = { 6 | 0: symbols, 7 | 1: symbols 8 | } 9 | -------------------------------------------------------------------------------- /lightning/datamodules/meta_datamodule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | from torch.utils.data import DataLoader, ConcatDataset 4 | 5 | from .baseline_datamodule import BaselineDataModule 6 | from .utils import few_shot_task_dataset 7 | 8 | 9 | class MetaDataModule(BaselineDataModule): 10 | def __init__(self, preprocess_config, train_config, algorithm_config, log_dir, result_dir): 11 | super().__init__(preprocess_config, train_config, algorithm_config, log_dir, result_dir) 12 | 13 | 14 | def setup(self, stage=None): 15 | super(BaselineDataModule, self).setup(stage) 16 | # pl.seed_everything(43, True) 17 | 18 | if stage in (None, 'fit', 'validate'): 19 | self._train_setup() 20 | self._validation_setup() 21 | 22 | if stage in (None, 'test', 'predict'): 23 | self._test_setup() 24 | 25 | 26 | def _train_setup(self): 27 | epoch_length = self.meta_batch_size * self.val_step 28 | self.train_dataset = ConcatDataset(self.train_datasets) 29 | 30 | self.train_task_dataset = few_shot_task_dataset( 31 | self.train_dataset, self.train_ways, self.train_shots, self.train_queries, 32 | n_tasks_per_label=-1, epoch_length=epoch_length, type=self.meta_type 33 | ) 34 | 35 | 36 | def train_dataloader(self): 37 | """Training dataloader""" 38 | self.train_loader = DataLoader( 39 | self.train_task_dataset, 40 | batch_size=1, 41 | shuffle=True, 42 | num_workers=4, 43 | collate_fn=lambda batch: batch, 44 | ) 45 | return self.train_loader 46 | -------------------------------------------------------------------------------- /lightning/datamodules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import ConcatDataset 5 | from learn2learn.data import MetaDataset, TaskDataset 6 | from learn2learn.data.transforms import FusedNWaysKShots, LoadData 7 | from learn2learn.data.task_dataset import DataDescription 8 | from learn2learn.utils.lightning import EpisodicBatcher 9 | 10 | from lightning.collate import SpeakerTaskCollate, LanguageTaskCollate 11 | from .define import LANG_ID2SYMBOLS 12 | 13 | 14 | def few_shot_task_dataset(_dataset, ways, shots, queries, n_tasks_per_label=-1, epoch_length=-1, type="spk"): 15 | """ 16 | _dataset is already a `ConcatDataset` 17 | """ 18 | if type == "spk": 19 | id2lb = get_multispeaker_id2lb(_dataset.datasets) 20 | _collate = SpeakerTaskCollate() 21 | else: 22 | id2lb = get_multilingual_id2lb(_dataset.datasets) 23 | _collate = LanguageTaskCollate({ 24 | "lang_id2symbols": LANG_ID2SYMBOLS, 25 | "representation_dim": 1024, 26 | }) 27 | 28 | meta_dataset = MetaDataset(_dataset, indices_to_labels=id2lb) 29 | 30 | if n_tasks_per_label > 0: 31 | # For val/test, constant number of tasks per label 32 | tasks = [] 33 | for label, indices in meta_dataset.labels_to_indices.items(): 34 | if len(indices) >= shots+queries: 35 | # 1-way-K-shots-Q-queries transforms per label 36 | transforms = [ 37 | FusedNWaysKShots(meta_dataset, n=ways, k=shots+queries, 38 | replacement=False, filter_labels=[label]), 39 | LoadData(meta_dataset), 40 | ] 41 | # 1-way-K-shots-Q-queries task dataset 42 | _tasks = TaskDataset( 43 | meta_dataset, task_transforms=transforms, num_tasks=n_tasks_per_label, 44 | task_collate=_collate.get_meta_collate(shots, queries), 45 | ) 46 | tasks.append(_tasks) 47 | tasks = ConcatDataset(tasks) 48 | 49 | else: 50 | # For train, dynamic tasks 51 | # 1-way-K-shots-Q-queries transforms 52 | transforms = [ 53 | FusedNWaysKShots(meta_dataset, n=ways, k=shots+queries, replacement=True), 54 | LoadData(meta_dataset), 55 | ] 56 | # 1-way-K-shots-Q-queries task dataset 57 | tasks = TaskDataset( 58 | meta_dataset, task_transforms=transforms, 59 | task_collate=_collate.get_meta_collate(shots, queries), 60 | ) 61 | if epoch_length > 0: 62 | # Epochify task dataset, for periodic validation 63 | tasks = EpisodicBatcher(tasks, epoch_length=epoch_length).train_dataloader() 64 | 65 | return tasks 66 | 67 | 68 | def load_descriptions(tasks, filename): 69 | with open(filename, 'r') as f: 70 | loaded_descriptions = json.load(f) 71 | assert len(tasks.datasets) == len(loaded_descriptions), "TaskDataset count mismatch" 72 | 73 | for i, _tasks in enumerate(tasks.datasets): 74 | descriptions = loaded_descriptions[i] 75 | assert len(descriptions) == _tasks.num_tasks, "num_tasks mismatch" 76 | for j in descriptions: 77 | data_descriptions = [DataDescription(index) for index in descriptions[j]] 78 | task_descriptions = _tasks.task_transforms[-1](data_descriptions) 79 | _tasks.sampled_descriptions[int(j)] = task_descriptions 80 | 81 | 82 | def write_descriptions(tasks, filename): 83 | descriptions = [] 84 | for ds in tasks.datasets: 85 | data_descriptions = {} 86 | for i in ds.sampled_descriptions: 87 | data_descriptions[i] = [desc.index for desc in ds.sampled_descriptions[i]] 88 | descriptions.append(data_descriptions) 89 | 90 | with open(filename, 'w') as f: 91 | json.dump(descriptions, f, indent=4) 92 | 93 | 94 | def load_SQids2Tid(SQids_filename, tag): 95 | with open(SQids_filename, 'r') as f: 96 | SQids = json.load(f) 97 | SQids2Tid = {} 98 | for i, SQids_dict in enumerate(SQids): 99 | sup_ids, qry_ids = SQids_dict['sup_id'], SQids_dict['qry_id'] 100 | SQids2Tid[f"{'-'.join(sup_ids)}.{'-'.join(qry_ids)}"] = f"{tag}_{i:03d}" 101 | return SQids, SQids2Tid 102 | 103 | 104 | def get_SQids2Tid(tasks, tag): 105 | SQids = [] 106 | SQids2Tid = {} 107 | for i, task in enumerate(tasks): 108 | sup_ids, qry_ids = task[0][0][0], task[1][0][0] 109 | SQids.append({'sup_id': sup_ids, 'qry_id': qry_ids}) 110 | SQids2Tid[f"{'-'.join(sup_ids)}.{'-'.join(qry_ids)}"] = f"{tag}_{i:03d}" 111 | return SQids, SQids2Tid 112 | 113 | 114 | def prefetch_tasks(tasks, tag='val', log_dir=''): 115 | if (os.path.exists(os.path.join(log_dir, f'{tag}_descriptions.json')) 116 | and os.path.exists(os.path.join(log_dir, f'{tag}_SQids.json'))): 117 | # Recover descriptions 118 | load_descriptions(tasks, os.path.join(log_dir, f'{tag}_descriptions.json')) 119 | SQids, SQids2Tid = load_SQids2Tid(os.path.join(log_dir, f'{tag}_SQids.json'), tag) 120 | 121 | else: 122 | os.makedirs(log_dir, exist_ok=True) 123 | 124 | # Run through tasks to get descriptions 125 | SQids, SQids2Tid = get_SQids2Tid(tasks, tag) 126 | with open(os.path.join(log_dir, f"{tag}_SQids.json"), 'w') as f: 127 | json.dump(SQids, f, indent=4) 128 | write_descriptions(tasks, os.path.join(log_dir, f"{tag}_descriptions.json")) 129 | 130 | return SQids2Tid 131 | 132 | 133 | def get_multispeaker_id2lb(datasets): 134 | id2lb = {} 135 | total = 0 136 | for dataset in datasets: 137 | l = len(dataset) 138 | id2lb.update({k: f"corpus_{dataset.lang_id}-spk_{dataset.speaker[k - total]}" 139 | for k in range(total, total + l)}) 140 | total += l 141 | 142 | return id2lb 143 | 144 | 145 | def get_multilingual_id2lb(datasets): 146 | id2lb = {} 147 | total = 0 148 | for dataset in datasets: 149 | l = len(dataset) 150 | id2lb.update({k: dataset.lang_id for k in range(total, total + l)}) 151 | total += l 152 | 153 | return id2lb 154 | -------------------------------------------------------------------------------- /lightning/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2 import FastSpeech2 2 | from .loss import FastSpeech2Loss 3 | from .optimizer import ScheduledOptim -------------------------------------------------------------------------------- /lightning/model/fastspeech2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | 9 | from transformer import Encoder, Decoder, PostNet 10 | from .modules import VarianceAdaptor 11 | from .speaker_encoder import SpeakerEncoder 12 | from .phoneme_embedding import PhonemeEmbedding 13 | from utils.tools import get_mask_from_lengths 14 | 15 | 16 | class FastSpeech2(pl.LightningModule): 17 | """ FastSpeech2 """ 18 | 19 | def __init__(self, preprocess_config, model_config, algorithm_config): 20 | super(FastSpeech2, self).__init__() 21 | self.model_config = model_config 22 | 23 | self.encoder = Encoder(model_config) 24 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config) 25 | self.decoder = Decoder(model_config) 26 | self.mel_linear = nn.Linear( 27 | model_config["transformer"]["decoder_hidden"], 28 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 29 | ) 30 | self.postnet = PostNet() 31 | 32 | # If not using multi-speaker, would return None 33 | self.speaker_emb = SpeakerEncoder(preprocess_config, model_config, algorithm_config) 34 | 35 | if algorithm_config["adapt"]["type"] == "lang": 36 | self.phn_emb_generator = PhonemeEmbedding(model_config, algorithm_config) 37 | print("PhonemeEmbedding", self.phn_emb_generator) 38 | 39 | 40 | def forward( 41 | self, 42 | speaker_args, 43 | texts, 44 | src_lens, 45 | max_src_len, 46 | mels=None, 47 | mel_lens=None, 48 | max_mel_len=None, 49 | p_targets=None, 50 | e_targets=None, 51 | d_targets=None, 52 | p_control=1.0, 53 | e_control=1.0, 54 | d_control=1.0, 55 | ): 56 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 57 | mel_masks = ( 58 | get_mask_from_lengths(mel_lens, max_mel_len) 59 | if mel_lens is not None 60 | else None 61 | ) 62 | 63 | output = self.encoder(texts, src_masks) 64 | 65 | if self.speaker_emb is not None: 66 | output = output + self.speaker_emb(speaker_args).unsqueeze(1).expand( 67 | -1, max_src_len, -1 68 | ) 69 | 70 | ( 71 | output, 72 | p_predictions, 73 | e_predictions, 74 | log_d_predictions, 75 | d_rounded, 76 | mel_lens, 77 | mel_masks, 78 | ) = self.variance_adaptor( 79 | output, 80 | src_masks, 81 | mel_masks, 82 | max_mel_len, 83 | p_targets, 84 | e_targets, 85 | d_targets, 86 | p_control, 87 | e_control, 88 | d_control, 89 | ) 90 | 91 | if self.speaker_emb is not None: 92 | output = output + self.speaker_emb(speaker_args).unsqueeze(1).expand( 93 | -1, max(mel_lens), -1 94 | ) 95 | 96 | output, mel_masks = self.decoder(output, mel_masks) 97 | output = self.mel_linear(output) 98 | 99 | postnet_output = self.postnet(output) + output 100 | 101 | return ( 102 | output, 103 | postnet_output, 104 | p_predictions, 105 | e_predictions, 106 | log_d_predictions, 107 | d_rounded, 108 | src_masks, 109 | mel_masks, 110 | src_lens, 111 | mel_lens, 112 | ) 113 | -------------------------------------------------------------------------------- /lightning/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FastSpeech2Loss(nn.Module): 6 | """ FastSpeech2 Loss """ 7 | 8 | def __init__(self, preprocess_config, model_config): 9 | super(FastSpeech2Loss, self).__init__() 10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 11 | "feature" 12 | ] 13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 14 | "feature" 15 | ] 16 | self.mse_loss = nn.MSELoss() 17 | self.mae_loss = nn.L1Loss() 18 | 19 | def forward(self, inputs, predictions): 20 | ( 21 | mel_targets, 22 | _, 23 | _, 24 | pitch_targets, 25 | energy_targets, 26 | duration_targets, 27 | ) = inputs[6:] 28 | ( 29 | mel_predictions, 30 | postnet_mel_predictions, 31 | pitch_predictions, 32 | energy_predictions, 33 | log_duration_predictions, 34 | _, 35 | src_masks, 36 | mel_masks, 37 | _, 38 | _, 39 | ) = predictions 40 | src_masks = ~src_masks 41 | mel_masks = ~mel_masks 42 | log_duration_targets = torch.log(duration_targets.float() + 1) 43 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 44 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 45 | 46 | log_duration_targets.requires_grad = False 47 | pitch_targets.requires_grad = False 48 | energy_targets.requires_grad = False 49 | mel_targets.requires_grad = False 50 | 51 | if self.pitch_feature_level == "phoneme_level": 52 | pitch_predictions = pitch_predictions.masked_select(src_masks) 53 | pitch_targets = pitch_targets.masked_select(src_masks) 54 | elif self.pitch_feature_level == "frame_level": 55 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 56 | pitch_targets = pitch_targets.masked_select(mel_masks) 57 | 58 | if self.energy_feature_level == "phoneme_level": 59 | energy_predictions = energy_predictions.masked_select(src_masks) 60 | energy_targets = energy_targets.masked_select(src_masks) 61 | if self.energy_feature_level == "frame_level": 62 | energy_predictions = energy_predictions.masked_select(mel_masks) 63 | energy_targets = energy_targets.masked_select(mel_masks) 64 | 65 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 66 | log_duration_targets = log_duration_targets.masked_select(src_masks) 67 | 68 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 69 | postnet_mel_predictions = postnet_mel_predictions.masked_select( 70 | mel_masks.unsqueeze(-1) 71 | ) 72 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 73 | 74 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 75 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) 76 | 77 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 78 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 79 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) 80 | 81 | total_loss = ( 82 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss 83 | ) 84 | 85 | return ( 86 | total_loss, 87 | mel_loss, 88 | postnet_mel_loss, 89 | pitch_loss, 90 | energy_loss, 91 | duration_loss, 92 | ) 93 | -------------------------------------------------------------------------------- /lightning/model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, train_config, model_config, current_step): 9 | 10 | self._optimizer = torch.optim.Adam( 11 | model.parameters(), 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"] 19 | self.current_step = current_step 20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 21 | 22 | def step_and_update_lr(self): 23 | self._update_learning_rate() 24 | self._optimizer.step() 25 | 26 | def zero_grad(self): 27 | # print(self.init_lr) 28 | self._optimizer.zero_grad() 29 | 30 | def load_state_dict(self, path): 31 | self._optimizer.load_state_dict(path) 32 | 33 | def _get_lr_scale(self): 34 | lr = np.min( 35 | [ 36 | np.power(self.current_step, -0.5), 37 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 38 | ] 39 | ) 40 | for s in self.anneal_steps: 41 | if self.current_step > s: 42 | lr = lr * self.anneal_rate 43 | return lr 44 | 45 | def _update_learning_rate(self): 46 | """ Learning rate scheduling per step """ 47 | self.current_step += 1 48 | lr = self.init_lr * self._get_lr_scale() 49 | 50 | for param_group in self._optimizer.param_groups: 51 | param_group["lr"] = lr 52 | -------------------------------------------------------------------------------- /lightning/model/phoneme_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | import pytorch_lightning as pl 11 | 12 | import transformer.Constants as Constants 13 | from transformer import Encoder, Decoder, PostNet 14 | from transformer.Modules import ScaledDotProductAttention 15 | from text.symbols import symbols 16 | 17 | 18 | class PhonemeEmbedding(pl.LightningModule): 19 | # NOTE: 20 | # Tested: 21 | # - hard att 22 | # TODO: 23 | # - soft att 24 | # - share bank 25 | # - embedding mode 26 | 27 | def __init__(self, model_config, algorithm_config): 28 | """ 29 | All the data members would be moved to cuda together with the whole model 30 | by pytorch_lightning default settings. 31 | Directly replacing `model.encoder.src_word_emb` with this class would do 32 | no harm to the `model.encoder` forwarding. 33 | """ 34 | super().__init__() 35 | self.emb_type = algorithm_config["adapt"]["phoneme_emb"]["type"] 36 | 37 | n_src_vocab = len(symbols) + 1 38 | d_word_vec = model_config["transformer"]["encoder_hidden"] 39 | self.d_word_vec = d_word_vec 40 | 41 | if self.emb_type == "embedding": 42 | # TODO 43 | pass 44 | 45 | elif self.emb_type == "codebook": 46 | self.codebook_config = algorithm_config["adapt"]["phoneme_emb"] 47 | codebook_size = self.codebook_config["size"] 48 | 49 | self.emb_banks = nn.Parameter(torch.randn(codebook_size, d_word_vec)) 50 | 51 | att_config = self.codebook_config["attention"] 52 | d_feat = self.codebook_config["representation_dim"] 53 | 54 | if att_config["type"] == "hard": 55 | # One-hot similarity 56 | 57 | # feats <-> att_banks -> token_id -> emb_banks 58 | self.att_banks = nn.Parameter(torch.randn(codebook_size, d_feat)) 59 | # TODO: init from SSL-feature centroids 60 | 61 | elif att_config["type"] == "soft": 62 | # Attention layer 63 | # key: att_banks 64 | # value: emb_banks 65 | # query: refs 66 | 67 | # att(feats, att_banks) -> token_id weights -> emb_banks 68 | self.att_banks = nn.Parameter(torch.randn(codebook_size, d_word_vec)) 69 | if att_config["share"]: 70 | # TPU shared weights are copied independently 71 | # on the XLA device and this line won't have any effect. 72 | # However, it works fine for CPU and GPU. 73 | self.att_banks.weight = self.emb_banks.weight 74 | 75 | self.w_qs = nn.Linear(d_feat, d_word_vec) 76 | self.w_ks = nn.Linear(d_word_vec, d_word_vec) 77 | self.attention = ScaledDotProductAttention( 78 | temperature=np.power(d_k, 0.5) 79 | ) 80 | 81 | 82 | def on_post_move_to_device(self): 83 | if (hasattr(self, 'codebook_config') 84 | and self.codebook_config["attention"]["type"] == "soft" 85 | and self.codebook_config["attention"]["share"]): 86 | # Weights shared after the model has been moved to TPU Device 87 | self.att_banks.weight = self.emb_banks.weight 88 | 89 | 90 | def get_new_embedding(self, ref): 91 | """ Compute binary quantize matrix from reference representations. 92 | Should run before inner_loop update. 93 | 94 | Args: 95 | ref: Reference representations with size (vocab_size, codebook_size), 96 | where vocab_size <= n_src_vocab. 97 | Assert the tensor is already moved to cuda by pytorch_lightning. 98 | """ 99 | try: 100 | assert ref.device == self.device 101 | except: 102 | ref = ref.to(device=self.device) 103 | 104 | if self.codebook_config["attention"]["type"] == "hard": 105 | ref_norm = ref.norm(dim=1, keepdim=True) 106 | ref_mask, _ = torch.nonzero(ref_norm, as_tuple=True) 107 | normed_ref = ref[ref_mask] / ref_norm[ref_mask] 108 | 109 | bank_norms = self.att_banks.norm(dim=1, keepdim=True) 110 | normed_banks = self.att_banks / bank_norms 111 | 112 | similarity = normed_ref @ normed_banks.T 113 | 114 | with torch.no_grad(): 115 | weighting_matrix = torch.zeros( 116 | ref.shape[0], self.att_banks.shape[0], 117 | device=self.device 118 | ) 119 | # NOTE: I don't know why I can't directly assign = 1 120 | weighting_matrix[ref_mask, similarity.argmax(1)] = torch.ones_like(ref_mask).float() 121 | # padding_idx 122 | weighting_matrix[Constants.PAD].fill_(0) 123 | weighted_embedding = weighting_matrix @ self.emb_banks 124 | return weighted_embedding 125 | 126 | elif self.codebook_config["attention"]["type"] == "soft": 127 | """ Soft attention weight. Better not share_att_banks. 128 | key: self.att_banks 129 | value: self.emb_banks 130 | query: ref 131 | """ 132 | q = self.w_qs(ref).view(1, -1, 1, self.d_word_vec) 133 | k = self.w_ks(self.att_banks).view(1, -1, 1, self.d_word_vec) 134 | q = q.permute(2, 0, 1, 3).contiguous().view(1, -1, d_word_vec) # 1 x vocab_size x dk 135 | k = k.permute(2, 0, 1, 3).contiguous().view(1, -1, d_word_vec) # 1 x codebook_size x dk 136 | v = self.emb_banks.unsqueeze(0) 137 | weighted_embedding = self.attention(q, k, v) 138 | with torch.no_grad(): 139 | weighted_embedding[Constants.PAD].fill_(0) 140 | return weighted_embedding 141 | 142 | -------------------------------------------------------------------------------- /lightning/model/speaker_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from torch import nn 7 | from resemblyzer import VoiceEncoder 8 | 9 | 10 | mel_n_channels = 40 11 | model_hidden_size = 256 12 | model_embedding_size = 256 13 | model_num_layers = 3 14 | class GE2E(VoiceEncoder): 15 | """ VoiceEncoder from scratch """ 16 | 17 | def __init__(self, device=None): 18 | super(VoiceEncoder, self).__init__() 19 | 20 | # Define the network 21 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) 22 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 23 | self.relu = nn.ReLU() 24 | 25 | # Get the target device 26 | if device is None: 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | elif isinstance(device, str): 29 | device = torch.device(device) 30 | self.device = device 31 | 32 | 33 | class SpeakerEncoder(pl.LightningModule): 34 | """ Could be a NN encoder or an embedding. """ 35 | 36 | def __new__(cls, *args, **kwargs): 37 | # If not using multi-speaker, do not construct speaker encoder 38 | if len(args) == 3: 39 | preprocess_config, model_config, algorithm_config = args 40 | if not model_config["multi_speaker"]: 41 | return None 42 | return super().__new__(cls) 43 | 44 | def __init__(self, preprocess_config, model_config, algorithm_config): 45 | super().__init__() 46 | self.emb_type = algorithm_config["adapt"]["speaker_emb"] 47 | 48 | if self.emb_type == "table": 49 | speaker_file = os.path.join(preprocess_config["path"]["preprocessed_path"], "speakers.json") 50 | n_speaker = len(json.load(open(speaker_file, "r"))) 51 | self.model = nn.Embedding(n_speaker, model_config["transformer"]["encoder_hidden"]) 52 | elif self.emb_type == "shared": 53 | self.model = nn.Embedding(1, model_config["transformer"]["encoder_hidden"]) 54 | elif self.emb_type == "encoder": 55 | self.model = VoiceEncoder('cpu') 56 | elif self.emb_type == "dvec": 57 | self.model = VoiceEncoder('cpu') 58 | self.freeze() 59 | elif self.emb_type == "scratch_encoder": 60 | self.model = GE2E('cpu') 61 | 62 | def forward(self, args): 63 | if self.emb_type == "table": 64 | speaker = args 65 | return self.model(speaker) 66 | 67 | elif self.emb_type == "shared": 68 | speaker = args 69 | return self.model(torch.zeros_like(speaker)) 70 | 71 | elif self.emb_type == "encoder" or self.emb_type == "dvec" or self.emb_type == "scratch_encoder": 72 | ref_mels, ref_slices = args 73 | partial_embeds = self.model(ref_mels) 74 | speaker_embeds = [partial_embeds[ref_slice].mean(dim=0) for ref_slice in ref_slices] 75 | speaker_emb = torch.stack([torch.nn.functional.normalize(spk_emb, dim=0) for spk_emb in speaker_embeds]) 76 | return speaker_emb 77 | -------------------------------------------------------------------------------- /lightning/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim.lr_scheduler import LambdaLR 4 | 5 | 6 | def get_optimizer(model, model_config, train_config): 7 | init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 8 | 9 | optimizer = torch.optim.Adam( 10 | model.parameters(), 11 | lr=init_lr, 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | return optimizer 17 | -------------------------------------------------------------------------------- /lightning/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset, BatchSampler, DistributedSampler 4 | 5 | 6 | class GroupBatchSampler(BatchSampler): 7 | def __init__(self, sampler, group_size, batch_size, drop_last, sort): 8 | super().__init__(sampler, batch_size, drop_last) 9 | self.dataset = sampler.data_source 10 | self.group_size = group_size 11 | self.sort = sort 12 | 13 | def sort_batches(self, batches): 14 | gbidx = [idx for batch in batches for idx in batch] 15 | texts = [np.array(text_to_sequence(self.dataset.text[idx], self.dataset.cleaners)) for idx in gbidx] 16 | len_arr = np.array([text.shape[0] for text in texts]) 17 | idx_arr = np.argsort(-len_arr) 18 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() 19 | assert len(idx_arr) == len(batches) 20 | return idx_arr 21 | 22 | def __iter__(self): 23 | batches = [] 24 | batch = [] 25 | for idx in self.sampler: 26 | batch.append(idx) 27 | if len(batch) == self.batch_size: 28 | batches.append(batch) 29 | batch = [] 30 | if len(batches) == self.group_size: 31 | if self.sort: 32 | sorted_batches = self.sort_batches(batches) 33 | else: 34 | sorted_batches = batches 35 | for b in sorted_batches: 36 | yield b 37 | batches = [] 38 | if len(batches) > 0: 39 | if self.sort: 40 | sorted_batches = self.sort_batches(batches) 41 | else: 42 | sorted_batches = batches 43 | for b in sorted_batches: 44 | yield b 45 | if len(batch) > 0 and not self.drop_last: 46 | yield batch 47 | 48 | 49 | class DistributedBatchSampler(BatchSampler): 50 | """ `BatchSampler` wrapper that distributes across each batch multiple workers. 51 | 52 | Args: 53 | batch_sampler (torch.utils.data.sampler.BatchSampler) 54 | num_replicas (int, optional): Number of processes participating in distributed training. 55 | rank (int, optional): Rank of the current process within num_replicas. 56 | 57 | Example: 58 | >>> from torch.utils.data.sampler import BatchSampler 59 | >>> from torch.utils.data.sampler import SequentialSampler 60 | >>> sampler = SequentialSampler(list(range(12))) 61 | >>> batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False) 62 | >>> 63 | >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=0)) 64 | [[0, 2], [4, 6], [8, 10]] 65 | >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=1)) 66 | [[1, 3], [5, 7], [9, 11]] 67 | 68 | Reference: 69 | torchnlp.samplers.distributed_batch_sampler 70 | """ 71 | 72 | def __init__(self, batch_sampler, **kwargs): 73 | self.batch_sampler = batch_sampler 74 | self.kwargs = kwargs 75 | 76 | def __iter__(self): 77 | for batch in self.batch_sampler: 78 | yield list(DistributedSampler(batch, **self.kwargs)) 79 | 80 | def __len__(self): 81 | return len(self.batch_sampler) 82 | -------------------------------------------------------------------------------- /lightning/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.optim.lr_scheduler import LambdaLR 4 | 5 | 6 | def get_scheduler(optimizer, train_config): 7 | n_warmup_steps = train_config["optimizer"]["warm_up_step"] 8 | anneal_steps = train_config["optimizer"]["anneal_steps"] 9 | anneal_rate = train_config["optimizer"]["anneal_rate"] 10 | 11 | def lr_lambda(step): 12 | """ For lightning with LambdaLR scheduler """ 13 | current_step = step + 1 14 | lr = np.min( 15 | [ 16 | np.power(current_step, -0.5), 17 | np.power(n_warmup_steps, -1.5) * current_step, 18 | ] 19 | ) 20 | for s in anneal_steps: 21 | if current_step > s: 22 | lr = lr * anneal_rate 23 | return lr 24 | 25 | scheduler = torch.optim.lr_scheduler.LambdaLR( 26 | optimizer=optimizer, 27 | lr_lambda=lr_lambda, 28 | ) 29 | return scheduler 30 | -------------------------------------------------------------------------------- /lightning/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import BaselineSystem 2 | from .meta import MetaSystem 3 | from .imaml import IMAMLSystem 4 | 5 | SYSTEM = { 6 | "meta": MetaSystem, 7 | "imaml": IMAMLSystem, 8 | "baseline": BaselineSystem, 9 | } 10 | 11 | # def get_system(algorithm, preprocess_config, model_config, train_config, algorithm_config, log_dir, result_dir): 12 | # return SYSTEM[algorithm](preprocess_config, model_config, train_config, algorithm_config, log_dir, result_dir) 13 | def get_system(algorithm): 14 | return SYSTEM[algorithm] 15 | -------------------------------------------------------------------------------- /lightning/systems/baseline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import json 5 | import torch 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import learn2learn as l2l 9 | 10 | from utils.tools import get_mask_from_lengths 11 | from lightning.systems.base_adaptor import BaseAdaptorSystem 12 | from lightning.utils import loss2dict 13 | 14 | 15 | class BaselineSystem(BaseAdaptorSystem): 16 | """A PyTorch Lightning module for ANIL for FastSpeech2. 17 | """ 18 | 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | 22 | def on_train_batch_start(self, batch, batch_idx, dataloader_idx): 23 | assert len(batch) == 12, "data with 12 elements" 24 | 25 | def training_step(self, batch, batch_idx): 26 | """ Normal forwarding. 27 | 28 | Function: 29 | common_step(): Defined in `lightning.systems.system.System` 30 | """ 31 | loss, output = self.common_step(batch, batch_idx, train=True) 32 | 33 | # Log metrics to CometLogger 34 | loss_dict = {f"Train/{k}":v for k,v in loss2dict(loss).items()} 35 | self.log_dict(loss_dict, sync_dist=True) 36 | return {'loss': loss[0], 'losses': loss, 'output': output, '_batch': batch} 37 | 38 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): 39 | self._on_meta_batch_start(batch) 40 | 41 | def validation_step(self, batch, batch_idx): 42 | """ Adapted forwarding. 43 | 44 | Function: 45 | meta_learn(): Defined in `lightning.systems.base_adaptor.BaseAdaptorSystem` 46 | """ 47 | val_loss, predictions = self.meta_learn(batch, batch_idx, train=False) 48 | qry_batch = batch[0][1][0] 49 | 50 | # Log metrics to CometLogger 51 | loss_dict = {f"Val/{k}":v for k,v in loss2dict(val_loss).items()} 52 | self.log_dict(loss_dict, sync_dist=True) 53 | return {'losses': val_loss, 'output': predictions, '_batch': qry_batch} 54 | 55 | -------------------------------------------------------------------------------- /lightning/systems/meta.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import json 5 | import torch 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import learn2learn as l2l 9 | 10 | from learn2learn.algorithms.lightning import LightningMAML 11 | 12 | from utils.tools import get_mask_from_lengths 13 | from lightning.systems.base_adaptor import BaseAdaptorSystem 14 | from lightning.utils import loss2dict 15 | 16 | 17 | class MetaSystem(BaseAdaptorSystem): 18 | """A PyTorch Lightning module for ANIL for FastSpeech2. 19 | """ 20 | 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | 24 | def on_after_batch_transfer(self, batch, dataloader_idx): 25 | if self.algorithm_config["adapt"]["phoneme_emb"]["type"] == "codebook": 26 | # NOTE: `self.model.encoder` and `self.learner.encoder` are pointing to 27 | # the same variable, they are not two variables with the same values. 28 | sup_batch, qry_batch, ref_phn_feats = batch[0] 29 | self.model.encoder.src_word_emb._parameters['weight'] = \ 30 | self.model.phn_emb_generator.get_new_embedding(ref_phn_feats).clone() 31 | return [(sup_batch, qry_batch)] 32 | else: 33 | return batch 34 | 35 | # Second order gradients for RNNs 36 | # @torch.backends.cudnn.flags(enabled=False) 37 | # @torch.enable_grad() 38 | # def adapt(self, batch, adaptation_steps=5, learner=None, train=True): 39 | # # TODO: overwrite for supporting SGD and iMAML 40 | # return super().adapt(batch, adaptation_steps, learner, train) 41 | 42 | # # MAML 43 | # # NOTE: skipped 44 | # # TODO: SGD data reuse for more steps 45 | # if learner is None: 46 | # learner = self.learner.clone() 47 | # learner.train() 48 | 49 | # sup_batch = batch[0][0][0] 50 | # first_order = not train 51 | # n_minibatch = 5 52 | # for step in range(adaptation_steps): 53 | # subset = slice(step*n_minibatch, (step+1)*n_minibatch) 54 | # mini_batch = [feat if i in [5, 8] else feat[subset] 55 | # for i, feat in enumerate(sup_batch)] 56 | 57 | # preds = self.forward_learner(learner, *mini_batch[2:]) 58 | # train_error = self.loss_func(mini_batch, preds) 59 | # learner.adapt( 60 | # train_error[0], first_order=first_order, 61 | # allow_unused=False, allow_nograd=True 62 | # ) 63 | # return learner 64 | 65 | def on_train_batch_start(self, batch, batch_idx, dataloader_idx): 66 | self._on_meta_batch_start(batch) 67 | 68 | def training_step(self, batch, batch_idx): 69 | """ Normal forwarding. 70 | 71 | Function: 72 | common_step(): Defined in `lightning.systems.system.System` 73 | """ 74 | train_loss, predictions = self.meta_learn(batch, batch_idx, train=True) 75 | qry_batch = batch[0][1][0] 76 | 77 | # Log metrics to CometLogger 78 | loss_dict = {f"Train/{k}": v for k, v in loss2dict(train_loss).items()} 79 | self.log_dict(loss_dict, sync_dist=True) 80 | return {'loss': train_loss[0], 'losses': train_loss, 'output': predictions, '_batch': qry_batch} 81 | 82 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): 83 | self._on_meta_batch_start(batch) 84 | 85 | def validation_step(self, batch, batch_idx): 86 | """ Adapted forwarding. 87 | 88 | Function: 89 | meta_learn(): Defined in `lightning.systems.base_adaptor.BaseAdaptorSystem` 90 | """ 91 | val_loss, predictions = self.meta_learn(batch, batch_idx) 92 | qry_batch = batch[0][1][0] 93 | 94 | # Log metrics to CometLogger 95 | loss_dict = {f"Val/{k}": v for k, v in loss2dict(val_loss).items()} 96 | self.log_dict(loss_dict, sync_dist=True) 97 | return {'losses': val_loss, 'output': predictions, '_batch': qry_batch} 98 | -------------------------------------------------------------------------------- /lightning/utils.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import numpy as np 3 | import torch 4 | import random 5 | from contextlib import contextmanager 6 | 7 | 8 | class LightningMelGAN(pl.LightningModule): 9 | def __init__(self): 10 | super().__init__() 11 | vocoder = torch.hub.load( 12 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 13 | ) 14 | self.mel2wav = vocoder.mel2wav 15 | 16 | def inverse(self, mel): 17 | with torch.no_grad(): 18 | return self.mel2wav(mel).squeeze(1) 19 | 20 | def infer(self, mels, max_wav_value, lengths=None): 21 | """preprocess_config["preprocessing"]["audio"]["max_wav_value"] 22 | """ 23 | wavs = self.inverse(mels / np.log(10)) 24 | wavs = (wavs.cpu().numpy() * max_wav_value).astype("int16") 25 | wavs = [wav for wav in wavs] 26 | 27 | for i in range(len(mels)): 28 | if lengths is not None: 29 | wavs[i] = wavs[i][: lengths[i]] 30 | return wavs 31 | 32 | @contextmanager 33 | def seed_all(seed=None, devices=None): 34 | rstate = random.getstate() 35 | nstate = np.random.get_state() 36 | with torch.random.fork_rng(devices): 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | if seed is None: 40 | seed = torch.seed() 41 | torch.cuda.manual_seed_all(seed) 42 | else: 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed_all(seed) 45 | yield 46 | random.setstate(rstate) 47 | np.random.set_state(nstate) 48 | 49 | class EpisodicInfiniteWrapper: 50 | def __init__(self, dataset, epoch_length): 51 | self.dataset = dataset 52 | self.epoch_length = epoch_length 53 | 54 | def __getitem__(self, idx): 55 | # new_idx = random.randrange(len(self.dataset)) 56 | # return self.dataset[new_idx] 57 | return random.choice(self.dataset) 58 | 59 | def __len__(self): 60 | return self.epoch_length 61 | 62 | def loss2str(loss): 63 | return dict2str(loss2dict(loss)) 64 | 65 | def loss2dict(loss): 66 | tblog_dict = { 67 | "Total Loss" : loss[0].item(), 68 | "Mel Loss" : loss[1].item(), 69 | "Mel-Postnet Loss" : loss[2].item(), 70 | "Pitch Loss" : loss[3].item(), 71 | "Energy Loss" : loss[4].item(), 72 | "Duration Loss" : loss[5].item(), 73 | } 74 | return tblog_dict 75 | 76 | def dict2loss(tblog_dict): 77 | loss = ( 78 | tblog_dict["Total Loss"], 79 | tblog_dict["Mel Loss"], 80 | tblog_dict["Mel-Postnet Loss"], 81 | tblog_dict["Pitch Loss"], 82 | tblog_dict["Energy Loss"], 83 | tblog_dict["Duration Loss"], 84 | ) 85 | return loss 86 | 87 | def dict2str(tblog_dict): 88 | message = ", ".join([f"{k}: {v:.4f}" for k, v in tblog_dict.items()]) 89 | return message 90 | 91 | 92 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import comet_ml 5 | import pytorch_lightning as pl 6 | import torch 7 | import yaml 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | from pytorch_lightning.profiler import AdvancedProfiler 13 | 14 | from config.comet import COMET_CONFIG 15 | from lightning.datamodules import get_datamodule 16 | from lightning.systems import get_system 17 | 18 | quiet = False 19 | if quiet: 20 | # NOTSET/DEBUG/INFO/WARNING/ERROR/CRITICAL 21 | os.environ["COMET_LOGGING_CONSOLE"] = "ERROR" 22 | import warnings 23 | warnings.filterwarnings("ignore") 24 | import logging 25 | # configure logging at the root level of lightning 26 | logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) 27 | 28 | 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | TRAINER_CONFIG = { 31 | "gpus": -1 if torch.cuda.is_available() else None, 32 | "strategy": "ddp" if torch.cuda.is_available() else None, 33 | "auto_select_gpus": True, 34 | "limit_train_batches": 1.0, # Useful for fast experiment 35 | "deterministic": True, 36 | "process_position": 1, 37 | "profiler": 'simple', 38 | } 39 | 40 | 41 | def main(args, configs): 42 | print("Prepare training ...") 43 | 44 | preprocess_configs, model_config, train_config, algorithm_config = configs 45 | 46 | for p in train_config["path"].values(): 47 | os.makedirs(p, exist_ok=True) 48 | 49 | # Checkpoint for resume training or testing 50 | ckpt_file = None 51 | if args.exp_key is not None: 52 | ckpt_file = os.path.join( 53 | 'output/ckpt/LibriTTS', COMET_CONFIG["project_name"], 54 | args.exp_key, 'checkpoints', args.ckpt_file 55 | ) 56 | 57 | trainer_training_config = { 58 | 'max_steps': train_config["step"]["total_step"], 59 | 'log_every_n_steps': train_config["step"]["log_step"], 60 | 'weights_save_path': train_config["path"]["ckpt_path"], 61 | 'gradient_clip_val': train_config["optimizer"]["grad_clip_thresh"], 62 | 'accumulate_grad_batches': train_config["optimizer"]["grad_acc_step"], 63 | 'resume_from_checkpoint': ckpt_file, 64 | } 65 | if algorithm_config["type"] == 'imaml': 66 | # should manually clip grad 67 | del trainer_training_config['gradient_clip_val'] 68 | 69 | if args.stage == 'train': 70 | # Init logger 71 | comet_logger = pl.loggers.CometLogger( 72 | save_dir=os.path.join(train_config["path"]["log_path"], "meta"), 73 | experiment_key=args.exp_key, 74 | experiment_name=algorithm_config["name"], 75 | **COMET_CONFIG 76 | ) 77 | comet_logger.log_hyperparams({ 78 | "preprocess_config": preprocess_configs, 79 | "model_config": model_config, 80 | "train_config": train_config, 81 | "algorithm_config": algorithm_config, 82 | }) 83 | loggers = [comet_logger] 84 | log_dir = os.path.join(comet_logger._save_dir, comet_logger.version) 85 | result_dir = os.path.join( 86 | train_config['path']['result_path'], comet_logger.version 87 | ) 88 | else: 89 | assert args.exp_key is not None 90 | log_dir = os.path.join( 91 | train_config["path"]["log_path"], "meta", args.exp_key 92 | ) 93 | result_dir = os.path.join( 94 | train_config['path']['result_path'], args.exp_key, algorithm_config["name"] 95 | ) 96 | 97 | # Get dataset 98 | datamodule = get_datamodule(algorithm_config["type"])( 99 | preprocess_configs, train_config, algorithm_config, log_dir, result_dir 100 | ) 101 | 102 | if args.stage == 'train': 103 | # Get model 104 | system = get_system(algorithm_config["type"]) 105 | model = system( 106 | preprocess_configs[0], model_config, train_config, algorithm_config, 107 | log_dir, result_dir 108 | ) 109 | # Train 110 | trainer = pl.Trainer( 111 | logger=loggers, **TRAINER_CONFIG, **trainer_training_config 112 | ) 113 | pl.seed_everything(43, True) 114 | trainer.fit(model, datamodule=datamodule) 115 | 116 | elif args.stage == 'test' or args.stage == 'predict': 117 | # Get model 118 | system = get_system(algorithm_config["type"]) 119 | model = system.load_from_checkpoint( 120 | ckpt_file, 121 | preprocess_config=preprocess_configs[0], 122 | model_config=model_config, 123 | train_config=train_config, 124 | algorithm_config=algorithm_config, 125 | log_dir=log_dir, result_dir=result_dir, 126 | strict=False, 127 | ) 128 | # Test 129 | trainer = pl.Trainer(**TRAINER_CONFIG) 130 | trainer.test(model, datamodule=datamodule) 131 | 132 | elif args.stage == 'debug': 133 | del datamodule 134 | datamodule = get_datamodule("base")( 135 | preprocess_configs, train_config, algorithm_config, log_dir, result_dir 136 | ) 137 | datamodule.setup('test') 138 | for _ in tqdm(datamodule.test_dataset, desc="test_dataset"): 139 | pass 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument( 145 | "-p", "--preprocess_config", type=str, nargs='+', help="path to preprocess.yaml", 146 | default=['config/preprocess/miniLibriTTS.yaml'], 147 | # default=['config/preprocess/LibriTTS.yaml'], 148 | ) 149 | parser.add_argument( 150 | "-m", "--model_config", type=str, help="path to model.yaml", 151 | default='config/model/dev.yaml', 152 | # default='config/model/base.yaml', 153 | ) 154 | parser.add_argument( 155 | "-t", "--train_config", type=str, nargs='+', help="path to train.yaml", 156 | default=['config/train/dev.yaml', 'config/train/miniLibriTTS.yaml'], 157 | # default=['config/train/base.yaml', 'config/train/LibriTTS.yaml'], 158 | ) 159 | parser.add_argument( 160 | "-a", "--algorithm_config", type=str, help="path to algorithm.yaml", 161 | default='config/algorithm/dev.yaml', 162 | ) 163 | parser.add_argument( 164 | "-e", "--exp_key", type=str, help="experiment key", 165 | default=None, 166 | ) 167 | parser.add_argument( 168 | "-c", "--ckpt_file", type=str, help="ckpt file name", 169 | default="last.ckpt", 170 | ) 171 | parser.add_argument( 172 | "-s", "--stage", type=str, help="stage (train/val/test/predict)", 173 | default="train", 174 | ) 175 | args = parser.parse_args() 176 | 177 | # Read Config 178 | preprocess_configs = [ 179 | yaml.load(open(path, "r"), Loader=yaml.FullLoader) 180 | for path in args.preprocess_config 181 | ] 182 | model_config = yaml.load( 183 | open(args.model_config, "r"), Loader=yaml.FullLoader 184 | ) 185 | train_config = yaml.load( 186 | open(args.train_config[0], "r"), Loader=yaml.FullLoader 187 | ) 188 | train_config.update( 189 | yaml.load(open(args.train_config[1], "r"), Loader=yaml.FullLoader) 190 | ) 191 | algorithm_config = yaml.load( 192 | open(args.algorithm_config, "r"), Loader=yaml.FullLoader 193 | ) 194 | configs = (preprocess_configs, model_config, train_config, algorithm_config) 195 | 196 | main(args, configs) 197 | -------------------------------------------------------------------------------- /prepare_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | 5 | from preprocessor import libritts, vctk 6 | 7 | 8 | def main(config): 9 | if "LibriTTS" in config["dataset"]: 10 | corpus_path = config["path"]["corpus_path"] 11 | raw_path = config["path"]["raw_path"] 12 | dsets = [] 13 | for dmode, dset in config["subsets"].items(): 14 | if dset == "train-clean": 15 | dsets += ["train-clean-100", "train-clean-360"] 16 | elif dset == "train-all": 17 | dsets += ["train-clean-100", "train-clean-360", "train-other-500"] 18 | # if isinstance(dset, list): 19 | # dsets += dset 20 | # elif isinstance(dset, str): 21 | # dsets.append(dset) 22 | for dset in dsets: 23 | config["path"]["corpus_path"] = os.path.join(corpus_path, dset) 24 | config["path"]["raw_path"] = os.path.join(raw_path, dset) 25 | libritts.prepare_align(config) 26 | if "VCTK" in config["dataset"]: 27 | corpus_path = config["path"]["corpus_path"] 28 | raw_path = config["path"]["raw_path"] 29 | for dmode, dset in config["subsets"].items(): 30 | # config["path"]["corpus_path"] = corpus_path 31 | config["path"]["raw_path"] = os.path.join(raw_path, dset) 32 | vctk.prepare_align(config) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 38 | args = parser.parse_args() 39 | 40 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 41 | main(config) 42 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import yaml 4 | import os 5 | 6 | from preprocessor.preprocessor import Preprocessor 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 12 | args = parser.parse_args() 13 | 14 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 15 | if "LibriTTS" in config["dataset"]: 16 | if config["subsets"]["train"] == "train-clean": 17 | config["subsets"]["train"] = ["train-clean-100", "train-clean-360"] 18 | elif config["subsets"]["train"] == "train-all": 19 | config["subsets"]["train"] = ["train-clean-100", "train-clean-360", "train-other-500"] 20 | preprocessor = Preprocessor(config) 21 | preprocessor.build_from_path() 22 | if "LibriTTS" in config["dataset"]: 23 | if config["subsets"]["train"] == ["train-clean-100", "train-clean-360"]: 24 | with open(os.path.join(config["path"]["preprocessed_path"], "train-clean.txt"), 'wb') as out_file: 25 | with open(os.path.join(config["path"]["preprocessed_path"], "train-clean-100.txt"), 'rb') as in_file: 26 | shutil.copyfileobj(in_file, out_file) 27 | with open(os.path.join(config["path"]["preprocessed_path"], "train-clean-360.txt"), 'rb') as in_file: 28 | shutil.copyfileobj(in_file, out_file) 29 | elif config["subsets"]["train"] == ["train-clean-100", "train-clean-360", "train-other-500"]: 30 | with open(os.path.join(config["path"]["preprocessed_path"], "train-all.txt"), 'wb') as out_file: 31 | with open(os.path.join(config["path"]["preprocessed_path"], "train-clean-100.txt"), 'rb') as in_file: 32 | shutil.copyfileobj(in_file, out_file) 33 | with open(os.path.join(config["path"]["preprocessed_path"], "train-clean-360.txt"), 'rb') as in_file: 34 | shutil.copyfileobj(in_file, out_file) 35 | with open(os.path.join(config["path"]["preprocessed_path"], "train-other-500.txt"), 'rb') as in_file: 36 | shutil.copyfileobj(in_file, out_file) 37 | -------------------------------------------------------------------------------- /preprocessed_data/example_corpus/TextGrid/speaker1/speaker1_utterance1.TextGrid: -------------------------------------------------------------------------------- 1 | File type = "ooTextFile" 2 | Object class = "TextGrid" 3 | 4 | xmin = 0.0 5 | xmax = 1.48 6 | tiers? 7 | size = 2 8 | item []: 9 | item [1]: 10 | class = "IntervalTier" 11 | name = "words" 12 | xmin = 0.0 13 | xmax = 1.48 14 | intervals: size = 5 15 | intervals [1]: 16 | xmin = 0.000 17 | xmax = 0.360 18 | text = "tom" 19 | intervals [2]: 20 | xmin = 0.360 21 | xmax = 0.440 22 | text = "the" 23 | intervals [3]: 24 | xmin = 0.440 25 | xmax = 0.930 26 | text = "piper's" 27 | intervals [4]: 28 | xmin = 0.930 29 | xmax = 1.420 30 | text = "son" 31 | intervals [5]: 32 | xmin = 1.420 33 | xmax = 1.48 34 | text = "" 35 | item [2]: 36 | class = "IntervalTier" 37 | name = "phones" 38 | xmin = 0.0 39 | xmax = 1.48 40 | intervals: size = 15 41 | intervals [1]: 42 | xmin = 0.000 43 | xmax = 0.110 44 | text = "T" 45 | intervals [2]: 46 | xmin = 0.110 47 | xmax = 0.290 48 | text = "AA1" 49 | intervals [3]: 50 | xmin = 0.290 51 | xmax = 0.360 52 | text = "M" 53 | intervals [4]: 54 | xmin = 0.360 55 | xmax = 0.400 56 | text = "DH" 57 | intervals [5]: 58 | xmin = 0.400 59 | xmax = 0.440 60 | text = "AH1" 61 | intervals [6]: 62 | xmin = 0.440 63 | xmax = 0.540 64 | text = "P" 65 | intervals [7]: 66 | xmin = 0.540 67 | xmax = 0.650 68 | text = "AY1" 69 | intervals [8]: 70 | xmin = 0.650 71 | xmax = 0.750 72 | text = "P" 73 | intervals [9]: 74 | xmin = 0.750 75 | xmax = 0.860 76 | text = "ER0" 77 | intervals [10]: 78 | xmin = 0.860 79 | xmax = 0.930 80 | text = "Z" 81 | intervals [11]: 82 | xmin = 0.930 83 | xmax = 1.070 84 | text = "S" 85 | intervals [12]: 86 | xmin = 1.070 87 | xmax = 1.260 88 | text = "AH1" 89 | intervals [13]: 90 | xmin = 1.260 91 | xmax = 1.420 92 | text = "N" 93 | intervals [14]: 94 | xmin = 1.420 95 | xmax = 1.460 96 | text = "sp" 97 | intervals [15]: 98 | xmin = 1.460 99 | xmax = 1.48 100 | text = "" 101 | -------------------------------------------------------------------------------- /preprocessor/libritts.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | from text import _clean_text 9 | 10 | 11 | def prepare_align(config): 12 | in_dir = config["path"]["corpus_path"] 13 | out_dir = config["path"]["raw_path"] 14 | corpus = config["dataset"] 15 | dset = out_dir.rsplit('/', 1)[1] 16 | 17 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 18 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 19 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 20 | for speaker in tqdm(os.listdir(in_dir), desc=f"{corpus}/{dset}"): 21 | for chapter in os.listdir(os.path.join(in_dir, speaker)): 22 | for file_name in os.listdir(os.path.join(in_dir, speaker, chapter)): 23 | if file_name[-4:] != ".wav": 24 | continue 25 | base_name = file_name[:-4] 26 | text_path = os.path.join( 27 | in_dir, speaker, chapter, "{}.normalized.txt".format(base_name) 28 | ) 29 | wav_path = os.path.join( 30 | in_dir, speaker, chapter, "{}.wav".format(base_name) 31 | ) 32 | with open(text_path) as f: 33 | text = f.readline().strip("\n") 34 | text = _clean_text(text, cleaners) 35 | 36 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 37 | wav, _ = librosa.load(wav_path, sampling_rate) 38 | wav = wav / max(abs(wav)) * max_wav_value 39 | wavfile.write( 40 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 41 | sampling_rate, 42 | wav.astype(np.int16), 43 | ) 44 | with open( 45 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 46 | "w", 47 | ) as f1: 48 | f1.write(text) 49 | -------------------------------------------------------------------------------- /preprocessor/vctk.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | from text import _clean_text 9 | 10 | 11 | def prepare_align(config): 12 | in_dir = config["path"]["corpus_path"] 13 | out_dir = config["path"]["raw_path"] 14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 16 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 17 | for speaker in tqdm(os.listdir(os.path.join(in_dir, "wav48_silence_trimmed"))): 18 | if os.path.isfile(os.path.join(in_dir, "wav48_silence_trimmed", speaker)): 19 | continue 20 | for file_name in os.listdir(os.path.join(in_dir, "wav48_silence_trimmed", speaker)): 21 | if file_name[-10:] != "_mic2.flac": 22 | continue 23 | base_name = file_name[:-10] 24 | text_path = os.path.join( 25 | in_dir, "txt", speaker, "{}.txt".format(base_name) 26 | ) 27 | wav_path = os.path.join( 28 | in_dir, "wav48_silence_trimmed", speaker, "{}_mic2.flac".format(base_name) 29 | ) 30 | with open(text_path) as f: 31 | text = f.readline().strip("\n") 32 | text = _clean_text(text, cleaners) 33 | 34 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 35 | wav, _ = librosa.load(wav_path, sampling_rate) 36 | wav = wav / max(abs(wav)) * max_wav_value 37 | wavfile.write( 38 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 39 | sampling_rate, 40 | wav.astype(np.int16), 41 | ) 42 | with open( 43 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 44 | "w", 45 | ) as f1: 46 | f1.write(text) 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning 2 | learn2learn >= 0.1.6 3 | comet-ml 4 | resemblyzer 5 | g2p-en 6 | inflect 7 | librosa 8 | matplotlib 9 | numba 10 | numpy 11 | pypinyin 12 | pyworld 13 | PyYAML 14 | scikit-learn 15 | scipy 16 | soundfile 17 | tensorboard 18 | tgt 19 | torch 20 | tqdm 21 | unidecode 22 | torchaudio 23 | 24 | # Evaluation 25 | seaborn 26 | -e git+https://github.com/jasminsternkopf/mel_cepstral_distance.git@main#egg=mcd 27 | -e git+https://github.com/aliutkus/speechmetrics#egg=speechmetrics[gpu] 28 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | """ 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 38 | sequence += _arpabet_to_sequence(m.group(2)) 39 | text = m.group(3) 40 | 41 | return sequence 42 | 43 | 44 | def sequence_to_text(sequence): 45 | """Converts a sequence of IDs back to a string""" 46 | result = "" 47 | for symbol_id in sequence: 48 | if symbol_id in _id_to_symbol: 49 | s = _id_to_symbol[symbol_id] 50 | # Enclose ARPAbet back in curly braces: 51 | if len(s) > 1 and s[0] == "@": 52 | s = "{%s}" % s[1:] 53 | result += s 54 | return result.replace("}{", " ") 55 | 56 | 57 | def _clean_text(text, cleaner_names): 58 | for name in cleaner_names: 59 | cleaner = getattr(cleaners, name) 60 | if not cleaner: 61 | raise Exception("Unknown cleaner: %s" % name) 62 | text = cleaner(text) 63 | return text 64 | 65 | 66 | def _symbols_to_sequence(symbols): 67 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 68 | 69 | 70 | def _arpabet_to_sequence(text): 71 | return _symbols_to_sequence(["@" + s for s in text.split()]) 72 | 73 | 74 | def _should_keep_symbol(s): 75 | return s in _symbol_to_id and s != "_" and s != "~" 76 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from text import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | 11 | class FFTBlock(torch.nn.Module): 12 | """FFT Block""" 13 | 14 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 15 | super(FFTBlock, self).__init__() 16 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 17 | self.pos_ffn = PositionwiseFeedForward( 18 | d_model, d_inner, kernel_size, dropout=dropout 19 | ) 20 | 21 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 22 | enc_output, enc_slf_attn = self.slf_attn( 23 | enc_input, enc_input, enc_input, mask=slf_attn_mask 24 | ) 25 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 26 | 27 | enc_output = self.pos_ffn(enc_output) 28 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 29 | 30 | return enc_output, enc_slf_attn 31 | 32 | 33 | class ConvNorm(torch.nn.Module): 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels, 38 | kernel_size=1, 39 | stride=1, 40 | padding=None, 41 | dilation=1, 42 | bias=True, 43 | w_init_gain="linear", 44 | ): 45 | super(ConvNorm, self).__init__() 46 | 47 | if padding is None: 48 | assert kernel_size % 2 == 1 49 | padding = int(dilation * (kernel_size - 1) / 2) 50 | 51 | self.conv = torch.nn.Conv1d( 52 | in_channels, 53 | out_channels, 54 | kernel_size=kernel_size, 55 | stride=stride, 56 | padding=padding, 57 | dilation=dilation, 58 | bias=bias, 59 | ) 60 | 61 | def forward(self, signal): 62 | conv_signal = self.conv(signal) 63 | 64 | return conv_signal 65 | 66 | 67 | class PostNet(nn.Module): 68 | """ 69 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 70 | """ 71 | 72 | def __init__( 73 | self, 74 | n_mel_channels=80, 75 | postnet_embedding_dim=512, 76 | postnet_kernel_size=5, 77 | postnet_n_convolutions=5, 78 | ): 79 | 80 | super(PostNet, self).__init__() 81 | self.convolutions = nn.ModuleList() 82 | 83 | self.convolutions.append( 84 | nn.Sequential( 85 | ConvNorm( 86 | n_mel_channels, 87 | postnet_embedding_dim, 88 | kernel_size=postnet_kernel_size, 89 | stride=1, 90 | padding=int((postnet_kernel_size - 1) / 2), 91 | dilation=1, 92 | w_init_gain="tanh", 93 | ), 94 | nn.BatchNorm1d(postnet_embedding_dim), 95 | ) 96 | ) 97 | 98 | for i in range(1, postnet_n_convolutions - 1): 99 | self.convolutions.append( 100 | nn.Sequential( 101 | ConvNorm( 102 | postnet_embedding_dim, 103 | postnet_embedding_dim, 104 | kernel_size=postnet_kernel_size, 105 | stride=1, 106 | padding=int((postnet_kernel_size - 1) / 2), 107 | dilation=1, 108 | w_init_gain="tanh", 109 | ), 110 | nn.BatchNorm1d(postnet_embedding_dim), 111 | ) 112 | ) 113 | 114 | self.convolutions.append( 115 | nn.Sequential( 116 | ConvNorm( 117 | postnet_embedding_dim, 118 | n_mel_channels, 119 | kernel_size=postnet_kernel_size, 120 | stride=1, 121 | padding=int((postnet_kernel_size - 1) / 2), 122 | dilation=1, 123 | w_init_gain="linear", 124 | ), 125 | nn.BatchNorm1d(n_mel_channels), 126 | ) 127 | ) 128 | 129 | def forward(self, x): 130 | x = x.contiguous().transpose(1, 2) 131 | 132 | for i in range(len(self.convolutions) - 1): 133 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 134 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 135 | 136 | x = x.contiguous().transpose(1, 2) 137 | return x 138 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import transformer.Constants as Constants 6 | from .Layers import FFTBlock 7 | from text.symbols import symbols 8 | 9 | 10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 11 | """ Sinusoid position encoding table """ 12 | 13 | def cal_angle(position, hid_idx): 14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 15 | 16 | def get_posi_angle_vec(position): 17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 18 | 19 | sinusoid_table = np.array( 20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 21 | ) 22 | 23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 25 | 26 | if padding_idx is not None: 27 | # zero vector for padding dimension 28 | sinusoid_table[padding_idx] = 0.0 29 | 30 | return torch.FloatTensor(sinusoid_table) 31 | 32 | 33 | class Encoder(nn.Module): 34 | """ Encoder """ 35 | 36 | def __init__(self, config): 37 | super(Encoder, self).__init__() 38 | 39 | n_position = config["max_seq_len"] + 1 40 | n_src_vocab = len(symbols) + 1 41 | d_word_vec = config["transformer"]["encoder_hidden"] 42 | n_layers = config["transformer"]["encoder_layer"] 43 | n_head = config["transformer"]["encoder_head"] 44 | d_k = d_v = ( 45 | config["transformer"]["encoder_hidden"] 46 | // config["transformer"]["encoder_head"] 47 | ) 48 | d_model = config["transformer"]["encoder_hidden"] 49 | d_inner = config["transformer"]["conv_filter_size"] 50 | kernel_size = config["transformer"]["conv_kernel_size"] 51 | dropout = config["transformer"]["encoder_dropout"] 52 | 53 | self.max_seq_len = config["max_seq_len"] 54 | self.d_model = d_model 55 | 56 | self.src_word_emb = nn.Embedding( 57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD 58 | ) 59 | self.position_enc = nn.Parameter( 60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 61 | requires_grad=False, 62 | ) 63 | 64 | self.layer_stack = nn.ModuleList( 65 | [ 66 | FFTBlock( 67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 68 | ) 69 | for _ in range(n_layers) 70 | ] 71 | ) 72 | 73 | def forward(self, src_seq, mask, return_attns=False): 74 | 75 | enc_slf_attn_list = [] 76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 77 | 78 | # -- Prepare masks 79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 80 | 81 | # -- Forward 82 | if not self.training and src_seq.shape[1] > self.max_seq_len: 83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 84 | src_seq.shape[1], self.d_model 85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 86 | src_seq.device 87 | ) 88 | else: 89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[ 90 | :, :max_len, : 91 | ].expand(batch_size, -1, -1) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 96 | ) 97 | if return_attns: 98 | enc_slf_attn_list += [enc_slf_attn] 99 | 100 | return enc_output 101 | 102 | 103 | class Decoder(nn.Module): 104 | """ Decoder """ 105 | 106 | def __init__(self, config): 107 | super(Decoder, self).__init__() 108 | 109 | n_position = config["max_seq_len"] + 1 110 | d_word_vec = config["transformer"]["decoder_hidden"] 111 | n_layers = config["transformer"]["decoder_layer"] 112 | n_head = config["transformer"]["decoder_head"] 113 | d_k = d_v = ( 114 | config["transformer"]["decoder_hidden"] 115 | // config["transformer"]["decoder_head"] 116 | ) 117 | d_model = config["transformer"]["decoder_hidden"] 118 | d_inner = config["transformer"]["conv_filter_size"] 119 | kernel_size = config["transformer"]["conv_kernel_size"] 120 | dropout = config["transformer"]["decoder_dropout"] 121 | 122 | self.max_seq_len = config["max_seq_len"] 123 | self.d_model = d_model 124 | 125 | self.position_enc = nn.Parameter( 126 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 127 | requires_grad=False, 128 | ) 129 | 130 | self.layer_stack = nn.ModuleList( 131 | [ 132 | FFTBlock( 133 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 134 | ) 135 | for _ in range(n_layers) 136 | ] 137 | ) 138 | 139 | def forward(self, enc_seq, mask, return_attns=False): 140 | 141 | dec_slf_attn_list = [] 142 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 143 | 144 | # -- Forward 145 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 146 | # -- Prepare masks 147 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 148 | dec_output = enc_seq + get_sinusoid_encoding_table( 149 | enc_seq.shape[1], self.d_model 150 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 151 | enc_seq.device 152 | ) 153 | else: 154 | max_len = min(max_len, self.max_seq_len) 155 | 156 | # -- Prepare masks 157 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 158 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 159 | :, :max_len, : 160 | ].expand(batch_size, -1, -1) 161 | mask = mask[:, :max_len] 162 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 163 | 164 | for dec_layer in self.layer_stack: 165 | dec_output, dec_slf_attn = dec_layer( 166 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 167 | ) 168 | if return_attns: 169 | dec_slf_attn_list += [dec_slf_attn] 170 | 171 | return dec_output, mask 172 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | 49 | output = output.view(n_head, sz_b, len_q, d_v) 50 | output = ( 51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 52 | ) # b x lq x (n*dv) 53 | 54 | output = self.dropout(self.fc(output)) 55 | output = self.layer_norm(output + residual) 56 | 57 | return output, attn 58 | 59 | 60 | class PositionwiseFeedForward(nn.Module): 61 | """ A two-feed-forward-layer module """ 62 | 63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 64 | super().__init__() 65 | 66 | # Use Conv1D 67 | # position-wise 68 | self.w_1 = nn.Conv1d( 69 | d_in, 70 | d_hid, 71 | kernel_size=kernel_size[0], 72 | padding=(kernel_size[0] - 1) // 2, 73 | ) 74 | # position-wise 75 | self.w_2 = nn.Conv1d( 76 | d_hid, 77 | d_in, 78 | kernel_size=kernel_size[1], 79 | padding=(kernel_size[1] - 1) // 2, 80 | ) 81 | 82 | self.layer_norm = nn.LayerNorm(d_in) 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def forward(self, x): 86 | residual = x 87 | output = x.transpose(1, 2) 88 | output = self.w_2(F.relu(self.w_1(output))) 89 | output = output.transpose(1, 2) 90 | output = self.dropout(output) 91 | output = self.layer_norm(output + residual) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import Encoder, Decoder 2 | from .Layers import PostNet 3 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def get_param_num(model): 9 | num_param = sum(param.numel() for param in model.parameters()) 10 | return num_param 11 | 12 | 13 | def get_vocoder(config, device): 14 | name = config["vocoder"]["model"] 15 | speaker = config["vocoder"]["speaker"] 16 | 17 | if name == "MelGAN": 18 | if speaker == "LJSpeech": 19 | vocoder = torch.hub.load( 20 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 21 | ) 22 | elif speaker == "universal": 23 | vocoder = torch.hub.load( 24 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 25 | ) 26 | vocoder.mel2wav.eval() 27 | vocoder.mel2wav.to(device) 28 | 29 | return vocoder 30 | 31 | 32 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): 33 | name = model_config["vocoder"]["model"] 34 | with torch.no_grad(): 35 | if name == "MelGAN": 36 | wavs = vocoder.inverse(mels / np.log(10)) 37 | elif name == "HiFi-GAN": 38 | wavs = vocoder(mels).squeeze(1) 39 | 40 | wavs = ( 41 | wavs.cpu().numpy() 42 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"] 43 | ).astype("int16") 44 | wavs = [wav for wav in wavs] 45 | 46 | for i in range(len(mels)): 47 | if lengths is not None: 48 | wavs[i] = wavs[i][: lengths[i]] 49 | 50 | return wavs 51 | --------------------------------------------------------------------------------