├── .gitignore ├── LICENSE ├── README.md ├── audio ├── __init__.py ├── audio_processing.py ├── stft.py └── tools.py ├── config ├── AIHub-MMV │ ├── model.yaml │ ├── preprocess.yaml │ └── train.yaml ├── IEMOCAP │ ├── model.yaml │ ├── preprocess.yaml │ └── train.yaml └── README.md ├── dataset.py ├── evaluate.py ├── hifigan ├── LICENSE ├── __init__.py ├── config.json ├── generator_LJSpeech.pth.tar.zip ├── generator_universal.pth.tar.zip └── models.py ├── img ├── emotional-fastspeech2-audios.png ├── emotional-fastspeech2-images.png ├── emotional-fastspeech2-scalars.png ├── model.png ├── model_conversational_tts.png ├── model_conversational_tts_chat_history.png └── model_emotional_tts.png ├── model ├── __init__.py ├── fastspeech2.py ├── loss.py ├── modules.py └── optimizer.py ├── preparation ├── aihub_mmv.py └── iemocap.py ├── prepare_align.py ├── prepare_data.py ├── preprocess.py ├── preprocessor ├── aihub_mmv.py ├── iemocap.py └── preprocessor.py ├── requirements.txt ├── synthesize.py ├── text ├── __init__.py ├── cleaners.py ├── cmudict.py ├── korean.py ├── korean_dict.py ├── numbers.py ├── pinyin.py └── symbols.py ├── train.py ├── transformer ├── Constants.py ├── Layers.py ├── Models.py ├── Modules.py ├── SubLayers.py └── __init__.py └── utils ├── model.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | __pycache__ 107 | .vscode 108 | .DS_Store 109 | 110 | # nohup 111 | *.out 112 | 113 | # data, checkpoint, and models 114 | raw_data/ 115 | output/ 116 | *.npy 117 | TextGrid/ 118 | hifigan/*.pth.tar 119 | lexicon/ 120 | montreal-forced-aligner/ 121 | preprocessed_data/ 122 | preparation/*.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Keon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | MIT License 24 | 25 | Copyright (c) 2020 Chung-Ming Chien 26 | 27 | Permission is hereby granted, free of charge, to any person obtaining a copy 28 | of this software and associated documentation files (the "Software"), to deal 29 | in the Software without restriction, including without limitation the rights 30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 31 | copies of the Software, and to permit persons to whom the Software is 32 | furnished to do so, subject to the following conditions: 33 | 34 | The above copyright notice and this permission notice shall be included in all 35 | copies or substantial portions of the Software. 36 | 37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 43 | SOFTWARE. 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Expressive-FastSpeech2 - PyTorch Implementation 2 | 3 | ## Contributions 4 | 5 | 1. **`Non-autoregressive Expressive TTS`**: This project aims to provide a cornerstone for future research and application on a non-autoregressive expressive TTS including `Emotional TTS` and `Conversational TTS`. For datasets, [AIHub Multimodal Video AI datasets](https://www.aihub.or.kr/aidata/137) and [IEMOCAP database](https://sail.usc.edu/iemocap/) are picked for Korean and English, respectively. 6 | 7 | **Note**: If you are interested in [GST-Tacotron](https://arxiv.org/abs/1803.09017) or [VAE-Tacotron](https://arxiv.org/abs/1812.04342) like expressive stylistic TTS model but under non-autoregressive decoding, you may also be interested in [STYLER](https://arxiv.org/abs/2103.09474) [[demo](https://keonlee9420.github.io/STYLER-Demo/), [code](https://github.com/keonlee9420/STYLER)]. 8 | 9 | 2. **`Annotated Data Processing`**: This project shed light on how to handle the new dataset, even with a different language, for the successful training of non-autoregressive emotional TTS. 10 | 3. **`English and Korean TTS`**: In addition to English, this project gives a broad view of treating Korean for the non-autoregressive TTS where the additional data processing must be considered under the language-specific features (e.g., training Montreal Forced Aligner with your own language and dataset). Please closely look into `text/`. 11 | 4. **`Adopting Own Language`**: For those who are interested in adapting other languages, please refer to the ["Training with your own dataset (own language)" section](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/categorical#training-with-your-own-dataset-own-language) of the [categorical branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/categorical). 12 | 13 | ## Repository Structure 14 | 15 | In this project, FastSpeech2 is adapted as a base non-autoregressive multi-speaker TTS framework, so it would be helpful to read [the paper](https://arxiv.org/abs/2006.04558) and [code](https://github.com/ming024/FastSpeech2) first (Also see [FastSpeech2 branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/FastSpeech2)). 16 | 17 |

18 | 19 |

