├── .gitignore
├── LICENSE
├── README.md
├── audio
├── __init__.py
├── audio_processing.py
├── stft.py
└── tools.py
├── config
├── AIHub-MMV
│ ├── model.yaml
│ ├── preprocess.yaml
│ └── train.yaml
├── IEMOCAP
│ ├── model.yaml
│ ├── preprocess.yaml
│ └── train.yaml
└── README.md
├── dataset.py
├── evaluate.py
├── hifigan
├── LICENSE
├── __init__.py
├── config.json
├── generator_LJSpeech.pth.tar.zip
├── generator_universal.pth.tar.zip
└── models.py
├── img
├── emotional-fastspeech2-audios.png
├── emotional-fastspeech2-images.png
├── emotional-fastspeech2-scalars.png
├── model.png
├── model_conversational_tts.png
├── model_conversational_tts_chat_history.png
└── model_emotional_tts.png
├── model
├── __init__.py
├── fastspeech2.py
├── loss.py
├── modules.py
└── optimizer.py
├── preparation
├── aihub_mmv.py
└── iemocap.py
├── prepare_align.py
├── prepare_data.py
├── preprocess.py
├── preprocessor
├── aihub_mmv.py
├── iemocap.py
└── preprocessor.py
├── requirements.txt
├── synthesize.py
├── text
├── __init__.py
├── cleaners.py
├── cmudict.py
├── korean.py
├── korean_dict.py
├── numbers.py
├── pinyin.py
└── symbols.py
├── train.py
├── transformer
├── Constants.py
├── Layers.py
├── Models.py
├── Modules.py
├── SubLayers.py
└── __init__.py
└── utils
├── model.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | __pycache__
107 | .vscode
108 | .DS_Store
109 |
110 | # nohup
111 | *.out
112 |
113 | # data, checkpoint, and models
114 | raw_data/
115 | output/
116 | *.npy
117 | TextGrid/
118 | hifigan/*.pth.tar
119 | lexicon/
120 | montreal-forced-aligner/
121 | preprocessed_data/
122 | preparation/*.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Keon Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 | MIT License
24 |
25 | Copyright (c) 2020 Chung-Ming Chien
26 |
27 | Permission is hereby granted, free of charge, to any person obtaining a copy
28 | of this software and associated documentation files (the "Software"), to deal
29 | in the Software without restriction, including without limitation the rights
30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
31 | copies of the Software, and to permit persons to whom the Software is
32 | furnished to do so, subject to the following conditions:
33 |
34 | The above copyright notice and this permission notice shall be included in all
35 | copies or substantial portions of the Software.
36 |
37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
43 | SOFTWARE.
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Expressive-FastSpeech2 - PyTorch Implementation
2 |
3 | ## Contributions
4 |
5 | 1. **`Non-autoregressive Expressive TTS`**: This project aims to provide a cornerstone for future research and application on a non-autoregressive expressive TTS including `Emotional TTS` and `Conversational TTS`. For datasets, [AIHub Multimodal Video AI datasets](https://www.aihub.or.kr/aidata/137) and [IEMOCAP database](https://sail.usc.edu/iemocap/) are picked for Korean and English, respectively.
6 |
7 | **Note**: If you are interested in [GST-Tacotron](https://arxiv.org/abs/1803.09017) or [VAE-Tacotron](https://arxiv.org/abs/1812.04342) like expressive stylistic TTS model but under non-autoregressive decoding, you may also be interested in [STYLER](https://arxiv.org/abs/2103.09474) [[demo](https://keonlee9420.github.io/STYLER-Demo/), [code](https://github.com/keonlee9420/STYLER)].
8 |
9 | 2. **`Annotated Data Processing`**: This project shed light on how to handle the new dataset, even with a different language, for the successful training of non-autoregressive emotional TTS.
10 | 3. **`English and Korean TTS`**: In addition to English, this project gives a broad view of treating Korean for the non-autoregressive TTS where the additional data processing must be considered under the language-specific features (e.g., training Montreal Forced Aligner with your own language and dataset). Please closely look into `text/`.
11 | 4. **`Adopting Own Language`**: For those who are interested in adapting other languages, please refer to the ["Training with your own dataset (own language)" section](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/categorical#training-with-your-own-dataset-own-language) of the [categorical branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/categorical).
12 |
13 | ## Repository Structure
14 |
15 | In this project, FastSpeech2 is adapted as a base non-autoregressive multi-speaker TTS framework, so it would be helpful to read [the paper](https://arxiv.org/abs/2006.04558) and [code](https://github.com/ming024/FastSpeech2) first (Also see [FastSpeech2 branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/FastSpeech2)).
16 |
17 |
18 |
19 |
20 |
21 | 1. `Emotional TTS`: Following branches contain implementations of the basic paradigm intorduced by [Emotional End-to-End Neural Speech synthesizer](https://arxiv.org/pdf/1711.05447.pdf).
22 |
23 |
24 |
25 |
26 |
27 | - [categorical branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/categorical): only conditioning categorical emotional descriptors (such as happy, sad, etc.)
28 | - [continuous branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/continuous): conditioning continuous emotional descriptors (such as arousal, valence, etc.) in addition to categorical emotional descriptors
29 | 2. `Conversational TTS`: Following branch contains implementation of [Conversational End-to-End TTS for Voice Agent](https://arxiv.org/abs/2005.10438)
30 |
31 |
32 |
33 |
34 |
35 | - [conversational branch](https://github.com/keonlee9420/Expressive-FastSpeech2/tree/conversational): conditioning chat history
36 |
37 | ## Citation
38 |
39 | If you would like to use or refer to this implementation, please cite the repo.
40 |
41 | ```bash
42 | @misc{lee2021expressive_fastspeech2,
43 | author = {Lee, Keon},
44 | title = {Expressive-FastSpeech2},
45 | year = {2021},
46 | publisher = {GitHub},
47 | journal = {GitHub repository},
48 | howpublished = {\url{https://github.com/keonlee9420/Expressive-FastSpeech2}}
49 | }
50 | ```
51 |
52 | ## References
53 |
54 | - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) (Later than 2021.02.26 ver.)
55 | - [HGU-DLLAB's Korean-FastSpeech2-Pytorch](https://github.com/HGU-DLLAB/Korean-FastSpeech2-Pytorch)
56 | - [hccho2's Tacotron2-Wavenet-Korean-TTS](https://github.com/hccho2/Tacotron2-Wavenet-Korean-TTS)
57 | - [carpedm20' multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-speaker-tacotron-tensorflow)
--------------------------------------------------------------------------------
/audio/__init__.py:
--------------------------------------------------------------------------------
1 | import audio.tools
2 | import audio.stft
3 | import audio.audio_processing
4 |
--------------------------------------------------------------------------------
/audio/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import librosa.util as librosa_util
4 | from scipy.signal import get_window
5 |
6 |
7 | def window_sumsquare(
8 | window,
9 | n_frames,
10 | hop_length,
11 | win_length,
12 | n_fft,
13 | dtype=np.float32,
14 | norm=None,
15 | ):
16 | """
17 | # from librosa 0.6
18 | Compute the sum-square envelope of a window function at a given hop length.
19 |
20 | This is used to estimate modulation effects induced by windowing
21 | observations in short-time fourier transforms.
22 |
23 | Parameters
24 | ----------
25 | window : string, tuple, number, callable, or list-like
26 | Window specification, as in `get_window`
27 |
28 | n_frames : int > 0
29 | The number of analysis frames
30 |
31 | hop_length : int > 0
32 | The number of samples to advance between frames
33 |
34 | win_length : [optional]
35 | The length of the window function. By default, this matches `n_fft`.
36 |
37 | n_fft : int > 0
38 | The length of each analysis frame.
39 |
40 | dtype : np.dtype
41 | The data type of the output
42 |
43 | Returns
44 | -------
45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46 | The sum-squared envelope of the window function
47 | """
48 | if win_length is None:
49 | win_length = n_fft
50 |
51 | n = n_fft + hop_length * (n_frames - 1)
52 | x = np.zeros(n, dtype=dtype)
53 |
54 | # Compute the squared window at the desired length
55 | win_sq = get_window(window, win_length, fftbins=True)
56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57 | win_sq = librosa_util.pad_center(win_sq, n_fft)
58 |
59 | # Fill the envelope
60 | for i in range(n_frames):
61 | sample = i * hop_length
62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63 | return x
64 |
65 |
66 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
67 | """
68 | PARAMS
69 | ------
70 | magnitudes: spectrogram magnitudes
71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72 | """
73 |
74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75 | angles = angles.astype(np.float32)
76 | angles = torch.autograd.Variable(torch.from_numpy(angles))
77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78 |
79 | for i in range(n_iters):
80 | _, angles = stft_fn.transform(signal)
81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82 | return signal
83 |
84 |
85 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
86 | """
87 | PARAMS
88 | ------
89 | C: compression factor
90 | """
91 | return torch.log(torch.clamp(x, min=clip_val) * C)
92 |
93 |
94 | def dynamic_range_decompression(x, C=1):
95 | """
96 | PARAMS
97 | ------
98 | C: compression factor used to compress
99 | """
100 | return torch.exp(x) / C
101 |
--------------------------------------------------------------------------------
/audio/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.signal import get_window
5 | from librosa.util import pad_center, tiny
6 | from librosa.filters import mel as librosa_mel_fn
7 |
8 | from audio.audio_processing import (
9 | dynamic_range_compression,
10 | dynamic_range_decompression,
11 | window_sumsquare,
12 | )
13 |
14 |
15 | class STFT(torch.nn.Module):
16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17 |
18 | def __init__(self, filter_length, hop_length, win_length, window="hann"):
19 | super(STFT, self).__init__()
20 | self.filter_length = filter_length
21 | self.hop_length = hop_length
22 | self.win_length = win_length
23 | self.window = window
24 | self.forward_transform = None
25 | scale = self.filter_length / self.hop_length
26 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
27 |
28 | cutoff = int((self.filter_length / 2 + 1))
29 | fourier_basis = np.vstack(
30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31 | )
32 |
33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34 | inverse_basis = torch.FloatTensor(
35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36 | )
37 |
38 | if window is not None:
39 | assert filter_length >= win_length
40 | # get window and zero center pad it to filter_length
41 | fft_window = get_window(window, win_length, fftbins=True)
42 | fft_window = pad_center(fft_window, filter_length)
43 | fft_window = torch.from_numpy(fft_window).float()
44 |
45 | # window the bases
46 | forward_basis *= fft_window
47 | inverse_basis *= fft_window
48 |
49 | self.register_buffer("forward_basis", forward_basis.float())
50 | self.register_buffer("inverse_basis", inverse_basis.float())
51 |
52 | def transform(self, input_data):
53 | num_batches = input_data.size(0)
54 | num_samples = input_data.size(1)
55 |
56 | self.num_samples = num_samples
57 |
58 | # similar to librosa, reflect-pad the input
59 | input_data = input_data.view(num_batches, 1, num_samples)
60 | input_data = F.pad(
61 | input_data.unsqueeze(1),
62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63 | mode="reflect",
64 | )
65 | input_data = input_data.squeeze(1)
66 |
67 | forward_transform = F.conv1d(
68 | input_data.cuda(),
69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
70 | stride=self.hop_length,
71 | padding=0,
72 | ).cpu()
73 |
74 | cutoff = int((self.filter_length / 2) + 1)
75 | real_part = forward_transform[:, :cutoff, :]
76 | imag_part = forward_transform[:, cutoff:, :]
77 |
78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80 |
81 | return magnitude, phase
82 |
83 | def inverse(self, magnitude, phase):
84 | recombine_magnitude_phase = torch.cat(
85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86 | )
87 |
88 | inverse_transform = F.conv_transpose1d(
89 | recombine_magnitude_phase,
90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91 | stride=self.hop_length,
92 | padding=0,
93 | )
94 |
95 | if self.window is not None:
96 | window_sum = window_sumsquare(
97 | self.window,
98 | magnitude.size(-1),
99 | hop_length=self.hop_length,
100 | win_length=self.win_length,
101 | n_fft=self.filter_length,
102 | dtype=np.float32,
103 | )
104 | # remove modulation effects
105 | approx_nonzero_indices = torch.from_numpy(
106 | np.where(window_sum > tiny(window_sum))[0]
107 | )
108 | window_sum = torch.autograd.Variable(
109 | torch.from_numpy(window_sum), requires_grad=False
110 | )
111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113 | approx_nonzero_indices
114 | ]
115 |
116 | # scale by hop ratio
117 | inverse_transform *= float(self.filter_length) / self.hop_length
118 |
119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121 |
122 | return inverse_transform
123 |
124 | def forward(self, input_data):
125 | self.magnitude, self.phase = self.transform(input_data)
126 | reconstruction = self.inverse(self.magnitude, self.phase)
127 | return reconstruction
128 |
129 |
130 | class TacotronSTFT(torch.nn.Module):
131 | def __init__(
132 | self,
133 | filter_length,
134 | hop_length,
135 | win_length,
136 | n_mel_channels,
137 | sampling_rate,
138 | mel_fmin,
139 | mel_fmax,
140 | ):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147 | )
148 | mel_basis = torch.from_numpy(mel_basis).float()
149 | self.register_buffer("mel_basis", mel_basis)
150 |
151 | def spectral_normalize(self, magnitudes):
152 | output = dynamic_range_compression(magnitudes)
153 | return output
154 |
155 | def spectral_de_normalize(self, magnitudes):
156 | output = dynamic_range_decompression(magnitudes)
157 | return output
158 |
159 | def mel_spectrogram(self, y):
160 | """Computes mel-spectrograms from a batch of waves
161 | PARAMS
162 | ------
163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164 |
165 | RETURNS
166 | -------
167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168 | """
169 | assert torch.min(y.data) >= -1
170 | assert torch.max(y.data) <= 1
171 |
172 | magnitudes, phases = self.stft_fn.transform(y)
173 | magnitudes = magnitudes.data
174 | mel_output = torch.matmul(self.mel_basis, magnitudes)
175 | mel_output = self.spectral_normalize(mel_output)
176 | energy = torch.norm(magnitudes, dim=1)
177 |
178 | return mel_output, energy
179 |
--------------------------------------------------------------------------------
/audio/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.io.wavfile import write
4 |
5 | from audio.audio_processing import griffin_lim
6 |
7 |
8 | def get_mel_from_wav(audio, _stft):
9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
10 | audio = torch.autograd.Variable(audio, requires_grad=False)
11 | melspec, energy = _stft.mel_spectrogram(audio)
12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14 |
15 | return melspec, energy
16 |
17 |
18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
19 | mel = torch.stack([mel])
20 | mel_decompress = _stft.spectral_de_normalize(mel)
21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
22 | spec_from_mel_scaling = 1000
23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
26 |
27 | audio = griffin_lim(
28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
29 | )
30 |
31 | audio = audio.squeeze()
32 | audio = audio.cpu().numpy()
33 | audio_path = out_filename
34 | write(audio_path, _stft.sampling_rate, audio)
35 |
--------------------------------------------------------------------------------
/config/AIHub-MMV/model.yaml:
--------------------------------------------------------------------------------
1 | transformer:
2 | encoder_layer: 4
3 | encoder_head: 2
4 | encoder_hidden: 256
5 | decoder_layer: 6
6 | decoder_head: 2
7 | decoder_hidden: 256
8 | conv_filter_size: 1024
9 | conv_kernel_size: [9, 1]
10 | encoder_dropout: 0.2
11 | decoder_dropout: 0.2
12 |
13 | variance_predictor:
14 | filter_size: 256
15 | kernel_size: 3
16 | dropout: 0.5
17 |
18 | variance_embedding:
19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
21 | n_bins: 256
22 |
23 | # gst:
24 | # use_gst: False
25 | # conv_filters: [32, 32, 64, 64, 128, 128]
26 | # gru_hidden: 128
27 | # token_size: 128
28 | # n_style_token: 10
29 | # attn_head: 4
30 |
31 | multi_speaker: True
32 |
33 | multi_emotion: True
34 |
35 | max_seq_len: 2000
36 |
37 | vocoder:
38 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
39 | speaker: "universal" # support 'LJSpeech', 'universal'
40 |
--------------------------------------------------------------------------------
/config/AIHub-MMV/preprocess.yaml:
--------------------------------------------------------------------------------
1 | dataset: "AIHub-MMV"
2 |
3 | path:
4 | corpus_path: "/path/to/AIHub-MMV"
5 | sub_dir_name: "clips"
6 | lexicon_path: "lexicon/aihub-mmv-lexicon.txt"
7 | fixed_text_path: "./preparation/aihub_mmv_fixed.txt"
8 | raw_path: "./raw_data/AIHub-MMV"
9 | preprocessed_path: "./preprocessed_data/AIHub-MMV"
10 |
11 | preprocessing:
12 | val_size: 512
13 | text:
14 | text_cleaners: ["korean_cleaners"]
15 | language: "kr"
16 | audio:
17 | sampling_rate: 22050
18 | max_wav_value: 32768.0
19 | stft:
20 | filter_length: 1024
21 | hop_length: 256
22 | win_length: 1024
23 | mel:
24 | n_mel_channels: 80
25 | mel_fmin: 0
26 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
27 | pitch:
28 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
29 | normalization: True
30 | energy:
31 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
32 | normalization: True
33 |
--------------------------------------------------------------------------------
/config/AIHub-MMV/train.yaml:
--------------------------------------------------------------------------------
1 | path:
2 | ckpt_path: "./output/ckpt/AIHub-MMV"
3 | log_path: "./output/log/AIHub-MMV"
4 | result_path: "./output/result/AIHub-MMV"
5 | optimizer:
6 | batch_size: 16
7 | betas: [0.9, 0.98]
8 | eps: 0.000000001
9 | weight_decay: 0.0
10 | grad_clip_thresh: 1.0
11 | grad_acc_step: 1
12 | warm_up_step: 4000
13 | anneal_steps: [300000, 400000, 500000]
14 | anneal_rate: 0.3
15 | step:
16 | total_step: 900000
17 | log_step: 100
18 | synth_step: 1000
19 | val_step: 1000
20 | save_step: 100000
21 |
--------------------------------------------------------------------------------
/config/IEMOCAP/model.yaml:
--------------------------------------------------------------------------------
1 | transformer:
2 | encoder_layer: 4
3 | encoder_head: 2
4 | encoder_hidden: 256
5 | decoder_layer: 6
6 | decoder_head: 2
7 | decoder_hidden: 256
8 | conv_filter_size: 1024
9 | conv_kernel_size: [9, 1]
10 | encoder_dropout: 0.2
11 | decoder_dropout: 0.2
12 |
13 | variance_predictor:
14 | filter_size: 256
15 | kernel_size: 3
16 | dropout: 0.5
17 |
18 | variance_embedding:
19 | pitch_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
21 | n_bins: 256
22 |
23 | # gst:
24 | # use_gst: False
25 | # conv_filters: [32, 32, 64, 64, 128, 128]
26 | # gru_hidden: 128
27 | # token_size: 128
28 | # n_style_token: 10
29 | # attn_head: 4
30 |
31 | multi_speaker: True
32 |
33 | multi_emotion: True
34 |
35 | max_seq_len: 2000
36 |
37 | vocoder:
38 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
39 | speaker: "universal" # support 'LJSpeech', 'universal'
40 |
--------------------------------------------------------------------------------
/config/IEMOCAP/preprocess.yaml:
--------------------------------------------------------------------------------
1 | dataset: "IEMOCAP"
2 |
3 | path:
4 | corpus_path: "/path/to/IEMOCAP_full_release"
5 | sub_dir_name: "sessions"
6 | lexicon_path: "lexicon/iemocap-lexicon.txt"
7 | fixed_text_path: "./preparation/iemocap_fixed.txt"
8 | raw_path: "./raw_data/IEMOCAP"
9 | preprocessed_path: "./preprocessed_data/IEMOCAP"
10 |
11 | preprocessing:
12 | val_size: 512
13 | text:
14 | text_cleaners: ["english_cleaners"]
15 | language: "en"
16 | audio:
17 | sampling_rate: 22050
18 | max_wav_value: 32768.0
19 | stft:
20 | filter_length: 1024
21 | hop_length: 256
22 | win_length: 1024
23 | mel:
24 | n_mel_channels: 80
25 | mel_fmin: 0
26 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
27 | pitch:
28 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
29 | normalization: True
30 | energy:
31 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
32 | normalization: True
33 |
--------------------------------------------------------------------------------
/config/IEMOCAP/train.yaml:
--------------------------------------------------------------------------------
1 | path:
2 | ckpt_path: "./output/ckpt/IEMOCAP"
3 | log_path: "./output/log/IEMOCAP"
4 | result_path: "./output/result/IEMOCAP"
5 | optimizer:
6 | batch_size: 16
7 | betas: [0.9, 0.98]
8 | eps: 0.000000001
9 | weight_decay: 0.0
10 | grad_clip_thresh: 1.0
11 | grad_acc_step: 1
12 | warm_up_step: 4000
13 | anneal_steps: [300000, 400000, 500000]
14 | anneal_rate: 0.3
15 | step:
16 | total_step: 900000
17 | log_step: 100
18 | synth_step: 1000
19 | val_step: 1000
20 | save_step: 100000
21 |
--------------------------------------------------------------------------------
/config/README.md:
--------------------------------------------------------------------------------
1 | # Config
2 |
3 | Here is the config file used to train the multi-speaker emotional TTS models. Two different configurations are given:
4 |
5 | - AIHub-MMV: suggested configuration for AIHub-MMV dataset.
6 | - IEMOCAP: suggested configuration for IEMOCAP dataset.
7 |
8 | Some important hyper-parameters are explained here.
9 |
10 | ## preprocess.yaml
11 |
12 | - **path.lexicon_path**: the lexicon (which maps words to phonemes) used by Montreal Forced Aligner.
13 | - **mel.stft.mel_fmax**: set it to 8000 if HiFi-GAN vocoder is used, and set it to null if MelGAN is used.
14 | - **pitch.feature & energy.feature**: the original paper proposed to predict and apply frame-level pitch and energy features to the inputs of the TTS decoder to control the pitch and energy of the synthesized utterances.
15 | However, in our experiments, we find that using phoneme-level features makes the prosody of the synthesized utterances more natural.
16 | - **pitch.normalization & energy.normalization**: to normalize the pitch and energy values or not.
17 | The original paper did not normalize these values.
18 |
19 | ## train.yaml
20 |
21 | - **optimizer.grad_acc_step**: the number of batches of gradient accumulation before updating the model parameters and call optimizer.zero_grad(), which is useful if you wish to train the model with a large batch size but you do not have sufficient GPU memory.
22 | - **optimizer.anneal_steps & optimizer.anneal_rate**: the learning rate is reduced at the **anneal_steps** by the ratio specified with **anneal_rate**.
23 |
24 | ## model.yaml
25 |
26 | - **transformer.decoder_layer**: the original paper used a 4-layer decoder, but we find it better to use a 6-layer decoder, especially for multi-speaker TTS.
27 | - **variance_embedding.pitch_quantization**: when the pitch values are normalized as specified in ``preprocess.yaml``, it is not valid to use log-scale quantization bins as proposed in the original paper, so we use linear-scaled bins instead.
28 | - **multi_speaker**: to apply a speaker embedding table to enable multi-speaker TTS or not.
29 | - **multi_emotion**: to apply a emotion embedding table to enable multi-emotion TTS or not.
30 | - **vocoder.speaker**: should be set to 'universal'.
31 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | from tqdm import tqdm
5 |
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 |
9 | from text import text_to_sequence
10 | from utils.tools import pad_1D, pad_2D
11 |
12 |
13 | class Dataset(Dataset):
14 | def __init__(
15 | self, filename, preprocess_config, model_config, train_config, sort=False, drop_last=False
16 | ):
17 | self.dataset_name = preprocess_config["dataset"]
18 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
19 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
20 | self.max_seq_len = model_config["max_seq_len"]
21 | self.batch_size = train_config["optimizer"]["batch_size"]
22 |
23 | self.basename, self.speaker, self.text, self.raw_text, self.aux_data = self.process_meta(
24 | filename
25 | )
26 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
27 | self.speaker_map = json.load(f)
28 | with open(os.path.join(self.preprocessed_path, "emotions.json")) as f:
29 | json_raw = json.load(f)
30 | self.emotion_map = json_raw["emotion_dict"]
31 | self.arousal_map = json_raw["arousal_dict"]
32 | self.valence_map = json_raw["valence_dict"]
33 | self.sort = sort
34 | self.drop_last = drop_last
35 |
36 | def __len__(self):
37 | return len(self.text)
38 |
39 | def __getitem__(self, idx):
40 | basename = self.basename[idx]
41 | speaker = self.speaker[idx]
42 | speaker_id = self.speaker_map[speaker]
43 | aux_data = self.aux_data[idx].split("|")
44 | emotion = self.emotion_map[aux_data[-3]]
45 | arousal = self.arousal_map[aux_data[-2]]
46 | valence = self.valence_map[aux_data[-1]]
47 | raw_text = self.raw_text[idx]
48 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
49 | mel_path = os.path.join(
50 | self.preprocessed_path,
51 | "mel",
52 | "{}-mel-{}.npy".format(speaker, basename),
53 | )
54 | mel = np.load(mel_path)
55 | pitch_path = os.path.join(
56 | self.preprocessed_path,
57 | "pitch",
58 | "{}-pitch-{}.npy".format(speaker, basename),
59 | )
60 | pitch = np.load(pitch_path)
61 | energy_path = os.path.join(
62 | self.preprocessed_path,
63 | "energy",
64 | "{}-energy-{}.npy".format(speaker, basename),
65 | )
66 | energy = np.load(energy_path)
67 | duration_path = os.path.join(
68 | self.preprocessed_path,
69 | "duration",
70 | "{}-duration-{}.npy".format(speaker, basename),
71 | )
72 | duration = np.load(duration_path)
73 |
74 | sample = {
75 | "id": basename,
76 | "speaker": speaker_id,
77 | "emotion": emotion,
78 | "arousal": arousal,
79 | "valence": valence,
80 | "text": phone,
81 | "raw_text": raw_text,
82 | "mel": mel,
83 | "pitch": pitch,
84 | "energy": energy,
85 | "duration": duration,
86 | }
87 |
88 | return sample
89 |
90 | def process_meta(self, filename):
91 | with open(
92 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
93 | ) as f:
94 | name = []
95 | speaker = []
96 | text = []
97 | raw_text = []
98 | aux_data = []
99 | for line in tqdm(f.readlines()):
100 | line_split = line.strip("\n").split("|")
101 | n, s, t, r = line_split[:4]
102 | mel_path = os.path.join(
103 | self.preprocessed_path,
104 | "mel",
105 | "{}-mel-{}.npy".format(s, n),
106 | )
107 | mel = np.load(mel_path)
108 | if mel.shape[0] > self.max_seq_len:
109 | continue
110 | a = "|".join(line_split[4:])
111 | name.append(n)
112 | speaker.append(s)
113 | text.append(t)
114 | raw_text.append(r)
115 | aux_data.append(a)
116 | return name, speaker, text, raw_text, aux_data
117 |
118 | def reprocess(self, data, idxs):
119 | ids = [data[idx]["id"] for idx in idxs]
120 | speakers = [data[idx]["speaker"] for idx in idxs]
121 | emotions = [data[idx]["emotion"] for idx in idxs]
122 | arousals = [data[idx]["arousal"] for idx in idxs]
123 | valences = [data[idx]["valence"] for idx in idxs]
124 | texts = [data[idx]["text"] for idx in idxs]
125 | raw_texts = [data[idx]["raw_text"] for idx in idxs]
126 | mels = [data[idx]["mel"] for idx in idxs]
127 | pitches = [data[idx]["pitch"] for idx in idxs]
128 | energies = [data[idx]["energy"] for idx in idxs]
129 | durations = [data[idx]["duration"] for idx in idxs]
130 |
131 | text_lens = np.array([text.shape[0] for text in texts])
132 | mel_lens = np.array([mel.shape[0] for mel in mels])
133 |
134 | speakers = np.array(speakers)
135 | emotions = np.array(emotions)
136 | arousals = np.array(arousals)
137 | valences = np.array(valences)
138 | texts = pad_1D(texts)
139 | mels = pad_2D(mels)
140 | pitches = pad_1D(pitches)
141 | energies = pad_1D(energies)
142 | durations = pad_1D(durations)
143 |
144 | return (
145 | ids,
146 | raw_texts,
147 | speakers,
148 | emotions,
149 | arousals,
150 | valences,
151 | texts,
152 | text_lens,
153 | max(text_lens),
154 | mels,
155 | mel_lens,
156 | max(mel_lens),
157 | pitches,
158 | energies,
159 | durations,
160 | )
161 |
162 | def collate_fn(self, data):
163 | data_size = len(data)
164 |
165 | if self.sort:
166 | len_arr = np.array([d["text"].shape[0] for d in data])
167 | idx_arr = np.argsort(-len_arr)
168 | else:
169 | idx_arr = np.arange(data_size)
170 |
171 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
172 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
173 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
174 | if not self.drop_last and len(tail) > 0:
175 | idx_arr += [tail.tolist()]
176 |
177 | output = list()
178 | for idx in idx_arr:
179 | output.append(self.reprocess(data, idx))
180 |
181 | return output
182 |
183 |
184 | class TextDataset(Dataset):
185 | def __init__(self, filepath, preprocess_config, model_config):
186 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
187 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
188 | self.max_seq_len = model_config["max_seq_len"]
189 |
190 | self.basename, self.speaker, self.text, self.raw_text, self.aux_data = self.process_meta(
191 | filepath
192 | )
193 | with open(
194 | os.path.join(
195 | preprocess_config["path"]["preprocessed_path"], "speakers.json"
196 | )
197 | ) as f:
198 | self.speaker_map = json.load(f)
199 | with open(os.path.join(self.preprocessed_path, "emotions.json")) as f:
200 | json_raw = json.load(f)
201 | self.emotion_map = json_raw["emotion_dict"]
202 | self.arousal_map = json_raw["arousal_dict"]
203 | self.valence_map = json_raw["valence_dict"]
204 |
205 | def __len__(self):
206 | return len(self.text)
207 |
208 | def __getitem__(self, idx):
209 | basename = self.basename[idx]
210 | speaker = self.speaker[idx]
211 | speaker_id = self.speaker_map[speaker]
212 | aux_data = self.aux_data[idx].split("|")
213 | emotion = self.emotion_map[aux_data[-3]]
214 | arousal = self.arousal_map[aux_data[-2]]
215 | valence = self.valence_map[aux_data[-1]]
216 | raw_text = self.raw_text[idx]
217 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
218 |
219 | return (basename, speaker_id, emotion, arousal, valence, phone, raw_text)
220 |
221 | def process_meta(self, filename):
222 | with open(filename, "r", encoding="utf-8") as f:
223 | name = []
224 | speaker = []
225 | text = []
226 | raw_text = []
227 | aux_data = []
228 | for line in tqdm(f.readlines()):
229 | line_split = line.strip("\n").split("|")
230 | n, s, t, r = line_split[:4]
231 | mel_path = os.path.join(
232 | self.preprocessed_path,
233 | "mel",
234 | "{}-mel-{}.npy".format(s, n),
235 | )
236 | mel = np.load(mel_path)
237 | if mel.shape[0] > self.max_seq_len:
238 | continue
239 | a = "|".join(line_split[4:])
240 | name.append(n)
241 | speaker.append(s)
242 | text.append(t)
243 | raw_text.append(r)
244 | aux_data.append(a)
245 | return name, speaker, text, raw_text, aux_data
246 |
247 | def collate_fn(self, data):
248 | ids = [d[0] for d in data]
249 | speakers = np.array([d[1] for d in data])
250 | emotions = np.array([d[2] for d in data])
251 | arousals = np.array([d[3] for d in data])
252 | valences = np.array([d[4] for d in data])
253 | texts = [d[5] for d in data]
254 | raw_texts = [d[6] for d in data]
255 | text_lens = np.array([text.shape[0] for text in texts])
256 |
257 | texts = pad_1D(texts)
258 |
259 | return ids, raw_texts, speakers, emotions, arousals, valences, texts, text_lens, max(text_lens)
260 |
261 |
262 | if __name__ == "__main__":
263 | # Test
264 | import torch
265 | import yaml
266 | from torch.utils.data import DataLoader
267 | from utils.utils import to_device
268 |
269 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270 | preprocess_config = yaml.load(
271 | open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader
272 | )
273 | model_config = yaml.load(
274 | open("./config/LJSpeech/model.yaml", "r"), Loader=yaml.FullLoader
275 | )
276 | train_config = yaml.load(
277 | open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader
278 | )
279 |
280 | train_dataset = Dataset(
281 | "train.txt", preprocess_config, model_config, train_config, sort=True, drop_last=True
282 | )
283 | val_dataset = Dataset(
284 | "val.txt", preprocess_config, model_config, train_config, sort=False, drop_last=False
285 | )
286 | train_loader = DataLoader(
287 | train_dataset,
288 | batch_size=train_config["optimizer"]["batch_size"] * 4,
289 | shuffle=True,
290 | collate_fn=train_dataset.collate_fn,
291 | )
292 | val_loader = DataLoader(
293 | val_dataset,
294 | batch_size=train_config["optimizer"]["batch_size"],
295 | shuffle=False,
296 | collate_fn=val_dataset.collate_fn,
297 | )
298 |
299 | n_batch = 0
300 | for batchs in train_loader:
301 | for batch in batchs:
302 | to_device(batch, device)
303 | n_batch += 1
304 | print(
305 | "Training set with size {} is composed of {} batches.".format(
306 | len(train_dataset), n_batch
307 | )
308 | )
309 |
310 | n_batch = 0
311 | for batchs in val_loader:
312 | for batch in batchs:
313 | to_device(batch, device)
314 | n_batch += 1
315 | print(
316 | "Validation set with size {} is composed of {} batches.".format(
317 | len(val_dataset), n_batch
318 | )
319 | )
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import yaml
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 |
9 | from utils.model import get_model, get_vocoder
10 | from utils.tools import to_device, log, synth_one_sample
11 | from model import FastSpeech2Loss
12 | from dataset import Dataset
13 |
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | def evaluate(model, step, configs, logger=None, vocoder=None):
19 | preprocess_config, model_config, train_config = configs
20 |
21 | # Get dataset
22 | dataset = Dataset(
23 | "val.txt", preprocess_config, model_config, train_config, sort=False, drop_last=False
24 | )
25 | batch_size = train_config["optimizer"]["batch_size"]
26 | loader = DataLoader(
27 | dataset,
28 | batch_size=batch_size,
29 | shuffle=False,
30 | collate_fn=dataset.collate_fn,
31 | )
32 |
33 | # Get loss function
34 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
35 |
36 | # Evaluation
37 | loss_sums = [0 for _ in range(6)]
38 | for batchs in loader:
39 | for batch in batchs:
40 | batch = to_device(batch, device)
41 | with torch.no_grad():
42 | # Forward
43 | output = model(*(batch[2:]))
44 |
45 | # Cal Loss
46 | losses = Loss(batch, output)
47 |
48 | for i in range(len(losses)):
49 | loss_sums[i] += losses[i].item() * len(batch[0])
50 |
51 | loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums]
52 |
53 | message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
54 | *([step] + [l for l in loss_means])
55 | )
56 |
57 | if logger is not None:
58 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
59 | batch,
60 | output,
61 | vocoder,
62 | model_config,
63 | preprocess_config,
64 | )
65 |
66 | log(logger, step, losses=loss_means)
67 | log(
68 | logger,
69 | fig=fig,
70 | tag="Validation/step_{}_{}".format(step, tag),
71 | )
72 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
73 | log(
74 | logger,
75 | audio=wav_reconstruction,
76 | sampling_rate=sampling_rate,
77 | tag="Validation/step_{}_{}_reconstructed".format(step, tag),
78 | )
79 | log(
80 | logger,
81 | audio=wav_prediction,
82 | sampling_rate=sampling_rate,
83 | tag="Validation/step_{}_{}_synthesized".format(step, tag),
84 | )
85 |
86 | return message
87 |
88 |
89 | if __name__ == "__main__":
90 |
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument("--restore_step", type=int, default=30000)
93 | parser.add_argument(
94 | "-p",
95 | "--preprocess_config",
96 | type=str,
97 | required=True,
98 | help="path to preprocess.yaml",
99 | )
100 | parser.add_argument(
101 | "-m", "--model_config", type=str, required=True, help="path to model.yaml"
102 | )
103 | parser.add_argument(
104 | "-t", "--train_config", type=str, required=True, help="path to train.yaml"
105 | )
106 | args = parser.parse_args()
107 |
108 | # Read Config
109 | preprocess_config = yaml.load(
110 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
111 | )
112 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
113 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
114 | configs = (preprocess_config, model_config, train_config)
115 |
116 | # Get model
117 | model = get_model(args, configs, device, train=False).to(device)
118 |
119 | message = evaluate(model, args.restore_step, configs)
120 | print(message)
--------------------------------------------------------------------------------
/hifigan/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jungil Kong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/hifigan/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import Generator
2 |
3 |
4 | class AttrDict(dict):
5 | def __init__(self, *args, **kwargs):
6 | super(AttrDict, self).__init__(*args, **kwargs)
7 | self.__dict__ = self
--------------------------------------------------------------------------------
/hifigan/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/hifigan/generator_LJSpeech.pth.tar.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/hifigan/generator_LJSpeech.pth.tar.zip
--------------------------------------------------------------------------------
/hifigan/generator_universal.pth.tar.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/hifigan/generator_universal.pth.tar.zip
--------------------------------------------------------------------------------
/hifigan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv1d, ConvTranspose1d
5 | from torch.nn.utils import weight_norm, remove_weight_norm
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | def init_weights(m, mean=0.0, std=0.01):
11 | classname = m.__class__.__name__
12 | if classname.find("Conv") != -1:
13 | m.weight.data.normal_(mean, std)
14 |
15 |
16 | def get_padding(kernel_size, dilation=1):
17 | return int((kernel_size * dilation - dilation) / 2)
18 |
19 |
20 | class ResBlock(torch.nn.Module):
21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22 | super(ResBlock, self).__init__()
23 | self.h = h
24 | self.convs1 = nn.ModuleList(
25 | [
26 | weight_norm(
27 | Conv1d(
28 | channels,
29 | channels,
30 | kernel_size,
31 | 1,
32 | dilation=dilation[0],
33 | padding=get_padding(kernel_size, dilation[0]),
34 | )
35 | ),
36 | weight_norm(
37 | Conv1d(
38 | channels,
39 | channels,
40 | kernel_size,
41 | 1,
42 | dilation=dilation[1],
43 | padding=get_padding(kernel_size, dilation[1]),
44 | )
45 | ),
46 | weight_norm(
47 | Conv1d(
48 | channels,
49 | channels,
50 | kernel_size,
51 | 1,
52 | dilation=dilation[2],
53 | padding=get_padding(kernel_size, dilation[2]),
54 | )
55 | ),
56 | ]
57 | )
58 | self.convs1.apply(init_weights)
59 |
60 | self.convs2 = nn.ModuleList(
61 | [
62 | weight_norm(
63 | Conv1d(
64 | channels,
65 | channels,
66 | kernel_size,
67 | 1,
68 | dilation=1,
69 | padding=get_padding(kernel_size, 1),
70 | )
71 | ),
72 | weight_norm(
73 | Conv1d(
74 | channels,
75 | channels,
76 | kernel_size,
77 | 1,
78 | dilation=1,
79 | padding=get_padding(kernel_size, 1),
80 | )
81 | ),
82 | weight_norm(
83 | Conv1d(
84 | channels,
85 | channels,
86 | kernel_size,
87 | 1,
88 | dilation=1,
89 | padding=get_padding(kernel_size, 1),
90 | )
91 | ),
92 | ]
93 | )
94 | self.convs2.apply(init_weights)
95 |
96 | def forward(self, x):
97 | for c1, c2 in zip(self.convs1, self.convs2):
98 | xt = F.leaky_relu(x, LRELU_SLOPE)
99 | xt = c1(xt)
100 | xt = F.leaky_relu(xt, LRELU_SLOPE)
101 | xt = c2(xt)
102 | x = xt + x
103 | return x
104 |
105 | def remove_weight_norm(self):
106 | for l in self.convs1:
107 | remove_weight_norm(l)
108 | for l in self.convs2:
109 | remove_weight_norm(l)
110 |
111 |
112 | class Generator(torch.nn.Module):
113 | def __init__(self, h):
114 | super(Generator, self).__init__()
115 | self.h = h
116 | self.num_kernels = len(h.resblock_kernel_sizes)
117 | self.num_upsamples = len(h.upsample_rates)
118 | self.conv_pre = weight_norm(
119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
120 | )
121 | resblock = ResBlock
122 |
123 | self.ups = nn.ModuleList()
124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125 | self.ups.append(
126 | weight_norm(
127 | ConvTranspose1d(
128 | h.upsample_initial_channel // (2 ** i),
129 | h.upsample_initial_channel // (2 ** (i + 1)),
130 | k,
131 | u,
132 | padding=(k - u) // 2,
133 | )
134 | )
135 | )
136 |
137 | self.resblocks = nn.ModuleList()
138 | for i in range(len(self.ups)):
139 | ch = h.upsample_initial_channel // (2 ** (i + 1))
140 | for j, (k, d) in enumerate(
141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142 | ):
143 | self.resblocks.append(resblock(h, ch, k, d))
144 |
145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146 | self.ups.apply(init_weights)
147 | self.conv_post.apply(init_weights)
148 |
149 | def forward(self, x):
150 | x = self.conv_pre(x)
151 | for i in range(self.num_upsamples):
152 | x = F.leaky_relu(x, LRELU_SLOPE)
153 | x = self.ups[i](x)
154 | xs = None
155 | for j in range(self.num_kernels):
156 | if xs is None:
157 | xs = self.resblocks[i * self.num_kernels + j](x)
158 | else:
159 | xs += self.resblocks[i * self.num_kernels + j](x)
160 | x = xs / self.num_kernels
161 | x = F.leaky_relu(x)
162 | x = self.conv_post(x)
163 | x = torch.tanh(x)
164 |
165 | return x
166 |
167 | def remove_weight_norm(self):
168 | print("Removing weight norm...")
169 | for l in self.ups:
170 | remove_weight_norm(l)
171 | for l in self.resblocks:
172 | l.remove_weight_norm()
173 | remove_weight_norm(self.conv_pre)
174 | remove_weight_norm(self.conv_post)
--------------------------------------------------------------------------------
/img/emotional-fastspeech2-audios.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/emotional-fastspeech2-audios.png
--------------------------------------------------------------------------------
/img/emotional-fastspeech2-images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/emotional-fastspeech2-images.png
--------------------------------------------------------------------------------
/img/emotional-fastspeech2-scalars.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/emotional-fastspeech2-scalars.png
--------------------------------------------------------------------------------
/img/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model.png
--------------------------------------------------------------------------------
/img/model_conversational_tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model_conversational_tts.png
--------------------------------------------------------------------------------
/img/model_conversational_tts_chat_history.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model_conversational_tts_chat_history.png
--------------------------------------------------------------------------------
/img/model_emotional_tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Expressive-FastSpeech2/7f1c463d0f10053596de62e5c112ee952f58d924/img/model_emotional_tts.png
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .fastspeech2 import FastSpeech2
2 | from .loss import FastSpeech2Loss
3 | from .optimizer import ScheduledOptim
--------------------------------------------------------------------------------
/model/fastspeech2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from transformer import Encoder, Decoder, PostNet
9 | from .modules import VarianceAdaptor
10 | from utils.tools import get_mask_from_lengths
11 |
12 |
13 | class FastSpeech2(nn.Module):
14 | """ FastSpeech2 """
15 |
16 | def __init__(self, preprocess_config, model_config):
17 | super(FastSpeech2, self).__init__()
18 | self.model_config = model_config
19 |
20 | self.encoder = Encoder(model_config)
21 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
22 | self.decoder = Decoder(model_config)
23 | self.mel_linear = nn.Linear(
24 | model_config["transformer"]["decoder_hidden"],
25 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
26 | )
27 | self.postnet = PostNet()
28 |
29 | self.speaker_emb = None
30 | if model_config["multi_speaker"]:
31 | with open(
32 | os.path.join(
33 | preprocess_config["path"]["preprocessed_path"], "speakers.json"
34 | ),
35 | "r",
36 | ) as f:
37 | n_speaker = len(json.load(f))
38 | self.speaker_emb = nn.Embedding(
39 | n_speaker,
40 | model_config["transformer"]["encoder_hidden"],
41 | )
42 |
43 | self.emotion_emb = None
44 | if model_config["multi_emotion"]:
45 | with open(
46 | os.path.join(
47 | preprocess_config["path"]["preprocessed_path"], "emotions.json"
48 | ),
49 | "r",
50 | ) as f:
51 | json_raw = json.load(f)
52 | n_emotion = len(json_raw["emotion_dict"])
53 | n_arousal = len(json_raw["arousal_dict"])
54 | n_valence = len(json_raw["valence_dict"])
55 | encoder_hidden = model_config["transformer"]["encoder_hidden"]
56 | self.emotion_emb = nn.Embedding(
57 | n_emotion,
58 | encoder_hidden//2,
59 | )
60 | self.arousal_emb = nn.Embedding(
61 | n_arousal,
62 | encoder_hidden//4,
63 | )
64 | self.valence_emb = nn.Embedding(
65 | n_valence,
66 | encoder_hidden//4,
67 | )
68 | self.emotion_linear = nn.Sequential(
69 | nn.Linear(encoder_hidden, encoder_hidden),
70 | nn.ReLU()
71 | )
72 |
73 | def forward(
74 | self,
75 | speakers,
76 | emotions,
77 | arousals,
78 | valences,
79 | texts,
80 | src_lens,
81 | max_src_len,
82 | mels=None,
83 | mel_lens=None,
84 | max_mel_len=None,
85 | p_targets=None,
86 | e_targets=None,
87 | d_targets=None,
88 | p_control=1.0,
89 | e_control=1.0,
90 | d_control=1.0,
91 | ):
92 | src_masks = get_mask_from_lengths(src_lens, max_src_len)
93 | mel_masks = (
94 | get_mask_from_lengths(mel_lens, max_mel_len)
95 | if mel_lens is not None
96 | else None
97 | )
98 |
99 | output = self.encoder(texts, src_masks)
100 |
101 | if self.speaker_emb is not None:
102 | output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
103 | -1, max_src_len, -1
104 | )
105 |
106 | if self.emotion_emb is not None:
107 | emb = torch.cat((self.emotion_emb(emotions), self.arousal_emb(arousals), self.valence_emb(valences)), dim=-1)
108 | output = output + self.emotion_linear(emb).unsqueeze(1).expand(
109 | -1, max_src_len, -1
110 | )
111 |
112 | (
113 | output,
114 | p_predictions,
115 | e_predictions,
116 | log_d_predictions,
117 | d_rounded,
118 | mel_lens,
119 | mel_masks,
120 | ) = self.variance_adaptor(
121 | output,
122 | src_masks,
123 | mel_masks,
124 | max_mel_len,
125 | p_targets,
126 | e_targets,
127 | d_targets,
128 | p_control,
129 | e_control,
130 | d_control,
131 | )
132 |
133 | output, mel_masks = self.decoder(output, mel_masks)
134 | output = self.mel_linear(output)
135 |
136 | postnet_output = self.postnet(output) + output
137 |
138 | return (
139 | output,
140 | postnet_output,
141 | p_predictions,
142 | e_predictions,
143 | log_d_predictions,
144 | d_rounded,
145 | src_masks,
146 | mel_masks,
147 | src_lens,
148 | mel_lens,
149 | )
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class FastSpeech2Loss(nn.Module):
6 | """ FastSpeech2 Loss """
7 |
8 | def __init__(self, preprocess_config, model_config):
9 | super(FastSpeech2Loss, self).__init__()
10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
11 | "feature"
12 | ]
13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
14 | "feature"
15 | ]
16 | self.mse_loss = nn.MSELoss()
17 | self.mae_loss = nn.L1Loss()
18 |
19 | def forward(self, inputs, predictions):
20 | (
21 | mel_targets,
22 | _,
23 | _,
24 | pitch_targets,
25 | energy_targets,
26 | duration_targets,
27 | ) = inputs[9:]
28 | (
29 | mel_predictions,
30 | postnet_mel_predictions,
31 | pitch_predictions,
32 | energy_predictions,
33 | log_duration_predictions,
34 | _,
35 | src_masks,
36 | mel_masks,
37 | _,
38 | _,
39 | ) = predictions
40 | src_masks = ~src_masks
41 | mel_masks = ~mel_masks
42 | log_duration_targets = torch.log(duration_targets.float() + 1)
43 | mel_targets = mel_targets[:, : mel_masks.shape[1], :]
44 | mel_masks = mel_masks[:, :mel_masks.shape[1]]
45 |
46 | log_duration_targets.requires_grad = False
47 | pitch_targets.requires_grad = False
48 | energy_targets.requires_grad = False
49 | mel_targets.requires_grad = False
50 |
51 | if self.pitch_feature_level == "phoneme_level":
52 | pitch_predictions = pitch_predictions.masked_select(src_masks)
53 | pitch_targets = pitch_targets.masked_select(src_masks)
54 | elif self.pitch_feature_level == "frame_level":
55 | pitch_predictions = pitch_predictions.masked_select(mel_masks)
56 | pitch_targets = pitch_targets.masked_select(mel_masks)
57 |
58 | if self.energy_feature_level == "phoneme_level":
59 | energy_predictions = energy_predictions.masked_select(src_masks)
60 | energy_targets = energy_targets.masked_select(src_masks)
61 | if self.energy_feature_level == "frame_level":
62 | energy_predictions = energy_predictions.masked_select(mel_masks)
63 | energy_targets = energy_targets.masked_select(mel_masks)
64 |
65 | log_duration_predictions = log_duration_predictions.masked_select(src_masks)
66 | log_duration_targets = log_duration_targets.masked_select(src_masks)
67 |
68 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
69 | postnet_mel_predictions = postnet_mel_predictions.masked_select(
70 | mel_masks.unsqueeze(-1)
71 | )
72 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
73 |
74 | mel_loss = self.mae_loss(mel_predictions, mel_targets)
75 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
76 |
77 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
78 | energy_loss = self.mse_loss(energy_predictions, energy_targets)
79 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
80 |
81 | total_loss = (
82 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
83 | )
84 |
85 | return (
86 | total_loss,
87 | mel_loss,
88 | postnet_mel_loss,
89 | pitch_loss,
90 | energy_loss,
91 | duration_loss,
92 | )
93 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import math
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn as nn
9 | import numpy as np
10 | import torch.nn.functional as F
11 |
12 | from utils.tools import get_mask_from_lengths, pad
13 |
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | class VarianceAdaptor(nn.Module):
18 | """ Variance Adaptor """
19 |
20 | def __init__(self, preprocess_config, model_config):
21 | super(VarianceAdaptor, self).__init__()
22 | self.duration_predictor = VariancePredictor(model_config)
23 | self.length_regulator = LengthRegulator()
24 | self.pitch_predictor = VariancePredictor(model_config)
25 | self.energy_predictor = VariancePredictor(model_config)
26 |
27 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
28 | "feature"
29 | ]
30 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
31 | "feature"
32 | ]
33 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
34 | assert self.energy_feature_level in ["phoneme_level", "frame_level"]
35 |
36 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
37 | energy_quantization = model_config["variance_embedding"]["energy_quantization"]
38 | n_bins = model_config["variance_embedding"]["n_bins"]
39 | assert pitch_quantization in ["linear", "log"]
40 | assert energy_quantization in ["linear", "log"]
41 | with open(
42 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
43 | ) as f:
44 | stats = json.load(f)
45 | pitch_min, pitch_max = stats["pitch"][:2]
46 | energy_min, energy_max = stats["energy"][:2]
47 |
48 | if pitch_quantization == "log":
49 | self.pitch_bins = nn.Parameter(
50 | torch.exp(
51 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
52 | ),
53 | requires_grad=False,
54 | )
55 | else:
56 | self.pitch_bins = nn.Parameter(
57 | torch.linspace(pitch_min, pitch_max, n_bins - 1),
58 | requires_grad=False,
59 | )
60 | if energy_quantization == "log":
61 | self.energy_bins = nn.Parameter(
62 | torch.exp(
63 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
64 | ),
65 | requires_grad=False,
66 | )
67 | else:
68 | self.energy_bins = nn.Parameter(
69 | torch.linspace(energy_min, energy_max, n_bins - 1),
70 | requires_grad=False,
71 | )
72 |
73 | self.pitch_embedding = nn.Embedding(
74 | n_bins, model_config["transformer"]["encoder_hidden"]
75 | )
76 | self.energy_embedding = nn.Embedding(
77 | n_bins, model_config["transformer"]["encoder_hidden"]
78 | )
79 |
80 | def get_pitch_embedding(self, x, target, mask, control):
81 | prediction = self.pitch_predictor(x, mask)
82 | if target is not None:
83 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
84 | else:
85 | prediction = prediction * control
86 | embedding = self.pitch_embedding(
87 | torch.bucketize(prediction, self.pitch_bins)
88 | )
89 | return prediction, embedding
90 |
91 | def get_energy_embedding(self, x, target, mask, control):
92 | prediction = self.energy_predictor(x, mask)
93 | if target is not None:
94 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
95 | else:
96 | prediction = prediction * control
97 | embedding = self.energy_embedding(
98 | torch.bucketize(prediction, self.energy_bins)
99 | )
100 | return prediction, embedding
101 |
102 | def forward(
103 | self,
104 | x,
105 | src_mask,
106 | mel_mask=None,
107 | max_len=None,
108 | pitch_target=None,
109 | energy_target=None,
110 | duration_target=None,
111 | p_control=1.0,
112 | e_control=1.0,
113 | d_control=1.0,
114 | ):
115 |
116 | log_duration_prediction = self.duration_predictor(x, src_mask)
117 | if self.pitch_feature_level == "phoneme_level":
118 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
119 | x, pitch_target, src_mask, p_control
120 | )
121 | x = x + pitch_embedding
122 | if self.energy_feature_level == "phoneme_level":
123 | energy_prediction, energy_embedding = self.get_energy_embedding(
124 | x, energy_target, src_mask, p_control
125 | )
126 | x = x + energy_embedding
127 |
128 | if duration_target is not None:
129 | x, mel_len = self.length_regulator(x, duration_target, max_len)
130 | duration_rounded = duration_target
131 | else:
132 | duration_rounded = torch.clamp(
133 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
134 | min=0,
135 | )
136 | x, mel_len = self.length_regulator(x, duration_rounded, max_len)
137 | mel_mask = get_mask_from_lengths(mel_len)
138 |
139 | if self.pitch_feature_level == "frame_level":
140 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
141 | x, pitch_target, mel_mask, p_control
142 | )
143 | x = x + pitch_embedding
144 | if self.energy_feature_level == "frame_level":
145 | energy_prediction, energy_embedding = self.get_energy_embedding(
146 | x, energy_target, mel_mask, p_control
147 | )
148 | x = x + energy_embedding
149 |
150 | return (
151 | x,
152 | pitch_prediction,
153 | energy_prediction,
154 | log_duration_prediction,
155 | duration_rounded,
156 | mel_len,
157 | mel_mask,
158 | )
159 |
160 |
161 | class LengthRegulator(nn.Module):
162 | """ Length Regulator """
163 |
164 | def __init__(self):
165 | super(LengthRegulator, self).__init__()
166 |
167 | def LR(self, x, duration, max_len):
168 | output = list()
169 | mel_len = list()
170 | for batch, expand_target in zip(x, duration):
171 | expanded = self.expand(batch, expand_target)
172 | output.append(expanded)
173 | mel_len.append(expanded.shape[0])
174 |
175 | if max_len is not None:
176 | output = pad(output, max_len)
177 | else:
178 | output = pad(output)
179 |
180 | return output, torch.LongTensor(mel_len).to(device)
181 |
182 | def expand(self, batch, predicted):
183 | out = list()
184 |
185 | for i, vec in enumerate(batch):
186 | expand_size = predicted[i].item()
187 | out.append(vec.expand(max(int(expand_size), 0), -1))
188 | out = torch.cat(out, 0)
189 |
190 | return out
191 |
192 | def forward(self, x, duration, max_len):
193 | output, mel_len = self.LR(x, duration, max_len)
194 | return output, mel_len
195 |
196 |
197 | class VariancePredictor(nn.Module):
198 | """ Duration, Pitch and Energy Predictor """
199 |
200 | def __init__(self, model_config):
201 | super(VariancePredictor, self).__init__()
202 |
203 | self.input_size = model_config["transformer"]["encoder_hidden"]
204 | self.filter_size = model_config["variance_predictor"]["filter_size"]
205 | self.kernel = model_config["variance_predictor"]["kernel_size"]
206 | self.conv_output_size = model_config["variance_predictor"]["filter_size"]
207 | self.dropout = model_config["variance_predictor"]["dropout"]
208 |
209 | self.conv_layer = nn.Sequential(
210 | OrderedDict(
211 | [
212 | (
213 | "conv1d_1",
214 | Conv(
215 | self.input_size,
216 | self.filter_size,
217 | kernel_size=self.kernel,
218 | padding=(self.kernel - 1) // 2,
219 | ),
220 | ),
221 | ("relu_1", nn.ReLU()),
222 | ("layer_norm_1", nn.LayerNorm(self.filter_size)),
223 | ("dropout_1", nn.Dropout(self.dropout)),
224 | (
225 | "conv1d_2",
226 | Conv(
227 | self.filter_size,
228 | self.filter_size,
229 | kernel_size=self.kernel,
230 | padding=1,
231 | ),
232 | ),
233 | ("relu_2", nn.ReLU()),
234 | ("layer_norm_2", nn.LayerNorm(self.filter_size)),
235 | ("dropout_2", nn.Dropout(self.dropout)),
236 | ]
237 | )
238 | )
239 |
240 | self.linear_layer = nn.Linear(self.conv_output_size, 1)
241 |
242 | def forward(self, encoder_output, mask):
243 | out = self.conv_layer(encoder_output)
244 | out = self.linear_layer(out)
245 | out = out.squeeze(-1)
246 |
247 | if mask is not None:
248 | out = out.masked_fill(mask, 0.0)
249 |
250 | return out
251 |
252 |
253 | class Conv(nn.Module):
254 | """
255 | Convolution Module
256 | """
257 |
258 | def __init__(
259 | self,
260 | in_channels,
261 | out_channels,
262 | kernel_size=1,
263 | stride=1,
264 | padding=0,
265 | dilation=1,
266 | bias=True,
267 | w_init="linear",
268 | ):
269 | """
270 | :param in_channels: dimension of input
271 | :param out_channels: dimension of output
272 | :param kernel_size: size of kernel
273 | :param stride: size of stride
274 | :param padding: size of padding
275 | :param dilation: dilation rate
276 | :param bias: boolean. if True, bias is included.
277 | :param w_init: str. weight inits with xavier initialization.
278 | """
279 | super(Conv, self).__init__()
280 |
281 | self.conv = nn.Conv1d(
282 | in_channels,
283 | out_channels,
284 | kernel_size=kernel_size,
285 | stride=stride,
286 | padding=padding,
287 | dilation=dilation,
288 | bias=bias,
289 | )
290 |
291 | def forward(self, x):
292 | x = x.contiguous().transpose(1, 2)
293 | x = self.conv(x)
294 | x = x.contiguous().transpose(1, 2)
295 |
296 | return x
297 |
--------------------------------------------------------------------------------
/model/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ScheduledOptim:
6 | """ A simple wrapper class for learning rate scheduling """
7 |
8 | def __init__(self, model, train_config, model_config, current_step):
9 |
10 | self._optimizer = torch.optim.Adam(
11 | model.parameters(),
12 | betas=train_config["optimizer"]["betas"],
13 | eps=train_config["optimizer"]["eps"],
14 | weight_decay=train_config["optimizer"]["weight_decay"],
15 | )
16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"]
17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"]
18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"]
19 | self.current_step = current_step
20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5)
21 |
22 | def step_and_update_lr(self):
23 | self._update_learning_rate()
24 | self._optimizer.step()
25 |
26 | def zero_grad(self):
27 | # print(self.init_lr)
28 | self._optimizer.zero_grad()
29 |
30 | def load_state_dict(self, path):
31 | self._optimizer.load_state_dict(path)
32 |
33 | def _get_lr_scale(self):
34 | lr = np.min(
35 | [
36 | np.power(self.current_step, -0.5),
37 | np.power(self.n_warmup_steps, -1.5) * self.current_step,
38 | ]
39 | )
40 | for s in self.anneal_steps:
41 | if self.current_step > s:
42 | lr = lr * self.anneal_rate
43 | return lr
44 |
45 | def _update_learning_rate(self):
46 | """ Learning rate scheduling per step """
47 | self.current_step += 1
48 | lr = self.init_lr * self._get_lr_scale()
49 |
50 | for param_group in self._optimizer.param_groups:
51 | param_group["lr"] = lr
52 |
--------------------------------------------------------------------------------
/preparation/aihub_mmv.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 | import yaml
4 | import os
5 | import shutil
6 | import json
7 | import librosa
8 | import soundfile
9 | from glob import glob
10 | from tqdm import tqdm
11 | from moviepy.editor import VideoFileClip
12 | from text import _clean_text
13 | from text.korean import tokenize, normalize_nonchar
14 |
15 |
16 | def write_text(txt_path, text):
17 | with open(txt_path, 'w', encoding='utf-8') as f:
18 | f.write(text)
19 |
20 |
21 | def get_sorted_items(items):
22 | # sort by key
23 | return sorted(items, key=lambda x:int(x[0]))
24 |
25 |
26 | def get_emotion(emo_dict):
27 | e, a, v = 0, 0, 0
28 | if 'emotion' in emo_dict:
29 | e = emo_dict['emotion']
30 | a = emo_dict['arousal']
31 | v = emo_dict['valence']
32 | return e, a, v
33 |
34 |
35 | def pad_spk_id(speaker_id):
36 | return 'p{}'.format("0"*(3-len(speaker_id))+speaker_id)
37 |
38 |
39 | def create_dataset(preprocess_config):
40 | """
41 | See https://github.com/Kyumin-Park/aihub_multimodal_speech
42 | """
43 | in_dir = preprocess_config["path"]["corpus_path"]
44 | audio_dir = os.path.join(os.path.dirname(in_dir), os.path.basename(in_dir)+"_audio")
45 | out_dir = os.path.join(os.path.dirname(in_dir), os.path.basename(in_dir)+"_preprocessed")
46 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
47 |
48 | print("Gather audio...")
49 | video_files = glob(f'{in_dir}/**/*.mp4', recursive=True)
50 |
51 | if os.path.exists(out_dir):
52 | shutil.rmtree(out_dir)
53 | os.makedirs(out_dir)
54 | filelist = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8')
55 | speaker_info, speaker_done = dict(), set()
56 | total_duration = 0
57 |
58 | print("Create dataset...")
59 | for video_path in tqdm(video_files):
60 | # Load annotation file
61 | file_name = os.path.splitext(os.path.basename(video_path))[0]
62 | json_path = video_path.replace('mp4', 'json')
63 | try:
64 | with open(json_path, 'r', encoding='utf-8') as f:
65 | annotation = json.load(f)
66 | except UnicodeDecodeError:
67 | continue
68 |
69 | # Load video clip
70 | audio_path = video_path.replace(in_dir, audio_dir, 1).replace('mp4', 'wav')
71 | orig_sr = librosa.get_samplerate(audio_path)
72 | y, sr = librosa.load(audio_path, sr=orig_sr)
73 | duration = librosa.get_duration(y, sr=sr)
74 | new_sr = sampling_rate
75 | new_y = librosa.resample(y, sr, new_sr)
76 |
77 | # Metadata
78 | n_frames = float(annotation['nr_frame'])
79 | fps = n_frames / duration
80 | for spk_id, spk_info in annotation['actor'].items():
81 | if spk_id not in speaker_done:
82 | speaker_info[spk_id] = spk_info
83 | speaker_done.add(spk_id)
84 |
85 | turn_id = 0
86 | done = set()
87 | for frame, frame_data in get_sorted_items(annotation['data'].items()):
88 | for sub_id, info_data in frame_data.items():
89 | if 'text' not in info_data.keys():
90 | continue
91 |
92 | # Extract data
93 | text_data = info_data['text']
94 | emotion_data = info_data['emotion']
95 | speaker_id = info_data['person_id']
96 | start_frame = text_data['script_start']
97 | end_frame = text_data['script_end']
98 | intent = text_data['intent']
99 | strategy = text_data['strategy']
100 |
101 | et_e, et_a, et_v = get_emotion(emotion_data['text'])
102 | es_e, es_a, es_v = get_emotion(emotion_data['sound'])
103 | ei_e, ei_a, ei_v = get_emotion(emotion_data['image'])
104 | em_e, em_a, em_v = get_emotion(emotion_data['multimodal'])
105 |
106 | script = refine_text(text_data['script'])
107 |
108 | start_idx = int(float(start_frame) / fps * new_sr)
109 | end_idx = int(float(end_frame) / fps * new_sr)
110 |
111 | # Write wav
112 | y_part = new_y[start_idx:end_idx]
113 | speaker_id = pad_spk_id(speaker_id)
114 | file_name = file_name.replace('clip_', 'c')
115 | framename = f'{start_frame}-{end_frame}'
116 | basename = f'{turn_id}_{speaker_id}_{file_name}_{framename}'
117 | wav_path = os.path.join(os.path.dirname(audio_path).replace(audio_dir, out_dir),
118 | f'{basename}.wav')
119 | if framename not in done:
120 | os.makedirs(os.path.dirname(wav_path), exist_ok=True)
121 | soundfile.write(wav_path, y_part, new_sr)
122 | write_text(wav_path.replace('.wav', '.txt'), script)
123 |
124 | # Write filelist
125 | filelist.write(f'{basename}|{script}|{speaker_id}|{intent}|{strategy}|{et_e}|{et_a}|{et_v}|{es_e}|{es_a}|{es_v}|{ei_e}|{ei_a}|{ei_v}|{em_e}|{em_a}|{em_v}\n')
126 | total_duration += (end_idx - start_idx) / float(new_sr)
127 |
128 | done.add(f'{framename}')
129 | turn_id += 1
130 |
131 | filelist.close()
132 |
133 | # Save Speaker Info
134 | with open(f'{out_dir}/speaker_info.txt', 'w', encoding='utf-8') as f:
135 | for spk_id, spk_info in get_sorted_items(speaker_info.items()):
136 | gender = 'F' if spk_info['gender'] == 'female' else 'M'
137 | age = spk_info['age']
138 | spk_id = pad_spk_id(speaker_id)
139 | f.write(f'{spk_id}|{gender}|{age}\n')
140 |
141 | print(f'End parsing, total duration: {total_duration}')
142 |
143 |
144 | def refine_text(text):
145 | # Fix invalid characters in text
146 | text = text.replace('…', ',')
147 | text = text.replace('\t', '')
148 | text = text.replace('-', ',')
149 | text = text.replace('–', ',')
150 | text = ' '.join(text.split())
151 | return text
152 |
153 |
154 | def extract_audio(preprocess_config):
155 | in_dir = preprocess_config["path"]["corpus_path"]
156 | out_dir = os.path.join(os.path.dirname(in_dir), os.path.basename(in_dir)+"_tmp")
157 | video_files = glob(f'{in_dir}/**/*.mp4', recursive=True)
158 |
159 | print("Extract audio...")
160 | for video_path in tqdm(video_files):
161 | audio_path = video_path.replace(in_dir, out_dir, 1).replace('mp4', 'wav')
162 | os.makedirs(os.path.dirname(audio_path), exist_ok=True)
163 |
164 | clip = VideoFileClip(video_path)
165 | clip.audio.write_audiofile(audio_path, verbose=False)
166 | clip.close()
167 |
168 |
169 | def extract_nonkr(preprocess_config):
170 | in_dir = preprocess_config["path"]["raw_path"]
171 | filelist = open(f'{in_dir}/nonkr.txt', 'w', encoding='utf-8')
172 |
173 | count = 0
174 | nonkr = set()
175 | print("Extract non korean charactors...")
176 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as f:
177 | lines = f.readlines()
178 | total_count = len(lines)
179 | for line in tqdm(lines):
180 | wav = line.split('|')[0]
181 | text = line.split('|')[1]
182 | reg = re.compile("""[^ ㄱ-ㅣ가-힣~!.,?:{}`"'"“‘’”’()\[\]]+""")
183 | impurities = reg.findall(text)
184 | if len(impurities) == 0:
185 | count+=1
186 | continue
187 | norm = _clean_text(text, preprocess_config["preprocessing"]["text"]["text_cleaners"])
188 | impurities_str = ','.join(impurities)
189 | filelist.write(f'{norm}|{text}|{impurities_str}|{wav}\n')
190 | for imp in impurities:
191 | nonkr.add(imp)
192 | filelist.close()
193 | print('Total {} non korean charactors from {} lines'.format(len(nonkr), total_count-count))
194 | print(sorted(list(nonkr)))
195 |
196 |
197 | def extract_lexicon(preprocess_config):
198 | """
199 | Extract lexicon and build grapheme-phoneme dictionary for MFA training
200 | See https://github.com/HGU-DLLAB/Korean-FastSpeech2-Pytorch
201 | """
202 | in_dir = preprocess_config["path"]["raw_path"]
203 | lexicon_path = preprocess_config["path"]["lexicon_path"]
204 | filelist = open(lexicon_path, 'a+', encoding='utf-8')
205 |
206 | # Load Lexicon Dictionary
207 | done = set()
208 | if os.path.isfile(lexicon_path):
209 | filelist.seek(0)
210 | for line in filelist.readlines():
211 | grapheme = line.split("\t")[0]
212 | done.add(grapheme)
213 |
214 | print("Extract lexicon...")
215 | for lab in tqdm(glob(f'{in_dir}/**/*.lab', recursive=True)):
216 | with open(lab, 'r', encoding='utf-8') as f:
217 | text = f.readline().strip("\n")
218 | assert text == normalize_nonchar(text), "No special token should be left."
219 |
220 | for grapheme in text.split(" "):
221 | if not grapheme in done:
222 | phoneme = " ".join(tokenize(grapheme, norm=False))
223 | filelist.write("{}\t{}\n".format(grapheme, phoneme))
224 | done.add(grapheme)
225 | filelist.close()
226 |
227 |
228 | def apply_fixed_text(preprocess_config):
229 | in_dir = preprocess_config["path"]["corpus_path"]
230 | sub_dir = preprocess_config["path"]["sub_dir_name"]
231 | out_dir = preprocess_config["path"]["raw_path"]
232 | fixed_text_path = preprocess_config["path"]["fixed_text_path"]
233 | cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
234 |
235 | fixed_text_dict = dict()
236 | print("Fixing transcripts...")
237 | with open(fixed_text_path, 'r', encoding='utf-8') as f:
238 | for line in tqdm(f.readlines()):
239 | wav, fixed_text = line.split('|')[0], line.split('|')[1]
240 | clip_name = wav.split('_')[2].replace('c', 'clip_')
241 | fixed_text_dict[wav] = fixed_text.replace('\n', '')
242 |
243 | text = _clean_text(fixed_text, cleaners)
244 | with open(
245 | os.path.join(out_dir, sub_dir, clip_name, "{}.lab".format(wav)),
246 | "w",
247 | ) as f1:
248 | f1.write(text)
249 |
250 | filelist_fixed = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8')
251 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as filelist:
252 | for line in tqdm(filelist.readlines()):
253 | wav = line.split('|')[0]
254 | if wav in fixed_text_dict:
255 | filelist_fixed.write("|".join([line.split("|")[0]] + [fixed_text_dict[wav]] + line.split("|")[2:]))
256 | else:
257 | filelist_fixed.write(line)
258 | filelist_fixed.close()
259 |
260 | extract_lexicon(preprocess_config)
--------------------------------------------------------------------------------
/preparation/iemocap.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 | import yaml
4 | import os
5 | import shutil
6 | import json
7 | import librosa
8 | import soundfile
9 | from glob import glob
10 | from tqdm import tqdm
11 | from moviepy.editor import VideoFileClip
12 | from text import _clean_text
13 | from text.korean import normalize_nonchar
14 | from g2p_en import G2p
15 |
16 |
17 | def extract_nonen(preprocess_config):
18 | in_dir = preprocess_config["path"]["raw_path"]
19 | filelist = open(f'{in_dir}/nonen.txt', 'w', encoding='utf-8')
20 |
21 | count = 0
22 | nonen = set()
23 | print("Extract non english charactors...")
24 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as f:
25 | lines = f.readlines()
26 | total_count = len(lines)
27 | for line in tqdm(lines):
28 | wav = line.split('|')[0]
29 | text = line.split('|')[1]
30 |
31 | reg = re.compile("""[^ a-zA-Z~!.,?:`"'"“‘’”’]+""")
32 | impurities = reg.findall(text)
33 | if len(impurities) == 0:
34 | count+=1
35 | continue
36 | norm = _clean_text(text, preprocess_config["preprocessing"]["text"]["text_cleaners"])
37 | impurities_str = ','.join(impurities)
38 | filelist.write(f'{norm}|{text}|{impurities_str}|{wav}\n')
39 | for imp in impurities:
40 | nonen.add(imp)
41 | filelist.close()
42 | print('Total {} non english charactors from {} lines'.format(len(nonen), total_count-count))
43 | print(sorted(list(nonen)))
44 |
45 |
46 | def extract_lexicon(preprocess_config):
47 | """
48 | Extract lexicon and build grapheme-phoneme dictionary for MFA training
49 | """
50 | in_dir = preprocess_config["path"]["raw_path"]
51 | lexicon_path = preprocess_config["path"]["lexicon_path"]
52 | filelist = open(lexicon_path, 'a+', encoding='utf-8')
53 |
54 | # Load Lexicon Dictionary
55 | done = set()
56 | if os.path.isfile(lexicon_path):
57 | filelist.seek(0)
58 | for line in filelist.readlines():
59 | grapheme = line.split("\t")[0]
60 | done.add(grapheme)
61 |
62 | print("Extract lexicon...")
63 | g2p = G2p()
64 | for lab in tqdm(glob(f'{in_dir}/**/*.lab', recursive=True)):
65 | with open(lab, 'r', encoding='utf-8') as f:
66 | text = f.readline().strip("\n")
67 | text = normalize_nonchar(text)
68 |
69 | for grapheme in text.split(" "):
70 | if not grapheme in done:
71 | phoneme = " ".join(g2p(grapheme))
72 | filelist.write("{}\t{}\n".format(grapheme, phoneme))
73 | done.add(grapheme)
74 | filelist.close()
75 |
76 |
77 | def apply_fixed_text(preprocess_config):
78 | in_dir = preprocess_config["path"]["corpus_path"]
79 | sub_dir = preprocess_config["path"]["sub_dir_name"]
80 | out_dir = preprocess_config["path"]["raw_path"]
81 | fixed_text_path = preprocess_config["path"]["fixed_text_path"]
82 | cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
83 |
84 | fixed_text_dict = dict()
85 | print("Fixing transcripts...")
86 | with open(fixed_text_path, 'r', encoding='utf-8') as f:
87 | for line in tqdm(f.readlines()):
88 | wav, fixed_text = line.split('|')[0], line.split('|')[1]
89 | session = '_'.join(wav.split('_')[1:])
90 | fixed_text_dict[wav] = fixed_text.replace('\n', '')
91 |
92 | text = _clean_text(fixed_text, cleaners)
93 | with open(
94 | os.path.join(out_dir, sub_dir, session, "{}.lab".format(wav)),
95 | "w",
96 | ) as f1:
97 | f1.write(text)
98 |
99 | filelist_fixed = open(f'{out_dir}/filelist_fixed.txt', 'w', encoding='utf-8')
100 | with open(f'{out_dir}/filelist.txt', 'r', encoding='utf-8') as filelist:
101 | for line in tqdm(filelist.readlines()):
102 | wav = line.split('|')[0]
103 | if wav in fixed_text_dict:
104 | filelist_fixed.write("|".join([line.split("|")[0]] + [fixed_text_dict[wav]] + line.split("|")[2:]))
105 | else:
106 | filelist_fixed.write(line)
107 | filelist_fixed.close()
108 |
109 | os.remove(f'{out_dir}/filelist.txt')
110 | os.rename(f'{out_dir}/filelist_fixed.txt', f'{out_dir}/filelist.txt')
111 |
112 | extract_lexicon(preprocess_config)
--------------------------------------------------------------------------------
/prepare_align.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import yaml
4 |
5 | from preprocessor import aihub_mmv, iemocap
6 |
7 |
8 | def main(config):
9 | if "AIHub-MMV" in config["dataset"]:
10 | aihub_mmv.prepare_align(config)
11 | if "IEMOCAP" in config["dataset"]:
12 | iemocap.prepare_align(config)
13 |
14 |
15 | if __name__ == "__main__":
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("config", type=str, help="path to preprocess.yaml")
18 | args = parser.parse_args()
19 |
20 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
21 | main(config)
22 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import yaml
4 |
5 | from preparation import aihub_mmv, iemocap
6 |
7 |
8 | def main(args, preprocess_config):
9 | os.makedirs("./lexicon", exist_ok=True)
10 | os.makedirs("./preprocessed_data", exist_ok=True)
11 | os.makedirs("./montreal-forced-aligner", exist_ok=True)
12 |
13 | if "AIHub-MMV" in preprocess_config["dataset"]:
14 | if args.extract_nonkr:
15 | aihub_mmv.extract_nonkr(preprocess_config)
16 | elif args.extract_lexicon:
17 | aihub_mmv.extract_lexicon(preprocess_config)
18 | elif args.apply_fixed_text:
19 | aihub_mmv.apply_fixed_text(preprocess_config)
20 | else:
21 | if args.extract_audio:
22 | aihub_mmv.extract_audio(preprocess_config)
23 | aihub_mmv.create_dataset(preprocess_config)
24 | elif "IEMOCAP" in preprocess_config["dataset"]:
25 | if args.extract_nonen:
26 | iemocap.extract_nonen(preprocess_config)
27 | elif args.extract_lexicon:
28 | iemocap.extract_lexicon(preprocess_config)
29 | elif args.apply_fixed_text:
30 | iemocap.apply_fixed_text(preprocess_config)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "-p",
37 | "--preprocess_config",
38 | type=str,
39 | required=True,
40 | help="path to preprocess.yaml",
41 | )
42 | parser.add_argument(
43 | '--extract_audio',
44 | help='convert video into .wav file',
45 | action='store_true',
46 | )
47 | parser.add_argument(
48 | '--extract_nonkr',
49 | help='extract non korean charactor',
50 | action='store_true',
51 | )
52 | parser.add_argument(
53 | '--extract_nonen',
54 | help='extract non english charactor',
55 | action='store_true',
56 | )
57 | parser.add_argument(
58 | '--extract_lexicon',
59 | help='extract lexicon and build grapheme-phoneme dictionary',
60 | action='store_true',
61 | )
62 | parser.add_argument(
63 | '--apply_fixed_text',
64 | help='apply fixed text to both raw data and filelist',
65 | action='store_true',
66 | )
67 | args = parser.parse_args()
68 |
69 | preprocess_config = yaml.load(
70 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
71 | )
72 |
73 | main(args, preprocess_config)
74 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import yaml
4 |
5 | from preprocessor.preprocessor import Preprocessor
6 |
7 |
8 | if __name__ == "__main__":
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("config", type=str, help="path to preprocess.yaml")
11 | args = parser.parse_args()
12 |
13 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
14 | preprocessor = Preprocessor(config)
15 | preprocessor.build_from_path()
16 |
--------------------------------------------------------------------------------
/preprocessor/aihub_mmv.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 | from shutil import copyfile
8 |
9 | from text import _clean_text
10 |
11 |
12 | def prepare_align(config):
13 | in_dir = config["path"]["corpus_path"]
14 | sub_dir = config["path"]["sub_dir_name"]
15 | out_dir = config["path"]["raw_path"]
16 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
17 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
18 | fixed_text_path = config["path"]["fixed_text_path"]
19 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
20 |
21 | fixed_text_dict = dict()
22 | with open(fixed_text_path, 'r', encoding='utf-8') as f:
23 | for line in tqdm(f.readlines()):
24 | wav, fixed_text = line.split('|')[0], line.split('|')[1]
25 | fixed_text_dict[wav] = fixed_text.replace('\n', '')
26 |
27 | for sep_dir in tqdm(next(os.walk(in_dir))[1]):
28 | for clip_name in os.listdir(os.path.join(in_dir, sep_dir)):
29 | for file_name in os.listdir(os.path.join(in_dir, sep_dir, clip_name)):
30 | if file_name[-4:] != ".wav":
31 | continue
32 | base_name = file_name[:-4]
33 | text_path = os.path.join(
34 | in_dir, sep_dir, clip_name, "{}.txt".format(base_name)
35 | )
36 | wav_path = os.path.join(
37 | in_dir, sep_dir, clip_name, "{}.wav".format(base_name)
38 | )
39 | if base_name in fixed_text_dict:
40 | text = fixed_text_dict[base_name]
41 | else:
42 | with open(text_path) as f:
43 | text = f.readline().strip("\n")
44 | text = _clean_text(text, cleaners)
45 |
46 | os.makedirs(os.path.join(out_dir, sub_dir, clip_name), exist_ok=True)
47 | wav, _ = librosa.load(wav_path, sampling_rate)
48 | wav = wav / max(abs(wav)) * max_wav_value
49 | wavfile.write(
50 | os.path.join(out_dir, sub_dir, clip_name, "{}.wav".format(base_name)),
51 | sampling_rate,
52 | wav.astype(np.int16),
53 | )
54 | with open(
55 | os.path.join(out_dir, sub_dir, clip_name, "{}.lab".format(base_name)),
56 | "w",
57 | ) as f1:
58 | f1.write(text)
59 |
60 | # Filelist
61 | filelist_fixed = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8')
62 | with open(f'{in_dir}/filelist.txt', 'r', encoding='utf-8') as filelist:
63 | for line in tqdm(filelist.readlines()):
64 | wav = line.split('|')[0]
65 | if wav in fixed_text_dict:
66 | filelist_fixed.write("|".join([line.split("|")[0]] + [fixed_text_dict[wav]] + line.split("|")[2:]))
67 | else:
68 | filelist_fixed.write(line)
69 | filelist_fixed.close()
70 |
71 | # Speaker Info
72 | copyfile(f'{in_dir}/speaker_info.txt', f'{out_dir}/speaker_info.txt')
--------------------------------------------------------------------------------
/preprocessor/iemocap.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 |
8 | from text import _clean_text
9 |
10 | _square_brackets_re = re.compile(r"\[[\w\d\s]+\]")
11 | _inv_square_brackets_re = re.compile(r"(.*?)\](.+?)\[(.*)")
12 |
13 |
14 | def get_sorted_items(items):
15 | # sort by key
16 | return sorted(items, key=lambda x:x[0])
17 |
18 |
19 | def prepare_align(config):
20 | in_dir = config["path"]["corpus_path"]
21 | sub_dir = config["path"]["sub_dir_name"]
22 | out_dir = config["path"]["raw_path"]
23 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
24 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
25 | fixed_text_path = config["path"]["fixed_text_path"]
26 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
27 |
28 | os.makedirs(os.path.join(out_dir), exist_ok=True)
29 | filelist_fixed = open(f'{out_dir}/filelist.txt', 'w', encoding='utf-8')
30 | speaker_info, speaker_done = dict(), set()
31 |
32 | fixed_text_dict = dict()
33 | with open(fixed_text_path, 'r', encoding='utf-8') as f:
34 | for line in tqdm(f.readlines()):
35 | wav, fixed_text = line.split('|')[0], line.split('|')[1]
36 | fixed_text_dict[wav] = fixed_text.replace('\n', '')
37 |
38 | for sep_dir in tqdm(next(os.walk(in_dir))[1]):
39 | if sub_dir[:-1] not in sep_dir.lower():
40 | continue
41 | for wav_dir in tqdm((next(os.walk(os.path.join(in_dir, sep_dir, "sentences", "wav")))[1])):
42 |
43 | # Build Text Dict
44 | text_dict = dict()
45 | text_raw_path = os.path.join(
46 | in_dir, sep_dir, "dialog", "transcriptions", "{}.txt".format(wav_dir)
47 | )
48 | with open(text_raw_path) as f:
49 | for line in f.readlines():
50 | base_name = line.split("[")[0].strip()
51 | transcript = line.split("]:")[-1].strip()
52 | text_dict[base_name] = transcript
53 |
54 | # Build Emotion Dict
55 | emo_dict = dict()
56 | emo_raw_path = os.path.join(
57 | in_dir, sep_dir, "dialog", "EmoEvaluation", "{}.txt".format(wav_dir)
58 | )
59 | with open(emo_raw_path) as f:
60 | for line in f.readlines()[1:]:
61 | if "[" not in line or "%" in line:
62 | continue
63 | m = _inv_square_brackets_re.match(" ".join(line.split()))
64 | base_name, emo_gt = m.group(2).strip().split(" ")
65 | valence, arousal = m.group(3).split(",")[0].strip(), m.group(3).split(",")[1].strip()
66 | emo_dict[base_name] = {
67 | "e": emo_gt,
68 | "a": arousal,
69 | "v": valence,
70 | }
71 |
72 | for file_name in os.listdir(os.path.join(in_dir, sep_dir, "sentences", "wav", wav_dir)):
73 | if file_name[0] == "." or file_name[-4:] != ".wav":
74 | continue
75 | base_name = file_name[:-4]
76 | if len(base_name.split("_")) == 3:
77 | spk_id, dialog_type, turn = base_name.split("_")
78 | elif len(base_name.split("_")) == 4:
79 | spk_id, dialog_type, turn = base_name.split("_")[0], "_".join(base_name.split("_")[1:3]), base_name.split("_")[3]
80 | base_name_new = "_".join([turn, spk_id, dialog_type])
81 |
82 | if spk_id not in speaker_done:
83 | speaker_info[spk_id] = {
84 | 'gender': spk_id[-1]
85 | }
86 | speaker_done.add(spk_id)
87 |
88 | wav_path = os.path.join(
89 | in_dir, sep_dir, "sentences", "wav", wav_dir, "{}.wav".format(base_name)
90 | )
91 | if base_name in fixed_text_dict:
92 | text = fixed_text_dict[base_name]
93 | else:
94 | text = text_dict[base_name]
95 | text = re.sub(_square_brackets_re, "", text)
96 | text = ' '.join(text.split())
97 | text = _clean_text(text, cleaners)
98 |
99 | os.makedirs(os.path.join(out_dir, sub_dir, wav_dir), exist_ok=True)
100 | wav, _ = librosa.load(wav_path, sampling_rate)
101 | wav = wav / max(abs(wav)) * max_wav_value
102 | wavfile.write(
103 | os.path.join(out_dir, sub_dir, wav_dir, "{}.wav".format(base_name_new)),
104 | sampling_rate,
105 | wav.astype(np.int16),
106 | )
107 | with open(
108 | os.path.join(out_dir, sub_dir, wav_dir, "{}.lab".format(base_name_new)),
109 | "w",
110 | ) as f1:
111 | f1.write(text)
112 |
113 | # Filelist
114 | emo_ = emo_dict[base_name]
115 | emotion, arousal, valence = emo_["e"], emo_["a"], emo_["v"]
116 | filelist_fixed.write("|".join([base_name_new, text, spk_id, emotion, arousal, valence]) + "\n")
117 | filelist_fixed.close()
118 |
119 | # Save Speaker Info
120 | with open(f'{out_dir}/speaker_info.txt', 'w', encoding='utf-8') as f:
121 | for spk_id, spk_info in get_sorted_items(speaker_info.items()):
122 | gender = spk_info['gender']
123 | f.write(f'{spk_id}|{gender}\n')
--------------------------------------------------------------------------------
/preprocessor/preprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 |
5 | import tgt
6 | import librosa
7 | import numpy as np
8 | import pyworld as pw
9 | from scipy.interpolate import interp1d
10 | from sklearn.preprocessing import StandardScaler
11 | from tqdm import tqdm
12 |
13 | import audio as Audio
14 |
15 | random.seed(1234)
16 |
17 | class Preprocessor:
18 | def __init__(self, config):
19 | self.config = config
20 | self.dataset = config["dataset"]
21 | self.sub_dir = ""
22 | self.speakers = dict()
23 | self.emotions = dict()
24 | self.sub_dir = config["path"]["sub_dir_name"]
25 | self.speakers = self.load_speaker_dict()
26 | self.filelist, self.emotions = self.load_filelist_dict()
27 | self.in_dir = os.path.join(config["path"]["raw_path"], self.sub_dir)
28 | self.out_dir = config["path"]["preprocessed_path"]
29 | self.val_size = config["preprocessing"]["val_size"]
30 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
31 | self.hop_length = config["preprocessing"]["stft"]["hop_length"]
32 |
33 | assert config["preprocessing"]["pitch"]["feature"] in [
34 | "phoneme_level",
35 | "frame_level",
36 | ]
37 | assert config["preprocessing"]["energy"]["feature"] in [
38 | "phoneme_level",
39 | "frame_level",
40 | ]
41 | self.pitch_phoneme_averaging = (
42 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level"
43 | )
44 | self.energy_phoneme_averaging = (
45 | config["preprocessing"]["energy"]["feature"] == "phoneme_level"
46 | )
47 |
48 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"]
49 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"]
50 |
51 | self.STFT = Audio.stft.TacotronSTFT(
52 | config["preprocessing"]["stft"]["filter_length"],
53 | config["preprocessing"]["stft"]["hop_length"],
54 | config["preprocessing"]["stft"]["win_length"],
55 | config["preprocessing"]["mel"]["n_mel_channels"],
56 | config["preprocessing"]["audio"]["sampling_rate"],
57 | config["preprocessing"]["mel"]["mel_fmin"],
58 | config["preprocessing"]["mel"]["mel_fmax"],
59 | )
60 |
61 | def load_speaker_dict(self):
62 | spk_dir = os.path.join(self.config["path"]["raw_path"], 'speaker_info.txt')
63 | spk_dict = dict()
64 | with open(spk_dir, 'r', encoding='utf-8') as f:
65 | for i, line in enumerate(f.readlines()):
66 | spk_id = line.split("|")[0]
67 | spk_dict[spk_id] = i
68 | return spk_dict
69 |
70 | def load_filelist_dict(self):
71 | filelist_dir = os.path.join(self.config["path"]["raw_path"], 'filelist.txt')
72 | filelist_dict, emotion_dict, arousal_dict, valence_dict = dict(), dict(), dict(), dict()
73 | emotions, arousals, valences = set(), set(), set()
74 | with open(filelist_dir, 'r', encoding='utf-8') as f:
75 | for i, line in enumerate(f.readlines()):
76 | basename, aux_data = line.split("|")[0], line.split("|")[3:]
77 | filelist_dict[basename] = "|".join(aux_data).strip("\n")
78 | emotions.add(aux_data[-3])
79 | arousals.add(aux_data[-2])
80 | valences.add(aux_data[-1].strip("\n"))
81 | for i, emotion in enumerate(list(emotions)):
82 | emotion_dict[emotion] = i
83 | for i, arousal in enumerate(list(arousals)):
84 | arousal_dict[arousal] = i
85 | for i, valence in enumerate(list(valences)):
86 | valence_dict[valence] = i
87 | emotion_dict = {
88 | "emotion_dict": emotion_dict,
89 | "arousal_dict": arousal_dict,
90 | "valence_dict": valence_dict,
91 | }
92 | return filelist_dict, emotion_dict
93 |
94 | def build_from_path(self):
95 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True)
96 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True)
97 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True)
98 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True)
99 |
100 | print("Processing Data ...")
101 | out = list()
102 | n_frames = 0
103 | pitch_scaler = StandardScaler()
104 | energy_scaler = StandardScaler()
105 |
106 | # Compute pitch, energy, duration, and mel-spectrogram
107 | speakers = self.speakers.copy()
108 | for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))):
109 | if len(self.speakers) == 0:
110 | speakers[speaker] = i
111 | for wav_name in os.listdir(os.path.join(self.in_dir, speaker)):
112 | if ".wav" not in wav_name:
113 | continue
114 |
115 | basename = wav_name.split(".")[0]
116 | tg_path = os.path.join(
117 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
118 | )
119 | if os.path.exists(tg_path):
120 | ret = self.process_utterance(speaker, basename)
121 | if ret is None:
122 | continue
123 | else:
124 | info, pitch, energy, n = ret
125 | out.append(info)
126 |
127 | if len(pitch) > 0:
128 | pitch_scaler.partial_fit(pitch.reshape((-1, 1)))
129 | if len(energy) > 0:
130 | energy_scaler.partial_fit(energy.reshape((-1, 1)))
131 |
132 | n_frames += n
133 |
134 | print("Computing statistic quantities ...")
135 | # Perform normalization if necessary
136 | if self.pitch_normalization:
137 | pitch_mean = pitch_scaler.mean_[0]
138 | pitch_std = pitch_scaler.scale_[0]
139 | else:
140 | # A numerical trick to avoid normalization...
141 | pitch_mean = 0
142 | pitch_std = 1
143 | if self.energy_normalization:
144 | energy_mean = energy_scaler.mean_[0]
145 | energy_std = energy_scaler.scale_[0]
146 | else:
147 | energy_mean = 0
148 | energy_std = 1
149 |
150 | pitch_min, pitch_max = self.normalize(
151 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std
152 | )
153 | energy_min, energy_max = self.normalize(
154 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std
155 | )
156 |
157 | # Save files
158 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f:
159 | f.write(json.dumps(speakers))
160 |
161 | if len(self.emotions) != 0:
162 | with open(os.path.join(self.out_dir, "emotions.json"), "w") as f:
163 | f.write(json.dumps(self.emotions))
164 |
165 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f:
166 | stats = {
167 | "pitch": [
168 | float(pitch_min),
169 | float(pitch_max),
170 | float(pitch_mean),
171 | float(pitch_std),
172 | ],
173 | "energy": [
174 | float(energy_min),
175 | float(energy_max),
176 | float(energy_mean),
177 | float(energy_std),
178 | ],
179 | }
180 | f.write(json.dumps(stats))
181 |
182 | print(
183 | "Total time: {} hours".format(
184 | n_frames * self.hop_length / self.sampling_rate / 3600
185 | )
186 | )
187 |
188 | random.shuffle(out)
189 | out = [r for r in out if r is not None]
190 |
191 | # Write metadata
192 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f:
193 | for m in out[self.val_size :]:
194 | f.write(m + "\n")
195 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f:
196 | for m in out[: self.val_size]:
197 | f.write(m + "\n")
198 |
199 | return out
200 |
201 | def process_utterance(self, speaker, basename):
202 | aux_data = ""
203 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename))
204 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename))
205 | tg_path = os.path.join(
206 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
207 | )
208 | speaker = basename.split("_")[1]
209 | aux_data = self.filelist[basename]
210 |
211 | # Get alignments
212 | textgrid = tgt.io.read_textgrid(tg_path)
213 | phone, duration, start, end = self.get_alignment(
214 | textgrid.get_tier_by_name("phones")
215 | )
216 | text = "{" + " ".join(phone) + "}"
217 | if start >= end:
218 | return None
219 |
220 | # Read and trim wav files
221 | wav, _ = librosa.load(wav_path)
222 | wav = wav[
223 | int(self.sampling_rate * start) : int(self.sampling_rate * end)
224 | ].astype(np.float32)
225 |
226 | # Read raw text
227 | with open(text_path, "r") as f:
228 | raw_text = f.readline().strip("\n")
229 |
230 | # Compute fundamental frequency
231 | pitch, t = pw.dio(
232 | wav.astype(np.float64),
233 | self.sampling_rate,
234 | frame_period=self.hop_length / self.sampling_rate * 1000,
235 | )
236 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate)
237 |
238 | pitch = pitch[: sum(duration)]
239 | if np.sum(pitch != 0) <= 1:
240 | return None
241 |
242 | # Compute mel-scale spectrogram and energy
243 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT)
244 | mel_spectrogram = mel_spectrogram[:, : sum(duration)]
245 | energy = energy[: sum(duration)]
246 |
247 | if self.pitch_phoneme_averaging:
248 | # perform linear interpolation
249 | nonzero_ids = np.where(pitch != 0)[0]
250 | interp_fn = interp1d(
251 | nonzero_ids,
252 | pitch[nonzero_ids],
253 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
254 | bounds_error=False,
255 | )
256 | pitch = interp_fn(np.arange(0, len(pitch)))
257 |
258 | # Phoneme-level average
259 | pos = 0
260 | for i, d in enumerate(duration):
261 | if d > 0:
262 | pitch[i] = np.mean(pitch[pos : pos + d])
263 | else:
264 | pitch[i] = 0
265 | pos += d
266 | pitch = pitch[: len(duration)]
267 |
268 | if self.energy_phoneme_averaging:
269 | # Phoneme-level average
270 | pos = 0
271 | for i, d in enumerate(duration):
272 | if d > 0:
273 | energy[i] = np.mean(energy[pos : pos + d])
274 | else:
275 | energy[i] = 0
276 | pos += d
277 | energy = energy[: len(duration)]
278 |
279 | # Save files
280 | dur_filename = "{}-duration-{}.npy".format(speaker, basename)
281 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration)
282 |
283 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename)
284 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch)
285 |
286 | energy_filename = "{}-energy-{}.npy".format(speaker, basename)
287 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy)
288 |
289 | mel_filename = "{}-mel-{}.npy".format(speaker, basename)
290 | np.save(
291 | os.path.join(self.out_dir, "mel", mel_filename),
292 | mel_spectrogram.T,
293 | )
294 |
295 | return (
296 | "|".join([basename, speaker, text, raw_text, aux_data]),
297 | self.remove_outlier(pitch),
298 | self.remove_outlier(energy),
299 | mel_spectrogram.shape[1],
300 | )
301 |
302 | def get_alignment(self, tier):
303 | sil_phones = ["sil", "sp", "spn"]
304 |
305 | phones = []
306 | durations = []
307 | start_time = 0
308 | end_time = 0
309 | end_idx = 0
310 | for t in tier._objects:
311 | s, e, p = t.start_time, t.end_time, t.text
312 |
313 | # Trim leading silences
314 | if phones == []:
315 | if p in sil_phones:
316 | continue
317 | else:
318 | start_time = s
319 |
320 | if p not in sil_phones:
321 | # For ordinary phones
322 | phones.append(p)
323 | end_time = e
324 | end_idx = len(phones)
325 | else:
326 | # For silent phones
327 | phones.append(p)
328 |
329 | durations.append(
330 | int(
331 | np.round(e * self.sampling_rate / self.hop_length)
332 | - np.round(s * self.sampling_rate / self.hop_length)
333 | )
334 | )
335 |
336 | # Trim tailing silences
337 | phones = phones[:end_idx]
338 | durations = durations[:end_idx]
339 |
340 | return phones, durations, start_time, end_time
341 |
342 | def remove_outlier(self, values):
343 | values = np.array(values)
344 | p25 = np.percentile(values, 25)
345 | p75 = np.percentile(values, 75)
346 | lower = p25 - 1.5 * (p75 - p25)
347 | upper = p75 + 1.5 * (p75 - p25)
348 | normal_indices = np.logical_and(values > lower, values < upper)
349 |
350 | return values[normal_indices]
351 |
352 | def normalize(self, in_dir, mean, std):
353 | max_value = np.finfo(np.float64).min
354 | min_value = np.finfo(np.float64).max
355 | for filename in os.listdir(in_dir):
356 | filename = os.path.join(in_dir, filename)
357 | values = (np.load(filename) - mean) / std
358 | np.save(filename, values)
359 |
360 | max_value = max(max_value, max(values))
361 | min_value = min(min_value, min(values))
362 |
363 | return min_value, max_value
364 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.6.3
2 | g2pk==0.9.4
3 | torch==1.7.0
4 | numpy==1.19.5
5 | tgt==1.4.3
6 | scipy==1.5.0
7 | pyworld==0.2.10
8 | librosa==0.7.2
9 | numba==0.53.1
10 | matplotlib==3.2.2
11 | unidecode==1.1.1
12 | inflect==4.1.0
13 | g2p-en==2.1.0
14 | tensorboard==2.4.1
15 | python_speech_features==0.6
16 | pandas==1.1.3
17 | pysptk==0.1.18
18 | tensorflow==2.4.0
19 | PyYAML==5.4.1
20 | tqdm==4.46.1
21 | moviepy==1.0.3
22 | quickspacer==1.0.4
--------------------------------------------------------------------------------
/synthesize.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 | from string import punctuation
4 | import os
5 | import json
6 |
7 | import torch
8 | import yaml
9 | import numpy as np
10 | from torch.utils.data import DataLoader
11 | from g2p_en import G2p
12 |
13 | from utils.model import get_model, get_vocoder
14 | from utils.tools import to_device, synth_samples
15 | from dataset import TextDataset
16 | from text import text_to_sequence
17 | from text.korean import tokenize, normalize_nonchar
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 |
22 | def read_lexicon(lex_path):
23 | lexicon = {}
24 | with open(lex_path) as f:
25 | for line in f:
26 | temp = re.split(r"\s+", line.strip("\n"))
27 | word = temp[0]
28 | phones = temp[1:]
29 | if word not in lexicon:
30 | lexicon[word] = phones
31 | return lexicon
32 |
33 |
34 | def preprocess_korean(text, preprocess_config):
35 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
36 |
37 | phones = []
38 | words = filter(None, re.split(r"([,;.\-\?\!\s+])", text))
39 | for w in words:
40 | if w in lexicon:
41 | phones += lexicon[w]
42 | else:
43 | phones += list(filter(lambda p: p != " ", tokenize(w, norm=False)))
44 | phones = "{" + "}{".join(phones) + "}"
45 | phones = normalize_nonchar(phones, inference=True)
46 | phones = phones.replace("}{", " ")
47 |
48 | print("Raw Text Sequence: {}".format(text))
49 | print("Phoneme Sequence: {}".format(phones))
50 | sequence = np.array(
51 | text_to_sequence(
52 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
53 | )
54 | )
55 |
56 | return np.array(sequence)
57 |
58 |
59 | def preprocess_english(text, preprocess_config):
60 | text = text.rstrip(punctuation)
61 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
62 |
63 | g2p = G2p()
64 | phones = []
65 | words = filter(None, re.split(r"([,;.\-\?\!\s+])", text))
66 | for w in words:
67 | if w.lower() in lexicon:
68 | phones += lexicon[w.lower()]
69 | else:
70 | phones += list(filter(lambda p: p != " ", g2p(w)))
71 | phones = "{" + "}{".join(phones) + "}"
72 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
73 | phones = phones.replace("}{", " ")
74 |
75 | print("Raw Text Sequence: {}".format(text))
76 | print("Phoneme Sequence: {}".format(phones))
77 | sequence = np.array(
78 | text_to_sequence(
79 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
80 | )
81 | )
82 |
83 | return np.array(sequence)
84 |
85 |
86 | def synthesize(model, step, configs, vocoder, batchs, control_values, tag):
87 | preprocess_config, model_config, train_config = configs
88 | pitch_control, energy_control, duration_control = control_values
89 |
90 | for batch in batchs:
91 | batch = to_device(batch, device)
92 | with torch.no_grad():
93 | # Forward
94 | output = model(
95 | *(batch[2:]),
96 | p_control=pitch_control,
97 | e_control=energy_control,
98 | d_control=duration_control
99 | )
100 | synth_samples(
101 | batch,
102 | output,
103 | vocoder,
104 | model_config,
105 | preprocess_config,
106 | train_config["path"]["result_path"],
107 | tag,
108 | )
109 |
110 |
111 | if __name__ == "__main__":
112 |
113 | parser = argparse.ArgumentParser()
114 | parser.add_argument("--restore_step", type=int, required=True)
115 | parser.add_argument(
116 | "--mode",
117 | type=str,
118 | choices=["batch", "single"],
119 | required=True,
120 | help="Synthesize a whole dataset or a single sentence",
121 | )
122 | parser.add_argument(
123 | "--source",
124 | type=str,
125 | default=None,
126 | help="path to a source file with format like train.txt and val.txt, for batch mode only",
127 | )
128 | parser.add_argument(
129 | "--text",
130 | type=str,
131 | default=None,
132 | help="raw text to synthesize, for single-sentence mode only",
133 | )
134 | parser.add_argument(
135 | "--speaker_id",
136 | type=str,
137 | default="p001",
138 | help="speaker ID for multi-speaker synthesis, for single-sentence mode only",
139 | )
140 | parser.add_argument(
141 | "--emotion_id",
142 | type=str,
143 | default="happy",
144 | help="emotion ID for multi-emotion synthesis, for single-sentence mode only",
145 | )
146 | parser.add_argument(
147 | "--arousal",
148 | type=str,
149 | default="3",
150 | help="arousal value for multi-emotion synthesis, for single-sentence mode only",
151 | )
152 | parser.add_argument(
153 | "--valence",
154 | type=str,
155 | default="3",
156 | help="valence value for multi-emotion synthesis, for single-sentence mode only",
157 | )
158 | parser.add_argument(
159 | "-p",
160 | "--preprocess_config",
161 | type=str,
162 | required=True,
163 | help="path to preprocess.yaml",
164 | )
165 | parser.add_argument(
166 | "-m", "--model_config", type=str, required=True, help="path to model.yaml"
167 | )
168 | parser.add_argument(
169 | "-t", "--train_config", type=str, required=True, help="path to train.yaml"
170 | )
171 | parser.add_argument(
172 | "--pitch_control",
173 | type=float,
174 | default=1.0,
175 | help="control the pitch of the whole utterance, larger value for higher pitch",
176 | )
177 | parser.add_argument(
178 | "--energy_control",
179 | type=float,
180 | default=1.0,
181 | help="control the energy of the whole utterance, larger value for larger volume",
182 | )
183 | parser.add_argument(
184 | "--duration_control",
185 | type=float,
186 | default=1.0,
187 | help="control the speed of the whole utterance, larger value for slower speaking rate",
188 | )
189 | args = parser.parse_args()
190 |
191 | # Check source texts
192 | if args.mode == "batch":
193 | assert args.source is not None and args.text is None
194 | if args.mode == "single":
195 | assert args.source is None and args.text is not None
196 |
197 | # Read Config
198 | preprocess_config = yaml.load(
199 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
200 | )
201 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
202 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
203 | configs = (preprocess_config, model_config, train_config)
204 |
205 | # Get model
206 | model = get_model(args, configs, device, train=False)
207 |
208 | # Load vocoder
209 | vocoder = get_vocoder(model_config, device)
210 |
211 | # Preprocess texts
212 | if args.mode == "batch":
213 | # Get dataset
214 | dataset = TextDataset(args.source, preprocess_config, model_config)
215 | batchs = DataLoader(
216 | dataset,
217 | batch_size=8,
218 | collate_fn=dataset.collate_fn,
219 | )
220 | tag = None
221 | if args.mode == "single":
222 | emotions = arousals = valences = None
223 | ids = raw_texts = [args.text[:100]]
224 | with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "speakers.json")) as f:
225 | speaker_map = json.load(f)
226 | speakers = np.array([speaker_map[args.speaker_id]])
227 | if model_config["multi_emotion"]:
228 | with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "emotions.json")) as f:
229 | json_raw = json.load(f)
230 | emotion_map = json_raw["emotion_dict"]
231 | arousal_map = json_raw["arousal_dict"]
232 | valence_map = json_raw["valence_dict"]
233 | emotions = np.array([emotion_map[args.emotion_id]])
234 | arousals = np.array([arousal_map[args.arousal]])
235 | valences = np.array([valence_map[args.valence]])
236 | if preprocess_config["preprocessing"]["text"]["language"] == "kr":
237 | texts = np.array([preprocess_korean(args.text, preprocess_config)])
238 | elif preprocess_config["preprocessing"]["text"]["language"] == "en":
239 | texts = np.array([preprocess_english(args.text, preprocess_config)])
240 | text_lens = np.array([len(texts[0])])
241 | batchs = [(ids, raw_texts, speakers, emotions, arousals, valences, texts, text_lens, max(text_lens))]
242 | tag = f"{args.speaker_id}_{args.emotion_id}"
243 |
244 | control_values = args.pitch_control, args.energy_control, args.duration_control
245 |
246 | synthesize(model, args.restore_step, configs, vocoder, batchs, control_values, tag)
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import re
3 | from text import cleaners
4 | from text.symbols import symbols
5 | from .korean_dict import char_to_id, id_to_char
6 |
7 | # Regular expression matching text enclosed in curly braces:
8 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
9 |
10 |
11 | def text_to_sequence(text, cleaner_names):
12 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13 |
14 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
15 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
16 |
17 | Args:
18 | text: string to convert to a sequence
19 | cleaner_names: names of the cleaner functions to run the text through
20 |
21 | Returns:
22 | List of integers corresponding to the symbols in the text
23 | """
24 | sequence = []
25 |
26 | # Mappings from symbol to numeric ID and vice versa:
27 | _language, _symbol_to_id, _ = ("kr", char_to_id, id_to_char) if "korean_cleaners" in cleaner_names\
28 | else ("en", {s: i for i, s in enumerate(symbols)}, {i: s for i, s in enumerate(symbols)})
29 |
30 | # Check for curly braces and treat their contents as ARPAbet:
31 | while len(text):
32 | m = _curly_re.match(text)
33 |
34 | if not m:
35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names), _symbol_to_id)
36 | break
37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names), _symbol_to_id)
38 | sequence += _arpabet_to_sequence(m.group(2), _language, _symbol_to_id)
39 | text = m.group(3)
40 |
41 | return sequence
42 |
43 |
44 | def _clean_text(text, cleaner_names):
45 | for name in cleaner_names:
46 | cleaner = getattr(cleaners, name)
47 | if not cleaner:
48 | raise Exception("Unknown cleaner: %s" % name)
49 | text = cleaner(text)
50 | return text
51 |
52 |
53 | def _symbols_to_sequence(symbols, _symbol_to_id):
54 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s, _symbol_to_id)]
55 |
56 |
57 | def _arpabet_to_sequence(text, _language, _symbol_to_id):
58 | if _language == "kr":
59 | return _symbols_to_sequence([s for s in text.split()], _symbol_to_id)
60 | return _symbols_to_sequence(["@" + s for s in text.split()], _symbol_to_id)
61 |
62 |
63 | def _should_keep_symbol(s, _symbol_to_id):
64 | return s in _symbol_to_id and s != "_" and s != "~"
65 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 | import re
18 | from unidecode import unidecode
19 | from .numbers import normalize_numbers
20 | from .korean import normalize
21 | _whitespace_re = re.compile(r'\s+')
22 |
23 | # List of (regular expression, replacement) pairs for abbreviations:
24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
25 | ('mrs', 'misess'),
26 | ('mr', 'mister'),
27 | ('dr', 'doctor'),
28 | ('st', 'saint'),
29 | ('co', 'company'),
30 | ('jr', 'junior'),
31 | ('maj', 'major'),
32 | ('gen', 'general'),
33 | ('drs', 'doctors'),
34 | ('rev', 'reverend'),
35 | ('lt', 'lieutenant'),
36 | ('hon', 'honorable'),
37 | ('sgt', 'sergeant'),
38 | ('capt', 'captain'),
39 | ('esq', 'esquire'),
40 | ('ltd', 'limited'),
41 | ('col', 'colonel'),
42 | ('ft', 'fort'),
43 | ]]
44 |
45 |
46 | def expand_abbreviations(text):
47 | for regex, replacement in _abbreviations:
48 | text = re.sub(regex, replacement, text)
49 | return text
50 |
51 |
52 | def expand_numbers(text):
53 | return normalize_numbers(text)
54 |
55 |
56 | def lowercase(text):
57 | return text.lower()
58 |
59 |
60 | def collapse_whitespace(text):
61 | return re.sub(_whitespace_re, ' ', text)
62 |
63 |
64 | def convert_to_ascii(text):
65 | return unidecode(text)
66 |
67 |
68 | def basic_cleaners(text):
69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
70 | text = lowercase(text)
71 | text = collapse_whitespace(text)
72 | return text
73 |
74 |
75 | def transliteration_cleaners(text):
76 | '''Pipeline for non-English text that transliterates to ASCII.'''
77 | text = convert_to_ascii(text)
78 | text = lowercase(text)
79 | text = collapse_whitespace(text)
80 | return text
81 |
82 |
83 | def english_cleaners(text):
84 | '''Pipeline for English text, including number and abbreviation expansion.'''
85 | text = convert_to_ascii(text)
86 | text = lowercase(text)
87 | text = expand_numbers(text)
88 | text = expand_abbreviations(text)
89 | text = collapse_whitespace(text)
90 | return text
91 |
92 |
93 | def korean_cleaners(text):
94 | '''Pipeline for Korean (Hangul) text, including number and abbreviation expansion.'''
95 | text = normalize(text)
96 | text = collapse_whitespace(text)
97 | return text
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | "AA",
8 | "AA0",
9 | "AA1",
10 | "AA2",
11 | "AE",
12 | "AE0",
13 | "AE1",
14 | "AE2",
15 | "AH",
16 | "AH0",
17 | "AH1",
18 | "AH2",
19 | "AO",
20 | "AO0",
21 | "AO1",
22 | "AO2",
23 | "AW",
24 | "AW0",
25 | "AW1",
26 | "AW2",
27 | "AY",
28 | "AY0",
29 | "AY1",
30 | "AY2",
31 | "B",
32 | "CH",
33 | "D",
34 | "DH",
35 | "EH",
36 | "EH0",
37 | "EH1",
38 | "EH2",
39 | "ER",
40 | "ER0",
41 | "ER1",
42 | "ER2",
43 | "EY",
44 | "EY0",
45 | "EY1",
46 | "EY2",
47 | "F",
48 | "G",
49 | "HH",
50 | "IH",
51 | "IH0",
52 | "IH1",
53 | "IH2",
54 | "IY",
55 | "IY0",
56 | "IY1",
57 | "IY2",
58 | "JH",
59 | "K",
60 | "L",
61 | "M",
62 | "N",
63 | "NG",
64 | "OW",
65 | "OW0",
66 | "OW1",
67 | "OW2",
68 | "OY",
69 | "OY0",
70 | "OY1",
71 | "OY2",
72 | "P",
73 | "R",
74 | "S",
75 | "SH",
76 | "T",
77 | "TH",
78 | "UH",
79 | "UH0",
80 | "UH1",
81 | "UH2",
82 | "UW",
83 | "UW0",
84 | "UW1",
85 | "UW2",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | ]
92 |
93 | _valid_symbol_set = set(valid_symbols)
94 |
95 |
96 | class CMUDict:
97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98 |
99 | def __init__(self, file_or_path, keep_ambiguous=True):
100 | if isinstance(file_or_path, str):
101 | with open(file_or_path, encoding="latin-1") as f:
102 | entries = _parse_cmudict(f)
103 | else:
104 | entries = _parse_cmudict(file_or_path)
105 | if not keep_ambiguous:
106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107 | self._entries = entries
108 |
109 | def __len__(self):
110 | return len(self._entries)
111 |
112 | def lookup(self, word):
113 | """Returns list of ARPAbet pronunciations of the given word."""
114 | return self._entries.get(word.upper())
115 |
116 |
117 | _alt_re = re.compile(r"\([0-9]+\)")
118 |
119 |
120 | def _parse_cmudict(file):
121 | cmudict = {}
122 | for line in file:
123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124 | parts = line.split(" ")
125 | word = re.sub(_alt_re, "", parts[0])
126 | pronunciation = _get_pronunciation(parts[1])
127 | if pronunciation:
128 | if word in cmudict:
129 | cmudict[word].append(pronunciation)
130 | else:
131 | cmudict[word] = [pronunciation]
132 | return cmudict
133 |
134 |
135 | def _get_pronunciation(s):
136 | parts = s.strip().split(" ")
137 | for part in parts:
138 | if part not in _valid_symbol_set:
139 | return None
140 | return " ".join(parts)
141 |
--------------------------------------------------------------------------------
/text/korean.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | # Code based on
3 |
4 | import re
5 | import os
6 | import ast
7 | import json
8 | from quickspacer import Spacer
9 | from g2pk import G2p
10 | from jamo import hangul_to_jamo, h2j, j2h
11 | from jamo.jamo import _jamo_char_to_hcj
12 | from .korean_dict import JAMO_LEADS, JAMO_VOWELS, JAMO_TAILS, ko_dict
13 |
14 | g2p=G2p()
15 | spacer = Spacer(level=3)
16 |
17 |
18 | def tokenize(text, norm=True):
19 | """
20 | Input --- Grapheme in string
21 | Output --- Phoneme in list
22 |
23 | Example:
24 | '한글은 위대하다.' --> ['ᄒ', 'ᅡ', 'ᆫ', 'ᄀ', 'ᅳ', 'ᄅ', 'ᅳ', ' ', 'ᄂ', 'ᅱ', 'ᄃ', 'ᅢ', 'ᄒ', 'ᅡ', 'ᄃ', 'ᅡ', '.']
25 | """
26 | if norm:
27 | text = normalize(text)
28 | text = g2p(text)
29 | tokens = list(hangul_to_jamo(text))
30 |
31 | return tokens
32 |
33 |
34 | def detokenize(tokens):
35 | """s
36 | Input --- Grapheme or Phoneme in list
37 | Output --- Grapheme or Phoneme in string
38 |
39 | Example:
40 | ['ᄒ', 'ᅡ', 'ᆫ', 'ᄀ', 'ᅳ', 'ᆯ', 'ᄋ', 'ᅳ', 'ᆫ', ' ', 'ᄋ', 'ᅱ', 'ᄃ', 'ᅢ', 'ᄒ', 'ᅡ', 'ᄃ', 'ᅡ', '.'] --> '한글은 위대하다.'
41 | ['ᄒ', 'ᅡ', 'ᆫ', 'ᄀ', 'ᅳ', 'ᄅ', 'ᅳ', ' ', 'ᄂ', 'ᅱ', 'ᄃ', 'ᅢ', 'ᄒ', 'ᅡ', 'ᄃ', 'ᅡ', '.'] --> '한그르 뉘대하다.'
42 | """
43 | tokens = h2j(tokens)
44 |
45 | idx = 0
46 | text = ""
47 | candidates = []
48 |
49 | while True:
50 | if idx >= len(tokens):
51 | text += _get_text_from_candidates(candidates)
52 | break
53 |
54 | char = tokens[idx]
55 | mode = _get_mode(char)
56 |
57 | if mode == 0:
58 | text += _get_text_from_candidates(candidates)
59 | candidates = [char]
60 | elif mode == -1:
61 | text += _get_text_from_candidates(candidates)
62 | text += char
63 | candidates = []
64 | else:
65 | candidates.append(char)
66 |
67 | idx += 1
68 | return text
69 |
70 |
71 | def _get_mode(char):
72 | if char in JAMO_LEADS:
73 | return 0
74 | elif char in JAMO_VOWELS:
75 | return 1
76 | elif char in JAMO_TAILS:
77 | return 2
78 | else:
79 | return -1
80 |
81 |
82 | def _get_text_from_candidates(candidates):
83 | if len(candidates) == 0:
84 | return ""
85 | elif len(candidates) == 1:
86 | return _jamo_char_to_hcj(candidates[0])
87 | else:
88 | return j2h(**dict(zip(["lead", "vowel", "tail"], candidates)))
89 |
90 |
91 | def compare_sentence_with_jamo(text1, text2):
92 | return h2j(text1) != h2j(text2)
93 |
94 |
95 | def normalize(text):
96 | """
97 | Transliterate input text into Hangul grapheme.
98 | """
99 | text = text.strip()
100 | text = re.sub('\(\d+일\)', '', text)
101 | text = re.sub('\([⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]+\)', '', text)
102 |
103 | text = normalize_with_dictionary(text, ko_dict["etc_dictionary"])
104 | text = normalize_english(text)
105 | text = re.sub('[a-zA-Z]+', normalize_upper, text)
106 |
107 | text = normalize_quote(text)
108 | text = normalize_number(text)
109 | text = normalize_nonchar(text)
110 | text = spacer.space([text])[0]
111 |
112 | return text
113 |
114 |
115 | def normalize_nonchar(text, inference=False):
116 | return re.sub(r"\{[^\w\s]?\}", "{sp}", text) if inference else\
117 | re.sub(r"[^\w\s]?", "", text)
118 |
119 |
120 | def normalize_with_dictionary(text, dic):
121 | if any(key in text for key in dic.keys()):
122 | pattern = re.compile('|'.join(re.escape(key) for key in dic.keys()))
123 | return pattern.sub(lambda x: dic[x.group()], text)
124 | else:
125 | return text
126 |
127 |
128 | def normalize_english(text):
129 | def fn(m):
130 | word = m.group()
131 | if word in ko_dict["english_dictionary"]:
132 | return ko_dict["english_dictionary"].get(word)
133 | else:
134 | return word
135 |
136 | text = re.sub("([A-Za-z]+)", fn, text)
137 | return text
138 |
139 |
140 | def normalize_upper(text):
141 | text = text.group(0)
142 |
143 | if all([char.isupper() for char in text]):
144 | return "".join(ko_dict["upper_to_kor"][char] for char in text)
145 | else:
146 | return text
147 |
148 |
149 | def normalize_quote(text):
150 | def fn(found_text):
151 | from nltk import sent_tokenize # NLTK doesn't along with multiprocessing
152 |
153 | found_text = found_text.group()
154 | unquoted_text = found_text[1:-1]
155 |
156 | sentences = sent_tokenize(unquoted_text)
157 | return " ".join(["'{}'".format(sent) for sent in sentences])
158 |
159 | return re.sub(ko_dict["quote_checker"], fn, text)
160 |
161 |
162 | def normalize_number(text):
163 | text = normalize_with_dictionary(text, ko_dict["unit_to_kor"])
164 | text = re.sub(ko_dict["number_checker"] + ko_dict["count_checker"],
165 | lambda x: number_to_korean(x, True), text)
166 | text = re.sub(ko_dict["number_checker"],
167 | lambda x: number_to_korean(x, False), text)
168 | return text
169 |
170 |
171 | def number_to_korean(num_str, is_count=False):
172 | zero_cnt = 0
173 | if is_count:
174 | num_str, unit_str = num_str.group(1), num_str.group(2)
175 | else:
176 | num_str, unit_str = num_str.group(), ""
177 |
178 | num_str = num_str.replace(',', '')
179 |
180 | if is_count and len(num_str) > 2:
181 | is_count = False
182 |
183 | if len(num_str) > 1 and num_str.startswith("0") and '.' not in num_str:
184 | for n in num_str:
185 | zero_cnt += 1 if n == "0" else 0
186 | num_str = num_str[zero_cnt:]
187 |
188 | kor = ""
189 | if num_str != '':
190 | num = ast.literal_eval(num_str)
191 |
192 | if num == 0:
193 | return "영" + (unit_str if unit_str else "")
194 |
195 | check_float = num_str.split('.')
196 | if len(check_float) == 2:
197 | digit_str, float_str = check_float
198 | elif len(check_float) >= 3:
199 | raise Exception(" [!] Wrong number format")
200 | else:
201 | digit_str, float_str = check_float[0], None
202 |
203 | if is_count and float_str is not None:
204 | raise Exception(" [!] `is_count` and float number does not fit each other")
205 |
206 | digit = int(digit_str)
207 |
208 | if digit_str.startswith("-") or digit_str.startswith("+"):
209 | digit, digit_str = abs(digit), str(abs(digit))
210 |
211 | size = len(str(digit))
212 | tmp = []
213 |
214 | for i, v in enumerate(digit_str, start=1):
215 | v = int(v)
216 |
217 | if v != 0:
218 | if is_count:
219 | tmp += ko_dict["count_to_kor1"][v]
220 | else:
221 | tmp += ko_dict["num_to_kor1"][v]
222 | if v == 1 and i != 1 and i != len(digit_str):
223 | tmp = tmp[:-1]
224 | tmp += ko_dict["num_to_kor3"][(size - i) % 4]
225 |
226 | if (size - i) % 4 == 0 and len(tmp) != 0:
227 | kor += "".join(tmp)
228 | tmp = []
229 | kor += ko_dict["num_to_kor2"][int((size - i) / 4)]
230 |
231 | if is_count:
232 | if kor.startswith("한") and len(kor) > 1:
233 | kor = kor[1:]
234 |
235 | if any(word in kor for word in ko_dict["count_tenth_dict"]):
236 | kor = re.sub(
237 | '|'.join(ko_dict["count_tenth_dict"].keys()),
238 | lambda x: ko_dict["count_tenth_dict"][x.group()], kor)
239 |
240 | if not is_count and kor.startswith("일") and len(kor) > 1:
241 | kor = kor[1:]
242 |
243 | if float_str is not None and float_str != "":
244 | kor += "영" if kor == "" else ""
245 | kor += "쩜 "
246 | kor += re.sub('\d', lambda x: ko_dict["num_to_kor"][x.group()], float_str)
247 |
248 | if num_str.startswith("+"):
249 | kor = "플러스 " + kor
250 | elif num_str.startswith("-"):
251 | kor = "마이너스 " + kor
252 | if zero_cnt > 0:
253 | kor = "공"*zero_cnt + kor
254 |
255 | return kor + unit_str
256 |
257 |
258 | def test_normalize(texts):
259 | for text in texts:
260 | raw = text
261 | norm = normalize(text)
262 |
263 | print("="*30)
264 | print(raw)
265 | print(norm)
266 |
267 |
268 | if __name__ == "__main__":
269 | test_inputs = [
270 | "JTBC는 JTBCs를 DY는 A가 Absolute",
271 | "오늘(13일) 3,600마리 강아지가",
272 | "60.3%",
273 | '"저돌"(猪突) 입니다.',
274 | '비대위원장이 지난 1월 이런 말을 했습니다. “난 그냥 산돼지처럼 돌파하는 스타일이다”',
275 | "지금은 -12.35%였고 종류는 5가지와 19가지, 그리고 55가지였다",
276 | "JTBC는 TH와 K 양이 2017년 9월 12일 오후 12시에 24살이 된다",
277 | "이렇게 세트로 98,000원인데, 지금 세일 중이어서, 78,400원이에요.",
278 | "이렇게 세트로 98000원인데, 지금 세일 중이어서, 78400원이에요.",
279 | "저, 토익 970점이요.",
280 | "원래대로라면은 0점 처리해야 하는데.",
281 | "진짜? 그럼 너한테 한 두 마리만 줘도 돼?",
282 | "내가 화분이 좀 많아서. 그래도 17평에서 20평은 됐으면 좋겠어. 요즘 애들, 많이 사는 원룸, 그런데는 말고.",
283 | "매매는 3억까지. 전세는 1억 5천. 그 이상은 안돼.",
284 | "1억 3천이요.",
285 | "지금 3개월입니다.",
286 | "기계값 200만원 짜리를. 30개월 할부로 300만원에 파셨잖아요!",
287 | "오늘(13일) 99통 강아지가",
288 | "이제 55개.. 째예요.",
289 | "이제 55개월.. 째예요.",
290 | "한 근에 3만 5천 원이나 하는 1++ 등급 한우라니까!",
291 | "한 근에 3만 5천 원이나 하는 A+ 등급 한우라니까!",
292 | "19,22,30,34,39,44+36",
293 | "그거 1+1으로 프로모션 때려버리자.",
294 | "아 테이프는 너무 우리 때 얘기인가? 그 MP3 파일 같은 거 있잖아. 영어 중국어 이런 거. 영어 책 읽어주고 그런 거.",
295 | "231 cm야.",
296 | "1 cm야.",
297 | "21 cm야.",
298 | "110 cm야.",
299 | "21마리야.",
300 | "아, 시력은 알고 있어요. 왼쪽 0.3이고 오른쪽 0.1이요.",
301 | "왼쪽 0점",
302 | "우리 스마트폰 쓰기 전에 공일일 번호였을때. 그 때 썼던 전화기를 2G라고 하거든?",
303 | "102마리 강아지.",
304 | "87. 105. 120. 네. 100 넘었어요!",
305 | ]
306 |
307 | test_normalize(test_inputs)
--------------------------------------------------------------------------------
/text/korean_dict.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | # Symbols
3 | PAD = '_'
4 | EOS = '~'
5 | PUNC = '!\'(),-.:;?'
6 | SPACE = ' '
7 | _SILENCES = ['sp', 'spn', 'sil']
8 |
9 | JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
10 | JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
11 | JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
12 |
13 | VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE
14 | ALL_SYMBOLS = list(PAD + EOS + VALID_CHARS) + _SILENCES
15 |
16 | char_to_id = {c: i for i, c in enumerate(ALL_SYMBOLS)}
17 | id_to_char = {i: c for i, c in enumerate(ALL_SYMBOLS)}
18 |
19 | # Dictionaries
20 | ko_dict = {
21 | "quote_checker": """([`"'"“‘’”’])(.+?)([`"'"“‘’”’])""",
22 | "number_checker": "([+-]?\d{1,3},\d{3}(?!\d)|[+-]?\d+)[\.]?\d*", # "([+-]?\d[\d,]*)[\.]?\d*"
23 | "count_checker": "(시|명|가지|살|마리|포기|송이|수|톨|통|점|개(?!월)|벌|척|채|다발|그루|자루|줄|켤레|그릇|잔|마디|상자|사람|곡|병|판)",
24 |
25 | "num_to_kor": {
26 | '0': '영',
27 | '1': '일',
28 | '2': '이',
29 | '3': '삼',
30 | '4': '사',
31 | '5': '오',
32 | '6': '육',
33 | '7': '칠',
34 | '8': '팔',
35 | '9': '구',
36 | },
37 |
38 | "num_to_kor1": [""] + list("일이삼사오육칠팔구"),
39 | "num_to_kor2": [""] + list("만억조경해"),
40 | "num_to_kor3": [""] + list("십백천"),
41 | "count_to_kor1": [""] + ["한","두","세","네","다섯","여섯","일곱","여덟","아홉"], # [""] + ["하나","둘","셋","넷","다섯","여섯","일곱","여덟","아홉"]
42 |
43 | "count_tenth_dict": {
44 | "십": "열",
45 | "두십": "스물",
46 | "세십": "서른",
47 | "네십": "마흔",
48 | "다섯십": "쉰",
49 | "여섯십": "예순",
50 | "일곱십": "일흔",
51 | "여덟십": "여든",
52 | "아홉십": "아흔",
53 | },
54 |
55 | "unit_to_kor": {
56 | '%': '퍼센트',
57 | 'ml': '밀리리터',
58 | 'cm': '센치미터',
59 | 'mm': '밀리미터',
60 | 'km': '킬로미터',
61 | 'kg': '킬로그람',
62 | 'm': '미터',
63 | },
64 |
65 | "upper_to_kor": {
66 | 'A': '에이',
67 | 'B': '비',
68 | 'C': '씨',
69 | 'D': '디',
70 | 'E': '이',
71 | 'F': '에프',
72 | 'G': '지',
73 | 'H': '에이치',
74 | 'I': '아이',
75 | 'J': '제이',
76 | 'K': '케이',
77 | 'L': '엘',
78 | 'M': '엠',
79 | 'N': '엔',
80 | 'O': '오',
81 | 'P': '피',
82 | 'Q': '큐',
83 | 'R': '알',
84 | 'S': '에스',
85 | 'T': '티',
86 | 'U': '유',
87 | 'V': '브이',
88 | 'W': '더블유',
89 | 'X': '엑스',
90 | 'Y': '와이',
91 | 'Z': '지',
92 | },
93 |
94 | "english_dictionary": {
95 | 'TV': '티비',
96 | 'CCTV': '씨씨티비',
97 | 'cctv': '씨씨티비',
98 | 'cc': '씨씨',
99 | 'Apple': '애플',
100 | 'lte': '엘티이',
101 | 'KG': '킬로그람',
102 | 'x': '엑스',
103 | 'z': '제트',
104 | 'Yo': '요',
105 | 'YOLO': '욜로',
106 | 'Gone': '건',
107 | 'gone': '건',
108 | 'Have': '헤브',
109 | 'p': '피',
110 | 'ppt': '피피티',
111 | 'suv': '에스유브이',
112 | },
113 |
114 | "etc_dictionary": {
115 | '1+1': '원 플러스 원',
116 | '+': '플러스',
117 | 'MP3': '엠피쓰리',
118 | '5G': '파이브지',
119 | '4G': '포지',
120 | '3G': '쓰리지',
121 | '2G': '투지',
122 | 'A/S': '에이 에스',
123 | '1/3':'삼분의 일',
124 | 'greentea907': '그린티 구공칠',
125 | 'CNT 123': '씨엔티 일이삼',
126 | '14학번': '일사 학번',
127 | '7011번': '칠공일일번',
128 | 'P8학원': '피에잇 학원',
129 | '102마리': '백 두마리',
130 | '20명': '스무명',
131 | }
132 | }
133 |
134 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars" # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return "%s %s" % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return "%s %s" % (cents, cent_unit)
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(
60 | num, andword="", zero="oh", group=2
61 | ).replace(", ", " ")
62 | else:
63 | return _inflect.number_to_words(num, andword="")
64 |
65 |
66 | def normalize_numbers(text):
67 | text = re.sub(_comma_number_re, _remove_commas, text)
68 | text = re.sub(_pounds_re, r"\1 pounds", text)
69 | text = re.sub(_dollars_re, _expand_dollars, text)
70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71 | text = re.sub(_ordinal_re, _expand_ordinal, text)
72 | text = re.sub(_number_re, _expand_number, text)
73 | return text
74 |
--------------------------------------------------------------------------------
/text/pinyin.py:
--------------------------------------------------------------------------------
1 | initials = [
2 | "b",
3 | "c",
4 | "ch",
5 | "d",
6 | "f",
7 | "g",
8 | "h",
9 | "j",
10 | "k",
11 | "l",
12 | "m",
13 | "n",
14 | "p",
15 | "q",
16 | "r",
17 | "s",
18 | "sh",
19 | "t",
20 | "w",
21 | "x",
22 | "y",
23 | "z",
24 | "zh",
25 | ]
26 | finals = [
27 | "a1",
28 | "a2",
29 | "a3",
30 | "a4",
31 | "a5",
32 | "ai1",
33 | "ai2",
34 | "ai3",
35 | "ai4",
36 | "ai5",
37 | "an1",
38 | "an2",
39 | "an3",
40 | "an4",
41 | "an5",
42 | "ang1",
43 | "ang2",
44 | "ang3",
45 | "ang4",
46 | "ang5",
47 | "ao1",
48 | "ao2",
49 | "ao3",
50 | "ao4",
51 | "ao5",
52 | "e1",
53 | "e2",
54 | "e3",
55 | "e4",
56 | "e5",
57 | "ei1",
58 | "ei2",
59 | "ei3",
60 | "ei4",
61 | "ei5",
62 | "en1",
63 | "en2",
64 | "en3",
65 | "en4",
66 | "en5",
67 | "eng1",
68 | "eng2",
69 | "eng3",
70 | "eng4",
71 | "eng5",
72 | "er1",
73 | "er2",
74 | "er3",
75 | "er4",
76 | "er5",
77 | "i1",
78 | "i2",
79 | "i3",
80 | "i4",
81 | "i5",
82 | "ia1",
83 | "ia2",
84 | "ia3",
85 | "ia4",
86 | "ia5",
87 | "ian1",
88 | "ian2",
89 | "ian3",
90 | "ian4",
91 | "ian5",
92 | "iang1",
93 | "iang2",
94 | "iang3",
95 | "iang4",
96 | "iang5",
97 | "iao1",
98 | "iao2",
99 | "iao3",
100 | "iao4",
101 | "iao5",
102 | "ie1",
103 | "ie2",
104 | "ie3",
105 | "ie4",
106 | "ie5",
107 | "ii1",
108 | "ii2",
109 | "ii3",
110 | "ii4",
111 | "ii5",
112 | "iii1",
113 | "iii2",
114 | "iii3",
115 | "iii4",
116 | "iii5",
117 | "in1",
118 | "in2",
119 | "in3",
120 | "in4",
121 | "in5",
122 | "ing1",
123 | "ing2",
124 | "ing3",
125 | "ing4",
126 | "ing5",
127 | "iong1",
128 | "iong2",
129 | "iong3",
130 | "iong4",
131 | "iong5",
132 | "iou1",
133 | "iou2",
134 | "iou3",
135 | "iou4",
136 | "iou5",
137 | "o1",
138 | "o2",
139 | "o3",
140 | "o4",
141 | "o5",
142 | "ong1",
143 | "ong2",
144 | "ong3",
145 | "ong4",
146 | "ong5",
147 | "ou1",
148 | "ou2",
149 | "ou3",
150 | "ou4",
151 | "ou5",
152 | "u1",
153 | "u2",
154 | "u3",
155 | "u4",
156 | "u5",
157 | "ua1",
158 | "ua2",
159 | "ua3",
160 | "ua4",
161 | "ua5",
162 | "uai1",
163 | "uai2",
164 | "uai3",
165 | "uai4",
166 | "uai5",
167 | "uan1",
168 | "uan2",
169 | "uan3",
170 | "uan4",
171 | "uan5",
172 | "uang1",
173 | "uang2",
174 | "uang3",
175 | "uang4",
176 | "uang5",
177 | "uei1",
178 | "uei2",
179 | "uei3",
180 | "uei4",
181 | "uei5",
182 | "uen1",
183 | "uen2",
184 | "uen3",
185 | "uen4",
186 | "uen5",
187 | "uo1",
188 | "uo2",
189 | "uo3",
190 | "uo4",
191 | "uo5",
192 | "v1",
193 | "v2",
194 | "v3",
195 | "v4",
196 | "v5",
197 | "van1",
198 | "van2",
199 | "van3",
200 | "van4",
201 | "van5",
202 | "ve1",
203 | "ve2",
204 | "ve3",
205 | "ve4",
206 | "ve5",
207 | "vn1",
208 | "vn2",
209 | "vn3",
210 | "vn4",
211 | "vn5",
212 | ]
213 | valid_symbols = initials + finals + ["rr"]
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
7 |
8 | from text import cmudict, pinyin
9 |
10 | _pad = "_"
11 | _punctuation = "!'(),.:;? "
12 | _special = "-"
13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
14 | _silences = ["@sp", "@spn", "@sil"]
15 |
16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
17 | _arpabet = ["@" + s for s in cmudict.valid_symbols]
18 | _pinyin = ["@" + s for s in pinyin.valid_symbols]
19 |
20 | # Export all symbols:
21 | symbols = (
22 | [_pad]
23 | + list(_special)
24 | + list(_punctuation)
25 | + list(_letters)
26 | + _arpabet
27 | + _pinyin
28 | + _silences
29 | )
30 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import yaml
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 | from torch.utils.tensorboard import SummaryWriter
9 | from tqdm import tqdm
10 |
11 | from utils.model import get_model, get_vocoder, get_param_num
12 | from utils.tools import to_device, log, synth_one_sample
13 | from model import FastSpeech2Loss
14 | from dataset import Dataset
15 |
16 | from evaluate import evaluate
17 |
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 |
21 | def main(args, configs):
22 | print("Prepare training ...")
23 |
24 | preprocess_config, model_config, train_config = configs
25 |
26 | # Get dataset
27 | dataset = Dataset(
28 | "train.txt", preprocess_config, model_config, train_config, sort=True, drop_last=True
29 | )
30 | batch_size = train_config["optimizer"]["batch_size"]
31 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset
32 | assert batch_size * group_size < len(dataset)
33 | loader = DataLoader(
34 | dataset,
35 | batch_size=batch_size * group_size,
36 | shuffle=True,
37 | collate_fn=dataset.collate_fn,
38 | )
39 |
40 | # Prepare model
41 | model, optimizer = get_model(args, configs, device, train=True)
42 | model = nn.DataParallel(model)
43 | num_param = get_param_num(model)
44 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
45 | print("Number of FastSpeech2 Parameters:", num_param)
46 |
47 | # Load vocoder
48 | vocoder = get_vocoder(model_config, device)
49 |
50 | # Init logger
51 | for p in train_config["path"].values():
52 | os.makedirs(p, exist_ok=True)
53 | train_log_path = os.path.join(train_config["path"]["log_path"], "train")
54 | val_log_path = os.path.join(train_config["path"]["log_path"], "val")
55 | os.makedirs(train_log_path, exist_ok=True)
56 | os.makedirs(val_log_path, exist_ok=True)
57 | train_logger = SummaryWriter(train_log_path)
58 | val_logger = SummaryWriter(val_log_path)
59 |
60 | # Training
61 | step = args.restore_step + 1
62 | epoch = 1
63 | grad_acc_step = train_config["optimizer"]["grad_acc_step"]
64 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
65 | total_step = train_config["step"]["total_step"]
66 | log_step = train_config["step"]["log_step"]
67 | save_step = train_config["step"]["save_step"]
68 | synth_step = train_config["step"]["synth_step"]
69 | val_step = train_config["step"]["val_step"]
70 |
71 | outer_bar = tqdm(total=total_step, desc="Training", position=0)
72 | outer_bar.n = args.restore_step
73 | outer_bar.update()
74 |
75 | while True:
76 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
77 | for batchs in loader:
78 | for batch in batchs:
79 | batch = to_device(batch, device)
80 |
81 | # Forward
82 | output = model(*(batch[2:]))
83 |
84 | # Cal Loss
85 | losses = Loss(batch, output)
86 | total_loss = losses[0]
87 |
88 | # Backward
89 | total_loss = total_loss / grad_acc_step
90 | total_loss.backward()
91 | if step % grad_acc_step == 0:
92 | # Clipping gradients to avoid gradient explosion
93 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
94 |
95 | # Update weights
96 | optimizer.step_and_update_lr()
97 | optimizer.zero_grad()
98 |
99 | if step % log_step == 0:
100 | losses = [l.item() for l in losses]
101 | message1 = "Step {}/{}, ".format(step, total_step)
102 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
103 | *losses
104 | )
105 |
106 | with open(os.path.join(train_log_path, "log.txt"), "a") as f:
107 | f.write(message1 + message2 + "\n")
108 |
109 | outer_bar.write(message1 + message2)
110 |
111 | log(train_logger, step, losses=losses)
112 |
113 | if step % synth_step == 0:
114 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
115 | batch,
116 | output,
117 | vocoder,
118 | model_config,
119 | preprocess_config,
120 | )
121 | log(
122 | train_logger,
123 | fig=fig,
124 | tag="Training/step_{}_{}".format(step, tag),
125 | )
126 | sampling_rate = preprocess_config["preprocessing"]["audio"][
127 | "sampling_rate"
128 | ]
129 | log(
130 | train_logger,
131 | audio=wav_reconstruction,
132 | sampling_rate=sampling_rate,
133 | tag="Training/step_{}_{}_reconstructed".format(step, tag),
134 | )
135 | log(
136 | train_logger,
137 | audio=wav_prediction,
138 | sampling_rate=sampling_rate,
139 | tag="Training/step_{}_{}_synthesized".format(step, tag),
140 | )
141 |
142 | if step % val_step == 0:
143 | model.eval()
144 | message = evaluate(model, step, configs, val_logger, vocoder)
145 | with open(os.path.join(val_log_path, "log.txt"), "a") as f:
146 | f.write(message + "\n")
147 | outer_bar.write(message)
148 |
149 | model.train()
150 |
151 | if step % save_step == 0:
152 | torch.save(
153 | {
154 | "model": model.module.state_dict(),
155 | "optimizer": optimizer._optimizer.state_dict(),
156 | },
157 | os.path.join(
158 | train_config["path"]["ckpt_path"],
159 | "{}.pth.tar".format(step),
160 | ),
161 | )
162 |
163 | if step == total_step:
164 | quit()
165 | step += 1
166 | outer_bar.update(1)
167 |
168 | inner_bar.update(1)
169 | epoch += 1
170 |
171 |
172 | if __name__ == "__main__":
173 | parser = argparse.ArgumentParser()
174 | parser.add_argument("--restore_step", type=int, default=0)
175 | parser.add_argument(
176 | "-p",
177 | "--preprocess_config",
178 | type=str,
179 | required=True,
180 | help="path to preprocess.yaml",
181 | )
182 | parser.add_argument(
183 | "-m", "--model_config", type=str, required=True, help="path to model.yaml"
184 | )
185 | parser.add_argument(
186 | "-t", "--train_config", type=str, required=True, help="path to train.yaml"
187 | )
188 | args = parser.parse_args()
189 |
190 | # Read Config
191 | preprocess_config = yaml.load(
192 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
193 | )
194 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
195 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
196 | configs = (preprocess_config, model_config, train_config)
197 |
198 | main(args, configs)
199 |
--------------------------------------------------------------------------------
/transformer/Constants.py:
--------------------------------------------------------------------------------
1 | PAD = 0
2 | UNK = 1
3 | BOS = 2
4 | EOS = 3
5 |
6 | PAD_WORD = ""
7 | UNK_WORD = ""
8 | BOS_WORD = ""
9 | EOS_WORD = ""
10 |
--------------------------------------------------------------------------------
/transformer/Layers.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | from torch.nn import functional as F
7 |
8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
9 |
10 |
11 | class FFTBlock(torch.nn.Module):
12 | """FFT Block"""
13 |
14 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
15 | super(FFTBlock, self).__init__()
16 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
17 | self.pos_ffn = PositionwiseFeedForward(
18 | d_model, d_inner, kernel_size, dropout=dropout
19 | )
20 |
21 | def forward(self, enc_input, mask=None, slf_attn_mask=None):
22 | enc_output, enc_slf_attn = self.slf_attn(
23 | enc_input, enc_input, enc_input, mask=slf_attn_mask
24 | )
25 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
26 |
27 | enc_output = self.pos_ffn(enc_output)
28 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
29 |
30 | return enc_output, enc_slf_attn
31 |
32 |
33 | class ConvNorm(torch.nn.Module):
34 | def __init__(
35 | self,
36 | in_channels,
37 | out_channels,
38 | kernel_size=1,
39 | stride=1,
40 | padding=None,
41 | dilation=1,
42 | bias=True,
43 | w_init_gain="linear",
44 | ):
45 | super(ConvNorm, self).__init__()
46 |
47 | if padding is None:
48 | assert kernel_size % 2 == 1
49 | padding = int(dilation * (kernel_size - 1) / 2)
50 |
51 | self.conv = torch.nn.Conv1d(
52 | in_channels,
53 | out_channels,
54 | kernel_size=kernel_size,
55 | stride=stride,
56 | padding=padding,
57 | dilation=dilation,
58 | bias=bias,
59 | )
60 |
61 | def forward(self, signal):
62 | conv_signal = self.conv(signal)
63 |
64 | return conv_signal
65 |
66 |
67 | class PostNet(nn.Module):
68 | """
69 | PostNet: Five 1-d convolution with 512 channels and kernel size 5
70 | """
71 |
72 | def __init__(
73 | self,
74 | n_mel_channels=80,
75 | postnet_embedding_dim=512,
76 | postnet_kernel_size=5,
77 | postnet_n_convolutions=5,
78 | ):
79 |
80 | super(PostNet, self).__init__()
81 | self.convolutions = nn.ModuleList()
82 |
83 | self.convolutions.append(
84 | nn.Sequential(
85 | ConvNorm(
86 | n_mel_channels,
87 | postnet_embedding_dim,
88 | kernel_size=postnet_kernel_size,
89 | stride=1,
90 | padding=int((postnet_kernel_size - 1) / 2),
91 | dilation=1,
92 | w_init_gain="tanh",
93 | ),
94 | nn.BatchNorm1d(postnet_embedding_dim),
95 | )
96 | )
97 |
98 | for i in range(1, postnet_n_convolutions - 1):
99 | self.convolutions.append(
100 | nn.Sequential(
101 | ConvNorm(
102 | postnet_embedding_dim,
103 | postnet_embedding_dim,
104 | kernel_size=postnet_kernel_size,
105 | stride=1,
106 | padding=int((postnet_kernel_size - 1) / 2),
107 | dilation=1,
108 | w_init_gain="tanh",
109 | ),
110 | nn.BatchNorm1d(postnet_embedding_dim),
111 | )
112 | )
113 |
114 | self.convolutions.append(
115 | nn.Sequential(
116 | ConvNorm(
117 | postnet_embedding_dim,
118 | n_mel_channels,
119 | kernel_size=postnet_kernel_size,
120 | stride=1,
121 | padding=int((postnet_kernel_size - 1) / 2),
122 | dilation=1,
123 | w_init_gain="linear",
124 | ),
125 | nn.BatchNorm1d(n_mel_channels),
126 | )
127 | )
128 |
129 | def forward(self, x):
130 | x = x.contiguous().transpose(1, 2)
131 |
132 | for i in range(len(self.convolutions) - 1):
133 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
134 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
135 |
136 | x = x.contiguous().transpose(1, 2)
137 | return x
138 |
--------------------------------------------------------------------------------
/transformer/Models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | import transformer.Constants as Constants
6 | from .Layers import FFTBlock
7 | from text.symbols import symbols
8 |
9 |
10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
11 | """ Sinusoid position encoding table """
12 |
13 | def cal_angle(position, hid_idx):
14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
15 |
16 | def get_posi_angle_vec(position):
17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
18 |
19 | sinusoid_table = np.array(
20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
21 | )
22 |
23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
25 |
26 | if padding_idx is not None:
27 | # zero vector for padding dimension
28 | sinusoid_table[padding_idx] = 0.0
29 |
30 | return torch.FloatTensor(sinusoid_table)
31 |
32 |
33 | class Encoder(nn.Module):
34 | """ Encoder """
35 |
36 | def __init__(self, config):
37 | super(Encoder, self).__init__()
38 |
39 | n_position = config["max_seq_len"] + 1
40 | n_src_vocab = len(symbols) + 1
41 | d_word_vec = config["transformer"]["encoder_hidden"]
42 | n_layers = config["transformer"]["encoder_layer"]
43 | n_head = config["transformer"]["encoder_head"]
44 | d_k = d_v = (
45 | config["transformer"]["encoder_hidden"]
46 | // config["transformer"]["encoder_head"]
47 | )
48 | d_model = config["transformer"]["encoder_hidden"]
49 | d_inner = config["transformer"]["conv_filter_size"]
50 | kernel_size = config["transformer"]["conv_kernel_size"]
51 | dropout = config["transformer"]["encoder_dropout"]
52 |
53 | self.max_seq_len = config["max_seq_len"]
54 | self.d_model = d_model
55 |
56 | self.src_word_emb = nn.Embedding(
57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD
58 | )
59 | self.position_enc = nn.Parameter(
60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
61 | requires_grad=False,
62 | )
63 |
64 | self.layer_stack = nn.ModuleList(
65 | [
66 | FFTBlock(
67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
68 | )
69 | for _ in range(n_layers)
70 | ]
71 | )
72 |
73 | def forward(self, src_seq, mask, return_attns=False):
74 |
75 | enc_slf_attn_list = []
76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
77 |
78 | # -- Prepare masks
79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
80 |
81 | # -- Forward
82 | if not self.training and src_seq.shape[1] > self.max_seq_len:
83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table(
84 | src_seq.shape[1], self.d_model
85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
86 | src_seq.device
87 | )
88 | else:
89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[
90 | :, :max_len, :
91 | ].expand(batch_size, -1, -1)
92 |
93 | for enc_layer in self.layer_stack:
94 | enc_output, enc_slf_attn = enc_layer(
95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask
96 | )
97 | if return_attns:
98 | enc_slf_attn_list += [enc_slf_attn]
99 |
100 | return enc_output
101 |
102 |
103 | class Decoder(nn.Module):
104 | """ Decoder """
105 |
106 | def __init__(self, config):
107 | super(Decoder, self).__init__()
108 |
109 | n_position = config["max_seq_len"] + 1
110 | d_word_vec = config["transformer"]["decoder_hidden"]
111 | n_layers = config["transformer"]["decoder_layer"]
112 | n_head = config["transformer"]["decoder_head"]
113 | d_k = d_v = (
114 | config["transformer"]["decoder_hidden"]
115 | // config["transformer"]["decoder_head"]
116 | )
117 | d_model = config["transformer"]["decoder_hidden"]
118 | d_inner = config["transformer"]["conv_filter_size"]
119 | kernel_size = config["transformer"]["conv_kernel_size"]
120 | dropout = config["transformer"]["decoder_dropout"]
121 |
122 | self.max_seq_len = config["max_seq_len"]
123 | self.d_model = d_model
124 |
125 | self.position_enc = nn.Parameter(
126 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
127 | requires_grad=False,
128 | )
129 |
130 | self.layer_stack = nn.ModuleList(
131 | [
132 | FFTBlock(
133 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
134 | )
135 | for _ in range(n_layers)
136 | ]
137 | )
138 |
139 | def forward(self, enc_seq, mask, return_attns=False):
140 |
141 | dec_slf_attn_list = []
142 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
143 |
144 | # -- Forward
145 | if not self.training and enc_seq.shape[1] > self.max_seq_len:
146 | # -- Prepare masks
147 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
148 | dec_output = enc_seq + get_sinusoid_encoding_table(
149 | enc_seq.shape[1], self.d_model
150 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
151 | enc_seq.device
152 | )
153 | else:
154 | max_len = min(max_len, self.max_seq_len)
155 |
156 | # -- Prepare masks
157 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
158 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[
159 | :, :max_len, :
160 | ].expand(batch_size, -1, -1)
161 | mask = mask[:, :max_len]
162 | slf_attn_mask = slf_attn_mask[:, :, :max_len]
163 |
164 | for dec_layer in self.layer_stack:
165 | dec_output, dec_slf_attn = dec_layer(
166 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask
167 | )
168 | if return_attns:
169 | dec_slf_attn_list += [dec_slf_attn]
170 |
171 | return dec_output, mask
172 |
--------------------------------------------------------------------------------
/transformer/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class ScaledDotProductAttention(nn.Module):
7 | """ Scaled Dot-Product Attention """
8 |
9 | def __init__(self, temperature):
10 | super().__init__()
11 | self.temperature = temperature
12 | self.softmax = nn.Softmax(dim=2)
13 |
14 | def forward(self, q, k, v, mask=None):
15 |
16 | attn = torch.bmm(q, k.transpose(1, 2))
17 | attn = attn / self.temperature
18 |
19 | if mask is not None:
20 | attn = attn.masked_fill(mask, -np.inf)
21 |
22 | attn = self.softmax(attn)
23 | output = torch.bmm(attn, v)
24 |
25 | return output, attn
26 |
--------------------------------------------------------------------------------
/transformer/SubLayers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from .Modules import ScaledDotProductAttention
6 |
7 |
8 | class MultiHeadAttention(nn.Module):
9 | """ Multi-Head Attention module """
10 |
11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
12 | super().__init__()
13 |
14 | self.n_head = n_head
15 | self.d_k = d_k
16 | self.d_v = d_v
17 |
18 | self.w_qs = nn.Linear(d_model, n_head * d_k)
19 | self.w_ks = nn.Linear(d_model, n_head * d_k)
20 | self.w_vs = nn.Linear(d_model, n_head * d_v)
21 |
22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
23 | self.layer_norm = nn.LayerNorm(d_model)
24 |
25 | self.fc = nn.Linear(n_head * d_v, d_model)
26 |
27 | self.dropout = nn.Dropout(dropout)
28 |
29 | def forward(self, q, k, v, mask=None):
30 |
31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
32 |
33 | sz_b, len_q, _ = q.size()
34 | sz_b, len_k, _ = k.size()
35 | sz_b, len_v, _ = v.size()
36 |
37 | residual = q
38 |
39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
45 |
46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
47 | output, attn = self.attention(q, k, v, mask=mask)
48 |
49 | output = output.view(n_head, sz_b, len_q, d_v)
50 | output = (
51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)
52 | ) # b x lq x (n*dv)
53 |
54 | output = self.dropout(self.fc(output))
55 | output = self.layer_norm(output + residual)
56 |
57 | return output, attn
58 |
59 |
60 | class PositionwiseFeedForward(nn.Module):
61 | """ A two-feed-forward-layer module """
62 |
63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1):
64 | super().__init__()
65 |
66 | # Use Conv1D
67 | # position-wise
68 | self.w_1 = nn.Conv1d(
69 | d_in,
70 | d_hid,
71 | kernel_size=kernel_size[0],
72 | padding=(kernel_size[0] - 1) // 2,
73 | )
74 | # position-wise
75 | self.w_2 = nn.Conv1d(
76 | d_hid,
77 | d_in,
78 | kernel_size=kernel_size[1],
79 | padding=(kernel_size[1] - 1) // 2,
80 | )
81 |
82 | self.layer_norm = nn.LayerNorm(d_in)
83 | self.dropout = nn.Dropout(dropout)
84 |
85 | def forward(self, x):
86 | residual = x
87 | output = x.transpose(1, 2)
88 | output = self.w_2(F.relu(self.w_1(output)))
89 | output = output.transpose(1, 2)
90 | output = self.dropout(output)
91 | output = self.layer_norm(output + residual)
92 |
93 | return output
94 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .Models import Encoder, Decoder
2 | from .Layers import PostNet
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import numpy as np
6 |
7 | import hifigan
8 | from model import FastSpeech2, ScheduledOptim
9 |
10 |
11 | def get_model(args, configs, device, train=False):
12 | (preprocess_config, model_config, train_config) = configs
13 |
14 | model = FastSpeech2(preprocess_config, model_config).to(device)
15 | if args.restore_step:
16 | ckpt_path = os.path.join(
17 | train_config["path"]["ckpt_path"],
18 | "{}.pth.tar".format(args.restore_step),
19 | )
20 | ckpt = torch.load(ckpt_path)
21 | model.load_state_dict(ckpt["model"])
22 |
23 | if train:
24 | scheduled_optim = ScheduledOptim(
25 | model, train_config, model_config, args.restore_step
26 | )
27 | if args.restore_step:
28 | scheduled_optim.load_state_dict(ckpt["optimizer"])
29 | model.train()
30 | return model, scheduled_optim
31 |
32 | model.eval()
33 | model.requires_grad_ = False
34 | return model
35 |
36 |
37 | def get_param_num(model):
38 | num_param = sum(param.numel() for param in model.parameters())
39 | return num_param
40 |
41 |
42 | def get_vocoder(config, device):
43 | name = config["vocoder"]["model"]
44 | speaker = config["vocoder"]["speaker"]
45 |
46 | if name == "MelGAN":
47 | if speaker == "LJSpeech":
48 | vocoder = torch.hub.load(
49 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
50 | )
51 | elif speaker == "universal":
52 | vocoder = torch.hub.load(
53 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
54 | )
55 | vocoder.mel2wav.eval()
56 | vocoder.mel2wav.to(device)
57 | elif name == "HiFi-GAN":
58 | with open("hifigan/config.json", "r") as f:
59 | config = json.load(f)
60 | config = hifigan.AttrDict(config)
61 | vocoder = hifigan.Generator(config)
62 | if speaker == "LJSpeech":
63 | ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar")
64 | elif speaker == "universal":
65 | ckpt = torch.load("hifigan/generator_universal.pth.tar")
66 | vocoder.load_state_dict(ckpt["generator"])
67 | vocoder.eval()
68 | vocoder.remove_weight_norm()
69 | vocoder.to(device)
70 |
71 | return vocoder
72 |
73 |
74 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
75 | name = model_config["vocoder"]["model"]
76 | with torch.no_grad():
77 | if name == "MelGAN":
78 | wavs = vocoder.inverse(mels / np.log(10))
79 | elif name == "HiFi-GAN":
80 | wavs = vocoder(mels).squeeze(1)
81 |
82 | wavs = (
83 | wavs.cpu().numpy()
84 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"]
85 | ).astype("int16")
86 | wavs = [wav for wav in wavs]
87 |
88 | for i in range(len(mels)):
89 | if lengths is not None:
90 | wavs[i] = wavs[i][: lengths[i]]
91 |
92 | return wavs
93 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import matplotlib
8 | from scipy.io import wavfile
9 | from matplotlib import pyplot as plt
10 |
11 |
12 | matplotlib.use("Agg")
13 |
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | def to_device(data, device):
19 | if len(data) == 12:
20 | (
21 | ids,
22 | raw_texts,
23 | speakers,
24 | texts,
25 | src_lens,
26 | max_src_len,
27 | mels,
28 | mel_lens,
29 | max_mel_len,
30 | pitches,
31 | energies,
32 | durations,
33 | ) = data
34 |
35 | speakers = torch.from_numpy(speakers).long().to(device)
36 | texts = torch.from_numpy(texts).long().to(device)
37 | src_lens = torch.from_numpy(src_lens).to(device)
38 | mels = torch.from_numpy(mels).float().to(device)
39 | mel_lens = torch.from_numpy(mel_lens).to(device)
40 | pitches = torch.from_numpy(pitches).float().to(device)
41 | energies = torch.from_numpy(energies).to(device)
42 | durations = torch.from_numpy(durations).long().to(device)
43 |
44 | return (
45 | ids,
46 | raw_texts,
47 | speakers,
48 | texts,
49 | src_lens,
50 | max_src_len,
51 | mels,
52 | mel_lens,
53 | max_mel_len,
54 | pitches,
55 | energies,
56 | durations,
57 | )
58 |
59 | if len(data) == 15:
60 | (
61 | ids,
62 | raw_texts,
63 | speakers,
64 | emotions,
65 | arousals,
66 | valences,
67 | texts,
68 | src_lens,
69 | max_src_len,
70 | mels,
71 | mel_lens,
72 | max_mel_len,
73 | pitches,
74 | energies,
75 | durations,
76 | ) = data
77 |
78 | speakers = torch.from_numpy(speakers).long().to(device)
79 | emotions = torch.from_numpy(emotions).long().to(device)
80 | arousals = torch.from_numpy(arousals).long().to(device)
81 | valences = torch.from_numpy(valences).long().to(device)
82 | texts = torch.from_numpy(texts).long().to(device)
83 | src_lens = torch.from_numpy(src_lens).to(device)
84 | mels = torch.from_numpy(mels).float().to(device)
85 | mel_lens = torch.from_numpy(mel_lens).to(device)
86 | pitches = torch.from_numpy(pitches).float().to(device)
87 | energies = torch.from_numpy(energies).to(device)
88 | durations = torch.from_numpy(durations).long().to(device)
89 |
90 | return (
91 | ids,
92 | raw_texts,
93 | speakers,
94 | emotions,
95 | arousals,
96 | valences,
97 | texts,
98 | src_lens,
99 | max_src_len,
100 | mels,
101 | mel_lens,
102 | max_mel_len,
103 | pitches,
104 | energies,
105 | durations,
106 | )
107 |
108 | if len(data) == 6:
109 | (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
110 |
111 | speakers = torch.from_numpy(speakers).long().to(device)
112 | texts = torch.from_numpy(texts).long().to(device)
113 | src_lens = torch.from_numpy(src_lens).to(device)
114 |
115 | return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
116 |
117 | if len(data) == 9:
118 | (ids, raw_texts, speakers, emotions, arousals, valences, texts, src_lens, max_src_len) = data
119 |
120 | speakers = torch.from_numpy(speakers).long().to(device)
121 | emotions = torch.from_numpy(emotions).long().to(device)
122 | arousals = torch.from_numpy(arousals).long().to(device)
123 | valences = torch.from_numpy(valences).long().to(device)
124 | texts = torch.from_numpy(texts).long().to(device)
125 | src_lens = torch.from_numpy(src_lens).to(device)
126 |
127 | return (ids, raw_texts, speakers, emotions, arousals, valences, texts, src_lens, max_src_len)
128 |
129 |
130 | def log(
131 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""
132 | ):
133 | if losses is not None:
134 | logger.add_scalar("Loss/total_loss", losses[0], step)
135 | logger.add_scalar("Loss/mel_loss", losses[1], step)
136 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
137 | logger.add_scalar("Loss/pitch_loss", losses[3], step)
138 | logger.add_scalar("Loss/energy_loss", losses[4], step)
139 | logger.add_scalar("Loss/duration_loss", losses[5], step)
140 |
141 | if fig is not None:
142 | logger.add_figure(tag, fig)
143 |
144 | if audio is not None:
145 | logger.add_audio(
146 | tag,
147 | audio / max(abs(audio)),
148 | sample_rate=sampling_rate,
149 | )
150 |
151 |
152 | def get_mask_from_lengths(lengths, max_len=None):
153 | batch_size = lengths.shape[0]
154 | if max_len is None:
155 | max_len = torch.max(lengths).item()
156 |
157 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
158 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
159 |
160 | return mask
161 |
162 |
163 | def expand(values, durations):
164 | out = list()
165 | for value, d in zip(values, durations):
166 | out += [value] * max(0, int(d))
167 | return np.array(out)
168 |
169 |
170 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config):
171 |
172 | basename = targets[0][0]
173 | src_len = predictions[8][0].item()
174 | mel_len = predictions[9][0].item()
175 | mel_target = targets[9][0, :mel_len].detach().transpose(0, 1)
176 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1)
177 | duration = targets[14][0, :src_len].detach().cpu().numpy()
178 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
179 | pitch = targets[12][0, :src_len].detach().cpu().numpy()
180 | pitch = expand(pitch, duration)
181 | else:
182 | pitch = targets[12][0, :mel_len].detach().cpu().numpy()
183 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
184 | energy = targets[13][0, :src_len].detach().cpu().numpy()
185 | energy = expand(energy, duration)
186 | else:
187 | energy = targets[13][0, :mel_len].detach().cpu().numpy()
188 |
189 | with open(
190 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
191 | ) as f:
192 | stats = json.load(f)
193 | stats = stats["pitch"] + stats["energy"][:2]
194 |
195 | fig = plot_mel(
196 | [
197 | (mel_prediction.cpu().numpy(), pitch, energy),
198 | (mel_target.cpu().numpy(), pitch, energy),
199 | ],
200 | stats,
201 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
202 | )
203 |
204 | if vocoder is not None:
205 | from .model import vocoder_infer
206 |
207 | wav_reconstruction = vocoder_infer(
208 | mel_target.unsqueeze(0),
209 | vocoder,
210 | model_config,
211 | preprocess_config,
212 | )[0]
213 | wav_prediction = vocoder_infer(
214 | mel_prediction.unsqueeze(0),
215 | vocoder,
216 | model_config,
217 | preprocess_config,
218 | )[0]
219 | else:
220 | wav_reconstruction = wav_prediction = None
221 |
222 | return fig, wav_reconstruction, wav_prediction, basename
223 |
224 |
225 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path, tag=None):
226 |
227 | basenames = targets[0]
228 | for i in range(len(predictions[0])):
229 | basename = basenames[i]
230 | src_len = predictions[8][i].item()
231 | mel_len = predictions[9][i].item()
232 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
233 | duration = predictions[5][i, :src_len].detach().cpu().numpy()
234 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
235 | pitch = predictions[2][i, :src_len].detach().cpu().numpy()
236 | pitch = expand(pitch, duration)
237 | else:
238 | pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
239 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
240 | energy = predictions[3][i, :src_len].detach().cpu().numpy()
241 | energy = expand(energy, duration)
242 | else:
243 | energy = predictions[3][i, :mel_len].detach().cpu().numpy()
244 |
245 | with open(
246 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
247 | ) as f:
248 | stats = json.load(f)
249 | stats = stats["pitch"] + stats["energy"][:2]
250 |
251 | fig = plot_mel(
252 | [
253 | (mel_prediction.cpu().numpy(), pitch, energy),
254 | ],
255 | stats,
256 | ["Synthetized Spectrogram"],
257 | )
258 | plt.savefig(os.path.join(path, "{}{}.png".format(basename, f"_{tag}" if tag is not None else "")))
259 | plt.close()
260 |
261 | from .model import vocoder_infer
262 |
263 | mel_predictions = predictions[1].transpose(1, 2)
264 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
265 | wav_predictions = vocoder_infer(
266 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
267 | )
268 |
269 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
270 | for wav, basename in zip(wav_predictions, basenames):
271 | wavfile.write(os.path.join(path, "{}{}.wav".format(basename, f"_{tag}" if tag is not None else "")), sampling_rate, wav)
272 |
273 |
274 | def plot_mel(data, stats, titles):
275 | fig, axes = plt.subplots(len(data), 1, squeeze=False)
276 | if titles is None:
277 | titles = [None for i in range(len(data))]
278 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
279 | pitch_min = pitch_min * pitch_std + pitch_mean
280 | pitch_max = pitch_max * pitch_std + pitch_mean
281 |
282 | def add_axis(fig, old_ax):
283 | ax = fig.add_axes(old_ax.get_position(), anchor="W")
284 | ax.set_facecolor("None")
285 | return ax
286 |
287 | for i in range(len(data)):
288 | mel, pitch, energy = data[i]
289 | pitch = pitch * pitch_std + pitch_mean
290 | axes[i][0].imshow(mel, origin="lower")
291 | axes[i][0].set_aspect(2.5, adjustable="box")
292 | axes[i][0].set_ylim(0, mel.shape[0])
293 | axes[i][0].set_title(titles[i], fontsize="medium")
294 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
295 | axes[i][0].set_anchor("W")
296 |
297 | ax1 = add_axis(fig, axes[i][0])
298 | ax1.plot(pitch, color="tomato")
299 | ax1.set_xlim(0, mel.shape[1])
300 | ax1.set_ylim(0, pitch_max)
301 | ax1.set_ylabel("F0", color="tomato")
302 | ax1.tick_params(
303 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
304 | )
305 |
306 | ax2 = add_axis(fig, axes[i][0])
307 | ax2.plot(energy, color="darkviolet")
308 | ax2.set_xlim(0, mel.shape[1])
309 | ax2.set_ylim(energy_min, energy_max)
310 | ax2.set_ylabel("Energy", color="darkviolet")
311 | ax2.yaxis.set_label_position("right")
312 | ax2.tick_params(
313 | labelsize="x-small",
314 | colors="darkviolet",
315 | bottom=False,
316 | labelbottom=False,
317 | left=False,
318 | labelleft=False,
319 | right=True,
320 | labelright=True,
321 | )
322 |
323 | return fig
324 |
325 |
326 | def pad_1D(inputs, PAD=0):
327 | def pad_data(x, length, PAD):
328 | x_padded = np.pad(
329 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
330 | )
331 | return x_padded
332 |
333 | max_len = max((len(x) for x in inputs))
334 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
335 |
336 | return padded
337 |
338 |
339 | def pad_2D(inputs, maxlen=None):
340 | def pad(x, max_len):
341 | PAD = 0
342 | if np.shape(x)[0] > max_len:
343 | raise ValueError("not max_len")
344 |
345 | s = np.shape(x)[1]
346 | x_padded = np.pad(
347 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
348 | )
349 | return x_padded[:, :s]
350 |
351 | if maxlen:
352 | output = np.stack([pad(x, maxlen) for x in inputs])
353 | else:
354 | max_len = max(np.shape(x)[0] for x in inputs)
355 | output = np.stack([pad(x, max_len) for x in inputs])
356 |
357 | return output
358 |
359 |
360 | def pad(input_ele, mel_max_length=None):
361 | if mel_max_length:
362 | max_len = mel_max_length
363 | else:
364 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
365 |
366 | out_list = list()
367 | for i, batch in enumerate(input_ele):
368 | if len(batch.shape) == 1:
369 | one_batch_padded = F.pad(
370 | batch, (0, max_len - batch.size(0)), "constant", 0.0
371 | )
372 | elif len(batch.shape) == 2:
373 | one_batch_padded = F.pad(
374 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
375 | )
376 | out_list.append(one_batch_padded)
377 | out_padded = torch.stack(out_list)
378 | return out_padded
379 |
--------------------------------------------------------------------------------