├── .gitignore ├── LICENSE ├── README.md ├── audio ├── __init__.py ├── audio_processing.py ├── stft.py └── tools.py ├── config └── LibriTTS │ ├── model.yaml │ ├── preprocess.yaml │ └── train.yaml ├── dataset.py ├── demo └── LibriTTS │ ├── references │ ├── 0_19_198_000000_000000.lab │ ├── 0_19_198_000000_000000.wav │ ├── 1_26_495_000004_000000.lab │ ├── 1_26_495_000004_000000.wav │ ├── 2_27_123349_000001_000000.lab │ ├── 2_27_123349_000001_000000.wav │ ├── 3_32_4137_000005_000001.lab │ ├── 3_32_4137_000005_000001.wav │ ├── 4_39_121916_000015_000005.lab │ ├── 4_39_121916_000015_000005.wav │ ├── 5_40_222_000001_000000.lab │ └── 5_40_222_000001_000000.wav │ └── results │ ├── 0_the two children therefore got up, dressed themselves quickly, and went away..png │ ├── 0_the two children therefore got up, dressed themselves quickly, and went away..wav │ ├── 1_the two children therefore got up, dressed themselves quickly, and went away..png │ ├── 1_the two children therefore got up, dressed themselves quickly, and went away..wav │ ├── 2_the two children therefore got up, dressed themselves quickly, and went away..png │ ├── 2_the two children therefore got up, dressed themselves quickly, and went away..wav │ ├── 3_the two children therefore got up, dressed themselves quickly, and went away..png │ ├── 3_the two children therefore got up, dressed themselves quickly, and went away..wav │ ├── 4_the two children therefore got up, dressed themselves quickly, and went away..png │ ├── 4_the two children therefore got up, dressed themselves quickly, and went away..wav │ ├── 5_the two children therefore got up, dressed themselves quickly, and went away..png │ └── 5_the two children therefore got up, dressed themselves quickly, and went away..wav ├── evaluate.py ├── filelist_filtering.py ├── hifigan ├── LICENSE ├── __init__.py ├── config.json ├── generator_LJSpeech.pth.tar.zip ├── generator_universal.pth.tar.zip └── models.py ├── img ├── model_1.png ├── model_2.png ├── tensorboard_audio.png ├── tensorboard_loss.png └── tensorboard_spec.png ├── lexicon ├── librispeech-lexicon.txt └── pinyin-lexicon-r.txt ├── model ├── StyleSpeech.py ├── __init__.py ├── blocks.py ├── loss.py ├── modules.py └── optimizer.py ├── prepare_align.py ├── preprocess.py ├── preprocessed_data └── LibriTTS │ ├── speakers.json │ ├── stats.json │ ├── train.txt │ ├── train_filtered.txt │ └── val.txt ├── preprocessor ├── aishell3.py ├── libritts.py ├── ljspeech.py └── preprocessor.py ├── requirements.txt ├── synthesize.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py ├── pinyin.py └── symbols.py ├── train.py └── utils ├── model.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | __pycache__ 107 | .vscode 108 | .DS_Store 109 | 110 | # MFA 111 | montreal-forced-aligner/ 112 | 113 | # data, checkpoint, and models 114 | raw_data/ 115 | output/ 116 | *.npy 117 | TextGrid/ 118 | hifigan/*.pth.tar 119 | *.out 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Keon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleSpeech - PyTorch Implementation 2 | 3 | PyTorch Implementation of [Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation](https://arxiv.org/abs/2106.03153). 4 | 5 |

6 | 7 |

8 | 9 |

10 | 11 |