20 | 21 | 1. `Emotional TTS`: Following branches contain implementations of the basic paradigm intorduced by [Emotional End-to-End Neural Speech synthesizer](https://arxiv.org/pdf/1711.05447.pdf). 22 | 23 |

24 | 25 |

26 | 27 | - [categorical branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/categorical): only conditioning categorical emotional descriptors (such as happy, sad, etc.) 28 | - [continuous branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/continuous): conditioning continuous emotional descriptors (such as arousal, valence, etc.) in addition to categorical emotional descriptors 29 | 2. `Conversational TTS`: Following branch contains implementation of [Conversational End-to-End TTS for Voice Agent](https://arxiv.org/abs/2005.10438) 30 | 31 |

32 | 33 |

34 | 35 | - [conversational branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/conversational): conditioning chat history 36 | 37 | ## Citation 38 | 39 | If you would like to use or refer to this implementation, please cite the repo. 40 | 41 | ```bash 42 | @misc{lee2021expressive_fastspeech2, 43 | author = {Lee, Keon}, 44 | title = {Expressive-FastSpeech2}, 45 | year = {2021}, 46 | publisher = {GitHub}, 47 | journal = {GitHub repository}, 48 | howpublished = {\url{https://github.com/keonlee9420/Expressive-FastSpeech2}} 49 | } 50 | ``` 51 | 52 | ## References 53 | 54 | - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) (Later than 2021.02.26 ver.) 55 | - [HGU-DLLAB's Korean-FastSpeech2-Pytorch](https://github.com/HGU-DLLAB/Korean-FastSpeech2-Pytorch) 56 | - [hccho2's Tacotron2-Wavenet-Korean-TTS](https://github.com/hccho2/Tacotron2-Wavenet-Korean-TTS) 57 | - [carpedm20' multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-speaker-tacotron-tensorflow) -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data.cuda(), 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes): 152 | output = dynamic_range_compression(magnitudes) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1 170 | assert torch.max(y.data) <= 1 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, energy 179 | -------------------------------------------------------------------------------- /audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | 5 | from audio.audio_processing import griffin_lim 6 | 7 | 8 | def get_mel_from_wav(audio, _stft): 9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 10 | audio = torch.autograd.Variable(audio, requires_grad=False) 11 | melspec, energy = _stft.mel_spectrogram(audio) 12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 14 | 15 | return melspec, energy 16 | 17 | 18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 19 | mel = torch.stack([mel]) 20 | mel_decompress = _stft.spectral_de_normalize(mel) 21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 22 | spec_from_mel_scaling = 1000 23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 26 | 27 | audio = griffin_lim( 28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 29 | ) 30 | 31 | audio = audio.squeeze() 32 | audio = audio.cpu().numpy() 33 | audio_path = out_filename 34 | write(audio_path, _stft.sampling_rate, audio) 35 | -------------------------------------------------------------------------------- /config/AIHub-MMV/model.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 4 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 6 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | # gst: 24 | # use_gst: False 25 | # conv_filters: [32, 32, 64, 64, 128, 128] 26 | # gru_hidden: 128 27 | # token_size: 128 28 | # n_style_token: 10 29 | # attn_head: 4 30 | 31 | multi_speaker: True 32 | 33 | multi_emotion: True 34 | 35 | max_seq_len: 2000 36 | 37 | vocoder: 38 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN' 39 | speaker: "universal" # support 'LJSpeech', 'universal' 40 | -------------------------------------------------------------------------------- /config/AIHub-MMV/preprocess.yaml: -------------------------------------------------------------------------------- 1 | dataset: "AIHub-MMV" 2 | 3 | path: 4 | corpus_path: "/path/to/AIHub-MMV" 5 | sub_dir_name: "clips" 6 | lexicon_path: "lexicon/aihub-mmv-lexicon.txt" 7 | fixed_text_path: "./preparation/aihub_mmv_fixed.txt" 8 | raw_path: "./raw_data/AIHub-MMV" 9 | preprocessed_path: "./preprocessed_data/AIHub-MMV" 10 | 11 | preprocessing: 12 | val_size: 512 13 | text: 14 | text_cleaners: ["korean_cleaners"] 15 | language: "kr" 16 | audio: 17 | sampling_rate: 22050 18 | max_wav_value: 32768.0 19 | stft: 20 | filter_length: 1024 21 | hop_length: 256 22 | win_length: 1024 23 | mel: 24 | n_mel_channels: 80 25 | mel_fmin: 0 26 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 27 | pitch: 28 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 29 | normalization: True 30 | energy: 31 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 32 | normalization: True 33 | -------------------------------------------------------------------------------- /config/AIHub-MMV/train.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/AIHub-MMV" 3 | log_path: "./output/log/AIHub-MMV" 4 | result_path: "./output/result/AIHub-MMV" 5 | optimizer: 6 | batch_size: 16 7 | betas: [0.9, 0.98] 8 | eps: 0.000000001 9 | weight_decay: 0.0 10 | grad_clip_thresh: 1.0 11 | grad_acc_step: 1 12 | warm_up_step: 4000 13 | anneal_steps: [300000, 400000, 500000] 14 | anneal_rate: 0.3 15 | step: 16 | total_step: 900000 17 | log_step: 100 18 | synth_step: 1000 19 | val_step: 1000 20 | save_step: 100000 21 | -------------------------------------------------------------------------------- /config/IEMOCAP/model.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 4 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 6 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | # gst: 24 | # use_gst: False 25 | # conv_filters: [32, 32, 64, 64, 128, 128] 26 | # gru_hidden: 128 27 | # token_size: 128 28 | # n_style_token: 10 29 | # attn_head: 4 30 | 31 | multi_speaker: True 32 | 33 | multi_emotion: True 34 | 35 | max_seq_len: 2000 36 | 37 | vocoder: 38 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN' 39 | speaker: "universal" # support 'LJSpeech', 'universal' 40 | -------------------------------------------------------------------------------- /config/IEMOCAP/preprocess.yaml: -------------------------------------------------------------------------------- 1 | dataset: "IEMOCAP" 2 | 3 | path: 4 | corpus_path: "/path/to/IEMOCAP_full_release" 5 | sub_dir_name: "sessions" 6 | lexicon_path: "lexicon/iemocap-lexicon.txt" 7 | fixed_text_path: "./preparation/iemocap_fixed.txt" 8 | raw_path: "./raw_data/IEMOCAP" 9 | preprocessed_path: "./preprocessed_data/IEMOCAP" 10 | 11 | preprocessing: 12 | val_size: 512 13 | text: 14 | text_cleaners: ["english_cleaners"] 15 | language: "en" 16 | audio: 17 | sampling_rate: 22050 18 | max_wav_value: 32768.0 19 | stft: 20 | filter_length: 1024 21 | hop_length: 256 22 | win_length: 1024 23 | mel: 24 | n_mel_channels: 80 25 | mel_fmin: 0 26 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 27 | pitch: 28 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 29 | normalization: True 30 | energy: 31 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 32 | normalization: True 33 | -------------------------------------------------------------------------------- /config/IEMOCAP/train.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/IEMOCAP" 3 | log_path: "./output/log/IEMOCAP" 4 | result_path: "./output/result/IEMOCAP" 5 | optimizer: 6 | batch_size: 16 7 | betas: [0.9, 0.98] 8 | eps: 0.000000001 9 | weight_decay: 0.0 10 | grad_clip_thresh: 1.0 11 | grad_acc_step: 1 12 | warm_up_step: 4000 13 | anneal_steps: [300000, 400000, 500000] 14 | anneal_rate: 0.3 15 | step: 16 | total_step: 900000 17 | log_step: 100 18 | synth_step: 1000 19 | val_step: 1000 20 | save_step: 100000 21 | -------------------------------------------------------------------------------- /config/README.md: -------------------------------------------------------------------------------- 1 | # Config 2 | 3 | Here is the config file used to train the multi-speaker emotional TTS models. Two different configurations are given: 4 | 5 | - AIHub-MMV: suggested configuration for AIHub-MMV dataset. 6 | - IEMOCAP: suggested configuration for IEMOCAP dataset. 7 | 8 | Some important hyper-parameters are explained here. 9 | 10 | ## preprocess.yaml 11 | 12 | - **path.lexicon_path**: the lexicon (which maps words to phonemes) used by Montreal Forced Aligner. 13 | - **mel.stft.mel_fmax**: set it to 8000 if HiFi-GAN vocoder is used, and set it to null if MelGAN is used. 14 | - **pitch.feature & energy.feature**: the original paper proposed to predict and apply frame-level pitch and energy features to the inputs of the TTS decoder to control the pitch and energy of the synthesized utterances. 15 | However, in our experiments, we find that using phoneme-level features makes the prosody of the synthesized utterances more natural. 16 | - **pitch.normalization & energy.normalization**: to normalize the pitch and energy values or not. 17 | The original paper did not normalize these values. 18 | 19 | ## train.yaml 20 | 21 | - **optimizer.grad_acc_step**: the number of batches of gradient accumulation before updating the model parameters and call optimizer.zero_grad(), which is useful if you wish to train the model with a large batch size but you do not have sufficient GPU memory. 22 | - **optimizer.anneal_steps & optimizer.anneal_rate**: the learning rate is reduced at the **anneal_steps** by the ratio specified with **anneal_rate**. 23 | 24 | ## model.yaml 25 | 26 | - **transformer.decoder_layer**: the original paper used a 4-layer decoder, but we find it better to use a 6-layer decoder, especially for multi-speaker TTS. 27 | - **variance_embedding.pitch_quantization**: when the pitch values are normalized as specified in ``preprocess.yaml``, it is not valid to use log-scale quantization bins as proposed in the original paper, so we use linear-scaled bins instead. 28 | - **multi_speaker**: to apply a speaker embedding table to enable multi-speaker TTS or not. 29 | - **multi_emotion**: to apply a emotion embedding table to enable multi-emotion TTS or not. 30 | - **vocoder.speaker**: should be set to 'universal'. 31 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | 9 | from text import text_to_sequence 10 | from utils.tools import pad_1D, pad_2D 11 | 12 | 13 | class Dataset(Dataset): 14 | def __init__( 15 | self, filename, preprocess_config, model_config, train_config, sort=False, drop_last=False 16 | ): 17 | self.dataset_name = preprocess_config["dataset"] 18 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 19 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 20 | self.max_seq_len = model_config["max_seq_len"] 21 | self.batch_size = train_config["optimizer"]["batch_size"] 22 | 23 | self.basename, self.speaker, self.text, self.raw_text, self.aux_data = self.process_meta( 24 | filename 25 | ) 26 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: 27 | self.speaker_map = json.load(f) 28 | with open(os.path.join(self.preprocessed_path, "emotions.json")) as f: 29 | json_raw = json.load(f) 30 | self.emotion_map = json_raw["emotion_dict"] 31 | self.arousal_map = json_raw["arousal_dict"] 32 | self.valence_map = json_raw["valence_dict"] 33 | self.sort = sort 34 | self.drop_last = drop_last 35 | 36 | def __len__(self): 37 | return len(self.text) 38 | 39 | def __getitem__(self, idx): 40 | basename = self.basename[idx] 41 | speaker = self.speaker[idx] 42 | speaker_id = self.speaker_map[speaker] 43 | aux_data = self.aux_data[idx].split("|") 44 | emotion = self.emotion_map[aux_data[-3]] 45 | arousal = self.arousal_map[aux_data[-2]] 46 | valence = self.valence_map[aux_data[-1]] 47 | raw_text = self.raw_text[idx] 48 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 49 | mel_path = os.path.join( 50 | self.preprocessed_path, 51 | "mel", 52 | "{}-mel-{}.npy".format(speaker, basename), 53 | ) 54 | mel = np.load(mel_path) 55 | pitch_path = os.path.join( 56 | self.preprocessed_path, 57 | "pitch", 58 | "{}-pitch-{}.npy".format(speaker, basename), 59 | ) 60 | pitch = np.load(pitch_path) 61 | energy_path = os.path.join( 62 | self.preprocessed_path, 63 | "energy", 64 | "{}-energy-{}.npy".format(speaker, basename), 65 | ) 66 | energy = np.load(energy_path) 67 | duration_path = os.path.join( 68 | self.preprocessed_path, 69 | "duration", 70 | "{}-duration-{}.npy".format(speaker, basename), 71 | ) 72 | duration = np.load(duration_path) 73 | 74 | sample = { 75 | "id": basename, 76 | "speaker": speaker_id, 77 | "emotion": emotion, 78 | "arousal": arousal, 79 | "valence": valence, 80 | "text": phone, 81 | "raw_text": raw_text, 82 | "mel": mel, 83 | "pitch": pitch, 84 | "energy": energy, 85 | "duration": duration, 86 | } 87 | 88 | return sample 89 | 90 | def process_meta(self, filename): 91 | with open( 92 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8" 93 | ) as f: 94 | name = [] 95 | speaker = [] 96 | text = [] 97 | raw_text = [] 98 | aux_data = [] 99 | for line in tqdm(f.readlines()): 100 | line_split = line.strip("\n").split("|") 101 | n, s, t, r = line_split[:4] 102 | mel_path = os.path.join( 103 | self.preprocessed_path, 104 | "mel", 105 | "{}-mel-{}.npy".format(s, n), 106 | ) 107 | mel = np.load(mel_path) 108 | if mel.shape[0] > self.max_seq_len: 109 | continue 110 | a = "|".join(line_split[4:]) 111 | name.append(n) 112 | speaker.append(s) 113 | text.append(t) 114 | raw_text.append(r) 115 | aux_data.append(a) 116 | return name, speaker, text, raw_text, aux_data 117 | 118 | def reprocess(self, data, idxs): 119 | ids = [data[idx]["id"] for idx in idxs] 120 | speakers = [data[idx]["speaker"] for idx in idxs] 121 | emotions = [data[idx]["emotion"] for idx in idxs] 122 | arousals = [data[idx]["arousal"] for idx in idxs] 123 | valences = [data[idx]["valence"] for idx in idxs] 124 | texts = [data[idx]["text"] for idx in idxs] 125 | raw_texts = [data[idx]["raw_text"] for idx in idxs] 126 | mels = [data[idx]["mel"] for idx in idxs] 127 | pitches = [data[idx]["pitch"] for idx in idxs] 128 | energies = [data[idx]["energy"] for idx in idxs] 129 | durations = [data[idx]["duration"] for idx in idxs] 130 | 131 | text_lens = np.array([text.shape[0] for text in texts]) 132 | mel_lens = np.array([mel.shape[0] for mel in mels]) 133 | 134 | speakers = np.array(speakers) 135 | emotions = np.array(emotions) 136 | arousals = np.array(arousals) 137 | valences = np.array(valences) 138 | texts = pad_1D(texts) 139 | mels = pad_2D(mels) 140 | pitches = pad_1D(pitches) 141 | energies = pad_1D(energies) 142 | durations = pad_1D(durations) 143 | 144 | return ( 145 | ids, 146 | raw_texts, 147 | speakers, 148 | emotions, 149 | arousals, 150 | valences, 151 | texts, 152 | text_lens, 153 | max(text_lens), 154 | mels, 155 | mel_lens, 156 | max(mel_lens), 157 | pitches, 158 | energies, 159 | durations, 160 | ) 161 | 162 | def collate_fn(self, data): 163 | data_size = len(data) 164 | 165 | if self.sort: 166 | len_arr = np.array([d["text"].shape[0] for d in data]) 167 | idx_arr = np.argsort(-len_arr) 168 | else: 169 | idx_arr = np.arange(data_size) 170 | 171 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :] 172 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)] 173 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() 174 | if not self.drop_last and len(tail) > 0: 175 | idx_arr += [tail.tolist()] 176 | 177 | output = list() 178 | for idx in idx_arr: 179 | output.append(self.reprocess(data, idx)) 180 | 181 | return output 182 | 183 | 184 | class TextDataset(Dataset): 185 | def __init__(self, filepath, preprocess_config, model_config): 186 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 187 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 188 | self.max_seq_len = model_config["max_seq_len"] 189 | 190 | self.basename, self.speaker, self.text, self.raw_text, self.aux_data = self.process_meta( 191 | filepath 192 | ) 193 | with open( 194 | os.path.join( 195 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 196 | ) 197 | ) as f: 198 | self.speaker_map = json.load(f) 199 | with open(os.path.join(self.preprocessed_path, "emotions.json")) as f: 200 | json_raw = json.load(f) 201 | self.emotion_map = json_raw["emotion_dict"] 202 | self.arousal_map = json_raw["arousal_dict"] 203 | self.valence_map = json_raw["valence_dict"] 204 | 205 | def __len__(self): 206 | return len(self.text) 207 | 208 | def __getitem__(self, idx): 209 | basename = self.basename[idx] 210 | speaker = self.speaker[idx] 211 | speaker_id = self.speaker_map[speaker] 212 | aux_data = self.aux_data[idx].split("|") 213 | emotion = self.emotion_map[aux_data[-3]] 214 | arousal = self.arousal_map[aux_data[-2]] 215 | valence = self.valence_map[aux_data[-1]] 216 | raw_text = self.raw_text[idx] 217 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 218 | 219 | return (basename, speaker_id, emotion, arousal, valence, phone, raw_text) 220 | 221 | def process_meta(self, filename): 222 | with open(filename, "r", encoding="utf-8") as f: 223 | name = [] 224 | speaker = [] 225 | text = [] 226 | raw_text = [] 227 | aux_data = [] 228 | for line in tqdm(f.readlines()): 229 | line_split = line.strip("\n").split("|") 230 | n, s, t, r = line_split[:4] 231 | mel_path = os.path.join( 232 | self.preprocessed_path, 233 | "mel", 234 | "{}-mel-{}.npy".format(s, n), 235 | ) 236 | mel = np.load(mel_path) 237 | if mel.shape[0] > self.max_seq_len: 238 | continue 239 | a = "|".join(line_split[4:]) 240 | name.append(n) 241 | speaker.append(s) 242 | text.append(t) 243 | raw_text.append(r) 244 | aux_data.append(a) 245 | return name, speaker, text, raw_text, aux_data 246 | 247 | def collate_fn(self, data): 248 | ids = [d[0] for d in data] 249 | speakers = np.array([d[1] for d in data]) 250 | emotions = np.array([d[2] for d in data]) 251 | arousals = np.array([d[3] for d in data]) 252 | valences = np.array([d[4] for d in data]) 253 | texts = [d[5] for d in data] 254 | raw_texts = [d[6] for d in data] 255 | text_lens = np.array([text.shape[0] for text in texts]) 256 | 257 | texts = pad_1D(texts) 258 | 259 | return ids, raw_texts, speakers, emotions, arousals, valences, texts, text_lens, max(text_lens) 260 | 261 | 262 | if __name__ == "__main__": 263 | # Test 264 | import torch 265 | import yaml 266 | from torch.utils.data import DataLoader 267 | from utils.utils import to_device 268 | 269 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 270 | preprocess_config = yaml.load( 271 | open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader 272 | ) 273 | model_config = yaml.load( 274 | open("./config/LJSpeech/model.yaml", "r"), Loader=yaml.FullLoader 275 | ) 276 | train_config = yaml.load( 277 | open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader 278 | ) 279 | 280 | train_dataset = Dataset( 281 | "train.txt", preprocess_config, model_config, train_config, sort=True, drop_last=True 282 | ) 283 | val_dataset = Dataset( 284 | "val.txt", preprocess_config, model_config, train_config, sort=False, drop_last=False 285 | ) 286 | train_loader = DataLoader( 287 | train_dataset, 288 | batch_size=train_config["optimizer"]["batch_size"] * 4, 289 | shuffle=True, 290 | collate_fn=train_dataset.collate_fn, 291 | ) 292 | val_loader = DataLoader( 293 | val_dataset, 294 | batch_size=train_config["optimizer"]["batch_size"], 295 | shuffle=False, 296 | collate_fn=val_dataset.collate_fn, 297 | ) 298 | 299 | n_batch = 0 300 | for batchs in train_loader: 301 | for batch in batchs: 302 | to_device(batch, device) 303 | n_batch += 1 304 | print( 305 | "Training set with size {} is composed of {} batches.".format( 306 | len(train_dataset), n_batch 307 | ) 308 | ) 309 | 310 | n_batch = 0 311 | for batchs in val_loader: 312 | for batch in batchs: 313 | to_device(batch, device) 314 | n_batch += 1 315 | print( 316 | "Validation set with size {} is composed of {} batches.".format( 317 | len(val_dataset), n_batch 318 | ) 319 | ) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | 9 | from utils.model import get_model, get_vocoder 10 | from utils.tools import to_device, log, synth_one_sample 11 | from model import FastSpeech2Loss 12 | from dataset import Dataset 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def evaluate(model, step, configs, logger=None, vocoder=None): 19 | preprocess_config, model_config, train_config = configs 20 | 21 | # Get dataset 22 | dataset = Dataset( 23 | "val.txt", preprocess_config, model_config, train_config, sort=False, drop_last=False 24 | ) 25 | batch_size = train_config["optimizer"]["batch_size"] 26 | loader = DataLoader( 27 | dataset, 28 | batch_size=batch_size, 29 | shuffle=False, 30 | collate_fn=dataset.collate_fn, 31 | ) 32 | 33 | # Get loss function 34 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) 35 | 36 | # Evaluation 37 | loss_sums = [0 for _ in range(6)] 38 | for batchs in loader: 39 | for batch in batchs: 40 | batch = to_device(batch, device) 41 | with torch.no_grad(): 42 | # Forward 43 | output = model(*(batch[2:])) 44 | 45 | # Cal Loss 46 | losses = Loss(batch, output) 47 | 48 | for i in range(len(losses)): 49 | loss_sums[i] += losses[i].item() * len(batch[0]) 50 | 51 | loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums] 52 | 53 | message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format( 54 | *([step] + [l for l in loss_means]) 55 | ) 56 | 57 | if logger is not None: 58 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 59 | batch, 60 | output, 61 | vocoder, 62 | model_config, 63 | preprocess_config, 64 | ) 65 | 66 | log(logger, step, losses=loss_means) 67 | log( 68 | logger, 69 | fig=fig, 70 | tag="Validation/step_{}_{}".format(step, tag), 71 | ) 72 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 73 | log( 74 | logger, 75 | audio=wav_reconstruction, 76 | sampling_rate=sampling_rate, 77 | tag="Validation/step_{}_{}_reconstructed".format(step, tag), 78 | ) 79 | log( 80 | logger, 81 | audio=wav_prediction, 82 | sampling_rate=sampling_rate, 83 | tag="Validation/step_{}_{}_synthesized".format(step, tag), 84 | ) 85 | 86 | return message 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--restore_step", type=int, default=30000) 93 | parser.add_argument( 94 | "-p", 95 | "--preprocess_config", 96 | type=str, 97 | required=True, 98 | help="path to preprocess.yaml", 99 | ) 100 | parser.add_argument( 101 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 102 | ) 103 | parser.add_argument( 104 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 105 | ) 106 | args = parser.parse_args() 107 | 108 | # Read Config 109 | preprocess_config = yaml.load( 110 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 111 | ) 112 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 113 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 114 | configs = (preprocess_config, model_config, train_config) 115 | 116 | # Get model 117 | model = get_model(args, configs, device, train=False).to(device) 118 | 119 | message = evaluate(model, args.restore_step, configs) 120 | print(message) -------------------------------------------------------------------------------- /hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self -------------------------------------------------------------------------------- /hifigan/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /hifigan/generator_LJSpeech.pth.tar.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/hifigan/generator_LJSpeech.pth.tar.zip -------------------------------------------------------------------------------- /hifigan/generator_universal.pth.tar.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/hifigan/generator_universal.pth.tar.zip -------------------------------------------------------------------------------- /hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2 ** i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) -------------------------------------------------------------------------------- /img/emotional-fastspeech2-audios.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/emotional-fastspeech2-audios.png -------------------------------------------------------------------------------- /img/emotional-fastspeech2-images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/emotional-fastspeech2-images.png -------------------------------------------------------------------------------- /img/emotional-fastspeech2-scalars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/emotional-fastspeech2-scalars.png -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model.png -------------------------------------------------------------------------------- /img/model_conversational_tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model_conversational_tts.png -------------------------------------------------------------------------------- /img/model_conversational_tts_chat_history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model_conversational_tts_chat_history.png -------------------------------------------------------------------------------- /img/model_emotional_tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model_emotional_tts.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2 import FastSpeech2 2 | from .loss import FastSpeech2Loss 3 | from .optimizer import ScheduledOptim -------------------------------------------------------------------------------- /model/fastspeech2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from transformer import Encoder, Decoder, PostNet 9 | from .modules import VarianceAdaptor 10 | from utils.tools import get_mask_from_lengths 11 | 12 | 13 | class FastSpeech2(nn.Module): 14 | """ FastSpeech2 """ 15 | 16 | def __init__(self, preprocess_config, model_config): 17 | super(FastSpeech2, self).__init__() 18 | self.model_config = model_config 19 | 20 | self.encoder = Encoder(model_config) 21 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config) 22 | self.decoder = Decoder(model_config) 23 | self.mel_linear = nn.Linear( 24 | model_config["transformer"]["decoder_hidden"], 25 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 26 | ) 27 | self.postnet = PostNet() 28 | 29 | self.speaker_emb = None 30 | if model_config["multi_speaker"]: 31 | with open( 32 | os.path.join( 33 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 34 | ), 35 | "r", 36 | ) as f: 37 | n_speaker = len(json.load(f)) 38 | self.speaker_emb = nn.Embedding( 39 | n_speaker, 40 | model_config["transformer"]["encoder_hidden"], 41 | ) 42 | 43 | self.emotion_emb = None 44 | if model_config["multi_emotion"]: 45 | with open( 46 | os.path.join( 47 | preprocess_config["path"]["preprocessed_path"], "emotions.json" 48 | ), 49 | "r", 50 | ) as f: 51 | json_raw = json.load(f) 52 | n_emotion = len(json_raw["emotion_dict"]) 53 | n_arousal = len(json_raw["arousal_dict"]) 54 | n_valence = len(json_raw["valence_dict"]) 55 | encoder_hidden = model_config["transformer"]["encoder_hidden"] 56 | self.emotion_emb = nn.Embedding( 57 | n_emotion, 58 | encoder_hidden//2, 59 | ) 60 | self.arousal_emb = nn.Embedding( 61 | n_arousal, 62 | encoder_hidden//4, 63 | ) 64 | self.valence_emb = nn.Embedding( 65 | n_valence, 66 | encoder_hidden//4, 67 | ) 68 | self.emotion_linear = nn.Sequential( 69 | nn.Linear(encoder_hidden, encoder_hidden), 70 | nn.ReLU() 71 | ) 72 | 73 | def forward( 74 | self, 75 | speakers, 76 | emotions, 77 | arousals, 78 | valences, 79 | texts, 80 | src_lens, 81 | max_src_len, 82 | mels=None, 83 | mel_lens=None, 84 | max_mel_len=None, 85 | p_targets=None, 86 | e_targets=None, 87 | d_targets=None, 88 | p_control=1.0, 89 | e_control=1.0, 90 | d_control=1.0, 91 | ): 92 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 93 | mel_masks = ( 94 | get_mask_from_lengths(mel_lens, max_mel_len) 95 | if mel_lens is not None 96 | else None 97 | ) 98 | 99 | output = self.encoder(texts, src_masks) 100 | 101 | if self.speaker_emb is not None: 102 | output = output + self.speaker_emb(speakers).unsqueeze(1).expand( 103 | -1, max_src_len, -1 104 | ) 105 | 106 | if self.emotion_emb is not None: 107 | emb = torch.cat((self.emotion_emb(emotions), self.arousal_emb(arousals), self.valence_emb(valences)), dim=-1) 108 | output = output + self.emotion_linear(emb).unsqueeze(1).expand( 109 | -1, max_src_len, -1 110 | ) 111 | 112 | ( 113 | output, 114 | p_predictions, 115 | e_predictions, 116 | log_d_predictions, 117 | d_rounded, 118 | mel_lens, 119 | mel_masks, 120 | ) = self.variance_adaptor( 121 | output, 122 | src_masks, 123 | mel_masks, 124 | max_mel_len, 125 | p_targets, 126 | e_targets, 127 | d_targets, 128 | p_control, 129 | e_control, 130 | d_control, 131 | ) 132 | 133 | output, mel_masks = self.decoder(output, mel_masks) 134 | output = self.mel_linear(output) 135 | 136 | postnet_output = self.postnet(output) + output 137 | 138 | return ( 139 | output, 140 | postnet_output, 141 | p_predictions, 142 | e_predictions, 143 | log_d_predictions, 144 | d_rounded, 145 | src_masks, 146 | mel_masks, 147 | src_lens, 148 | mel_lens, 149 | ) -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FastSpeech2Loss(nn.Module): 6 | """ FastSpeech2 Loss """ 7 | 8 | def __init__(self, preprocess_config, model_config): 9 | super(FastSpeech2Loss, self).__init__() 10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 11 | "feature" 12 | ] 13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 14 | "feature" 15 | ] 16 | self.mse_loss = nn.MSELoss() 17 | self.mae_loss = nn.L1Loss() 18 | 19 | def forward(self, inputs, predictions): 20 | ( 21 | mel_targets, 22 | _, 23 | _, 24 | pitch_targets, 25 | energy_targets, 26 | duration_targets, 27 | ) = inputs[9:] 28 | ( 29 | mel_predictions, 30 | postnet_mel_predictions, 31 | pitch_predictions, 32 | energy_predictions, 33 | log_duration_predictions, 34 | _, 35 | src_masks, 36 | mel_masks, 37 | _, 38 | _, 39 | ) = predictions 40 | src_masks = ~src_masks 41 | mel_masks = ~mel_masks 42 | log_duration_targets = torch.log(duration_targets.float() + 1) 43 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 44 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 45 | 46 | log_duration_targets.requires_grad = False 47 | pitch_targets.requires_grad = False 48 | energy_targets.requires_grad = False 49 | mel_targets.requires_grad = False 50 | 51 | if self.pitch_feature_level == "phoneme_level": 52 | pitch_predictions = pitch_predictions.masked_select(src_masks) 53 | pitch_targets = pitch_targets.masked_select(src_masks) 54 | elif self.pitch_feature_level == "frame_level": 55 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 56 | pitch_targets = pitch_targets.masked_select(mel_masks) 57 | 58 | if self.energy_feature_level == "phoneme_level": 59 | energy_predictions = energy_predictions.masked_select(src_masks) 60 | energy_targets = energy_targets.masked_select(src_masks) 61 | if self.energy_feature_level == "frame_level": 62 | energy_predictions = energy_predictions.masked_select(mel_masks) 63 | energy_targets = energy_targets.masked_select(mel_masks) 64 | 65 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 66 | log_duration_targets = log_duration_targets.masked_select(src_masks) 67 | 68 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 69 | postnet_mel_predictions = postnet_mel_predictions.masked_select( 70 | mel_masks.unsqueeze(-1) 71 | ) 72 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 73 | 74 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 75 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) 76 | 77 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 78 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 79 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) 80 | 81 | total_loss = ( 82 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss 83 | ) 84 | 85 | return ( 86 | total_loss, 87 | mel_loss, 88 | postnet_mel_loss, 89 | pitch_loss, 90 | energy_loss, 91 | duration_loss, 92 | ) 93 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | from utils.tools import get_mask_from_lengths, pad 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class VarianceAdaptor(nn.Module): 18 | """ Variance Adaptor """ 19 | 20 | def __init__(self, preprocess_config, model_config): 21 | super(VarianceAdaptor, self).__init__() 22 | self.duration_predictor = VariancePredictor(model_config) 23 | self.length_regulator = LengthRegulator() 24 | self.pitch_predictor = VariancePredictor(model_config) 25 | self.energy_predictor = VariancePredictor(model_config) 26 | 27 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 28 | "feature" 29 | ] 30 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 31 | "feature" 32 | ] 33 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 34 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 35 | 36 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"] 37 | energy_quantization = model_config["variance_embedding"]["energy_quantization"] 38 | n_bins = model_config["variance_embedding"]["n_bins"] 39 | assert pitch_quantization in ["linear", "log"] 40 | assert energy_quantization in ["linear", "log"] 41 | with open( 42 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 43 | ) as f: 44 | stats = json.load(f) 45 | pitch_min, pitch_max = stats["pitch"][:2] 46 | energy_min, energy_max = stats["energy"][:2] 47 | 48 | if pitch_quantization == "log": 49 | self.pitch_bins = nn.Parameter( 50 | torch.exp( 51 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) 52 | ), 53 | requires_grad=False, 54 | ) 55 | else: 56 | self.pitch_bins = nn.Parameter( 57 | torch.linspace(pitch_min, pitch_max, n_bins - 1), 58 | requires_grad=False, 59 | ) 60 | if energy_quantization == "log": 61 | self.energy_bins = nn.Parameter( 62 | torch.exp( 63 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) 64 | ), 65 | requires_grad=False, 66 | ) 67 | else: 68 | self.energy_bins = nn.Parameter( 69 | torch.linspace(energy_min, energy_max, n_bins - 1), 70 | requires_grad=False, 71 | ) 72 | 73 | self.pitch_embedding = nn.Embedding( 74 | n_bins, model_config["transformer"]["encoder_hidden"] 75 | ) 76 | self.energy_embedding = nn.Embedding( 77 | n_bins, model_config["transformer"]["encoder_hidden"] 78 | ) 79 | 80 | def get_pitch_embedding(self, x, target, mask, control): 81 | prediction = self.pitch_predictor(x, mask) 82 | if target is not None: 83 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) 84 | else: 85 | prediction = prediction * control 86 | embedding = self.pitch_embedding( 87 | torch.bucketize(prediction, self.pitch_bins) 88 | ) 89 | return prediction, embedding 90 | 91 | def get_energy_embedding(self, x, target, mask, control): 92 | prediction = self.energy_predictor(x, mask) 93 | if target is not None: 94 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) 95 | else: 96 | prediction = prediction * control 97 | embedding = self.energy_embedding( 98 | torch.bucketize(prediction, self.energy_bins) 99 | ) 100 | return prediction, embedding 101 | 102 | def forward( 103 | self, 104 | x, 105 | src_mask, 106 | mel_mask=None, 107 | max_len=None, 108 | pitch_target=None, 109 | energy_target=None, 110 | duration_target=None, 111 | p_control=1.0, 112 | e_control=1.0, 113 | d_control=1.0, 114 | ): 115 | 116 | log_duration_prediction = self.duration_predictor(x, src_mask) 117 | if self.pitch_feature_level == "phoneme_level": 118 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 119 | x, pitch_target, src_mask, p_control 120 | ) 121 | x = x + pitch_embedding 122 | if self.energy_feature_level == "phoneme_level": 123 | energy_prediction, energy_embedding = self.get_energy_embedding( 124 | x, energy_target, src_mask, p_control 125 | ) 126 | x = x + energy_embedding 127 | 128 | if duration_target is not None: 129 | x, mel_len = self.length_regulator(x, duration_target, max_len) 130 | duration_rounded = duration_target 131 | else: 132 | duration_rounded = torch.clamp( 133 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 134 | min=0, 135 | ) 136 | x, mel_len = self.length_regulator(x, duration_rounded, max_len) 137 | mel_mask = get_mask_from_lengths(mel_len) 138 | 139 | if self.pitch_feature_level == "frame_level": 140 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 141 | x, pitch_target, mel_mask, p_control 142 | ) 143 | x = x + pitch_embedding 144 | if self.energy_feature_level == "frame_level": 145 | energy_prediction, energy_embedding = self.get_energy_embedding( 146 | x, energy_target, mel_mask, p_control 147 | ) 148 | x = x + energy_embedding 149 | 150 | return ( 151 | x, 152 | pitch_prediction, 153 | energy_prediction, 154 | log_duration_prediction, 155 | duration_rounded, 156 | mel_len, 157 | mel_mask, 158 | ) 159 | 160 | 161 | class LengthRegulator(nn.Module): 162 | """ Length Regulator """ 163 | 164 | def __init__(self): 165 | super(LengthRegulator, self).__init__() 166 | 167 | def LR(self, x, duration, max_len): 168 | output = list() 169 | mel_len = list() 170 | for batch, expand_target in zip(x, duration): 171 | expanded = self.expand(batch, expand_target) 172 | output.append(expanded) 173 | mel_len.append(expanded.shape[0]) 174 | 175 | if max_len is not None: 176 | output = pad(output, max_len) 177 | else: 178 | output = pad(output) 179 | 180 | return output, torch.LongTensor(mel_len).to(device) 181 | 182 | def expand(self, batch, predicted): 183 | out = list() 184 | 185 | for i, vec in enumerate(batch): 186 | expand_size = predicted[i].item() 187 | out.append(vec.expand(max(int(expand_size), 0), -1)) 188 | out = torch.cat(out, 0) 189 | 190 | return out 191 | 192 | def forward(self, x, duration, max_len): 193 | output, mel_len = self.LR(x, duration, max_len) 194 | return output, mel_len 195 | 196 | 197 | class VariancePredictor(nn.Module): 198 | """ Duration, Pitch and Energy Predictor """ 199 | 200 | def __init__(self, model_config): 201 | super(VariancePredictor, self).__init__() 202 | 203 | self.input_size = model_config["transformer"]["encoder_hidden"] 204 | self.filter_size = model_config["variance_predictor"]["filter_size"] 205 | self.kernel = model_config["variance_predictor"]["kernel_size"] 206 | self.conv_output_size = model_config["variance_predictor"]["filter_size"] 207 | self.dropout = model_config["variance_predictor"]["dropout"] 208 | 209 | self.conv_layer = nn.Sequential( 210 | OrderedDict( 211 | [ 212 | ( 213 | "conv1d_1", 214 | Conv( 215 | self.input_size, 216 | self.filter_size, 217 | kernel_size=self.kernel, 218 | padding=(self.kernel - 1) // 2, 219 | ), 220 | ), 221 | ("relu_1", nn.ReLU()), 222 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 223 | ("dropout_1", nn.Dropout(self.dropout)), 224 | ( 225 | "conv1d_2", 226 | Conv( 227 | self.filter_size, 228 | self.filter_size, 229 | kernel_size=self.kernel, 230 | padding=1, 231 | ), 232 | ), 233 | ("relu_2", nn.ReLU()), 234 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 235 | ("dropout_2", nn.Dropout(self.dropout)), 236 | ] 237 | ) 238 | ) 239 | 240 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 241 | 242 | def forward(self, encoder_output, mask): 243 | out = self.conv_layer(encoder_output) 244 | out = self.linear_layer(out) 245 | out = out.squeeze(-1) 246 | 247 | if mask is not None: 248 | out = out.masked_fill(mask, 0.0) 249 | 250 | return out 251 | 252 | 253 | class Conv(nn.Module): 254 | """ 255 | Convolution Module 256 | """ 257 | 258 | def __init__( 259 | self, 260 | in_channels, 261 | out_channels, 262 | kernel_size=1, 263 | stride=1, 264 | padding=0, 265 | dilation=1, 266 | bias=True, 267 | w_init="linear", 268 | ): 269 | """ 270 | :param in_channels: dimension of input 271 | :param out_channels: dimension of output 272 | :param kernel_size: size of kernel 273 | :param stride: size of stride 274 | :param padding: size of padding 275 | :param dilation: dilation rate 276 | :param bias: boolean. if True, bias is included. 277 | :param w_init: str. weight inits with xavier initialization. 278 | """ 279 | super(Conv, self).__init__() 280 | 281 | self.conv = nn.Conv1d( 282 | in_channels, 283 | out_channels, 284 | kernel_size=kernel_size, 285 | stride=stride, 286 | padding=padding, 287 | dilation=dilation, 288 | bias=bias, 289 | ) 290 | 291 | def forward(self, x): 292 | x = x.contiguous().transpose(1, 2) 293 | x = self.conv(x) 294 | x = x.contiguous().transpose(1, 2) 295 | 296 | return x 297 | -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, train_config, model_config, current_step): 9 | 10 | self._optimizer = torch.optim.Adam( 11 | model.parameters(), 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"] 19 | self.current_step = current_step 20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 21 | 22 | def step_and_update_lr(self): 23 | self._update_learning_rate() 24 | self._optimizer.step() 25 | 26 | def zero_grad(self): 27 | # print(self.init_lr) 28 | self._optimizer.zero_grad() 29 | 30 | def load_state_dict(self, path): 31 | self._optimizer.load_state_dict(path) 32 | 33 | def _get_lr_scale(self): 34 | lr = np.min( 35 | [ 36 | np.power(self.current_step, -0.5), 37 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 38 | ] 39 | ) 40 | for s in self.anneal_steps: 41 | if self.current_step > s: 42 | lr = lr * self.anneal_rate 43 | return lr 44 | 45 | def _update_learning_rate(self): 46 | """ Learning rate scheduling per step """ 47 | self.current_step += 1 48 | lr = self.init_lr * self._get_lr_scale() 49 | 50 | for param_group in self._optimizer.param_groups: 51 | param_group["lr"] = lr 52 | -------------------------------------------------------------------------------- /preparation/aihub_mmv.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | import yaml 4 | import os 5 | import shutil 6 | import json 7 | import librosa 8 | import soundfile 9 | from glob import glob 10 | from tqdm import tqdm 11 | from moviepy.editor import VideoFileClip 12 | from text import _clean_text 13 | from text.korean import tokenize, normalize_nonchar 14 | 15 | 16 | def write_text(txt_path, text): 17 | with open(txt_path, 'w', encoding='utf-8') as f: 18 | f.write(text) 19 | 20 | 21 | def get_sorted_items(items): 22 | # sort by key 23 | return sorted(items, key=lambda x:int(x[0])) 24 | 25 | 26 | def get_emotion(emo_dict): 27 | e, a, v = 0, 0, 0 28 | if 'emotion' in emo_dict: 29 | e = emo_dict['emotion'] 30 | a = emo_dict['arousal'] 31 | v = emo_dict['valence'] 32 | return e, a, v 33 | 34 | 35 | def pad_spk_id(speaker_id): 36 | return 'p{}'.format("0"*(3-len(speaker_id))+speaker_id) 37 | 38 | 39 | def create_dataset(preprocess_config): 40 | """ 41 | See https://github.com/Kyumin-Park/aihub_multimodal_speech 42 | """ 43 | in_dir = preprocess_config["path"]["corpus_path"] 44 | audio_dir = os.path.join(os.path.dirname(in_dir), os.path.basename(in_dir)+"_audio") 45 | out_dir = os.path.join(os.path.dirname(in_dir), os.path.basename(in_dir)+"_preprocessed") 46 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 47 | 48 | print("Gather audio...") 49 | video_files = glob(f'{in_dir}/**/*.mp4', recursive=True) 50 | 51 | if os.path.exists(out_dir): 52 | shutil.rmtree(out_dir) 53 | os.makedirs(out_dir) 54 | filelist = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8') 55 | speaker_info, speaker_done = dict(), set() 56 | total_duration = 0 57 | 58 | print("Create dataset...") 59 | for video_path in tqdm(video_files): 60 | # Load annotation file 61 | file_name = os.path.splitext(os.path.basename(video_path))[0] 62 | json_path = video_path.replace('mp4', 'json') 63 | try: 64 | with open(json_path, 'r', encoding='utf-8') as f: 65 | annotation = json.load(f) 66 | except UnicodeDecodeError: 67 | continue 68 | 69 | # Load video clip 70 | audio_path = video_path.replace(in_dir, audio_dir, 1).replace('mp4', 'wav') 71 | orig_sr = librosa.get_samplerate(audio_path) 72 | y, sr = librosa.load(audio_path, sr=orig_sr) 73 | duration = librosa.get_duration(y, sr=sr) 74 | new_sr = sampling_rate 75 | new_y = librosa.resample(y, sr, new_sr) 76 | 77 | # Metadata 78 | n_frames = float(annotation['nr_frame']) 79 | fps = n_frames / duration 80 | for spk_id, spk_info in annotation['actor'].items(): 81 | if spk_id not in speaker_done: 82 | speaker_info[spk_id] = spk_info 83 | speaker_done.add(spk_id) 84 | 85 | turn_id = 0 86 | done = set() 87 | for frame, frame_data in get_sorted_items(annotation['data'].items()): 88 | for sub_id, info_data in frame_data.items(): 89 | if 'text' not in info_data.keys(): 90 | continue 91 | 92 | # Extract data 93 | text_data = info_data['text'] 94 | emotion_data = info_data['emotion'] 95 | speaker_id = info_data['person_id'] 96 | start_frame = text_data['script_start'] 97 | end_frame = text_data['script_end'] 98 | intent = text_data['intent'] 99 | strategy = text_data['strategy'] 100 | 101 | et_e, et_a, et_v = get_emotion(emotion_data['text']) 102 | es_e, es_a, es_v = get_emotion(emotion_data['sound']) 103 | ei_e, ei_a, ei_v = get_emotion(emotion_data['image']) 104 | em_e, em_a, em_v = get_emotion(emotion_data['multimodal']) 105 | 106 | script = refine_text(text_data['script']) 107 | 108 | start_idx = int(float(start_frame) / fps * new_sr) 109 | end_idx = int(float(end_frame) / fps * new_sr) 110 | 111 | # Write wav 112 | y_part = new_y[start_idx:end_idx] 113 | speaker_id = pad_spk_id(speaker_id) 114 | file_name = file_name.replace('clip_', 'c') 115 | framename = f'{start_frame}-{end_frame}' 116 | basename = f'{turn_id}_{speaker_id}_{file_name}_{framename}' 117 | wav_path = os.path.join(os.path.dirname(audio_path).replace(audio_dir, out_dir), 118 | f'{basename}.wav') 119 | if framename not in done: 120 | os.makedirs(os.path.dirname(wav_path), exist_ok=True) 121 | soundfile.write(wav_path, y_part, new_sr) 122 | write_text(wav_path.replace('.wav', '.txt'), script) 123 | 124 | # Write filelist 125 | filelist.write(f'{basename}|{script}|{speaker_id}|{intent}|{strategy}|{et_e}|{et_a}|{et_v}|{es_e}|{es_a}|{es_v}|{ei_e}|{ei_a}|{ei_v}|{em_e}|{em_a}|{em_v}\n') 126 | total_duration += (end_idx - start_idx) / float(new_sr) 127 | 128 | done.add(f'{framename}') 129 | turn_id += 1 130 | 131 | filelist.close() 132 | 133 | # Save Speaker Info 134 | with open(f'{out_dir}/speaker_info.txt', 'w', encoding='utf-8') as f: 135 | for spk_id, spk_info in get_sorted_items(speaker_info.items()): 136 | gender = 'F' if spk_info['gender'] == 'female' else 'M' 137 | age = spk_info['age'] 138 | spk_id = pad_spk_id(speaker_id) 139 | f.write(f'{spk_id}|{gender}|{age}\n') 140 | 141 | print(f'End parsing, total duration: {total_duration}') 142 | 143 | 144 | def refine_text(text): 145 | # Fix invalid characters in text 146 | text = text.replace('…', ',') 147 | text = text.replace('\t', '') 148 | text = text.replace('-', ',') 149 | text = text.replace('–', ',') 150 | text = ' '.join(text.split()) 151 | return text 152 | 153 | 154 | def extract_audio(preprocess_config): 155 | in_dir = preprocess_config["path"]["corpus_path"] 156 | out_dir = os.path.join(os.path.dirname(in_dir), os.path.basename(in_dir)+"_tmp") 157 | video_files = glob(f'{in_dir}/**/*.mp4', recursive=True) 158 | 159 | print("Extract audio...") 160 | for video_path in tqdm(video_files): 161 | audio_path = video_path.replace(in_dir, out_dir, 1).replace('mp4', 'wav') 162 | os.makedirs(os.path.dirname(audio_path), exist_ok=True) 163 | 164 | clip = VideoFileClip(video_path) 165 | clip.audio.write_audiofile(audio_path, verbose=False) 166 | clip.close() 167 | 168 | 169 | def extract_nonkr(preprocess_config): 170 | in_dir = preprocess_config["path"]["raw_path"] 171 | filelist = open(f'{in_dir}/nonkr.txt', 'w', encoding='utf-8') 172 | 173 | count = 0 174 | nonkr = set() 175 | print("Extract non korean charactors...") 176 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as f: 177 | lines = f.readlines() 178 | total_count = len(lines) 179 | for line in tqdm(lines): 180 | wav = line.split('|')[0] 181 | text = line.split('|')[1] 182 | reg = re.compile("""[^ ㄱ-ㅣ가-힣~!.,?:{}`"'"“‘’”’()\[\]]+""") 183 | impurities = reg.findall(text) 184 | if len(impurities) == 0: 185 | count+=1 186 | continue 187 | norm = _clean_text(text, preprocess_config["preprocessing"]["text"]["text_cleaners"]) 188 | impurities_str = ','.join(impurities) 189 | filelist.write(f'{norm}|{text}|{impurities_str}|{wav}\n') 190 | for imp in impurities: 191 | nonkr.add(imp) 192 | filelist.close() 193 | print('Total {} non korean charactors from {} lines'.format(len(nonkr), total_count-count)) 194 | print(sorted(list(nonkr))) 195 | 196 | 197 | def extract_lexicon(preprocess_config): 198 | """ 199 | Extract lexicon and build grapheme-phoneme dictionary for MFA training 200 | See https://github.com/HGU-DLLAB/Korean-FastSpeech2-Pytorch 201 | """ 202 | in_dir = preprocess_config["path"]["raw_path"] 203 | lexicon_path = preprocess_config["path"]["lexicon_path"] 204 | filelist = open(lexicon_path, 'a+', encoding='utf-8') 205 | 206 | # Load Lexicon Dictionary 207 | done = set() 208 | if os.path.isfile(lexicon_path): 209 | filelist.seek(0) 210 | for line in filelist.readlines(): 211 | grapheme = line.split("\t")[0] 212 | done.add(grapheme) 213 | 214 | print("Extract lexicon...") 215 | for lab in tqdm(glob(f'{in_dir}/**/*.lab', recursive=True)): 216 | with open(lab, 'r', encoding='utf-8') as f: 217 | text = f.readline().strip("\n") 218 | assert text == normalize_nonchar(text), "No special token should be left." 219 | 220 | for grapheme in text.split(" "): 221 | if not grapheme in done: 222 | phoneme = " ".join(tokenize(grapheme, norm=False)) 223 | filelist.write("{}\t{}\n".format(grapheme, phoneme)) 224 | done.add(grapheme) 225 | filelist.close() 226 | 227 | 228 | def apply_fixed_text(preprocess_config): 229 | in_dir = preprocess_config["path"]["corpus_path"] 230 | sub_dir = preprocess_config["path"]["sub_dir_name"] 231 | out_dir = preprocess_config["path"]["raw_path"] 232 | fixed_text_path = preprocess_config["path"]["fixed_text_path"] 233 | cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 234 | 235 | fixed_text_dict = dict() 236 | print("Fixing transcripts...") 237 | with open(fixed_text_path, 'r', encoding='utf-8') as f: 238 | for line in tqdm(f.readlines()): 239 | wav, fixed_text = line.split('|')[0], line.split('|')[1] 240 | clip_name = wav.split('_')[2].replace('c', 'clip_') 241 | fixed_text_dict[wav] = fixed_text.replace('\n', '') 242 | 243 | text = _clean_text(fixed_text, cleaners) 244 | with open( 245 | os.path.join(out_dir, sub_dir, clip_name, "{}.lab".format(wav)), 246 | "w", 247 | ) as f1: 248 | f1.write(text) 249 | 250 | filelist_fixed = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8') 251 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as filelist: 252 | for line in tqdm(filelist.readlines()): 253 | wav = line.split('|')[0] 254 | if wav in fixed_text_dict: 255 | filelist_fixed.write("|".join([line.split("|")[0]] + [fixed_text_dict[wav]] + line.split("|")[2:])) 256 | else: 257 | filelist_fixed.write(line) 258 | filelist_fixed.close() 259 | 260 | extract_lexicon(preprocess_config) -------------------------------------------------------------------------------- /preparation/iemocap.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | import yaml 4 | import os 5 | import shutil 6 | import json 7 | import librosa 8 | import soundfile 9 | from glob import glob 10 | from tqdm import tqdm 11 | from moviepy.editor import VideoFileClip 12 | from text import _clean_text 13 | from text.korean import normalize_nonchar 14 | from g2p_en import G2p 15 | 16 | 17 | def extract_nonen(preprocess_config): 18 | in_dir = preprocess_config["path"]["raw_path"] 19 | filelist = open(f'{in_dir}/nonen.txt', 'w', encoding='utf-8') 20 | 21 | count = 0 22 | nonen = set() 23 | print("Extract non english charactors...") 24 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as f: 25 | lines = f.readlines() 26 | total_count = len(lines) 27 | for line in tqdm(lines): 28 | wav = line.split('|')[0] 29 | text = line.split('|')[1] 30 | 31 | reg = re.compile("""[^ a-zA-Z~!.,?:`"'"“‘’”’]+""") 32 | impurities = reg.findall(text) 33 | if len(impurities) == 0: 34 | count+=1 35 | continue 36 | norm = _clean_text(text, preprocess_config["preprocessing"]["text"]["text_cleaners"]) 37 | impurities_str = ','.join(impurities) 38 | filelist.write(f'{norm}|{text}|{impurities_str}|{wav}\n') 39 | for imp in impurities: 40 | nonen.add(imp) 41 | filelist.close() 42 | print('Total {} non english charactors from {} lines'.format(len(nonen), total_count-count)) 43 | print(sorted(list(nonen))) 44 | 45 | 46 | def extract_lexicon(preprocess_config): 47 | """ 48 | Extract lexicon and build grapheme-phoneme dictionary for MFA training 49 | """ 50 | in_dir = preprocess_config["path"]["raw_path"] 51 | lexicon_path = preprocess_config["path"]["lexicon_path"] 52 | filelist = open(lexicon_path, 'a+', encoding='utf-8') 53 | 54 | # Load Lexicon Dictionary 55 | done = set() 56 | if os.path.isfile(lexicon_path): 57 | filelist.seek(0) 58 | for line in filelist.readlines(): 59 | grapheme = line.split("\t")[0] 60 | done.add(grapheme) 61 | 62 | print("Extract lexicon...") 63 | g2p = G2p() 64 | for lab in tqdm(glob(f'{in_dir}/**/*.lab', recursive=True)): 65 | with open(lab, 'r', encoding='utf-8') as f: 66 | text = f.readline().strip("\n") 67 | text = normalize_nonchar(text) 68 | 69 | for grapheme in text.split(" "): 70 | if not grapheme in done: 71 | phoneme = " ".join(g2p(grapheme)) 72 | filelist.write("{}\t{}\n".format(grapheme, phoneme)) 73 | done.add(grapheme) 74 | filelist.close() 75 | 76 | 77 | def apply_fixed_text(preprocess_config): 78 | in_dir = preprocess_config["path"]["corpus_path"] 79 | sub_dir = preprocess_config["path"]["sub_dir_name"] 80 | out_dir = preprocess_config["path"]["raw_path"] 81 | fixed_text_path = preprocess_config["path"]["fixed_text_path"] 82 | cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 83 | 84 | fixed_text_dict = dict() 85 | print("Fixing transcripts...") 86 | with open(fixed_text_path, 'r', encoding='utf-8') as f: 87 | for line in tqdm(f.readlines()): 88 | wav, fixed_text = line.split('|')[0], line.split('|')[1] 89 | session = '_'.join(wav.split('_')[1:]) 90 | fixed_text_dict[wav] = fixed_text.replace('\n', '') 91 | 92 | text = _clean_text(fixed_text, cleaners) 93 | with open( 94 | os.path.join(out_dir, sub_dir, session, "{}.lab".format(wav)), 95 | "w", 96 | ) as f1: 97 | f1.write(text) 98 | 99 | filelist_fixed = open(f'{out_dir}/filelist_fixed.txt', 'w', encoding='utf-8') 100 | with open(f'{out_dir}/filelist.txt', 'r', encoding='utf-8') as filelist: 101 | for line in tqdm(filelist.readlines()): 102 | wav = line.split('|')[0] 103 | if wav in fixed_text_dict: 104 | filelist_fixed.write("|".join([line.split("|")[0]] + [fixed_text_dict[wav]] + line.split("|")[2:])) 105 | else: 106 | filelist_fixed.write(line) 107 | filelist_fixed.close() 108 | 109 | os.remove(f'{out_dir}/filelist.txt') 110 | os.rename(f'{out_dir}/filelist_fixed.txt', f'{out_dir}/filelist.txt') 111 | 112 | extract_lexicon(preprocess_config) -------------------------------------------------------------------------------- /prepare_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from preprocessor import aihub_mmv, iemocap 6 | 7 | 8 | def main(config): 9 | if "AIHub-MMV" in config["dataset"]: 10 | aihub_mmv.prepare_align(config) 11 | if "IEMOCAP" in config["dataset"]: 12 | iemocap.prepare_align(config) 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 18 | args = parser.parse_args() 19 | 20 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 21 | main(config) 22 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | 5 | from preparation import aihub_mmv, iemocap 6 | 7 | 8 | def main(args, preprocess_config): 9 | os.makedirs("./lexicon", exist_ok=True) 10 | os.makedirs("./preprocessed_data", exist_ok=True) 11 | os.makedirs("./montreal-forced-aligner", exist_ok=True) 12 | 13 | if "AIHub-MMV" in preprocess_config["dataset"]: 14 | if args.extract_nonkr: 15 | aihub_mmv.extract_nonkr(preprocess_config) 16 | elif args.extract_lexicon: 17 | aihub_mmv.extract_lexicon(preprocess_config) 18 | elif args.apply_fixed_text: 19 | aihub_mmv.apply_fixed_text(preprocess_config) 20 | else: 21 | if args.extract_audio: 22 | aihub_mmv.extract_audio(preprocess_config) 23 | aihub_mmv.create_dataset(preprocess_config) 24 | elif "IEMOCAP" in preprocess_config["dataset"]: 25 | if args.extract_nonen: 26 | iemocap.extract_nonen(preprocess_config) 27 | elif args.extract_lexicon: 28 | iemocap.extract_lexicon(preprocess_config) 29 | elif args.apply_fixed_text: 30 | iemocap.apply_fixed_text(preprocess_config) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "-p", 37 | "--preprocess_config", 38 | type=str, 39 | required=True, 40 | help="path to preprocess.yaml", 41 | ) 42 | parser.add_argument( 43 | '--extract_audio', 44 | help='convert video into .wav file', 45 | action='store_true', 46 | ) 47 | parser.add_argument( 48 | '--extract_nonkr', 49 | help='extract non korean charactor', 50 | action='store_true', 51 | ) 52 | parser.add_argument( 53 | '--extract_nonen', 54 | help='extract non english charactor', 55 | action='store_true', 56 | ) 57 | parser.add_argument( 58 | '--extract_lexicon', 59 | help='extract lexicon and build grapheme-phoneme dictionary', 60 | action='store_true', 61 | ) 62 | parser.add_argument( 63 | '--apply_fixed_text', 64 | help='apply fixed text to both raw data and filelist', 65 | action='store_true', 66 | ) 67 | args = parser.parse_args() 68 | 69 | preprocess_config = yaml.load( 70 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 71 | ) 72 | 73 | main(args, preprocess_config) 74 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from preprocessor.preprocessor import Preprocessor 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 11 | args = parser.parse_args() 12 | 13 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 14 | preprocessor = Preprocessor(config) 15 | preprocessor.build_from_path() 16 | -------------------------------------------------------------------------------- /preprocessor/aihub_mmv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | from shutil import copyfile 8 | 9 | from text import _clean_text 10 | 11 | 12 | def prepare_align(config): 13 | in_dir = config["path"]["corpus_path"] 14 | sub_dir = config["path"]["sub_dir_name"] 15 | out_dir = config["path"]["raw_path"] 16 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 17 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 18 | fixed_text_path = config["path"]["fixed_text_path"] 19 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 20 | 21 | fixed_text_dict = dict() 22 | with open(fixed_text_path, 'r', encoding='utf-8') as f: 23 | for line in tqdm(f.readlines()): 24 | wav, fixed_text = line.split('|')[0], line.split('|')[1] 25 | fixed_text_dict[wav] = fixed_text.replace('\n', '') 26 | 27 | for sep_dir in tqdm(next(os.walk(in_dir))[1]): 28 | for clip_name in os.listdir(os.path.join(in_dir, sep_dir)): 29 | for file_name in os.listdir(os.path.join(in_dir, sep_dir, clip_name)): 30 | if file_name[-4:] != ".wav": 31 | continue 32 | base_name = file_name[:-4] 33 | text_path = os.path.join( 34 | in_dir, sep_dir, clip_name, "{}.txt".format(base_name) 35 | ) 36 | wav_path = os.path.join( 37 | in_dir, sep_dir, clip_name, "{}.wav".format(base_name) 38 | ) 39 | if base_name in fixed_text_dict: 40 | text = fixed_text_dict[base_name] 41 | else: 42 | with open(text_path) as f: 43 | text = f.readline().strip("\n") 44 | text = _clean_text(text, cleaners) 45 | 46 | os.makedirs(os.path.join(out_dir, sub_dir, clip_name), exist_ok=True) 47 | wav, _ = librosa.load(wav_path, sampling_rate) 48 | wav = wav / max(abs(wav)) * max_wav_value 49 | wavfile.write( 50 | os.path.join(out_dir, sub_dir, clip_name, "{}.wav".format(base_name)), 51 | sampling_rate, 52 | wav.astype(np.int16), 53 | ) 54 | with open( 55 | os.path.join(out_dir, sub_dir, clip_name, "{}.lab".format(base_name)), 56 | "w", 57 | ) as f1: 58 | f1.write(text) 59 | 60 | # Filelist 61 | filelist_fixed = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8') 62 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as filelist: 63 | for line in tqdm(filelist.readlines()): 64 | wav = line.split('|')[0] 65 | if wav in fixed_text_dict: 66 | filelist_fixed.write("|".join([line.split("|")[0]] + [fixed_text_dict[wav]] + line.split("|")[2:])) 67 | else: 68 | filelist_fixed.write(line) 69 | filelist_fixed.close() 70 | 71 | # Speaker Info 72 | copyfile(f'{in_dir}/speaker_info.txt', f'{out_dir}/speaker_info.txt') -------------------------------------------------------------------------------- /preprocessor/iemocap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | from text import _clean_text 9 | 10 | _square_brackets_re = re.compile(r"\[[\w\d\s]+\]") 11 | _inv_square_brackets_re = re.compile(r"(.*?)\](.+?)\[(.*)") 12 | 13 | 14 | def get_sorted_items(items): 15 | # sort by key 16 | return sorted(items, key=lambda x:x[0]) 17 | 18 | 19 | def prepare_align(config): 20 | in_dir = config["path"]["corpus_path"] 21 | sub_dir = config["path"]["sub_dir_name"] 22 | out_dir = config["path"]["raw_path"] 23 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 24 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 25 | fixed_text_path = config["path"]["fixed_text_path"] 26 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 27 | 28 | os.makedirs(os.path.join(out_dir), exist_ok=True) 29 | filelist_fixed = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8') 30 | speaker_info, speaker_done = dict(), set() 31 | 32 | fixed_text_dict = dict() 33 | with open(fixed_text_path, 'r', encoding='utf-8') as f: 34 | for line in tqdm(f.readlines()): 35 | wav, fixed_text = line.split('|')[0], line.split('|')[1] 36 | fixed_text_dict[wav] = fixed_text.replace('\n', '') 37 | 38 | for sep_dir in tqdm(next(os.walk(in_dir))[1]): 39 | if sub_dir[:-1] not in sep_dir.lower(): 40 | continue 41 | for wav_dir in tqdm((next(os.walk(os.path.join(in_dir, sep_dir, "sentences", "wav")))[1])): 42 | 43 | # Build Text Dict 44 | text_dict = dict() 45 | text_raw_path = os.path.join( 46 | in_dir, sep_dir, "dialog", "transcriptions", "{}.txt".format(wav_dir) 47 | ) 48 | with open(text_raw_path) as f: 49 | for line in f.readlines(): 50 | base_name = line.split("[")[0].strip() 51 | transcript = line.split("]:")[-1].strip() 52 | text_dict[base_name] = transcript 53 | 54 | # Build Emotion Dict 55 | emo_dict = dict() 56 | emo_raw_path = os.path.join( 57 | in_dir, sep_dir, "dialog", "EmoEvaluation", "{}.txt".format(wav_dir) 58 | ) 59 | with open(emo_raw_path) as f: 60 | for line in f.readlines()[1:]: 61 | if "[" not in line or "%" in line: 62 | continue 63 | m = _inv_square_brackets_re.match(" ".join(line.split())) 64 | base_name, emo_gt = m.group(2).strip().split(" ") 65 | valence, arousal = m.group(3).split(",")[0].strip(), m.group(3).split(",")[1].strip() 66 | emo_dict[base_name] = { 67 | "e": emo_gt, 68 | "a": arousal, 69 | "v": valence, 70 | } 71 | 72 | for file_name in os.listdir(os.path.join(in_dir, sep_dir, "sentences", "wav", wav_dir)): 73 | if file_name[0] == "." or file_name[-4:] != ".wav": 74 | continue 75 | base_name = file_name[:-4] 76 | if len(base_name.split("_")) == 3: 77 | spk_id, dialog_type, turn = base_name.split("_") 78 | elif len(base_name.split("_")) == 4: 79 | spk_id, dialog_type, turn = base_name.split("_")[0], "_".join(base_name.split("_")[1:3]), base_name.split("_")[3] 80 | base_name_new = "_".join([turn, spk_id, dialog_type]) 81 | 82 | if spk_id not in speaker_done: 83 | speaker_info[spk_id] = { 84 | 'gender': spk_id[-1] 85 | } 86 | speaker_done.add(spk_id) 87 | 88 | wav_path = os.path.join( 89 | in_dir, sep_dir, "sentences", "wav", wav_dir, "{}.wav".format(base_name) 90 | ) 91 | if base_name in fixed_text_dict: 92 | text = fixed_text_dict[base_name] 93 | else: 94 | text = text_dict[base_name] 95 | text = re.sub(_square_brackets_re, "", text) 96 | text = ' '.join(text.split()) 97 | text = _clean_text(text, cleaners) 98 | 99 | os.makedirs(os.path.join(out_dir, sub_dir, wav_dir), exist_ok=True) 100 | wav, _ = librosa.load(wav_path, sampling_rate) 101 | wav = wav / max(abs(wav)) * max_wav_value 102 | wavfile.write( 103 | os.path.join(out_dir, sub_dir, wav_dir, "{}.wav".format(base_name_new)), 104 | sampling_rate, 105 | wav.astype(np.int16), 106 | ) 107 | with open( 108 | os.path.join(out_dir, sub_dir, wav_dir, "{}.lab".format(base_name_new)), 109 | "w", 110 | ) as f1: 111 | f1.write(text) 112 | 113 | # Filelist 114 | emo_ = emo_dict[base_name] 115 | emotion, arousal, valence = emo_["e"], emo_["a"], emo_["v"] 116 | filelist_fixed.write("|".join([base_name_new, text, spk_id, emotion, arousal, valence]) + "\n") 117 | filelist_fixed.close() 118 | 119 | # Save Speaker Info 120 | with open(f'{out_dir}/speaker_info.txt', 'w', encoding='utf-8') as f: 121 | for spk_id, spk_info in get_sorted_items(speaker_info.items()): 122 | gender = spk_info['gender'] 123 | f.write(f'{spk_id}|{gender}\n') -------------------------------------------------------------------------------- /preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | 5 | import tgt 6 | import librosa 7 | import numpy as np 8 | import pyworld as pw 9 | from scipy.interpolate import interp1d 10 | from sklearn.preprocessing import StandardScaler 11 | from tqdm import tqdm 12 | 13 | import audio as Audio 14 | 15 | random.seed(1234) 16 | 17 | class Preprocessor: 18 | def __init__(self, config): 19 | self.config = config 20 | self.dataset = config["dataset"] 21 | self.sub_dir = "" 22 | self.speakers = dict() 23 | self.emotions = dict() 24 | self.sub_dir = config["path"]["sub_dir_name"] 25 | self.speakers = self.load_speaker_dict() 26 | self.filelist, self.emotions = self.load_filelist_dict() 27 | self.in_dir = os.path.join(config["path"]["raw_path"], self.sub_dir) 28 | self.out_dir = config["path"]["preprocessed_path"] 29 | self.val_size = config["preprocessing"]["val_size"] 30 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 31 | self.hop_length = config["preprocessing"]["stft"]["hop_length"] 32 | 33 | assert config["preprocessing"]["pitch"]["feature"] in [ 34 | "phoneme_level", 35 | "frame_level", 36 | ] 37 | assert config["preprocessing"]["energy"]["feature"] in [ 38 | "phoneme_level", 39 | "frame_level", 40 | ] 41 | self.pitch_phoneme_averaging = ( 42 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level" 43 | ) 44 | self.energy_phoneme_averaging = ( 45 | config["preprocessing"]["energy"]["feature"] == "phoneme_level" 46 | ) 47 | 48 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"] 49 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"] 50 | 51 | self.STFT = Audio.stft.TacotronSTFT( 52 | config["preprocessing"]["stft"]["filter_length"], 53 | config["preprocessing"]["stft"]["hop_length"], 54 | config["preprocessing"]["stft"]["win_length"], 55 | config["preprocessing"]["mel"]["n_mel_channels"], 56 | config["preprocessing"]["audio"]["sampling_rate"], 57 | config["preprocessing"]["mel"]["mel_fmin"], 58 | config["preprocessing"]["mel"]["mel_fmax"], 59 | ) 60 | 61 | def load_speaker_dict(self): 62 | spk_dir = os.path.join(self.config["path"]["raw_path"], 'speaker_info.txt') 63 | spk_dict = dict() 64 | with open(spk_dir, 'r', encoding='utf-8') as f: 65 | for i, line in enumerate(f.readlines()): 66 | spk_id = line.split("|")[0] 67 | spk_dict[spk_id] = i 68 | return spk_dict 69 | 70 | def load_filelist_dict(self): 71 | filelist_dir = os.path.join(self.config["path"]["raw_path"], 'filelist.txt') 72 | filelist_dict, emotion_dict, arousal_dict, valence_dict = dict(), dict(), dict(), dict() 73 | emotions, arousals, valences = set(), set(), set() 74 | with open(filelist_dir, 'r', encoding='utf-8') as f: 75 | for i, line in enumerate(f.readlines()): 76 | basename, aux_data = line.split("|")[0], line.split("|")[3:] 77 | filelist_dict[basename] = "|".join(aux_data).strip("\n") 78 | emotions.add(aux_data[-3]) 79 | arousals.add(aux_data[-2]) 80 | valences.add(aux_data[-1].strip("\n")) 81 | for i, emotion in enumerate(list(emotions)): 82 | emotion_dict[emotion] = i 83 | for i, arousal in enumerate(list(arousals)): 84 | arousal_dict[arousal] = i 85 | for i, valence in enumerate(list(valences)): 86 | valence_dict[valence] = i 87 | emotion_dict = { 88 | "emotion_dict": emotion_dict, 89 | "arousal_dict": arousal_dict, 90 | "valence_dict": valence_dict, 91 | } 92 | return filelist_dict, emotion_dict 93 | 94 | def build_from_path(self): 95 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True) 96 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True) 97 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True) 98 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True) 99 | 100 | print("Processing Data ...") 101 | out = list() 102 | n_frames = 0 103 | pitch_scaler = StandardScaler() 104 | energy_scaler = StandardScaler() 105 | 106 | # Compute pitch, energy, duration, and mel-spectrogram 107 | speakers = self.speakers.copy() 108 | for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))): 109 | if len(self.speakers) == 0: 110 | speakers[speaker] = i 111 | for wav_name in os.listdir(os.path.join(self.in_dir, speaker)): 112 | if ".wav" not in wav_name: 113 | continue 114 | 115 | basename = wav_name.split(".")[0] 116 | tg_path = os.path.join( 117 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 118 | ) 119 | if os.path.exists(tg_path): 120 | ret = self.process_utterance(speaker, basename) 121 | if ret is None: 122 | continue 123 | else: 124 | info, pitch, energy, n = ret 125 | out.append(info) 126 | 127 | if len(pitch) > 0: 128 | pitch_scaler.partial_fit(pitch.reshape((-1, 1))) 129 | if len(energy) > 0: 130 | energy_scaler.partial_fit(energy.reshape((-1, 1))) 131 | 132 | n_frames += n 133 | 134 | print("Computing statistic quantities ...") 135 | # Perform normalization if necessary 136 | if self.pitch_normalization: 137 | pitch_mean = pitch_scaler.mean_[0] 138 | pitch_std = pitch_scaler.scale_[0] 139 | else: 140 | # A numerical trick to avoid normalization... 141 | pitch_mean = 0 142 | pitch_std = 1 143 | if self.energy_normalization: 144 | energy_mean = energy_scaler.mean_[0] 145 | energy_std = energy_scaler.scale_[0] 146 | else: 147 | energy_mean = 0 148 | energy_std = 1 149 | 150 | pitch_min, pitch_max = self.normalize( 151 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std 152 | ) 153 | energy_min, energy_max = self.normalize( 154 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std 155 | ) 156 | 157 | # Save files 158 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f: 159 | f.write(json.dumps(speakers)) 160 | 161 | if len(self.emotions) != 0: 162 | with open(os.path.join(self.out_dir, "emotions.json"), "w") as f: 163 | f.write(json.dumps(self.emotions)) 164 | 165 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f: 166 | stats = { 167 | "pitch": [ 168 | float(pitch_min), 169 | float(pitch_max), 170 | float(pitch_mean), 171 | float(pitch_std), 172 | ], 173 | "energy": [ 174 | float(energy_min), 175 | float(energy_max), 176 | float(energy_mean), 177 | float(energy_std), 178 | ], 179 | } 180 | f.write(json.dumps(stats)) 181 | 182 | print( 183 | "Total time: {} hours".format( 184 | n_frames * self.hop_length / self.sampling_rate / 3600 185 | ) 186 | ) 187 | 188 | random.shuffle(out) 189 | out = [r for r in out if r is not None] 190 | 191 | # Write metadata 192 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f: 193 | for m in out[self.val_size :]: 194 | f.write(m + "\n") 195 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f: 196 | for m in out[: self.val_size]: 197 | f.write(m + "\n") 198 | 199 | return out 200 | 201 | def process_utterance(self, speaker, basename): 202 | aux_data = "" 203 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename)) 204 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename)) 205 | tg_path = os.path.join( 206 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 207 | ) 208 | speaker = basename.split("_")[1] 209 | aux_data = self.filelist[basename] 210 | 211 | # Get alignments 212 | textgrid = tgt.io.read_textgrid(tg_path) 213 | phone, duration, start, end = self.get_alignment( 214 | textgrid.get_tier_by_name("phones") 215 | ) 216 | text = "{" + " ".join(phone) + "}" 217 | if start >= end: 218 | return None 219 | 220 | # Read and trim wav files 221 | wav, _ = librosa.load(wav_path) 222 | wav = wav[ 223 | int(self.sampling_rate * start) : int(self.sampling_rate * end) 224 | ].astype(np.float32) 225 | 226 | # Read raw text 227 | with open(text_path, "r") as f: 228 | raw_text = f.readline().strip("\n") 229 | 230 | # Compute fundamental frequency 231 | pitch, t = pw.dio( 232 | wav.astype(np.float64), 233 | self.sampling_rate, 234 | frame_period=self.hop_length / self.sampling_rate * 1000, 235 | ) 236 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate) 237 | 238 | pitch = pitch[: sum(duration)] 239 | if np.sum(pitch != 0) <= 1: 240 | return None 241 | 242 | # Compute mel-scale spectrogram and energy 243 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT) 244 | mel_spectrogram = mel_spectrogram[:, : sum(duration)] 245 | energy = energy[: sum(duration)] 246 | 247 | if self.pitch_phoneme_averaging: 248 | # perform linear interpolation 249 | nonzero_ids = np.where(pitch != 0)[0] 250 | interp_fn = interp1d( 251 | nonzero_ids, 252 | pitch[nonzero_ids], 253 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 254 | bounds_error=False, 255 | ) 256 | pitch = interp_fn(np.arange(0, len(pitch))) 257 | 258 | # Phoneme-level average 259 | pos = 0 260 | for i, d in enumerate(duration): 261 | if d > 0: 262 | pitch[i] = np.mean(pitch[pos : pos + d]) 263 | else: 264 | pitch[i] = 0 265 | pos += d 266 | pitch = pitch[: len(duration)] 267 | 268 | if self.energy_phoneme_averaging: 269 | # Phoneme-level average 270 | pos = 0 271 | for i, d in enumerate(duration): 272 | if d > 0: 273 | energy[i] = np.mean(energy[pos : pos + d]) 274 | else: 275 | energy[i] = 0 276 | pos += d 277 | energy = energy[: len(duration)] 278 | 279 | # Save files 280 | dur_filename = "{}-duration-{}.npy".format(speaker, basename) 281 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration) 282 | 283 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename) 284 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch) 285 | 286 | energy_filename = "{}-energy-{}.npy".format(speaker, basename) 287 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy) 288 | 289 | mel_filename = "{}-mel-{}.npy".format(speaker, basename) 290 | np.save( 291 | os.path.join(self.out_dir, "mel", mel_filename), 292 | mel_spectrogram.T, 293 | ) 294 | 295 | return ( 296 | "|".join([basename, speaker, text, raw_text, aux_data]), 297 | self.remove_outlier(pitch), 298 | self.remove_outlier(energy), 299 | mel_spectrogram.shape[1], 300 | ) 301 | 302 | def get_alignment(self, tier): 303 | sil_phones = ["sil", "sp", "spn"] 304 | 305 | phones = [] 306 | durations = [] 307 | start_time = 0 308 | end_time = 0 309 | end_idx = 0 310 | for t in tier._objects: 311 | s, e, p = t.start_time, t.end_time, t.text 312 | 313 | # Trim leading silences 314 | if phones == []: 315 | if p in sil_phones: 316 | continue 317 | else: 318 | start_time = s 319 | 320 | if p not in sil_phones: 321 | # For ordinary phones 322 | phones.append(p) 323 | end_time = e 324 | end_idx = len(phones) 325 | else: 326 | # For silent phones 327 | phones.append(p) 328 | 329 | durations.append( 330 | int( 331 | np.round(e * self.sampling_rate / self.hop_length) 332 | - np.round(s * self.sampling_rate / self.hop_length) 333 | ) 334 | ) 335 | 336 | # Trim tailing silences 337 | phones = phones[:end_idx] 338 | durations = durations[:end_idx] 339 | 340 | return phones, durations, start_time, end_time 341 | 342 | def remove_outlier(self, values): 343 | values = np.array(values) 344 | p25 = np.percentile(values, 25) 345 | p75 = np.percentile(values, 75) 346 | lower = p25 - 1.5 * (p75 - p25) 347 | upper = p75 + 1.5 * (p75 - p25) 348 | normal_indices = np.logical_and(values > lower, values < upper) 349 | 350 | return values[normal_indices] 351 | 352 | def normalize(self, in_dir, mean, std): 353 | max_value = np.finfo(np.float64).min 354 | min_value = np.finfo(np.float64).max 355 | for filename in os.listdir(in_dir): 356 | filename = os.path.join(in_dir, filename) 357 | values = (np.load(filename) - mean) / std 358 | np.save(filename, values) 359 | 360 | max_value = max(max_value, max(values)) 361 | min_value = min(min_value, min(values)) 362 | 363 | return min_value, max_value 364 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.6.3 2 | g2pk==0.9.4 3 | torch==1.7.0 4 | numpy==1.19.5 5 | tgt==1.4.3 6 | scipy==1.5.0 7 | pyworld==0.2.10 8 | librosa==0.7.2 9 | numba==0.53.1 10 | matplotlib==3.2.2 11 | unidecode==1.1.1 12 | inflect==4.1.0 13 | g2p-en==2.1.0 14 | tensorboard==2.4.1 15 | python_speech_features==0.6 16 | pandas==1.1.3 17 | pysptk==0.1.18 18 | tensorflow==2.4.0 19 | PyYAML==5.4.1 20 | tqdm==4.46.1 21 | moviepy==1.0.3 22 | quickspacer==1.0.4 -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from string import punctuation 4 | import os 5 | import json 6 | 7 | import torch 8 | import yaml 9 | import numpy as np 10 | from torch.utils.data import DataLoader 11 | from g2p_en import G2p 12 | 13 | from utils.model import get_model, get_vocoder 14 | from utils.tools import to_device, synth_samples 15 | from dataset import TextDataset 16 | from text import text_to_sequence 17 | from text.korean import tokenize, normalize_nonchar 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | def read_lexicon(lex_path): 23 | lexicon = {} 24 | with open(lex_path) as f: 25 | for line in f: 26 | temp = re.split(r"\s+", line.strip("\n")) 27 | word = temp[0] 28 | phones = temp[1:] 29 | if word not in lexicon: 30 | lexicon[word] = phones 31 | return lexicon 32 | 33 | 34 | def preprocess_korean(text, preprocess_config): 35 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) 36 | 37 | phones = [] 38 | words = filter(None, re.split(r"([,;.\-\?\!\s+])", text)) 39 | for w in words: 40 | if w in lexicon: 41 | phones += lexicon[w] 42 | else: 43 | phones += list(filter(lambda p: p != " ", tokenize(w, norm=False))) 44 | phones = "{" + "}{".join(phones) + "}" 45 | phones = normalize_nonchar(phones, inference=True) 46 | phones = phones.replace("}{", " ") 47 | 48 | print("Raw Text Sequence: {}".format(text)) 49 | print("Phoneme Sequence: {}".format(phones)) 50 | sequence = np.array( 51 | text_to_sequence( 52 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] 53 | ) 54 | ) 55 | 56 | return np.array(sequence) 57 | 58 | 59 | def preprocess_english(text, preprocess_config): 60 | text = text.rstrip(punctuation) 61 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) 62 | 63 | g2p = G2p() 64 | phones = [] 65 | words = filter(None, re.split(r"([,;.\-\?\!\s+])", text)) 66 | for w in words: 67 | if w.lower() in lexicon: 68 | phones += lexicon[w.lower()] 69 | else: 70 | phones += list(filter(lambda p: p != " ", g2p(w))) 71 | phones = "{" + "}{".join(phones) + "}" 72 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) 73 | phones = phones.replace("}{", " ") 74 | 75 | print("Raw Text Sequence: {}".format(text)) 76 | print("Phoneme Sequence: {}".format(phones)) 77 | sequence = np.array( 78 | text_to_sequence( 79 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] 80 | ) 81 | ) 82 | 83 | return np.array(sequence) 84 | 85 | 86 | def synthesize(model, step, configs, vocoder, batchs, control_values, tag): 87 | preprocess_config, model_config, train_config = configs 88 | pitch_control, energy_control, duration_control = control_values 89 | 90 | for batch in batchs: 91 | batch = to_device(batch, device) 92 | with torch.no_grad(): 93 | # Forward 94 | output = model( 95 | *(batch[2:]), 96 | p_control=pitch_control, 97 | e_control=energy_control, 98 | d_control=duration_control 99 | ) 100 | synth_samples( 101 | batch, 102 | output, 103 | vocoder, 104 | model_config, 105 | preprocess_config, 106 | train_config["path"]["result_path"], 107 | tag, 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--restore_step", type=int, required=True) 115 | parser.add_argument( 116 | "--mode", 117 | type=str, 118 | choices=["batch", "single"], 119 | required=True, 120 | help="Synthesize a whole dataset or a single sentence", 121 | ) 122 | parser.add_argument( 123 | "--source", 124 | type=str, 125 | default=None, 126 | help="path to a source file with format like train.txt and val.txt, for batch mode only", 127 | ) 128 | parser.add_argument( 129 | "--text", 130 | type=str, 131 | default=None, 132 | help="raw text to synthesize, for single-sentence mode only", 133 | ) 134 | parser.add_argument( 135 | "--speaker_id", 136 | type=str, 137 | default="p001", 138 | help="speaker ID for multi-speaker synthesis, for single-sentence mode only", 139 | ) 140 | parser.add_argument( 141 | "--emotion_id", 142 | type=str, 143 | default="happy", 144 | help="emotion ID for multi-emotion synthesis, for single-sentence mode only", 145 | ) 146 | parser.add_argument( 147 | "--arousal", 148 | type=str, 149 | default="3", 150 | help="arousal value for multi-emotion synthesis, for single-sentence mode only", 151 | ) 152 | parser.add_argument( 153 | "--valence", 154 | type=str, 155 | default="3", 156 | help="valence value for multi-emotion synthesis, for single-sentence mode only", 157 | ) 158 | parser.add_argument( 159 | "-p", 160 | "--preprocess_config", 161 | type=str, 162 | required=True, 163 | help="path to preprocess.yaml", 164 | ) 165 | parser.add_argument( 166 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 167 | ) 168 | parser.add_argument( 169 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 170 | ) 171 | parser.add_argument( 172 | "--pitch_control", 173 | type=float, 174 | default=1.0, 175 | help="control the pitch of the whole utterance, larger value for higher pitch", 176 | ) 177 | parser.add_argument( 178 | "--energy_control", 179 | type=float, 180 | default=1.0, 181 | help="control the energy of the whole utterance, larger value for larger volume", 182 | ) 183 | parser.add_argument( 184 | "--duration_control", 185 | type=float, 186 | default=1.0, 187 | help="control the speed of the whole utterance, larger value for slower speaking rate", 188 | ) 189 | args = parser.parse_args() 190 | 191 | # Check source texts 192 | if args.mode == "batch": 193 | assert args.source is not None and args.text is None 194 | if args.mode == "single": 195 | assert args.source is None and args.text is not None 196 | 197 | # Read Config 198 | preprocess_config = yaml.load( 199 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 200 | ) 201 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 202 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 203 | configs = (preprocess_config, model_config, train_config) 204 | 205 | # Get model 206 | model = get_model(args, configs, device, train=False) 207 | 208 | # Load vocoder 209 | vocoder = get_vocoder(model_config, device) 210 | 211 | # Preprocess texts 212 | if args.mode == "batch": 213 | # Get dataset 214 | dataset = TextDataset(args.source, preprocess_config, model_config) 215 | batchs = DataLoader( 216 | dataset, 217 | batch_size=8, 218 | collate_fn=dataset.collate_fn, 219 | ) 220 | tag = None 221 | if args.mode == "single": 222 | emotions = arousals = valences = None 223 | ids = raw_texts = [args.text[:100]] 224 | with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "speakers.json")) as f: 225 | speaker_map = json.load(f) 226 | speakers = np.array([speaker_map[args.speaker_id]]) 227 | if model_config["multi_emotion"]: 228 | with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "emotions.json")) as f: 229 | json_raw = json.load(f) 230 | emotion_map = json_raw["emotion_dict"] 231 | arousal_map = json_raw["arousal_dict"] 232 | valence_map = json_raw["valence_dict"] 233 | emotions = np.array([emotion_map[args.emotion_id]]) 234 | arousals = np.array([arousal_map[args.arousal]]) 235 | valences = np.array([valence_map[args.valence]]) 236 | if preprocess_config["preprocessing"]["text"]["language"] == "kr": 237 | texts = np.array([preprocess_korean(args.text, preprocess_config)]) 238 | elif preprocess_config["preprocessing"]["text"]["language"] == "en": 239 | texts = np.array([preprocess_english(args.text, preprocess_config)]) 240 | text_lens = np.array([len(texts[0])]) 241 | batchs = [(ids, raw_texts, speakers, emotions, arousals, valences, texts, text_lens, max(text_lens))] 242 | tag = f"{args.speaker_id}_{args.emotion_id}" 243 | 244 | control_values = args.pitch_control, args.energy_control, args.duration_control 245 | 246 | synthesize(model, args.restore_step, configs, vocoder, batchs, control_values, tag) -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | from .korean_dict import char_to_id, id_to_char 6 | 7 | # Regular expression matching text enclosed in curly braces: 8 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 9 | 10 | 11 | def text_to_sequence(text, cleaner_names): 12 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | 14 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 15 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 16 | 17 | Args: 18 | text: string to convert to a sequence 19 | cleaner_names: names of the cleaner functions to run the text through 20 | 21 | Returns: 22 | List of integers corresponding to the symbols in the text 23 | """ 24 | sequence = [] 25 | 26 | # Mappings from symbol to numeric ID and vice versa: 27 | _language, _symbol_to_id, _ = ("kr", char_to_id, id_to_char) if "korean_cleaners" in cleaner_names\ 28 | else ("en", {s: i for i, s in enumerate(symbols)}, {i: s for i, s in enumerate(symbols)}) 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names), _symbol_to_id) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names), _symbol_to_id) 38 | sequence += _arpabet_to_sequence(m.group(2), _language, _symbol_to_id) 39 | text = m.group(3) 40 | 41 | return sequence 42 | 43 | 44 | def _clean_text(text, cleaner_names): 45 | for name in cleaner_names: 46 | cleaner = getattr(cleaners, name) 47 | if not cleaner: 48 | raise Exception("Unknown cleaner: %s" % name) 49 | text = cleaner(text) 50 | return text 51 | 52 | 53 | def _symbols_to_sequence(symbols, _symbol_to_id): 54 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s, _symbol_to_id)] 55 | 56 | 57 | def _arpabet_to_sequence(text, _language, _symbol_to_id): 58 | if _language == "kr": 59 | return _symbols_to_sequence([s for s in text.split()], _symbol_to_id) 60 | return _symbols_to_sequence(["@" + s for s in text.split()], _symbol_to_id) 61 | 62 | 63 | def _should_keep_symbol(s, _symbol_to_id): 64 | return s in _symbol_to_id and s != "_" and s != "~" 65 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | from .korean import normalize 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | 92 | 93 | def korean_cleaners(text): 94 | '''Pipeline for Korean (Hangul) text, including number and abbreviation expansion.''' 95 | text = normalize(text) 96 | text = collapse_whitespace(text) 97 | return text -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /text/korean.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Code based on 3 | 4 | import re 5 | import os 6 | import ast 7 | import json 8 | from quickspacer import Spacer 9 | from g2pk import G2p 10 | from jamo import hangul_to_jamo, h2j, j2h 11 | from jamo.jamo import _jamo_char_to_hcj 12 | from .korean_dict import JAMO_LEADS, JAMO_VOWELS, JAMO_TAILS, ko_dict 13 | 14 | g2p=G2p() 15 | spacer = Spacer(level=3) 16 | 17 | 18 | def tokenize(text, norm=True): 19 | """ 20 | Input --- Grapheme in string 21 | Output --- Phoneme in list 22 | 23 | Example: 24 | '한글은 위대하다.' --> ['ᄒ', 'ᅡ', 'ᆫ', 'ᄀ', 'ᅳ', 'ᄅ', 'ᅳ', ' ', 'ᄂ', 'ᅱ', 'ᄃ', 'ᅢ', 'ᄒ', 'ᅡ', 'ᄃ', 'ᅡ', '.'] 25 | """ 26 | if norm: 27 | text = normalize(text) 28 | text = g2p(text) 29 | tokens = list(hangul_to_jamo(text)) 30 | 31 | return tokens 32 | 33 | 34 | def detokenize(tokens): 35 | """s 36 | Input --- Grapheme or Phoneme in list 37 | Output --- Grapheme or Phoneme in string 38 | 39 | Example: 40 | ['ᄒ', 'ᅡ', 'ᆫ', 'ᄀ', 'ᅳ', 'ᆯ', 'ᄋ', 'ᅳ', 'ᆫ', ' ', 'ᄋ', 'ᅱ', 'ᄃ', 'ᅢ', 'ᄒ', 'ᅡ', 'ᄃ', 'ᅡ', '.'] --> '한글은 위대하다.' 41 | ['ᄒ', 'ᅡ', 'ᆫ', 'ᄀ', 'ᅳ', 'ᄅ', 'ᅳ', ' ', 'ᄂ', 'ᅱ', 'ᄃ', 'ᅢ', 'ᄒ', 'ᅡ', 'ᄃ', 'ᅡ', '.'] --> '한그르 뉘대하다.' 42 | """ 43 | tokens = h2j(tokens) 44 | 45 | idx = 0 46 | text = "" 47 | candidates = [] 48 | 49 | while True: 50 | if idx >= len(tokens): 51 | text += _get_text_from_candidates(candidates) 52 | break 53 | 54 | char = tokens[idx] 55 | mode = _get_mode(char) 56 | 57 | if mode == 0: 58 | text += _get_text_from_candidates(candidates) 59 | candidates = [char] 60 | elif mode == -1: 61 | text += _get_text_from_candidates(candidates) 62 | text += char 63 | candidates = [] 64 | else: 65 | candidates.append(char) 66 | 67 | idx += 1 68 | return text 69 | 70 | 71 | def _get_mode(char): 72 | if char in JAMO_LEADS: 73 | return 0 74 | elif char in JAMO_VOWELS: 75 | return 1 76 | elif char in JAMO_TAILS: 77 | return 2 78 | else: 79 | return -1 80 | 81 | 82 | def _get_text_from_candidates(candidates): 83 | if len(candidates) == 0: 84 | return "" 85 | elif len(candidates) == 1: 86 | return _jamo_char_to_hcj(candidates[0]) 87 | else: 88 | return j2h(**dict(zip(["lead", "vowel", "tail"], candidates))) 89 | 90 | 91 | def compare_sentence_with_jamo(text1, text2): 92 | return h2j(text1) != h2j(text2) 93 | 94 | 95 | def normalize(text): 96 | """ 97 | Transliterate input text into Hangul grapheme. 98 | """ 99 | text = text.strip() 100 | text = re.sub('\(\d+일\)', '', text) 101 | text = re.sub('\([⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]+\)', '', text) 102 | 103 | text = normalize_with_dictionary(text, ko_dict["etc_dictionary"]) 104 | text = normalize_english(text) 105 | text = re.sub('[a-zA-Z]+', normalize_upper, text) 106 | 107 | text = normalize_quote(text) 108 | text = normalize_number(text) 109 | text = normalize_nonchar(text) 110 | text = spacer.space([text])[0] 111 | 112 | return text 113 | 114 | 115 | def normalize_nonchar(text, inference=False): 116 | return re.sub(r"\{[^\w\s]?\}", "{sp}", text) if inference else\ 117 | re.sub(r"[^\w\s]?", "", text) 118 | 119 | 120 | def normalize_with_dictionary(text, dic): 121 | if any(key in text for key in dic.keys()): 122 | pattern = re.compile('|'.join(re.escape(key) for key in dic.keys())) 123 | return pattern.sub(lambda x: dic[x.group()], text) 124 | else: 125 | return text 126 | 127 | 128 | def normalize_english(text): 129 | def fn(m): 130 | word = m.group() 131 | if word in ko_dict["english_dictionary"]: 132 | return ko_dict["english_dictionary"].get(word) 133 | else: 134 | return word 135 | 136 | text = re.sub("([A-Za-z]+)", fn, text) 137 | return text 138 | 139 | 140 | def normalize_upper(text): 141 | text = text.group(0) 142 | 143 | if all([char.isupper() for char in text]): 144 | return "".join(ko_dict["upper_to_kor"][char] for char in text) 145 | else: 146 | return text 147 | 148 | 149 | def normalize_quote(text): 150 | def fn(found_text): 151 | from nltk import sent_tokenize # NLTK doesn't along with multiprocessing 152 | 153 | found_text = found_text.group() 154 | unquoted_text = found_text[1:-1] 155 | 156 | sentences = sent_tokenize(unquoted_text) 157 | return " ".join(["'{}'".format(sent) for sent in sentences]) 158 | 159 | return re.sub(ko_dict["quote_checker"], fn, text) 160 | 161 | 162 | def normalize_number(text): 163 | text = normalize_with_dictionary(text, ko_dict["unit_to_kor"]) 164 | text = re.sub(ko_dict["number_checker"] + ko_dict["count_checker"], 165 | lambda x: number_to_korean(x, True), text) 166 | text = re.sub(ko_dict["number_checker"], 167 | lambda x: number_to_korean(x, False), text) 168 | return text 169 | 170 | 171 | def number_to_korean(num_str, is_count=False): 172 | zero_cnt = 0 173 | if is_count: 174 | num_str, unit_str = num_str.group(1), num_str.group(2) 175 | else: 176 | num_str, unit_str = num_str.group(), "" 177 | 178 | num_str = num_str.replace(',', '') 179 | 180 | if is_count and len(num_str) > 2: 181 | is_count = False 182 | 183 | if len(num_str) > 1 and num_str.startswith("0") and '.' not in num_str: 184 | for n in num_str: 185 | zero_cnt += 1 if n == "0" else 0 186 | num_str = num_str[zero_cnt:] 187 | 188 | kor = "" 189 | if num_str != '': 190 | num = ast.literal_eval(num_str) 191 | 192 | if num == 0: 193 | return "영" + (unit_str if unit_str else "") 194 | 195 | check_float = num_str.split('.') 196 | if len(check_float) == 2: 197 | digit_str, float_str = check_float 198 | elif len(check_float) >= 3: 199 | raise Exception(" [!] Wrong number format") 200 | else: 201 | digit_str, float_str = check_float[0], None 202 | 203 | if is_count and float_str is not None: 204 | raise Exception(" [!] `is_count` and float number does not fit each other") 205 | 206 | digit = int(digit_str) 207 | 208 | if digit_str.startswith("-") or digit_str.startswith("+"): 209 | digit, digit_str = abs(digit), str(abs(digit)) 210 | 211 | size = len(str(digit)) 212 | tmp = [] 213 | 214 | for i, v in enumerate(digit_str, start=1): 215 | v = int(v) 216 | 217 | if v != 0: 218 | if is_count: 219 | tmp += ko_dict["count_to_kor1"][v] 220 | else: 221 | tmp += ko_dict["num_to_kor1"][v] 222 | if v == 1 and i != 1 and i != len(digit_str): 223 | tmp = tmp[:-1] 224 | tmp += ko_dict["num_to_kor3"][(size - i) % 4] 225 | 226 | if (size - i) % 4 == 0 and len(tmp) != 0: 227 | kor += "".join(tmp) 228 | tmp = [] 229 | kor += ko_dict["num_to_kor2"][int((size - i) / 4)] 230 | 231 | if is_count: 232 | if kor.startswith("한") and len(kor) > 1: 233 | kor = kor[1:] 234 | 235 | if any(word in kor for word in ko_dict["count_tenth_dict"]): 236 | kor = re.sub( 237 | '|'.join(ko_dict["count_tenth_dict"].keys()), 238 | lambda x: ko_dict["count_tenth_dict"][x.group()], kor) 239 | 240 | if not is_count and kor.startswith("일") and len(kor) > 1: 241 | kor = kor[1:] 242 | 243 | if float_str is not None and float_str != "": 244 | kor += "영" if kor == "" else "" 245 | kor += "쩜 " 246 | kor += re.sub('\d', lambda x: ko_dict["num_to_kor"][x.group()], float_str) 247 | 248 | if num_str.startswith("+"): 249 | kor = "플러스 " + kor 250 | elif num_str.startswith("-"): 251 | kor = "마이너스 " + kor 252 | if zero_cnt > 0: 253 | kor = "공"*zero_cnt + kor 254 | 255 | return kor + unit_str 256 | 257 | 258 | def test_normalize(texts): 259 | for text in texts: 260 | raw = text 261 | norm = normalize(text) 262 | 263 | print("="*30) 264 | print(raw) 265 | print(norm) 266 | 267 | 268 | if __name__ == "__main__": 269 | test_inputs = [ 270 | "JTBC는 JTBCs를 DY는 A가 Absolute", 271 | "오늘(13일) 3,600마리 강아지가", 272 | "60.3%", 273 | '"저돌"(猪突) 입니다.', 274 | '비대위원장이 지난 1월 이런 말을 했습니다. “난 그냥 산돼지처럼 돌파하는 스타일이다”', 275 | "지금은 -12.35%였고 종류는 5가지와 19가지, 그리고 55가지였다", 276 | "JTBC는 TH와 K 양이 2017년 9월 12일 오후 12시에 24살이 된다", 277 | "이렇게 세트로 98,000원인데, 지금 세일 중이어서, 78,400원이에요.", 278 | "이렇게 세트로 98000원인데, 지금 세일 중이어서, 78400원이에요.", 279 | "저, 토익 970점이요.", 280 | "원래대로라면은 0점 처리해야 하는데.", 281 | "진짜? 그럼 너한테 한 두 마리만 줘도 돼?", 282 | "내가 화분이 좀 많아서. 그래도 17평에서 20평은 됐으면 좋겠어. 요즘 애들, 많이 사는 원룸, 그런데는 말고.", 283 | "매매는 3억까지. 전세는 1억 5천. 그 이상은 안돼.", 284 | "1억 3천이요.", 285 | "지금 3개월입니다.", 286 | "기계값 200만원 짜리를. 30개월 할부로 300만원에 파셨잖아요!", 287 | "오늘(13일) 99통 강아지가", 288 | "이제 55개.. 째예요.", 289 | "이제 55개월.. 째예요.", 290 | "한 근에 3만 5천 원이나 하는 1++ 등급 한우라니까!", 291 | "한 근에 3만 5천 원이나 하는 A+ 등급 한우라니까!", 292 | "19,22,30,34,39,44+36", 293 | "그거 1+1으로 프로모션 때려버리자.", 294 | "아 테이프는 너무 우리 때 얘기인가? 그 MP3 파일 같은 거 있잖아. 영어 중국어 이런 거. 영어 책 읽어주고 그런 거.", 295 | "231 cm야.", 296 | "1 cm야.", 297 | "21 cm야.", 298 | "110 cm야.", 299 | "21마리야.", 300 | "아, 시력은 알고 있어요. 왼쪽 0.3이고 오른쪽 0.1이요.", 301 | "왼쪽 0점", 302 | "우리 스마트폰 쓰기 전에 공일일 번호였을때. 그 때 썼던 전화기를 2G라고 하거든?", 303 | "102마리 강아지.", 304 | "87. 105. 120. 네. 100 넘었어요!", 305 | ] 306 | 307 | test_normalize(test_inputs) -------------------------------------------------------------------------------- /text/korean_dict.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Symbols 3 | PAD = '_' 4 | EOS = '~' 5 | PUNC = '!\'(),-.:;?' 6 | SPACE = ' ' 7 | _SILENCES = ['sp', 'spn', 'sil'] 8 | 9 | JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) 10 | JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) 11 | JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) 12 | 13 | VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE 14 | ALL_SYMBOLS = list(PAD + EOS + VALID_CHARS) + _SILENCES 15 | 16 | char_to_id = {c: i for i, c in enumerate(ALL_SYMBOLS)} 17 | id_to_char = {i: c for i, c in enumerate(ALL_SYMBOLS)} 18 | 19 | # Dictionaries 20 | ko_dict = { 21 | "quote_checker": """([`"'"“‘’”’])(.+?)([`"'"“‘’”’])""", 22 | "number_checker": "([+-]?\d{1,3},\d{3}(?!\d)|[+-]?\d+)[\.]?\d*", # "([+-]?\d[\d,]*)[\.]?\d*" 23 | "count_checker": "(시|명|가지|살|마리|포기|송이|수|톨|통|점|개(?!월)|벌|척|채|다발|그루|자루|줄|켤레|그릇|잔|마디|상자|사람|곡|병|판)", 24 | 25 | "num_to_kor": { 26 | '0': '영', 27 | '1': '일', 28 | '2': '이', 29 | '3': '삼', 30 | '4': '사', 31 | '5': '오', 32 | '6': '육', 33 | '7': '칠', 34 | '8': '팔', 35 | '9': '구', 36 | }, 37 | 38 | "num_to_kor1": [""] + list("일이삼사오육칠팔구"), 39 | "num_to_kor2": [""] + list("만억조경해"), 40 | "num_to_kor3": [""] + list("십백천"), 41 | "count_to_kor1": [""] + ["한","두","세","네","다섯","여섯","일곱","여덟","아홉"], # [""] + ["하나","둘","셋","넷","다섯","여섯","일곱","여덟","아홉"] 42 | 43 | "count_tenth_dict": { 44 | "십": "열", 45 | "두십": "스물", 46 | "세십": "서른", 47 | "네십": "마흔", 48 | "다섯십": "쉰", 49 | "여섯십": "예순", 50 | "일곱십": "일흔", 51 | "여덟십": "여든", 52 | "아홉십": "아흔", 53 | }, 54 | 55 | "unit_to_kor": { 56 | '%': '퍼센트', 57 | 'ml': '밀리리터', 58 | 'cm': '센치미터', 59 | 'mm': '밀리미터', 60 | 'km': '킬로미터', 61 | 'kg': '킬로그람', 62 | 'm': '미터', 63 | }, 64 | 65 | "upper_to_kor": { 66 | 'A': '에이', 67 | 'B': '비', 68 | 'C': '씨', 69 | 'D': '디', 70 | 'E': '이', 71 | 'F': '에프', 72 | 'G': '지', 73 | 'H': '에이치', 74 | 'I': '아이', 75 | 'J': '제이', 76 | 'K': '케이', 77 | 'L': '엘', 78 | 'M': '엠', 79 | 'N': '엔', 80 | 'O': '오', 81 | 'P': '피', 82 | 'Q': '큐', 83 | 'R': '알', 84 | 'S': '에스', 85 | 'T': '티', 86 | 'U': '유', 87 | 'V': '브이', 88 | 'W': '더블유', 89 | 'X': '엑스', 90 | 'Y': '와이', 91 | 'Z': '지', 92 | }, 93 | 94 | "english_dictionary": { 95 | 'TV': '티비', 96 | 'CCTV': '씨씨티비', 97 | 'cctv': '씨씨티비', 98 | 'cc': '씨씨', 99 | 'Apple': '애플', 100 | 'lte': '엘티이', 101 | 'KG': '킬로그람', 102 | 'x': '엑스', 103 | 'z': '제트', 104 | 'Yo': '요', 105 | 'YOLO': '욜로', 106 | 'Gone': '건', 107 | 'gone': '건', 108 | 'Have': '헤브', 109 | 'p': '피', 110 | 'ppt': '피피티', 111 | 'suv': '에스유브이', 112 | }, 113 | 114 | "etc_dictionary": { 115 | '1+1': '원 플러스 원', 116 | '+': '플러스', 117 | 'MP3': '엠피쓰리', 118 | '5G': '파이브지', 119 | '4G': '포지', 120 | '3G': '쓰리지', 121 | '2G': '투지', 122 | 'A/S': '에이 에스', 123 | '1/3':'삼분의 일', 124 | 'greentea907': '그린티 구공칠', 125 | 'CNT 123': '씨엔티 일이삼', 126 | '14학번': '일사 학번', 127 | '7011번': '칠공일일번', 128 | 'P8학원': '피에잇 학원', 129 | '102마리': '백 두마리', 130 | '20명': '스무명', 131 | } 132 | } 133 | 134 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from text import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | 11 | from utils.model import get_model, get_vocoder, get_param_num 12 | from utils.tools import to_device, log, synth_one_sample 13 | from model import FastSpeech2Loss 14 | from dataset import Dataset 15 | 16 | from evaluate import evaluate 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | def main(args, configs): 22 | print("Prepare training ...") 23 | 24 | preprocess_config, model_config, train_config = configs 25 | 26 | # Get dataset 27 | dataset = Dataset( 28 | "train.txt", preprocess_config, model_config, train_config, sort=True, drop_last=True 29 | ) 30 | batch_size = train_config["optimizer"]["batch_size"] 31 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset 32 | assert batch_size * group_size < len(dataset) 33 | loader = DataLoader( 34 | dataset, 35 | batch_size=batch_size * group_size, 36 | shuffle=True, 37 | collate_fn=dataset.collate_fn, 38 | ) 39 | 40 | # Prepare model 41 | model, optimizer = get_model(args, configs, device, train=True) 42 | model = nn.DataParallel(model) 43 | num_param = get_param_num(model) 44 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) 45 | print("Number of FastSpeech2 Parameters:", num_param) 46 | 47 | # Load vocoder 48 | vocoder = get_vocoder(model_config, device) 49 | 50 | # Init logger 51 | for p in train_config["path"].values(): 52 | os.makedirs(p, exist_ok=True) 53 | train_log_path = os.path.join(train_config["path"]["log_path"], "train") 54 | val_log_path = os.path.join(train_config["path"]["log_path"], "val") 55 | os.makedirs(train_log_path, exist_ok=True) 56 | os.makedirs(val_log_path, exist_ok=True) 57 | train_logger = SummaryWriter(train_log_path) 58 | val_logger = SummaryWriter(val_log_path) 59 | 60 | # Training 61 | step = args.restore_step + 1 62 | epoch = 1 63 | grad_acc_step = train_config["optimizer"]["grad_acc_step"] 64 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] 65 | total_step = train_config["step"]["total_step"] 66 | log_step = train_config["step"]["log_step"] 67 | save_step = train_config["step"]["save_step"] 68 | synth_step = train_config["step"]["synth_step"] 69 | val_step = train_config["step"]["val_step"] 70 | 71 | outer_bar = tqdm(total=total_step, desc="Training", position=0) 72 | outer_bar.n = args.restore_step 73 | outer_bar.update() 74 | 75 | while True: 76 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) 77 | for batchs in loader: 78 | for batch in batchs: 79 | batch = to_device(batch, device) 80 | 81 | # Forward 82 | output = model(*(batch[2:])) 83 | 84 | # Cal Loss 85 | losses = Loss(batch, output) 86 | total_loss = losses[0] 87 | 88 | # Backward 89 | total_loss = total_loss / grad_acc_step 90 | total_loss.backward() 91 | if step % grad_acc_step == 0: 92 | # Clipping gradients to avoid gradient explosion 93 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) 94 | 95 | # Update weights 96 | optimizer.step_and_update_lr() 97 | optimizer.zero_grad() 98 | 99 | if step % log_step == 0: 100 | losses = [l.item() for l in losses] 101 | message1 = "Step {}/{}, ".format(step, total_step) 102 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format( 103 | *losses 104 | ) 105 | 106 | with open(os.path.join(train_log_path, "log.txt"), "a") as f: 107 | f.write(message1 + message2 + "\n") 108 | 109 | outer_bar.write(message1 + message2) 110 | 111 | log(train_logger, step, losses=losses) 112 | 113 | if step % synth_step == 0: 114 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 115 | batch, 116 | output, 117 | vocoder, 118 | model_config, 119 | preprocess_config, 120 | ) 121 | log( 122 | train_logger, 123 | fig=fig, 124 | tag="Training/step_{}_{}".format(step, tag), 125 | ) 126 | sampling_rate = preprocess_config["preprocessing"]["audio"][ 127 | "sampling_rate" 128 | ] 129 | log( 130 | train_logger, 131 | audio=wav_reconstruction, 132 | sampling_rate=sampling_rate, 133 | tag="Training/step_{}_{}_reconstructed".format(step, tag), 134 | ) 135 | log( 136 | train_logger, 137 | audio=wav_prediction, 138 | sampling_rate=sampling_rate, 139 | tag="Training/step_{}_{}_synthesized".format(step, tag), 140 | ) 141 | 142 | if step % val_step == 0: 143 | model.eval() 144 | message = evaluate(model, step, configs, val_logger, vocoder) 145 | with open(os.path.join(val_log_path, "log.txt"), "a") as f: 146 | f.write(message + "\n") 147 | outer_bar.write(message) 148 | 149 | model.train() 150 | 151 | if step % save_step == 0: 152 | torch.save( 153 | { 154 | "model": model.module.state_dict(), 155 | "optimizer": optimizer._optimizer.state_dict(), 156 | }, 157 | os.path.join( 158 | train_config["path"]["ckpt_path"], 159 | "{}.pth.tar".format(step), 160 | ), 161 | ) 162 | 163 | if step == total_step: 164 | quit() 165 | step += 1 166 | outer_bar.update(1) 167 | 168 | inner_bar.update(1) 169 | epoch += 1 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument("--restore_step", type=int, default=0) 175 | parser.add_argument( 176 | "-p", 177 | "--preprocess_config", 178 | type=str, 179 | required=True, 180 | help="path to preprocess.yaml", 181 | ) 182 | parser.add_argument( 183 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 184 | ) 185 | parser.add_argument( 186 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 187 | ) 188 | args = parser.parse_args() 189 | 190 | # Read Config 191 | preprocess_config = yaml.load( 192 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 193 | ) 194 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 195 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 196 | configs = (preprocess_config, model_config, train_config) 197 | 198 | main(args, configs) 199 | -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | 11 | class FFTBlock(torch.nn.Module): 12 | """FFT Block""" 13 | 14 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 15 | super(FFTBlock, self).__init__() 16 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 17 | self.pos_ffn = PositionwiseFeedForward( 18 | d_model, d_inner, kernel_size, dropout=dropout 19 | ) 20 | 21 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 22 | enc_output, enc_slf_attn = self.slf_attn( 23 | enc_input, enc_input, enc_input, mask=slf_attn_mask 24 | ) 25 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 26 | 27 | enc_output = self.pos_ffn(enc_output) 28 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 29 | 30 | return enc_output, enc_slf_attn 31 | 32 | 33 | class ConvNorm(torch.nn.Module): 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels, 38 | kernel_size=1, 39 | stride=1, 40 | padding=None, 41 | dilation=1, 42 | bias=True, 43 | w_init_gain="linear", 44 | ): 45 | super(ConvNorm, self).__init__() 46 | 47 | if padding is None: 48 | assert kernel_size % 2 == 1 49 | padding = int(dilation * (kernel_size - 1) / 2) 50 | 51 | self.conv = torch.nn.Conv1d( 52 | in_channels, 53 | out_channels, 54 | kernel_size=kernel_size, 55 | stride=stride, 56 | padding=padding, 57 | dilation=dilation, 58 | bias=bias, 59 | ) 60 | 61 | def forward(self, signal): 62 | conv_signal = self.conv(signal) 63 | 64 | return conv_signal 65 | 66 | 67 | class PostNet(nn.Module): 68 | """ 69 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 70 | """ 71 | 72 | def __init__( 73 | self, 74 | n_mel_channels=80, 75 | postnet_embedding_dim=512, 76 | postnet_kernel_size=5, 77 | postnet_n_convolutions=5, 78 | ): 79 | 80 | super(PostNet, self).__init__() 81 | self.convolutions = nn.ModuleList() 82 | 83 | self.convolutions.append( 84 | nn.Sequential( 85 | ConvNorm( 86 | n_mel_channels, 87 | postnet_embedding_dim, 88 | kernel_size=postnet_kernel_size, 89 | stride=1, 90 | padding=int((postnet_kernel_size - 1) / 2), 91 | dilation=1, 92 | w_init_gain="tanh", 93 | ), 94 | nn.BatchNorm1d(postnet_embedding_dim), 95 | ) 96 | ) 97 | 98 | for i in range(1, postnet_n_convolutions - 1): 99 | self.convolutions.append( 100 | nn.Sequential( 101 | ConvNorm( 102 | postnet_embedding_dim, 103 | postnet_embedding_dim, 104 | kernel_size=postnet_kernel_size, 105 | stride=1, 106 | padding=int((postnet_kernel_size - 1) / 2), 107 | dilation=1, 108 | w_init_gain="tanh", 109 | ), 110 | nn.BatchNorm1d(postnet_embedding_dim), 111 | ) 112 | ) 113 | 114 | self.convolutions.append( 115 | nn.Sequential( 116 | ConvNorm( 117 | postnet_embedding_dim, 118 | n_mel_channels, 119 | kernel_size=postnet_kernel_size, 120 | stride=1, 121 | padding=int((postnet_kernel_size - 1) / 2), 122 | dilation=1, 123 | w_init_gain="linear", 124 | ), 125 | nn.BatchNorm1d(n_mel_channels), 126 | ) 127 | ) 128 | 129 | def forward(self, x): 130 | x = x.contiguous().transpose(1, 2) 131 | 132 | for i in range(len(self.convolutions) - 1): 133 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 134 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 135 | 136 | x = x.contiguous().transpose(1, 2) 137 | return x 138 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import transformer.Constants as Constants 6 | from .Layers import FFTBlock 7 | from text.symbols import symbols 8 | 9 | 10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 11 | """ Sinusoid position encoding table """ 12 | 13 | def cal_angle(position, hid_idx): 14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 15 | 16 | def get_posi_angle_vec(position): 17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 18 | 19 | sinusoid_table = np.array( 20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 21 | ) 22 | 23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 25 | 26 | if padding_idx is not None: 27 | # zero vector for padding dimension 28 | sinusoid_table[padding_idx] = 0.0 29 | 30 | return torch.FloatTensor(sinusoid_table) 31 | 32 | 33 | class Encoder(nn.Module): 34 | """ Encoder """ 35 | 36 | def __init__(self, config): 37 | super(Encoder, self).__init__() 38 | 39 | n_position = config["max_seq_len"] + 1 40 | n_src_vocab = len(symbols) + 1 41 | d_word_vec = config["transformer"]["encoder_hidden"] 42 | n_layers = config["transformer"]["encoder_layer"] 43 | n_head = config["transformer"]["encoder_head"] 44 | d_k = d_v = ( 45 | config["transformer"]["encoder_hidden"] 46 | // config["transformer"]["encoder_head"] 47 | ) 48 | d_model = config["transformer"]["encoder_hidden"] 49 | d_inner = config["transformer"]["conv_filter_size"] 50 | kernel_size = config["transformer"]["conv_kernel_size"] 51 | dropout = config["transformer"]["encoder_dropout"] 52 | 53 | self.max_seq_len = config["max_seq_len"] 54 | self.d_model = d_model 55 | 56 | self.src_word_emb = nn.Embedding( 57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD 58 | ) 59 | self.position_enc = nn.Parameter( 60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 61 | requires_grad=False, 62 | ) 63 | 64 | self.layer_stack = nn.ModuleList( 65 | [ 66 | FFTBlock( 67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 68 | ) 69 | for _ in range(n_layers) 70 | ] 71 | ) 72 | 73 | def forward(self, src_seq, mask, return_attns=False): 74 | 75 | enc_slf_attn_list = [] 76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 77 | 78 | # -- Prepare masks 79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 80 | 81 | # -- Forward 82 | if not self.training and src_seq.shape[1] > self.max_seq_len: 83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 84 | src_seq.shape[1], self.d_model 85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 86 | src_seq.device 87 | ) 88 | else: 89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[ 90 | :, :max_len, : 91 | ].expand(batch_size, -1, -1) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 96 | ) 97 | if return_attns: 98 | enc_slf_attn_list += [enc_slf_attn] 99 | 100 | return enc_output 101 | 102 | 103 | class Decoder(nn.Module): 104 | """ Decoder """ 105 | 106 | def __init__(self, config): 107 | super(Decoder, self).__init__() 108 | 109 | n_position = config["max_seq_len"] + 1 110 | d_word_vec = config["transformer"]["decoder_hidden"] 111 | n_layers = config["transformer"]["decoder_layer"] 112 | n_head = config["transformer"]["decoder_head"] 113 | d_k = d_v = ( 114 | config["transformer"]["decoder_hidden"] 115 | // config["transformer"]["decoder_head"] 116 | ) 117 | d_model = config["transformer"]["decoder_hidden"] 118 | d_inner = config["transformer"]["conv_filter_size"] 119 | kernel_size = config["transformer"]["conv_kernel_size"] 120 | dropout = config["transformer"]["decoder_dropout"] 121 | 122 | self.max_seq_len = config["max_seq_len"] 123 | self.d_model = d_model 124 | 125 | self.position_enc = nn.Parameter( 126 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 127 | requires_grad=False, 128 | ) 129 | 130 | self.layer_stack = nn.ModuleList( 131 | [ 132 | FFTBlock( 133 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 134 | ) 135 | for _ in range(n_layers) 136 | ] 137 | ) 138 | 139 | def forward(self, enc_seq, mask, return_attns=False): 140 | 141 | dec_slf_attn_list = [] 142 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 143 | 144 | # -- Forward 145 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 146 | # -- Prepare masks 147 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 148 | dec_output = enc_seq + get_sinusoid_encoding_table( 149 | enc_seq.shape[1], self.d_model 150 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 151 | enc_seq.device 152 | ) 153 | else: 154 | max_len = min(max_len, self.max_seq_len) 155 | 156 | # -- Prepare masks 157 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 158 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 159 | :, :max_len, : 160 | ].expand(batch_size, -1, -1) 161 | mask = mask[:, :max_len] 162 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 163 | 164 | for dec_layer in self.layer_stack: 165 | dec_output, dec_slf_attn = dec_layer( 166 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 167 | ) 168 | if return_attns: 169 | dec_slf_attn_list += [dec_slf_attn] 170 | 171 | return dec_output, mask 172 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | 49 | output = output.view(n_head, sz_b, len_q, d_v) 50 | output = ( 51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 52 | ) # b x lq x (n*dv) 53 | 54 | output = self.dropout(self.fc(output)) 55 | output = self.layer_norm(output + residual) 56 | 57 | return output, attn 58 | 59 | 60 | class PositionwiseFeedForward(nn.Module): 61 | """ A two-feed-forward-layer module """ 62 | 63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 64 | super().__init__() 65 | 66 | # Use Conv1D 67 | # position-wise 68 | self.w_1 = nn.Conv1d( 69 | d_in, 70 | d_hid, 71 | kernel_size=kernel_size[0], 72 | padding=(kernel_size[0] - 1) // 2, 73 | ) 74 | # position-wise 75 | self.w_2 = nn.Conv1d( 76 | d_hid, 77 | d_in, 78 | kernel_size=kernel_size[1], 79 | padding=(kernel_size[1] - 1) // 2, 80 | ) 81 | 82 | self.layer_norm = nn.LayerNorm(d_in) 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def forward(self, x): 86 | residual = x 87 | output = x.transpose(1, 2) 88 | output = self.w_2(F.relu(self.w_1(output))) 89 | output = output.transpose(1, 2) 90 | output = self.dropout(output) 91 | output = self.layer_norm(output + residual) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import Encoder, Decoder 2 | from .Layers import PostNet -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import hifigan 8 | from model import FastSpeech2, ScheduledOptim 9 | 10 | 11 | def get_model(args, configs, device, train=False): 12 | (preprocess_config, model_config, train_config) = configs 13 | 14 | model = FastSpeech2(preprocess_config, model_config).to(device) 15 | if args.restore_step: 16 | ckpt_path = os.path.join( 17 | train_config["path"]["ckpt_path"], 18 | "{}.pth.tar".format(args.restore_step), 19 | ) 20 | ckpt = torch.load(ckpt_path) 21 | model.load_state_dict(ckpt["model"]) 22 | 23 | if train: 24 | scheduled_optim = ScheduledOptim( 25 | model, train_config, model_config, args.restore_step 26 | ) 27 | if args.restore_step: 28 | scheduled_optim.load_state_dict(ckpt["optimizer"]) 29 | model.train() 30 | return model, scheduled_optim 31 | 32 | model.eval() 33 | model.requires_grad_ = False 34 | return model 35 | 36 | 37 | def get_param_num(model): 38 | num_param = sum(param.numel() for param in model.parameters()) 39 | return num_param 40 | 41 | 42 | def get_vocoder(config, device): 43 | name = config["vocoder"]["model"] 44 | speaker = config["vocoder"]["speaker"] 45 | 46 | if name == "MelGAN": 47 | if speaker == "LJSpeech": 48 | vocoder = torch.hub.load( 49 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 50 | ) 51 | elif speaker == "universal": 52 | vocoder = torch.hub.load( 53 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 54 | ) 55 | vocoder.mel2wav.eval() 56 | vocoder.mel2wav.to(device) 57 | elif name == "HiFi-GAN": 58 | with open("hifigan/config.json", "r") as f: 59 | config = json.load(f) 60 | config = hifigan.AttrDict(config) 61 | vocoder = hifigan.Generator(config) 62 | if speaker == "LJSpeech": 63 | ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar") 64 | elif speaker == "universal": 65 | ckpt = torch.load("hifigan/generator_universal.pth.tar") 66 | vocoder.load_state_dict(ckpt["generator"]) 67 | vocoder.eval() 68 | vocoder.remove_weight_norm() 69 | vocoder.to(device) 70 | 71 | return vocoder 72 | 73 | 74 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): 75 | name = model_config["vocoder"]["model"] 76 | with torch.no_grad(): 77 | if name == "MelGAN": 78 | wavs = vocoder.inverse(mels / np.log(10)) 79 | elif name == "HiFi-GAN": 80 | wavs = vocoder(mels).squeeze(1) 81 | 82 | wavs = ( 83 | wavs.cpu().numpy() 84 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"] 85 | ).astype("int16") 86 | wavs = [wav for wav in wavs] 87 | 88 | for i in range(len(mels)): 89 | if lengths is not None: 90 | wavs[i] = wavs[i][: lengths[i]] 91 | 92 | return wavs 93 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import matplotlib 8 | from scipy.io import wavfile 9 | from matplotlib import pyplot as plt 10 | 11 | 12 | matplotlib.use("Agg") 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def to_device(data, device): 19 | if len(data) == 12: 20 | ( 21 | ids, 22 | raw_texts, 23 | speakers, 24 | texts, 25 | src_lens, 26 | max_src_len, 27 | mels, 28 | mel_lens, 29 | max_mel_len, 30 | pitches, 31 | energies, 32 | durations, 33 | ) = data 34 | 35 | speakers = torch.from_numpy(speakers).long().to(device) 36 | texts = torch.from_numpy(texts).long().to(device) 37 | src_lens = torch.from_numpy(src_lens).to(device) 38 | mels = torch.from_numpy(mels).float().to(device) 39 | mel_lens = torch.from_numpy(mel_lens).to(device) 40 | pitches = torch.from_numpy(pitches).float().to(device) 41 | energies = torch.from_numpy(energies).to(device) 42 | durations = torch.from_numpy(durations).long().to(device) 43 | 44 | return ( 45 | ids, 46 | raw_texts, 47 | speakers, 48 | texts, 49 | src_lens, 50 | max_src_len, 51 | mels, 52 | mel_lens, 53 | max_mel_len, 54 | pitches, 55 | energies, 56 | durations, 57 | ) 58 | 59 | if len(data) == 15: 60 | ( 61 | ids, 62 | raw_texts, 63 | speakers, 64 | emotions, 65 | arousals, 66 | valences, 67 | texts, 68 | src_lens, 69 | max_src_len, 70 | mels, 71 | mel_lens, 72 | max_mel_len, 73 | pitches, 74 | energies, 75 | durations, 76 | ) = data 77 | 78 | speakers = torch.from_numpy(speakers).long().to(device) 79 | emotions = torch.from_numpy(emotions).long().to(device) 80 | arousals = torch.from_numpy(arousals).long().to(device) 81 | valences = torch.from_numpy(valences).long().to(device) 82 | texts = torch.from_numpy(texts).long().to(device) 83 | src_lens = torch.from_numpy(src_lens).to(device) 84 | mels = torch.from_numpy(mels).float().to(device) 85 | mel_lens = torch.from_numpy(mel_lens).to(device) 86 | pitches = torch.from_numpy(pitches).float().to(device) 87 | energies = torch.from_numpy(energies).to(device) 88 | durations = torch.from_numpy(durations).long().to(device) 89 | 90 | return ( 91 | ids, 92 | raw_texts, 93 | speakers, 94 | emotions, 95 | arousals, 96 | valences, 97 | texts, 98 | src_lens, 99 | max_src_len, 100 | mels, 101 | mel_lens, 102 | max_mel_len, 103 | pitches, 104 | energies, 105 | durations, 106 | ) 107 | 108 | if len(data) == 6: 109 | (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data 110 | 111 | speakers = torch.from_numpy(speakers).long().to(device) 112 | texts = torch.from_numpy(texts).long().to(device) 113 | src_lens = torch.from_numpy(src_lens).to(device) 114 | 115 | return (ids, raw_texts, speakers, texts, src_lens, max_src_len) 116 | 117 | if len(data) == 9: 118 | (ids, raw_texts, speakers, emotions, arousals, valences, texts, src_lens, max_src_len) = data 119 | 120 | speakers = torch.from_numpy(speakers).long().to(device) 121 | emotions = torch.from_numpy(emotions).long().to(device) 122 | arousals = torch.from_numpy(arousals).long().to(device) 123 | valences = torch.from_numpy(valences).long().to(device) 124 | texts = torch.from_numpy(texts).long().to(device) 125 | src_lens = torch.from_numpy(src_lens).to(device) 126 | 127 | return (ids, raw_texts, speakers, emotions, arousals, valences, texts, src_lens, max_src_len) 128 | 129 | 130 | def log( 131 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 132 | ): 133 | if losses is not None: 134 | logger.add_scalar("Loss/total_loss", losses[0], step) 135 | logger.add_scalar("Loss/mel_loss", losses[1], step) 136 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) 137 | logger.add_scalar("Loss/pitch_loss", losses[3], step) 138 | logger.add_scalar("Loss/energy_loss", losses[4], step) 139 | logger.add_scalar("Loss/duration_loss", losses[5], step) 140 | 141 | if fig is not None: 142 | logger.add_figure(tag, fig) 143 | 144 | if audio is not None: 145 | logger.add_audio( 146 | tag, 147 | audio / max(abs(audio)), 148 | sample_rate=sampling_rate, 149 | ) 150 | 151 | 152 | def get_mask_from_lengths(lengths, max_len=None): 153 | batch_size = lengths.shape[0] 154 | if max_len is None: 155 | max_len = torch.max(lengths).item() 156 | 157 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 158 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 159 | 160 | return mask 161 | 162 | 163 | def expand(values, durations): 164 | out = list() 165 | for value, d in zip(values, durations): 166 | out += [value] * max(0, int(d)) 167 | return np.array(out) 168 | 169 | 170 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): 171 | 172 | basename = targets[0][0] 173 | src_len = predictions[8][0].item() 174 | mel_len = predictions[9][0].item() 175 | mel_target = targets[9][0, :mel_len].detach().transpose(0, 1) 176 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) 177 | duration = targets[14][0, :src_len].detach().cpu().numpy() 178 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 179 | pitch = targets[12][0, :src_len].detach().cpu().numpy() 180 | pitch = expand(pitch, duration) 181 | else: 182 | pitch = targets[12][0, :mel_len].detach().cpu().numpy() 183 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 184 | energy = targets[13][0, :src_len].detach().cpu().numpy() 185 | energy = expand(energy, duration) 186 | else: 187 | energy = targets[13][0, :mel_len].detach().cpu().numpy() 188 | 189 | with open( 190 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 191 | ) as f: 192 | stats = json.load(f) 193 | stats = stats["pitch"] + stats["energy"][:2] 194 | 195 | fig = plot_mel( 196 | [ 197 | (mel_prediction.cpu().numpy(), pitch, energy), 198 | (mel_target.cpu().numpy(), pitch, energy), 199 | ], 200 | stats, 201 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 202 | ) 203 | 204 | if vocoder is not None: 205 | from .model import vocoder_infer 206 | 207 | wav_reconstruction = vocoder_infer( 208 | mel_target.unsqueeze(0), 209 | vocoder, 210 | model_config, 211 | preprocess_config, 212 | )[0] 213 | wav_prediction = vocoder_infer( 214 | mel_prediction.unsqueeze(0), 215 | vocoder, 216 | model_config, 217 | preprocess_config, 218 | )[0] 219 | else: 220 | wav_reconstruction = wav_prediction = None 221 | 222 | return fig, wav_reconstruction, wav_prediction, basename 223 | 224 | 225 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path, tag=None): 226 | 227 | basenames = targets[0] 228 | for i in range(len(predictions[0])): 229 | basename = basenames[i] 230 | src_len = predictions[8][i].item() 231 | mel_len = predictions[9][i].item() 232 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) 233 | duration = predictions[5][i, :src_len].detach().cpu().numpy() 234 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 235 | pitch = predictions[2][i, :src_len].detach().cpu().numpy() 236 | pitch = expand(pitch, duration) 237 | else: 238 | pitch = predictions[2][i, :mel_len].detach().cpu().numpy() 239 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 240 | energy = predictions[3][i, :src_len].detach().cpu().numpy() 241 | energy = expand(energy, duration) 242 | else: 243 | energy = predictions[3][i, :mel_len].detach().cpu().numpy() 244 | 245 | with open( 246 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 247 | ) as f: 248 | stats = json.load(f) 249 | stats = stats["pitch"] + stats["energy"][:2] 250 | 251 | fig = plot_mel( 252 | [ 253 | (mel_prediction.cpu().numpy(), pitch, energy), 254 | ], 255 | stats, 256 | ["Synthetized Spectrogram"], 257 | ) 258 | plt.savefig(os.path.join(path, "{}{}.png".format(basename, f"_{tag}" if tag is not None else ""))) 259 | plt.close() 260 | 261 | from .model import vocoder_infer 262 | 263 | mel_predictions = predictions[1].transpose(1, 2) 264 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] 265 | wav_predictions = vocoder_infer( 266 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths 267 | ) 268 | 269 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 270 | for wav, basename in zip(wav_predictions, basenames): 271 | wavfile.write(os.path.join(path, "{}{}.wav".format(basename, f"_{tag}" if tag is not None else "")), sampling_rate, wav) 272 | 273 | 274 | def plot_mel(data, stats, titles): 275 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 276 | if titles is None: 277 | titles = [None for i in range(len(data))] 278 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 279 | pitch_min = pitch_min * pitch_std + pitch_mean 280 | pitch_max = pitch_max * pitch_std + pitch_mean 281 | 282 | def add_axis(fig, old_ax): 283 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 284 | ax.set_facecolor("None") 285 | return ax 286 | 287 | for i in range(len(data)): 288 | mel, pitch, energy = data[i] 289 | pitch = pitch * pitch_std + pitch_mean 290 | axes[i][0].imshow(mel, origin="lower") 291 | axes[i][0].set_aspect(2.5, adjustable="box") 292 | axes[i][0].set_ylim(0, mel.shape[0]) 293 | axes[i][0].set_title(titles[i], fontsize="medium") 294 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 295 | axes[i][0].set_anchor("W") 296 | 297 | ax1 = add_axis(fig, axes[i][0]) 298 | ax1.plot(pitch, color="tomato") 299 | ax1.set_xlim(0, mel.shape[1]) 300 | ax1.set_ylim(0, pitch_max) 301 | ax1.set_ylabel("F0", color="tomato") 302 | ax1.tick_params( 303 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 304 | ) 305 | 306 | ax2 = add_axis(fig, axes[i][0]) 307 | ax2.plot(energy, color="darkviolet") 308 | ax2.set_xlim(0, mel.shape[1]) 309 | ax2.set_ylim(energy_min, energy_max) 310 | ax2.set_ylabel("Energy", color="darkviolet") 311 | ax2.yaxis.set_label_position("right") 312 | ax2.tick_params( 313 | labelsize="x-small", 314 | colors="darkviolet", 315 | bottom=False, 316 | labelbottom=False, 317 | left=False, 318 | labelleft=False, 319 | right=True, 320 | labelright=True, 321 | ) 322 | 323 | return fig 324 | 325 | 326 | def pad_1D(inputs, PAD=0): 327 | def pad_data(x, length, PAD): 328 | x_padded = np.pad( 329 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 330 | ) 331 | return x_padded 332 | 333 | max_len = max((len(x) for x in inputs)) 334 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 335 | 336 | return padded 337 | 338 | 339 | def pad_2D(inputs, maxlen=None): 340 | def pad(x, max_len): 341 | PAD = 0 342 | if np.shape(x)[0] > max_len: 343 | raise ValueError("not max_len") 344 | 345 | s = np.shape(x)[1] 346 | x_padded = np.pad( 347 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 348 | ) 349 | return x_padded[:, :s] 350 | 351 | if maxlen: 352 | output = np.stack([pad(x, maxlen) for x in inputs]) 353 | else: 354 | max_len = max(np.shape(x)[0] for x in inputs) 355 | output = np.stack([pad(x, max_len) for x in inputs]) 356 | 357 | return output 358 | 359 | 360 | def pad(input_ele, mel_max_length=None): 361 | if mel_max_length: 362 | max_len = mel_max_length 363 | else: 364 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 365 | 366 | out_list = list() 367 | for i, batch in enumerate(input_ele): 368 | if len(batch.shape) == 1: 369 | one_batch_padded = F.pad( 370 | batch, (0, max_len - batch.size(0)), "constant", 0.0 371 | ) 372 | elif len(batch.shape) == 2: 373 | one_batch_padded = F.pad( 374 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 375 | ) 376 | out_list.append(one_batch_padded) 377 | out_padded = torch.stack(out_list) 378 | return out_padded 379 | --------------------------------------------------------------------------------