├── .gitignore ├── LICENSE ├── README.md ├── audio ├── __init__.py ├── audio_processing.py ├── stft.py └── tools.py ├── configs └── config.json ├── dataloader.py ├── evaluate.py ├── lexicon └── librispeech-lexicon.txt ├── models ├── Constants.py ├── Discriminators.py ├── Loss.py ├── Modules.py ├── StyleSpeech.py ├── VarianceAdaptor.py └── __init__.py ├── optimizer.py ├── prepare_align.py ├── preprocess.py ├── preprocessors ├── libritts.py └── utils.py ├── requirements.txt ├── synthesize.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py └── symbols.py ├── train.py ├── train_meta.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | dataset/* 4 | exp* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 min95 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation 2 | 3 | **Recent Updates** 4 | -------- 5 | [12/18/2021] 6 | :sparkles: Thanks Guan-Ting Lin for sharing the pre-trained multi-speaker MelGAN vocoder in 16kHz, and the checkpoint is now available in [Pre-trained 16k-MelGAN](https://huggingface.co/Guan-Ting/StyleSpeech-MelGAN-vocoder-16kHz). For the usage details, please follow the instructions in [MelGAN](https://github.com/descriptinc/melgan-neurips). 7 | 8 | [06/09/2021] 9 | Few modifications on the Variance Adaptor wich were found to improve the quality of the model . 1) We replace the architecture of variance emdedding from one Conv1D layer to two Conv1D layers followed by a linear layer. 2) We add a layernorm and phoneme-wise positional encoding. Please refer to [here](models/VarianceAdaptor.py). 10 | 11 | -------- 12 | 13 | Introduction 14 | ---------- 15 | 16 | This is an official code for our recent [paper](https://arxiv.org/abs/2106.03153). 17 | We propose Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation. 18 | We provide our implementation and pretrained models as open source in this repository. 19 | 20 | **Abstract :** 21 | With rapid progress in neural text-to-speech (TTS) models, personalized speech generation is now in high demand for many applications. For practical applicability, a TTS model should generate high-quality speech with only a few audio samples from the given speaker, that are also short in length. However, existing methods either require to fine-tune the model or achieve low adaptation quality without fine-tuning. In this work, we propose StyleSpeech, a new TTS model which not only synthesizes high-quality speech but also effectively adapts to new speakers. Specifically, we propose Style-Adaptive Layer Normalization (SALN) which aligns gain and bias of the text input according to the style extracted from a reference speech audio. With SALN, our model effectively synthesizes speech in the style of the target speaker even from single speech audio. Furthermore, to enhance StyleSpeech's adaptation to speech from new speakers, we extend it to Meta-StyleSpeech by introducing two discriminators trained with style prototypes, and performing episodic training. The experimental results show that our models generate high-quality speech which accurately follows the speaker's voice with single short-duration (1-3 sec) speech audio, significantly outperforming baselines. 22 | 23 | Demo audio samples are avaliable [demo page](https://stylespeech.github.io/). 24 | 25 | 26 | Getting the pretrained models 27 | ---------- 28 | | Model | Link to the model | 29 | | :-------------: | :---------------: | 30 | | Meta-StyleSpeech | [Link](https://drive.google.com/file/d/1xGLGt6bK7IapiKNj9YliMBmP5MCBv9OR/view?usp=sharing) | 31 | | StyleSpeech | [Link](https://drive.google.com/file/d/1Q7yLKnFH4UkOjaszikjaovItNAaTyEVN/view?usp=sharing) | 32 | 33 | 34 | Prerequisites 35 | ------------- 36 | - Clone this repository. 37 | - Install python requirements. Please refer [requirements.txt](requirements.txt) 38 | 39 | 40 | Inference 41 | ------------- 42 | You have to download pretrained models and prepared an audio for reference speech sample. 43 | ```bash 44 | python synthesize.py --text --ref_audio --checkpoint_path 45 | ``` 46 | The generated mel-spectrogram will be saved in `results/` folder. 47 | 48 | 49 | Preprocessing the dataset 50 | ------------- 51 | Our models are trained on [LibriTTS dataset](https://openslr.org/60/). Download, extract and place it in the `dataset/` folder. 52 | 53 | To preprocess the dataset : 54 | First, run 55 | ```bash 56 | python prepare_align.py 57 | ``` 58 | to resample audios to 16kHz and for some other preperations. 59 | 60 | Second, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) (MFA) is used to obtain the alignments between the utterances and the phoneme sequences. 61 | ```bash 62 | ./montreal-forced-aligner/bin/mfa_align dataset/wav16/ lexicon/librispeech-lexicon.txt english datset/TextGrid/ -j 10 -v 63 | ``` 64 | 65 | Third, preprocess the dataset to prepare mel-spectrogram, duration, pitch and energy for fast training. 66 | ```bash 67 | python preprocess.py 68 | ``` 69 | 70 | Train! 71 | ------------- 72 | Train the StyleSpeech from the scratch with 73 | ```bash 74 | python train.py 75 | ``` 76 | 77 | Train the Meta-StyleSpeech from pretrained StyleSpeech with 78 | ```bash 79 | python train_meta.py --checkpoint_path 80 | ``` 81 | 82 | 83 | ## Acknowledgements 84 | We refered to 85 | * [FastSpeech2](https://arxiv.org/abs/2006.04558) 86 | * [ming024's FastSpeech implementation](https://github.com/ming024/FastSpeech2) 87 | * [Mellotron](https://github.com/NVIDIA/mellotron) 88 | * [Tacotron](https://github.com/keithito/tacotron) 89 | -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | def window_sumsquare(window, n_frames, hop_length, win_length, n_fft, dtype=np.float32, norm=None): 7 | """ 8 | # from librosa 0.6 9 | Compute the sum-square envelope of a window function at a given hop length. 10 | 11 | This is used to estimate modulation effects induced by windowing 12 | observations in short-time fourier transforms. 13 | 14 | Parameters 15 | ---------- 16 | window : string, tuple, number, callable, or list-like 17 | Window specification, as in `get_window` 18 | 19 | n_frames : int > 0 20 | The number of analysis frames 21 | 22 | hop_length : int > 0 23 | The number of samples to advance between frames 24 | 25 | win_length : [optional] 26 | The length of the window function. By default, this matches `n_fft`. 27 | 28 | n_fft : int > 0 29 | The length of each analysis frame. 30 | 31 | dtype : np.dtype 32 | The data type of the output 33 | 34 | Returns 35 | ------- 36 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 37 | The sum-squared envelope of the window function 38 | """ 39 | if win_length is None: 40 | win_length = n_fft 41 | 42 | n = n_fft + hop_length * (n_frames - 1) 43 | x = np.zeros(n, dtype=dtype) 44 | 45 | # Compute the squared window at the desired length 46 | win_sq = get_window(window, win_length, fftbins=True) 47 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 48 | win_sq = librosa_util.pad_center(win_sq, n_fft) 49 | 50 | # Fill the envelope 51 | for i in range(n_frames): 52 | sample = i * hop_length 53 | x[sample:min(n, sample + n_fft) 54 | ] += win_sq[:max(0, min(n_fft, n - sample))] 55 | return x 56 | 57 | 58 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 59 | """ 60 | PARAMS 61 | ------ 62 | magnitudes: spectrogram magnitudes 63 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 64 | """ 65 | 66 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 67 | angles = angles.astype(np.float32) 68 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 69 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 70 | 71 | for i in range(n_iters): 72 | _, angles = stft_fn.transform(signal) 73 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 74 | return signal 75 | 76 | 77 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 78 | """ 79 | PARAMS 80 | ------ 81 | C: compression factor 82 | """ 83 | return torch.log10(torch.clamp(x, min=clip_val) * C) 84 | 85 | 86 | def dynamic_range_decompression(x, C=1): 87 | """ 88 | PARAMS 89 | ------ 90 | C: compression factor used to compress 91 | """ 92 | return torch.pow(x, 10) / C 93 | -------------------------------------------------------------------------------- /audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | from scipy.signal import get_window 7 | from librosa.util import pad_center, tiny 8 | from librosa.filters import mel as librosa_mel_fn 9 | 10 | from audio.audio_processing import dynamic_range_compression 11 | from audio.audio_processing import dynamic_range_decompression 12 | from audio.audio_processing import window_sumsquare 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, 19 | window='hann'): 20 | super(STFT, self).__init__() 21 | self.filter_length = filter_length 22 | self.hop_length = hop_length 23 | self.win_length = win_length 24 | self.window = window 25 | self.forward_transform = None 26 | scale = self.filter_length / self.hop_length 27 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 28 | 29 | cutoff = int((self.filter_length / 2 + 1)) 30 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 31 | np.imag(fourier_basis[:cutoff, :])]) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 36 | 37 | if window is not None: 38 | assert(filter_length >= win_length) 39 | # get window and zero center pad it to filter_length 40 | fft_window = get_window(window, win_length, fftbins=True) 41 | fft_window = pad_center(fft_window, filter_length) 42 | fft_window = torch.from_numpy(fft_window).float() 43 | 44 | # window the bases 45 | forward_basis *= fft_window 46 | inverse_basis *= fft_window 47 | 48 | self.register_buffer('forward_basis', forward_basis.float()) 49 | self.register_buffer('inverse_basis', inverse_basis.float()) 50 | 51 | def transform(self, input_data): 52 | num_batches = input_data.size(0) 53 | num_samples = input_data.size(1) 54 | 55 | self.num_samples = num_samples 56 | 57 | # similar to librosa, reflect-pad the input 58 | input_data = input_data.view(num_batches, 1, num_samples) 59 | input_data = F.pad( 60 | input_data.unsqueeze(1), 61 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 62 | mode='reflect') 63 | input_data = input_data.squeeze(1) 64 | 65 | forward_transform = F.conv1d( 66 | input_data.cuda(), 67 | Variable(self.forward_basis, requires_grad=False).cuda(), 68 | stride=self.hop_length, 69 | padding=0).cpu() 70 | 71 | cutoff = int((self.filter_length / 2) + 1) 72 | real_part = forward_transform[:, :cutoff, :] 73 | imag_part = forward_transform[:, cutoff:, :] 74 | 75 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 76 | phase = torch.autograd.Variable( 77 | torch.atan2(imag_part.data, real_part.data)) 78 | 79 | return magnitude, phase 80 | 81 | def inverse(self, magnitude, phase): 82 | recombine_magnitude_phase = torch.cat( 83 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 84 | 85 | inverse_transform = F.conv_transpose1d( 86 | recombine_magnitude_phase, 87 | Variable(self.inverse_basis, requires_grad=False), 88 | stride=self.hop_length, 89 | padding=0) 90 | 91 | if self.window is not None: 92 | window_sum = window_sumsquare( 93 | self.window, magnitude.size(-1), hop_length=self.hop_length, 94 | win_length=self.win_length, n_fft=self.filter_length, 95 | dtype=np.float32) 96 | # remove modulation effects 97 | approx_nonzero_indices = torch.from_numpy( 98 | np.where(window_sum > tiny(window_sum))[0]) 99 | window_sum = torch.autograd.Variable( 100 | torch.from_numpy(window_sum), requires_grad=False) 101 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 102 | inverse_transform[:, :, 103 | approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 104 | 105 | # scale by hop ratio 106 | inverse_transform *= float(self.filter_length) / self.hop_length 107 | 108 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 109 | inverse_transform = inverse_transform[:, 110 | :, :-int(self.filter_length/2):] 111 | 112 | return inverse_transform 113 | 114 | def forward(self, input_data): 115 | self.magnitude, self.phase = self.transform(input_data) 116 | reconstruction = self.inverse(self.magnitude, self.phase) 117 | return reconstruction 118 | 119 | 120 | class TacotronSTFT(torch.nn.Module): 121 | def __init__(self, filter_length, hop_length, win_length, 122 | n_mel_channels, sampling_rate, mel_fmin=0.0, 123 | mel_fmax=8000.0): 124 | super(TacotronSTFT, self).__init__() 125 | self.n_mel_channels = n_mel_channels 126 | self.sampling_rate = sampling_rate 127 | self.stft_fn = STFT(filter_length, hop_length, win_length) 128 | mel_basis = librosa_mel_fn( 129 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 130 | mel_basis = torch.from_numpy(mel_basis).float() 131 | self.register_buffer('mel_basis', mel_basis) 132 | 133 | def spectral_normalize(self, magnitudes): 134 | output = dynamic_range_compression(magnitudes) 135 | return output 136 | 137 | def spectral_de_normalize(self, magnitudes): 138 | output = dynamic_range_decompression(magnitudes) 139 | return output 140 | 141 | def mel_spectrogram(self, y): 142 | """Computes mel-spectrograms from a batch of waves 143 | PARAMS 144 | ------ 145 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 146 | 147 | RETURNS 148 | ------- 149 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 150 | """ 151 | assert(torch.min(y.data) >= -1) 152 | assert(torch.max(y.data) <= 1) 153 | 154 | magnitudes, phases = self.stft_fn.transform(y) 155 | magnitudes = magnitudes.data 156 | mel_output = torch.matmul(self.mel_basis, magnitudes) 157 | mel_output = self.spectral_normalize(mel_output) 158 | energy = torch.norm(magnitudes, dim=1) 159 | 160 | return mel_output, energy 161 | -------------------------------------------------------------------------------- /audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def get_mel_from_wav(audio, _stft): 5 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 6 | audio = torch.autograd.Variable(audio, requires_grad=False) 7 | melspec, energy = _stft.mel_spectrogram(audio) 8 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 9 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 10 | 11 | return melspec, energy 12 | 13 | -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "LibriTTS", 3 | "n_speakers": 1124, 4 | 5 | "text_cleaners": ["english_cleaners"], 6 | 7 | "sampling_rate": 16000, 8 | "filter_length": 1024, 9 | "hop_length": 256, 10 | "win_length": 1024, 11 | "max_wav_value": 32768.0, 12 | "mel_fmin": 0.0, 13 | "mel_fmax": 8000.0, 14 | "n_mel_channels": 80, 15 | "max_seq_len": 1000, 16 | 17 | "encoder_layer": 4, 18 | "encoder_head": 2, 19 | "encoder_hidden": 256, 20 | "decoder_layer": 4, 21 | "decoder_head": 2, 22 | "decoder_hidden": 256, 23 | "fft_conv1d_filter_size": 1024, 24 | "fft_conv1d_kernel_size": [9, 1], 25 | "dropout": 0.1, 26 | 27 | "variance_predictor_filter_size": 256, 28 | "variance_predictor_kernel_size": 3, 29 | "variance_embedding_kernel_size": 3, 30 | "variance_dropout": 0.5, 31 | 32 | "style_hidden": 128, 33 | "style_head": 2, 34 | "style_kernel_size": 5, 35 | "style_vector_dim": 128, 36 | 37 | "batch_size": 48, 38 | "meta_batch_size": 20, 39 | "max_iter": 200000, 40 | "meta_iter": 40000, 41 | "n_warm_up_step": 4000, 42 | "grad_clip_thresh": 1.0, 43 | 44 | "betas":[0.9, 0.98], 45 | "eps":1e-9 46 | } 47 | 48 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import numpy as np 3 | import json 4 | import os 5 | from text import text_to_sequence 6 | from utils import pad_1D, pad_2D, process_meta 7 | 8 | 9 | def prepare_dataloader(data_path, filename, batch_size, shuffle=True, num_workers=2, meta_learning=False, seed=0): 10 | dataset = TextMelDataset(data_path, filename) 11 | if meta_learning: 12 | sampler = MetaBatchSampler(dataset.sid_to_indexes, batch_size, seed=seed) 13 | else: 14 | sampler = None 15 | shuffle = shuffle if sampler is None else None 16 | if meta_learning: 17 | loader = DataLoader(dataset, batch_sampler=sampler, 18 | collate_fn=dataset.collate_fn, num_workers=num_workers, pin_memory=True) 19 | else: 20 | loader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, shuffle=shuffle, 21 | collate_fn=dataset.collate_fn, drop_last=True, num_workers=num_workers) 22 | return loader 23 | 24 | 25 | def replace_outlier(values, max_v, min_v): 26 | values = np.where(valuesmin_v, values, min_v) 28 | return values 29 | 30 | 31 | def norm_mean_std(x, mean, std): 32 | x = (x - mean) / std 33 | return x 34 | 35 | 36 | class TextMelDataset(Dataset): 37 | def __init__(self, data_path, filename="train.txt",): 38 | self.data_path = data_path 39 | self.basename, self.text, self.sid = process_meta(os.path.join(data_path, filename)) 40 | 41 | self.sid_dict = self.create_speaker_table(self.sid) 42 | 43 | with open(os.path.join(data_path, 'stats.json')) as f: 44 | data = f.read() 45 | stats_config = json.loads(data) 46 | self.f0_stat = stats_config["f0_stat"] # max, min, mean, std 47 | self.energy_stat = stats_config["energy_stat"] # max, min, mean, std 48 | 49 | self.create_sid_to_index() 50 | print('Speaker Num :{}'.format(len(self.sid_dict))) 51 | 52 | def create_speaker_table(self, sids): 53 | speaker_ids = np.sort(np.unique(sids)) 54 | d = {speaker_ids[i]: i for i in range(len(speaker_ids))} 55 | return d 56 | 57 | def create_sid_to_index(self): 58 | _sid_to_indexes = {} 59 | # for keeping instance indexes with the same speaker ids 60 | for i, sid in enumerate(self.sid): 61 | if sid in _sid_to_indexes: 62 | _sid_to_indexes[sid].append(i) 63 | else: 64 | _sid_to_indexes[sid] = [i] 65 | self.sid_to_indexes = _sid_to_indexes 66 | 67 | def __len__(self): 68 | return len(self.text) 69 | 70 | def __getitem__(self, idx): 71 | basename = self.basename[idx] 72 | sid = self.sid_dict[self.sid[idx]] 73 | phone = np.array(text_to_sequence(self.text[idx], [])) 74 | mel_path = os.path.join( 75 | self.data_path, "mel", "libritts-mel-{}.npy".format(basename)) 76 | mel_target = np.load(mel_path) 77 | D_path = os.path.join( 78 | self.data_path, "alignment", "libritts-ali-{}.npy".format(basename)) 79 | D = np.load(D_path) 80 | f0_path = os.path.join( 81 | self.data_path, "f0", "libritts-f0-{}.npy".format(basename)) 82 | f0 = np.load(f0_path) 83 | f0 = replace_outlier(f0, self.f0_stat[0], self.f0_stat[1]) 84 | f0 = norm_mean_std(f0, self.f0_stat[2], self.f0_stat[3]) 85 | energy_path = os.path.join( 86 | self.data_path, "energy", "libritts-energy-{}.npy".format(basename)) 87 | energy = np.load(energy_path) 88 | energy = replace_outlier(energy, self.energy_stat[0], self.energy_stat[1]) 89 | energy = norm_mean_std(energy, self.energy_stat[2], self.energy_stat[3]) 90 | 91 | sample = {"id": basename, 92 | "sid": sid, 93 | "text": phone, 94 | "mel_target": mel_target, 95 | "D": D, 96 | "f0": f0, 97 | "energy": energy} 98 | 99 | return sample 100 | 101 | def reprocess(self, batch, cut_list): 102 | ids = [batch[ind]["id"] for ind in cut_list] 103 | sids = [batch[ind]["sid"] for ind in cut_list] 104 | texts = [batch[ind]["text"] for ind in cut_list] 105 | mel_targets = [batch[ind]["mel_target"] for ind in cut_list] 106 | Ds = [batch[ind]["D"] for ind in cut_list] 107 | f0s = [batch[ind]["f0"] for ind in cut_list] 108 | energies = [batch[ind]["energy"] for ind in cut_list] 109 | for text, D, id_ in zip(texts, Ds, ids): 110 | if len(text) != len(D): 111 | print(text, text.shape, D, D.shape, id_) 112 | length_text = np.array(list()) 113 | for text in texts: 114 | length_text = np.append(length_text, text.shape[0]) 115 | 116 | length_mel = np.array(list()) 117 | for mel in mel_targets: 118 | length_mel = np.append(length_mel, mel.shape[0]) 119 | 120 | texts = pad_1D(texts) 121 | Ds = pad_1D(Ds) 122 | mel_targets = pad_2D(mel_targets) 123 | f0s = pad_1D(f0s) 124 | energies = pad_1D(energies) 125 | log_Ds = np.log(Ds + 1.) 126 | 127 | out = {"id": ids, 128 | "sid": np.array(sids), 129 | "text": texts, 130 | "mel_target": mel_targets, 131 | "D": Ds, 132 | "log_D": log_Ds, 133 | "f0": f0s, 134 | "energy": energies, 135 | "src_len": length_text, 136 | "mel_len": length_mel} 137 | 138 | return out 139 | 140 | def collate_fn(self, batch): 141 | len_arr = np.array([d["text"].shape[0] for d in batch]) 142 | index_arr = np.argsort(-len_arr) 143 | output = self.reprocess(batch, index_arr) 144 | 145 | return output 146 | 147 | 148 | class MetaBatchSampler(): 149 | def __init__(self, sid_to_idx, batch_size, max_iter=100000, seed=0): 150 | # iterdict contains {sid: [idx1, idx2, ...]} 151 | np.random.seed(seed) 152 | 153 | self.sids = list(sid_to_idx.keys()) 154 | np.random.shuffle(self.sids) 155 | 156 | self.sid_to_idx = sid_to_idx 157 | self.batch_size = batch_size 158 | self.max_iter = max_iter 159 | 160 | def __iter__(self): 161 | for _ in range(self.max_iter): 162 | selected_sids = np.random.choice(self.sids, self.batch_size, replace=False) 163 | batch = [] 164 | for sid in selected_sids: 165 | idx = np.random.choice(self.sid_to_idx[sid], 1)[0] 166 | batch.append(idx) 167 | 168 | assert len(batch) == self.batch_size 169 | yield batch 170 | 171 | def __len__(self): 172 | return self.max_iter 173 | 174 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataloader import prepare_dataloader 3 | 4 | 5 | def evaluate(args, model, step): 6 | # Get dataset 7 | data_loader = prepare_dataloader(args.data_path, "val.txt", batch_size=50, shuffle=False) 8 | 9 | # Get loss function 10 | Loss = model.get_criterion() 11 | 12 | # Evaluation 13 | mel_l_list = [] 14 | d_l_list = [] 15 | f_l_list = [] 16 | e_l_list = [] 17 | current_step = 0 18 | for i, batch in enumerate(data_loader): 19 | # Get Data 20 | id_ = batch["id"] 21 | sid, text, mel_target, D, log_D, f0, energy, \ 22 | src_len, mel_len, max_src_len, max_mel_len = model.parse_batch(batch) 23 | 24 | with torch.no_grad(): 25 | # Forward 26 | mel_output, _, _, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model( 27 | text, src_len, mel_target, mel_len, D, f0, energy, max_src_len, max_mel_len) 28 | # Cal Loss 29 | mel_loss, d_loss, f_loss, e_loss = Loss(mel_output, mel_target, 30 | log_duration_output, log_D, f0_output, f0, energy_output, energy, src_len, mel_len) 31 | 32 | # Logger 33 | m_l = mel_loss.item() 34 | d_l = d_loss.item() 35 | f_l = f_loss.item() 36 | e_l = e_loss.item() 37 | 38 | mel_l_list.append(m_l) 39 | d_l_list.append(d_l) 40 | f_l_list.append(f_l) 41 | e_l_list.append(e_l) 42 | 43 | current_step += 1 44 | 45 | mel_l = sum(mel_l_list) / len(mel_l_list) 46 | d_l = sum(d_l_list) / len(d_l_list) 47 | f_l = sum(f_l_list) / len(f_l_list) 48 | e_l = sum(e_l_list) / len(e_l_list) 49 | 50 | return mel_l, d_l, f_l, e_l 51 | 52 | 53 | -------------------------------------------------------------------------------- /models/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = '' 7 | UNK_WORD = '' 8 | BOS_WORD = '' 9 | EOS_WORD = '' 10 | -------------------------------------------------------------------------------- /models/Discriminators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.Modules import ConvNorm, LinearNorm, MultiHeadAttention, get_sinusoid_encoding_table 5 | from models.Loss import LSGANLoss 6 | 7 | LEAKY_RELU = 0.1 8 | 9 | 10 | def dot_product_logit(a, b): 11 | n = a.size(0) 12 | m = b.size(0) 13 | a = a.unsqueeze(1).expand(n, m, -1) 14 | b = b.unsqueeze(0).expand(n, m, -1) 15 | logits = (a*b).sum(dim=2) 16 | return logits 17 | 18 | 19 | class Discriminator(nn.Module): 20 | ''' Discriminator ''' 21 | def __init__(self, config): 22 | super(Discriminator, self).__init__() 23 | 24 | self.style_D = StyleDiscriminator( 25 | config.n_speakers, 26 | config.n_mel_channels, 27 | config.style_vector_dim, 28 | config.style_vector_dim, 29 | config.style_kernel_size, 30 | config.style_head) 31 | 32 | self.phoneme_D = PhonemeDiscriminator( 33 | config.n_mel_channels, 34 | config.encoder_hidden, 35 | config.max_seq_len) 36 | 37 | 38 | def forward(self, mels, srcs, ws, sids, mask): 39 | mels = mels.masked_fill(mask.unsqueeze(-1), 0) 40 | 41 | t_val = self.phoneme_D(mels, srcs, mask) 42 | s_val, ce_loss = self.style_D(mels, ws, sids, mask) 43 | 44 | return t_val, s_val, ce_loss 45 | 46 | def get_criterion(self): 47 | return LSGANLoss() 48 | 49 | 50 | 51 | class StyleDiscriminator(nn.Module): 52 | ''' Style Discriminator ''' 53 | def __init__(self, n_speakers, input_dim, hidden_dim, style_dim, kernel_size, n_head): 54 | super(StyleDiscriminator, self).__init__() 55 | 56 | self.style_prototypes = nn.Embedding(n_speakers, style_dim) 57 | 58 | self.spectral = nn.Sequential( 59 | LinearNorm(input_dim, hidden_dim, spectral_norm=True), 60 | nn.LeakyReLU(LEAKY_RELU), 61 | LinearNorm(hidden_dim, hidden_dim, spectral_norm=True), 62 | nn.LeakyReLU(LEAKY_RELU), 63 | ) 64 | 65 | self.temporal = nn.ModuleList([nn.Sequential( 66 | ConvNorm(hidden_dim, hidden_dim, kernel_size, spectral_norm=True), 67 | nn.LeakyReLU(LEAKY_RELU)) for _ in range(2)]) 68 | 69 | self.slf_attn = MultiHeadAttention(n_head, hidden_dim, hidden_dim//n_head, hidden_dim//n_head, spectral_norm=True) 70 | 71 | self.fc = LinearNorm(hidden_dim, hidden_dim, spectral_norm=True) 72 | 73 | self.V = LinearNorm(style_dim, hidden_dim, spectral_norm=True) 74 | 75 | self.w = nn.Parameter(torch.ones(1)) 76 | self.b = nn.Parameter(torch.zeros(1)) 77 | 78 | def temporal_avg_pool(self, xs, mask): 79 | xs = xs.masked_fill(mask.unsqueeze(-1), 0) 80 | len_ = (~mask).sum(dim=1).unsqueeze(1) 81 | xs = torch.sum(xs, dim=1) 82 | xs = torch.div(xs, len_) 83 | return xs 84 | 85 | def forward(self, mels, ws, sids, mask): 86 | max_len = mels.shape[1] 87 | 88 | # Update style prototypes 89 | if ws is not None: 90 | style_prototypes = self.style_prototypes.weight.clone() 91 | logit = dot_product_logit(ws, style_prototypes) 92 | cls_loss = F.cross_entropy(logit, sids) 93 | else: 94 | cls_loss = None 95 | 96 | # Style discriminator 97 | x = self.spectral(mels) 98 | 99 | for conv in self.temporal: 100 | residual = x 101 | x = x.transpose(1,2) 102 | x = conv(x) 103 | x = x.transpose(1,2) 104 | x = residual + x 105 | 106 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 107 | x, _ = self.slf_attn(x, slf_attn_mask) 108 | 109 | x = self.fc(x) 110 | h = self.temporal_avg_pool(x, mask) 111 | 112 | ps = self.style_prototypes(sids) 113 | s_val = self.w * torch.sum(self.V(ps)*h, dim=1) + self.b 114 | 115 | return s_val, cls_loss 116 | 117 | 118 | class PhonemeDiscriminator(nn.Module): 119 | ''' Phoneme Discriminator ''' 120 | def __init__(self, input_dim, hidden_dim, max_seq_len): 121 | super(PhonemeDiscriminator, self).__init__() 122 | self.hidden_dim = hidden_dim 123 | self.max_seq_len = max_seq_len 124 | 125 | self.mel_prenet = nn.Sequential( 126 | LinearNorm(input_dim, hidden_dim, spectral_norm=True), 127 | nn.LeakyReLU(LEAKY_RELU), 128 | LinearNorm(hidden_dim, hidden_dim, spectral_norm=True), 129 | nn.LeakyReLU(LEAKY_RELU), 130 | ) 131 | 132 | n_position = max_seq_len + 1 133 | self.position_enc = nn.Parameter( 134 | get_sinusoid_encoding_table(n_position, hidden_dim).unsqueeze(0), 135 | requires_grad = False) 136 | 137 | self.fcs = nn.Sequential( 138 | LinearNorm(hidden_dim*2, hidden_dim*2, spectral_norm=True), 139 | nn.LeakyReLU(LEAKY_RELU), 140 | LinearNorm(hidden_dim*2, hidden_dim*2, spectral_norm=True), 141 | nn.LeakyReLU(LEAKY_RELU), 142 | LinearNorm(hidden_dim*2, hidden_dim*2, spectral_norm=True), 143 | nn.LeakyReLU(LEAKY_RELU), 144 | LinearNorm(hidden_dim*2, 1, spectral_norm=True) 145 | ) 146 | 147 | 148 | def forward(self, mels, srcs, mask): 149 | batch_size, max_len = mels.shape[0], mels.shape[1] 150 | 151 | mels = self.mel_prenet(mels) 152 | 153 | if srcs.shape[1] > self.max_seq_len: 154 | position_embed = get_sinusoid_encoding_table(srcs.shape[1], self.hidden_dim)[:srcs.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(srcs.device) 155 | else: 156 | position_embed = self.position_enc[:, :max_len, :].expand(batch_size, -1, -1) 157 | srcs = srcs + position_embed 158 | 159 | xs = torch.cat((mels, srcs), dim=-1) 160 | xs = self.fcs(xs) 161 | t_val = xs.squeeze(-1) 162 | mel_len = (~mask).sum(-1) 163 | 164 | # temporal avg pooling 165 | t_val = t_val.masked_fill(mask, 0.) 166 | t_val = torch.div(torch.sum(t_val, dim=1), mel_len) 167 | return t_val 168 | -------------------------------------------------------------------------------- /models/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class StyleSpeechLoss(nn.Module): 6 | """ StyleSpeech Loss """ 7 | def __init__(self): 8 | super(StyleSpeechLoss, self).__init__() 9 | self.mse_loss = nn.MSELoss() 10 | self.mae_loss = nn.L1Loss() 11 | 12 | def forward(self, mel, mel_target, log_d_predicted, log_d_target, 13 | p_predicted, p_target, e_predicted, e_target, src_len, mel_len): 14 | B = mel_target.shape[0] 15 | log_d_target.requires_grad = False 16 | p_target.requires_grad = False 17 | e_target.requires_grad = False 18 | mel_target.requires_grad = False 19 | 20 | mel_loss = 0. 21 | d_loss = 0. 22 | p_loss = 0. 23 | e_loss = 0. 24 | 25 | for b, (mel_l, src_l) in enumerate(zip(mel_len, src_len)): 26 | mel_loss += self.mae_loss(mel[b, :mel_l, :], mel_target[b, :mel_l, :]) 27 | d_loss += self.mse_loss(log_d_predicted[b, :src_l], log_d_target[b, :src_l]) 28 | p_loss += self.mse_loss(p_predicted[b, :src_l], p_target[b, :src_l]) 29 | e_loss += self.mse_loss(e_predicted[b, :src_l], e_target[b, :src_l]) 30 | 31 | mel_loss = mel_loss / B 32 | d_loss = d_loss / B 33 | p_loss = p_loss / B 34 | e_loss = e_loss / B 35 | 36 | return mel_loss, d_loss, p_loss, e_loss 37 | 38 | 39 | class LSGANLoss(nn.Module): 40 | """ LSGAN Loss """ 41 | def __init__(self): 42 | super(LSGANLoss, self).__init__() 43 | self.criterion = nn.MSELoss() 44 | 45 | def forward(self, r, is_real): 46 | if is_real: 47 | ones = torch.ones(r.size(), requires_grad=False).to(r.device) 48 | loss = self.criterion(r, ones) 49 | else: 50 | zeros = torch.zeros(r.size(), requires_grad=False).to(r.device) 51 | loss = self.criterion(r, zeros) 52 | return loss -------------------------------------------------------------------------------- /models/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | 7 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 8 | ''' Sinusoid position encoding table ''' 9 | 10 | def cal_angle(position, hid_idx): 11 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 12 | 13 | def get_posi_angle_vec(position): 14 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 15 | 16 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) 17 | for pos_i in range(n_position)]) 18 | 19 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 20 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 21 | 22 | if padding_idx is not None: 23 | # zero vector for padding dimension 24 | sinusoid_table[padding_idx] = 0. 25 | return torch.FloatTensor(sinusoid_table) 26 | 27 | 28 | class Mish(nn.Module): 29 | def __init__(self): 30 | super(Mish, self).__init__() 31 | def forward(self, x): 32 | return x * torch.tanh(F.softplus(x)) 33 | 34 | 35 | class AffineLinear(nn.Module): 36 | def __init__(self, in_dim, out_dim): 37 | super(AffineLinear, self).__init__() 38 | affine = nn.Linear(in_dim, out_dim) 39 | self.affine = affine 40 | 41 | def forward(self, input): 42 | return self.affine(input) 43 | 44 | 45 | class StyleAdaptiveLayerNorm(nn.Module): 46 | def __init__(self, in_channel, style_dim): 47 | super(StyleAdaptiveLayerNorm, self).__init__() 48 | self.in_channel = in_channel 49 | self.norm = nn.LayerNorm(in_channel, elementwise_affine=False) 50 | 51 | self.style = AffineLinear(style_dim, in_channel * 2) 52 | self.style.affine.bias.data[:in_channel] = 1 53 | self.style.affine.bias.data[in_channel:] = 0 54 | 55 | def forward(self, input, style_code): 56 | # style 57 | style = self.style(style_code).unsqueeze(1) 58 | gamma, beta = style.chunk(2, dim=-1) 59 | 60 | out = self.norm(input) 61 | out = gamma * out + beta 62 | return out 63 | 64 | 65 | class LinearNorm(nn.Module): 66 | def __init__(self, 67 | in_channels, 68 | out_channels, 69 | bias=True, 70 | spectral_norm=False, 71 | ): 72 | super(LinearNorm, self).__init__() 73 | self.fc = nn.Linear(in_channels, out_channels, bias) 74 | 75 | if spectral_norm: 76 | self.fc = nn.utils.spectral_norm(self.fc) 77 | 78 | def forward(self, input): 79 | out = self.fc(input) 80 | return out 81 | 82 | 83 | class ConvNorm(nn.Module): 84 | def __init__(self, 85 | in_channels, 86 | out_channels, 87 | kernel_size=1, 88 | stride=1, 89 | padding=None, 90 | dilation=1, 91 | bias=True, 92 | spectral_norm=False, 93 | ): 94 | super(ConvNorm, self).__init__() 95 | 96 | if padding is None: 97 | assert(kernel_size % 2 == 1) 98 | padding = int(dilation * (kernel_size - 1) / 2) 99 | 100 | self.conv = torch.nn.Conv1d(in_channels, 101 | out_channels, 102 | kernel_size=kernel_size, 103 | stride=stride, 104 | padding=padding, 105 | dilation=dilation, 106 | bias=bias) 107 | 108 | if spectral_norm: 109 | self.conv = nn.utils.spectral_norm(self.conv) 110 | 111 | def forward(self, input): 112 | out = self.conv(input) 113 | return out 114 | 115 | 116 | class MultiHeadAttention(nn.Module): 117 | ''' Multi-Head Attention module ''' 118 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False): 119 | super().__init__() 120 | 121 | self.n_head = n_head 122 | self.d_k = d_k 123 | self.d_v = d_v 124 | 125 | self.w_qs = nn.Linear(d_model, n_head * d_k) 126 | self.w_ks = nn.Linear(d_model, n_head * d_k) 127 | self.w_vs = nn.Linear(d_model, n_head * d_v) 128 | 129 | self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout) 130 | 131 | self.fc = nn.Linear(n_head * d_v, d_model) 132 | self.dropout = nn.Dropout(dropout) 133 | 134 | if spectral_norm: 135 | self.w_qs = nn.utils.spectral_norm(self.w_qs) 136 | self.w_ks = nn.utils.spectral_norm(self.w_ks) 137 | self.w_vs = nn.utils.spectral_norm(self.w_vs) 138 | self.fc = nn.utils.spectral_norm(self.fc) 139 | 140 | def forward(self, x, mask=None): 141 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 142 | sz_b, len_x, _ = x.size() 143 | 144 | residual = x 145 | 146 | q = self.w_qs(x).view(sz_b, len_x, n_head, d_k) 147 | k = self.w_ks(x).view(sz_b, len_x, n_head, d_k) 148 | v = self.w_vs(x).view(sz_b, len_x, n_head, d_v) 149 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, 150 | len_x, d_k) # (n*b) x lq x dk 151 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, 152 | len_x, d_k) # (n*b) x lk x dk 153 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, 154 | len_x, d_v) # (n*b) x lv x dv 155 | 156 | if mask is not None: 157 | slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 158 | else: 159 | slf_mask = None 160 | output, attn = self.attention(q, k, v, mask=slf_mask) 161 | 162 | output = output.view(n_head, sz_b, len_x, d_v) 163 | output = output.permute(1, 2, 0, 3).contiguous().view( 164 | sz_b, len_x, -1) # b x lq x (n*dv) 165 | 166 | output = self.fc(output) 167 | 168 | output = self.dropout(output) + residual 169 | return output, attn 170 | 171 | 172 | class ScaledDotProductAttention(nn.Module): 173 | ''' Scaled Dot-Product Attention ''' 174 | 175 | def __init__(self, temperature, dropout): 176 | super().__init__() 177 | self.temperature = temperature 178 | self.softmax = nn.Softmax(dim=2) 179 | self.dropout = nn.Dropout(dropout) 180 | 181 | def forward(self, q, k, v, mask=None): 182 | 183 | attn = torch.bmm(q, k.transpose(1, 2)) 184 | attn = attn / self.temperature 185 | 186 | if mask is not None: 187 | attn = attn.masked_fill(mask, -np.inf) 188 | 189 | attn = self.softmax(attn) 190 | p_attn = self.dropout(attn) 191 | 192 | output = torch.bmm(p_attn, v) 193 | return output, attn 194 | 195 | 196 | class Conv1dGLU(nn.Module): 197 | ''' 198 | Conv1d + GLU(Gated Linear Unit) with residual connection. 199 | For GLU refer to https://arxiv.org/abs/1612.08083 paper. 200 | ''' 201 | def __init__(self, in_channels, out_channels, kernel_size, dropout): 202 | super(Conv1dGLU, self).__init__() 203 | self.out_channels = out_channels 204 | self.conv1 = ConvNorm(in_channels, 2*out_channels, kernel_size=kernel_size) 205 | self.dropout = nn.Dropout(dropout) 206 | 207 | def forward(self, x): 208 | residual = x 209 | x = self.conv1(x) 210 | x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) 211 | x = x1 * torch.sigmoid(x2) 212 | x = residual + self.dropout(x) 213 | return x 214 | 215 | 216 | -------------------------------------------------------------------------------- /models/StyleSpeech.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from text.symbols import symbols 5 | import models.Constants as Constants 6 | from models.Modules import Mish, LinearNorm, ConvNorm, Conv1dGLU, \ 7 | MultiHeadAttention, StyleAdaptiveLayerNorm, get_sinusoid_encoding_table 8 | from models.VarianceAdaptor import VarianceAdaptor 9 | from models.Loss import StyleSpeechLoss 10 | from utils import get_mask_from_lengths 11 | 12 | 13 | class StyleSpeech(nn.Module): 14 | ''' StyleSpeech ''' 15 | def __init__(self, config): 16 | super(StyleSpeech, self).__init__() 17 | self.style_encoder = MelStyleEncoder(config) 18 | self.encoder = Encoder(config) 19 | self.variance_adaptor = VarianceAdaptor(config) 20 | self.decoder = Decoder(config) 21 | 22 | def parse_batch(self, batch): 23 | sid = torch.from_numpy(batch["sid"]).long().cuda() 24 | text = torch.from_numpy(batch["text"]).long().cuda() 25 | mel_target = torch.from_numpy(batch["mel_target"]).float().cuda() 26 | D = torch.from_numpy(batch["D"]).long().cuda() 27 | log_D = torch.from_numpy(batch["log_D"]).float().cuda() 28 | f0 = torch.from_numpy(batch["f0"]).float().cuda() 29 | energy = torch.from_numpy(batch["energy"]).float().cuda() 30 | src_len = torch.from_numpy(batch["src_len"]).long().cuda() 31 | mel_len = torch.from_numpy(batch["mel_len"]).long().cuda() 32 | max_src_len = np.max(batch["src_len"]).astype(np.int32) 33 | max_mel_len = np.max(batch["mel_len"]).astype(np.int32) 34 | return sid, text, mel_target, D, log_D, f0, energy, src_len, mel_len, max_src_len, max_mel_len 35 | 36 | def forward(self, src_seq, src_len, mel_target, mel_len=None, 37 | d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None): 38 | src_mask = get_mask_from_lengths(src_len, max_src_len) 39 | mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None 40 | 41 | # Extract Style Vector 42 | style_vector = self.style_encoder(mel_target, mel_mask) 43 | # Encoding 44 | encoder_output, src_embedded, _ = self.encoder(src_seq, style_vector, src_mask) 45 | # Variance Adaptor 46 | acoustic_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( 47 | encoder_output, src_mask, mel_len, mel_mask, 48 | d_target, p_target, e_target, max_mel_len) 49 | # Deocoding 50 | mel_prediction, _ = self.decoder(acoustic_adaptor_output, style_vector, mel_mask) 51 | 52 | return mel_prediction, src_embedded, style_vector, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len 53 | 54 | def inference(self, style_vector, src_seq, src_len=None, max_src_len=None, return_attn=False): 55 | src_mask = get_mask_from_lengths(src_len, max_src_len) 56 | 57 | # Encoding 58 | encoder_output, src_embedded, enc_slf_attn = self.encoder(src_seq, style_vector, src_mask) 59 | 60 | # Variance Adaptor 61 | acoustic_adaptor_output, d_prediction, p_prediction, e_prediction, \ 62 | mel_len, mel_mask = self.variance_adaptor(encoder_output, src_mask) 63 | 64 | # Deocoding 65 | mel_output, dec_slf_attn = self.decoder(acoustic_adaptor_output, style_vector, mel_mask) 66 | 67 | if return_attn: 68 | return enc_slf_attn, dec_slf_attn 69 | 70 | return mel_output, src_embedded, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len 71 | 72 | def get_style_vector(self, mel_target, mel_len=None): 73 | mel_mask = get_mask_from_lengths(mel_len) if mel_len is not None else None 74 | style_vector = self.style_encoder(mel_target, mel_mask) 75 | 76 | return style_vector 77 | 78 | def get_criterion(self): 79 | return StyleSpeechLoss() 80 | 81 | 82 | class Encoder(nn.Module): 83 | ''' Encoder ''' 84 | def __init__(self, config, n_src_vocab=len(symbols)+1): 85 | super(Encoder, self).__init__() 86 | self.max_seq_len = config.max_seq_len 87 | self.n_layers = config.encoder_layer 88 | self.d_model = config.encoder_hidden 89 | self.n_head = config.encoder_head 90 | self.d_k = config.encoder_hidden // config.encoder_head 91 | self.d_v = config.encoder_hidden // config.encoder_head 92 | self.d_inner = config.fft_conv1d_filter_size 93 | self.fft_conv1d_kernel_size = config.fft_conv1d_kernel_size 94 | self.d_out = config.decoder_hidden 95 | self.style_dim = config.style_vector_dim 96 | self.dropout = config.dropout 97 | 98 | self.src_word_emb = nn.Embedding(n_src_vocab, self.d_model, padding_idx=Constants.PAD) 99 | self.prenet = Prenet(self.d_model, self.d_model, self.dropout) 100 | 101 | n_position = self.max_seq_len + 1 102 | self.position_enc = nn.Parameter( 103 | get_sinusoid_encoding_table(n_position, self.d_model).unsqueeze(0), requires_grad = False) 104 | 105 | self.layer_stack = nn.ModuleList([FFTBlock( 106 | self.d_model, self.d_inner, self.n_head, self.d_k, self.d_v, 107 | self.fft_conv1d_kernel_size, self.style_dim, self.dropout) for _ in range(self.n_layers)]) 108 | 109 | self.fc_out = nn.Linear(self.d_model, self.d_out) 110 | 111 | def forward(self, src_seq, style_vector, mask): 112 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 113 | 114 | # -- Prepare masks 115 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 116 | 117 | # -- Forward 118 | # word embedding 119 | src_embedded = self.src_word_emb(src_seq) 120 | # prenet 121 | src_seq = self.prenet(src_embedded, mask) 122 | # position encoding 123 | if src_seq.shape[1] > self.max_seq_len: 124 | position_embedded = get_sinusoid_encoding_table(src_seq.shape[1], self.d_model)[:src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(src_seq.device) 125 | else: 126 | position_embedded = self.position_enc[:, :max_len, :].expand(batch_size, -1, -1) 127 | enc_output = src_seq + position_embedded 128 | # fft blocks 129 | slf_attn = [] 130 | for enc_layer in self.layer_stack: 131 | enc_output, enc_slf_attn = enc_layer( 132 | enc_output, style_vector, 133 | mask=mask, 134 | slf_attn_mask=slf_attn_mask) 135 | slf_attn.append(enc_slf_attn) 136 | # last fc 137 | enc_output = self.fc_out(enc_output) 138 | return enc_output, src_embedded, slf_attn 139 | 140 | 141 | class Decoder(nn.Module): 142 | """ Decoder """ 143 | def __init__(self, config): 144 | super(Decoder, self).__init__() 145 | self.max_seq_len = config.max_seq_len 146 | self.n_layers = config.decoder_layer 147 | self.d_model = config.decoder_hidden 148 | self.n_head = config.decoder_head 149 | self.d_k = config.decoder_hidden // config.decoder_head 150 | self.d_v = config.decoder_hidden // config.decoder_head 151 | self.d_inner = config.fft_conv1d_filter_size 152 | self.fft_conv1d_kernel_size = config.fft_conv1d_kernel_size 153 | self.d_out = config.n_mel_channels 154 | self.style_dim = config.style_vector_dim 155 | self.dropout = config.dropout 156 | 157 | self.prenet = nn.Sequential( 158 | nn.Linear(self.d_model, self.d_model//2), 159 | Mish(), 160 | nn.Dropout(self.dropout), 161 | nn.Linear(self.d_model//2, self.d_model) 162 | ) 163 | 164 | n_position = self.max_seq_len + 1 165 | self.position_enc = nn.Parameter( 166 | get_sinusoid_encoding_table(n_position, self.d_model).unsqueeze(0), requires_grad = False) 167 | 168 | self.layer_stack = nn.ModuleList([FFTBlock( 169 | self.d_model, self.d_inner, self.n_head, self.d_k, self.d_v, 170 | self.fft_conv1d_kernel_size, self.style_dim, self.dropout) for _ in range(self.n_layers)]) 171 | 172 | self.fc_out = nn.Linear(self.d_model, self.d_out) 173 | 174 | def forward(self, enc_seq, style_code, mask): 175 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 176 | # -- Prepare masks 177 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 178 | 179 | # -- Forward 180 | # prenet 181 | dec_embedded = self.prenet(enc_seq) 182 | # poistion encoding 183 | if enc_seq.shape[1] > self.max_seq_len: 184 | position_embedded = get_sinusoid_encoding_table(enc_seq.shape[1], self.d_model)[:enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(enc_seq.device) 185 | else: 186 | position_embedded = self.position_enc[:, :max_len, :].expand(batch_size, -1, -1) 187 | dec_output = dec_embedded + position_embedded 188 | # fft blocks 189 | slf_attn = [] 190 | for dec_layer in self.layer_stack: 191 | dec_output, dec_slf_attn = dec_layer( 192 | dec_output, style_code, 193 | mask=mask, 194 | slf_attn_mask=slf_attn_mask) 195 | slf_attn.append(dec_slf_attn) 196 | # last fc 197 | dec_output = self.fc_out(dec_output) 198 | return dec_output, slf_attn 199 | 200 | 201 | class FFTBlock(nn.Module): 202 | ''' FFT Block ''' 203 | def __init__(self, d_model,d_inner, 204 | n_head, d_k, d_v, fft_conv1d_kernel_size, style_dim, dropout): 205 | super(FFTBlock, self).__init__() 206 | self.slf_attn = MultiHeadAttention( 207 | n_head, d_model, d_k, d_v, dropout=dropout) 208 | self.saln_0 = StyleAdaptiveLayerNorm(d_model, style_dim) 209 | 210 | self.pos_ffn = PositionwiseFeedForward( 211 | d_model, d_inner, fft_conv1d_kernel_size, dropout=dropout) 212 | self.saln_1 = StyleAdaptiveLayerNorm(d_model, style_dim) 213 | 214 | def forward(self, input, style_vector, mask=None, slf_attn_mask=None): 215 | # multi-head self attn 216 | slf_attn_output, slf_attn = self.slf_attn(input, mask=slf_attn_mask) 217 | slf_attn_output = self.saln_0(slf_attn_output, style_vector) 218 | if mask is not None: 219 | slf_attn_output = slf_attn_output.masked_fill(mask.unsqueeze(-1), 0) 220 | 221 | # position wise FF 222 | output = self.pos_ffn(slf_attn_output) 223 | output = self.saln_1(output, style_vector) 224 | if mask is not None: 225 | output = output.masked_fill(mask.unsqueeze(-1), 0) 226 | 227 | return output, slf_attn 228 | 229 | 230 | class PositionwiseFeedForward(nn.Module): 231 | ''' A two-feed-forward-layer module ''' 232 | def __init__(self, d_in, d_hid, fft_conv1d_kernel_size, dropout=0.1): 233 | super().__init__() 234 | self.w_1 = ConvNorm(d_in, d_hid, kernel_size=fft_conv1d_kernel_size[0]) 235 | self.w_2 = ConvNorm(d_hid, d_in, kernel_size=fft_conv1d_kernel_size[1]) 236 | 237 | self.mish = Mish() 238 | self.dropout = nn.Dropout(dropout) 239 | 240 | def forward(self, input): 241 | residual = input 242 | 243 | output = input.transpose(1, 2) 244 | output = self.w_2(self.dropout(self.mish(self.w_1(output)))) 245 | output = output.transpose(1, 2) 246 | 247 | output = self.dropout(output) + residual 248 | return output 249 | 250 | 251 | class MelStyleEncoder(nn.Module): 252 | ''' MelStyleEncoder ''' 253 | def __init__(self, config): 254 | super(MelStyleEncoder, self).__init__() 255 | self.in_dim = config.n_mel_channels 256 | self.hidden_dim = config.style_hidden 257 | self.out_dim = config.style_vector_dim 258 | self.kernel_size = config.style_kernel_size 259 | self.n_head = config.style_head 260 | self.dropout = config.dropout 261 | 262 | self.spectral = nn.Sequential( 263 | LinearNorm(self.in_dim, self.hidden_dim), 264 | Mish(), 265 | nn.Dropout(self.dropout), 266 | LinearNorm(self.hidden_dim, self.hidden_dim), 267 | Mish(), 268 | nn.Dropout(self.dropout) 269 | ) 270 | 271 | self.temporal = nn.Sequential( 272 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 273 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 274 | ) 275 | 276 | self.slf_attn = MultiHeadAttention(self.n_head, self.hidden_dim, 277 | self.hidden_dim//self.n_head, self.hidden_dim//self.n_head, self.dropout) 278 | 279 | self.fc = LinearNorm(self.hidden_dim, self.out_dim) 280 | 281 | def temporal_avg_pool(self, x, mask=None): 282 | if mask is None: 283 | out = torch.mean(x, dim=1) 284 | else: 285 | len_ = (~mask).sum(dim=1).unsqueeze(1) 286 | x = x.masked_fill(mask.unsqueeze(-1), 0) 287 | x = x.sum(dim=1) 288 | out = torch.div(x, len_) 289 | return out 290 | 291 | def forward(self, x, mask=None): 292 | max_len = x.shape[1] 293 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None 294 | 295 | # spectral 296 | x = self.spectral(x) 297 | # temporal 298 | x = x.transpose(1,2) 299 | x = self.temporal(x) 300 | x = x.transpose(1,2) 301 | # self-attention 302 | if mask is not None: 303 | x = x.masked_fill(mask.unsqueeze(-1), 0) 304 | x, _ = self.slf_attn(x, mask=slf_attn_mask) 305 | # fc 306 | x = self.fc(x) 307 | # temoral average pooling 308 | w = self.temporal_avg_pool(x, mask=mask) 309 | 310 | return w 311 | 312 | 313 | class Prenet(nn.Module): 314 | ''' Prenet ''' 315 | def __init__(self, hidden_dim, out_dim, dropout): 316 | super(Prenet, self).__init__() 317 | 318 | self.convs = nn.Sequential( 319 | ConvNorm(hidden_dim, hidden_dim, kernel_size=3), 320 | Mish(), 321 | nn.Dropout(dropout), 322 | ConvNorm(hidden_dim, hidden_dim, kernel_size=3), 323 | Mish(), 324 | nn.Dropout(dropout), 325 | ) 326 | self.fc = LinearNorm(hidden_dim, out_dim) 327 | 328 | def forward(self, input, mask=None): 329 | residual = input 330 | # convs 331 | output = input.transpose(1,2) 332 | output = self.convs(output) 333 | output = output.transpose(1,2) 334 | # fc & residual 335 | output = self.fc(output) + residual 336 | 337 | if mask is not None: 338 | output = output.masked_fill(mask.unsqueeze(-1), 0) 339 | return output -------------------------------------------------------------------------------- /models/VarianceAdaptor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.Modules import LinearNorm, ConvNorm, get_sinusoid_encoding_table 4 | import utils 5 | 6 | 7 | class VarianceAdaptor(nn.Module): 8 | """ Variance Adaptor """ 9 | def __init__(self, config): 10 | super(VarianceAdaptor, self).__init__() 11 | 12 | self.hidden_dim = config.variance_predictor_filter_size 13 | self.predictor_kernel_size = config.variance_predictor_kernel_size 14 | self.embedding_kernel_size = config.variance_embedding_kernel_size 15 | self.dropout = config.variance_dropout 16 | 17 | # Duration 18 | self.duration_predictor = VariancePredictor(self.hidden_dim, self.hidden_dim, 19 | self.predictor_kernel_size, dropout=self.dropout) 20 | # Pitch 21 | self.pitch_predictor = VariancePredictor(self.hidden_dim, self.hidden_dim, self.predictor_kernel_size, 22 | dropout=self.dropout) 23 | self.pitch_embedding = VarianceEmbedding(1, self.hidden_dim, self.embedding_kernel_size, self.dropout) 24 | # Energy 25 | self.energy_predictor = VariancePredictor(self.hidden_dim, self.hidden_dim, self.predictor_kernel_size, 26 | dropout=self.dropout) 27 | self.energy_embedding = VarianceEmbedding(1, self.hidden_dim, self.embedding_kernel_size, self.dropout) 28 | # Phoneme 29 | self.ln = nn.LayerNorm(self.hidden_dim) 30 | 31 | # Length regulator 32 | self.length_regulator = LengthRegulator(self.hidden_dim, config.max_seq_len) 33 | 34 | def forward(self, x, src_mask, mel_len=None, mel_mask=None, 35 | duration_target=None, pitch_target=None, energy_target=None, max_len=None): 36 | # Duration 37 | log_duration_prediction = self.duration_predictor(x, src_mask) 38 | 39 | # Pitch & Energy 40 | pitch_prediction = self.pitch_predictor(x, src_mask) 41 | if pitch_target is not None: 42 | pitch_embedding = self.pitch_embedding(pitch_target.unsqueeze(-1)) 43 | else: 44 | pitch_embedding = self.pitch_embedding(pitch_prediction.unsqueeze(-1)) 45 | 46 | energy_prediction = self.energy_predictor(x, src_mask) 47 | if energy_target is not None: 48 | energy_embedding = self.energy_embedding(energy_target.unsqueeze(-1)) 49 | else: 50 | energy_embedding = self.energy_embedding(energy_prediction.unsqueeze(-1)) 51 | 52 | x = self.ln(x) + pitch_embedding + energy_embedding 53 | 54 | # Length regulate 55 | if duration_target is not None: 56 | output, pe, mel_len = self.length_regulator(x, duration_target, max_len) 57 | mel_mask = utils.get_mask_from_lengths(mel_len) 58 | else: 59 | duration_rounded = torch.clamp(torch.round(torch.exp(log_duration_prediction)-1.0), min=0) 60 | duration_rounded = duration_rounded.masked_fill(src_mask, 0).long() 61 | output, pe, mel_len = self.length_regulator(x, duration_rounded) 62 | mel_mask = utils.get_mask_from_lengths(mel_len) 63 | 64 | # Phoneme-wise positional encoding 65 | output = output + pe 66 | return output, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask 67 | 68 | 69 | class LengthRegulator(nn.Module): 70 | """ Length Regulator """ 71 | def __init__(self, hidden_size, max_pos): 72 | super(LengthRegulator, self).__init__() 73 | self.position_enc = nn.Parameter( 74 | get_sinusoid_encoding_table(max_pos+1, hidden_size), requires_grad=False) 75 | 76 | def LR(self, x, duration, max_len): 77 | output = list() 78 | position = list() 79 | mel_len = list() 80 | for batch, expand_target in zip(x, duration): 81 | expanded, pos = self.expand(batch, expand_target) 82 | output.append(expanded) 83 | position.append(pos) 84 | mel_len.append(expanded.shape[0]) 85 | 86 | if max_len is not None: 87 | output = utils.pad(output, max_len) 88 | position = utils.pad(position, max_len) 89 | else: 90 | output = utils.pad(output) 91 | position = utils.pad(position) 92 | return output, position, torch.LongTensor(mel_len).cuda() 93 | 94 | def expand(self, batch, predicted): 95 | out = list() 96 | pos = list() 97 | for i, vec in enumerate(batch): 98 | expand_size = predicted[i].item() 99 | out.append(vec.expand(int(expand_size), -1)) 100 | pos.append(self.position_enc[:expand_size, :]) 101 | out = torch.cat(out, 0) 102 | pos = torch.cat(pos, 0) 103 | return out, pos 104 | 105 | def forward(self, x, duration, max_len=None): 106 | output, position, mel_len = self.LR(x, duration, max_len) 107 | return output, position, mel_len 108 | 109 | 110 | class VariancePredictor(nn.Module): 111 | """ Variance Predictor """ 112 | def __init__(self, input_size, filter_size, kernel_size, output_size=1, n_layers=2, dropout=0.5): 113 | super(VariancePredictor, self).__init__() 114 | 115 | convs = [ConvNorm(input_size, filter_size, kernel_size)] 116 | for _ in range(n_layers-1): 117 | convs.append(ConvNorm(filter_size, filter_size, kernel_size)) 118 | self.convs = nn.ModuleList(convs) 119 | self.lns = nn.ModuleList([nn.LayerNorm(filter_size) for _ in range(n_layers)]) 120 | self.linear_layer = nn.Linear(filter_size, output_size) 121 | 122 | self.relu = nn.ReLU() 123 | self.dropout = nn.Dropout(dropout) 124 | 125 | def forward(self, x, mask): 126 | 127 | for conv, ln in zip(self.convs, self.lns): 128 | x = x.transpose(1,2) 129 | x = self.relu(conv(x)) 130 | x = x.transpose(1,2) 131 | x = ln(x) 132 | x = self.dropout(x) 133 | 134 | out = self.linear_layer(x) 135 | 136 | if mask is not None: 137 | out = out.masked_fill(mask.unsqueeze(-1), 0) 138 | return out.squeeze(-1) 139 | 140 | 141 | class VarianceEmbedding(nn.Module): 142 | """ Variance Embedding """ 143 | def __init__(self, input_size, embed_size, kernel_size, dropout): 144 | super(VarianceEmbedding, self).__init__() 145 | self.conv1 = ConvNorm(input_size, embed_size, kernel_size) 146 | self.conv2 = ConvNorm(embed_size, embed_size, kernel_size) 147 | self.fc = LinearNorm(embed_size, embed_size) 148 | 149 | self.relu = nn.ReLU() 150 | self.dropout = nn.Dropout(dropout) 151 | 152 | def forward(self, x): 153 | x = x.transpose(1,2) 154 | x = self.dropout(self.relu(self.conv1(x))) 155 | x = self.dropout(self.relu(self.conv2(x))) 156 | x = x.transpose(1,2) 157 | 158 | out = self.dropout(self.fc(x)) 159 | return out -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import models.Constants 2 | import models.Modules 3 | import models.StyleSpeech 4 | import models.VarianceAdaptor 5 | import models.Discriminators 6 | import models.Loss -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class ScheduledOptim(): 4 | ''' A simple wrapper class for learning rate scheduling ''' 5 | 6 | def __init__(self, optimizer, d_model, n_warmup_steps, current_steps): 7 | self._optimizer = optimizer 8 | self.n_warmup_steps = n_warmup_steps 9 | self.n_current_steps = current_steps 10 | self.init_lr = np.power(d_model, -0.5) 11 | 12 | def state_dict(self): 13 | return self._optimizer.state_dict() 14 | 15 | def step(self): 16 | self._optimizer.step() 17 | 18 | def step_and_update_lr(self, step=None): 19 | self._update_learning_rate(step=step) 20 | self._optimizer.step() 21 | 22 | def zero_grad(self): 23 | self._optimizer.zero_grad() 24 | 25 | def _get_lr_scale(self): 26 | return np.min([ 27 | np.power(self.n_current_steps, -0.5), 28 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 29 | 30 | def _update_learning_rate(self, step=None): 31 | ''' Learning rate scheduling per step ''' 32 | if step is None: 33 | self.n_current_steps += 1 34 | else: 35 | self.n_current_steps += step 36 | lr = self.init_lr * self._get_lr_scale() 37 | 38 | for param_group in self._optimizer.param_groups: 39 | param_group['lr'] = lr 40 | -------------------------------------------------------------------------------- /prepare_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import preprocessors.libritts as libritts 3 | 4 | def main(data_path, sr): 5 | libritts.prepare_align_and_resample(data_path, sr) 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--data_path', type=str, default='dataset/') 11 | parser.add_argument('--resample_rate', '-sr', type=int, default=16000) 12 | 13 | args = parser.parse_args() 14 | 15 | main(args.data_path, args.resample_rate) 16 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import preprocessors.libritts as libritts 5 | 6 | 7 | def make_train_files(out_dir, datas): 8 | random.shuffle(datas) 9 | num_train = int(len(datas)*0.95) 10 | train_set = datas[:num_train] 11 | val_set = datas[num_train:] 12 | with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: 13 | for m in train_set: 14 | f.write(m + '\n') 15 | with open(os.path.join(out_dir, 'val.txt'), 'w', encoding='utf-8') as f: 16 | for m in val_set: 17 | f.write(m + '\n') 18 | 19 | 20 | def make_folders(out_dir): 21 | mel_out_dir = os.path.join(out_dir, "mel") 22 | if not os.path.exists(mel_out_dir): 23 | os.makedirs(mel_out_dir, exist_ok=True) 24 | ali_out_dir = os.path.join(out_dir, "alignment") 25 | if not os.path.exists(ali_out_dir): 26 | os.makedirs(ali_out_dir, exist_ok=True) 27 | f0_out_dir = os.path.join(out_dir, "f0") 28 | if not os.path.exists(f0_out_dir): 29 | os.makedirs(f0_out_dir, exist_ok=True) 30 | energy_out_dir = os.path.join(out_dir, "energy") 31 | if not os.path.exists(energy_out_dir): 32 | os.makedirs(energy_out_dir, exist_ok=True) 33 | 34 | 35 | def main(data_dir, out_dir): 36 | libritts.write_metadata(data_dir, out_dir) 37 | make_folders(out_dir) 38 | datas = libritts.build_from_path(data_dir, out_dir) 39 | make_train_files(out_dir, datas) 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--data_path', type=str, default='dataset/') 45 | parser.add_argument('--output_path', type=str, default='dataset/') 46 | args = parser.parse_args() 47 | 48 | main(args.data_path, args.output_path) 49 | -------------------------------------------------------------------------------- /preprocessors/libritts.py: -------------------------------------------------------------------------------- 1 | import audio as Audio 2 | from text import _clean_text 3 | import numpy as np 4 | import librosa 5 | import os 6 | from pathlib import Path 7 | from scipy.io.wavfile import write 8 | from joblib import Parallel, delayed 9 | import tgt 10 | import pyworld as pw 11 | from preprocessors.utils import remove_outlier, get_alignment, average_by_duration 12 | from scipy.interpolate import interp1d 13 | import json 14 | 15 | 16 | def write_single(output_folder, wav_fname, text, resample_rate, top_db=None): 17 | data, sample_rate = librosa.load(wav_fname, sr=None) 18 | # trim audio 19 | if top_db is not None: 20 | trimmed, _ = librosa.effects.trim(data, top_db=top_db) 21 | else: 22 | trimmed = data 23 | # resample audio 24 | resampled = librosa.resample(trimmed, sample_rate, resample_rate) 25 | y = (resampled * 32767.0).astype(np.int16) 26 | wav_fname = wav_fname.split('/')[-1] 27 | target_wav_fname = os.path.join(output_folder, wav_fname) 28 | target_txt_fname = os.path.join(output_folder, wav_fname.replace('.wav', '.txt')) 29 | if not os.path.exists(output_folder): 30 | os.makedirs(output_folder, exist_ok=True) 31 | 32 | write(target_wav_fname, resample_rate, y) 33 | with open(target_txt_fname, 'wt') as f: 34 | f.write(text) 35 | f.close() 36 | return y.shape[0] / float(resample_rate) 37 | 38 | 39 | def prepare_align_and_resample(data_dir, sr): 40 | wav_foder_names = ['train-clean-100', 'train-clean-360'] 41 | wavs = [] 42 | for wav_folder in wav_foder_names: 43 | wav_folder = os.path.join(data_dir, wav_folder) 44 | wav_fname_list = [str(f) for f in list(Path(wav_folder).rglob('*.wav'))] 45 | 46 | output_wavs_folder_name = 'wav{}'.format(sr//1000) 47 | output_wavs_folder = os.path.join(data_dir, output_wavs_folder_name) 48 | if not os.path.exists(output_wavs_folder): 49 | os.mkdir(output_wavs_folder) 50 | 51 | for wav_fname in wav_fname_list: 52 | _sid = wav_fname.split('/')[-3] 53 | output_folder = os.path.join(output_wavs_folder, _sid) 54 | txt_fname = wav_fname.replace('.wav','.normalized.txt') 55 | with open(txt_fname, 'r') as f: 56 | text = f.readline().strip() 57 | text = _clean_text(text, ['english_cleaners']) 58 | wavs.append((output_folder, wav_fname, text)) 59 | 60 | lengths = Parallel(n_jobs=10, verbose=1)( 61 | delayed(write_single)(wav[0], wav[1], wav[2], sr) for wav in wavs 62 | ) 63 | 64 | 65 | class Preprocessor: 66 | def __init__(self, config): 67 | self.config = config 68 | self.sampling_rate = config["sampling_rate"] 69 | 70 | self.n_mel_channels = config["n_mel_channels"] 71 | self.filter_length = config["filter_length"] 72 | self.hop_length = config["hop_length"] 73 | self.win_length = config["win_length"] 74 | self.max_wav_value = config["max_wav_value"] 75 | self.mel_fmin = config["mel_fmin"] 76 | self.mel_fmax= config["mel_fmax"] 77 | 78 | self.max_seq_len = config["max_seq_len"] 79 | 80 | self.STFT = Audio.stft.TacotronSTFT( 81 | config["preprocessing"]["stft"]["filter_length"], 82 | config["preprocessing"]["stft"]["hop_length"], 83 | config["preprocessing"]["stft"]["win_length"], 84 | config["preprocessing"]["mel"]["n_mel_channels"], 85 | config["preprocessing"]["audio"]["sampling_rate"], 86 | config["preprocessing"]["mel"]["mel_fmin"], 87 | config["preprocessing"]["mel"]["mel_fmax"], 88 | ) 89 | 90 | def write_metadata(self, data_dir, out_dir): 91 | metadata = os.path.join(out_dir, 'metadata.csv') 92 | if not os.path.exists(metadata): 93 | wav_fname_list = [str(f) for f in list(Path(data_dir).rglob('*.wav'))] 94 | lines = [] 95 | for wav_fname in wav_fname_list: 96 | basename = wav_fname.split('/')[-1].replace('.wav', '') 97 | sid = wav_fname.split('/')[-2] 98 | assert sid in basename 99 | txt_fname = wav_fname.replace('.wav', '.txt') 100 | with open(txt_fname, 'r') as f: 101 | text = f.readline().strip() 102 | f.close() 103 | lines.append('{}|{}|{}'.format(basename, text, sid)) 104 | with open(metadata, 'wt') as f: 105 | f.writelines('\n'.join(lines)) 106 | f.close() 107 | 108 | def build_from_path(self, data_dir, out_dir): 109 | datas = list() 110 | f0 = list() 111 | energy = list() 112 | n_frames = 0 113 | with open(os.path.join(out_dir, 'metadata.csv'), encoding='utf-8') as f: 114 | basenames = [] 115 | for line in f: 116 | parts = line.strip().split('|') 117 | basename = parts[0] 118 | basenames.append(basename) 119 | 120 | results = Parallel(n_jobs=10, verbose=1)( 121 | delayed(self.process_utterance)(data_dir, out_dir, basename) for basename in basenames 122 | ) 123 | results = [ r for r in results if r is not None ] 124 | for r in results: 125 | datas.extend(r[0]) 126 | f0.extend(r[1]) 127 | energy.extend(r[2]) 128 | n_frames += r[3] 129 | 130 | f0 = remove_outlier(f0) 131 | energy = remove_outlier(energy) 132 | 133 | f0_max = np.max(f0) 134 | f0_min = np.min(f0) 135 | f0_mean = np.mean(f0) 136 | f0_std = np.std(f0) 137 | energy_max = np.max(energy) 138 | energy_min = np.min(energy) 139 | energy_mean = np.mean(energy) 140 | energy_std = np.std(energy) 141 | 142 | total_time = n_frames*self.hop_length/self.sampling_rate/3600 143 | f_json = { 144 | "total_time": total_time, 145 | "n_frames": n_frames, 146 | "f0_stat": [f0_max, f0_min, f0_mean, f0_std], 147 | "energy_state": [energy_max, energy_min, energy_mean, energy_std] 148 | } 149 | with open(os.path.join(out_dir, 'stats.json'), 'w') as f: 150 | json.dump(f_json, f) 151 | 152 | return datas 153 | 154 | 155 | def process_utterance(self, in_dir, out_dir, basename, dataset='libritts'): 156 | sid = basename.split('_')[0] 157 | wav_path = os.path.join(in_dir, 'wav{}', sid, '{}.wav'.format(self.sampling_rate//1000, basename)) 158 | tg_path = os.path.join(out_dir, 'TextGrid', sid, '{}.TextGrid'.format(basename)) 159 | 160 | if not os.path.exists(wav_path) or not os.path.exists(tg_path): 161 | return None 162 | 163 | # Get alignments 164 | textgrid = tgt.io.read_textgrid(tg_path) 165 | phone, duration, start, end = get_alignment(textgrid.get_tier_by_name('phones'), self.sampling_rate, self.hop_length) 166 | text = '{'+ '}{'.join(phone) + '}' # '{A}{B}{$}{C}', $ represents silent phones 167 | text = text.replace('{$}', ' ') # '{A}{B} {C}' 168 | text = text.replace('}{', ' ') # '{A B} {C}' 169 | 170 | if start >= end: 171 | return None 172 | 173 | # Read and trim wav files 174 | wav, _ = librosa.load(wav_path, sr=None) 175 | wav = wav[int(self.sampling_rate*start):int(self.sampling_rate*end)].astype(np.float32) 176 | 177 | # Compute fundamental frequency 178 | _f0, t = pw.dio(wav.astype(np.float64), self.sampling_rate, frame_period=self.hop_length/self.sampling_rate*1000) 179 | f0 = pw.stonemask(wav.astype(np.float64), _f0, t, self.sampling_rate) 180 | f0 = f0[:sum(duration)] 181 | 182 | # Compute mel-scale spectrogram and energy 183 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT) 184 | mel_spectrogram = mel_spectrogram[:, :sum(duration)] 185 | energy = energy[:sum(duration)] 186 | 187 | if mel_spectrogram.shape[1] >= self.max_seq_len: 188 | return None 189 | 190 | # Pitch perform linear interpolation 191 | nonzero_ids = np.where(f0 != 0)[0] 192 | if len(nonzero_ids)>=2: 193 | interp_fn = interp1d( 194 | nonzero_ids, 195 | f0[nonzero_ids], 196 | fill_value=(f0[nonzero_ids[0]], f0[nonzero_ids[-1]]), 197 | bounds_error=False, 198 | ) 199 | f0 = interp_fn(np.arange(0, len(f0))) 200 | # Pitch phoneme-level average 201 | f0 = average_by_duration(np.array(f0), np.array(duration)) 202 | 203 | # Energy phoneme-level average 204 | energy = average_by_duration(np.array(energy), np.array(duration)) 205 | 206 | if len([f for f in f0 if f != 0]) ==0 or len([e for e in energy if e != 0]): 207 | return None 208 | 209 | # Save alignment 210 | ali_filename = '{}-ali-{}.npy'.format(dataset, basename) 211 | np.save(os.path.join(out_dir, 'alignment', ali_filename), duration, allow_pickle=False) 212 | 213 | # Save fundamental frequency 214 | f0_filename = '{}-f0-{}.npy'.format(dataset, basename) 215 | np.save(os.path.join(out_dir, 'f0', f0_filename), f0, allow_pickle=False) 216 | 217 | # Save energy 218 | energy_filename = '{}-energy-{}.npy'.format(dataset, basename) 219 | np.save(os.path.join(out_dir, 'energy', energy_filename), energy, allow_pickle=False) 220 | 221 | # Save spectrogram 222 | mel_filename = '{}-mel-{}.npy'.format(dataset, basename) 223 | np.save(os.path.join(out_dir, 'mel', mel_filename), mel_spectrogram.T, allow_pickle=False) 224 | 225 | return '|'.join([basename, text, sid]), list(f0), list(energy), mel_spectrogram.shape[1] 226 | -------------------------------------------------------------------------------- /preprocessors/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_alignment(tier, sampling_rate, hop_length): 4 | sil_phones = ['sil', 'sp', 'spn', ''] 5 | phones = [] 6 | durations = [] 7 | start_time = 0 8 | end_time = 0 9 | end_idx = 0 10 | for t in tier._objects: 11 | s, e, p = t.start_time, t.end_time, t.text 12 | 13 | # Trimming leading silences 14 | if phones == []: 15 | if p in sil_phones: 16 | continue 17 | else: 18 | start_time = s 19 | if p not in sil_phones: 20 | phones.append(p) 21 | end_time = e 22 | end_idx = len(phones) 23 | else: 24 | phones.append(p) 25 | durations.append(round(e*sampling_rate/hop_length)-round(s*sampling_rate/hop_length)) 26 | 27 | # Trimming tailing silences 28 | phones = phones[:end_idx] 29 | durations = durations[:end_idx] 30 | 31 | return phones, durations, start_time, end_time 32 | 33 | 34 | def is_outlier(x, p25, p75): 35 | """Check if value is an outlier.""" 36 | lower = p25 - 1.5 * (p75 - p25) 37 | upper = p75 + 1.5 * (p75 - p25) 38 | return x <= lower or x >= upper 39 | 40 | 41 | def remove_outlier(x, p_bottom: int = 25, p_top: int = 75): 42 | """Remove outlier from x.""" 43 | p_bottom = np.percentile(x, p_bottom) 44 | p_top = np.percentile(x, p_top) 45 | 46 | indices_of_outliers = [] 47 | for ind, value in enumerate(x): 48 | if is_outlier(value, p_bottom, p_top): 49 | indices_of_outliers.append(ind) 50 | 51 | x[indices_of_outliers] = 0.0 52 | x[indices_of_outliers] = np.max(x) 53 | return 54 | 55 | def average_by_duration(x, durs): 56 | length = sum(durs) 57 | durs_cum = np.cumsum(np.pad(durs, (1, 0), mode='constant')) 58 | 59 | # calculate charactor f0/energy 60 | if len(x.shape) == 2: 61 | x_char = np.zeros((durs.shape[0], x.shape[1]), dtype=np.float32) 62 | else: 63 | x_char = np.zeros((durs.shape[0],), dtype=np.float32) 64 | for idx, start, end in zip(range(length), durs_cum[:-1], durs_cum[1:]): 65 | values = x[start:end][np.where(x[start:end] != 0.0)[0]] 66 | x_char[idx] = np.mean(values, axis=0) if len(values) > 0 else 0.0 # np.mean([]) = nan. 67 | 68 | return x_char.astype(np.float32) 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.7.10 2 | pytorch==1.8.1 3 | scipy==1.4.1 4 | numpy==1.19.4 5 | librosa==0.8.1 6 | pyword==0.2.11 7 | joblib==0.14.1 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import argparse 5 | import librosa 6 | import re 7 | import json 8 | from string import punctuation 9 | from g2p_en import G2p 10 | 11 | from models.StyleSpeech import StyleSpeech 12 | from text import text_to_sequence 13 | import audio as Audio 14 | import utils 15 | 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def read_lexicon(lex_path): 21 | lexicon = {} 22 | with open(lex_path) as f: 23 | for line in f: 24 | temp = re.split(r"\s+", line.strip("\n")) 25 | word = temp[0] 26 | phones = temp[1:] 27 | if word.lower() not in lexicon: 28 | lexicon[word.lower()] = phones 29 | return lexicon 30 | 31 | 32 | def preprocess_english(text, lexicon_path): 33 | text = text.rstrip(punctuation) 34 | lexicon = read_lexicon(lexicon_path) 35 | 36 | g2p = G2p() 37 | phones = [] 38 | words = re.split(r"([,;.\-\?\!\s+])", text) 39 | for w in words: 40 | if w.lower() in lexicon: 41 | phones += lexicon[w.lower()] 42 | else: 43 | phones += list(filter(lambda p: p != " ", g2p(w))) 44 | phones = "{" + "}{".join(phones) + "}" 45 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) 46 | phones = phones.replace("}{", " ") 47 | 48 | print("Raw Text Sequence: {}".format(text)) 49 | print("Phoneme Sequence: {}".format(phones)) 50 | sequence = np.array(text_to_sequence(phones, ['english_cleaners'])) 51 | 52 | return torch.from_numpy(sequence).to(device=device) 53 | 54 | 55 | def preprocess_audio(audio_file, _stft): 56 | wav, sample_rate = librosa.load(audio_file, sr=None) 57 | if sample_rate != 16000: 58 | wav = librosa.resample(wav, sample_rate, 16000) 59 | mel_spectrogram, _ = Audio.tools.get_mel_from_wav(wav, _stft) 60 | return torch.from_numpy(mel_spectrogram).to(device=device) 61 | 62 | 63 | def get_StyleSpeech(config, checkpoint_path): 64 | model = StyleSpeech(config).to(device=device) 65 | model.load_state_dict(torch.load(checkpoint_path)['model']) 66 | model.eval() 67 | return model 68 | 69 | 70 | def synthesize(args, model, _stft): 71 | # preprocess audio and text 72 | ref_mel = preprocess_audio(args.ref_audio, _stft).transpose(0,1).unsqueeze(0) 73 | src = preprocess_english(args.text, args.lexicon_path).unsqueeze(0) 74 | src_len = torch.from_numpy(np.array([src.shape[1]])).to(device=device) 75 | 76 | save_path = args.save_path 77 | if not os.path.exists(save_path): 78 | os.makedirs(save_path, exist_ok=True) 79 | 80 | # Extract style vector 81 | style_vector = model.get_style_vector(ref_mel) 82 | 83 | # Forward 84 | mel_output = model.inference(style_vector, src, src_len)[0] 85 | 86 | mel_ref_ = ref_mel.cpu().squeeze().transpose(0, 1).detach() 87 | mel_ = mel_output.cpu().squeeze().transpose(0, 1).detach() 88 | 89 | # plotting 90 | utils.plot_data([mel_ref_.numpy(), mel_.numpy()], 91 | ['Ref Spectrogram', 'Synthesized Spectrogram'], filename=os.path.join(save_path, 'plot.png')) 92 | print('Generate done!') 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--checkpoint_path", type=str, required=True, 98 | help="Path to the pretrained model") 99 | parser.add_argument('--config', default='configs/config.json') 100 | parser.add_argument("--save_path", type=str, default='results/') 101 | parser.add_argument("--ref_audio", type=str, required=True, 102 | help="path to an reference speech audio sample") 103 | parser.add_argument("--text", type=str, required=True, 104 | help="raw text to synthesize") 105 | parser.add_argument("--lexicon_path", type=str, default='lexicon/librispeech-lexicon.txt') 106 | args = parser.parse_args() 107 | 108 | with open(args.config) as f: 109 | data = f.read() 110 | json_config = json.loads(data) 111 | config = utils.AttrDict(json_config) 112 | 113 | # Get model 114 | model = get_StyleSpeech(config, args.checkpoint_path) 115 | print('model is prepared') 116 | 117 | _stft = Audio.stft.TacotronSTFT( 118 | config.filter_length, 119 | config.hop_length, 120 | config.win_length, 121 | config.n_mel_channels, 122 | config.sampling_rate, 123 | config.mel_fmin, 124 | config.mel_fmax) 125 | 126 | # Synthesize 127 | synthesize(args, model, _stft) 128 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | sequence += _symbols_to_sequence( 38 | _clean_text(m.group(1), cleaner_names)) 39 | sequence += _arpabet_to_sequence(m.group(2)) 40 | text = m.group(3) 41 | 42 | return sequence 43 | 44 | 45 | def sequence_to_text(sequence): 46 | '''Converts a sequence of IDs back to a string''' 47 | result = '' 48 | for symbol_id in sequence: 49 | if symbol_id in _id_to_symbol: 50 | s = _id_to_symbol[symbol_id] 51 | # Enclose ARPAbet back in curly braces: 52 | if len(s) > 1 and s[0] == '@': 53 | s = '{%s}' % s[1:] 54 | result += s 55 | return result.replace('}{', ' ') 56 | 57 | 58 | def _clean_text(text, cleaner_names): 59 | for name in cleaner_names: 60 | cleaner = getattr(cleaners, name) 61 | if not cleaner: 62 | raise Exception('Unknown cleaner: %s' % name) 63 | text = cleaner(text) 64 | return text 65 | 66 | 67 | def _symbols_to_sequence(symbols): 68 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 69 | 70 | 71 | def _arpabet_to_sequence(text): 72 | return _symbols_to_sequence(['@' + s for s in text.split()]) 73 | 74 | 75 | def _should_keep_symbol(s): 76 | return s in _symbol_to_id and s is not '_' and s is not '~' 77 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | 22 | def __init__(self, file_or_path, keep_ambiguous=True): 23 | if isinstance(file_or_path, str): 24 | with open(file_or_path, encoding='latin-1') as f: 25 | entries = _parse_cmudict(f) 26 | else: 27 | entries = _parse_cmudict(file_or_path) 28 | if not keep_ambiguous: 29 | entries = {word: pron for word, 30 | pron in entries.items() if len(pron) == 1} 31 | self._entries = entries 32 | 33 | def __len__(self): 34 | return len(self._entries) 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | _alt_re = re.compile(r'\([0-9]+\)') 42 | 43 | 44 | def _parse_cmudict(file): 45 | cmudict = {} 46 | for line in file: 47 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 48 | parts = line.split(' ') 49 | word = re.sub(_alt_re, '', parts[0]) 50 | pronunciation = _get_pronunciation(parts[1]) 51 | if pronunciation: 52 | if word in cmudict: 53 | cmudict[word].append(pronunciation) 54 | else: 55 | cmudict[word] = [pronunciation] 56 | return cmudict 57 | 58 | 59 | def _get_pronunciation(s): 60 | parts = s.strip().split(' ') 61 | for part in parts: 62 | if part not in _valid_symbol_set: 63 | return None 64 | return ' '.join(parts) 65 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | 8 | from text import cmudict 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | _silences = ['@sp', '@spn', '@sil'] 14 | 15 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 16 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 17 | 18 | # Export all symbols: 19 | symbols = [_pad] + list(_special) + list(_punctuation) + \ 20 | list(_letters) + _arpabet + _silences 21 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.tensorboard import SummaryWriter 4 | import argparse 5 | import os 6 | import json 7 | from models.StyleSpeech import StyleSpeech 8 | from dataloader import prepare_dataloader 9 | from optimizer import ScheduledOptim 10 | from evaluate import evaluate 11 | import utils 12 | 13 | def load_checkpoint(checkpoint_path, model, optimizer): 14 | assert os.path.isfile(checkpoint_path) 15 | print("Starting model from checkpoint '{}'".format(checkpoint_path)) 16 | checkpoint_dict = torch.load(checkpoint_path) 17 | if 'model' in checkpoint_dict: 18 | model.load_state_dict(checkpoint_dict['model']) 19 | print('Model is loaded!') 20 | if 'optimizer' in checkpoint_dict: 21 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 22 | print('Optimizer is loaded!') 23 | current_step = checkpoint_dict['step'] + 1 24 | return model, optimizer, current_step 25 | 26 | 27 | def main(args, c): 28 | 29 | # Define model 30 | model = StyleSpeech(c).cuda() 31 | print("StyleSpeech Has Been Defined") 32 | num_param = utils.get_param_num(model) 33 | print('Number of StyleSpeech Parameters:', num_param) 34 | with open(os.path.join(args.save_path, "model.txt"), "w") as f_log: 35 | f_log.write(str(model)) 36 | 37 | # Optimizer 38 | optimizer = torch.optim.Adam(model.parameters(), betas=c.betas, eps=c.eps) 39 | # Loss 40 | Loss = model.get_criterion() 41 | print("Optimizer and Loss Function Defined.") 42 | 43 | # Get dataset 44 | data_loader = prepare_dataloader(args.data_path, "train.txt", shuffle=True, batch_size=c.batch_size) 45 | print("Data Loader is Prepared.") 46 | 47 | # Load checkpoint if exists 48 | if args.checkpoint_path is not None: 49 | assert os.path.exists(args.checkpoint_path) 50 | model, optimizer, current_step= load_checkpoint(args.checkpoint_path, model, optimizer) 51 | print("\n---Model Restored at Step {}---\n".format(current_step)) 52 | else: 53 | print("\n---Start New Training---\n") 54 | current_step = 0 55 | checkpoint_path = os.path.join(args.save_path, 'ckpt') 56 | os.makedirs(checkpoint_path, exist_ok=True) 57 | 58 | # Scheduled optimizer 59 | scheduled_optim = ScheduledOptim(optimizer, c.decoder_hidden, c.n_warm_up_step, current_step) 60 | 61 | # Init logger 62 | log_path = os.path.join(args.save_path, 'log') 63 | logger = SummaryWriter(os.path.join(log_path, 'board')) 64 | with open(os.path.join(log_path, "log.txt"), "a") as f_log: 65 | f_log.write("Dataset :{}\n Number of Parameters: {}\n".format(c.dataset, num_param)) 66 | 67 | # Init synthesis directory 68 | synth_path = os.path.join(args.save_path, 'synth') 69 | os.makedirs(synth_path, exist_ok=True) 70 | 71 | # Training 72 | model.train() 73 | while current_step < args.max_iter: 74 | # Get Training Loader 75 | for idx, batch in enumerate(data_loader): 76 | 77 | if current_step == args.max_iter: 78 | break 79 | 80 | # Get Data 81 | sid, text, mel_target, D, log_D, f0, energy, \ 82 | src_len, mel_len, max_src_len, max_mel_len = model.parse_batch(batch) 83 | 84 | # Forward 85 | scheduled_optim.zero_grad() 86 | mel_output, src_output, style_vector, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( 87 | text, src_len, mel_target, mel_len, D, f0, energy, max_src_len, max_mel_len) 88 | 89 | mel_loss, d_loss, f_loss, e_loss = Loss(mel_output, mel_target, 90 | log_duration_output, log_D, f0_output, f0, energy_output, energy, src_len, mel_len) 91 | 92 | # Total loss 93 | total_loss = mel_loss + d_loss + f_loss + e_loss 94 | # Backward 95 | total_loss.backward() 96 | # Clipping gradients to avoid gradient explosion 97 | nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip_thresh) 98 | # Update weights 99 | scheduled_optim.step_and_update_lr() 100 | 101 | # Print log 102 | if current_step % args.log_step == 0 and current_step != 0: 103 | t_l = total_loss.item() 104 | m_l = mel_loss.item() 105 | d_l = d_loss.item() 106 | f_l = f_loss.item() 107 | e_l = e_loss.item() 108 | 109 | str1 = "Step [{}/{}]:".format(current_step, args.max_iter) 110 | str2 = "Total Loss: {:.4f}\nMel Loss: {:.4f},\n" \ 111 | "Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f} ;" \ 112 | .format(t_l, m_l, d_l, f_l, e_l) 113 | print(str1 + "\n" + str2 +"\n") 114 | with open(os.path.join(log_path, "log.txt"), "a") as f_log: 115 | f_log.write(str1 + "\n" + str2 +"\n") 116 | 117 | logger.add_scalar('Train/total_loss', t_l, current_step) 118 | logger.add_scalar('Train/mel_loss', m_l, current_step) 119 | logger.add_scalar('Train/duration_loss', d_l, current_step) 120 | logger.add_scalar('Train/f0_loss', f_l, current_step) 121 | logger.add_scalar('Train/energy_loss', e_l, current_step) 122 | 123 | # Save Checkpoint 124 | if current_step % args.save_step == 0 and current_step != 0: 125 | torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': current_step}, 126 | os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) 127 | print("*** Save Checkpoint ***") 128 | print("Save model at step {}...\n".format(current_step)) 129 | 130 | if current_step % args.synth_step == 0 and current_step != 0: 131 | length = mel_len[0].item() 132 | mel_target = mel_target[0, :length].detach().cpu().transpose(0, 1) 133 | mel = mel_output[0, :length].detach().cpu().transpose(0, 1) 134 | # plotting 135 | utils.plot_data([mel.numpy(), mel_target.numpy()], 136 | ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step))) 137 | print("Synth spectrograms at step {}...\n".format(current_step)) 138 | 139 | if current_step % args.eval_step == 0 and current_step != 0: 140 | model.eval() 141 | with torch.no_grad(): 142 | m_l, d_l, f_l, e_l = evaluate(args, model, current_step) 143 | str_v = "*** Validation ***\n" \ 144 | "StyleSpeech Step {},\n" \ 145 | "Mel Loss: {}\nDuration Loss:{}\nF0 Loss: {}\nEnergy Loss: {}" \ 146 | .format(current_step, m_l, d_l, f_l, e_l) 147 | print(str_v + "\n" ) 148 | with open(os.path.join(log_path, "eval.txt"), "a") as f_log: 149 | f_log.write(str_v + "\n") 150 | logger.add_scalar('Validation/mel_loss', m_l, current_step) 151 | logger.add_scalar('Validation/duration_loss', d_l, current_step) 152 | logger.add_scalar('Validation/f0_loss', f_l, current_step) 153 | logger.add_scalar('Validation/energy_loss', e_l, current_step) 154 | model.train() 155 | 156 | current_step += 1 157 | 158 | print("Training Done at Step : {}".format(current_step)) 159 | torch.save({'model': model.state_dict(), 'optimizer': scheduled_optim.state_dict(), 'step': current_step}, 160 | os.path.join(checkpoint_path, 'checkpoint_last_{}.pth.tar'.format(current_step))) 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('--data_path', default='dataset/LibriTTS/preprocessed') 166 | parser.add_argument('--save_path', default='exp_stylespeech') 167 | parser.add_argument('--config', default='configs/config.json') 168 | parser.add_argument('--max_iter', default=100000, type=int) 169 | parser.add_argument('--save_step', default=5000, type=int) 170 | parser.add_argument('--synth_step', default=1000, type=int) 171 | parser.add_argument('--eval_step', default=5000, type=int) 172 | parser.add_argument('--log_step', default=100, type=int) 173 | parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to the pretrained model') 174 | 175 | args = parser.parse_args() 176 | 177 | torch.backends.cudnn.enabled = True 178 | 179 | with open(args.config) as f: 180 | data = f.read() 181 | json_config = json.loads(data) 182 | config = utils.AttrDict(json_config) 183 | utils.build_env(args.config, 'config.json', args.save_path) 184 | 185 | main(args, config) 186 | -------------------------------------------------------------------------------- /train_meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from torch.utils.tensorboard import SummaryWriter 5 | import torch.multiprocessing as mp 6 | import torch.distributed as dist 7 | import torch.utils.data.distributed 8 | 9 | import argparse 10 | import os 11 | import json 12 | 13 | from models.StyleSpeech import StyleSpeech 14 | from models.Discriminators import Discriminator 15 | from dataloader import prepare_dataloader 16 | from optimizer import ScheduledOptim 17 | from evaluate import evaluate 18 | import utils 19 | 20 | def load_checkpoint(checkpoint_path, model, discriminator, G_optim, D_optim, rank, distributed=False): 21 | assert os.path.isfile(checkpoint_path) 22 | print("Starting model from checkpoint '{}'".format(checkpoint_path)) 23 | checkpoint_dict = torch.load(checkpoint_path, map_location='cuda:{}'.format(rank)) 24 | if 'model' in checkpoint_dict: 25 | if distributed: 26 | state_dict = {} 27 | for k,v in checkpoint_dict['model'].items(): 28 | state_dict['module.{}'.format(k)] = v 29 | model.load_state_dict(state_dict) 30 | else: 31 | model.load_state_dict(checkpoint_dict['model']) 32 | print('Model is loaded!') 33 | if 'discriminator' in checkpoint_dict: 34 | if distributed: 35 | state_dict = {} 36 | for k,v in checkpoint_dict['discriminator'].items(): 37 | state_dict['module.{}'.format(k)] = v 38 | discriminator.load_state_dict(state_dict) 39 | else: 40 | discriminator.load_state_dict(checkpoint_dict['discriminator']) 41 | print('Discriminator is loaded!') 42 | if 'G_optim' in checkpoint_dict or 'optimizer' in checkpoint_dict: 43 | if 'optimizer' in checkpoint_dict: 44 | G_optim.load_state_dict(checkpoint_dict['optimizer']) 45 | if 'G_optim' in checkpoint_dict: 46 | G_optim.load_state_dict(checkpoint_dict['G_optim']) 47 | print('G_optim is loaded!') 48 | if 'D_optim' in checkpoint_dict: 49 | D_optim.load_state_dict(checkpoint_dict['D_optim']) 50 | print('D_optim is loaded!') 51 | current_step = checkpoint_dict['step'] + 1 52 | del checkpoint_dict 53 | return model, discriminator, G_optim, D_optim, current_step 54 | 55 | 56 | def main(rank, args, c): 57 | 58 | print('Use GPU: {} for training'.format(rank)) 59 | 60 | ngpus = args.ngpus 61 | if args.distributed: 62 | torch.cuda.set_device(rank % ngpus) 63 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=rank) 64 | 65 | # Define model & loss 66 | model = StyleSpeech(c).cuda() 67 | discriminator = Discriminator(c).cuda() 68 | num_param = utils.get_param_num(model) 69 | D_num_param = utils.get_param_num(discriminator) 70 | 71 | if rank==0: 72 | print('Number of Meta-StyleSpeech Parameters:', num_param) 73 | print('Number of Discriminator Parameters:', D_num_param) 74 | with open(os.path.join(args.save_path, "model.txt"), "w") as f_log: 75 | f_log.write(str(model)) 76 | f_log.write(str(discriminator)) 77 | print("Model Has Been Defined") 78 | 79 | model_without_ddp = model 80 | discriminator_without_ddp = discriminator 81 | if args.distributed: 82 | c.meta_batch_size = c.meta_batch_size // ngpus 83 | model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) 84 | model_without_ddp = model.module 85 | discriminator = nn.parallel.DistributedDataParallel(discriminator, device_ids=[rank]) 86 | discriminator_without_ddp = discriminator.module 87 | 88 | # Optimizer 89 | G_optim = torch.optim.Adam(model.parameters(), betas=c.betas, eps=c.eps) 90 | D_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=c.betas, eps=c.eps) 91 | # Loss 92 | Loss = model_without_ddp.get_criterion() 93 | adversarial_loss = discriminator_without_ddp.get_criterion() 94 | print("Optimizer and Loss Function Defined.") 95 | 96 | # Get dataset 97 | data_loader = prepare_dataloader(args.data_path, "train.txt", batch_size=c.meta_batch_size, meta_learning=True, seed=rank) 98 | print("Data Loader is Prepared") 99 | 100 | # Load checkpoint if exists 101 | if args.checkpoint_path is not None: 102 | assert os.path.exists(args.checkpoint_path) 103 | model, discriminator, G_optim, D_optim, current_step = load_checkpoint( 104 | args.checkpoint_path, model, discriminator, G_optim, D_optim, rank, args.distributed) 105 | print("\n---Model Restored at Step {}---\n".format(current_step)) 106 | else: 107 | print("\n---Start New Training---\n") 108 | current_step = 0 109 | if rank == 0: 110 | checkpoint_path = os.path.join(args.save_path, 'ckpt') 111 | os.makedirs(checkpoint_path, exist_ok=True) 112 | 113 | # scheduled optimizer 114 | G_optim = ScheduledOptim(G_optim, c.decoder_hidden, c.n_warm_up_step, current_step) 115 | 116 | # Init logger 117 | if rank == 0: 118 | log_path = os.path.join(args.save_path, 'log') 119 | logger = SummaryWriter(os.path.join(log_path, 'board')) 120 | with open(os.path.join(log_path, "log.txt"), "a") as f_log: 121 | f_log.write("Dataset :{}\n Number of Parameters: {}\n".format(c.dataset, num_param)) 122 | 123 | # Init synthesis directory 124 | if rank == 0: 125 | synth_path = os.path.join(args.save_path, 'synth') 126 | os.makedirs(synth_path, exist_ok=True) 127 | 128 | model.train() 129 | while current_step < args.max_iter: 130 | # Get Training Loader 131 | for idx, batch in enumerate(data_loader): 132 | 133 | if current_step == args.max_iter: 134 | break 135 | 136 | losses = {} 137 | #### Generator #### 138 | G_optim.zero_grad() 139 | # Get Support Data 140 | sid, text, mel_target, D, log_D, f0, energy, \ 141 | src_len, mel_len, max_src_len, max_mel_len = model_without_ddp.parse_batch(batch) 142 | 143 | # Support Forward 144 | mel_output, src_output, style_vector, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( 145 | text, src_len, mel_target, mel_len, D, f0, energy, max_src_len, max_mel_len) 146 | src_target, _, _ = model_without_ddp.variance_adaptor.length_regulator(src_output, D) 147 | 148 | # Reconstruction loss 149 | mel_loss, d_loss, f_loss, e_loss = Loss(mel_output, mel_target, 150 | log_duration_output, log_D, f0_output, f0, energy_output, energy, src_len, mel_len) 151 | losses['G_recon'] = mel_loss 152 | losses['d_loss'] = d_loss 153 | losses['f_loss'] = f_loss 154 | losses['e_loss'] = e_loss 155 | 156 | 157 | #### META LEARNING #### 158 | # Get query text 159 | B = mel_target.shape[0] 160 | perm_idx = torch.randperm(B) 161 | q_text, q_src_len = text[perm_idx], src_len[perm_idx] 162 | # Generate query speech 163 | q_mel_output, q_src_output, q_log_duration_output, \ 164 | _, _, q_src_mask, q_mel_mask, q_mel_len = model_without_ddp.inference(style_vector, q_text, q_src_len) 165 | # Legulate length of query src 166 | q_duration_rounded = torch.clamp(torch.round(torch.exp(q_log_duration_output.detach())-1.), min=0) 167 | q_duration = q_duration_rounded.masked_fill(q_src_mask, 0).long() 168 | q_src, _, _ = model_without_ddp.variance_adaptor.length_regulator(q_src_output, q_duration) 169 | # Adverserial loss 170 | t_val, s_val, _= discriminator(q_mel_output, q_src, None, sid, q_mel_mask) 171 | losses['G_GAN_query_t'] = adversarial_loss(t_val, is_real=True) 172 | losses['G_GAN_query_s'] = adversarial_loss(s_val, is_real=True) 173 | 174 | # Total generator loss 175 | alpha = 10.0 176 | G_loss = alpha*losses['G_recon'] + losses['d_loss'] + losses['f_loss'] + losses['e_loss'] + \ 177 | losses['G_GAN_query_t'] + losses['G_GAN_query_s'] 178 | # Backward loss 179 | G_loss.backward() 180 | # Update weights 181 | G_optim.step_and_update_lr() 182 | 183 | 184 | #### Discriminator #### 185 | D_optim.zero_grad() 186 | # Real 187 | real_t_pred, real_s_pred, cls_loss = discriminator( 188 | mel_target, src_target.detach(), style_vector.detach(), sid, mask=mel_mask) 189 | # Fake 190 | fake_t_pred, fake_s_pred, _ = discriminator( 191 | q_mel_output.detach(), q_src.detach(), None, sid, mask=q_mel_mask) 192 | losses['D_t_loss'] = adversarial_loss(real_t_pred, is_real=True) + adversarial_loss(fake_t_pred, is_real=False) 193 | losses['D_s_loss'] = adversarial_loss(real_s_pred, is_real=True) + adversarial_loss(fake_s_pred, is_real=False) 194 | losses['cls_loss'] = cls_loss 195 | # Total discriminator Loss 196 | D_loss = losses['D_t_loss'] + losses['D_s_loss'] + losses['cls_loss'] 197 | # Backward 198 | D_loss.backward() 199 | # Update weights 200 | D_optim.step() 201 | 202 | # Print log 203 | if current_step % args.log_step == 0 and current_step != 0 and rank == 0 : 204 | m_l = losses['G_recon'].item() 205 | d_l = losses['d_loss'].item() 206 | f_l = losses['f_loss'].item() 207 | e_l = losses['e_loss'].item() 208 | g_t_l = losses['G_GAN_query_t'].item() 209 | g_s_l = losses['G_GAN_query_s'].item() 210 | d_t_l = losses['D_t_loss'].item() / 2 211 | d_s_l = losses['D_s_loss'].item() / 2 212 | cls_l = losses['cls_loss'].item() 213 | 214 | str1 = "Step [{}/{}]:".format(current_step, args.max_iter) 215 | str2 = "Mel Loss: {:.4f},\n" \ 216 | "Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f}\n" \ 217 | "T G Loss: {:.4f}, T D Loss: {:.4f}, S G Loss: {:.4f}, S D Loss: {:.4f} \n" \ 218 | "cls_Loss: {:.4f};" \ 219 | .format(m_l, d_l, f_l, e_l, g_t_l, d_t_l, g_s_l, d_s_l, cls_l) 220 | print(str1 + "\n" + str2 +"\n") 221 | with open(os.path.join(log_path, "log.txt"), "a") as f_log: 222 | f_log.write(str1 + "\n" + str2 +"\n") 223 | 224 | logger.add_scalar('Train/mel_loss', m_l, current_step) 225 | logger.add_scalar('Train/duration_loss', d_l, current_step) 226 | logger.add_scalar('Train/f0_loss', f_l, current_step) 227 | logger.add_scalar('Train/energy_loss', e_l, current_step) 228 | logger.add_scalar('Train/G_t_loss', g_t_l, current_step) 229 | logger.add_scalar('Train/D_t_loss', d_t_l, current_step) 230 | logger.add_scalar('Train/G_s_loss', g_s_l, current_step) 231 | logger.add_scalar('Train/D_s_loss', d_s_l, current_step) 232 | logger.add_scalar('Train/cls_loss', cls_l, current_step) 233 | 234 | # Save Checkpoint 235 | if current_step % args.save_step == 0 and current_step != 0 and rank == 0: 236 | torch.save({'model': model_without_ddp.state_dict(), 237 | 'discriminator': discriminator_without_ddp.state_dict(), 238 | 'G_optim': G_optim.state_dict(),'D_optim': D_optim.state_dict(), 239 | 'step': current_step}, 240 | os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) 241 | print("*** Save Checkpoint ***") 242 | print("Save model at step {}...\n".format(current_step)) 243 | 244 | if current_step % args.synth_step == 0 and current_step != 0 and rank == 0: 245 | length = mel_len[0].item() 246 | mel_target = mel_target[0, :length].detach().cpu().transpose(0, 1) 247 | mel = mel_output[0, :length].detach().cpu().transpose(0, 1) 248 | q_length = q_mel_len[0].item() 249 | q_mel = q_mel_output[0, :q_length].detach().cpu().transpose(0, 1) 250 | # plotting 251 | utils.plot_data([q_mel.numpy(), mel.numpy(), mel_target.numpy()], 252 | ['Query Spectrogram', 'Recon Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step))) 253 | print("Synth audios at step {}...\n".format(current_step)) 254 | 255 | # Evaluate 256 | if current_step % args.eval_step == 0 and current_step != 0 and rank == 0: 257 | model.eval() 258 | with torch.no_grad(): 259 | m_l, d_l, f_l, e_l = evaluate(args, model_without_ddp, current_step) 260 | str_v = "*** Validation ***\n" \ 261 | "Meta-StyleSpeech Step {},\n" \ 262 | "Mel Loss: {}\nDuration Loss:{}\nF0 Loss: {}\nEnergy Loss: {}" \ 263 | .format(current_step, m_l, d_l, f_l, e_l) 264 | print(str_v + "\n" ) 265 | with open(os.path.join(log_path, "eval.txt"), "a") as f_log: 266 | f_log.write(str_v + "\n") 267 | logger.add_scalar('Validation/mel_loss', m_l, current_step) 268 | logger.add_scalar('Validation/duration_loss', d_l, current_step) 269 | logger.add_scalar('Validation/f0_loss', f_l, current_step) 270 | logger.add_scalar('Validation/energy_loss', e_l, current_step) 271 | model.train() 272 | 273 | current_step += 1 274 | 275 | if rank == 0: 276 | print("Training Done at Step : {}".format(current_step)) 277 | torch.save({'model': model_without_ddp.state_dict(), 278 | 'discriminator': discriminator_without_ddp.state_dict(), 279 | 'G_optim': G_optim.state_dict(), 'D_optim': D_optim.state_dict(), 280 | 'step': current_step}, 281 | os.path.join(checkpoint_path, 'checkpoint_last_{}.pth.tar'.format(current_step))) 282 | 283 | 284 | if __name__ == "__main__": 285 | parser = argparse.ArgumentParser() 286 | parser.add_argument('--data_path', default='dataset/LibriTTS/preprocessed') 287 | parser.add_argument('--save_path', default='exp_meta_stylespeech') 288 | parser.add_argument('--config', default='configs/config.json') 289 | parser.add_argument('--max_iter', default=100000, type=int) 290 | parser.add_argument('--save_step', default=5000, type=int) 291 | parser.add_argument('--synth_step', default=1000, type=int) 292 | parser.add_argument('--eval_step', default=5000, type=int) 293 | parser.add_argument('--log_step', default=100, type=int) 294 | parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pretrained model') 295 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:3456', type=str, help='url for setting up distributed training') 296 | parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') 297 | parser.add_argument('--rank', default=-1, type=int, help='distributed backend') 298 | parser.add_argument('--dist-backend', default='nccl', type=str, help='node rank for distributed training') 299 | 300 | args = parser.parse_args() 301 | 302 | torch.backends.cudnn.enabled = True 303 | 304 | with open(args.config) as f: 305 | data = f.read() 306 | json_config = json.loads(data) 307 | config = utils.AttrDict(json_config) 308 | utils.build_env(args.config, 'config.json', args.save_path) 309 | 310 | ngpus = torch.cuda.device_count() 311 | args.ngpus = ngpus 312 | args.distributed = ngpus > 1 313 | 314 | if args.distributed: 315 | args.world_size = ngpus 316 | mp.spawn(main, nprocs=ngpus, args=(args, config)) 317 | else: 318 | main(0, args, config) 319 | 320 | 321 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib 7 | matplotlib.use("Agg") 8 | from matplotlib import pyplot as plt 9 | 10 | 11 | def process_meta(meta_path): 12 | with open(meta_path, "r", encoding="utf-8") as f: 13 | text = [] 14 | name = [] 15 | sid = [] 16 | for line in f.readlines(): 17 | n, t, s = line.strip('\n').split('|') 18 | name.append(n) 19 | text.append(t) 20 | sid.append(s) 21 | return name, text, sid 22 | 23 | 24 | def get_param_num(model): 25 | num_param = sum(param.numel() for param in model.parameters()) 26 | return num_param 27 | 28 | 29 | def plot_data(data, titles=None, filename=None): 30 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 31 | fig.tight_layout() 32 | if titles is None: 33 | titles = [None for i in range(len(data))] 34 | for i in range(len(data)): 35 | spectrogram = data[i] 36 | axes[i][0].imshow(spectrogram, origin='lower') 37 | axes[i][0].set_aspect(2.5, adjustable='box') 38 | axes[i][0].set_ylim(0, 80) 39 | axes[i][0].set_title(titles[i], fontsize='medium') 40 | axes[i][0].tick_params(labelsize='x-small', left=False, labelleft=False) 41 | axes[i][0].set_anchor('W') 42 | 43 | plt.savefig(filename, dpi=200) 44 | plt.close() 45 | 46 | 47 | def get_mask_from_lengths(lengths, max_len=None): 48 | batch_size = lengths.shape[0] 49 | if max_len is None: 50 | max_len = torch.max(lengths).item() 51 | 52 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).cuda() 53 | mask = (ids >= lengths.unsqueeze(1).expand(-1, max_len)) 54 | return mask 55 | 56 | 57 | def pad_1D(inputs, PAD=0): 58 | 59 | def pad_data(x, length, PAD): 60 | x_padded = np.pad(x, (0, length - x.shape[0]), 61 | mode='constant', 62 | constant_values=PAD) 63 | return x_padded 64 | 65 | max_len = max((len(x) for x in inputs)) 66 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 67 | return padded 68 | 69 | 70 | def pad_2D(inputs, maxlen=None): 71 | 72 | def pad(x, max_len): 73 | PAD = 0 74 | if np.shape(x)[0] > max_len: 75 | raise ValueError("not max_len") 76 | 77 | s = np.shape(x)[1] 78 | x_padded = np.pad(x, (0, max_len - np.shape(x)[0]), 79 | mode='constant', 80 | constant_values=PAD) 81 | return x_padded[:, :s] 82 | 83 | if maxlen: 84 | output = np.stack([pad(x, maxlen) for x in inputs]) 85 | else: 86 | max_len = max(np.shape(x)[0] for x in inputs) 87 | output = np.stack([pad(x, max_len) for x in inputs]) 88 | return output 89 | 90 | 91 | def pad(input_ele, mel_max_length=None): 92 | if mel_max_length: 93 | max_len = mel_max_length 94 | else: 95 | max_len = max([input_ele[i].size(0)for i in range(len(input_ele))]) 96 | 97 | out_list = list() 98 | for i, batch in enumerate(input_ele): 99 | if len(batch.shape) == 1: 100 | one_batch_padded = F.pad( 101 | batch, (0, max_len-batch.size(0)), "constant", 0.0) 102 | elif len(batch.shape) == 2: 103 | one_batch_padded = F.pad( 104 | batch, (0, 0, 0, max_len-batch.size(0)), "constant", 0.0) 105 | out_list.append(one_batch_padded) 106 | out_padded = torch.stack(out_list) 107 | return out_padded 108 | 109 | 110 | class AttrDict(dict): 111 | def __init__(self, *args, **kwargs): 112 | super(AttrDict, self).__init__(*args, **kwargs) 113 | self.__dict__ = self 114 | 115 | 116 | def build_env(config, config_name, path): 117 | t_path = os.path.join(path, config_name) 118 | if config != t_path: 119 | os.makedirs(path, exist_ok=True) 120 | shutil.copyfile(config, t_path) 121 | --------------------------------------------------------------------------------