12 | 13 | # Branch 14 | - [x] StyleSpeech (`naive` branch) 15 | - [x] Meta-StyleSpeech (`main` branch) 16 | 17 | # Quickstart 18 | 19 | ## Dependencies 20 | You can install the Python dependencies with 21 | ``` 22 | pip3 install -r requirements.txt 23 | ``` 24 | 25 | ## Inference 26 | 27 | You have to download [pretrained models](https://drive.google.com/drive/folders/1fQmu1v7fRgfM-TwxAJ96UUgnl79f1FHt?usp=sharing) and put them in ``output/ckpt/LibriTTS_meta_learner/``. 28 | 29 | For English multi-speaker TTS, run 30 | ``` 31 | python3 synthesize.py --text "YOUR_DESIRED_TEXT" --ref_audio path/to/reference_audio.wav --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml 32 | ``` 33 | The generated utterances will be put in ``output/result/``. Your synthesized speech will have `ref_audio`'s style. 34 | 35 | 36 | ## Batch Inference 37 | Batch inference is also supported, try 38 | 39 | ``` 40 | python3 synthesize.py --source preprocessed_data/LibriTTS/val.txt --restore_step 200000 --mode batch -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml 41 | ``` 42 | to synthesize all utterances in ``preprocessed_data/LibriTTS/val.txt``. This can be viewed as a reconstruction of validation datasets referring to themselves for the reference style. 43 | 44 | ## Controllability 45 | The pitch/volume/speaking rate of the synthesized utterances can be controlled by specifying the desired pitch/energy/duration ratios. 46 | For example, one can increase the speaking rate by 20 % and decrease the volume by 20 % by 47 | 48 | ``` 49 | python3 synthesize.py --text "YOUR_DESIRED_TEXT" --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml --duration_control 0.8 --energy_control 0.8 50 | ``` 51 | Note that the controllability is originated from FastSpeech2 and not a vital interest of StyleSpeech. Please refer to [STYLER](https://arxiv.org/abs/2103.09474) [[demo](https://keonlee9420.github.io/STYLER-Demo/), [code](https://github.com/keonlee9420/STYLER)] for the controllability of each style factor. 52 | 53 | # Training 54 | 55 | ## Datasets 56 | 57 | The supported datasets are 58 | 59 | - [LibriTTS](https://research.google/tools/datasets/libri-tts/): a multi-speaker English dataset containing 585 hours of speech by 2456 speakers. 60 | - (will be added more) 61 | 62 | ## Preprocessing 63 | 64 | Run 65 | ``` 66 | python3 prepare_align.py config/LibriTTS/preprocess.yaml 67 | ``` 68 | for some preparations. 69 | 70 | For the forced alignment, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) (MFA) is used to obtain the alignments between the utterances and the phoneme sequences. 71 | Pre-extracted alignments for the datasets are provided [here](https://drive.google.com/drive/folders/1fizpyOiQ1lG2UDaMlXnT3Ll4_j6Xwg7K?usp=sharing). 72 | You have to unzip the files in `preprocessed_data/LibriTTS/TextGrid/`. Alternately, you can [run the aligner by yourself](https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/workflows/index.html). 73 | 74 | After that, run the preprocessing script by 75 | ``` 76 | python3 preprocess.py config/LibriTTS/preprocess.yaml 77 | ``` 78 | 79 | ## Training 80 | 81 | Train your model with 82 | ``` 83 | python3 train.py -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml 84 | ``` 85 | As described in the paper, the script will start from pre-training the naive model until `meta_learning_warmup` steps and then meta-train the model for additional steps via episodic training. 86 | 87 | # TensorBoard 88 | 89 | Use 90 | ``` 91 | tensorboard --logdir output/log/LibriTTS 92 | ``` 93 | 94 | to serve TensorBoard on your localhost. 95 | The loss curves, synthesized mel-spectrograms, and audios are shown. 96 | 97 | ![](./img/tensorboard_loss.png) 98 | ![](./img/tensorboard_spec.png) 99 | ![](./img/tensorboard_audio.png) 100 | 101 | # Implementation Issues 102 | 103 | 1. Use `22050Hz` sampling rate instead of `16kHz`. 104 | 2. Add one fully connected layer at the beginning of Mel-Style Encoder to upsample input mel-spectrogram from `80` to `128`. 105 | 3. The model size including meta-learner is `28.197M`. 106 | 4. Use a maximum `16` batch size on training instead of `48` or `20` mainly due to the lack of memory capacity with a single **24GiB TITAN-RTX**. This can be achieved by the following script to filter out data longer than `max_seq_len`: 107 | ``` 108 | python3 filelist_filtering.py -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml 109 | ``` 110 | This will generate `train_filtered.txt` in the same location of `train.txt`. 111 | 5. Since the total batch size is decreased, the number of training steps is doubled compared to the original paper. 112 | 6. Use **HiFi-GAN** instead of **MelGAN** for vocoding. 113 | 114 | # Citation 115 | 116 | ``` 117 | @misc{lee2021stylespeech, 118 | author = {Lee, Keon}, 119 | title = {StyleSpeech}, 120 | year = {2021}, 121 | publisher = {GitHub}, 122 | journal = {GitHub repository}, 123 | howpublished = {\url{https://github.com/keonlee9420/StyleSpeech}} 124 | } 125 | ``` 126 | 127 | # References 128 | - [Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation](https://arxiv.org/abs/2106.03153) 129 | - [A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/abs/1812.04948) 130 | - [Matching Networks for One Shot Learning](https://arxiv.org/abs/1606.04080) 131 | - [Prototypical Networks for Few-shot Learning](https://arxiv.org/pdf/1703.05175v2.pdf) 132 | - [TADAM: Task dependent adaptive metric for improved few-shot learning](https://arxiv.org/abs/1805.10123) 133 | - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) 134 | -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data.cuda(), 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes): 152 | output = dynamic_range_compression(magnitudes) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1 170 | assert torch.max(y.data) <= 1 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, energy 179 | -------------------------------------------------------------------------------- /audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | 5 | from audio.audio_processing import griffin_lim 6 | 7 | 8 | def get_mel_from_wav(audio, _stft): 9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 10 | audio = torch.autograd.Variable(audio, requires_grad=False) 11 | melspec, energy = _stft.mel_spectrogram(audio) 12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 14 | 15 | return melspec, energy 16 | 17 | 18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 19 | mel = torch.stack([mel]) 20 | mel_decompress = _stft.spectral_de_normalize(mel) 21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 22 | spec_from_mel_scaling = 1000 23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 26 | 27 | audio = griffin_lim( 28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 29 | ) 30 | 31 | audio = audio.squeeze() 32 | audio = audio.cpu().numpy() 33 | audio_path = out_filename 34 | write(audio_path, _stft.sampling_rate, audio) 35 | -------------------------------------------------------------------------------- /config/LibriTTS/model.yaml: -------------------------------------------------------------------------------- 1 | prenet: 2 | conv_kernel_size: 3 3 | dropout: 0.1 4 | 5 | transformer: 6 | encoder_layer: 4 7 | encoder_head: 2 8 | encoder_hidden: 256 9 | decoder_layer: 4 10 | decoder_head: 2 11 | decoder_hidden: 256 12 | conv_filter_size: 1024 13 | conv_kernel_size: [9, 1] 14 | encoder_dropout: 0.1 15 | decoder_dropout: 0.1 16 | 17 | melencoder: 18 | encoder_hidden: 128 19 | spectral_layer: 2 20 | temporal_layer: 2 21 | slf_attn_layer: 1 22 | slf_attn_head: 2 23 | conv_kernel_size: 5 24 | encoder_dropout: 0.1 25 | 26 | variance_predictor: 27 | filter_size: 256 28 | kernel_size: 3 29 | dropout: 0.5 30 | 31 | variance_embedding: 32 | kernel_size: 9 33 | 34 | discriminator: 35 | mel_linear_size: 256 36 | phoneme_layer: 3 37 | phoneme_hidden: 512 38 | 39 | multi_speaker: True 40 | 41 | max_seq_len: 1000 42 | 43 | vocoder: 44 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN' 45 | speaker: "universal" # support 'LJSpeech', 'universal' 46 | -------------------------------------------------------------------------------- /config/LibriTTS/preprocess.yaml: -------------------------------------------------------------------------------- 1 | dataset: "LibriTTS" 2 | 3 | path: 4 | corpus_path: "/mnt/nfs2/speech-datasets/en/LibriTTS/train-clean-100" 5 | lexicon_path: "lexicon/librispeech-lexicon.txt" 6 | raw_path: "./raw_data/LibriTTS" 7 | preprocessed_path: "./preprocessed_data/LibriTTS" 8 | 9 | preprocessing: 10 | val_size: 512 11 | text: 12 | text_cleaners: ["english_cleaners"] 13 | language: "en" 14 | audio: 15 | sampling_rate: 22050 16 | max_wav_value: 32768.0 17 | stft: 18 | filter_length: 1024 19 | hop_length: 256 20 | win_length: 1024 21 | mel: 22 | n_mel_channels: 80 23 | mel_fmin: 0 24 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 25 | pitch: 26 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 27 | normalization: True 28 | energy: 29 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 30 | normalization: True 31 | -------------------------------------------------------------------------------- /config/LibriTTS/train.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/LibriTTS_meta_learner" 3 | log_path: "./output/log/LibriTTS_meta_learner" 4 | result_path: "./output/result/LibriTTS_meta_learner" 5 | optimizer: 6 | batch_size: 16 7 | betas: [0.9, 0.98] 8 | eps: 0.000000001 9 | weight_decay: 0.0 10 | grad_clip_thresh: 1.0 11 | grad_acc_step: 1 12 | warm_up_step: 4000 13 | anneal_steps: [300000, 400000, 500000] 14 | anneal_rate: 0.3 15 | lr_disc: 0.00002 16 | alpha: 10 17 | step: 18 | meta_learning_warmup: 120000 19 | total_step: 200000 20 | log_step: 100 21 | synth_step: 1000 22 | val_step: 1000 23 | save_step: 40000 24 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | 9 | from text import text_to_sequence 10 | from utils.tools import pad_1D, pad_2D, expand 11 | 12 | random.seed(1234) 13 | 14 | 15 | class Dataset(Dataset): 16 | def __init__( 17 | self, filename, preprocess_config, train_config, sort=False, drop_last=False 18 | ): 19 | self.dataset_name = preprocess_config["dataset"] 20 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 21 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 22 | self.batch_size = train_config["optimizer"]["batch_size"] 23 | 24 | self.basename, self.speaker, self.text, self.raw_text, self.speaker_to_ids = self.process_meta( 25 | filename 26 | ) 27 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: 28 | self.speaker_map = json.load(f) 29 | self.sort = sort 30 | self.drop_last = drop_last 31 | 32 | def __len__(self): 33 | return len(self.text) 34 | 35 | def __getitem__(self, idx): 36 | basename = self.basename[idx] 37 | speaker = self.speaker[idx] 38 | speaker_id = self.speaker_map[speaker] 39 | raw_text = self.raw_text[idx] 40 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 41 | query_idx = random.choice(self.speaker_to_ids[speaker]) # Sample the query text 42 | raw_quary_text = self.raw_text[query_idx] 43 | query_phone = np.array(text_to_sequence(self.text[query_idx], self.cleaners)) 44 | mel_path = os.path.join( 45 | self.preprocessed_path, 46 | "mel", 47 | "{}-mel-{}.npy".format(speaker, basename), 48 | ) 49 | mel = np.load(mel_path) 50 | pitch_path = os.path.join( 51 | self.preprocessed_path, 52 | "pitch", 53 | "{}-pitch-{}.npy".format(speaker, basename), 54 | ) 55 | pitch = np.load(pitch_path) 56 | energy_path = os.path.join( 57 | self.preprocessed_path, 58 | "energy", 59 | "{}-energy-{}.npy".format(speaker, basename), 60 | ) 61 | energy = np.load(energy_path) 62 | duration_path = os.path.join( 63 | self.preprocessed_path, 64 | "duration", 65 | "{}-duration-{}.npy".format(speaker, basename), 66 | ) 67 | duration = np.load(duration_path) 68 | quary_duration_path = os.path.join( 69 | self.preprocessed_path, 70 | "duration", 71 | "{}-duration-{}.npy".format(self.speaker[query_idx], self.basename[query_idx]), 72 | ) 73 | quary_duration = np.load(quary_duration_path) 74 | 75 | sample = { 76 | "id": basename, 77 | "speaker": speaker_id, 78 | "text": phone, 79 | "raw_text": raw_text, 80 | "quary_text": query_phone, 81 | "raw_quary_text": raw_quary_text, 82 | "mel": mel, 83 | "pitch": pitch, 84 | "energy": energy, 85 | "duration": duration, 86 | "quary_duration": quary_duration, 87 | } 88 | 89 | return sample 90 | 91 | def process_meta(self, filename): 92 | with open( 93 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8" 94 | ) as f: 95 | name = [] 96 | speaker = [] 97 | text = [] 98 | raw_text = [] 99 | speaker_to_ids = dict() 100 | for i, line in enumerate(f.readlines()): 101 | n, s, t, r = line.strip("\n").split("|") 102 | name.append(n) 103 | speaker.append(s) 104 | text.append(t) 105 | raw_text.append(r) 106 | if s not in speaker_to_ids: 107 | speaker_to_ids[s] = [i] 108 | else: 109 | speaker_to_ids[s] += [i] 110 | return name, speaker, text, raw_text, speaker_to_ids 111 | 112 | def reprocess(self, data, idxs): 113 | ids = [data[idx]["id"] for idx in idxs] 114 | speakers = [data[idx]["speaker"] for idx in idxs] 115 | texts = [data[idx]["text"] for idx in idxs] 116 | raw_texts = [data[idx]["raw_text"] for idx in idxs] 117 | quary_texts = [data[idx]["quary_text"] for idx in idxs] 118 | raw_quary_texts = [data[idx]["raw_quary_text"] for idx in idxs] 119 | mels = [data[idx]["mel"] for idx in idxs] 120 | pitches = [data[idx]["pitch"] for idx in idxs] 121 | energies = [data[idx]["energy"] for idx in idxs] 122 | durations = [data[idx]["duration"] for idx in idxs] 123 | quary_durations = [data[idx]["quary_duration"] for idx in idxs] 124 | 125 | text_lens = np.array([text.shape[0] for text in texts]) 126 | quary_text_lens = np.array([text.shape[0] for text in quary_texts]) 127 | mel_lens = np.array([mel.shape[0] for mel in mels]) 128 | 129 | speakers = np.array(speakers) 130 | texts = pad_1D(texts) 131 | quary_texts = pad_1D(quary_texts) 132 | mels = pad_2D(mels) 133 | pitches = pad_1D(pitches) 134 | energies = pad_1D(energies) 135 | durations = pad_1D(durations) 136 | quary_durations = pad_1D(quary_durations) 137 | 138 | return ( 139 | ids, 140 | raw_texts, 141 | speakers, 142 | texts, 143 | text_lens, 144 | max(text_lens), 145 | mels, 146 | mel_lens, 147 | max(mel_lens), 148 | pitches, 149 | energies, 150 | durations, 151 | raw_quary_texts, 152 | quary_texts, 153 | quary_text_lens, 154 | max(quary_text_lens), 155 | quary_durations, 156 | ) 157 | 158 | def collate_fn(self, data): 159 | data_size = len(data) 160 | 161 | if self.sort: 162 | len_arr = np.array([d["text"].shape[0] for d in data]) 163 | idx_arr = np.argsort(-len_arr) 164 | else: 165 | idx_arr = np.arange(data_size) 166 | 167 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :] 168 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)] 169 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() 170 | if not self.drop_last and len(tail) > 0: 171 | idx_arr += [tail.tolist()] 172 | 173 | output = list() 174 | for idx in idx_arr: 175 | output.append(self.reprocess(data, idx)) 176 | 177 | return output 178 | 179 | 180 | class BatchInferenceDataset(Dataset): 181 | def __init__(self, filepath, preprocess_config): 182 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 183 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"] 184 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"] 185 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 186 | 187 | self.basename, self.speaker, self.text, self.raw_text = self.process_meta( 188 | filepath 189 | ) 190 | with open( 191 | os.path.join( 192 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 193 | ) 194 | ) as f: 195 | self.speaker_map = json.load(f) 196 | 197 | def __len__(self): 198 | return len(self.text) 199 | 200 | def __getitem__(self, idx): 201 | basename = self.basename[idx] 202 | speaker = self.speaker[idx] 203 | speaker_id = self.speaker_map[speaker] 204 | raw_text = self.raw_text[idx] 205 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 206 | mel_path = os.path.join( 207 | self.preprocessed_path, 208 | "mel", 209 | "{}-mel-{}.npy".format(speaker, basename), 210 | ) 211 | mel = np.load(mel_path) 212 | pitch_path = os.path.join( 213 | self.preprocessed_path, 214 | "pitch", 215 | "{}-pitch-{}.npy".format(speaker, basename), 216 | ) 217 | pitch = np.load(pitch_path) 218 | energy_path = os.path.join( 219 | self.preprocessed_path, 220 | "energy", 221 | "{}-energy-{}.npy".format(speaker, basename), 222 | ) 223 | energy = np.load(energy_path) 224 | duration_path = os.path.join( 225 | self.preprocessed_path, 226 | "duration", 227 | "{}-duration-{}.npy".format(speaker, basename), 228 | ) 229 | duration = np.load(duration_path) 230 | 231 | return (basename, speaker_id, phone, raw_text, mel, pitch, energy, duration) 232 | 233 | def process_meta(self, filename): 234 | with open(filename, "r", encoding="utf-8") as f: 235 | name = [] 236 | speaker = [] 237 | text = [] 238 | raw_text = [] 239 | for line in f.readlines(): 240 | n, s, t, r = line.strip("\n").split("|") 241 | name.append(n) 242 | speaker.append(s) 243 | text.append(t) 244 | raw_text.append(r) 245 | return name, speaker, text, raw_text 246 | 247 | def collate_fn(self, data): 248 | ids = [d[0] for d in data] 249 | speakers = np.array([d[1] for d in data]) 250 | texts = [d[2] for d in data] 251 | raw_texts = [d[3] for d in data] 252 | mels = [d[4] for d in data] 253 | pitches = [d[5] for d in data] 254 | energies = [d[6] for d in data] 255 | durations = [d[7] for d in data] 256 | 257 | text_lens = np.array([text.shape[0] for text in texts]) 258 | mel_lens = np.array([mel.shape[0] for mel in mels]) 259 | 260 | ref_infos = list() 261 | for _, (m, p, e, d) in enumerate(zip(mels, pitches, energies, durations)): 262 | if self.pitch_feature_level == "phoneme_level": 263 | pitch = expand(p, d) 264 | else: 265 | pitch = p 266 | if self.energy_feature_level == "phoneme_level": 267 | energy = expand(e, d) 268 | else: 269 | energy = e 270 | ref_infos.append((m.T, pitch, energy)) 271 | 272 | texts = pad_1D(texts) 273 | mels = pad_2D(mels) 274 | 275 | return ( 276 | ids, 277 | raw_texts, 278 | speakers, 279 | texts, 280 | text_lens, 281 | max(text_lens), 282 | mels, 283 | mel_lens, 284 | max(mel_lens), 285 | ref_infos, 286 | ) 287 | -------------------------------------------------------------------------------- /demo/LibriTTS/references/0_19_198_000000_000000.lab: -------------------------------------------------------------------------------- 1 | this is a librivox recording. -------------------------------------------------------------------------------- /demo/LibriTTS/references/0_19_198_000000_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/0_19_198_000000_000000.wav -------------------------------------------------------------------------------- /demo/LibriTTS/references/1_26_495_000004_000000.lab: -------------------------------------------------------------------------------- 1 | by daniel defoe -------------------------------------------------------------------------------- /demo/LibriTTS/references/1_26_495_000004_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/1_26_495_000004_000000.wav -------------------------------------------------------------------------------- /demo/LibriTTS/references/2_27_123349_000001_000000.lab: -------------------------------------------------------------------------------- 1 | at length all differences were compromised. -------------------------------------------------------------------------------- /demo/LibriTTS/references/2_27_123349_000001_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/2_27_123349_000001_000000.wav -------------------------------------------------------------------------------- /demo/LibriTTS/references/3_32_4137_000005_000001.lab: -------------------------------------------------------------------------------- 1 | as many of the slaves had been brought up in richmond, and had relations residing there, the slave trader determined to leave the city early in the morning, so as not to witness any of those scenes so common where slaves are separated from their relatives and friends, when about departing for the southern market. -------------------------------------------------------------------------------- /demo/LibriTTS/references/3_32_4137_000005_000001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/3_32_4137_000005_000001.wav -------------------------------------------------------------------------------- /demo/LibriTTS/references/4_39_121916_000015_000005.lab: -------------------------------------------------------------------------------- 1 | ours are all apple tarts. -------------------------------------------------------------------------------- /demo/LibriTTS/references/4_39_121916_000015_000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/4_39_121916_000015_000005.wav -------------------------------------------------------------------------------- /demo/LibriTTS/references/5_40_222_000001_000000.lab: -------------------------------------------------------------------------------- 1 | chapter twenty five -------------------------------------------------------------------------------- /demo/LibriTTS/references/5_40_222_000001_000000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/5_40_222_000001_000000.wav -------------------------------------------------------------------------------- /demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..png -------------------------------------------------------------------------------- /demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..wav -------------------------------------------------------------------------------- /demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..png -------------------------------------------------------------------------------- /demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..wav -------------------------------------------------------------------------------- /demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..png -------------------------------------------------------------------------------- /demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..wav -------------------------------------------------------------------------------- /demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..png -------------------------------------------------------------------------------- /demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..wav -------------------------------------------------------------------------------- /demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..png -------------------------------------------------------------------------------- /demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..wav -------------------------------------------------------------------------------- /demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..png -------------------------------------------------------------------------------- /demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..wav -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | 9 | from utils.model import get_model, get_vocoder 10 | from utils.tools import to_device, log, synth_one_sample 11 | from model import MetaStyleSpeechLossMain 12 | from dataset import Dataset 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def evaluate(model, step, configs, logger=None, vocoder=None, loss_len=5): 19 | preprocess_config, model_config, train_config = configs 20 | 21 | # Get dataset 22 | dataset = Dataset( 23 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False 24 | ) 25 | batch_size = train_config["optimizer"]["batch_size"] 26 | loader = DataLoader( 27 | dataset, 28 | batch_size=batch_size, 29 | shuffle=False, 30 | collate_fn=dataset.collate_fn, 31 | ) 32 | 33 | # Get loss function 34 | Loss = MetaStyleSpeechLossMain(preprocess_config, model_config, train_config).to(device) 35 | 36 | # Evaluation 37 | loss_sums = [0 for _ in range(loss_len)] 38 | for batchs in loader: 39 | for batch in batchs: 40 | batch = to_device(batch, device) 41 | with torch.no_grad(): 42 | # Forward 43 | output = (None, None, *model(*(batch[2:-5]))) 44 | 45 | # Cal Loss 46 | losses = Loss(batch, output) 47 | 48 | for i in range(len(losses)): 49 | loss_sums[i] += losses[i].item() * len(batch[0]) 50 | 51 | loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums] 52 | 53 | message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format( 54 | *([step] + [l for l in loss_means[:5]]) 55 | ) 56 | 57 | if logger is not None: 58 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 59 | batch, 60 | output[2:], 61 | vocoder, 62 | model_config, 63 | preprocess_config, 64 | ) 65 | 66 | log(logger, step, losses=loss_means) 67 | log( 68 | logger, 69 | fig=fig, 70 | tag="Validation/step_{}_{}".format(step, tag), 71 | ) 72 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 73 | log( 74 | logger, 75 | audio=wav_reconstruction, 76 | sampling_rate=sampling_rate, 77 | tag="Validation/step_{}_{}_reconstructed".format(step, tag), 78 | ) 79 | log( 80 | logger, 81 | audio=wav_prediction, 82 | sampling_rate=sampling_rate, 83 | tag="Validation/step_{}_{}_synthesized".format(step, tag), 84 | ) 85 | 86 | return message 87 | -------------------------------------------------------------------------------- /filelist_filtering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import yaml 4 | import argparse 5 | 6 | def main(preprocess_config, model_config): 7 | preprocessed_path = preprocess_config["path"]["preprocessed_path"] 8 | max_seq_len = model_config["max_seq_len"] 9 | 10 | with open( 11 | os.path.join(preprocessed_path, "train.txt"), "r", encoding="utf-8" 12 | ) as f: 13 | filtered_list = [] 14 | for i, line in enumerate(f.readlines()): 15 | basename, speaker, *_ = line.strip("\n").split("|") 16 | mel_path = os.path.join( 17 | preprocessed_path, 18 | "mel", 19 | "{}-mel-{}.npy".format(speaker, basename), 20 | ) 21 | mel = np.load(mel_path) 22 | if mel.shape[0] <= max_seq_len: 23 | filtered_list.append(line) 24 | 25 | # Write Filtered Filelist 26 | with open(os.path.join(preprocessed_path, "train_filtered.txt"), "w", encoding="utf-8") as f: 27 | for line in filtered_list: 28 | f.write(line) 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "-p", 34 | "--preprocess_config", 35 | type=str, 36 | required=True, 37 | help="path to preprocess.yaml", 38 | ) 39 | parser.add_argument( 40 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 41 | ) 42 | args = parser.parse_args() 43 | 44 | # Read Config 45 | preprocess_config = yaml.load( 46 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 47 | ) 48 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 49 | 50 | main(preprocess_config, model_config) -------------------------------------------------------------------------------- /hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self -------------------------------------------------------------------------------- /hifigan/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /hifigan/generator_LJSpeech.pth.tar.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/hifigan/generator_LJSpeech.pth.tar.zip -------------------------------------------------------------------------------- /hifigan/generator_universal.pth.tar.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/hifigan/generator_universal.pth.tar.zip -------------------------------------------------------------------------------- /hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2 ** i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) -------------------------------------------------------------------------------- /img/model_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/model_1.png -------------------------------------------------------------------------------- /img/model_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/model_2.png -------------------------------------------------------------------------------- /img/tensorboard_audio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/tensorboard_audio.png -------------------------------------------------------------------------------- /img/tensorboard_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/tensorboard_loss.png -------------------------------------------------------------------------------- /img/tensorboard_spec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/tensorboard_spec.png -------------------------------------------------------------------------------- /model/StyleSpeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .modules import ( 9 | MelStyleEncoder, 10 | PhonemeEncoder, 11 | MelDecoder, 12 | VarianceAdaptor, 13 | PhonemeDiscriminator, 14 | StyleDiscriminator, 15 | ) 16 | from utils.tools import get_mask_from_lengths 17 | 18 | 19 | class StyleSpeech(nn.Module): 20 | """ StyleSpeech """ 21 | 22 | def __init__(self, preprocess_config, model_config): 23 | super(StyleSpeech, self).__init__() 24 | self.model_config = model_config 25 | 26 | self.mel_style_encoder = MelStyleEncoder(preprocess_config, model_config) 27 | self.phoneme_encoder = PhonemeEncoder(model_config) 28 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config) 29 | self.mel_decoder = MelDecoder(model_config) 30 | self.phoneme_linear = nn.Linear( 31 | model_config["transformer"]["encoder_hidden"], 32 | model_config["transformer"]["encoder_hidden"], 33 | ) 34 | self.mel_linear = nn.Linear( 35 | model_config["transformer"]["decoder_hidden"], 36 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 37 | ) 38 | self.D_t = PhonemeDiscriminator(preprocess_config, model_config) 39 | self.D_s = StyleDiscriminator(preprocess_config, model_config) 40 | 41 | with open( 42 | os.path.join( 43 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 44 | ), 45 | "r", 46 | ) as f: 47 | n_speaker = len(json.load(f)) 48 | self.style_prototype = nn.Embedding( 49 | n_speaker, 50 | model_config["melencoder"]["encoder_hidden"], 51 | ) 52 | 53 | def G( 54 | self, 55 | style_vector, 56 | texts, 57 | src_masks, 58 | mel_masks, 59 | max_mel_len, 60 | p_targets=None, 61 | e_targets=None, 62 | d_targets=None, 63 | p_control=1.0, 64 | e_control=1.0, 65 | d_control=1.0, 66 | ): 67 | output = self.phoneme_encoder(texts, style_vector, src_masks) 68 | output = self.phoneme_linear(output) 69 | 70 | ( 71 | output, 72 | p_predictions, 73 | e_predictions, 74 | log_d_predictions, 75 | d_rounded, 76 | mel_lens, 77 | mel_masks, 78 | ) = self.variance_adaptor( 79 | output, 80 | src_masks, 81 | mel_masks, 82 | max_mel_len, 83 | p_targets, 84 | e_targets, 85 | d_targets, 86 | p_control, 87 | e_control, 88 | d_control, 89 | ) 90 | 91 | output, mel_masks = self.mel_decoder(output, style_vector, mel_masks) 92 | output = self.mel_linear(output) 93 | 94 | return ( 95 | output, 96 | p_predictions, 97 | e_predictions, 98 | log_d_predictions, 99 | d_rounded, 100 | mel_lens, 101 | mel_masks, 102 | ) 103 | 104 | def forward( 105 | self, 106 | _, 107 | texts, 108 | src_lens, 109 | max_src_len, 110 | mels, 111 | mel_lens, 112 | max_mel_len, 113 | p_targets=None, 114 | e_targets=None, 115 | d_targets=None, 116 | p_control=1.0, 117 | e_control=1.0, 118 | d_control=1.0, 119 | ): 120 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 121 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len) 122 | 123 | style_vector = self.mel_style_encoder(mels, mel_masks) 124 | 125 | ( 126 | output, 127 | p_predictions, 128 | e_predictions, 129 | log_d_predictions, 130 | d_rounded, 131 | mel_lens, 132 | mel_masks, 133 | ) = self.G( 134 | style_vector, 135 | texts, 136 | src_masks, 137 | mel_masks, 138 | max_mel_len, 139 | p_targets, 140 | e_targets, 141 | d_targets, 142 | p_control, 143 | e_control, 144 | d_control, 145 | ) 146 | 147 | return ( 148 | output, 149 | p_predictions, 150 | e_predictions, 151 | log_d_predictions, 152 | d_rounded, 153 | src_masks, 154 | mel_masks, 155 | src_lens, 156 | mel_lens, 157 | ) 158 | 159 | def meta_learner_1( 160 | self, 161 | speakers, 162 | texts, 163 | src_lens, 164 | max_src_len, 165 | mels, 166 | mel_lens, 167 | max_mel_len, 168 | p_targets=None, 169 | e_targets=None, 170 | d_targets=None, 171 | raw_quary_texts=None, 172 | quary_texts=None, 173 | quary_src_lens=None, 174 | max_quary_src_len=None, 175 | quary_d_targets=None, 176 | p_control=1.0, 177 | e_control=1.0, 178 | d_control=1.0, 179 | ): 180 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 181 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len) 182 | 183 | quary_mel_lens = quary_d_targets.sum(dim=-1) 184 | max_quary_mel_len = max(quary_mel_lens).item() 185 | quary_src_masks = get_mask_from_lengths(quary_src_lens, max_quary_src_len) 186 | quary_mel_masks = get_mask_from_lengths(quary_mel_lens, max_quary_mel_len) 187 | 188 | style_vector = self.mel_style_encoder(mels, mel_masks) 189 | 190 | ( 191 | output, 192 | _, 193 | _, 194 | _, 195 | d_rounded_adv, 196 | mel_lens_adv, 197 | mel_masks_adv, 198 | ) = self.G( 199 | style_vector, 200 | quary_texts, 201 | quary_src_masks, 202 | quary_mel_masks, 203 | max_quary_mel_len, 204 | None, 205 | None, 206 | None, 207 | p_control, 208 | e_control, 209 | d_control, 210 | ) 211 | 212 | D_s = self.D_s(self.style_prototype, speakers, output, mel_masks_adv) 213 | 214 | quary_texts = self.phoneme_encoder.src_word_emb(quary_texts) 215 | D_t = self.D_t(self.variance_adaptor.upsample, quary_texts, output, max(mel_lens_adv).item(), mel_masks_adv, d_rounded_adv) 216 | 217 | ( 218 | G, 219 | p_predictions, 220 | e_predictions, 221 | log_d_predictions, 222 | d_rounded, 223 | mel_lens, 224 | mel_masks, 225 | ) = self.G( 226 | style_vector, 227 | texts, 228 | src_masks, 229 | mel_masks, 230 | max_mel_len, 231 | p_targets, 232 | e_targets, 233 | d_targets, 234 | p_control, 235 | e_control, 236 | d_control, 237 | ) 238 | 239 | return ( 240 | D_s, 241 | D_t, 242 | G, 243 | p_predictions, 244 | e_predictions, 245 | log_d_predictions, 246 | d_rounded, 247 | src_masks, 248 | mel_masks, 249 | src_lens, 250 | mel_lens, 251 | ) 252 | 253 | def meta_learner_2( 254 | self, 255 | speakers, 256 | texts, 257 | src_lens, 258 | max_src_len, 259 | mels, 260 | mel_lens, 261 | max_mel_len, 262 | p_targets=None, 263 | e_targets=None, 264 | d_targets=None, 265 | raw_quary_texts=None, 266 | quary_texts=None, 267 | quary_src_lens=None, 268 | max_quary_src_len=None, 269 | quary_d_targets=None, 270 | p_control=1.0, 271 | e_control=1.0, 272 | d_control=1.0, 273 | ): 274 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 275 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len) 276 | 277 | quary_mel_lens = quary_d_targets.sum(dim=-1) 278 | max_quary_mel_len = max(quary_mel_lens).item() 279 | quary_src_masks = get_mask_from_lengths(quary_src_lens, max_quary_src_len) 280 | quary_mel_masks = get_mask_from_lengths(quary_mel_lens, max_quary_mel_len) 281 | 282 | style_vector = self.mel_style_encoder(mels, mel_masks) 283 | 284 | ( 285 | output, 286 | _, 287 | _, 288 | _, 289 | d_rounded_adv, 290 | mel_lens_adv, 291 | mel_masks_adv, 292 | ) = self.G( 293 | style_vector, 294 | quary_texts, 295 | quary_src_masks, 296 | quary_mel_masks, 297 | max_quary_mel_len, 298 | None, 299 | None, 300 | None, 301 | p_control, 302 | e_control, 303 | d_control, 304 | ) 305 | 306 | texts = self.phoneme_encoder.src_word_emb(texts) 307 | D_t_s = self.D_t(self.variance_adaptor.upsample, texts, mels, max_mel_len, mel_masks, d_targets) 308 | 309 | quary_texts = self.phoneme_encoder.src_word_emb(quary_texts) 310 | D_t_q = self.D_t(self.variance_adaptor.upsample, quary_texts, output, max(mel_lens_adv).item(), mel_masks_adv, d_rounded_adv) 311 | 312 | D_s_s = self.D_s(self.style_prototype, speakers, mels, mel_masks) 313 | D_s_q = self.D_s(self.style_prototype, speakers, output, mel_masks_adv) 314 | 315 | # Get Style Logit 316 | w = style_vector.squeeze() # [B, H] 317 | style_logit = torch.matmul(w, self.style_prototype.weight.contiguous().transpose(0, 1)) # [B, K] 318 | 319 | return D_t_s, D_t_q, D_s_s, D_s_q, style_logit 320 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .StyleSpeech import StyleSpeech 2 | from .loss import MetaStyleSpeechLossMain, MetaStyleSpeechLossDisc 3 | from .optimizer import ScheduledOptimMain, ScheduledOptimDisc -------------------------------------------------------------------------------- /model/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import functional as F 5 | 6 | 7 | class Mish(nn.Module): 8 | def forward(self, x): 9 | return x * torch.tanh(F.softplus(x)) 10 | 11 | 12 | class StyleAdaptiveLayerNorm(nn.Module): 13 | """ Style-Adaptive Layer Norm (SALN) """ 14 | 15 | def __init__(self, w_size, hidden_size, bias=False): 16 | super(StyleAdaptiveLayerNorm, self).__init__() 17 | self.hidden_size = hidden_size 18 | self.affine_layer = LinearNorm( 19 | w_size, 20 | 2 * hidden_size, # For both b (bias) g (gain) 21 | bias, 22 | ) 23 | 24 | def forward(self, h, w): 25 | """ 26 | h --- [B, T, H_m] 27 | w --- [B, 1, H_w] 28 | o --- [B, T, H_m] 29 | """ 30 | 31 | # Normalize Input Features 32 | mu, sigma = torch.mean(h, dim=-1, keepdim=True), torch.std(h, dim=-1, keepdim=True) 33 | y = (h - mu) / sigma # [B, T, H_m] 34 | 35 | # Get Bias and Gain 36 | b, g = torch.split(self.affine_layer(w), self.hidden_size, dim=-1) # [B, 1, 2 * H_m] --> 2 * [B, 1, H_m] 37 | 38 | # Perform Scailing and Shifting 39 | o = g * y + b # [B, T, H_m] 40 | 41 | return o 42 | 43 | 44 | class FCBlock(nn.Module): 45 | """ Fully Connected Block """ 46 | 47 | def __init__(self, in_features, out_features, activation=None, bias=False, dropout=None, spectral_norm=False): 48 | super(FCBlock, self).__init__() 49 | self.fc_layer = nn.Sequential() 50 | self.fc_layer.add_module( 51 | "fc_layer", 52 | LinearNorm( 53 | in_features, 54 | out_features, 55 | bias, 56 | spectral_norm, 57 | ), 58 | ) 59 | if activation is not None: 60 | self.fc_layer.add_module("activ", activation) 61 | self.dropout = dropout 62 | 63 | def forward(self, x): 64 | x = self.fc_layer(x) 65 | if self.dropout is not None: 66 | x = F.dropout(x, self.dropout, self.training) 67 | return x 68 | 69 | 70 | class LinearNorm(nn.Module): 71 | """ LinearNorm Projection """ 72 | 73 | def __init__(self, in_features, out_features, bias=False, spectral_norm=False): 74 | super(LinearNorm, self).__init__() 75 | self.linear = nn.Linear(in_features, out_features, bias) 76 | 77 | nn.init.xavier_uniform_(self.linear.weight) 78 | if bias: 79 | nn.init.constant_(self.linear.bias, 0.0) 80 | if spectral_norm: 81 | self.linear = nn.utils.spectral_norm(self.linear) 82 | 83 | def forward(self, x): 84 | x = self.linear(x) 85 | return x 86 | 87 | 88 | class Conv1DBlock(nn.Module): 89 | """ 1D Convolutional Block """ 90 | 91 | def __init__(self, in_channels, out_channels, kernel_size, activation=None, dropout=None, spectral_norm=False): 92 | super(Conv1DBlock, self).__init__() 93 | 94 | self.conv_layer = nn.Sequential() 95 | self.conv_layer.add_module( 96 | "conv_layer", 97 | ConvNorm( 98 | in_channels, 99 | out_channels, 100 | kernel_size=kernel_size, 101 | stride=1, 102 | padding=int((kernel_size - 1) / 2), 103 | dilation=1, 104 | w_init_gain="tanh", 105 | spectral_norm=spectral_norm, 106 | ), 107 | ) 108 | if activation is not None: 109 | self.conv_layer.add_module("activ", activation) 110 | self.dropout = dropout 111 | 112 | def forward(self, x, mask=None): 113 | x = x.contiguous().transpose(1, 2) 114 | x = self.conv_layer(x) 115 | 116 | if self.dropout is not None: 117 | x = F.dropout(x, self.dropout, self.training) 118 | 119 | x = x.contiguous().transpose(1, 2) 120 | if mask is not None: 121 | x = x.masked_fill(mask.unsqueeze(-1), 0) 122 | 123 | return x 124 | 125 | 126 | class ConvNorm(nn.Module): 127 | """ 1D Convolution """ 128 | 129 | def __init__( 130 | self, 131 | in_channels, 132 | out_channels, 133 | kernel_size=1, 134 | stride=1, 135 | padding=None, 136 | dilation=1, 137 | bias=True, 138 | w_init_gain="linear", 139 | spectral_norm=False, 140 | ): 141 | super(ConvNorm, self).__init__() 142 | 143 | if padding is None: 144 | assert kernel_size % 2 == 1 145 | padding = int(dilation * (kernel_size - 1) / 2) 146 | 147 | self.conv = nn.Conv1d( 148 | in_channels, 149 | out_channels, 150 | kernel_size=kernel_size, 151 | stride=stride, 152 | padding=padding, 153 | dilation=dilation, 154 | bias=bias, 155 | ) 156 | if spectral_norm: 157 | self.conv = nn.utils.spectral_norm(self.conv) 158 | 159 | def forward(self, signal): 160 | conv_signal = self.conv(signal) 161 | 162 | return conv_signal 163 | 164 | 165 | class SALNFFTBlock(nn.Module): 166 | """ FFT Block with SALN """ 167 | 168 | def __init__(self, d_model, d_w, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 169 | super(SALNFFTBlock, self).__init__() 170 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 171 | self.pos_ffn = PositionwiseFeedForward( 172 | d_model, d_inner, kernel_size, dropout=dropout 173 | ) 174 | self.layer_norm_1 = StyleAdaptiveLayerNorm(d_w, d_model) 175 | self.layer_norm_2 = StyleAdaptiveLayerNorm(d_w, d_model) 176 | 177 | def forward(self, enc_input, w, mask=None, slf_attn_mask=None): 178 | enc_output, enc_slf_attn = self.slf_attn( 179 | enc_input, enc_input, enc_input, mask=slf_attn_mask 180 | ) 181 | enc_output = self.layer_norm_1(enc_output, w) 182 | if mask is not None: 183 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 184 | 185 | enc_output = self.pos_ffn(enc_output) 186 | enc_output = self.layer_norm_2(enc_output, w) 187 | if mask is not None: 188 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 189 | 190 | return enc_output, enc_slf_attn 191 | 192 | 193 | class MultiHeadAttention(nn.Module): 194 | """ Multi-Head Attention """ 195 | 196 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, layer_norm=False, spectral_norm=False): 197 | super(MultiHeadAttention, self).__init__() 198 | 199 | self.n_head = n_head 200 | self.d_k = d_k 201 | self.d_v = d_v 202 | 203 | self.w_qs = LinearNorm(d_model, n_head * d_k, spectral_norm=spectral_norm) 204 | self.w_ks = LinearNorm(d_model, n_head * d_k, spectral_norm=spectral_norm) 205 | self.w_vs = LinearNorm(d_model, n_head * d_v, spectral_norm=spectral_norm) 206 | 207 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 208 | self.layer_norm = nn.LayerNorm(d_model) if layer_norm else None 209 | 210 | self.fc = LinearNorm(n_head * d_v, d_model, spectral_norm=spectral_norm) 211 | 212 | self.dropout = nn.Dropout(dropout) 213 | 214 | def forward(self, q, k, v, mask=None): 215 | 216 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 217 | 218 | sz_b, len_q, _ = q.size() 219 | sz_b, len_k, _ = k.size() 220 | sz_b, len_v, _ = v.size() 221 | 222 | residual = q 223 | 224 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 225 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 226 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 227 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 228 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 229 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 230 | 231 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 232 | output, attn = self.attention(q, k, v, mask=mask) 233 | 234 | output = output.view(n_head, sz_b, len_q, d_v) 235 | output = ( 236 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 237 | ) # b x lq x (n*dv) 238 | 239 | output = self.dropout(self.fc(output)) 240 | output = output + residual 241 | if self.layer_norm is not None: 242 | output = self.layer_norm(output) 243 | 244 | return output, attn 245 | 246 | 247 | class ScaledDotProductAttention(nn.Module): 248 | """ Scaled Dot-Product Attention """ 249 | 250 | def __init__(self, temperature): 251 | super(ScaledDotProductAttention, self).__init__() 252 | self.temperature = temperature 253 | self.softmax = nn.Softmax(dim=2) 254 | 255 | def forward(self, q, k, v, mask=None): 256 | 257 | attn = torch.bmm(q, k.transpose(1, 2)) 258 | attn = attn / self.temperature 259 | 260 | if mask is not None: 261 | attn = attn.masked_fill(mask, -np.inf) 262 | 263 | attn = self.softmax(attn) 264 | output = torch.bmm(attn, v) 265 | 266 | return output, attn 267 | 268 | 269 | class PositionwiseFeedForward(nn.Module): 270 | """ A two-feed-forward-layer """ 271 | 272 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 273 | super(PositionwiseFeedForward, self).__init__() 274 | 275 | # Use Conv1D 276 | # position-wise 277 | self.w_1 = nn.Conv1d( 278 | d_in, 279 | d_hid, 280 | kernel_size=kernel_size[0], 281 | padding=(kernel_size[0] - 1) // 2, 282 | ) 283 | # position-wise 284 | self.w_2 = nn.Conv1d( 285 | d_hid, 286 | d_in, 287 | kernel_size=kernel_size[1], 288 | padding=(kernel_size[1] - 1) // 2, 289 | ) 290 | self.dropout = nn.Dropout(dropout) 291 | 292 | def forward(self, x): 293 | residual = x 294 | output = x.transpose(1, 2) 295 | output = self.w_2(F.relu(self.w_1(output))) 296 | output = output.transpose(1, 2) 297 | output = self.dropout(output) 298 | output = output + residual 299 | 300 | return output 301 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MetaStyleSpeechLossMain(nn.Module): 6 | """ Meta-StyleSpeech Loss for naive StyleSpeech and Step 1 """ 7 | 8 | def __init__(self, preprocess_config, model_config, train_config): 9 | super(MetaStyleSpeechLossMain, self).__init__() 10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 11 | "feature" 12 | ] 13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 14 | "feature" 15 | ] 16 | self.alpha = train_config["optimizer"]["alpha"] 17 | self.mse_loss = nn.MSELoss() 18 | self.mae_loss = nn.L1Loss() 19 | 20 | def forward(self, inputs, predictions): 21 | ( 22 | mel_targets, 23 | _, 24 | _, 25 | pitch_targets, 26 | energy_targets, 27 | duration_targets, 28 | _, 29 | _, 30 | _, 31 | _, 32 | _, 33 | ) = inputs[6:] 34 | ( 35 | D_s, 36 | D_t, 37 | mel_predictions, 38 | pitch_predictions, 39 | energy_predictions, 40 | log_duration_predictions, 41 | _, 42 | src_masks, 43 | mel_masks, 44 | _, 45 | _, 46 | ) = predictions 47 | src_masks = ~src_masks 48 | mel_masks = ~mel_masks 49 | log_duration_targets = torch.log(duration_targets.float() + 1) 50 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 51 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 52 | 53 | log_duration_targets.requires_grad = False 54 | pitch_targets.requires_grad = False 55 | energy_targets.requires_grad = False 56 | mel_targets.requires_grad = False 57 | 58 | if self.pitch_feature_level == "phoneme_level": 59 | pitch_predictions = pitch_predictions.masked_select(src_masks) 60 | pitch_targets = pitch_targets.masked_select(src_masks) 61 | elif self.pitch_feature_level == "frame_level": 62 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 63 | pitch_targets = pitch_targets.masked_select(mel_masks) 64 | 65 | if self.energy_feature_level == "phoneme_level": 66 | energy_predictions = energy_predictions.masked_select(src_masks) 67 | energy_targets = energy_targets.masked_select(src_masks) 68 | if self.energy_feature_level == "frame_level": 69 | energy_predictions = energy_predictions.masked_select(mel_masks) 70 | energy_targets = energy_targets.masked_select(mel_masks) 71 | 72 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 73 | log_duration_targets = log_duration_targets.masked_select(src_masks) 74 | 75 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 76 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 77 | 78 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 79 | 80 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 81 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 82 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) 83 | 84 | alpha = 1 85 | D_s_loss = D_t_loss = torch.tensor([0.], device=mel_predictions.device, requires_grad=False) 86 | if D_s is not None and D_t is not None: 87 | D_s_loss = self.mse_loss(D_s, torch.ones_like(D_s, requires_grad=False)) 88 | D_t_loss = self.mse_loss(D_t, torch.ones_like(D_t, requires_grad=False)) 89 | alpha = self.alpha 90 | 91 | recon_loss = alpha * (mel_loss + duration_loss + pitch_loss + energy_loss) 92 | total_loss = ( 93 | recon_loss + D_s_loss + D_t_loss 94 | ) 95 | 96 | return ( 97 | total_loss, 98 | mel_loss, 99 | pitch_loss, 100 | energy_loss, 101 | duration_loss, 102 | D_s_loss, 103 | D_t_loss, 104 | ) 105 | 106 | 107 | class MetaStyleSpeechLossDisc(nn.Module): 108 | """ Meta-StyleSpeech Loss for Step 2 """ 109 | 110 | def __init__(self, preprocess_config, model_config): 111 | super(MetaStyleSpeechLossDisc, self).__init__() 112 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 113 | "feature" 114 | ] 115 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 116 | "feature" 117 | ] 118 | self.mse_loss = nn.MSELoss() 119 | self.mae_loss = nn.L1Loss() 120 | self.cross_entropy_loss = nn.CrossEntropyLoss() 121 | 122 | def forward(self, speakers, predictions): 123 | ( 124 | D_t_s, 125 | D_t_q, 126 | D_s_s, 127 | D_s_q, 128 | style_logit, 129 | ) = predictions 130 | speakers.requires_grad = False 131 | 132 | D_t_loss = self.mse_loss(D_t_s, torch.ones_like(D_t_s, requires_grad=False))\ 133 | + self.mse_loss(D_t_q, torch.zeros_like(D_t_q, requires_grad=False)) 134 | D_s_loss = self.mse_loss(D_s_s, torch.ones_like(D_s_s, requires_grad=False))\ 135 | + self.mse_loss(D_s_q, torch.zeros_like(D_s_q, requires_grad=False)) 136 | cls_loss = self.cross_entropy_loss(style_logit, speakers) 137 | 138 | total_loss = ( 139 | D_t_loss + D_s_loss + cls_loss 140 | ) 141 | 142 | return ( 143 | total_loss, 144 | D_s_loss, 145 | D_t_loss, 146 | cls_loss, 147 | ) 148 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.parameter import Parameter 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | from utils.tools import get_mask_from_lengths, pad 14 | 15 | from .blocks import ( 16 | Mish, 17 | FCBlock, 18 | Conv1DBlock, 19 | SALNFFTBlock, 20 | MultiHeadAttention, 21 | ) 22 | from text.symbols import symbols 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 28 | """ Sinusoid position encoding table """ 29 | 30 | def cal_angle(position, hid_idx): 31 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 32 | 33 | def get_posi_angle_vec(position): 34 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 35 | 36 | sinusoid_table = np.array( 37 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 38 | ) 39 | 40 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 41 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 42 | 43 | if padding_idx is not None: 44 | # zero vector for padding dimension 45 | sinusoid_table[padding_idx] = 0.0 46 | 47 | return torch.FloatTensor(sinusoid_table) 48 | 49 | 50 | class MelStyleEncoder(nn.Module): 51 | """ Mel-Style Encoder """ 52 | 53 | def __init__(self, preprocess_config, model_config): 54 | super(MelStyleEncoder, self).__init__() 55 | n_position = model_config["max_seq_len"] + 1 56 | n_mel_channels = preprocess_config["preprocessing"]["mel"]["n_mel_channels"] 57 | d_melencoder = model_config["melencoder"]["encoder_hidden"] 58 | n_spectral_layer = model_config["melencoder"]["spectral_layer"] 59 | n_temporal_layer = model_config["melencoder"]["temporal_layer"] 60 | n_slf_attn_layer = model_config["melencoder"]["slf_attn_layer"] 61 | n_slf_attn_head = model_config["melencoder"]["slf_attn_head"] 62 | d_k = d_v = ( 63 | model_config["melencoder"]["encoder_hidden"] 64 | // model_config["melencoder"]["slf_attn_head"] 65 | ) 66 | kernel_size = model_config["melencoder"]["conv_kernel_size"] 67 | dropout = model_config["melencoder"]["encoder_dropout"] 68 | 69 | self.max_seq_len = model_config["max_seq_len"] 70 | 71 | self.fc_1 = FCBlock(n_mel_channels, d_melencoder) 72 | 73 | self.spectral_stack = nn.ModuleList( 74 | [ 75 | FCBlock( 76 | d_melencoder, d_melencoder, activation=Mish() 77 | ) 78 | for _ in range(n_spectral_layer) 79 | ] 80 | ) 81 | 82 | self.temporal_stack = nn.ModuleList( 83 | [ 84 | nn.Sequential( 85 | Conv1DBlock( 86 | d_melencoder, 2 * d_melencoder, kernel_size, activation=Mish(), dropout=dropout 87 | ), 88 | nn.GLU(), 89 | ) 90 | for _ in range(n_temporal_layer) 91 | ] 92 | ) 93 | 94 | self.slf_attn_stack = nn.ModuleList( 95 | [ 96 | MultiHeadAttention( 97 | n_slf_attn_head, d_melencoder, d_k, d_v, dropout=dropout, layer_norm=True 98 | ) 99 | for _ in range(n_slf_attn_layer) 100 | ] 101 | ) 102 | 103 | self.fc_2 = FCBlock(d_melencoder, d_melencoder) 104 | 105 | def forward(self, mel, mask): 106 | 107 | max_len = mel.shape[1] 108 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 109 | 110 | enc_output = self.fc_1(mel) 111 | 112 | # Spectral Processing 113 | for _, layer in enumerate(self.spectral_stack): 114 | enc_output = layer(enc_output) 115 | 116 | # Temporal Processing 117 | for _, layer in enumerate(self.temporal_stack): 118 | residual = enc_output 119 | enc_output = layer(enc_output) 120 | enc_output = residual + enc_output 121 | 122 | # Multi-head self-attention 123 | for _, layer in enumerate(self.slf_attn_stack): 124 | residual = enc_output 125 | enc_output, _ = layer( 126 | enc_output, enc_output, enc_output, mask=slf_attn_mask 127 | ) 128 | enc_output = residual + enc_output 129 | 130 | # Final Layer 131 | enc_output = self.fc_2(enc_output) # [B, T, H] 132 | 133 | # Temporal Average Pooling 134 | enc_output = torch.mean(enc_output, dim=1, keepdim=True) # [B, 1, H] 135 | 136 | return enc_output 137 | 138 | 139 | class PhonemePreNet(nn.Module): 140 | """ Phoneme Encoder PreNet """ 141 | 142 | def __init__(self, config): 143 | super(PhonemePreNet, self).__init__() 144 | d_model = config["transformer"]["encoder_hidden"] 145 | kernel_size = config["prenet"]["conv_kernel_size"] 146 | dropout = config["prenet"]["dropout"] 147 | 148 | self.prenet_layer = nn.Sequential( 149 | Conv1DBlock( 150 | d_model, d_model, kernel_size, activation=Mish(), dropout=dropout 151 | ), 152 | Conv1DBlock( 153 | d_model, d_model, kernel_size, activation=Mish(), dropout=dropout 154 | ), 155 | FCBlock(d_model, d_model, dropout=dropout), 156 | ) 157 | 158 | def forward(self, x, mask=None): 159 | residual = x 160 | x = self.prenet_layer(x) 161 | if mask is not None: 162 | x = x.masked_fill(mask.unsqueeze(-1), 0) 163 | x = residual + x 164 | return x 165 | 166 | 167 | class PhonemeEncoder(nn.Module): 168 | """ PhonemeText Encoder """ 169 | 170 | def __init__(self, config): 171 | super(PhonemeEncoder, self).__init__() 172 | 173 | n_position = config["max_seq_len"] + 1 174 | n_src_vocab = len(symbols) + 1 175 | d_word_vec = config["transformer"]["encoder_hidden"] 176 | n_layers = config["transformer"]["encoder_layer"] 177 | n_head = config["transformer"]["encoder_head"] 178 | d_w = config["melencoder"]["encoder_hidden"] 179 | d_k = d_v = ( 180 | config["transformer"]["encoder_hidden"] 181 | // config["transformer"]["encoder_head"] 182 | ) 183 | d_model = config["transformer"]["encoder_hidden"] 184 | d_inner = config["transformer"]["conv_filter_size"] 185 | kernel_size = config["transformer"]["conv_kernel_size"] 186 | dropout = config["transformer"]["encoder_dropout"] 187 | 188 | self.max_seq_len = config["max_seq_len"] 189 | self.d_model = d_model 190 | 191 | self.src_word_emb = nn.Embedding( 192 | n_src_vocab, d_word_vec, padding_idx=0 193 | ) 194 | self.phoneme_prenet = PhonemePreNet(config) 195 | self.position_enc = nn.Parameter( 196 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 197 | requires_grad=False, 198 | ) 199 | 200 | self.layer_stack = nn.ModuleList( 201 | [ 202 | SALNFFTBlock( 203 | d_model, d_w, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 204 | ) 205 | for _ in range(n_layers) 206 | ] 207 | ) 208 | 209 | def forward(self, src_seq, w, mask, return_attns=False): 210 | 211 | enc_slf_attn_list = [] 212 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 213 | 214 | # -- Prepare masks 215 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 216 | 217 | # -- PreNet 218 | src_seq = self.phoneme_prenet(self.src_word_emb(src_seq), mask) 219 | 220 | # -- Forward 221 | if not self.training and src_seq.shape[1] > self.max_seq_len: 222 | enc_output = src_seq + get_sinusoid_encoding_table( 223 | src_seq.shape[1], self.d_model 224 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 225 | src_seq.device 226 | ) 227 | else: 228 | enc_output = src_seq + self.position_enc[ 229 | :, :max_len, : 230 | ].expand(batch_size, -1, -1) 231 | 232 | for enc_layer in self.layer_stack: 233 | enc_output, enc_slf_attn = enc_layer( 234 | enc_output, w, mask=mask, slf_attn_mask=slf_attn_mask 235 | ) 236 | if return_attns: 237 | enc_slf_attn_list += [enc_slf_attn] 238 | 239 | return enc_output 240 | 241 | 242 | class MelPreNet(nn.Module): 243 | """ Mel-spectrogram Decoder PreNet """ 244 | 245 | def __init__(self, config): 246 | super(MelPreNet, self).__init__() 247 | d_model = config["transformer"]["encoder_hidden"] 248 | d_melencoder = config["melencoder"]["encoder_hidden"] 249 | dropout = config["prenet"]["dropout"] 250 | 251 | self.prenet_layer = nn.Sequential( 252 | FCBlock(d_model, d_melencoder, activation=Mish(), dropout=dropout), 253 | FCBlock(d_melencoder, d_model, activation=Mish(), dropout=dropout), 254 | ) 255 | 256 | def forward(self, x, mask=None): 257 | x = self.prenet_layer(x) 258 | if mask is not None: 259 | x = x.masked_fill(mask.unsqueeze(-1), 0) 260 | return x 261 | 262 | 263 | class MelDecoder(nn.Module): 264 | """ MelDecoder """ 265 | 266 | def __init__(self, config): 267 | super(MelDecoder, self).__init__() 268 | 269 | n_position = config["max_seq_len"] + 1 270 | d_word_vec = config["transformer"]["decoder_hidden"] 271 | n_layers = config["transformer"]["decoder_layer"] 272 | n_head = config["transformer"]["decoder_head"] 273 | d_w = config["melencoder"]["encoder_hidden"] 274 | d_k = d_v = ( 275 | config["transformer"]["decoder_hidden"] 276 | // config["transformer"]["decoder_head"] 277 | ) 278 | d_model = config["transformer"]["decoder_hidden"] 279 | d_inner = config["transformer"]["conv_filter_size"] 280 | kernel_size = config["transformer"]["conv_kernel_size"] 281 | dropout = config["transformer"]["decoder_dropout"] 282 | 283 | self.max_seq_len = config["max_seq_len"] 284 | self.d_model = d_model 285 | 286 | self.mel_prenet = MelPreNet(config) 287 | self.position_enc = nn.Parameter( 288 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 289 | requires_grad=False, 290 | ) 291 | 292 | self.layer_stack = nn.ModuleList( 293 | [ 294 | SALNFFTBlock( 295 | d_model, d_w, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 296 | ) 297 | for _ in range(n_layers) 298 | ] 299 | ) 300 | 301 | def forward(self, enc_seq, w, mask, return_attns=False): 302 | 303 | dec_slf_attn_list = [] 304 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 305 | 306 | # -- PreNet 307 | enc_seq = self.mel_prenet(enc_seq, mask) 308 | 309 | # -- Forward 310 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 311 | # -- Prepare masks 312 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 313 | dec_output = enc_seq + get_sinusoid_encoding_table( 314 | enc_seq.shape[1], self.d_model 315 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 316 | enc_seq.device 317 | ) 318 | else: 319 | max_len = min(max_len, self.max_seq_len) 320 | 321 | # -- Prepare masks 322 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 323 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 324 | :, :max_len, : 325 | ].expand(batch_size, -1, -1) 326 | mask = mask[:, :max_len] 327 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 328 | 329 | for dec_layer in self.layer_stack: 330 | dec_output, dec_slf_attn = dec_layer( 331 | dec_output, w, mask=mask, slf_attn_mask=slf_attn_mask 332 | ) 333 | if return_attns: 334 | dec_slf_attn_list += [dec_slf_attn] 335 | 336 | return dec_output, mask 337 | 338 | 339 | class VarianceAdaptor(nn.Module): 340 | """ Variance Adaptor """ 341 | 342 | def __init__(self, preprocess_config, model_config): 343 | super(VarianceAdaptor, self).__init__() 344 | self.duration_predictor = VariancePredictor(model_config) 345 | self.length_regulator = LengthRegulator() 346 | self.pitch_predictor = VariancePredictor(model_config) 347 | self.energy_predictor = VariancePredictor(model_config) 348 | 349 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 350 | "feature" 351 | ] 352 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 353 | "feature" 354 | ] 355 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 356 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 357 | 358 | d_model = model_config["transformer"]["encoder_hidden"] 359 | kernel_size = model_config["variance_embedding"]["kernel_size"] 360 | self.pitch_embedding = Conv1DBlock( 361 | 1, d_model, kernel_size 362 | ) 363 | self.energy_embedding = Conv1DBlock( 364 | 1, d_model, kernel_size 365 | ) 366 | 367 | def get_pitch_embedding(self, x, target, mask, control): 368 | prediction = self.pitch_predictor(x, mask) 369 | if target is not None: 370 | embedding = self.pitch_embedding(target.unsqueeze(-1)) 371 | else: 372 | prediction = prediction * control 373 | embedding = self.pitch_embedding(prediction.unsqueeze(-1)) 374 | return prediction, embedding 375 | 376 | def get_energy_embedding(self, x, target, mask, control): 377 | prediction = self.energy_predictor(x, mask) 378 | if target is not None: 379 | embedding = self.energy_embedding(target.unsqueeze(-1)) 380 | else: 381 | prediction = prediction * control 382 | embedding = self.energy_embedding(prediction.unsqueeze(-1)) 383 | return prediction, embedding 384 | 385 | def upsample(self, x, mel_mask, max_len, log_duration_prediction=None, duration_target=None, d_control=1.0): 386 | if duration_target is not None: 387 | x, mel_len = self.length_regulator(x, duration_target, max_len) 388 | duration_rounded = duration_target 389 | else: 390 | duration_rounded = torch.clamp( 391 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 392 | min=0, 393 | ) 394 | x, mel_len = self.length_regulator(x, duration_rounded, None) 395 | mel_mask = get_mask_from_lengths(mel_len) 396 | return x, duration_rounded, mel_len, mel_mask 397 | 398 | def forward( 399 | self, 400 | x, 401 | src_mask, 402 | mel_mask, 403 | max_len, 404 | pitch_target=None, 405 | energy_target=None, 406 | duration_target=None, 407 | p_control=1.0, 408 | e_control=1.0, 409 | d_control=1.0, 410 | ): 411 | upsampled_text = None 412 | log_duration_prediction = self.duration_predictor(x, src_mask) 413 | if self.pitch_feature_level == "phoneme_level": 414 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 415 | x, pitch_target, src_mask, p_control 416 | ) 417 | x = x + pitch_embedding 418 | if self.energy_feature_level == "phoneme_level": 419 | energy_prediction, energy_embedding = self.get_energy_embedding( 420 | x, energy_target, src_mask, p_control 421 | ) 422 | x = x + energy_embedding 423 | 424 | x, duration_rounded, mel_len, mel_mask = self.upsample( 425 | x, mel_mask, max_len, log_duration_prediction=log_duration_prediction, duration_target=duration_target, d_control=d_control 426 | ) 427 | 428 | if self.pitch_feature_level == "frame_level": 429 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 430 | x, pitch_target, mel_mask, p_control 431 | ) 432 | x = x + pitch_embedding 433 | if self.energy_feature_level == "frame_level": 434 | energy_prediction, energy_embedding = self.get_energy_embedding( 435 | x, energy_target, mel_mask, p_control 436 | ) 437 | x = x + energy_embedding 438 | 439 | return ( 440 | x, 441 | pitch_prediction, 442 | energy_prediction, 443 | log_duration_prediction, 444 | duration_rounded, 445 | mel_len, 446 | mel_mask, 447 | ) 448 | 449 | 450 | class LengthRegulator(nn.Module): 451 | """ Length Regulator """ 452 | 453 | def __init__(self): 454 | super(LengthRegulator, self).__init__() 455 | 456 | def LR(self, x, duration, max_len): 457 | output = list() 458 | mel_len = list() 459 | for batch, expand_target in zip(x, duration): 460 | expanded = self.expand(batch, expand_target) 461 | output.append(expanded) 462 | mel_len.append(expanded.shape[0]) 463 | 464 | if max_len is not None: 465 | output = pad(output, max_len) 466 | else: 467 | output = pad(output) 468 | 469 | return output, torch.LongTensor(mel_len).to(device) 470 | 471 | def expand(self, batch, predicted): 472 | out = list() 473 | 474 | for i, vec in enumerate(batch): 475 | expand_size = predicted[i].item() 476 | out.append(vec.expand(max(int(expand_size), 0), -1)) 477 | out = torch.cat(out, 0) 478 | 479 | return out 480 | 481 | def forward(self, x, duration, max_len): 482 | output, mel_len = self.LR(x, duration, max_len) 483 | return output, mel_len 484 | 485 | 486 | class VariancePredictor(nn.Module): 487 | """ Duration, Pitch and Energy Predictor """ 488 | 489 | def __init__(self, model_config): 490 | super(VariancePredictor, self).__init__() 491 | 492 | self.input_size = model_config["transformer"]["encoder_hidden"] 493 | self.filter_size = model_config["variance_predictor"]["filter_size"] 494 | self.kernel = model_config["variance_predictor"]["kernel_size"] 495 | self.conv_output_size = model_config["variance_predictor"]["filter_size"] 496 | self.dropout = model_config["variance_predictor"]["dropout"] 497 | 498 | self.conv_layer = nn.Sequential( 499 | OrderedDict( 500 | [ 501 | ( 502 | "conv1d_1", 503 | Conv( 504 | self.input_size, 505 | self.filter_size, 506 | kernel_size=self.kernel, 507 | padding=(self.kernel - 1) // 2, 508 | ), 509 | ), 510 | ("relu_1", nn.ReLU()), 511 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 512 | ("dropout_1", nn.Dropout(self.dropout)), 513 | ( 514 | "conv1d_2", 515 | Conv( 516 | self.filter_size, 517 | self.filter_size, 518 | kernel_size=self.kernel, 519 | padding=1, 520 | ), 521 | ), 522 | ("relu_2", nn.ReLU()), 523 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 524 | ("dropout_2", nn.Dropout(self.dropout)), 525 | ] 526 | ) 527 | ) 528 | 529 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 530 | 531 | def forward(self, encoder_output, mask): 532 | out = self.conv_layer(encoder_output) 533 | out = self.linear_layer(out) 534 | out = out.squeeze(-1) 535 | 536 | if mask is not None: 537 | out = out.masked_fill(mask, 0.0) 538 | 539 | return out 540 | 541 | 542 | class Conv(nn.Module): 543 | """ 544 | Convolution Module 545 | """ 546 | 547 | def __init__( 548 | self, 549 | in_channels, 550 | out_channels, 551 | kernel_size=1, 552 | stride=1, 553 | padding=0, 554 | dilation=1, 555 | bias=True, 556 | w_init="linear", 557 | ): 558 | """ 559 | :param in_channels: dimension of input 560 | :param out_channels: dimension of output 561 | :param kernel_size: size of kernel 562 | :param stride: size of stride 563 | :param padding: size of padding 564 | :param dilation: dilation rate 565 | :param bias: boolean. if True, bias is included. 566 | :param w_init: str. weight inits with xavier initialization. 567 | """ 568 | super(Conv, self).__init__() 569 | 570 | self.conv = nn.Conv1d( 571 | in_channels, 572 | out_channels, 573 | kernel_size=kernel_size, 574 | stride=stride, 575 | padding=padding, 576 | dilation=dilation, 577 | bias=bias, 578 | ) 579 | 580 | def forward(self, x): 581 | x = x.contiguous().transpose(1, 2) 582 | x = self.conv(x) 583 | x = x.contiguous().transpose(1, 2) 584 | 585 | return x 586 | 587 | 588 | class PhonemeDiscriminator(nn.Module): 589 | """ Phoneme Discriminator """ 590 | 591 | def __init__(self, preprocess_config, model_config): 592 | super(PhonemeDiscriminator, self).__init__() 593 | n_mel_channels = preprocess_config["preprocessing"]["mel"]["n_mel_channels"] 594 | d_mel_linear = model_config["discriminator"]["mel_linear_size"] 595 | d_model = model_config["discriminator"]["phoneme_hidden"] 596 | d_layer = model_config["discriminator"]["phoneme_layer"] 597 | 598 | self.max_seq_len = model_config["max_seq_len"] 599 | self.mel_linear = nn.Sequential( 600 | FCBlock(n_mel_channels, d_mel_linear, activation=nn.LeakyReLU(), spectral_norm=True), 601 | FCBlock(d_mel_linear, d_mel_linear, activation=nn.LeakyReLU(), spectral_norm=True), 602 | ) 603 | self.discriminator_stack = nn.ModuleList( 604 | [ 605 | FCBlock( 606 | d_model, d_model, activation=nn.LeakyReLU(), spectral_norm=True 607 | ) 608 | for _ in range(d_layer) 609 | ] 610 | ) 611 | self.final_linear = FCBlock(d_model, 1, spectral_norm=True) 612 | 613 | def forward(self, upsampler, text, mel, max_len, mask, duration_target): 614 | 615 | # Prepare Upsampled Text 616 | upsampled_text, _, _, _ = upsampler( 617 | text, mask, max_len, duration_target=duration_target 618 | ) 619 | max_len = min(max_len, self.max_seq_len) 620 | upsampled_text = upsampled_text[:, :max_len, :] 621 | 622 | # Prepare Mel 623 | mel = self.mel_linear(mel)[:, :max_len, :] 624 | mel = mel.masked_fill(mask.unsqueeze(-1)[:, :max_len, :], 0) 625 | 626 | # Prepare Input 627 | x = torch.cat([upsampled_text, mel], dim=-1) 628 | 629 | # Phoneme Discriminator 630 | for _, layer in enumerate(self.discriminator_stack): 631 | x = layer(x) 632 | x = self.final_linear(x) # [B, T, 1] 633 | x = x.masked_fill(mask.unsqueeze(-1)[:, :max_len, :], 0) 634 | 635 | # Temporal Average Pooling 636 | x = torch.mean(x, dim=1, keepdim=True) # [B, 1, 1] 637 | x = x.squeeze() # [B,] 638 | 639 | return x 640 | 641 | 642 | class StyleDiscriminator(nn.Module): 643 | """ Style Discriminator """ 644 | 645 | def __init__(self, preprocess_config, model_config): 646 | super(StyleDiscriminator, self).__init__() 647 | n_position = model_config["max_seq_len"] + 1 648 | n_mel_channels = preprocess_config["preprocessing"]["mel"]["n_mel_channels"] 649 | d_melencoder = model_config["melencoder"]["encoder_hidden"] 650 | n_spectral_layer = model_config["melencoder"]["spectral_layer"] 651 | n_temporal_layer = model_config["melencoder"]["temporal_layer"] 652 | n_slf_attn_layer = model_config["melencoder"]["slf_attn_layer"] 653 | n_slf_attn_head = model_config["melencoder"]["slf_attn_head"] 654 | d_k = d_v = ( 655 | model_config["melencoder"]["encoder_hidden"] 656 | // model_config["melencoder"]["slf_attn_head"] 657 | ) 658 | kernel_size = model_config["melencoder"]["conv_kernel_size"] 659 | 660 | self.max_seq_len = model_config["max_seq_len"] 661 | 662 | self.fc_1 = FCBlock(n_mel_channels, d_melencoder, spectral_norm=True) 663 | 664 | self.spectral_stack = nn.ModuleList( 665 | [ 666 | FCBlock( 667 | d_melencoder, d_melencoder, activation=nn.LeakyReLU(), spectral_norm=True 668 | ) 669 | for _ in range(n_spectral_layer) 670 | ] 671 | ) 672 | 673 | self.temporal_stack = nn.ModuleList( 674 | [ 675 | Conv1DBlock( 676 | d_melencoder, d_melencoder, kernel_size, activation=nn.LeakyReLU(), spectral_norm=True 677 | ) 678 | for _ in range(n_temporal_layer) 679 | ] 680 | ) 681 | 682 | self.slf_attn_stack = nn.ModuleList( 683 | [ 684 | MultiHeadAttention( 685 | n_slf_attn_head, d_melencoder, d_k, d_v, layer_norm=True, spectral_norm=True 686 | ) 687 | for _ in range(n_slf_attn_layer) 688 | ] 689 | ) 690 | 691 | self.fc_2 = FCBlock(d_melencoder, d_melencoder, spectral_norm=True) 692 | 693 | self.V = FCBlock(d_melencoder, d_melencoder) 694 | self.w_b_0 = FCBlock(1, 1, bias=True) 695 | 696 | def forward(self, style_prototype, speakers, mel, mask): 697 | 698 | max_len = mel.shape[1] 699 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 700 | 701 | x = self.fc_1(mel) 702 | 703 | # Spectral Processing 704 | for _, layer in enumerate(self.spectral_stack): 705 | x = layer(x) 706 | 707 | # Temporal Processing 708 | for _, layer in enumerate(self.temporal_stack): 709 | residual = x 710 | x = layer(x) 711 | x = residual + x 712 | 713 | # Multi-head self-attention 714 | for _, layer in enumerate(self.slf_attn_stack): 715 | residual = x 716 | x, _ = layer( 717 | x, x, x, mask=slf_attn_mask 718 | ) 719 | x = residual + x 720 | 721 | # Final Layer 722 | x = self.fc_2(x) # [B, T, H] 723 | 724 | # Temporal Average Pooling, h(x) 725 | x = torch.mean(x, dim=1, keepdim=True) # [B, 1, H] 726 | 727 | # Output Computation 728 | s_i = style_prototype(speakers) # [B, H] 729 | V = self.V(s_i).unsqueeze(2) # [B, H, 1] 730 | o = torch.matmul(x, V).squeeze(2) # [B, 1] 731 | o = self.w_b_0(o).squeeze() # [B,] 732 | 733 | return o -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptimMain: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, train_config, model_config, current_step): 9 | self._optimizer = torch.optim.Adam( 10 | [param for name, param in model.named_parameters() 11 | if not any([filtered_name in name for filtered_name in ['D_s', 'D_t']])], 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"] 19 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 20 | # meta_learning_warmup = train_config["step"]["meta_learning_warmup"] 21 | self.current_step = current_step# if current_step <= meta_learning_warmup else current_step - meta_learning_warmup 22 | 23 | def step_and_update_lr(self): 24 | self._update_learning_rate() 25 | self._optimizer.step() 26 | 27 | def zero_grad(self): 28 | # print(self.init_lr) 29 | self._optimizer.zero_grad() 30 | 31 | def load_state_dict(self, state_dict): 32 | state_dict['param_groups'] = self._optimizer.state_dict()['param_groups'] 33 | self._optimizer.load_state_dict(state_dict) 34 | 35 | def _get_lr_scale(self): 36 | lr = np.min( 37 | [ 38 | np.power(self.current_step, -0.5), 39 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 40 | ] 41 | ) 42 | for s in self.anneal_steps: 43 | if self.current_step > s: 44 | lr = lr * self.anneal_rate 45 | return lr 46 | 47 | def _update_learning_rate(self): 48 | """ Learning rate scheduling per step """ 49 | self.current_step += 1 50 | lr = self.init_lr * self._get_lr_scale() 51 | 52 | for param_group in self._optimizer.param_groups: 53 | param_group["lr"] = lr 54 | 55 | 56 | class ScheduledOptimDisc: 57 | """ A simple wrapper class for learning rate scheduling """ 58 | 59 | def __init__(self, model, train_config): 60 | 61 | self._optimizer = torch.optim.Adam( 62 | [param for name, param in model.named_parameters() 63 | if any([filtered_name in name for filtered_name in ['D_s', 'D_t']])], 64 | betas=train_config["optimizer"]["betas"], 65 | eps=train_config["optimizer"]["eps"], 66 | weight_decay=train_config["optimizer"]["weight_decay"], 67 | ) 68 | self.init_lr = train_config["optimizer"]["lr_disc"] 69 | self._init_learning_rate() 70 | 71 | def step_and_update_lr(self): 72 | self._optimizer.step() 73 | 74 | def zero_grad(self): 75 | # print(self.init_lr) 76 | self._optimizer.zero_grad() 77 | 78 | def load_state_dict(self, state_dict): 79 | state_dict['param_groups'] = self._optimizer.state_dict()['param_groups'] 80 | self._optimizer.load_state_dict(state_dict) 81 | 82 | def _init_learning_rate(self): 83 | lr = self.init_lr 84 | for param_group in self._optimizer.param_groups: 85 | param_group["lr"] = lr 86 | -------------------------------------------------------------------------------- /prepare_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from preprocessor import ljspeech, aishell3, libritts 6 | 7 | 8 | def main(config): 9 | if "LJSpeech" in config["dataset"]: 10 | ljspeech.prepare_align(config) 11 | if "AISHELL3" in config["dataset"]: 12 | aishell3.prepare_align(config) 13 | if "LibriTTS" in config["dataset"]: 14 | libritts.prepare_align(config) 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 20 | args = parser.parse_args() 21 | 22 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 23 | main(config) 24 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from preprocessor.preprocessor import Preprocessor 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 11 | args = parser.parse_args() 12 | 13 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 14 | preprocessor = Preprocessor(config) 15 | preprocessor.build_from_path() 16 | -------------------------------------------------------------------------------- /preprocessed_data/LibriTTS/speakers.json: -------------------------------------------------------------------------------- 1 | {"229": 0, "5789": 1, "1034": 2, "3983": 3, "5463": 4, "4481": 5, "1355": 6, "1970": 7, "7264": 8, "625": 9, "5688": 10, "6880": 11, "2002": 12, "1502": 13, "7517": 14, "298": 15, "2989": 16, "8468": 17, "3235": 18, "103": 19, "1455": 20, "3947": 21, "887": 22, "669": 23, "5393": 24, "8838": 25, "5750": 26, "3242": 27, "3857": 28, "696": 29, "60": 30, "78": 31, "125": 32, "1898": 33, "2514": 34, "6454": 35, "6836": 36, "3112": 37, "6925": 38, "839": 39, "8630": 40, "4859": 41, "7148": 42, "6078": 43, "2007": 44, "3723": 45, "1447": 46, "4788": 47, "4340": 48, "8051": 49, "458": 50, "5456": 51, "39": 52, "2893": 53, "4898": 54, "8629": 55, "3259": 56, "5514": 57, "3879": 58, "8580": 59, "3664": 60, "6529": 61, "307": 62, "481": 63, "5022": 64, "2843": 65, "3440": 66, "83": 67, "3436": 68, "7078": 69, "311": 70, "405": 71, "2136": 72, "118": 73, "6385": 74, "1963": 75, "730": 76, "3830": 77, "8238": 78, "412": 79, "6437": 80, "587": 81, "89": 82, "6272": 83, "1081": 84, "6848": 85, "254": 86, "4297": 87, "4680": 88, "7113": 89, "27": 90, "374": 91, "7794": 92, "1926": 93, "8123": 94, "3486": 95, "2416": 96, "6818": 97, "8095": 98, "7278": 99, "332": 100, "200": 101, "7800": 102, "6147": 103, "2952": 104, "7190": 105, "1088": 106, "248": 107, "32": 108, "6000": 109, "6531": 110, "909": 111, "1578": 112, "7635": 113, "6476": 114, "5678": 115, "3526": 116, "1867": 117, "446": 118, "8609": 119, "5049": 120, "302": 121, "250": 122, "1743": 123, "4853": 124, "1363": 125, "8797": 126, "4406": 127, "6563": 128, "2159": 129, "2436": 130, "8465": 131, "426": 132, "7178": 133, "4088": 134, "8312": 135, "8098": 136, "3240": 137, "2836": 138, "198": 139, "2384": 140, "7067": 141, "5778": 142, "7402": 143, "5808": 144, "7505": 145, "1841": 146, "5163": 147, "1098": 148, "4397": 149, "3982": 150, "7226": 151, "26": 152, "6019": 153, "2196": 154, "4214": 155, "4160": 156, "8419": 157, "8226": 158, "1183": 159, "1235": 160, "1334": 161, "7447": 162, "40": 163, "8770": 164, "196": 165, "5339": 166, "3214": 167, "6367": 168, "1069": 169, "5703": 170, "201": 171, "4137": 172, "5104": 173, "7859": 174, "8088": 175, "87": 176, "2691": 177, "3607": 178, "7302": 179, "1263": 180, "1553": 181, "4640": 182, "2182": 183, "5322": 184, "2817": 185, "8108": 186, "19": 187, "211": 188, "1624": 189, "150": 190, "4195": 191, "8975": 192, "6081": 193, "1246": 194, "3807": 195, "5867": 196, "8747": 197, "1737": 198, "8063": 199, "5192": 200, "2764": 201, "6209": 202, "2518": 203, "8014": 204, "1992": 205, "1594": 206, "3374": 207, "460": 208, "4267": 209, "163": 210, "289": 211, "403": 212, "3699": 213, "5390": 214, "233": 215, "7511": 216, "2391": 217, "4014": 218, "7312": 219, "4018": 220, "4441": 221, "2289": 222, "7780": 223, "6181": 224, "5652": 225, "911": 226, "7059": 227, "1040": 228, "831": 229, "1116": 230, "4830": 231, "4813": 232, "6415": 233, "7367": 234, "8425": 235, "5561": 236, "2911": 237, "6064": 238, "8324": 239, "2092": 240, "3168": 241, "2910": 242, "4051": 243, "4362": 244, "322": 245, "226": 246} -------------------------------------------------------------------------------- /preprocessed_data/LibriTTS/stats.json: -------------------------------------------------------------------------------- 1 | {"pitch": [-2.7442070957317246, 11.15693693509074, 165.6901007628454, 60.37813291152693], "energy": [-1.2591148614883423, 10.99685001373291, 41.67014254956833, 33.09479306846138], "mel": [-11.512925148010254, 2.3185698986053467]} -------------------------------------------------------------------------------- /preprocessor/aishell3.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | 9 | def prepare_align(config): 10 | in_dir = config["path"]["corpus_path"] 11 | out_dir = config["path"]["raw_path"] 12 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 13 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 14 | for dataset in ["train", "test"]: 15 | print("Processing {}ing set...".format(dataset)) 16 | with open(os.path.join(in_dir, dataset, "content.txt"), encoding="utf-8") as f: 17 | for line in tqdm(f): 18 | wav_name, text = line.strip("\n").split("\t") 19 | speaker = wav_name[:7] 20 | text = text.split(" ")[1::2] 21 | wav_path = os.path.join(in_dir, dataset, "wav", speaker, wav_name) 22 | if os.path.exists(wav_path): 23 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 24 | wav, _ = librosa.load(wav_path, sampling_rate) 25 | wav = wav / max(abs(wav)) * max_wav_value 26 | wavfile.write( 27 | os.path.join(out_dir, speaker, wav_name), 28 | sampling_rate, 29 | wav.astype(np.int16), 30 | ) 31 | with open( 32 | os.path.join(out_dir, speaker, "{}.lab".format(wav_name[:11])), 33 | "w", 34 | ) as f1: 35 | f1.write(" ".join(text)) -------------------------------------------------------------------------------- /preprocessor/libritts.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | from text import _clean_text 9 | 10 | 11 | def prepare_align(config): 12 | in_dir = config["path"]["corpus_path"] 13 | out_dir = config["path"]["raw_path"] 14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 16 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 17 | for speaker in tqdm(os.listdir(in_dir)): 18 | for chapter in os.listdir(os.path.join(in_dir, speaker)): 19 | for file_name in os.listdir(os.path.join(in_dir, speaker, chapter)): 20 | if file_name[-4:] != ".wav": 21 | continue 22 | base_name = file_name[:-4] 23 | text_path = os.path.join( 24 | in_dir, speaker, chapter, "{}.normalized.txt".format(base_name) 25 | ) 26 | wav_path = os.path.join( 27 | in_dir, speaker, chapter, "{}.wav".format(base_name) 28 | ) 29 | with open(text_path) as f: 30 | text = f.readline().strip("\n") 31 | text = _clean_text(text, cleaners) 32 | 33 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 34 | wav, _ = librosa.load(wav_path, sampling_rate) 35 | wav = wav / max(abs(wav)) * max_wav_value 36 | wavfile.write( 37 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 38 | sampling_rate, 39 | wav.astype(np.int16), 40 | ) 41 | with open( 42 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 43 | "w", 44 | ) as f1: 45 | f1.write(text) -------------------------------------------------------------------------------- /preprocessor/ljspeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | from text import _clean_text 9 | 10 | 11 | def prepare_align(config): 12 | in_dir = config["path"]["corpus_path"] 13 | out_dir = config["path"]["raw_path"] 14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 16 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 17 | speaker = "LJSpeech" 18 | with open(os.path.join(in_dir, "metadata.csv"), encoding="utf-8") as f: 19 | for line in tqdm(f): 20 | parts = line.strip().split("|") 21 | base_name = parts[0] 22 | text = parts[2] 23 | text = _clean_text(text, cleaners) 24 | 25 | wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name)) 26 | if os.path.exists(wav_path): 27 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 28 | wav, _ = librosa.load(wav_path, sampling_rate) 29 | wav = wav / max(abs(wav)) * max_wav_value 30 | wavfile.write( 31 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 32 | sampling_rate, 33 | wav.astype(np.int16), 34 | ) 35 | with open( 36 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 37 | "w", 38 | ) as f1: 39 | f1.write(text) -------------------------------------------------------------------------------- /preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | 5 | import tgt 6 | import librosa 7 | import numpy as np 8 | import pyworld as pw 9 | from scipy.interpolate import interp1d 10 | from sklearn.preprocessing import StandardScaler 11 | from tqdm import tqdm 12 | 13 | import audio as Audio 14 | 15 | random.seed(1234) 16 | 17 | 18 | class Preprocessor: 19 | def __init__(self, config): 20 | self.config = config 21 | self.in_dir = config["path"]["raw_path"] 22 | self.out_dir = config["path"]["preprocessed_path"] 23 | self.val_size = config["preprocessing"]["val_size"] 24 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 25 | self.hop_length = config["preprocessing"]["stft"]["hop_length"] 26 | 27 | assert config["preprocessing"]["pitch"]["feature"] in [ 28 | "phoneme_level", 29 | "frame_level", 30 | ] 31 | assert config["preprocessing"]["energy"]["feature"] in [ 32 | "phoneme_level", 33 | "frame_level", 34 | ] 35 | self.pitch_phoneme_averaging = ( 36 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level" 37 | ) 38 | self.energy_phoneme_averaging = ( 39 | config["preprocessing"]["energy"]["feature"] == "phoneme_level" 40 | ) 41 | 42 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"] 43 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"] 44 | 45 | self.STFT = Audio.stft.TacotronSTFT( 46 | config["preprocessing"]["stft"]["filter_length"], 47 | config["preprocessing"]["stft"]["hop_length"], 48 | config["preprocessing"]["stft"]["win_length"], 49 | config["preprocessing"]["mel"]["n_mel_channels"], 50 | config["preprocessing"]["audio"]["sampling_rate"], 51 | config["preprocessing"]["mel"]["mel_fmin"], 52 | config["preprocessing"]["mel"]["mel_fmax"], 53 | ) 54 | 55 | def build_from_path(self): 56 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True) 57 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True) 58 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True) 59 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True) 60 | 61 | print("Processing Data ...") 62 | out = list() 63 | n_frames = 0 64 | mel_min = float('inf') 65 | mel_max = -float('inf') 66 | pitch_scaler = StandardScaler() 67 | energy_scaler = StandardScaler() 68 | 69 | # Compute pitch, energy, duration, and mel-spectrogram 70 | speakers = {} 71 | for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))): 72 | speakers[speaker] = i 73 | for wav_name in tqdm(os.listdir(os.path.join(self.in_dir, speaker))): 74 | if ".wav" not in wav_name: 75 | continue 76 | 77 | basename = wav_name.split(".")[0] 78 | chapter = basename.split("_")[1] 79 | tg_path = os.path.join( 80 | self.out_dir, "TextGrid", speaker, chapter, "{}.TextGrid".format(basename) 81 | ) 82 | if os.path.exists(tg_path): 83 | ret = self.process_utterance(speaker, chapter, basename) 84 | if ret is None: 85 | continue 86 | else: 87 | info, pitch, energy, n, m_min, m_max = ret 88 | out.append(info) 89 | 90 | if len(pitch) > 0: 91 | pitch_scaler.partial_fit(pitch.reshape((-1, 1))) 92 | if len(energy) > 0: 93 | energy_scaler.partial_fit(energy.reshape((-1, 1))) 94 | if mel_min > m_min: 95 | mel_min = m_min 96 | if mel_max < m_max: 97 | mel_max = m_max 98 | 99 | n_frames += n 100 | 101 | print("Computing statistic quantities ...") 102 | # Perform normalization if necessary 103 | if self.pitch_normalization: 104 | pitch_mean = pitch_scaler.mean_[0] 105 | pitch_std = pitch_scaler.scale_[0] 106 | else: 107 | # A numerical trick to avoid normalization... 108 | pitch_mean = 0 109 | pitch_std = 1 110 | if self.energy_normalization: 111 | energy_mean = energy_scaler.mean_[0] 112 | energy_std = energy_scaler.scale_[0] 113 | else: 114 | energy_mean = 0 115 | energy_std = 1 116 | 117 | pitch_min, pitch_max = self.normalize( 118 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std 119 | ) 120 | energy_min, energy_max = self.normalize( 121 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std 122 | ) 123 | 124 | # Save files 125 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f: 126 | f.write(json.dumps(speakers)) 127 | 128 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f: 129 | stats = { 130 | "pitch": [ 131 | float(pitch_min), 132 | float(pitch_max), 133 | float(pitch_mean), 134 | float(pitch_std), 135 | ], 136 | "energy": [ 137 | float(energy_min), 138 | float(energy_max), 139 | float(energy_mean), 140 | float(energy_std), 141 | ], 142 | "mel": [ 143 | float(mel_min), 144 | float(mel_max), 145 | ], 146 | } 147 | f.write(json.dumps(stats)) 148 | 149 | print( 150 | "Total time: {} hours".format( 151 | n_frames * self.hop_length / self.sampling_rate / 3600 152 | ) 153 | ) 154 | 155 | random.shuffle(out) 156 | out = [r for r in out if r is not None] 157 | 158 | # Write metadata 159 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f: 160 | for m in out[self.val_size :]: 161 | f.write(m + "\n") 162 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f: 163 | for m in out[: self.val_size]: 164 | f.write(m + "\n") 165 | 166 | return out 167 | 168 | def process_utterance(self, speaker, chapter, basename): 169 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename)) 170 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename)) 171 | tg_path = os.path.join( 172 | self.out_dir, "TextGrid", speaker, chapter, "{}.TextGrid".format(basename) 173 | ) 174 | 175 | # Get alignments 176 | textgrid = tgt.io.read_textgrid(tg_path) 177 | phone, duration, start, end = self.get_alignment( 178 | textgrid.get_tier_by_name("phones") 179 | ) 180 | text = "{" + " ".join(phone) + "}" 181 | if start >= end: 182 | return None 183 | 184 | # Read and trim wav files 185 | wav, _ = librosa.load(wav_path) 186 | wav = wav[ 187 | int(self.sampling_rate * start) : int(self.sampling_rate * end) 188 | ].astype(np.float32) 189 | 190 | # Read raw text 191 | with open(text_path, "r") as f: 192 | raw_text = f.readline().strip("\n") 193 | 194 | # Compute fundamental frequency 195 | pitch, t = pw.dio( 196 | wav.astype(np.float64), 197 | self.sampling_rate, 198 | frame_period=self.hop_length / self.sampling_rate * 1000, 199 | ) 200 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate) 201 | 202 | pitch = pitch[: sum(duration)] 203 | if np.sum(pitch != 0) <= 1: 204 | return None 205 | 206 | # Compute mel-scale spectrogram and energy 207 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT) 208 | mel_spectrogram = mel_spectrogram[:, : sum(duration)] 209 | energy = energy[: sum(duration)] 210 | 211 | if self.pitch_phoneme_averaging: 212 | # perform linear interpolation 213 | nonzero_ids = np.where(pitch != 0)[0] 214 | interp_fn = interp1d( 215 | nonzero_ids, 216 | pitch[nonzero_ids], 217 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 218 | bounds_error=False, 219 | ) 220 | pitch = interp_fn(np.arange(0, len(pitch))) 221 | 222 | # Phoneme-level average 223 | pos = 0 224 | for i, d in enumerate(duration): 225 | if d > 0: 226 | pitch[i] = np.mean(pitch[pos : pos + d]) 227 | else: 228 | pitch[i] = 0 229 | pos += d 230 | pitch = pitch[: len(duration)] 231 | 232 | if self.energy_phoneme_averaging: 233 | # Phoneme-level average 234 | pos = 0 235 | for i, d in enumerate(duration): 236 | if d > 0: 237 | energy[i] = np.mean(energy[pos : pos + d]) 238 | else: 239 | energy[i] = 0 240 | pos += d 241 | energy = energy[: len(duration)] 242 | 243 | # Save files 244 | dur_filename = "{}-duration-{}.npy".format(speaker, basename) 245 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration) 246 | 247 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename) 248 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch) 249 | 250 | energy_filename = "{}-energy-{}.npy".format(speaker, basename) 251 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy) 252 | 253 | mel_filename = "{}-mel-{}.npy".format(speaker, basename) 254 | np.save( 255 | os.path.join(self.out_dir, "mel", mel_filename), 256 | mel_spectrogram.T, 257 | ) 258 | 259 | return ( 260 | "|".join([basename, speaker, text, raw_text]), 261 | self.remove_outlier(pitch), 262 | self.remove_outlier(energy), 263 | mel_spectrogram.shape[1], 264 | np.min(mel_spectrogram), 265 | np.max(mel_spectrogram), 266 | ) 267 | 268 | def get_alignment(self, tier): 269 | sil_phones = ["sil", "sp", "spn"] 270 | 271 | phones = [] 272 | durations = [] 273 | start_time = 0 274 | end_time = 0 275 | end_idx = 0 276 | for t in tier._objects: 277 | s, e, p = t.start_time, t.end_time, t.text 278 | 279 | # Trim leading silences 280 | if phones == []: 281 | if p in sil_phones: 282 | continue 283 | else: 284 | start_time = s 285 | 286 | if p not in sil_phones: 287 | # For ordinary phones 288 | phones.append(p) 289 | end_time = e 290 | end_idx = len(phones) 291 | else: 292 | # For silent phones 293 | phones.append(p) 294 | 295 | durations.append( 296 | int( 297 | np.round(e * self.sampling_rate / self.hop_length) 298 | - np.round(s * self.sampling_rate / self.hop_length) 299 | ) 300 | ) 301 | 302 | # Trim tailing silences 303 | phones = phones[:end_idx] 304 | durations = durations[:end_idx] 305 | 306 | return phones, durations, start_time, end_time 307 | 308 | def remove_outlier(self, values): 309 | values = np.array(values) 310 | p25 = np.percentile(values, 25) 311 | p75 = np.percentile(values, 75) 312 | lower = p25 - 1.5 * (p75 - p25) 313 | upper = p75 + 1.5 * (p75 - p25) 314 | normal_indices = np.logical_and(values > lower, values < upper) 315 | 316 | return values[normal_indices] 317 | 318 | def normalize(self, in_dir, mean, std): 319 | max_value = np.finfo(np.float64).min 320 | min_value = np.finfo(np.float64).max 321 | for filename in os.listdir(in_dir): 322 | filename = os.path.join(in_dir, filename) 323 | values = (np.load(filename) - mean) / std 324 | np.save(filename, values) 325 | 326 | max_value = max(max_value, max(values)) 327 | min_value = min(min_value, min(values)) 328 | 329 | return min_value, max_value 330 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | g2p-en==2.1.0 2 | inflect==4.1.0 3 | librosa==0.7.2 4 | matplotlib==3.4.2 5 | numba==0.48 6 | numpy==1.19.0 7 | pypinyin==0.39.0 8 | pyworld==0.3.0 9 | PyYAML==5.4.1 10 | scikit-learn==0.23.2 11 | scipy==1.6.3 12 | soundfile==0.10.3.post1 13 | tensorboard==2.2.2 14 | tgt==1.4.4 15 | torch==1.8.1 16 | tqdm==4.46.1 17 | unidecode==1.1.1 -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from string import punctuation 4 | import torch 5 | import yaml 6 | import numpy as np 7 | import os 8 | import json 9 | 10 | import librosa 11 | import pyworld as pw 12 | import audio as Audio 13 | 14 | from torch.utils.data import DataLoader 15 | from g2p_en import G2p 16 | from pypinyin import pinyin, Style 17 | 18 | from utils.model import get_model, get_vocoder 19 | from utils.tools import to_device, synth_samples 20 | from dataset import BatchInferenceDataset 21 | from text import text_to_sequence 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | def read_lexicon(lex_path): 27 | lexicon = {} 28 | with open(lex_path) as f: 29 | for line in f: 30 | temp = re.split(r"\s+", line.strip("\n")) 31 | word = temp[0] 32 | phones = temp[1:] 33 | if word.lower() not in lexicon: 34 | lexicon[word.lower()] = phones 35 | return lexicon 36 | 37 | 38 | def get_audio(preprocess_config, wav_path): 39 | 40 | hop_length = preprocess_config["preprocessing"]["stft"]["hop_length"] 41 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 42 | STFT = Audio.stft.TacotronSTFT( 43 | preprocess_config["preprocessing"]["stft"]["filter_length"], 44 | hop_length, 45 | preprocess_config["preprocessing"]["stft"]["win_length"], 46 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 47 | sampling_rate, 48 | preprocess_config["preprocessing"]["mel"]["mel_fmin"], 49 | preprocess_config["preprocessing"]["mel"]["mel_fmax"], 50 | ) 51 | with open( 52 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 53 | ) as f: 54 | stats = json.load(f) 55 | stats = stats["pitch"][2:] + stats["energy"][2:] 56 | pitch_mean, pitch_std, energy_mean, energy_std = stats 57 | 58 | # Read and trim wav files 59 | wav, _ = librosa.load(wav_path) 60 | 61 | # Compute fundamental frequency 62 | pitch, t = pw.dio( 63 | wav.astype(np.float64), 64 | sampling_rate, 65 | frame_period=hop_length / sampling_rate * 1000, 66 | ) 67 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, sampling_rate) 68 | 69 | # Compute mel-scale spectrogram and energy 70 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav.astype(np.float32), STFT) 71 | 72 | # Normalize Variance 73 | pitch = (pitch - pitch_mean) / pitch_std 74 | energy = (energy - energy_mean) / energy_std 75 | 76 | mels = mel_spectrogram.T[None] 77 | mel_lens = np.array([len(mels[0])]) 78 | 79 | mel_spectrogram = mel_spectrogram.astype(np.float32) 80 | energy = energy.astype(np.float32) 81 | 82 | return mels, mel_lens, (mel_spectrogram, pitch, energy) 83 | 84 | 85 | def preprocess_english(text, preprocess_config): 86 | text = text.rstrip(punctuation) 87 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) 88 | 89 | g2p = G2p() 90 | phones = [] 91 | words = re.split(r"([,;.\-\?\!\s+])", text) 92 | for w in words: 93 | if w.lower() in lexicon: 94 | phones += lexicon[w.lower()] 95 | else: 96 | phones += list(filter(lambda p: p != " ", g2p(w))) 97 | phones = "{" + "}{".join(phones) + "}" 98 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) 99 | phones = phones.replace("}{", " ") 100 | 101 | print("Raw Text Sequence: {}".format(text)) 102 | print("Phoneme Sequence: {}".format(phones)) 103 | sequence = np.array( 104 | text_to_sequence( 105 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] 106 | ) 107 | ) 108 | 109 | return np.array(sequence) 110 | 111 | 112 | def preprocess_mandarin(text, preprocess_config): 113 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) 114 | 115 | phones = [] 116 | pinyins = [ 117 | p[0] 118 | for p in pinyin( 119 | text, style=Style.TONE3, strict=False, neutral_tone_with_five=True 120 | ) 121 | ] 122 | for p in pinyins: 123 | if p in lexicon: 124 | phones += lexicon[p] 125 | else: 126 | phones.append("sp") 127 | 128 | phones = "{" + " ".join(phones) + "}" 129 | print("Raw Text Sequence: {}".format(text)) 130 | print("Phoneme Sequence: {}".format(phones)) 131 | sequence = np.array( 132 | text_to_sequence( 133 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] 134 | ) 135 | ) 136 | 137 | return np.array(sequence) 138 | 139 | 140 | def synthesize(model, step, configs, vocoder, batchs, control_values): 141 | preprocess_config, model_config, train_config = configs 142 | pitch_control, energy_control, duration_control = control_values 143 | 144 | for batch in batchs: 145 | batch = to_device(batch, device) 146 | with torch.no_grad(): 147 | # Forward 148 | output = model( 149 | *(batch[2:-1]), 150 | p_control=pitch_control, 151 | e_control=energy_control, 152 | d_control=duration_control 153 | ) 154 | synth_samples( 155 | batch, 156 | output, 157 | vocoder, 158 | model_config, 159 | preprocess_config, 160 | train_config["path"]["result_path"], 161 | ) 162 | 163 | 164 | if __name__ == "__main__": 165 | 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--restore_step", type=int, required=True) 168 | parser.add_argument( 169 | "--mode", 170 | type=str, 171 | choices=["batch", "single"], 172 | required=True, 173 | help="Synthesize a whole dataset or a single sentence", 174 | ) 175 | parser.add_argument( 176 | "--source", 177 | type=str, 178 | default=None, 179 | help="path to a source file with format like train.txt and val.txt, for batch mode only", 180 | ) 181 | parser.add_argument( 182 | "--text", 183 | type=str, 184 | default=None, 185 | help="raw text to synthesize, for single-sentence mode only", 186 | ) 187 | parser.add_argument( 188 | "--ref_audio", 189 | type=str, 190 | default=None, 191 | help="reference audio path to extract the speech style, for single-sentence mode only", 192 | ) 193 | parser.add_argument( 194 | "-p", 195 | "--preprocess_config", 196 | type=str, 197 | required=True, 198 | help="path to preprocess.yaml", 199 | ) 200 | parser.add_argument( 201 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 202 | ) 203 | parser.add_argument( 204 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 205 | ) 206 | parser.add_argument( 207 | "--pitch_control", 208 | type=float, 209 | default=1.0, 210 | help="control the pitch of the whole utterance, larger value for higher pitch", 211 | ) 212 | parser.add_argument( 213 | "--energy_control", 214 | type=float, 215 | default=1.0, 216 | help="control the energy of the whole utterance, larger value for larger volume", 217 | ) 218 | parser.add_argument( 219 | "--duration_control", 220 | type=float, 221 | default=1.0, 222 | help="control the speed of the whole utterance, larger value for slower speaking rate", 223 | ) 224 | args = parser.parse_args() 225 | 226 | # Check source texts 227 | if args.mode == "batch": 228 | assert args.source is not None and args.text is None 229 | if args.mode == "single": 230 | assert args.source is None and args.text is not None 231 | 232 | # Read Config 233 | preprocess_config = yaml.load( 234 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 235 | ) 236 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 237 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 238 | configs = (preprocess_config, model_config, train_config) 239 | 240 | # Get model 241 | model = get_model(args, configs, device, train=False) 242 | 243 | # Load vocoder 244 | vocoder = get_vocoder(model_config, device) 245 | 246 | # Preprocess texts 247 | if args.mode == "batch": 248 | # Get dataset 249 | dataset = BatchInferenceDataset(args.source, preprocess_config) 250 | batchs = DataLoader( 251 | dataset, 252 | batch_size=8, 253 | collate_fn=dataset.collate_fn, 254 | ) 255 | if args.mode == "single": 256 | ids = raw_texts = [args.text[:100]] 257 | if preprocess_config["preprocessing"]["text"]["language"] == "en": 258 | texts = np.array([preprocess_english(args.text, preprocess_config)]) 259 | elif preprocess_config["preprocessing"]["text"]["language"] == "zh": 260 | texts = np.array([preprocess_mandarin(args.text, preprocess_config)]) 261 | text_lens = np.array([len(texts[0])]) 262 | mels, mel_lens, ref_info = get_audio(preprocess_config, args.ref_audio) 263 | batchs = [(["_".join([os.path.basename(args.ref_audio).strip(".wav"), id]) for id in ids], \ 264 | raw_texts, None, texts, text_lens, max(text_lens), mels, mel_lens, max(mel_lens), [ref_info])] 265 | 266 | control_values = args.pitch_control, args.energy_control, args.duration_control 267 | 268 | synthesize(model, args.restore_step, configs, vocoder, batchs, control_values) 269 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | """ 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 38 | sequence += _arpabet_to_sequence(m.group(2)) 39 | text = m.group(3) 40 | 41 | return sequence 42 | 43 | 44 | def sequence_to_text(sequence): 45 | """Converts a sequence of IDs back to a string""" 46 | result = "" 47 | for symbol_id in sequence: 48 | if symbol_id in _id_to_symbol: 49 | s = _id_to_symbol[symbol_id] 50 | # Enclose ARPAbet back in curly braces: 51 | if len(s) > 1 and s[0] == "@": 52 | s = "{%s}" % s[1:] 53 | result += s 54 | return result.replace("}{", " ") 55 | 56 | 57 | def _clean_text(text, cleaner_names): 58 | for name in cleaner_names: 59 | cleaner = getattr(cleaners, name) 60 | if not cleaner: 61 | raise Exception("Unknown cleaner: %s" % name) 62 | text = cleaner(text) 63 | return text 64 | 65 | 66 | def _symbols_to_sequence(symbols): 67 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 68 | 69 | 70 | def _arpabet_to_sequence(text): 71 | return _symbols_to_sequence(["@" + s for s in text.split()]) 72 | 73 | 74 | def _should_keep_symbol(s): 75 | return s in _symbol_to_id and s != "_" and s != "~" 76 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from text import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | 11 | from utils.model import get_model, get_vocoder, get_param_num 12 | from utils.tools import to_device, log, synth_one_sample 13 | from model import MetaStyleSpeechLossMain, MetaStyleSpeechLossDisc 14 | from dataset import Dataset 15 | 16 | from evaluate import evaluate 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | def backward(model, optimizer, total_loss, step, grad_acc_step, grad_clip_thresh): 22 | total_loss = total_loss / grad_acc_step 23 | total_loss.backward() 24 | if step % grad_acc_step == 0: 25 | # Clipping gradients to avoid gradient explosion 26 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) 27 | 28 | # Update weights 29 | optimizer.step_and_update_lr() 30 | optimizer.zero_grad() 31 | 32 | 33 | def main(args, configs): 34 | print("Prepare training ...") 35 | 36 | preprocess_config, model_config, train_config = configs 37 | 38 | # Get dataset 39 | dataset = Dataset( 40 | "train_filtered.txt", preprocess_config, train_config, sort=True, drop_last=True 41 | ) 42 | batch_size = train_config["optimizer"]["batch_size"] 43 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset 44 | assert batch_size * group_size < len(dataset) 45 | loader = DataLoader( 46 | dataset, 47 | batch_size=batch_size * group_size, 48 | shuffle=True, 49 | collate_fn=dataset.collate_fn, 50 | ) 51 | 52 | # Prepare model 53 | model, optimizer_main, optimizer_disc = get_model(args, configs, device, train=True) 54 | model = nn.DataParallel(model) 55 | num_param = get_param_num(model) 56 | Loss_1 = MetaStyleSpeechLossMain(preprocess_config, model_config, train_config).to(device) 57 | Loss_2 = MetaStyleSpeechLossDisc(preprocess_config, model_config).to(device) 58 | print("Number of StyleSpeech Parameters:", num_param) 59 | 60 | # Load vocoder 61 | vocoder = get_vocoder(model_config, device) 62 | 63 | # Init logger 64 | for p in train_config["path"].values(): 65 | os.makedirs(p, exist_ok=True) 66 | train_log_path = os.path.join(train_config["path"]["log_path"], "train") 67 | val_log_path = os.path.join(train_config["path"]["log_path"], "val") 68 | os.makedirs(train_log_path, exist_ok=True) 69 | os.makedirs(val_log_path, exist_ok=True) 70 | train_logger = SummaryWriter(train_log_path) 71 | val_logger = SummaryWriter(val_log_path) 72 | 73 | # Training 74 | step = args.restore_step + 1 75 | epoch = 1 76 | meta_learning_warmup = train_config["step"]["meta_learning_warmup"] 77 | grad_acc_step = train_config["optimizer"]["grad_acc_step"] 78 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] 79 | total_step = train_config["step"]["total_step"] 80 | log_step = train_config["step"]["log_step"] 81 | save_step = train_config["step"]["save_step"] 82 | synth_step = train_config["step"]["synth_step"] 83 | val_step = train_config["step"]["val_step"] 84 | 85 | outer_bar = tqdm(total=total_step, desc="Training", position=0) 86 | outer_bar.n = args.restore_step 87 | outer_bar.update() 88 | 89 | while True: 90 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) 91 | for batchs in loader: 92 | for batch in batchs: 93 | batch = to_device(batch, device) 94 | 95 | # Warm-up Stage 96 | if step <= meta_learning_warmup: 97 | # Forward 98 | output = (None, None, *model(*(batch[2:-5]))) 99 | # Meta Learning 100 | else: 101 | # Step 1: Update Enc_s and G 102 | output = model.module.meta_learner_1(*(batch[2:])) 103 | 104 | # Cal Loss 105 | losses_1 = Loss_1(batch, output) 106 | total_loss = losses_1[0] 107 | 108 | # Backward 109 | backward(model, optimizer_main, total_loss, step, grad_acc_step, grad_clip_thresh) 110 | 111 | # Meta Learning 112 | if step > meta_learning_warmup: 113 | # Step 2: Update D_t and D_s 114 | output_disc = model.module.meta_learner_2(*(batch[2:])) 115 | 116 | losses_2 = Loss_2(batch[2], output_disc) 117 | total_loss_disc = losses_2[0] 118 | 119 | backward(model, optimizer_disc, total_loss_disc, step, grad_acc_step, grad_clip_thresh) 120 | 121 | if step % log_step == 0: 122 | if step > meta_learning_warmup: 123 | losses = [l.item() for l in (losses_1+losses_2[1:])] 124 | else: 125 | losses = [l.item() for l in (losses_1+tuple([torch.zeros(1).to(device) for _ in range(3)]))] 126 | message1 = "Step {}/{}, ".format(step, total_step) 127 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}, Adversarial_D_s Loss: {:.4f}, Adversarial_D_t Loss: {:.4f}, D_s Loss: {:.4f}, D_t Loss: {:.4f}, cls Loss: {:.4f}".format( 128 | *losses 129 | ) 130 | 131 | with open(os.path.join(train_log_path, "log.txt"), "a") as f: 132 | f.write(message1 + message2 + "\n") 133 | 134 | outer_bar.write(message1 + message2) 135 | 136 | log(train_logger, step, losses=losses) 137 | 138 | if step % synth_step == 0: 139 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 140 | batch, 141 | output[2:], 142 | vocoder, 143 | model_config, 144 | preprocess_config, 145 | ) 146 | log( 147 | train_logger, 148 | fig=fig, 149 | tag="Training/step_{}_{}".format(step, tag), 150 | ) 151 | sampling_rate = preprocess_config["preprocessing"]["audio"][ 152 | "sampling_rate" 153 | ] 154 | log( 155 | train_logger, 156 | audio=wav_reconstruction, 157 | sampling_rate=sampling_rate, 158 | tag="Training/step_{}_{}_reconstructed".format(step, tag), 159 | ) 160 | log( 161 | train_logger, 162 | audio=wav_prediction, 163 | sampling_rate=sampling_rate, 164 | tag="Training/step_{}_{}_synthesized".format(step, tag), 165 | ) 166 | 167 | if step % val_step == 0: 168 | model.eval() 169 | message = evaluate(model, step, configs, val_logger, vocoder, len(losses)) 170 | with open(os.path.join(val_log_path, "log.txt"), "a") as f: 171 | f.write(message + "\n") 172 | outer_bar.write(message) 173 | 174 | model.train() 175 | 176 | if step % save_step == 0: 177 | torch.save( 178 | { 179 | "model": model.module.state_dict(), 180 | "optimizer_main": optimizer_main._optimizer.state_dict(), 181 | "optimizer_disc": optimizer_disc._optimizer.state_dict(), 182 | }, 183 | os.path.join( 184 | train_config["path"]["ckpt_path"], 185 | "{}.pth.tar".format(step), 186 | ), 187 | ) 188 | 189 | if step == total_step: 190 | quit() 191 | step += 1 192 | outer_bar.update(1) 193 | 194 | inner_bar.update(1) 195 | epoch += 1 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--restore_step", type=int, default=0) 201 | parser.add_argument( 202 | "-p", 203 | "--preprocess_config", 204 | type=str, 205 | required=True, 206 | help="path to preprocess.yaml", 207 | ) 208 | parser.add_argument( 209 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 210 | ) 211 | parser.add_argument( 212 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 213 | ) 214 | args = parser.parse_args() 215 | 216 | # Read Config 217 | preprocess_config = yaml.load( 218 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 219 | ) 220 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 221 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 222 | configs = (preprocess_config, model_config, train_config) 223 | 224 | main(args, configs) 225 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import hifigan 8 | from model import StyleSpeech, ScheduledOptimMain, ScheduledOptimDisc 9 | 10 | 11 | def get_model(args, configs, device, train=False): 12 | (preprocess_config, model_config, train_config) = configs 13 | 14 | model = StyleSpeech(preprocess_config, model_config).to(device) 15 | if args.restore_step: 16 | ckpt_path = os.path.join( 17 | train_config["path"]["ckpt_path"], 18 | "{}.pth.tar".format(args.restore_step), 19 | ) 20 | ckpt = torch.load(ckpt_path) 21 | model.load_state_dict(ckpt["model"]) 22 | 23 | if train: 24 | scheduled_optim_main = ScheduledOptimMain( 25 | model, train_config, model_config, args.restore_step 26 | ) 27 | scheduled_optim_disc = ScheduledOptimDisc( 28 | model, train_config 29 | ) 30 | if args.restore_step: 31 | scheduled_optim_main.load_state_dict(ckpt["optimizer_main"]) 32 | scheduled_optim_disc.load_state_dict(ckpt["optimizer_disc"]) 33 | model.train() 34 | return model, scheduled_optim_main, scheduled_optim_disc 35 | 36 | model.eval() 37 | model.requires_grad_ = False 38 | return model 39 | 40 | 41 | def get_param_num(model): 42 | num_param = sum(param.numel() for param in model.parameters()) 43 | return num_param 44 | 45 | 46 | def get_vocoder(config, device): 47 | name = config["vocoder"]["model"] 48 | speaker = config["vocoder"]["speaker"] 49 | 50 | if name == "MelGAN": 51 | if speaker == "LJSpeech": 52 | vocoder = torch.hub.load( 53 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 54 | ) 55 | elif speaker == "universal": 56 | vocoder = torch.hub.load( 57 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 58 | ) 59 | vocoder.mel2wav.eval() 60 | vocoder.mel2wav.to(device) 61 | elif name == "HiFi-GAN": 62 | with open("hifigan/config.json", "r") as f: 63 | config = json.load(f) 64 | config = hifigan.AttrDict(config) 65 | vocoder = hifigan.Generator(config) 66 | if speaker == "LJSpeech": 67 | ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar") 68 | elif speaker == "universal": 69 | ckpt = torch.load("hifigan/generator_universal.pth.tar") 70 | vocoder.load_state_dict(ckpt["generator"]) 71 | vocoder.eval() 72 | vocoder.remove_weight_norm() 73 | vocoder.to(device) 74 | 75 | return vocoder 76 | 77 | 78 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): 79 | name = model_config["vocoder"]["model"] 80 | with torch.no_grad(): 81 | if name == "MelGAN": 82 | wavs = vocoder.inverse(mels / np.log(10)) 83 | elif name == "HiFi-GAN": 84 | wavs = vocoder(mels).squeeze(1) 85 | 86 | wavs = ( 87 | wavs.cpu().numpy() 88 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"] 89 | ).astype("int16") 90 | wavs = [wav for wav in wavs] 91 | 92 | for i in range(len(mels)): 93 | if lengths is not None: 94 | wavs[i] = wavs[i][: lengths[i]] 95 | 96 | return wavs 97 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import matplotlib 8 | from scipy.io import wavfile 9 | from matplotlib import pyplot as plt 10 | 11 | 12 | matplotlib.use("Agg") 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def to_device(data, device): 19 | if len(data) == 17: 20 | ( 21 | ids, 22 | raw_texts, 23 | speakers, 24 | texts, 25 | src_lens, 26 | max_src_len, 27 | mels, 28 | mel_lens, 29 | max_mel_len, 30 | pitches, 31 | energies, 32 | durations, 33 | raw_quary_texts, 34 | quary_texts, 35 | quary_src_lens, 36 | max_quary_src_len, 37 | quary_durations, 38 | ) = data 39 | 40 | speakers = torch.from_numpy(speakers).long().to(device) 41 | texts = torch.from_numpy(texts).long().to(device) 42 | src_lens = torch.from_numpy(src_lens).to(device) 43 | quary_texts = torch.from_numpy(quary_texts).long().to(device) 44 | quary_src_lens = torch.from_numpy(quary_src_lens).to(device) 45 | mels = torch.from_numpy(mels).float().to(device) 46 | mel_lens = torch.from_numpy(mel_lens).to(device) 47 | pitches = torch.from_numpy(pitches).float().to(device) 48 | energies = torch.from_numpy(energies).to(device) 49 | durations = torch.from_numpy(durations).long().to(device) 50 | quary_durations = torch.from_numpy(quary_durations).long().to(device) 51 | 52 | return ( 53 | ids, 54 | raw_texts, 55 | speakers, 56 | texts, 57 | src_lens, 58 | max_src_len, 59 | mels, 60 | mel_lens, 61 | max_mel_len, 62 | pitches, 63 | energies, 64 | durations, 65 | raw_quary_texts, 66 | quary_texts, 67 | quary_src_lens, 68 | max_quary_src_len, 69 | quary_durations, 70 | ) 71 | 72 | if len(data) == 10: 73 | ( 74 | ids, 75 | raw_texts, 76 | speakers, 77 | texts, 78 | src_lens, 79 | max_src_len, 80 | mels, 81 | mel_lens, 82 | max_mel_len, 83 | ref_infos, 84 | ) = data 85 | 86 | texts = torch.from_numpy(texts).long().to(device) 87 | src_lens = torch.from_numpy(src_lens).to(device) 88 | mels = torch.from_numpy(mels).float().to(device) 89 | mel_lens = torch.from_numpy(mel_lens).to(device) 90 | 91 | return ( 92 | ids, 93 | raw_texts, 94 | speakers, 95 | texts, 96 | src_lens, 97 | max_src_len, 98 | mels, 99 | mel_lens, 100 | max_mel_len, 101 | ref_infos, 102 | ) 103 | 104 | 105 | def log( 106 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 107 | ): 108 | if losses is not None: 109 | logger.add_scalar("Loss/total_loss", losses[0], step) 110 | logger.add_scalar("Loss/mel_loss", losses[1], step) 111 | logger.add_scalar("Loss/pitch_loss", losses[2], step) 112 | logger.add_scalar("Loss/energy_loss", losses[3], step) 113 | logger.add_scalar("Loss/duration_loss", losses[4], step) 114 | logger.add_scalar("Loss/adv_D_s_loss", losses[5], step) 115 | logger.add_scalar("Loss/adv_D_t_loss", losses[6], step) 116 | logger.add_scalar("Loss/D_s_loss", losses[7], step) 117 | logger.add_scalar("Loss/D_t_loss", losses[8], step) 118 | logger.add_scalar("Loss/cls_loss", losses[9], step) 119 | 120 | if fig is not None: 121 | logger.add_figure(tag, fig) 122 | 123 | if audio is not None: 124 | logger.add_audio( 125 | tag, 126 | audio / max(abs(audio)), 127 | sample_rate=sampling_rate, 128 | ) 129 | 130 | 131 | def get_mask_from_lengths(lengths, max_len=None): 132 | batch_size = lengths.shape[0] 133 | if max_len is None: 134 | max_len = torch.max(lengths).item() 135 | 136 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 137 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 138 | 139 | return mask 140 | 141 | 142 | def expand(values, durations): 143 | out = list() 144 | for value, d in zip(values, durations): 145 | out += [value] * max(0, int(d)) 146 | return np.array(out) 147 | 148 | 149 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): 150 | 151 | basename = targets[0][0] 152 | src_len = predictions[7][0].item() 153 | mel_len = predictions[8][0].item() 154 | mel_target = targets[6][0, :mel_len].detach().transpose(0, 1) 155 | mel_prediction = predictions[0][0, :mel_len].detach().transpose(0, 1) 156 | duration = targets[11][0, :src_len].detach().cpu().numpy() 157 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 158 | pitch = targets[9][0, :src_len].detach().cpu().numpy() 159 | pitch = expand(pitch, duration) 160 | else: 161 | pitch = targets[9][0, :mel_len].detach().cpu().numpy() 162 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 163 | energy = targets[10][0, :src_len].detach().cpu().numpy() 164 | energy = expand(energy, duration) 165 | else: 166 | energy = targets[10][0, :mel_len].detach().cpu().numpy() 167 | 168 | with open( 169 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 170 | ) as f: 171 | stats = json.load(f) 172 | stats = stats["pitch"] + stats["energy"][:2] 173 | 174 | fig = plot_mel( 175 | [ 176 | (mel_prediction.cpu().numpy(), pitch, energy), 177 | (mel_target.cpu().numpy(), pitch, energy), 178 | ], 179 | stats, 180 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 181 | ) 182 | 183 | if vocoder is not None: 184 | from .model import vocoder_infer 185 | 186 | wav_reconstruction = vocoder_infer( 187 | mel_target.unsqueeze(0), 188 | vocoder, 189 | model_config, 190 | preprocess_config, 191 | )[0] 192 | wav_prediction = vocoder_infer( 193 | mel_prediction.unsqueeze(0), 194 | vocoder, 195 | model_config, 196 | preprocess_config, 197 | )[0] 198 | else: 199 | wav_reconstruction = wav_prediction = None 200 | 201 | return fig, wav_reconstruction, wav_prediction, basename 202 | 203 | 204 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): 205 | 206 | basenames = targets[0] 207 | for i in range(len(predictions[0])): 208 | basename = basenames[i] 209 | src_len = predictions[7][i].item() 210 | mel_len = predictions[8][i].item() 211 | mel_prediction = predictions[0][i, :mel_len].detach().transpose(0, 1) 212 | duration = predictions[4][i, :src_len].detach().cpu().numpy() 213 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 214 | pitch = predictions[1][i, :src_len].detach().cpu().numpy() 215 | pitch = expand(pitch, duration) 216 | else: 217 | pitch = predictions[1][i, :mel_len].detach().cpu().numpy() 218 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 219 | energy = predictions[2][i, :src_len].detach().cpu().numpy() 220 | energy = expand(energy, duration) 221 | else: 222 | energy = predictions[2][i, :mel_len].detach().cpu().numpy() 223 | 224 | with open( 225 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 226 | ) as f: 227 | stats = json.load(f) 228 | stats = stats["pitch"] + stats["energy"][:2] 229 | 230 | fig = plot_mel( 231 | [ 232 | (mel_prediction.cpu().numpy(), pitch, energy), 233 | targets[-1][i], 234 | ], 235 | stats, 236 | ["Synthetized Spectrogram", "Reference Spectrogram"], 237 | ) 238 | plt.savefig(os.path.join(path, "{}.png".format(basename))) 239 | plt.close() 240 | 241 | from .model import vocoder_infer 242 | 243 | mel_predictions = predictions[0].transpose(1, 2) 244 | lengths = predictions[8] * preprocess_config["preprocessing"]["stft"]["hop_length"] 245 | wav_predictions = vocoder_infer( 246 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths 247 | ) 248 | 249 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 250 | for wav, basename in zip(wav_predictions, basenames): 251 | wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) 252 | 253 | 254 | def plot_mel(data, stats, titles): 255 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 256 | if titles is None: 257 | titles = [None for i in range(len(data))] 258 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 259 | pitch_min = pitch_min * pitch_std + pitch_mean 260 | pitch_max = pitch_max * pitch_std + pitch_mean 261 | 262 | def add_axis(fig, old_ax): 263 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 264 | ax.set_facecolor("None") 265 | return ax 266 | 267 | for i in range(len(data)): 268 | mel, pitch, energy = data[i] 269 | pitch = pitch * pitch_std + pitch_mean 270 | axes[i][0].imshow(mel, origin="lower") 271 | axes[i][0].set_aspect(2.5, adjustable="box") 272 | axes[i][0].set_ylim(0, mel.shape[0]) 273 | axes[i][0].set_title(titles[i], fontsize="medium") 274 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 275 | axes[i][0].set_anchor("W") 276 | 277 | ax1 = add_axis(fig, axes[i][0]) 278 | ax1.plot(pitch, color="tomato", linewidth=.7) 279 | ax1.set_xlim(0, mel.shape[1]) 280 | ax1.set_ylim(0, pitch_max) 281 | ax1.set_ylabel("F0", color="tomato") 282 | ax1.tick_params( 283 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 284 | ) 285 | 286 | ax2 = add_axis(fig, axes[i][0]) 287 | ax2.plot(energy, color="darkviolet", linewidth=.7) 288 | ax2.set_xlim(0, mel.shape[1]) 289 | ax2.set_ylim(energy_min, energy_max) 290 | ax2.set_ylabel("Energy", color="darkviolet") 291 | ax2.yaxis.set_label_position("right") 292 | ax2.tick_params( 293 | labelsize="x-small", 294 | colors="darkviolet", 295 | bottom=False, 296 | labelbottom=False, 297 | left=False, 298 | labelleft=False, 299 | right=True, 300 | labelright=True, 301 | ) 302 | 303 | return fig 304 | 305 | 306 | def pad_1D(inputs, PAD=0): 307 | def pad_data(x, length, PAD): 308 | x_padded = np.pad( 309 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 310 | ) 311 | return x_padded 312 | 313 | max_len = max((len(x) for x in inputs)) 314 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 315 | 316 | return padded 317 | 318 | 319 | def pad_2D(inputs, maxlen=None): 320 | def pad(x, max_len): 321 | PAD = 0 322 | if np.shape(x)[0] > max_len: 323 | raise ValueError("not max_len") 324 | 325 | s = np.shape(x)[1] 326 | x_padded = np.pad( 327 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 328 | ) 329 | return x_padded[:, :s] 330 | 331 | if maxlen: 332 | output = np.stack([pad(x, maxlen) for x in inputs]) 333 | else: 334 | max_len = max(np.shape(x)[0] for x in inputs) 335 | output = np.stack([pad(x, max_len) for x in inputs]) 336 | 337 | return output 338 | 339 | 340 | def pad(input_ele, mel_max_length=None): 341 | if mel_max_length: 342 | max_len = mel_max_length 343 | else: 344 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 345 | 346 | out_list = list() 347 | for i, batch in enumerate(input_ele): 348 | if len(batch.shape) == 1: 349 | one_batch_padded = F.pad( 350 | batch, (0, max_len - batch.size(0)), "constant", 0.0 351 | ) 352 | elif len(batch.shape) == 2: 353 | one_batch_padded = F.pad( 354 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 355 | ) 356 | out_list.append(one_batch_padded) 357 | out_padded = torch.stack(out_list) 358 | return out_padded 359 | --------------------------------------------------------------------------------