├── .gitignore
├── LICENSE
├── README.md
├── audio
├── __init__.py
├── audio_processing.py
├── stft.py
└── tools.py
├── config
└── LibriTTS
│ ├── model.yaml
│ ├── preprocess.yaml
│ └── train.yaml
├── dataset.py
├── demo
└── LibriTTS
│ ├── references
│ ├── 0_19_198_000000_000000.lab
│ ├── 0_19_198_000000_000000.wav
│ ├── 1_26_495_000004_000000.lab
│ ├── 1_26_495_000004_000000.wav
│ ├── 2_27_123349_000001_000000.lab
│ ├── 2_27_123349_000001_000000.wav
│ ├── 3_32_4137_000005_000001.lab
│ ├── 3_32_4137_000005_000001.wav
│ ├── 4_39_121916_000015_000005.lab
│ ├── 4_39_121916_000015_000005.wav
│ ├── 5_40_222_000001_000000.lab
│ └── 5_40_222_000001_000000.wav
│ └── results
│ ├── 0_the two children therefore got up, dressed themselves quickly, and went away..png
│ ├── 0_the two children therefore got up, dressed themselves quickly, and went away..wav
│ ├── 1_the two children therefore got up, dressed themselves quickly, and went away..png
│ ├── 1_the two children therefore got up, dressed themselves quickly, and went away..wav
│ ├── 2_the two children therefore got up, dressed themselves quickly, and went away..png
│ ├── 2_the two children therefore got up, dressed themselves quickly, and went away..wav
│ ├── 3_the two children therefore got up, dressed themselves quickly, and went away..png
│ ├── 3_the two children therefore got up, dressed themselves quickly, and went away..wav
│ ├── 4_the two children therefore got up, dressed themselves quickly, and went away..png
│ ├── 4_the two children therefore got up, dressed themselves quickly, and went away..wav
│ ├── 5_the two children therefore got up, dressed themselves quickly, and went away..png
│ └── 5_the two children therefore got up, dressed themselves quickly, and went away..wav
├── evaluate.py
├── filelist_filtering.py
├── hifigan
├── LICENSE
├── __init__.py
├── config.json
├── generator_LJSpeech.pth.tar.zip
├── generator_universal.pth.tar.zip
└── models.py
├── img
├── model_1.png
├── model_2.png
├── tensorboard_audio.png
├── tensorboard_loss.png
└── tensorboard_spec.png
├── lexicon
├── librispeech-lexicon.txt
└── pinyin-lexicon-r.txt
├── model
├── StyleSpeech.py
├── __init__.py
├── blocks.py
├── loss.py
├── modules.py
└── optimizer.py
├── prepare_align.py
├── preprocess.py
├── preprocessed_data
└── LibriTTS
│ ├── speakers.json
│ ├── stats.json
│ ├── train.txt
│ ├── train_filtered.txt
│ └── val.txt
├── preprocessor
├── aishell3.py
├── libritts.py
├── ljspeech.py
└── preprocessor.py
├── requirements.txt
├── synthesize.py
├── text
├── __init__.py
├── cleaners.py
├── cmudict.py
├── numbers.py
├── pinyin.py
└── symbols.py
├── train.py
└── utils
├── model.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | __pycache__
107 | .vscode
108 | .DS_Store
109 |
110 | # MFA
111 | montreal-forced-aligner/
112 |
113 | # data, checkpoint, and models
114 | raw_data/
115 | output/
116 | *.npy
117 | TextGrid/
118 | hifigan/*.pth.tar
119 | *.out
120 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Keon Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StyleSpeech - PyTorch Implementation
2 |
3 | PyTorch Implementation of [Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation](https://arxiv.org/abs/2106.03153).
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | # Branch
14 | - [x] StyleSpeech (`naive` branch)
15 | - [x] Meta-StyleSpeech (`main` branch)
16 |
17 | # Quickstart
18 |
19 | ## Dependencies
20 | You can install the Python dependencies with
21 | ```
22 | pip3 install -r requirements.txt
23 | ```
24 |
25 | ## Inference
26 |
27 | You have to download [pretrained models](https://drive.google.com/drive/folders/1fQmu1v7fRgfM-TwxAJ96UUgnl79f1FHt?usp=sharing) and put them in ``output/ckpt/LibriTTS_meta_learner/``.
28 |
29 | For English multi-speaker TTS, run
30 | ```
31 | python3 synthesize.py --text "YOUR_DESIRED_TEXT" --ref_audio path/to/reference_audio.wav --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml
32 | ```
33 | The generated utterances will be put in ``output/result/``. Your synthesized speech will have `ref_audio`'s style.
34 |
35 |
36 | ## Batch Inference
37 | Batch inference is also supported, try
38 |
39 | ```
40 | python3 synthesize.py --source preprocessed_data/LibriTTS/val.txt --restore_step 200000 --mode batch -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml
41 | ```
42 | to synthesize all utterances in ``preprocessed_data/LibriTTS/val.txt``. This can be viewed as a reconstruction of validation datasets referring to themselves for the reference style.
43 |
44 | ## Controllability
45 | The pitch/volume/speaking rate of the synthesized utterances can be controlled by specifying the desired pitch/energy/duration ratios.
46 | For example, one can increase the speaking rate by 20 % and decrease the volume by 20 % by
47 |
48 | ```
49 | python3 synthesize.py --text "YOUR_DESIRED_TEXT" --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml --duration_control 0.8 --energy_control 0.8
50 | ```
51 | Note that the controllability is originated from FastSpeech2 and not a vital interest of StyleSpeech. Please refer to [STYLER](https://arxiv.org/abs/2103.09474) [[demo](https://keonlee9420.github.io/STYLER-Demo/), [code](https://github.com/keonlee9420/STYLER)] for the controllability of each style factor.
52 |
53 | # Training
54 |
55 | ## Datasets
56 |
57 | The supported datasets are
58 |
59 | - [LibriTTS](https://research.google/tools/datasets/libri-tts/): a multi-speaker English dataset containing 585 hours of speech by 2456 speakers.
60 | - (will be added more)
61 |
62 | ## Preprocessing
63 |
64 | Run
65 | ```
66 | python3 prepare_align.py config/LibriTTS/preprocess.yaml
67 | ```
68 | for some preparations.
69 |
70 | For the forced alignment, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) (MFA) is used to obtain the alignments between the utterances and the phoneme sequences.
71 | Pre-extracted alignments for the datasets are provided [here](https://drive.google.com/drive/folders/1fizpyOiQ1lG2UDaMlXnT3Ll4_j6Xwg7K?usp=sharing).
72 | You have to unzip the files in `preprocessed_data/LibriTTS/TextGrid/`. Alternately, you can [run the aligner by yourself](https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/workflows/index.html).
73 |
74 | After that, run the preprocessing script by
75 | ```
76 | python3 preprocess.py config/LibriTTS/preprocess.yaml
77 | ```
78 |
79 | ## Training
80 |
81 | Train your model with
82 | ```
83 | python3 train.py -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml
84 | ```
85 | As described in the paper, the script will start from pre-training the naive model until `meta_learning_warmup` steps and then meta-train the model for additional steps via episodic training.
86 |
87 | # TensorBoard
88 |
89 | Use
90 | ```
91 | tensorboard --logdir output/log/LibriTTS
92 | ```
93 |
94 | to serve TensorBoard on your localhost.
95 | The loss curves, synthesized mel-spectrograms, and audios are shown.
96 |
97 | 
98 | 
99 | 
100 |
101 | # Implementation Issues
102 |
103 | 1. Use `22050Hz` sampling rate instead of `16kHz`.
104 | 2. Add one fully connected layer at the beginning of Mel-Style Encoder to upsample input mel-spectrogram from `80` to `128`.
105 | 3. The model size including meta-learner is `28.197M`.
106 | 4. Use a maximum `16` batch size on training instead of `48` or `20` mainly due to the lack of memory capacity with a single **24GiB TITAN-RTX**. This can be achieved by the following script to filter out data longer than `max_seq_len`:
107 | ```
108 | python3 filelist_filtering.py -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml
109 | ```
110 | This will generate `train_filtered.txt` in the same location of `train.txt`.
111 | 5. Since the total batch size is decreased, the number of training steps is doubled compared to the original paper.
112 | 6. Use **HiFi-GAN** instead of **MelGAN** for vocoding.
113 |
114 | # Citation
115 |
116 | ```
117 | @misc{lee2021stylespeech,
118 | author = {Lee, Keon},
119 | title = {StyleSpeech},
120 | year = {2021},
121 | publisher = {GitHub},
122 | journal = {GitHub repository},
123 | howpublished = {\url{https://github.com/keonlee9420/StyleSpeech}}
124 | }
125 | ```
126 |
127 | # References
128 | - [Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation](https://arxiv.org/abs/2106.03153)
129 | - [A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/abs/1812.04948)
130 | - [Matching Networks for One Shot Learning](https://arxiv.org/abs/1606.04080)
131 | - [Prototypical Networks for Few-shot Learning](https://arxiv.org/pdf/1703.05175v2.pdf)
132 | - [TADAM: Task dependent adaptive metric for improved few-shot learning](https://arxiv.org/abs/1805.10123)
133 | - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2)
134 |
--------------------------------------------------------------------------------
/audio/__init__.py:
--------------------------------------------------------------------------------
1 | import audio.tools
2 | import audio.stft
3 | import audio.audio_processing
4 |
--------------------------------------------------------------------------------
/audio/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import librosa.util as librosa_util
4 | from scipy.signal import get_window
5 |
6 |
7 | def window_sumsquare(
8 | window,
9 | n_frames,
10 | hop_length,
11 | win_length,
12 | n_fft,
13 | dtype=np.float32,
14 | norm=None,
15 | ):
16 | """
17 | # from librosa 0.6
18 | Compute the sum-square envelope of a window function at a given hop length.
19 |
20 | This is used to estimate modulation effects induced by windowing
21 | observations in short-time fourier transforms.
22 |
23 | Parameters
24 | ----------
25 | window : string, tuple, number, callable, or list-like
26 | Window specification, as in `get_window`
27 |
28 | n_frames : int > 0
29 | The number of analysis frames
30 |
31 | hop_length : int > 0
32 | The number of samples to advance between frames
33 |
34 | win_length : [optional]
35 | The length of the window function. By default, this matches `n_fft`.
36 |
37 | n_fft : int > 0
38 | The length of each analysis frame.
39 |
40 | dtype : np.dtype
41 | The data type of the output
42 |
43 | Returns
44 | -------
45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46 | The sum-squared envelope of the window function
47 | """
48 | if win_length is None:
49 | win_length = n_fft
50 |
51 | n = n_fft + hop_length * (n_frames - 1)
52 | x = np.zeros(n, dtype=dtype)
53 |
54 | # Compute the squared window at the desired length
55 | win_sq = get_window(window, win_length, fftbins=True)
56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57 | win_sq = librosa_util.pad_center(win_sq, n_fft)
58 |
59 | # Fill the envelope
60 | for i in range(n_frames):
61 | sample = i * hop_length
62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63 | return x
64 |
65 |
66 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
67 | """
68 | PARAMS
69 | ------
70 | magnitudes: spectrogram magnitudes
71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72 | """
73 |
74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75 | angles = angles.astype(np.float32)
76 | angles = torch.autograd.Variable(torch.from_numpy(angles))
77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78 |
79 | for i in range(n_iters):
80 | _, angles = stft_fn.transform(signal)
81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82 | return signal
83 |
84 |
85 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
86 | """
87 | PARAMS
88 | ------
89 | C: compression factor
90 | """
91 | return torch.log(torch.clamp(x, min=clip_val) * C)
92 |
93 |
94 | def dynamic_range_decompression(x, C=1):
95 | """
96 | PARAMS
97 | ------
98 | C: compression factor used to compress
99 | """
100 | return torch.exp(x) / C
101 |
--------------------------------------------------------------------------------
/audio/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.signal import get_window
5 | from librosa.util import pad_center, tiny
6 | from librosa.filters import mel as librosa_mel_fn
7 |
8 | from audio.audio_processing import (
9 | dynamic_range_compression,
10 | dynamic_range_decompression,
11 | window_sumsquare,
12 | )
13 |
14 |
15 | class STFT(torch.nn.Module):
16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17 |
18 | def __init__(self, filter_length, hop_length, win_length, window="hann"):
19 | super(STFT, self).__init__()
20 | self.filter_length = filter_length
21 | self.hop_length = hop_length
22 | self.win_length = win_length
23 | self.window = window
24 | self.forward_transform = None
25 | scale = self.filter_length / self.hop_length
26 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
27 |
28 | cutoff = int((self.filter_length / 2 + 1))
29 | fourier_basis = np.vstack(
30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31 | )
32 |
33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34 | inverse_basis = torch.FloatTensor(
35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36 | )
37 |
38 | if window is not None:
39 | assert filter_length >= win_length
40 | # get window and zero center pad it to filter_length
41 | fft_window = get_window(window, win_length, fftbins=True)
42 | fft_window = pad_center(fft_window, filter_length)
43 | fft_window = torch.from_numpy(fft_window).float()
44 |
45 | # window the bases
46 | forward_basis *= fft_window
47 | inverse_basis *= fft_window
48 |
49 | self.register_buffer("forward_basis", forward_basis.float())
50 | self.register_buffer("inverse_basis", inverse_basis.float())
51 |
52 | def transform(self, input_data):
53 | num_batches = input_data.size(0)
54 | num_samples = input_data.size(1)
55 |
56 | self.num_samples = num_samples
57 |
58 | # similar to librosa, reflect-pad the input
59 | input_data = input_data.view(num_batches, 1, num_samples)
60 | input_data = F.pad(
61 | input_data.unsqueeze(1),
62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63 | mode="reflect",
64 | )
65 | input_data = input_data.squeeze(1)
66 |
67 | forward_transform = F.conv1d(
68 | input_data.cuda(),
69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
70 | stride=self.hop_length,
71 | padding=0,
72 | ).cpu()
73 |
74 | cutoff = int((self.filter_length / 2) + 1)
75 | real_part = forward_transform[:, :cutoff, :]
76 | imag_part = forward_transform[:, cutoff:, :]
77 |
78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80 |
81 | return magnitude, phase
82 |
83 | def inverse(self, magnitude, phase):
84 | recombine_magnitude_phase = torch.cat(
85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86 | )
87 |
88 | inverse_transform = F.conv_transpose1d(
89 | recombine_magnitude_phase,
90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91 | stride=self.hop_length,
92 | padding=0,
93 | )
94 |
95 | if self.window is not None:
96 | window_sum = window_sumsquare(
97 | self.window,
98 | magnitude.size(-1),
99 | hop_length=self.hop_length,
100 | win_length=self.win_length,
101 | n_fft=self.filter_length,
102 | dtype=np.float32,
103 | )
104 | # remove modulation effects
105 | approx_nonzero_indices = torch.from_numpy(
106 | np.where(window_sum > tiny(window_sum))[0]
107 | )
108 | window_sum = torch.autograd.Variable(
109 | torch.from_numpy(window_sum), requires_grad=False
110 | )
111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113 | approx_nonzero_indices
114 | ]
115 |
116 | # scale by hop ratio
117 | inverse_transform *= float(self.filter_length) / self.hop_length
118 |
119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121 |
122 | return inverse_transform
123 |
124 | def forward(self, input_data):
125 | self.magnitude, self.phase = self.transform(input_data)
126 | reconstruction = self.inverse(self.magnitude, self.phase)
127 | return reconstruction
128 |
129 |
130 | class TacotronSTFT(torch.nn.Module):
131 | def __init__(
132 | self,
133 | filter_length,
134 | hop_length,
135 | win_length,
136 | n_mel_channels,
137 | sampling_rate,
138 | mel_fmin,
139 | mel_fmax,
140 | ):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147 | )
148 | mel_basis = torch.from_numpy(mel_basis).float()
149 | self.register_buffer("mel_basis", mel_basis)
150 |
151 | def spectral_normalize(self, magnitudes):
152 | output = dynamic_range_compression(magnitudes)
153 | return output
154 |
155 | def spectral_de_normalize(self, magnitudes):
156 | output = dynamic_range_decompression(magnitudes)
157 | return output
158 |
159 | def mel_spectrogram(self, y):
160 | """Computes mel-spectrograms from a batch of waves
161 | PARAMS
162 | ------
163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164 |
165 | RETURNS
166 | -------
167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168 | """
169 | assert torch.min(y.data) >= -1
170 | assert torch.max(y.data) <= 1
171 |
172 | magnitudes, phases = self.stft_fn.transform(y)
173 | magnitudes = magnitudes.data
174 | mel_output = torch.matmul(self.mel_basis, magnitudes)
175 | mel_output = self.spectral_normalize(mel_output)
176 | energy = torch.norm(magnitudes, dim=1)
177 |
178 | return mel_output, energy
179 |
--------------------------------------------------------------------------------
/audio/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.io.wavfile import write
4 |
5 | from audio.audio_processing import griffin_lim
6 |
7 |
8 | def get_mel_from_wav(audio, _stft):
9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
10 | audio = torch.autograd.Variable(audio, requires_grad=False)
11 | melspec, energy = _stft.mel_spectrogram(audio)
12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14 |
15 | return melspec, energy
16 |
17 |
18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
19 | mel = torch.stack([mel])
20 | mel_decompress = _stft.spectral_de_normalize(mel)
21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
22 | spec_from_mel_scaling = 1000
23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
26 |
27 | audio = griffin_lim(
28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
29 | )
30 |
31 | audio = audio.squeeze()
32 | audio = audio.cpu().numpy()
33 | audio_path = out_filename
34 | write(audio_path, _stft.sampling_rate, audio)
35 |
--------------------------------------------------------------------------------
/config/LibriTTS/model.yaml:
--------------------------------------------------------------------------------
1 | prenet:
2 | conv_kernel_size: 3
3 | dropout: 0.1
4 |
5 | transformer:
6 | encoder_layer: 4
7 | encoder_head: 2
8 | encoder_hidden: 256
9 | decoder_layer: 4
10 | decoder_head: 2
11 | decoder_hidden: 256
12 | conv_filter_size: 1024
13 | conv_kernel_size: [9, 1]
14 | encoder_dropout: 0.1
15 | decoder_dropout: 0.1
16 |
17 | melencoder:
18 | encoder_hidden: 128
19 | spectral_layer: 2
20 | temporal_layer: 2
21 | slf_attn_layer: 1
22 | slf_attn_head: 2
23 | conv_kernel_size: 5
24 | encoder_dropout: 0.1
25 |
26 | variance_predictor:
27 | filter_size: 256
28 | kernel_size: 3
29 | dropout: 0.5
30 |
31 | variance_embedding:
32 | kernel_size: 9
33 |
34 | discriminator:
35 | mel_linear_size: 256
36 | phoneme_layer: 3
37 | phoneme_hidden: 512
38 |
39 | multi_speaker: True
40 |
41 | max_seq_len: 1000
42 |
43 | vocoder:
44 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
45 | speaker: "universal" # support 'LJSpeech', 'universal'
46 |
--------------------------------------------------------------------------------
/config/LibriTTS/preprocess.yaml:
--------------------------------------------------------------------------------
1 | dataset: "LibriTTS"
2 |
3 | path:
4 | corpus_path: "/mnt/nfs2/speech-datasets/en/LibriTTS/train-clean-100"
5 | lexicon_path: "lexicon/librispeech-lexicon.txt"
6 | raw_path: "./raw_data/LibriTTS"
7 | preprocessed_path: "./preprocessed_data/LibriTTS"
8 |
9 | preprocessing:
10 | val_size: 512
11 | text:
12 | text_cleaners: ["english_cleaners"]
13 | language: "en"
14 | audio:
15 | sampling_rate: 22050
16 | max_wav_value: 32768.0
17 | stft:
18 | filter_length: 1024
19 | hop_length: 256
20 | win_length: 1024
21 | mel:
22 | n_mel_channels: 80
23 | mel_fmin: 0
24 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
25 | pitch:
26 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
27 | normalization: True
28 | energy:
29 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level'
30 | normalization: True
31 |
--------------------------------------------------------------------------------
/config/LibriTTS/train.yaml:
--------------------------------------------------------------------------------
1 | path:
2 | ckpt_path: "./output/ckpt/LibriTTS_meta_learner"
3 | log_path: "./output/log/LibriTTS_meta_learner"
4 | result_path: "./output/result/LibriTTS_meta_learner"
5 | optimizer:
6 | batch_size: 16
7 | betas: [0.9, 0.98]
8 | eps: 0.000000001
9 | weight_decay: 0.0
10 | grad_clip_thresh: 1.0
11 | grad_acc_step: 1
12 | warm_up_step: 4000
13 | anneal_steps: [300000, 400000, 500000]
14 | anneal_rate: 0.3
15 | lr_disc: 0.00002
16 | alpha: 10
17 | step:
18 | meta_learning_warmup: 120000
19 | total_step: 200000
20 | log_step: 100
21 | synth_step: 1000
22 | val_step: 1000
23 | save_step: 40000
24 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | import random
5 |
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 |
9 | from text import text_to_sequence
10 | from utils.tools import pad_1D, pad_2D, expand
11 |
12 | random.seed(1234)
13 |
14 |
15 | class Dataset(Dataset):
16 | def __init__(
17 | self, filename, preprocess_config, train_config, sort=False, drop_last=False
18 | ):
19 | self.dataset_name = preprocess_config["dataset"]
20 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
21 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
22 | self.batch_size = train_config["optimizer"]["batch_size"]
23 |
24 | self.basename, self.speaker, self.text, self.raw_text, self.speaker_to_ids = self.process_meta(
25 | filename
26 | )
27 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
28 | self.speaker_map = json.load(f)
29 | self.sort = sort
30 | self.drop_last = drop_last
31 |
32 | def __len__(self):
33 | return len(self.text)
34 |
35 | def __getitem__(self, idx):
36 | basename = self.basename[idx]
37 | speaker = self.speaker[idx]
38 | speaker_id = self.speaker_map[speaker]
39 | raw_text = self.raw_text[idx]
40 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
41 | query_idx = random.choice(self.speaker_to_ids[speaker]) # Sample the query text
42 | raw_quary_text = self.raw_text[query_idx]
43 | query_phone = np.array(text_to_sequence(self.text[query_idx], self.cleaners))
44 | mel_path = os.path.join(
45 | self.preprocessed_path,
46 | "mel",
47 | "{}-mel-{}.npy".format(speaker, basename),
48 | )
49 | mel = np.load(mel_path)
50 | pitch_path = os.path.join(
51 | self.preprocessed_path,
52 | "pitch",
53 | "{}-pitch-{}.npy".format(speaker, basename),
54 | )
55 | pitch = np.load(pitch_path)
56 | energy_path = os.path.join(
57 | self.preprocessed_path,
58 | "energy",
59 | "{}-energy-{}.npy".format(speaker, basename),
60 | )
61 | energy = np.load(energy_path)
62 | duration_path = os.path.join(
63 | self.preprocessed_path,
64 | "duration",
65 | "{}-duration-{}.npy".format(speaker, basename),
66 | )
67 | duration = np.load(duration_path)
68 | quary_duration_path = os.path.join(
69 | self.preprocessed_path,
70 | "duration",
71 | "{}-duration-{}.npy".format(self.speaker[query_idx], self.basename[query_idx]),
72 | )
73 | quary_duration = np.load(quary_duration_path)
74 |
75 | sample = {
76 | "id": basename,
77 | "speaker": speaker_id,
78 | "text": phone,
79 | "raw_text": raw_text,
80 | "quary_text": query_phone,
81 | "raw_quary_text": raw_quary_text,
82 | "mel": mel,
83 | "pitch": pitch,
84 | "energy": energy,
85 | "duration": duration,
86 | "quary_duration": quary_duration,
87 | }
88 |
89 | return sample
90 |
91 | def process_meta(self, filename):
92 | with open(
93 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
94 | ) as f:
95 | name = []
96 | speaker = []
97 | text = []
98 | raw_text = []
99 | speaker_to_ids = dict()
100 | for i, line in enumerate(f.readlines()):
101 | n, s, t, r = line.strip("\n").split("|")
102 | name.append(n)
103 | speaker.append(s)
104 | text.append(t)
105 | raw_text.append(r)
106 | if s not in speaker_to_ids:
107 | speaker_to_ids[s] = [i]
108 | else:
109 | speaker_to_ids[s] += [i]
110 | return name, speaker, text, raw_text, speaker_to_ids
111 |
112 | def reprocess(self, data, idxs):
113 | ids = [data[idx]["id"] for idx in idxs]
114 | speakers = [data[idx]["speaker"] for idx in idxs]
115 | texts = [data[idx]["text"] for idx in idxs]
116 | raw_texts = [data[idx]["raw_text"] for idx in idxs]
117 | quary_texts = [data[idx]["quary_text"] for idx in idxs]
118 | raw_quary_texts = [data[idx]["raw_quary_text"] for idx in idxs]
119 | mels = [data[idx]["mel"] for idx in idxs]
120 | pitches = [data[idx]["pitch"] for idx in idxs]
121 | energies = [data[idx]["energy"] for idx in idxs]
122 | durations = [data[idx]["duration"] for idx in idxs]
123 | quary_durations = [data[idx]["quary_duration"] for idx in idxs]
124 |
125 | text_lens = np.array([text.shape[0] for text in texts])
126 | quary_text_lens = np.array([text.shape[0] for text in quary_texts])
127 | mel_lens = np.array([mel.shape[0] for mel in mels])
128 |
129 | speakers = np.array(speakers)
130 | texts = pad_1D(texts)
131 | quary_texts = pad_1D(quary_texts)
132 | mels = pad_2D(mels)
133 | pitches = pad_1D(pitches)
134 | energies = pad_1D(energies)
135 | durations = pad_1D(durations)
136 | quary_durations = pad_1D(quary_durations)
137 |
138 | return (
139 | ids,
140 | raw_texts,
141 | speakers,
142 | texts,
143 | text_lens,
144 | max(text_lens),
145 | mels,
146 | mel_lens,
147 | max(mel_lens),
148 | pitches,
149 | energies,
150 | durations,
151 | raw_quary_texts,
152 | quary_texts,
153 | quary_text_lens,
154 | max(quary_text_lens),
155 | quary_durations,
156 | )
157 |
158 | def collate_fn(self, data):
159 | data_size = len(data)
160 |
161 | if self.sort:
162 | len_arr = np.array([d["text"].shape[0] for d in data])
163 | idx_arr = np.argsort(-len_arr)
164 | else:
165 | idx_arr = np.arange(data_size)
166 |
167 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
168 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
169 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
170 | if not self.drop_last and len(tail) > 0:
171 | idx_arr += [tail.tolist()]
172 |
173 | output = list()
174 | for idx in idx_arr:
175 | output.append(self.reprocess(data, idx))
176 |
177 | return output
178 |
179 |
180 | class BatchInferenceDataset(Dataset):
181 | def __init__(self, filepath, preprocess_config):
182 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
183 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
184 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
185 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
186 |
187 | self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
188 | filepath
189 | )
190 | with open(
191 | os.path.join(
192 | preprocess_config["path"]["preprocessed_path"], "speakers.json"
193 | )
194 | ) as f:
195 | self.speaker_map = json.load(f)
196 |
197 | def __len__(self):
198 | return len(self.text)
199 |
200 | def __getitem__(self, idx):
201 | basename = self.basename[idx]
202 | speaker = self.speaker[idx]
203 | speaker_id = self.speaker_map[speaker]
204 | raw_text = self.raw_text[idx]
205 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
206 | mel_path = os.path.join(
207 | self.preprocessed_path,
208 | "mel",
209 | "{}-mel-{}.npy".format(speaker, basename),
210 | )
211 | mel = np.load(mel_path)
212 | pitch_path = os.path.join(
213 | self.preprocessed_path,
214 | "pitch",
215 | "{}-pitch-{}.npy".format(speaker, basename),
216 | )
217 | pitch = np.load(pitch_path)
218 | energy_path = os.path.join(
219 | self.preprocessed_path,
220 | "energy",
221 | "{}-energy-{}.npy".format(speaker, basename),
222 | )
223 | energy = np.load(energy_path)
224 | duration_path = os.path.join(
225 | self.preprocessed_path,
226 | "duration",
227 | "{}-duration-{}.npy".format(speaker, basename),
228 | )
229 | duration = np.load(duration_path)
230 |
231 | return (basename, speaker_id, phone, raw_text, mel, pitch, energy, duration)
232 |
233 | def process_meta(self, filename):
234 | with open(filename, "r", encoding="utf-8") as f:
235 | name = []
236 | speaker = []
237 | text = []
238 | raw_text = []
239 | for line in f.readlines():
240 | n, s, t, r = line.strip("\n").split("|")
241 | name.append(n)
242 | speaker.append(s)
243 | text.append(t)
244 | raw_text.append(r)
245 | return name, speaker, text, raw_text
246 |
247 | def collate_fn(self, data):
248 | ids = [d[0] for d in data]
249 | speakers = np.array([d[1] for d in data])
250 | texts = [d[2] for d in data]
251 | raw_texts = [d[3] for d in data]
252 | mels = [d[4] for d in data]
253 | pitches = [d[5] for d in data]
254 | energies = [d[6] for d in data]
255 | durations = [d[7] for d in data]
256 |
257 | text_lens = np.array([text.shape[0] for text in texts])
258 | mel_lens = np.array([mel.shape[0] for mel in mels])
259 |
260 | ref_infos = list()
261 | for _, (m, p, e, d) in enumerate(zip(mels, pitches, energies, durations)):
262 | if self.pitch_feature_level == "phoneme_level":
263 | pitch = expand(p, d)
264 | else:
265 | pitch = p
266 | if self.energy_feature_level == "phoneme_level":
267 | energy = expand(e, d)
268 | else:
269 | energy = e
270 | ref_infos.append((m.T, pitch, energy))
271 |
272 | texts = pad_1D(texts)
273 | mels = pad_2D(mels)
274 |
275 | return (
276 | ids,
277 | raw_texts,
278 | speakers,
279 | texts,
280 | text_lens,
281 | max(text_lens),
282 | mels,
283 | mel_lens,
284 | max(mel_lens),
285 | ref_infos,
286 | )
287 |
--------------------------------------------------------------------------------
/demo/LibriTTS/references/0_19_198_000000_000000.lab:
--------------------------------------------------------------------------------
1 | this is a librivox recording.
--------------------------------------------------------------------------------
/demo/LibriTTS/references/0_19_198_000000_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/0_19_198_000000_000000.wav
--------------------------------------------------------------------------------
/demo/LibriTTS/references/1_26_495_000004_000000.lab:
--------------------------------------------------------------------------------
1 | by daniel defoe
--------------------------------------------------------------------------------
/demo/LibriTTS/references/1_26_495_000004_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/1_26_495_000004_000000.wav
--------------------------------------------------------------------------------
/demo/LibriTTS/references/2_27_123349_000001_000000.lab:
--------------------------------------------------------------------------------
1 | at length all differences were compromised.
--------------------------------------------------------------------------------
/demo/LibriTTS/references/2_27_123349_000001_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/2_27_123349_000001_000000.wav
--------------------------------------------------------------------------------
/demo/LibriTTS/references/3_32_4137_000005_000001.lab:
--------------------------------------------------------------------------------
1 | as many of the slaves had been brought up in richmond, and had relations residing there, the slave trader determined to leave the city early in the morning, so as not to witness any of those scenes so common where slaves are separated from their relatives and friends, when about departing for the southern market.
--------------------------------------------------------------------------------
/demo/LibriTTS/references/3_32_4137_000005_000001.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/3_32_4137_000005_000001.wav
--------------------------------------------------------------------------------
/demo/LibriTTS/references/4_39_121916_000015_000005.lab:
--------------------------------------------------------------------------------
1 | ours are all apple tarts.
--------------------------------------------------------------------------------
/demo/LibriTTS/references/4_39_121916_000015_000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/4_39_121916_000015_000005.wav
--------------------------------------------------------------------------------
/demo/LibriTTS/references/5_40_222_000001_000000.lab:
--------------------------------------------------------------------------------
1 | chapter twenty five
--------------------------------------------------------------------------------
/demo/LibriTTS/references/5_40_222_000001_000000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/references/5_40_222_000001_000000.wav
--------------------------------------------------------------------------------
/demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..png
--------------------------------------------------------------------------------
/demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/0_the two children therefore got up, dressed themselves quickly, and went away..wav
--------------------------------------------------------------------------------
/demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..png
--------------------------------------------------------------------------------
/demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/1_the two children therefore got up, dressed themselves quickly, and went away..wav
--------------------------------------------------------------------------------
/demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..png
--------------------------------------------------------------------------------
/demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/2_the two children therefore got up, dressed themselves quickly, and went away..wav
--------------------------------------------------------------------------------
/demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..png
--------------------------------------------------------------------------------
/demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/3_the two children therefore got up, dressed themselves quickly, and went away..wav
--------------------------------------------------------------------------------
/demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..png
--------------------------------------------------------------------------------
/demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/4_the two children therefore got up, dressed themselves quickly, and went away..wav
--------------------------------------------------------------------------------
/demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..png
--------------------------------------------------------------------------------
/demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/demo/LibriTTS/results/5_the two children therefore got up, dressed themselves quickly, and went away..wav
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import yaml
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 |
9 | from utils.model import get_model, get_vocoder
10 | from utils.tools import to_device, log, synth_one_sample
11 | from model import MetaStyleSpeechLossMain
12 | from dataset import Dataset
13 |
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | def evaluate(model, step, configs, logger=None, vocoder=None, loss_len=5):
19 | preprocess_config, model_config, train_config = configs
20 |
21 | # Get dataset
22 | dataset = Dataset(
23 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False
24 | )
25 | batch_size = train_config["optimizer"]["batch_size"]
26 | loader = DataLoader(
27 | dataset,
28 | batch_size=batch_size,
29 | shuffle=False,
30 | collate_fn=dataset.collate_fn,
31 | )
32 |
33 | # Get loss function
34 | Loss = MetaStyleSpeechLossMain(preprocess_config, model_config, train_config).to(device)
35 |
36 | # Evaluation
37 | loss_sums = [0 for _ in range(loss_len)]
38 | for batchs in loader:
39 | for batch in batchs:
40 | batch = to_device(batch, device)
41 | with torch.no_grad():
42 | # Forward
43 | output = (None, None, *model(*(batch[2:-5])))
44 |
45 | # Cal Loss
46 | losses = Loss(batch, output)
47 |
48 | for i in range(len(losses)):
49 | loss_sums[i] += losses[i].item() * len(batch[0])
50 |
51 | loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums]
52 |
53 | message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
54 | *([step] + [l for l in loss_means[:5]])
55 | )
56 |
57 | if logger is not None:
58 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
59 | batch,
60 | output[2:],
61 | vocoder,
62 | model_config,
63 | preprocess_config,
64 | )
65 |
66 | log(logger, step, losses=loss_means)
67 | log(
68 | logger,
69 | fig=fig,
70 | tag="Validation/step_{}_{}".format(step, tag),
71 | )
72 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
73 | log(
74 | logger,
75 | audio=wav_reconstruction,
76 | sampling_rate=sampling_rate,
77 | tag="Validation/step_{}_{}_reconstructed".format(step, tag),
78 | )
79 | log(
80 | logger,
81 | audio=wav_prediction,
82 | sampling_rate=sampling_rate,
83 | tag="Validation/step_{}_{}_synthesized".format(step, tag),
84 | )
85 |
86 | return message
87 |
--------------------------------------------------------------------------------
/filelist_filtering.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import yaml
4 | import argparse
5 |
6 | def main(preprocess_config, model_config):
7 | preprocessed_path = preprocess_config["path"]["preprocessed_path"]
8 | max_seq_len = model_config["max_seq_len"]
9 |
10 | with open(
11 | os.path.join(preprocessed_path, "train.txt"), "r", encoding="utf-8"
12 | ) as f:
13 | filtered_list = []
14 | for i, line in enumerate(f.readlines()):
15 | basename, speaker, *_ = line.strip("\n").split("|")
16 | mel_path = os.path.join(
17 | preprocessed_path,
18 | "mel",
19 | "{}-mel-{}.npy".format(speaker, basename),
20 | )
21 | mel = np.load(mel_path)
22 | if mel.shape[0] <= max_seq_len:
23 | filtered_list.append(line)
24 |
25 | # Write Filtered Filelist
26 | with open(os.path.join(preprocessed_path, "train_filtered.txt"), "w", encoding="utf-8") as f:
27 | for line in filtered_list:
28 | f.write(line)
29 |
30 | if __name__ == "__main__":
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument(
33 | "-p",
34 | "--preprocess_config",
35 | type=str,
36 | required=True,
37 | help="path to preprocess.yaml",
38 | )
39 | parser.add_argument(
40 | "-m", "--model_config", type=str, required=True, help="path to model.yaml"
41 | )
42 | args = parser.parse_args()
43 |
44 | # Read Config
45 | preprocess_config = yaml.load(
46 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
47 | )
48 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
49 |
50 | main(preprocess_config, model_config)
--------------------------------------------------------------------------------
/hifigan/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jungil Kong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/hifigan/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import Generator
2 |
3 |
4 | class AttrDict(dict):
5 | def __init__(self, *args, **kwargs):
6 | super(AttrDict, self).__init__(*args, **kwargs)
7 | self.__dict__ = self
--------------------------------------------------------------------------------
/hifigan/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/hifigan/generator_LJSpeech.pth.tar.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/hifigan/generator_LJSpeech.pth.tar.zip
--------------------------------------------------------------------------------
/hifigan/generator_universal.pth.tar.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/hifigan/generator_universal.pth.tar.zip
--------------------------------------------------------------------------------
/hifigan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv1d, ConvTranspose1d
5 | from torch.nn.utils import weight_norm, remove_weight_norm
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | def init_weights(m, mean=0.0, std=0.01):
11 | classname = m.__class__.__name__
12 | if classname.find("Conv") != -1:
13 | m.weight.data.normal_(mean, std)
14 |
15 |
16 | def get_padding(kernel_size, dilation=1):
17 | return int((kernel_size * dilation - dilation) / 2)
18 |
19 |
20 | class ResBlock(torch.nn.Module):
21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22 | super(ResBlock, self).__init__()
23 | self.h = h
24 | self.convs1 = nn.ModuleList(
25 | [
26 | weight_norm(
27 | Conv1d(
28 | channels,
29 | channels,
30 | kernel_size,
31 | 1,
32 | dilation=dilation[0],
33 | padding=get_padding(kernel_size, dilation[0]),
34 | )
35 | ),
36 | weight_norm(
37 | Conv1d(
38 | channels,
39 | channels,
40 | kernel_size,
41 | 1,
42 | dilation=dilation[1],
43 | padding=get_padding(kernel_size, dilation[1]),
44 | )
45 | ),
46 | weight_norm(
47 | Conv1d(
48 | channels,
49 | channels,
50 | kernel_size,
51 | 1,
52 | dilation=dilation[2],
53 | padding=get_padding(kernel_size, dilation[2]),
54 | )
55 | ),
56 | ]
57 | )
58 | self.convs1.apply(init_weights)
59 |
60 | self.convs2 = nn.ModuleList(
61 | [
62 | weight_norm(
63 | Conv1d(
64 | channels,
65 | channels,
66 | kernel_size,
67 | 1,
68 | dilation=1,
69 | padding=get_padding(kernel_size, 1),
70 | )
71 | ),
72 | weight_norm(
73 | Conv1d(
74 | channels,
75 | channels,
76 | kernel_size,
77 | 1,
78 | dilation=1,
79 | padding=get_padding(kernel_size, 1),
80 | )
81 | ),
82 | weight_norm(
83 | Conv1d(
84 | channels,
85 | channels,
86 | kernel_size,
87 | 1,
88 | dilation=1,
89 | padding=get_padding(kernel_size, 1),
90 | )
91 | ),
92 | ]
93 | )
94 | self.convs2.apply(init_weights)
95 |
96 | def forward(self, x):
97 | for c1, c2 in zip(self.convs1, self.convs2):
98 | xt = F.leaky_relu(x, LRELU_SLOPE)
99 | xt = c1(xt)
100 | xt = F.leaky_relu(xt, LRELU_SLOPE)
101 | xt = c2(xt)
102 | x = xt + x
103 | return x
104 |
105 | def remove_weight_norm(self):
106 | for l in self.convs1:
107 | remove_weight_norm(l)
108 | for l in self.convs2:
109 | remove_weight_norm(l)
110 |
111 |
112 | class Generator(torch.nn.Module):
113 | def __init__(self, h):
114 | super(Generator, self).__init__()
115 | self.h = h
116 | self.num_kernels = len(h.resblock_kernel_sizes)
117 | self.num_upsamples = len(h.upsample_rates)
118 | self.conv_pre = weight_norm(
119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
120 | )
121 | resblock = ResBlock
122 |
123 | self.ups = nn.ModuleList()
124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125 | self.ups.append(
126 | weight_norm(
127 | ConvTranspose1d(
128 | h.upsample_initial_channel // (2 ** i),
129 | h.upsample_initial_channel // (2 ** (i + 1)),
130 | k,
131 | u,
132 | padding=(k - u) // 2,
133 | )
134 | )
135 | )
136 |
137 | self.resblocks = nn.ModuleList()
138 | for i in range(len(self.ups)):
139 | ch = h.upsample_initial_channel // (2 ** (i + 1))
140 | for j, (k, d) in enumerate(
141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142 | ):
143 | self.resblocks.append(resblock(h, ch, k, d))
144 |
145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146 | self.ups.apply(init_weights)
147 | self.conv_post.apply(init_weights)
148 |
149 | def forward(self, x):
150 | x = self.conv_pre(x)
151 | for i in range(self.num_upsamples):
152 | x = F.leaky_relu(x, LRELU_SLOPE)
153 | x = self.ups[i](x)
154 | xs = None
155 | for j in range(self.num_kernels):
156 | if xs is None:
157 | xs = self.resblocks[i * self.num_kernels + j](x)
158 | else:
159 | xs += self.resblocks[i * self.num_kernels + j](x)
160 | x = xs / self.num_kernels
161 | x = F.leaky_relu(x)
162 | x = self.conv_post(x)
163 | x = torch.tanh(x)
164 |
165 | return x
166 |
167 | def remove_weight_norm(self):
168 | print("Removing weight norm...")
169 | for l in self.ups:
170 | remove_weight_norm(l)
171 | for l in self.resblocks:
172 | l.remove_weight_norm()
173 | remove_weight_norm(self.conv_pre)
174 | remove_weight_norm(self.conv_post)
--------------------------------------------------------------------------------
/img/model_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/model_1.png
--------------------------------------------------------------------------------
/img/model_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/model_2.png
--------------------------------------------------------------------------------
/img/tensorboard_audio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/tensorboard_audio.png
--------------------------------------------------------------------------------
/img/tensorboard_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/tensorboard_loss.png
--------------------------------------------------------------------------------
/img/tensorboard_spec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/StyleSpeech/d11cb07b3e891a18c2a6d1e80fa0f18a389de446/img/tensorboard_spec.png
--------------------------------------------------------------------------------
/model/StyleSpeech.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .modules import (
9 | MelStyleEncoder,
10 | PhonemeEncoder,
11 | MelDecoder,
12 | VarianceAdaptor,
13 | PhonemeDiscriminator,
14 | StyleDiscriminator,
15 | )
16 | from utils.tools import get_mask_from_lengths
17 |
18 |
19 | class StyleSpeech(nn.Module):
20 | """ StyleSpeech """
21 |
22 | def __init__(self, preprocess_config, model_config):
23 | super(StyleSpeech, self).__init__()
24 | self.model_config = model_config
25 |
26 | self.mel_style_encoder = MelStyleEncoder(preprocess_config, model_config)
27 | self.phoneme_encoder = PhonemeEncoder(model_config)
28 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
29 | self.mel_decoder = MelDecoder(model_config)
30 | self.phoneme_linear = nn.Linear(
31 | model_config["transformer"]["encoder_hidden"],
32 | model_config["transformer"]["encoder_hidden"],
33 | )
34 | self.mel_linear = nn.Linear(
35 | model_config["transformer"]["decoder_hidden"],
36 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
37 | )
38 | self.D_t = PhonemeDiscriminator(preprocess_config, model_config)
39 | self.D_s = StyleDiscriminator(preprocess_config, model_config)
40 |
41 | with open(
42 | os.path.join(
43 | preprocess_config["path"]["preprocessed_path"], "speakers.json"
44 | ),
45 | "r",
46 | ) as f:
47 | n_speaker = len(json.load(f))
48 | self.style_prototype = nn.Embedding(
49 | n_speaker,
50 | model_config["melencoder"]["encoder_hidden"],
51 | )
52 |
53 | def G(
54 | self,
55 | style_vector,
56 | texts,
57 | src_masks,
58 | mel_masks,
59 | max_mel_len,
60 | p_targets=None,
61 | e_targets=None,
62 | d_targets=None,
63 | p_control=1.0,
64 | e_control=1.0,
65 | d_control=1.0,
66 | ):
67 | output = self.phoneme_encoder(texts, style_vector, src_masks)
68 | output = self.phoneme_linear(output)
69 |
70 | (
71 | output,
72 | p_predictions,
73 | e_predictions,
74 | log_d_predictions,
75 | d_rounded,
76 | mel_lens,
77 | mel_masks,
78 | ) = self.variance_adaptor(
79 | output,
80 | src_masks,
81 | mel_masks,
82 | max_mel_len,
83 | p_targets,
84 | e_targets,
85 | d_targets,
86 | p_control,
87 | e_control,
88 | d_control,
89 | )
90 |
91 | output, mel_masks = self.mel_decoder(output, style_vector, mel_masks)
92 | output = self.mel_linear(output)
93 |
94 | return (
95 | output,
96 | p_predictions,
97 | e_predictions,
98 | log_d_predictions,
99 | d_rounded,
100 | mel_lens,
101 | mel_masks,
102 | )
103 |
104 | def forward(
105 | self,
106 | _,
107 | texts,
108 | src_lens,
109 | max_src_len,
110 | mels,
111 | mel_lens,
112 | max_mel_len,
113 | p_targets=None,
114 | e_targets=None,
115 | d_targets=None,
116 | p_control=1.0,
117 | e_control=1.0,
118 | d_control=1.0,
119 | ):
120 | src_masks = get_mask_from_lengths(src_lens, max_src_len)
121 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len)
122 |
123 | style_vector = self.mel_style_encoder(mels, mel_masks)
124 |
125 | (
126 | output,
127 | p_predictions,
128 | e_predictions,
129 | log_d_predictions,
130 | d_rounded,
131 | mel_lens,
132 | mel_masks,
133 | ) = self.G(
134 | style_vector,
135 | texts,
136 | src_masks,
137 | mel_masks,
138 | max_mel_len,
139 | p_targets,
140 | e_targets,
141 | d_targets,
142 | p_control,
143 | e_control,
144 | d_control,
145 | )
146 |
147 | return (
148 | output,
149 | p_predictions,
150 | e_predictions,
151 | log_d_predictions,
152 | d_rounded,
153 | src_masks,
154 | mel_masks,
155 | src_lens,
156 | mel_lens,
157 | )
158 |
159 | def meta_learner_1(
160 | self,
161 | speakers,
162 | texts,
163 | src_lens,
164 | max_src_len,
165 | mels,
166 | mel_lens,
167 | max_mel_len,
168 | p_targets=None,
169 | e_targets=None,
170 | d_targets=None,
171 | raw_quary_texts=None,
172 | quary_texts=None,
173 | quary_src_lens=None,
174 | max_quary_src_len=None,
175 | quary_d_targets=None,
176 | p_control=1.0,
177 | e_control=1.0,
178 | d_control=1.0,
179 | ):
180 | src_masks = get_mask_from_lengths(src_lens, max_src_len)
181 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len)
182 |
183 | quary_mel_lens = quary_d_targets.sum(dim=-1)
184 | max_quary_mel_len = max(quary_mel_lens).item()
185 | quary_src_masks = get_mask_from_lengths(quary_src_lens, max_quary_src_len)
186 | quary_mel_masks = get_mask_from_lengths(quary_mel_lens, max_quary_mel_len)
187 |
188 | style_vector = self.mel_style_encoder(mels, mel_masks)
189 |
190 | (
191 | output,
192 | _,
193 | _,
194 | _,
195 | d_rounded_adv,
196 | mel_lens_adv,
197 | mel_masks_adv,
198 | ) = self.G(
199 | style_vector,
200 | quary_texts,
201 | quary_src_masks,
202 | quary_mel_masks,
203 | max_quary_mel_len,
204 | None,
205 | None,
206 | None,
207 | p_control,
208 | e_control,
209 | d_control,
210 | )
211 |
212 | D_s = self.D_s(self.style_prototype, speakers, output, mel_masks_adv)
213 |
214 | quary_texts = self.phoneme_encoder.src_word_emb(quary_texts)
215 | D_t = self.D_t(self.variance_adaptor.upsample, quary_texts, output, max(mel_lens_adv).item(), mel_masks_adv, d_rounded_adv)
216 |
217 | (
218 | G,
219 | p_predictions,
220 | e_predictions,
221 | log_d_predictions,
222 | d_rounded,
223 | mel_lens,
224 | mel_masks,
225 | ) = self.G(
226 | style_vector,
227 | texts,
228 | src_masks,
229 | mel_masks,
230 | max_mel_len,
231 | p_targets,
232 | e_targets,
233 | d_targets,
234 | p_control,
235 | e_control,
236 | d_control,
237 | )
238 |
239 | return (
240 | D_s,
241 | D_t,
242 | G,
243 | p_predictions,
244 | e_predictions,
245 | log_d_predictions,
246 | d_rounded,
247 | src_masks,
248 | mel_masks,
249 | src_lens,
250 | mel_lens,
251 | )
252 |
253 | def meta_learner_2(
254 | self,
255 | speakers,
256 | texts,
257 | src_lens,
258 | max_src_len,
259 | mels,
260 | mel_lens,
261 | max_mel_len,
262 | p_targets=None,
263 | e_targets=None,
264 | d_targets=None,
265 | raw_quary_texts=None,
266 | quary_texts=None,
267 | quary_src_lens=None,
268 | max_quary_src_len=None,
269 | quary_d_targets=None,
270 | p_control=1.0,
271 | e_control=1.0,
272 | d_control=1.0,
273 | ):
274 | src_masks = get_mask_from_lengths(src_lens, max_src_len)
275 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len)
276 |
277 | quary_mel_lens = quary_d_targets.sum(dim=-1)
278 | max_quary_mel_len = max(quary_mel_lens).item()
279 | quary_src_masks = get_mask_from_lengths(quary_src_lens, max_quary_src_len)
280 | quary_mel_masks = get_mask_from_lengths(quary_mel_lens, max_quary_mel_len)
281 |
282 | style_vector = self.mel_style_encoder(mels, mel_masks)
283 |
284 | (
285 | output,
286 | _,
287 | _,
288 | _,
289 | d_rounded_adv,
290 | mel_lens_adv,
291 | mel_masks_adv,
292 | ) = self.G(
293 | style_vector,
294 | quary_texts,
295 | quary_src_masks,
296 | quary_mel_masks,
297 | max_quary_mel_len,
298 | None,
299 | None,
300 | None,
301 | p_control,
302 | e_control,
303 | d_control,
304 | )
305 |
306 | texts = self.phoneme_encoder.src_word_emb(texts)
307 | D_t_s = self.D_t(self.variance_adaptor.upsample, texts, mels, max_mel_len, mel_masks, d_targets)
308 |
309 | quary_texts = self.phoneme_encoder.src_word_emb(quary_texts)
310 | D_t_q = self.D_t(self.variance_adaptor.upsample, quary_texts, output, max(mel_lens_adv).item(), mel_masks_adv, d_rounded_adv)
311 |
312 | D_s_s = self.D_s(self.style_prototype, speakers, mels, mel_masks)
313 | D_s_q = self.D_s(self.style_prototype, speakers, output, mel_masks_adv)
314 |
315 | # Get Style Logit
316 | w = style_vector.squeeze() # [B, H]
317 | style_logit = torch.matmul(w, self.style_prototype.weight.contiguous().transpose(0, 1)) # [B, K]
318 |
319 | return D_t_s, D_t_q, D_s_s, D_s_q, style_logit
320 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .StyleSpeech import StyleSpeech
2 | from .loss import MetaStyleSpeechLossMain, MetaStyleSpeechLossDisc
3 | from .optimizer import ScheduledOptimMain, ScheduledOptimDisc
--------------------------------------------------------------------------------
/model/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.nn import functional as F
5 |
6 |
7 | class Mish(nn.Module):
8 | def forward(self, x):
9 | return x * torch.tanh(F.softplus(x))
10 |
11 |
12 | class StyleAdaptiveLayerNorm(nn.Module):
13 | """ Style-Adaptive Layer Norm (SALN) """
14 |
15 | def __init__(self, w_size, hidden_size, bias=False):
16 | super(StyleAdaptiveLayerNorm, self).__init__()
17 | self.hidden_size = hidden_size
18 | self.affine_layer = LinearNorm(
19 | w_size,
20 | 2 * hidden_size, # For both b (bias) g (gain)
21 | bias,
22 | )
23 |
24 | def forward(self, h, w):
25 | """
26 | h --- [B, T, H_m]
27 | w --- [B, 1, H_w]
28 | o --- [B, T, H_m]
29 | """
30 |
31 | # Normalize Input Features
32 | mu, sigma = torch.mean(h, dim=-1, keepdim=True), torch.std(h, dim=-1, keepdim=True)
33 | y = (h - mu) / sigma # [B, T, H_m]
34 |
35 | # Get Bias and Gain
36 | b, g = torch.split(self.affine_layer(w), self.hidden_size, dim=-1) # [B, 1, 2 * H_m] --> 2 * [B, 1, H_m]
37 |
38 | # Perform Scailing and Shifting
39 | o = g * y + b # [B, T, H_m]
40 |
41 | return o
42 |
43 |
44 | class FCBlock(nn.Module):
45 | """ Fully Connected Block """
46 |
47 | def __init__(self, in_features, out_features, activation=None, bias=False, dropout=None, spectral_norm=False):
48 | super(FCBlock, self).__init__()
49 | self.fc_layer = nn.Sequential()
50 | self.fc_layer.add_module(
51 | "fc_layer",
52 | LinearNorm(
53 | in_features,
54 | out_features,
55 | bias,
56 | spectral_norm,
57 | ),
58 | )
59 | if activation is not None:
60 | self.fc_layer.add_module("activ", activation)
61 | self.dropout = dropout
62 |
63 | def forward(self, x):
64 | x = self.fc_layer(x)
65 | if self.dropout is not None:
66 | x = F.dropout(x, self.dropout, self.training)
67 | return x
68 |
69 |
70 | class LinearNorm(nn.Module):
71 | """ LinearNorm Projection """
72 |
73 | def __init__(self, in_features, out_features, bias=False, spectral_norm=False):
74 | super(LinearNorm, self).__init__()
75 | self.linear = nn.Linear(in_features, out_features, bias)
76 |
77 | nn.init.xavier_uniform_(self.linear.weight)
78 | if bias:
79 | nn.init.constant_(self.linear.bias, 0.0)
80 | if spectral_norm:
81 | self.linear = nn.utils.spectral_norm(self.linear)
82 |
83 | def forward(self, x):
84 | x = self.linear(x)
85 | return x
86 |
87 |
88 | class Conv1DBlock(nn.Module):
89 | """ 1D Convolutional Block """
90 |
91 | def __init__(self, in_channels, out_channels, kernel_size, activation=None, dropout=None, spectral_norm=False):
92 | super(Conv1DBlock, self).__init__()
93 |
94 | self.conv_layer = nn.Sequential()
95 | self.conv_layer.add_module(
96 | "conv_layer",
97 | ConvNorm(
98 | in_channels,
99 | out_channels,
100 | kernel_size=kernel_size,
101 | stride=1,
102 | padding=int((kernel_size - 1) / 2),
103 | dilation=1,
104 | w_init_gain="tanh",
105 | spectral_norm=spectral_norm,
106 | ),
107 | )
108 | if activation is not None:
109 | self.conv_layer.add_module("activ", activation)
110 | self.dropout = dropout
111 |
112 | def forward(self, x, mask=None):
113 | x = x.contiguous().transpose(1, 2)
114 | x = self.conv_layer(x)
115 |
116 | if self.dropout is not None:
117 | x = F.dropout(x, self.dropout, self.training)
118 |
119 | x = x.contiguous().transpose(1, 2)
120 | if mask is not None:
121 | x = x.masked_fill(mask.unsqueeze(-1), 0)
122 |
123 | return x
124 |
125 |
126 | class ConvNorm(nn.Module):
127 | """ 1D Convolution """
128 |
129 | def __init__(
130 | self,
131 | in_channels,
132 | out_channels,
133 | kernel_size=1,
134 | stride=1,
135 | padding=None,
136 | dilation=1,
137 | bias=True,
138 | w_init_gain="linear",
139 | spectral_norm=False,
140 | ):
141 | super(ConvNorm, self).__init__()
142 |
143 | if padding is None:
144 | assert kernel_size % 2 == 1
145 | padding = int(dilation * (kernel_size - 1) / 2)
146 |
147 | self.conv = nn.Conv1d(
148 | in_channels,
149 | out_channels,
150 | kernel_size=kernel_size,
151 | stride=stride,
152 | padding=padding,
153 | dilation=dilation,
154 | bias=bias,
155 | )
156 | if spectral_norm:
157 | self.conv = nn.utils.spectral_norm(self.conv)
158 |
159 | def forward(self, signal):
160 | conv_signal = self.conv(signal)
161 |
162 | return conv_signal
163 |
164 |
165 | class SALNFFTBlock(nn.Module):
166 | """ FFT Block with SALN """
167 |
168 | def __init__(self, d_model, d_w, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
169 | super(SALNFFTBlock, self).__init__()
170 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
171 | self.pos_ffn = PositionwiseFeedForward(
172 | d_model, d_inner, kernel_size, dropout=dropout
173 | )
174 | self.layer_norm_1 = StyleAdaptiveLayerNorm(d_w, d_model)
175 | self.layer_norm_2 = StyleAdaptiveLayerNorm(d_w, d_model)
176 |
177 | def forward(self, enc_input, w, mask=None, slf_attn_mask=None):
178 | enc_output, enc_slf_attn = self.slf_attn(
179 | enc_input, enc_input, enc_input, mask=slf_attn_mask
180 | )
181 | enc_output = self.layer_norm_1(enc_output, w)
182 | if mask is not None:
183 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
184 |
185 | enc_output = self.pos_ffn(enc_output)
186 | enc_output = self.layer_norm_2(enc_output, w)
187 | if mask is not None:
188 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
189 |
190 | return enc_output, enc_slf_attn
191 |
192 |
193 | class MultiHeadAttention(nn.Module):
194 | """ Multi-Head Attention """
195 |
196 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, layer_norm=False, spectral_norm=False):
197 | super(MultiHeadAttention, self).__init__()
198 |
199 | self.n_head = n_head
200 | self.d_k = d_k
201 | self.d_v = d_v
202 |
203 | self.w_qs = LinearNorm(d_model, n_head * d_k, spectral_norm=spectral_norm)
204 | self.w_ks = LinearNorm(d_model, n_head * d_k, spectral_norm=spectral_norm)
205 | self.w_vs = LinearNorm(d_model, n_head * d_v, spectral_norm=spectral_norm)
206 |
207 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
208 | self.layer_norm = nn.LayerNorm(d_model) if layer_norm else None
209 |
210 | self.fc = LinearNorm(n_head * d_v, d_model, spectral_norm=spectral_norm)
211 |
212 | self.dropout = nn.Dropout(dropout)
213 |
214 | def forward(self, q, k, v, mask=None):
215 |
216 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
217 |
218 | sz_b, len_q, _ = q.size()
219 | sz_b, len_k, _ = k.size()
220 | sz_b, len_v, _ = v.size()
221 |
222 | residual = q
223 |
224 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
225 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
226 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
227 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
228 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
229 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
230 |
231 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
232 | output, attn = self.attention(q, k, v, mask=mask)
233 |
234 | output = output.view(n_head, sz_b, len_q, d_v)
235 | output = (
236 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)
237 | ) # b x lq x (n*dv)
238 |
239 | output = self.dropout(self.fc(output))
240 | output = output + residual
241 | if self.layer_norm is not None:
242 | output = self.layer_norm(output)
243 |
244 | return output, attn
245 |
246 |
247 | class ScaledDotProductAttention(nn.Module):
248 | """ Scaled Dot-Product Attention """
249 |
250 | def __init__(self, temperature):
251 | super(ScaledDotProductAttention, self).__init__()
252 | self.temperature = temperature
253 | self.softmax = nn.Softmax(dim=2)
254 |
255 | def forward(self, q, k, v, mask=None):
256 |
257 | attn = torch.bmm(q, k.transpose(1, 2))
258 | attn = attn / self.temperature
259 |
260 | if mask is not None:
261 | attn = attn.masked_fill(mask, -np.inf)
262 |
263 | attn = self.softmax(attn)
264 | output = torch.bmm(attn, v)
265 |
266 | return output, attn
267 |
268 |
269 | class PositionwiseFeedForward(nn.Module):
270 | """ A two-feed-forward-layer """
271 |
272 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1):
273 | super(PositionwiseFeedForward, self).__init__()
274 |
275 | # Use Conv1D
276 | # position-wise
277 | self.w_1 = nn.Conv1d(
278 | d_in,
279 | d_hid,
280 | kernel_size=kernel_size[0],
281 | padding=(kernel_size[0] - 1) // 2,
282 | )
283 | # position-wise
284 | self.w_2 = nn.Conv1d(
285 | d_hid,
286 | d_in,
287 | kernel_size=kernel_size[1],
288 | padding=(kernel_size[1] - 1) // 2,
289 | )
290 | self.dropout = nn.Dropout(dropout)
291 |
292 | def forward(self, x):
293 | residual = x
294 | output = x.transpose(1, 2)
295 | output = self.w_2(F.relu(self.w_1(output)))
296 | output = output.transpose(1, 2)
297 | output = self.dropout(output)
298 | output = output + residual
299 |
300 | return output
301 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MetaStyleSpeechLossMain(nn.Module):
6 | """ Meta-StyleSpeech Loss for naive StyleSpeech and Step 1 """
7 |
8 | def __init__(self, preprocess_config, model_config, train_config):
9 | super(MetaStyleSpeechLossMain, self).__init__()
10 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
11 | "feature"
12 | ]
13 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
14 | "feature"
15 | ]
16 | self.alpha = train_config["optimizer"]["alpha"]
17 | self.mse_loss = nn.MSELoss()
18 | self.mae_loss = nn.L1Loss()
19 |
20 | def forward(self, inputs, predictions):
21 | (
22 | mel_targets,
23 | _,
24 | _,
25 | pitch_targets,
26 | energy_targets,
27 | duration_targets,
28 | _,
29 | _,
30 | _,
31 | _,
32 | _,
33 | ) = inputs[6:]
34 | (
35 | D_s,
36 | D_t,
37 | mel_predictions,
38 | pitch_predictions,
39 | energy_predictions,
40 | log_duration_predictions,
41 | _,
42 | src_masks,
43 | mel_masks,
44 | _,
45 | _,
46 | ) = predictions
47 | src_masks = ~src_masks
48 | mel_masks = ~mel_masks
49 | log_duration_targets = torch.log(duration_targets.float() + 1)
50 | mel_targets = mel_targets[:, : mel_masks.shape[1], :]
51 | mel_masks = mel_masks[:, :mel_masks.shape[1]]
52 |
53 | log_duration_targets.requires_grad = False
54 | pitch_targets.requires_grad = False
55 | energy_targets.requires_grad = False
56 | mel_targets.requires_grad = False
57 |
58 | if self.pitch_feature_level == "phoneme_level":
59 | pitch_predictions = pitch_predictions.masked_select(src_masks)
60 | pitch_targets = pitch_targets.masked_select(src_masks)
61 | elif self.pitch_feature_level == "frame_level":
62 | pitch_predictions = pitch_predictions.masked_select(mel_masks)
63 | pitch_targets = pitch_targets.masked_select(mel_masks)
64 |
65 | if self.energy_feature_level == "phoneme_level":
66 | energy_predictions = energy_predictions.masked_select(src_masks)
67 | energy_targets = energy_targets.masked_select(src_masks)
68 | if self.energy_feature_level == "frame_level":
69 | energy_predictions = energy_predictions.masked_select(mel_masks)
70 | energy_targets = energy_targets.masked_select(mel_masks)
71 |
72 | log_duration_predictions = log_duration_predictions.masked_select(src_masks)
73 | log_duration_targets = log_duration_targets.masked_select(src_masks)
74 |
75 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
76 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
77 |
78 | mel_loss = self.mae_loss(mel_predictions, mel_targets)
79 |
80 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
81 | energy_loss = self.mse_loss(energy_predictions, energy_targets)
82 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
83 |
84 | alpha = 1
85 | D_s_loss = D_t_loss = torch.tensor([0.], device=mel_predictions.device, requires_grad=False)
86 | if D_s is not None and D_t is not None:
87 | D_s_loss = self.mse_loss(D_s, torch.ones_like(D_s, requires_grad=False))
88 | D_t_loss = self.mse_loss(D_t, torch.ones_like(D_t, requires_grad=False))
89 | alpha = self.alpha
90 |
91 | recon_loss = alpha * (mel_loss + duration_loss + pitch_loss + energy_loss)
92 | total_loss = (
93 | recon_loss + D_s_loss + D_t_loss
94 | )
95 |
96 | return (
97 | total_loss,
98 | mel_loss,
99 | pitch_loss,
100 | energy_loss,
101 | duration_loss,
102 | D_s_loss,
103 | D_t_loss,
104 | )
105 |
106 |
107 | class MetaStyleSpeechLossDisc(nn.Module):
108 | """ Meta-StyleSpeech Loss for Step 2 """
109 |
110 | def __init__(self, preprocess_config, model_config):
111 | super(MetaStyleSpeechLossDisc, self).__init__()
112 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
113 | "feature"
114 | ]
115 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
116 | "feature"
117 | ]
118 | self.mse_loss = nn.MSELoss()
119 | self.mae_loss = nn.L1Loss()
120 | self.cross_entropy_loss = nn.CrossEntropyLoss()
121 |
122 | def forward(self, speakers, predictions):
123 | (
124 | D_t_s,
125 | D_t_q,
126 | D_s_s,
127 | D_s_q,
128 | style_logit,
129 | ) = predictions
130 | speakers.requires_grad = False
131 |
132 | D_t_loss = self.mse_loss(D_t_s, torch.ones_like(D_t_s, requires_grad=False))\
133 | + self.mse_loss(D_t_q, torch.zeros_like(D_t_q, requires_grad=False))
134 | D_s_loss = self.mse_loss(D_s_s, torch.ones_like(D_s_s, requires_grad=False))\
135 | + self.mse_loss(D_s_q, torch.zeros_like(D_s_q, requires_grad=False))
136 | cls_loss = self.cross_entropy_loss(style_logit, speakers)
137 |
138 | total_loss = (
139 | D_t_loss + D_s_loss + cls_loss
140 | )
141 |
142 | return (
143 | total_loss,
144 | D_s_loss,
145 | D_t_loss,
146 | cls_loss,
147 | )
148 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import math
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.parameter import Parameter
10 | import numpy as np
11 | import torch.nn.functional as F
12 |
13 | from utils.tools import get_mask_from_lengths, pad
14 |
15 | from .blocks import (
16 | Mish,
17 | FCBlock,
18 | Conv1DBlock,
19 | SALNFFTBlock,
20 | MultiHeadAttention,
21 | )
22 | from text.symbols import symbols
23 |
24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 |
26 |
27 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
28 | """ Sinusoid position encoding table """
29 |
30 | def cal_angle(position, hid_idx):
31 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
32 |
33 | def get_posi_angle_vec(position):
34 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
35 |
36 | sinusoid_table = np.array(
37 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
38 | )
39 |
40 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
41 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
42 |
43 | if padding_idx is not None:
44 | # zero vector for padding dimension
45 | sinusoid_table[padding_idx] = 0.0
46 |
47 | return torch.FloatTensor(sinusoid_table)
48 |
49 |
50 | class MelStyleEncoder(nn.Module):
51 | """ Mel-Style Encoder """
52 |
53 | def __init__(self, preprocess_config, model_config):
54 | super(MelStyleEncoder, self).__init__()
55 | n_position = model_config["max_seq_len"] + 1
56 | n_mel_channels = preprocess_config["preprocessing"]["mel"]["n_mel_channels"]
57 | d_melencoder = model_config["melencoder"]["encoder_hidden"]
58 | n_spectral_layer = model_config["melencoder"]["spectral_layer"]
59 | n_temporal_layer = model_config["melencoder"]["temporal_layer"]
60 | n_slf_attn_layer = model_config["melencoder"]["slf_attn_layer"]
61 | n_slf_attn_head = model_config["melencoder"]["slf_attn_head"]
62 | d_k = d_v = (
63 | model_config["melencoder"]["encoder_hidden"]
64 | // model_config["melencoder"]["slf_attn_head"]
65 | )
66 | kernel_size = model_config["melencoder"]["conv_kernel_size"]
67 | dropout = model_config["melencoder"]["encoder_dropout"]
68 |
69 | self.max_seq_len = model_config["max_seq_len"]
70 |
71 | self.fc_1 = FCBlock(n_mel_channels, d_melencoder)
72 |
73 | self.spectral_stack = nn.ModuleList(
74 | [
75 | FCBlock(
76 | d_melencoder, d_melencoder, activation=Mish()
77 | )
78 | for _ in range(n_spectral_layer)
79 | ]
80 | )
81 |
82 | self.temporal_stack = nn.ModuleList(
83 | [
84 | nn.Sequential(
85 | Conv1DBlock(
86 | d_melencoder, 2 * d_melencoder, kernel_size, activation=Mish(), dropout=dropout
87 | ),
88 | nn.GLU(),
89 | )
90 | for _ in range(n_temporal_layer)
91 | ]
92 | )
93 |
94 | self.slf_attn_stack = nn.ModuleList(
95 | [
96 | MultiHeadAttention(
97 | n_slf_attn_head, d_melencoder, d_k, d_v, dropout=dropout, layer_norm=True
98 | )
99 | for _ in range(n_slf_attn_layer)
100 | ]
101 | )
102 |
103 | self.fc_2 = FCBlock(d_melencoder, d_melencoder)
104 |
105 | def forward(self, mel, mask):
106 |
107 | max_len = mel.shape[1]
108 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
109 |
110 | enc_output = self.fc_1(mel)
111 |
112 | # Spectral Processing
113 | for _, layer in enumerate(self.spectral_stack):
114 | enc_output = layer(enc_output)
115 |
116 | # Temporal Processing
117 | for _, layer in enumerate(self.temporal_stack):
118 | residual = enc_output
119 | enc_output = layer(enc_output)
120 | enc_output = residual + enc_output
121 |
122 | # Multi-head self-attention
123 | for _, layer in enumerate(self.slf_attn_stack):
124 | residual = enc_output
125 | enc_output, _ = layer(
126 | enc_output, enc_output, enc_output, mask=slf_attn_mask
127 | )
128 | enc_output = residual + enc_output
129 |
130 | # Final Layer
131 | enc_output = self.fc_2(enc_output) # [B, T, H]
132 |
133 | # Temporal Average Pooling
134 | enc_output = torch.mean(enc_output, dim=1, keepdim=True) # [B, 1, H]
135 |
136 | return enc_output
137 |
138 |
139 | class PhonemePreNet(nn.Module):
140 | """ Phoneme Encoder PreNet """
141 |
142 | def __init__(self, config):
143 | super(PhonemePreNet, self).__init__()
144 | d_model = config["transformer"]["encoder_hidden"]
145 | kernel_size = config["prenet"]["conv_kernel_size"]
146 | dropout = config["prenet"]["dropout"]
147 |
148 | self.prenet_layer = nn.Sequential(
149 | Conv1DBlock(
150 | d_model, d_model, kernel_size, activation=Mish(), dropout=dropout
151 | ),
152 | Conv1DBlock(
153 | d_model, d_model, kernel_size, activation=Mish(), dropout=dropout
154 | ),
155 | FCBlock(d_model, d_model, dropout=dropout),
156 | )
157 |
158 | def forward(self, x, mask=None):
159 | residual = x
160 | x = self.prenet_layer(x)
161 | if mask is not None:
162 | x = x.masked_fill(mask.unsqueeze(-1), 0)
163 | x = residual + x
164 | return x
165 |
166 |
167 | class PhonemeEncoder(nn.Module):
168 | """ PhonemeText Encoder """
169 |
170 | def __init__(self, config):
171 | super(PhonemeEncoder, self).__init__()
172 |
173 | n_position = config["max_seq_len"] + 1
174 | n_src_vocab = len(symbols) + 1
175 | d_word_vec = config["transformer"]["encoder_hidden"]
176 | n_layers = config["transformer"]["encoder_layer"]
177 | n_head = config["transformer"]["encoder_head"]
178 | d_w = config["melencoder"]["encoder_hidden"]
179 | d_k = d_v = (
180 | config["transformer"]["encoder_hidden"]
181 | // config["transformer"]["encoder_head"]
182 | )
183 | d_model = config["transformer"]["encoder_hidden"]
184 | d_inner = config["transformer"]["conv_filter_size"]
185 | kernel_size = config["transformer"]["conv_kernel_size"]
186 | dropout = config["transformer"]["encoder_dropout"]
187 |
188 | self.max_seq_len = config["max_seq_len"]
189 | self.d_model = d_model
190 |
191 | self.src_word_emb = nn.Embedding(
192 | n_src_vocab, d_word_vec, padding_idx=0
193 | )
194 | self.phoneme_prenet = PhonemePreNet(config)
195 | self.position_enc = nn.Parameter(
196 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
197 | requires_grad=False,
198 | )
199 |
200 | self.layer_stack = nn.ModuleList(
201 | [
202 | SALNFFTBlock(
203 | d_model, d_w, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
204 | )
205 | for _ in range(n_layers)
206 | ]
207 | )
208 |
209 | def forward(self, src_seq, w, mask, return_attns=False):
210 |
211 | enc_slf_attn_list = []
212 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
213 |
214 | # -- Prepare masks
215 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
216 |
217 | # -- PreNet
218 | src_seq = self.phoneme_prenet(self.src_word_emb(src_seq), mask)
219 |
220 | # -- Forward
221 | if not self.training and src_seq.shape[1] > self.max_seq_len:
222 | enc_output = src_seq + get_sinusoid_encoding_table(
223 | src_seq.shape[1], self.d_model
224 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
225 | src_seq.device
226 | )
227 | else:
228 | enc_output = src_seq + self.position_enc[
229 | :, :max_len, :
230 | ].expand(batch_size, -1, -1)
231 |
232 | for enc_layer in self.layer_stack:
233 | enc_output, enc_slf_attn = enc_layer(
234 | enc_output, w, mask=mask, slf_attn_mask=slf_attn_mask
235 | )
236 | if return_attns:
237 | enc_slf_attn_list += [enc_slf_attn]
238 |
239 | return enc_output
240 |
241 |
242 | class MelPreNet(nn.Module):
243 | """ Mel-spectrogram Decoder PreNet """
244 |
245 | def __init__(self, config):
246 | super(MelPreNet, self).__init__()
247 | d_model = config["transformer"]["encoder_hidden"]
248 | d_melencoder = config["melencoder"]["encoder_hidden"]
249 | dropout = config["prenet"]["dropout"]
250 |
251 | self.prenet_layer = nn.Sequential(
252 | FCBlock(d_model, d_melencoder, activation=Mish(), dropout=dropout),
253 | FCBlock(d_melencoder, d_model, activation=Mish(), dropout=dropout),
254 | )
255 |
256 | def forward(self, x, mask=None):
257 | x = self.prenet_layer(x)
258 | if mask is not None:
259 | x = x.masked_fill(mask.unsqueeze(-1), 0)
260 | return x
261 |
262 |
263 | class MelDecoder(nn.Module):
264 | """ MelDecoder """
265 |
266 | def __init__(self, config):
267 | super(MelDecoder, self).__init__()
268 |
269 | n_position = config["max_seq_len"] + 1
270 | d_word_vec = config["transformer"]["decoder_hidden"]
271 | n_layers = config["transformer"]["decoder_layer"]
272 | n_head = config["transformer"]["decoder_head"]
273 | d_w = config["melencoder"]["encoder_hidden"]
274 | d_k = d_v = (
275 | config["transformer"]["decoder_hidden"]
276 | // config["transformer"]["decoder_head"]
277 | )
278 | d_model = config["transformer"]["decoder_hidden"]
279 | d_inner = config["transformer"]["conv_filter_size"]
280 | kernel_size = config["transformer"]["conv_kernel_size"]
281 | dropout = config["transformer"]["decoder_dropout"]
282 |
283 | self.max_seq_len = config["max_seq_len"]
284 | self.d_model = d_model
285 |
286 | self.mel_prenet = MelPreNet(config)
287 | self.position_enc = nn.Parameter(
288 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
289 | requires_grad=False,
290 | )
291 |
292 | self.layer_stack = nn.ModuleList(
293 | [
294 | SALNFFTBlock(
295 | d_model, d_w, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
296 | )
297 | for _ in range(n_layers)
298 | ]
299 | )
300 |
301 | def forward(self, enc_seq, w, mask, return_attns=False):
302 |
303 | dec_slf_attn_list = []
304 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
305 |
306 | # -- PreNet
307 | enc_seq = self.mel_prenet(enc_seq, mask)
308 |
309 | # -- Forward
310 | if not self.training and enc_seq.shape[1] > self.max_seq_len:
311 | # -- Prepare masks
312 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
313 | dec_output = enc_seq + get_sinusoid_encoding_table(
314 | enc_seq.shape[1], self.d_model
315 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
316 | enc_seq.device
317 | )
318 | else:
319 | max_len = min(max_len, self.max_seq_len)
320 |
321 | # -- Prepare masks
322 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
323 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[
324 | :, :max_len, :
325 | ].expand(batch_size, -1, -1)
326 | mask = mask[:, :max_len]
327 | slf_attn_mask = slf_attn_mask[:, :, :max_len]
328 |
329 | for dec_layer in self.layer_stack:
330 | dec_output, dec_slf_attn = dec_layer(
331 | dec_output, w, mask=mask, slf_attn_mask=slf_attn_mask
332 | )
333 | if return_attns:
334 | dec_slf_attn_list += [dec_slf_attn]
335 |
336 | return dec_output, mask
337 |
338 |
339 | class VarianceAdaptor(nn.Module):
340 | """ Variance Adaptor """
341 |
342 | def __init__(self, preprocess_config, model_config):
343 | super(VarianceAdaptor, self).__init__()
344 | self.duration_predictor = VariancePredictor(model_config)
345 | self.length_regulator = LengthRegulator()
346 | self.pitch_predictor = VariancePredictor(model_config)
347 | self.energy_predictor = VariancePredictor(model_config)
348 |
349 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
350 | "feature"
351 | ]
352 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
353 | "feature"
354 | ]
355 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
356 | assert self.energy_feature_level in ["phoneme_level", "frame_level"]
357 |
358 | d_model = model_config["transformer"]["encoder_hidden"]
359 | kernel_size = model_config["variance_embedding"]["kernel_size"]
360 | self.pitch_embedding = Conv1DBlock(
361 | 1, d_model, kernel_size
362 | )
363 | self.energy_embedding = Conv1DBlock(
364 | 1, d_model, kernel_size
365 | )
366 |
367 | def get_pitch_embedding(self, x, target, mask, control):
368 | prediction = self.pitch_predictor(x, mask)
369 | if target is not None:
370 | embedding = self.pitch_embedding(target.unsqueeze(-1))
371 | else:
372 | prediction = prediction * control
373 | embedding = self.pitch_embedding(prediction.unsqueeze(-1))
374 | return prediction, embedding
375 |
376 | def get_energy_embedding(self, x, target, mask, control):
377 | prediction = self.energy_predictor(x, mask)
378 | if target is not None:
379 | embedding = self.energy_embedding(target.unsqueeze(-1))
380 | else:
381 | prediction = prediction * control
382 | embedding = self.energy_embedding(prediction.unsqueeze(-1))
383 | return prediction, embedding
384 |
385 | def upsample(self, x, mel_mask, max_len, log_duration_prediction=None, duration_target=None, d_control=1.0):
386 | if duration_target is not None:
387 | x, mel_len = self.length_regulator(x, duration_target, max_len)
388 | duration_rounded = duration_target
389 | else:
390 | duration_rounded = torch.clamp(
391 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
392 | min=0,
393 | )
394 | x, mel_len = self.length_regulator(x, duration_rounded, None)
395 | mel_mask = get_mask_from_lengths(mel_len)
396 | return x, duration_rounded, mel_len, mel_mask
397 |
398 | def forward(
399 | self,
400 | x,
401 | src_mask,
402 | mel_mask,
403 | max_len,
404 | pitch_target=None,
405 | energy_target=None,
406 | duration_target=None,
407 | p_control=1.0,
408 | e_control=1.0,
409 | d_control=1.0,
410 | ):
411 | upsampled_text = None
412 | log_duration_prediction = self.duration_predictor(x, src_mask)
413 | if self.pitch_feature_level == "phoneme_level":
414 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
415 | x, pitch_target, src_mask, p_control
416 | )
417 | x = x + pitch_embedding
418 | if self.energy_feature_level == "phoneme_level":
419 | energy_prediction, energy_embedding = self.get_energy_embedding(
420 | x, energy_target, src_mask, p_control
421 | )
422 | x = x + energy_embedding
423 |
424 | x, duration_rounded, mel_len, mel_mask = self.upsample(
425 | x, mel_mask, max_len, log_duration_prediction=log_duration_prediction, duration_target=duration_target, d_control=d_control
426 | )
427 |
428 | if self.pitch_feature_level == "frame_level":
429 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
430 | x, pitch_target, mel_mask, p_control
431 | )
432 | x = x + pitch_embedding
433 | if self.energy_feature_level == "frame_level":
434 | energy_prediction, energy_embedding = self.get_energy_embedding(
435 | x, energy_target, mel_mask, p_control
436 | )
437 | x = x + energy_embedding
438 |
439 | return (
440 | x,
441 | pitch_prediction,
442 | energy_prediction,
443 | log_duration_prediction,
444 | duration_rounded,
445 | mel_len,
446 | mel_mask,
447 | )
448 |
449 |
450 | class LengthRegulator(nn.Module):
451 | """ Length Regulator """
452 |
453 | def __init__(self):
454 | super(LengthRegulator, self).__init__()
455 |
456 | def LR(self, x, duration, max_len):
457 | output = list()
458 | mel_len = list()
459 | for batch, expand_target in zip(x, duration):
460 | expanded = self.expand(batch, expand_target)
461 | output.append(expanded)
462 | mel_len.append(expanded.shape[0])
463 |
464 | if max_len is not None:
465 | output = pad(output, max_len)
466 | else:
467 | output = pad(output)
468 |
469 | return output, torch.LongTensor(mel_len).to(device)
470 |
471 | def expand(self, batch, predicted):
472 | out = list()
473 |
474 | for i, vec in enumerate(batch):
475 | expand_size = predicted[i].item()
476 | out.append(vec.expand(max(int(expand_size), 0), -1))
477 | out = torch.cat(out, 0)
478 |
479 | return out
480 |
481 | def forward(self, x, duration, max_len):
482 | output, mel_len = self.LR(x, duration, max_len)
483 | return output, mel_len
484 |
485 |
486 | class VariancePredictor(nn.Module):
487 | """ Duration, Pitch and Energy Predictor """
488 |
489 | def __init__(self, model_config):
490 | super(VariancePredictor, self).__init__()
491 |
492 | self.input_size = model_config["transformer"]["encoder_hidden"]
493 | self.filter_size = model_config["variance_predictor"]["filter_size"]
494 | self.kernel = model_config["variance_predictor"]["kernel_size"]
495 | self.conv_output_size = model_config["variance_predictor"]["filter_size"]
496 | self.dropout = model_config["variance_predictor"]["dropout"]
497 |
498 | self.conv_layer = nn.Sequential(
499 | OrderedDict(
500 | [
501 | (
502 | "conv1d_1",
503 | Conv(
504 | self.input_size,
505 | self.filter_size,
506 | kernel_size=self.kernel,
507 | padding=(self.kernel - 1) // 2,
508 | ),
509 | ),
510 | ("relu_1", nn.ReLU()),
511 | ("layer_norm_1", nn.LayerNorm(self.filter_size)),
512 | ("dropout_1", nn.Dropout(self.dropout)),
513 | (
514 | "conv1d_2",
515 | Conv(
516 | self.filter_size,
517 | self.filter_size,
518 | kernel_size=self.kernel,
519 | padding=1,
520 | ),
521 | ),
522 | ("relu_2", nn.ReLU()),
523 | ("layer_norm_2", nn.LayerNorm(self.filter_size)),
524 | ("dropout_2", nn.Dropout(self.dropout)),
525 | ]
526 | )
527 | )
528 |
529 | self.linear_layer = nn.Linear(self.conv_output_size, 1)
530 |
531 | def forward(self, encoder_output, mask):
532 | out = self.conv_layer(encoder_output)
533 | out = self.linear_layer(out)
534 | out = out.squeeze(-1)
535 |
536 | if mask is not None:
537 | out = out.masked_fill(mask, 0.0)
538 |
539 | return out
540 |
541 |
542 | class Conv(nn.Module):
543 | """
544 | Convolution Module
545 | """
546 |
547 | def __init__(
548 | self,
549 | in_channels,
550 | out_channels,
551 | kernel_size=1,
552 | stride=1,
553 | padding=0,
554 | dilation=1,
555 | bias=True,
556 | w_init="linear",
557 | ):
558 | """
559 | :param in_channels: dimension of input
560 | :param out_channels: dimension of output
561 | :param kernel_size: size of kernel
562 | :param stride: size of stride
563 | :param padding: size of padding
564 | :param dilation: dilation rate
565 | :param bias: boolean. if True, bias is included.
566 | :param w_init: str. weight inits with xavier initialization.
567 | """
568 | super(Conv, self).__init__()
569 |
570 | self.conv = nn.Conv1d(
571 | in_channels,
572 | out_channels,
573 | kernel_size=kernel_size,
574 | stride=stride,
575 | padding=padding,
576 | dilation=dilation,
577 | bias=bias,
578 | )
579 |
580 | def forward(self, x):
581 | x = x.contiguous().transpose(1, 2)
582 | x = self.conv(x)
583 | x = x.contiguous().transpose(1, 2)
584 |
585 | return x
586 |
587 |
588 | class PhonemeDiscriminator(nn.Module):
589 | """ Phoneme Discriminator """
590 |
591 | def __init__(self, preprocess_config, model_config):
592 | super(PhonemeDiscriminator, self).__init__()
593 | n_mel_channels = preprocess_config["preprocessing"]["mel"]["n_mel_channels"]
594 | d_mel_linear = model_config["discriminator"]["mel_linear_size"]
595 | d_model = model_config["discriminator"]["phoneme_hidden"]
596 | d_layer = model_config["discriminator"]["phoneme_layer"]
597 |
598 | self.max_seq_len = model_config["max_seq_len"]
599 | self.mel_linear = nn.Sequential(
600 | FCBlock(n_mel_channels, d_mel_linear, activation=nn.LeakyReLU(), spectral_norm=True),
601 | FCBlock(d_mel_linear, d_mel_linear, activation=nn.LeakyReLU(), spectral_norm=True),
602 | )
603 | self.discriminator_stack = nn.ModuleList(
604 | [
605 | FCBlock(
606 | d_model, d_model, activation=nn.LeakyReLU(), spectral_norm=True
607 | )
608 | for _ in range(d_layer)
609 | ]
610 | )
611 | self.final_linear = FCBlock(d_model, 1, spectral_norm=True)
612 |
613 | def forward(self, upsampler, text, mel, max_len, mask, duration_target):
614 |
615 | # Prepare Upsampled Text
616 | upsampled_text, _, _, _ = upsampler(
617 | text, mask, max_len, duration_target=duration_target
618 | )
619 | max_len = min(max_len, self.max_seq_len)
620 | upsampled_text = upsampled_text[:, :max_len, :]
621 |
622 | # Prepare Mel
623 | mel = self.mel_linear(mel)[:, :max_len, :]
624 | mel = mel.masked_fill(mask.unsqueeze(-1)[:, :max_len, :], 0)
625 |
626 | # Prepare Input
627 | x = torch.cat([upsampled_text, mel], dim=-1)
628 |
629 | # Phoneme Discriminator
630 | for _, layer in enumerate(self.discriminator_stack):
631 | x = layer(x)
632 | x = self.final_linear(x) # [B, T, 1]
633 | x = x.masked_fill(mask.unsqueeze(-1)[:, :max_len, :], 0)
634 |
635 | # Temporal Average Pooling
636 | x = torch.mean(x, dim=1, keepdim=True) # [B, 1, 1]
637 | x = x.squeeze() # [B,]
638 |
639 | return x
640 |
641 |
642 | class StyleDiscriminator(nn.Module):
643 | """ Style Discriminator """
644 |
645 | def __init__(self, preprocess_config, model_config):
646 | super(StyleDiscriminator, self).__init__()
647 | n_position = model_config["max_seq_len"] + 1
648 | n_mel_channels = preprocess_config["preprocessing"]["mel"]["n_mel_channels"]
649 | d_melencoder = model_config["melencoder"]["encoder_hidden"]
650 | n_spectral_layer = model_config["melencoder"]["spectral_layer"]
651 | n_temporal_layer = model_config["melencoder"]["temporal_layer"]
652 | n_slf_attn_layer = model_config["melencoder"]["slf_attn_layer"]
653 | n_slf_attn_head = model_config["melencoder"]["slf_attn_head"]
654 | d_k = d_v = (
655 | model_config["melencoder"]["encoder_hidden"]
656 | // model_config["melencoder"]["slf_attn_head"]
657 | )
658 | kernel_size = model_config["melencoder"]["conv_kernel_size"]
659 |
660 | self.max_seq_len = model_config["max_seq_len"]
661 |
662 | self.fc_1 = FCBlock(n_mel_channels, d_melencoder, spectral_norm=True)
663 |
664 | self.spectral_stack = nn.ModuleList(
665 | [
666 | FCBlock(
667 | d_melencoder, d_melencoder, activation=nn.LeakyReLU(), spectral_norm=True
668 | )
669 | for _ in range(n_spectral_layer)
670 | ]
671 | )
672 |
673 | self.temporal_stack = nn.ModuleList(
674 | [
675 | Conv1DBlock(
676 | d_melencoder, d_melencoder, kernel_size, activation=nn.LeakyReLU(), spectral_norm=True
677 | )
678 | for _ in range(n_temporal_layer)
679 | ]
680 | )
681 |
682 | self.slf_attn_stack = nn.ModuleList(
683 | [
684 | MultiHeadAttention(
685 | n_slf_attn_head, d_melencoder, d_k, d_v, layer_norm=True, spectral_norm=True
686 | )
687 | for _ in range(n_slf_attn_layer)
688 | ]
689 | )
690 |
691 | self.fc_2 = FCBlock(d_melencoder, d_melencoder, spectral_norm=True)
692 |
693 | self.V = FCBlock(d_melencoder, d_melencoder)
694 | self.w_b_0 = FCBlock(1, 1, bias=True)
695 |
696 | def forward(self, style_prototype, speakers, mel, mask):
697 |
698 | max_len = mel.shape[1]
699 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
700 |
701 | x = self.fc_1(mel)
702 |
703 | # Spectral Processing
704 | for _, layer in enumerate(self.spectral_stack):
705 | x = layer(x)
706 |
707 | # Temporal Processing
708 | for _, layer in enumerate(self.temporal_stack):
709 | residual = x
710 | x = layer(x)
711 | x = residual + x
712 |
713 | # Multi-head self-attention
714 | for _, layer in enumerate(self.slf_attn_stack):
715 | residual = x
716 | x, _ = layer(
717 | x, x, x, mask=slf_attn_mask
718 | )
719 | x = residual + x
720 |
721 | # Final Layer
722 | x = self.fc_2(x) # [B, T, H]
723 |
724 | # Temporal Average Pooling, h(x)
725 | x = torch.mean(x, dim=1, keepdim=True) # [B, 1, H]
726 |
727 | # Output Computation
728 | s_i = style_prototype(speakers) # [B, H]
729 | V = self.V(s_i).unsqueeze(2) # [B, H, 1]
730 | o = torch.matmul(x, V).squeeze(2) # [B, 1]
731 | o = self.w_b_0(o).squeeze() # [B,]
732 |
733 | return o
--------------------------------------------------------------------------------
/model/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ScheduledOptimMain:
6 | """ A simple wrapper class for learning rate scheduling """
7 |
8 | def __init__(self, model, train_config, model_config, current_step):
9 | self._optimizer = torch.optim.Adam(
10 | [param for name, param in model.named_parameters()
11 | if not any([filtered_name in name for filtered_name in ['D_s', 'D_t']])],
12 | betas=train_config["optimizer"]["betas"],
13 | eps=train_config["optimizer"]["eps"],
14 | weight_decay=train_config["optimizer"]["weight_decay"],
15 | )
16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"]
17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"]
18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"]
19 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5)
20 | # meta_learning_warmup = train_config["step"]["meta_learning_warmup"]
21 | self.current_step = current_step# if current_step <= meta_learning_warmup else current_step - meta_learning_warmup
22 |
23 | def step_and_update_lr(self):
24 | self._update_learning_rate()
25 | self._optimizer.step()
26 |
27 | def zero_grad(self):
28 | # print(self.init_lr)
29 | self._optimizer.zero_grad()
30 |
31 | def load_state_dict(self, state_dict):
32 | state_dict['param_groups'] = self._optimizer.state_dict()['param_groups']
33 | self._optimizer.load_state_dict(state_dict)
34 |
35 | def _get_lr_scale(self):
36 | lr = np.min(
37 | [
38 | np.power(self.current_step, -0.5),
39 | np.power(self.n_warmup_steps, -1.5) * self.current_step,
40 | ]
41 | )
42 | for s in self.anneal_steps:
43 | if self.current_step > s:
44 | lr = lr * self.anneal_rate
45 | return lr
46 |
47 | def _update_learning_rate(self):
48 | """ Learning rate scheduling per step """
49 | self.current_step += 1
50 | lr = self.init_lr * self._get_lr_scale()
51 |
52 | for param_group in self._optimizer.param_groups:
53 | param_group["lr"] = lr
54 |
55 |
56 | class ScheduledOptimDisc:
57 | """ A simple wrapper class for learning rate scheduling """
58 |
59 | def __init__(self, model, train_config):
60 |
61 | self._optimizer = torch.optim.Adam(
62 | [param for name, param in model.named_parameters()
63 | if any([filtered_name in name for filtered_name in ['D_s', 'D_t']])],
64 | betas=train_config["optimizer"]["betas"],
65 | eps=train_config["optimizer"]["eps"],
66 | weight_decay=train_config["optimizer"]["weight_decay"],
67 | )
68 | self.init_lr = train_config["optimizer"]["lr_disc"]
69 | self._init_learning_rate()
70 |
71 | def step_and_update_lr(self):
72 | self._optimizer.step()
73 |
74 | def zero_grad(self):
75 | # print(self.init_lr)
76 | self._optimizer.zero_grad()
77 |
78 | def load_state_dict(self, state_dict):
79 | state_dict['param_groups'] = self._optimizer.state_dict()['param_groups']
80 | self._optimizer.load_state_dict(state_dict)
81 |
82 | def _init_learning_rate(self):
83 | lr = self.init_lr
84 | for param_group in self._optimizer.param_groups:
85 | param_group["lr"] = lr
86 |
--------------------------------------------------------------------------------
/prepare_align.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import yaml
4 |
5 | from preprocessor import ljspeech, aishell3, libritts
6 |
7 |
8 | def main(config):
9 | if "LJSpeech" in config["dataset"]:
10 | ljspeech.prepare_align(config)
11 | if "AISHELL3" in config["dataset"]:
12 | aishell3.prepare_align(config)
13 | if "LibriTTS" in config["dataset"]:
14 | libritts.prepare_align(config)
15 |
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("config", type=str, help="path to preprocess.yaml")
20 | args = parser.parse_args()
21 |
22 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
23 | main(config)
24 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import yaml
4 |
5 | from preprocessor.preprocessor import Preprocessor
6 |
7 |
8 | if __name__ == "__main__":
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("config", type=str, help="path to preprocess.yaml")
11 | args = parser.parse_args()
12 |
13 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
14 | preprocessor = Preprocessor(config)
15 | preprocessor.build_from_path()
16 |
--------------------------------------------------------------------------------
/preprocessed_data/LibriTTS/speakers.json:
--------------------------------------------------------------------------------
1 | {"229": 0, "5789": 1, "1034": 2, "3983": 3, "5463": 4, "4481": 5, "1355": 6, "1970": 7, "7264": 8, "625": 9, "5688": 10, "6880": 11, "2002": 12, "1502": 13, "7517": 14, "298": 15, "2989": 16, "8468": 17, "3235": 18, "103": 19, "1455": 20, "3947": 21, "887": 22, "669": 23, "5393": 24, "8838": 25, "5750": 26, "3242": 27, "3857": 28, "696": 29, "60": 30, "78": 31, "125": 32, "1898": 33, "2514": 34, "6454": 35, "6836": 36, "3112": 37, "6925": 38, "839": 39, "8630": 40, "4859": 41, "7148": 42, "6078": 43, "2007": 44, "3723": 45, "1447": 46, "4788": 47, "4340": 48, "8051": 49, "458": 50, "5456": 51, "39": 52, "2893": 53, "4898": 54, "8629": 55, "3259": 56, "5514": 57, "3879": 58, "8580": 59, "3664": 60, "6529": 61, "307": 62, "481": 63, "5022": 64, "2843": 65, "3440": 66, "83": 67, "3436": 68, "7078": 69, "311": 70, "405": 71, "2136": 72, "118": 73, "6385": 74, "1963": 75, "730": 76, "3830": 77, "8238": 78, "412": 79, "6437": 80, "587": 81, "89": 82, "6272": 83, "1081": 84, "6848": 85, "254": 86, "4297": 87, "4680": 88, "7113": 89, "27": 90, "374": 91, "7794": 92, "1926": 93, "8123": 94, "3486": 95, "2416": 96, "6818": 97, "8095": 98, "7278": 99, "332": 100, "200": 101, "7800": 102, "6147": 103, "2952": 104, "7190": 105, "1088": 106, "248": 107, "32": 108, "6000": 109, "6531": 110, "909": 111, "1578": 112, "7635": 113, "6476": 114, "5678": 115, "3526": 116, "1867": 117, "446": 118, "8609": 119, "5049": 120, "302": 121, "250": 122, "1743": 123, "4853": 124, "1363": 125, "8797": 126, "4406": 127, "6563": 128, "2159": 129, "2436": 130, "8465": 131, "426": 132, "7178": 133, "4088": 134, "8312": 135, "8098": 136, "3240": 137, "2836": 138, "198": 139, "2384": 140, "7067": 141, "5778": 142, "7402": 143, "5808": 144, "7505": 145, "1841": 146, "5163": 147, "1098": 148, "4397": 149, "3982": 150, "7226": 151, "26": 152, "6019": 153, "2196": 154, "4214": 155, "4160": 156, "8419": 157, "8226": 158, "1183": 159, "1235": 160, "1334": 161, "7447": 162, "40": 163, "8770": 164, "196": 165, "5339": 166, "3214": 167, "6367": 168, "1069": 169, "5703": 170, "201": 171, "4137": 172, "5104": 173, "7859": 174, "8088": 175, "87": 176, "2691": 177, "3607": 178, "7302": 179, "1263": 180, "1553": 181, "4640": 182, "2182": 183, "5322": 184, "2817": 185, "8108": 186, "19": 187, "211": 188, "1624": 189, "150": 190, "4195": 191, "8975": 192, "6081": 193, "1246": 194, "3807": 195, "5867": 196, "8747": 197, "1737": 198, "8063": 199, "5192": 200, "2764": 201, "6209": 202, "2518": 203, "8014": 204, "1992": 205, "1594": 206, "3374": 207, "460": 208, "4267": 209, "163": 210, "289": 211, "403": 212, "3699": 213, "5390": 214, "233": 215, "7511": 216, "2391": 217, "4014": 218, "7312": 219, "4018": 220, "4441": 221, "2289": 222, "7780": 223, "6181": 224, "5652": 225, "911": 226, "7059": 227, "1040": 228, "831": 229, "1116": 230, "4830": 231, "4813": 232, "6415": 233, "7367": 234, "8425": 235, "5561": 236, "2911": 237, "6064": 238, "8324": 239, "2092": 240, "3168": 241, "2910": 242, "4051": 243, "4362": 244, "322": 245, "226": 246}
--------------------------------------------------------------------------------
/preprocessed_data/LibriTTS/stats.json:
--------------------------------------------------------------------------------
1 | {"pitch": [-2.7442070957317246, 11.15693693509074, 165.6901007628454, 60.37813291152693], "energy": [-1.2591148614883423, 10.99685001373291, 41.67014254956833, 33.09479306846138], "mel": [-11.512925148010254, 2.3185698986053467]}
--------------------------------------------------------------------------------
/preprocessor/aishell3.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 |
8 |
9 | def prepare_align(config):
10 | in_dir = config["path"]["corpus_path"]
11 | out_dir = config["path"]["raw_path"]
12 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
13 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
14 | for dataset in ["train", "test"]:
15 | print("Processing {}ing set...".format(dataset))
16 | with open(os.path.join(in_dir, dataset, "content.txt"), encoding="utf-8") as f:
17 | for line in tqdm(f):
18 | wav_name, text = line.strip("\n").split("\t")
19 | speaker = wav_name[:7]
20 | text = text.split(" ")[1::2]
21 | wav_path = os.path.join(in_dir, dataset, "wav", speaker, wav_name)
22 | if os.path.exists(wav_path):
23 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
24 | wav, _ = librosa.load(wav_path, sampling_rate)
25 | wav = wav / max(abs(wav)) * max_wav_value
26 | wavfile.write(
27 | os.path.join(out_dir, speaker, wav_name),
28 | sampling_rate,
29 | wav.astype(np.int16),
30 | )
31 | with open(
32 | os.path.join(out_dir, speaker, "{}.lab".format(wav_name[:11])),
33 | "w",
34 | ) as f1:
35 | f1.write(" ".join(text))
--------------------------------------------------------------------------------
/preprocessor/libritts.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 |
8 | from text import _clean_text
9 |
10 |
11 | def prepare_align(config):
12 | in_dir = config["path"]["corpus_path"]
13 | out_dir = config["path"]["raw_path"]
14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
16 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
17 | for speaker in tqdm(os.listdir(in_dir)):
18 | for chapter in os.listdir(os.path.join(in_dir, speaker)):
19 | for file_name in os.listdir(os.path.join(in_dir, speaker, chapter)):
20 | if file_name[-4:] != ".wav":
21 | continue
22 | base_name = file_name[:-4]
23 | text_path = os.path.join(
24 | in_dir, speaker, chapter, "{}.normalized.txt".format(base_name)
25 | )
26 | wav_path = os.path.join(
27 | in_dir, speaker, chapter, "{}.wav".format(base_name)
28 | )
29 | with open(text_path) as f:
30 | text = f.readline().strip("\n")
31 | text = _clean_text(text, cleaners)
32 |
33 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
34 | wav, _ = librosa.load(wav_path, sampling_rate)
35 | wav = wav / max(abs(wav)) * max_wav_value
36 | wavfile.write(
37 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
38 | sampling_rate,
39 | wav.astype(np.int16),
40 | )
41 | with open(
42 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
43 | "w",
44 | ) as f1:
45 | f1.write(text)
--------------------------------------------------------------------------------
/preprocessor/ljspeech.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | from scipy.io import wavfile
6 | from tqdm import tqdm
7 |
8 | from text import _clean_text
9 |
10 |
11 | def prepare_align(config):
12 | in_dir = config["path"]["corpus_path"]
13 | out_dir = config["path"]["raw_path"]
14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
16 | cleaners = config["preprocessing"]["text"]["text_cleaners"]
17 | speaker = "LJSpeech"
18 | with open(os.path.join(in_dir, "metadata.csv"), encoding="utf-8") as f:
19 | for line in tqdm(f):
20 | parts = line.strip().split("|")
21 | base_name = parts[0]
22 | text = parts[2]
23 | text = _clean_text(text, cleaners)
24 |
25 | wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name))
26 | if os.path.exists(wav_path):
27 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
28 | wav, _ = librosa.load(wav_path, sampling_rate)
29 | wav = wav / max(abs(wav)) * max_wav_value
30 | wavfile.write(
31 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
32 | sampling_rate,
33 | wav.astype(np.int16),
34 | )
35 | with open(
36 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
37 | "w",
38 | ) as f1:
39 | f1.write(text)
--------------------------------------------------------------------------------
/preprocessor/preprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 |
5 | import tgt
6 | import librosa
7 | import numpy as np
8 | import pyworld as pw
9 | from scipy.interpolate import interp1d
10 | from sklearn.preprocessing import StandardScaler
11 | from tqdm import tqdm
12 |
13 | import audio as Audio
14 |
15 | random.seed(1234)
16 |
17 |
18 | class Preprocessor:
19 | def __init__(self, config):
20 | self.config = config
21 | self.in_dir = config["path"]["raw_path"]
22 | self.out_dir = config["path"]["preprocessed_path"]
23 | self.val_size = config["preprocessing"]["val_size"]
24 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
25 | self.hop_length = config["preprocessing"]["stft"]["hop_length"]
26 |
27 | assert config["preprocessing"]["pitch"]["feature"] in [
28 | "phoneme_level",
29 | "frame_level",
30 | ]
31 | assert config["preprocessing"]["energy"]["feature"] in [
32 | "phoneme_level",
33 | "frame_level",
34 | ]
35 | self.pitch_phoneme_averaging = (
36 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level"
37 | )
38 | self.energy_phoneme_averaging = (
39 | config["preprocessing"]["energy"]["feature"] == "phoneme_level"
40 | )
41 |
42 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"]
43 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"]
44 |
45 | self.STFT = Audio.stft.TacotronSTFT(
46 | config["preprocessing"]["stft"]["filter_length"],
47 | config["preprocessing"]["stft"]["hop_length"],
48 | config["preprocessing"]["stft"]["win_length"],
49 | config["preprocessing"]["mel"]["n_mel_channels"],
50 | config["preprocessing"]["audio"]["sampling_rate"],
51 | config["preprocessing"]["mel"]["mel_fmin"],
52 | config["preprocessing"]["mel"]["mel_fmax"],
53 | )
54 |
55 | def build_from_path(self):
56 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True)
57 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True)
58 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True)
59 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True)
60 |
61 | print("Processing Data ...")
62 | out = list()
63 | n_frames = 0
64 | mel_min = float('inf')
65 | mel_max = -float('inf')
66 | pitch_scaler = StandardScaler()
67 | energy_scaler = StandardScaler()
68 |
69 | # Compute pitch, energy, duration, and mel-spectrogram
70 | speakers = {}
71 | for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))):
72 | speakers[speaker] = i
73 | for wav_name in tqdm(os.listdir(os.path.join(self.in_dir, speaker))):
74 | if ".wav" not in wav_name:
75 | continue
76 |
77 | basename = wav_name.split(".")[0]
78 | chapter = basename.split("_")[1]
79 | tg_path = os.path.join(
80 | self.out_dir, "TextGrid", speaker, chapter, "{}.TextGrid".format(basename)
81 | )
82 | if os.path.exists(tg_path):
83 | ret = self.process_utterance(speaker, chapter, basename)
84 | if ret is None:
85 | continue
86 | else:
87 | info, pitch, energy, n, m_min, m_max = ret
88 | out.append(info)
89 |
90 | if len(pitch) > 0:
91 | pitch_scaler.partial_fit(pitch.reshape((-1, 1)))
92 | if len(energy) > 0:
93 | energy_scaler.partial_fit(energy.reshape((-1, 1)))
94 | if mel_min > m_min:
95 | mel_min = m_min
96 | if mel_max < m_max:
97 | mel_max = m_max
98 |
99 | n_frames += n
100 |
101 | print("Computing statistic quantities ...")
102 | # Perform normalization if necessary
103 | if self.pitch_normalization:
104 | pitch_mean = pitch_scaler.mean_[0]
105 | pitch_std = pitch_scaler.scale_[0]
106 | else:
107 | # A numerical trick to avoid normalization...
108 | pitch_mean = 0
109 | pitch_std = 1
110 | if self.energy_normalization:
111 | energy_mean = energy_scaler.mean_[0]
112 | energy_std = energy_scaler.scale_[0]
113 | else:
114 | energy_mean = 0
115 | energy_std = 1
116 |
117 | pitch_min, pitch_max = self.normalize(
118 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std
119 | )
120 | energy_min, energy_max = self.normalize(
121 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std
122 | )
123 |
124 | # Save files
125 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f:
126 | f.write(json.dumps(speakers))
127 |
128 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f:
129 | stats = {
130 | "pitch": [
131 | float(pitch_min),
132 | float(pitch_max),
133 | float(pitch_mean),
134 | float(pitch_std),
135 | ],
136 | "energy": [
137 | float(energy_min),
138 | float(energy_max),
139 | float(energy_mean),
140 | float(energy_std),
141 | ],
142 | "mel": [
143 | float(mel_min),
144 | float(mel_max),
145 | ],
146 | }
147 | f.write(json.dumps(stats))
148 |
149 | print(
150 | "Total time: {} hours".format(
151 | n_frames * self.hop_length / self.sampling_rate / 3600
152 | )
153 | )
154 |
155 | random.shuffle(out)
156 | out = [r for r in out if r is not None]
157 |
158 | # Write metadata
159 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f:
160 | for m in out[self.val_size :]:
161 | f.write(m + "\n")
162 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f:
163 | for m in out[: self.val_size]:
164 | f.write(m + "\n")
165 |
166 | return out
167 |
168 | def process_utterance(self, speaker, chapter, basename):
169 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename))
170 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename))
171 | tg_path = os.path.join(
172 | self.out_dir, "TextGrid", speaker, chapter, "{}.TextGrid".format(basename)
173 | )
174 |
175 | # Get alignments
176 | textgrid = tgt.io.read_textgrid(tg_path)
177 | phone, duration, start, end = self.get_alignment(
178 | textgrid.get_tier_by_name("phones")
179 | )
180 | text = "{" + " ".join(phone) + "}"
181 | if start >= end:
182 | return None
183 |
184 | # Read and trim wav files
185 | wav, _ = librosa.load(wav_path)
186 | wav = wav[
187 | int(self.sampling_rate * start) : int(self.sampling_rate * end)
188 | ].astype(np.float32)
189 |
190 | # Read raw text
191 | with open(text_path, "r") as f:
192 | raw_text = f.readline().strip("\n")
193 |
194 | # Compute fundamental frequency
195 | pitch, t = pw.dio(
196 | wav.astype(np.float64),
197 | self.sampling_rate,
198 | frame_period=self.hop_length / self.sampling_rate * 1000,
199 | )
200 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate)
201 |
202 | pitch = pitch[: sum(duration)]
203 | if np.sum(pitch != 0) <= 1:
204 | return None
205 |
206 | # Compute mel-scale spectrogram and energy
207 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT)
208 | mel_spectrogram = mel_spectrogram[:, : sum(duration)]
209 | energy = energy[: sum(duration)]
210 |
211 | if self.pitch_phoneme_averaging:
212 | # perform linear interpolation
213 | nonzero_ids = np.where(pitch != 0)[0]
214 | interp_fn = interp1d(
215 | nonzero_ids,
216 | pitch[nonzero_ids],
217 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
218 | bounds_error=False,
219 | )
220 | pitch = interp_fn(np.arange(0, len(pitch)))
221 |
222 | # Phoneme-level average
223 | pos = 0
224 | for i, d in enumerate(duration):
225 | if d > 0:
226 | pitch[i] = np.mean(pitch[pos : pos + d])
227 | else:
228 | pitch[i] = 0
229 | pos += d
230 | pitch = pitch[: len(duration)]
231 |
232 | if self.energy_phoneme_averaging:
233 | # Phoneme-level average
234 | pos = 0
235 | for i, d in enumerate(duration):
236 | if d > 0:
237 | energy[i] = np.mean(energy[pos : pos + d])
238 | else:
239 | energy[i] = 0
240 | pos += d
241 | energy = energy[: len(duration)]
242 |
243 | # Save files
244 | dur_filename = "{}-duration-{}.npy".format(speaker, basename)
245 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration)
246 |
247 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename)
248 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch)
249 |
250 | energy_filename = "{}-energy-{}.npy".format(speaker, basename)
251 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy)
252 |
253 | mel_filename = "{}-mel-{}.npy".format(speaker, basename)
254 | np.save(
255 | os.path.join(self.out_dir, "mel", mel_filename),
256 | mel_spectrogram.T,
257 | )
258 |
259 | return (
260 | "|".join([basename, speaker, text, raw_text]),
261 | self.remove_outlier(pitch),
262 | self.remove_outlier(energy),
263 | mel_spectrogram.shape[1],
264 | np.min(mel_spectrogram),
265 | np.max(mel_spectrogram),
266 | )
267 |
268 | def get_alignment(self, tier):
269 | sil_phones = ["sil", "sp", "spn"]
270 |
271 | phones = []
272 | durations = []
273 | start_time = 0
274 | end_time = 0
275 | end_idx = 0
276 | for t in tier._objects:
277 | s, e, p = t.start_time, t.end_time, t.text
278 |
279 | # Trim leading silences
280 | if phones == []:
281 | if p in sil_phones:
282 | continue
283 | else:
284 | start_time = s
285 |
286 | if p not in sil_phones:
287 | # For ordinary phones
288 | phones.append(p)
289 | end_time = e
290 | end_idx = len(phones)
291 | else:
292 | # For silent phones
293 | phones.append(p)
294 |
295 | durations.append(
296 | int(
297 | np.round(e * self.sampling_rate / self.hop_length)
298 | - np.round(s * self.sampling_rate / self.hop_length)
299 | )
300 | )
301 |
302 | # Trim tailing silences
303 | phones = phones[:end_idx]
304 | durations = durations[:end_idx]
305 |
306 | return phones, durations, start_time, end_time
307 |
308 | def remove_outlier(self, values):
309 | values = np.array(values)
310 | p25 = np.percentile(values, 25)
311 | p75 = np.percentile(values, 75)
312 | lower = p25 - 1.5 * (p75 - p25)
313 | upper = p75 + 1.5 * (p75 - p25)
314 | normal_indices = np.logical_and(values > lower, values < upper)
315 |
316 | return values[normal_indices]
317 |
318 | def normalize(self, in_dir, mean, std):
319 | max_value = np.finfo(np.float64).min
320 | min_value = np.finfo(np.float64).max
321 | for filename in os.listdir(in_dir):
322 | filename = os.path.join(in_dir, filename)
323 | values = (np.load(filename) - mean) / std
324 | np.save(filename, values)
325 |
326 | max_value = max(max_value, max(values))
327 | min_value = min(min_value, min(values))
328 |
329 | return min_value, max_value
330 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | g2p-en==2.1.0
2 | inflect==4.1.0
3 | librosa==0.7.2
4 | matplotlib==3.4.2
5 | numba==0.48
6 | numpy==1.19.0
7 | pypinyin==0.39.0
8 | pyworld==0.3.0
9 | PyYAML==5.4.1
10 | scikit-learn==0.23.2
11 | scipy==1.6.3
12 | soundfile==0.10.3.post1
13 | tensorboard==2.2.2
14 | tgt==1.4.4
15 | torch==1.8.1
16 | tqdm==4.46.1
17 | unidecode==1.1.1
--------------------------------------------------------------------------------
/synthesize.py:
--------------------------------------------------------------------------------
1 | import re
2 | import argparse
3 | from string import punctuation
4 | import torch
5 | import yaml
6 | import numpy as np
7 | import os
8 | import json
9 |
10 | import librosa
11 | import pyworld as pw
12 | import audio as Audio
13 |
14 | from torch.utils.data import DataLoader
15 | from g2p_en import G2p
16 | from pypinyin import pinyin, Style
17 |
18 | from utils.model import get_model, get_vocoder
19 | from utils.tools import to_device, synth_samples
20 | from dataset import BatchInferenceDataset
21 | from text import text_to_sequence
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 |
25 |
26 | def read_lexicon(lex_path):
27 | lexicon = {}
28 | with open(lex_path) as f:
29 | for line in f:
30 | temp = re.split(r"\s+", line.strip("\n"))
31 | word = temp[0]
32 | phones = temp[1:]
33 | if word.lower() not in lexicon:
34 | lexicon[word.lower()] = phones
35 | return lexicon
36 |
37 |
38 | def get_audio(preprocess_config, wav_path):
39 |
40 | hop_length = preprocess_config["preprocessing"]["stft"]["hop_length"]
41 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
42 | STFT = Audio.stft.TacotronSTFT(
43 | preprocess_config["preprocessing"]["stft"]["filter_length"],
44 | hop_length,
45 | preprocess_config["preprocessing"]["stft"]["win_length"],
46 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
47 | sampling_rate,
48 | preprocess_config["preprocessing"]["mel"]["mel_fmin"],
49 | preprocess_config["preprocessing"]["mel"]["mel_fmax"],
50 | )
51 | with open(
52 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
53 | ) as f:
54 | stats = json.load(f)
55 | stats = stats["pitch"][2:] + stats["energy"][2:]
56 | pitch_mean, pitch_std, energy_mean, energy_std = stats
57 |
58 | # Read and trim wav files
59 | wav, _ = librosa.load(wav_path)
60 |
61 | # Compute fundamental frequency
62 | pitch, t = pw.dio(
63 | wav.astype(np.float64),
64 | sampling_rate,
65 | frame_period=hop_length / sampling_rate * 1000,
66 | )
67 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, sampling_rate)
68 |
69 | # Compute mel-scale spectrogram and energy
70 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav.astype(np.float32), STFT)
71 |
72 | # Normalize Variance
73 | pitch = (pitch - pitch_mean) / pitch_std
74 | energy = (energy - energy_mean) / energy_std
75 |
76 | mels = mel_spectrogram.T[None]
77 | mel_lens = np.array([len(mels[0])])
78 |
79 | mel_spectrogram = mel_spectrogram.astype(np.float32)
80 | energy = energy.astype(np.float32)
81 |
82 | return mels, mel_lens, (mel_spectrogram, pitch, energy)
83 |
84 |
85 | def preprocess_english(text, preprocess_config):
86 | text = text.rstrip(punctuation)
87 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
88 |
89 | g2p = G2p()
90 | phones = []
91 | words = re.split(r"([,;.\-\?\!\s+])", text)
92 | for w in words:
93 | if w.lower() in lexicon:
94 | phones += lexicon[w.lower()]
95 | else:
96 | phones += list(filter(lambda p: p != " ", g2p(w)))
97 | phones = "{" + "}{".join(phones) + "}"
98 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
99 | phones = phones.replace("}{", " ")
100 |
101 | print("Raw Text Sequence: {}".format(text))
102 | print("Phoneme Sequence: {}".format(phones))
103 | sequence = np.array(
104 | text_to_sequence(
105 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
106 | )
107 | )
108 |
109 | return np.array(sequence)
110 |
111 |
112 | def preprocess_mandarin(text, preprocess_config):
113 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
114 |
115 | phones = []
116 | pinyins = [
117 | p[0]
118 | for p in pinyin(
119 | text, style=Style.TONE3, strict=False, neutral_tone_with_five=True
120 | )
121 | ]
122 | for p in pinyins:
123 | if p in lexicon:
124 | phones += lexicon[p]
125 | else:
126 | phones.append("sp")
127 |
128 | phones = "{" + " ".join(phones) + "}"
129 | print("Raw Text Sequence: {}".format(text))
130 | print("Phoneme Sequence: {}".format(phones))
131 | sequence = np.array(
132 | text_to_sequence(
133 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
134 | )
135 | )
136 |
137 | return np.array(sequence)
138 |
139 |
140 | def synthesize(model, step, configs, vocoder, batchs, control_values):
141 | preprocess_config, model_config, train_config = configs
142 | pitch_control, energy_control, duration_control = control_values
143 |
144 | for batch in batchs:
145 | batch = to_device(batch, device)
146 | with torch.no_grad():
147 | # Forward
148 | output = model(
149 | *(batch[2:-1]),
150 | p_control=pitch_control,
151 | e_control=energy_control,
152 | d_control=duration_control
153 | )
154 | synth_samples(
155 | batch,
156 | output,
157 | vocoder,
158 | model_config,
159 | preprocess_config,
160 | train_config["path"]["result_path"],
161 | )
162 |
163 |
164 | if __name__ == "__main__":
165 |
166 | parser = argparse.ArgumentParser()
167 | parser.add_argument("--restore_step", type=int, required=True)
168 | parser.add_argument(
169 | "--mode",
170 | type=str,
171 | choices=["batch", "single"],
172 | required=True,
173 | help="Synthesize a whole dataset or a single sentence",
174 | )
175 | parser.add_argument(
176 | "--source",
177 | type=str,
178 | default=None,
179 | help="path to a source file with format like train.txt and val.txt, for batch mode only",
180 | )
181 | parser.add_argument(
182 | "--text",
183 | type=str,
184 | default=None,
185 | help="raw text to synthesize, for single-sentence mode only",
186 | )
187 | parser.add_argument(
188 | "--ref_audio",
189 | type=str,
190 | default=None,
191 | help="reference audio path to extract the speech style, for single-sentence mode only",
192 | )
193 | parser.add_argument(
194 | "-p",
195 | "--preprocess_config",
196 | type=str,
197 | required=True,
198 | help="path to preprocess.yaml",
199 | )
200 | parser.add_argument(
201 | "-m", "--model_config", type=str, required=True, help="path to model.yaml"
202 | )
203 | parser.add_argument(
204 | "-t", "--train_config", type=str, required=True, help="path to train.yaml"
205 | )
206 | parser.add_argument(
207 | "--pitch_control",
208 | type=float,
209 | default=1.0,
210 | help="control the pitch of the whole utterance, larger value for higher pitch",
211 | )
212 | parser.add_argument(
213 | "--energy_control",
214 | type=float,
215 | default=1.0,
216 | help="control the energy of the whole utterance, larger value for larger volume",
217 | )
218 | parser.add_argument(
219 | "--duration_control",
220 | type=float,
221 | default=1.0,
222 | help="control the speed of the whole utterance, larger value for slower speaking rate",
223 | )
224 | args = parser.parse_args()
225 |
226 | # Check source texts
227 | if args.mode == "batch":
228 | assert args.source is not None and args.text is None
229 | if args.mode == "single":
230 | assert args.source is None and args.text is not None
231 |
232 | # Read Config
233 | preprocess_config = yaml.load(
234 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
235 | )
236 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
237 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
238 | configs = (preprocess_config, model_config, train_config)
239 |
240 | # Get model
241 | model = get_model(args, configs, device, train=False)
242 |
243 | # Load vocoder
244 | vocoder = get_vocoder(model_config, device)
245 |
246 | # Preprocess texts
247 | if args.mode == "batch":
248 | # Get dataset
249 | dataset = BatchInferenceDataset(args.source, preprocess_config)
250 | batchs = DataLoader(
251 | dataset,
252 | batch_size=8,
253 | collate_fn=dataset.collate_fn,
254 | )
255 | if args.mode == "single":
256 | ids = raw_texts = [args.text[:100]]
257 | if preprocess_config["preprocessing"]["text"]["language"] == "en":
258 | texts = np.array([preprocess_english(args.text, preprocess_config)])
259 | elif preprocess_config["preprocessing"]["text"]["language"] == "zh":
260 | texts = np.array([preprocess_mandarin(args.text, preprocess_config)])
261 | text_lens = np.array([len(texts[0])])
262 | mels, mel_lens, ref_info = get_audio(preprocess_config, args.ref_audio)
263 | batchs = [(["_".join([os.path.basename(args.ref_audio).strip(".wav"), id]) for id in ids], \
264 | raw_texts, None, texts, text_lens, max(text_lens), mels, mel_lens, max(mel_lens), [ref_info])]
265 |
266 | control_values = args.pitch_control, args.energy_control, args.duration_control
267 |
268 | synthesize(model, args.restore_step, configs, vocoder, batchs, control_values)
269 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import re
3 | from text import cleaners
4 | from text.symbols import symbols
5 |
6 |
7 | # Mappings from symbol to numeric ID and vice versa:
8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
10 |
11 | # Regular expression matching text enclosed in curly braces:
12 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
13 |
14 |
15 | def text_to_sequence(text, cleaner_names):
16 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
17 |
18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
20 |
21 | Args:
22 | text: string to convert to a sequence
23 | cleaner_names: names of the cleaner functions to run the text through
24 |
25 | Returns:
26 | List of integers corresponding to the symbols in the text
27 | """
28 | sequence = []
29 |
30 | # Check for curly braces and treat their contents as ARPAbet:
31 | while len(text):
32 | m = _curly_re.match(text)
33 |
34 | if not m:
35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
36 | break
37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
38 | sequence += _arpabet_to_sequence(m.group(2))
39 | text = m.group(3)
40 |
41 | return sequence
42 |
43 |
44 | def sequence_to_text(sequence):
45 | """Converts a sequence of IDs back to a string"""
46 | result = ""
47 | for symbol_id in sequence:
48 | if symbol_id in _id_to_symbol:
49 | s = _id_to_symbol[symbol_id]
50 | # Enclose ARPAbet back in curly braces:
51 | if len(s) > 1 and s[0] == "@":
52 | s = "{%s}" % s[1:]
53 | result += s
54 | return result.replace("}{", " ")
55 |
56 |
57 | def _clean_text(text, cleaner_names):
58 | for name in cleaner_names:
59 | cleaner = getattr(cleaners, name)
60 | if not cleaner:
61 | raise Exception("Unknown cleaner: %s" % name)
62 | text = cleaner(text)
63 | return text
64 |
65 |
66 | def _symbols_to_sequence(symbols):
67 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
68 |
69 |
70 | def _arpabet_to_sequence(text):
71 | return _symbols_to_sequence(["@" + s for s in text.split()])
72 |
73 |
74 | def _should_keep_symbol(s):
75 | return s in _symbol_to_id and s != "_" and s != "~"
76 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 | import re
18 | from unidecode import unidecode
19 | from .numbers import normalize_numbers
20 | _whitespace_re = re.compile(r'\s+')
21 |
22 | # List of (regular expression, replacement) pairs for abbreviations:
23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
24 | ('mrs', 'misess'),
25 | ('mr', 'mister'),
26 | ('dr', 'doctor'),
27 | ('st', 'saint'),
28 | ('co', 'company'),
29 | ('jr', 'junior'),
30 | ('maj', 'major'),
31 | ('gen', 'general'),
32 | ('drs', 'doctors'),
33 | ('rev', 'reverend'),
34 | ('lt', 'lieutenant'),
35 | ('hon', 'honorable'),
36 | ('sgt', 'sergeant'),
37 | ('capt', 'captain'),
38 | ('esq', 'esquire'),
39 | ('ltd', 'limited'),
40 | ('col', 'colonel'),
41 | ('ft', 'fort'),
42 | ]]
43 |
44 |
45 | def expand_abbreviations(text):
46 | for regex, replacement in _abbreviations:
47 | text = re.sub(regex, replacement, text)
48 | return text
49 |
50 |
51 | def expand_numbers(text):
52 | return normalize_numbers(text)
53 |
54 |
55 | def lowercase(text):
56 | return text.lower()
57 |
58 |
59 | def collapse_whitespace(text):
60 | return re.sub(_whitespace_re, ' ', text)
61 |
62 |
63 | def convert_to_ascii(text):
64 | return unidecode(text)
65 |
66 |
67 | def basic_cleaners(text):
68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
69 | text = lowercase(text)
70 | text = collapse_whitespace(text)
71 | return text
72 |
73 |
74 | def transliteration_cleaners(text):
75 | '''Pipeline for non-English text that transliterates to ASCII.'''
76 | text = convert_to_ascii(text)
77 | text = lowercase(text)
78 | text = collapse_whitespace(text)
79 | return text
80 |
81 |
82 | def english_cleaners(text):
83 | '''Pipeline for English text, including number and abbreviation expansion.'''
84 | text = convert_to_ascii(text)
85 | text = lowercase(text)
86 | text = expand_numbers(text)
87 | text = expand_abbreviations(text)
88 | text = collapse_whitespace(text)
89 | return text
90 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | "AA",
8 | "AA0",
9 | "AA1",
10 | "AA2",
11 | "AE",
12 | "AE0",
13 | "AE1",
14 | "AE2",
15 | "AH",
16 | "AH0",
17 | "AH1",
18 | "AH2",
19 | "AO",
20 | "AO0",
21 | "AO1",
22 | "AO2",
23 | "AW",
24 | "AW0",
25 | "AW1",
26 | "AW2",
27 | "AY",
28 | "AY0",
29 | "AY1",
30 | "AY2",
31 | "B",
32 | "CH",
33 | "D",
34 | "DH",
35 | "EH",
36 | "EH0",
37 | "EH1",
38 | "EH2",
39 | "ER",
40 | "ER0",
41 | "ER1",
42 | "ER2",
43 | "EY",
44 | "EY0",
45 | "EY1",
46 | "EY2",
47 | "F",
48 | "G",
49 | "HH",
50 | "IH",
51 | "IH0",
52 | "IH1",
53 | "IH2",
54 | "IY",
55 | "IY0",
56 | "IY1",
57 | "IY2",
58 | "JH",
59 | "K",
60 | "L",
61 | "M",
62 | "N",
63 | "NG",
64 | "OW",
65 | "OW0",
66 | "OW1",
67 | "OW2",
68 | "OY",
69 | "OY0",
70 | "OY1",
71 | "OY2",
72 | "P",
73 | "R",
74 | "S",
75 | "SH",
76 | "T",
77 | "TH",
78 | "UH",
79 | "UH0",
80 | "UH1",
81 | "UH2",
82 | "UW",
83 | "UW0",
84 | "UW1",
85 | "UW2",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | ]
92 |
93 | _valid_symbol_set = set(valid_symbols)
94 |
95 |
96 | class CMUDict:
97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98 |
99 | def __init__(self, file_or_path, keep_ambiguous=True):
100 | if isinstance(file_or_path, str):
101 | with open(file_or_path, encoding="latin-1") as f:
102 | entries = _parse_cmudict(f)
103 | else:
104 | entries = _parse_cmudict(file_or_path)
105 | if not keep_ambiguous:
106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107 | self._entries = entries
108 |
109 | def __len__(self):
110 | return len(self._entries)
111 |
112 | def lookup(self, word):
113 | """Returns list of ARPAbet pronunciations of the given word."""
114 | return self._entries.get(word.upper())
115 |
116 |
117 | _alt_re = re.compile(r"\([0-9]+\)")
118 |
119 |
120 | def _parse_cmudict(file):
121 | cmudict = {}
122 | for line in file:
123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124 | parts = line.split(" ")
125 | word = re.sub(_alt_re, "", parts[0])
126 | pronunciation = _get_pronunciation(parts[1])
127 | if pronunciation:
128 | if word in cmudict:
129 | cmudict[word].append(pronunciation)
130 | else:
131 | cmudict[word] = [pronunciation]
132 | return cmudict
133 |
134 |
135 | def _get_pronunciation(s):
136 | parts = s.strip().split(" ")
137 | for part in parts:
138 | if part not in _valid_symbol_set:
139 | return None
140 | return " ".join(parts)
141 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars" # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return "%s %s" % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return "%s %s" % (cents, cent_unit)
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(
60 | num, andword="", zero="oh", group=2
61 | ).replace(", ", " ")
62 | else:
63 | return _inflect.number_to_words(num, andword="")
64 |
65 |
66 | def normalize_numbers(text):
67 | text = re.sub(_comma_number_re, _remove_commas, text)
68 | text = re.sub(_pounds_re, r"\1 pounds", text)
69 | text = re.sub(_dollars_re, _expand_dollars, text)
70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71 | text = re.sub(_ordinal_re, _expand_ordinal, text)
72 | text = re.sub(_number_re, _expand_number, text)
73 | return text
74 |
--------------------------------------------------------------------------------
/text/pinyin.py:
--------------------------------------------------------------------------------
1 | initials = [
2 | "b",
3 | "c",
4 | "ch",
5 | "d",
6 | "f",
7 | "g",
8 | "h",
9 | "j",
10 | "k",
11 | "l",
12 | "m",
13 | "n",
14 | "p",
15 | "q",
16 | "r",
17 | "s",
18 | "sh",
19 | "t",
20 | "w",
21 | "x",
22 | "y",
23 | "z",
24 | "zh",
25 | ]
26 | finals = [
27 | "a1",
28 | "a2",
29 | "a3",
30 | "a4",
31 | "a5",
32 | "ai1",
33 | "ai2",
34 | "ai3",
35 | "ai4",
36 | "ai5",
37 | "an1",
38 | "an2",
39 | "an3",
40 | "an4",
41 | "an5",
42 | "ang1",
43 | "ang2",
44 | "ang3",
45 | "ang4",
46 | "ang5",
47 | "ao1",
48 | "ao2",
49 | "ao3",
50 | "ao4",
51 | "ao5",
52 | "e1",
53 | "e2",
54 | "e3",
55 | "e4",
56 | "e5",
57 | "ei1",
58 | "ei2",
59 | "ei3",
60 | "ei4",
61 | "ei5",
62 | "en1",
63 | "en2",
64 | "en3",
65 | "en4",
66 | "en5",
67 | "eng1",
68 | "eng2",
69 | "eng3",
70 | "eng4",
71 | "eng5",
72 | "er1",
73 | "er2",
74 | "er3",
75 | "er4",
76 | "er5",
77 | "i1",
78 | "i2",
79 | "i3",
80 | "i4",
81 | "i5",
82 | "ia1",
83 | "ia2",
84 | "ia3",
85 | "ia4",
86 | "ia5",
87 | "ian1",
88 | "ian2",
89 | "ian3",
90 | "ian4",
91 | "ian5",
92 | "iang1",
93 | "iang2",
94 | "iang3",
95 | "iang4",
96 | "iang5",
97 | "iao1",
98 | "iao2",
99 | "iao3",
100 | "iao4",
101 | "iao5",
102 | "ie1",
103 | "ie2",
104 | "ie3",
105 | "ie4",
106 | "ie5",
107 | "ii1",
108 | "ii2",
109 | "ii3",
110 | "ii4",
111 | "ii5",
112 | "iii1",
113 | "iii2",
114 | "iii3",
115 | "iii4",
116 | "iii5",
117 | "in1",
118 | "in2",
119 | "in3",
120 | "in4",
121 | "in5",
122 | "ing1",
123 | "ing2",
124 | "ing3",
125 | "ing4",
126 | "ing5",
127 | "iong1",
128 | "iong2",
129 | "iong3",
130 | "iong4",
131 | "iong5",
132 | "iou1",
133 | "iou2",
134 | "iou3",
135 | "iou4",
136 | "iou5",
137 | "o1",
138 | "o2",
139 | "o3",
140 | "o4",
141 | "o5",
142 | "ong1",
143 | "ong2",
144 | "ong3",
145 | "ong4",
146 | "ong5",
147 | "ou1",
148 | "ou2",
149 | "ou3",
150 | "ou4",
151 | "ou5",
152 | "u1",
153 | "u2",
154 | "u3",
155 | "u4",
156 | "u5",
157 | "ua1",
158 | "ua2",
159 | "ua3",
160 | "ua4",
161 | "ua5",
162 | "uai1",
163 | "uai2",
164 | "uai3",
165 | "uai4",
166 | "uai5",
167 | "uan1",
168 | "uan2",
169 | "uan3",
170 | "uan4",
171 | "uan5",
172 | "uang1",
173 | "uang2",
174 | "uang3",
175 | "uang4",
176 | "uang5",
177 | "uei1",
178 | "uei2",
179 | "uei3",
180 | "uei4",
181 | "uei5",
182 | "uen1",
183 | "uen2",
184 | "uen3",
185 | "uen4",
186 | "uen5",
187 | "uo1",
188 | "uo2",
189 | "uo3",
190 | "uo4",
191 | "uo5",
192 | "v1",
193 | "v2",
194 | "v3",
195 | "v4",
196 | "v5",
197 | "van1",
198 | "van2",
199 | "van3",
200 | "van4",
201 | "van5",
202 | "ve1",
203 | "ve2",
204 | "ve3",
205 | "ve4",
206 | "ve5",
207 | "vn1",
208 | "vn2",
209 | "vn3",
210 | "vn4",
211 | "vn5",
212 | ]
213 | valid_symbols = initials + finals + ["rr"]
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
7 |
8 | from text import cmudict, pinyin
9 |
10 | _pad = "_"
11 | _punctuation = "!'(),.:;? "
12 | _special = "-"
13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
14 | _silences = ["@sp", "@spn", "@sil"]
15 |
16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
17 | _arpabet = ["@" + s for s in cmudict.valid_symbols]
18 | _pinyin = ["@" + s for s in pinyin.valid_symbols]
19 |
20 | # Export all symbols:
21 | symbols = (
22 | [_pad]
23 | + list(_special)
24 | + list(_punctuation)
25 | + list(_letters)
26 | + _arpabet
27 | + _pinyin
28 | + _silences
29 | )
30 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import yaml
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 | from torch.utils.tensorboard import SummaryWriter
9 | from tqdm import tqdm
10 |
11 | from utils.model import get_model, get_vocoder, get_param_num
12 | from utils.tools import to_device, log, synth_one_sample
13 | from model import MetaStyleSpeechLossMain, MetaStyleSpeechLossDisc
14 | from dataset import Dataset
15 |
16 | from evaluate import evaluate
17 |
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 |
21 | def backward(model, optimizer, total_loss, step, grad_acc_step, grad_clip_thresh):
22 | total_loss = total_loss / grad_acc_step
23 | total_loss.backward()
24 | if step % grad_acc_step == 0:
25 | # Clipping gradients to avoid gradient explosion
26 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
27 |
28 | # Update weights
29 | optimizer.step_and_update_lr()
30 | optimizer.zero_grad()
31 |
32 |
33 | def main(args, configs):
34 | print("Prepare training ...")
35 |
36 | preprocess_config, model_config, train_config = configs
37 |
38 | # Get dataset
39 | dataset = Dataset(
40 | "train_filtered.txt", preprocess_config, train_config, sort=True, drop_last=True
41 | )
42 | batch_size = train_config["optimizer"]["batch_size"]
43 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset
44 | assert batch_size * group_size < len(dataset)
45 | loader = DataLoader(
46 | dataset,
47 | batch_size=batch_size * group_size,
48 | shuffle=True,
49 | collate_fn=dataset.collate_fn,
50 | )
51 |
52 | # Prepare model
53 | model, optimizer_main, optimizer_disc = get_model(args, configs, device, train=True)
54 | model = nn.DataParallel(model)
55 | num_param = get_param_num(model)
56 | Loss_1 = MetaStyleSpeechLossMain(preprocess_config, model_config, train_config).to(device)
57 | Loss_2 = MetaStyleSpeechLossDisc(preprocess_config, model_config).to(device)
58 | print("Number of StyleSpeech Parameters:", num_param)
59 |
60 | # Load vocoder
61 | vocoder = get_vocoder(model_config, device)
62 |
63 | # Init logger
64 | for p in train_config["path"].values():
65 | os.makedirs(p, exist_ok=True)
66 | train_log_path = os.path.join(train_config["path"]["log_path"], "train")
67 | val_log_path = os.path.join(train_config["path"]["log_path"], "val")
68 | os.makedirs(train_log_path, exist_ok=True)
69 | os.makedirs(val_log_path, exist_ok=True)
70 | train_logger = SummaryWriter(train_log_path)
71 | val_logger = SummaryWriter(val_log_path)
72 |
73 | # Training
74 | step = args.restore_step + 1
75 | epoch = 1
76 | meta_learning_warmup = train_config["step"]["meta_learning_warmup"]
77 | grad_acc_step = train_config["optimizer"]["grad_acc_step"]
78 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
79 | total_step = train_config["step"]["total_step"]
80 | log_step = train_config["step"]["log_step"]
81 | save_step = train_config["step"]["save_step"]
82 | synth_step = train_config["step"]["synth_step"]
83 | val_step = train_config["step"]["val_step"]
84 |
85 | outer_bar = tqdm(total=total_step, desc="Training", position=0)
86 | outer_bar.n = args.restore_step
87 | outer_bar.update()
88 |
89 | while True:
90 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
91 | for batchs in loader:
92 | for batch in batchs:
93 | batch = to_device(batch, device)
94 |
95 | # Warm-up Stage
96 | if step <= meta_learning_warmup:
97 | # Forward
98 | output = (None, None, *model(*(batch[2:-5])))
99 | # Meta Learning
100 | else:
101 | # Step 1: Update Enc_s and G
102 | output = model.module.meta_learner_1(*(batch[2:]))
103 |
104 | # Cal Loss
105 | losses_1 = Loss_1(batch, output)
106 | total_loss = losses_1[0]
107 |
108 | # Backward
109 | backward(model, optimizer_main, total_loss, step, grad_acc_step, grad_clip_thresh)
110 |
111 | # Meta Learning
112 | if step > meta_learning_warmup:
113 | # Step 2: Update D_t and D_s
114 | output_disc = model.module.meta_learner_2(*(batch[2:]))
115 |
116 | losses_2 = Loss_2(batch[2], output_disc)
117 | total_loss_disc = losses_2[0]
118 |
119 | backward(model, optimizer_disc, total_loss_disc, step, grad_acc_step, grad_clip_thresh)
120 |
121 | if step % log_step == 0:
122 | if step > meta_learning_warmup:
123 | losses = [l.item() for l in (losses_1+losses_2[1:])]
124 | else:
125 | losses = [l.item() for l in (losses_1+tuple([torch.zeros(1).to(device) for _ in range(3)]))]
126 | message1 = "Step {}/{}, ".format(step, total_step)
127 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}, Adversarial_D_s Loss: {:.4f}, Adversarial_D_t Loss: {:.4f}, D_s Loss: {:.4f}, D_t Loss: {:.4f}, cls Loss: {:.4f}".format(
128 | *losses
129 | )
130 |
131 | with open(os.path.join(train_log_path, "log.txt"), "a") as f:
132 | f.write(message1 + message2 + "\n")
133 |
134 | outer_bar.write(message1 + message2)
135 |
136 | log(train_logger, step, losses=losses)
137 |
138 | if step % synth_step == 0:
139 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
140 | batch,
141 | output[2:],
142 | vocoder,
143 | model_config,
144 | preprocess_config,
145 | )
146 | log(
147 | train_logger,
148 | fig=fig,
149 | tag="Training/step_{}_{}".format(step, tag),
150 | )
151 | sampling_rate = preprocess_config["preprocessing"]["audio"][
152 | "sampling_rate"
153 | ]
154 | log(
155 | train_logger,
156 | audio=wav_reconstruction,
157 | sampling_rate=sampling_rate,
158 | tag="Training/step_{}_{}_reconstructed".format(step, tag),
159 | )
160 | log(
161 | train_logger,
162 | audio=wav_prediction,
163 | sampling_rate=sampling_rate,
164 | tag="Training/step_{}_{}_synthesized".format(step, tag),
165 | )
166 |
167 | if step % val_step == 0:
168 | model.eval()
169 | message = evaluate(model, step, configs, val_logger, vocoder, len(losses))
170 | with open(os.path.join(val_log_path, "log.txt"), "a") as f:
171 | f.write(message + "\n")
172 | outer_bar.write(message)
173 |
174 | model.train()
175 |
176 | if step % save_step == 0:
177 | torch.save(
178 | {
179 | "model": model.module.state_dict(),
180 | "optimizer_main": optimizer_main._optimizer.state_dict(),
181 | "optimizer_disc": optimizer_disc._optimizer.state_dict(),
182 | },
183 | os.path.join(
184 | train_config["path"]["ckpt_path"],
185 | "{}.pth.tar".format(step),
186 | ),
187 | )
188 |
189 | if step == total_step:
190 | quit()
191 | step += 1
192 | outer_bar.update(1)
193 |
194 | inner_bar.update(1)
195 | epoch += 1
196 |
197 |
198 | if __name__ == "__main__":
199 | parser = argparse.ArgumentParser()
200 | parser.add_argument("--restore_step", type=int, default=0)
201 | parser.add_argument(
202 | "-p",
203 | "--preprocess_config",
204 | type=str,
205 | required=True,
206 | help="path to preprocess.yaml",
207 | )
208 | parser.add_argument(
209 | "-m", "--model_config", type=str, required=True, help="path to model.yaml"
210 | )
211 | parser.add_argument(
212 | "-t", "--train_config", type=str, required=True, help="path to train.yaml"
213 | )
214 | args = parser.parse_args()
215 |
216 | # Read Config
217 | preprocess_config = yaml.load(
218 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader
219 | )
220 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
221 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
222 | configs = (preprocess_config, model_config, train_config)
223 |
224 | main(args, configs)
225 |
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import numpy as np
6 |
7 | import hifigan
8 | from model import StyleSpeech, ScheduledOptimMain, ScheduledOptimDisc
9 |
10 |
11 | def get_model(args, configs, device, train=False):
12 | (preprocess_config, model_config, train_config) = configs
13 |
14 | model = StyleSpeech(preprocess_config, model_config).to(device)
15 | if args.restore_step:
16 | ckpt_path = os.path.join(
17 | train_config["path"]["ckpt_path"],
18 | "{}.pth.tar".format(args.restore_step),
19 | )
20 | ckpt = torch.load(ckpt_path)
21 | model.load_state_dict(ckpt["model"])
22 |
23 | if train:
24 | scheduled_optim_main = ScheduledOptimMain(
25 | model, train_config, model_config, args.restore_step
26 | )
27 | scheduled_optim_disc = ScheduledOptimDisc(
28 | model, train_config
29 | )
30 | if args.restore_step:
31 | scheduled_optim_main.load_state_dict(ckpt["optimizer_main"])
32 | scheduled_optim_disc.load_state_dict(ckpt["optimizer_disc"])
33 | model.train()
34 | return model, scheduled_optim_main, scheduled_optim_disc
35 |
36 | model.eval()
37 | model.requires_grad_ = False
38 | return model
39 |
40 |
41 | def get_param_num(model):
42 | num_param = sum(param.numel() for param in model.parameters())
43 | return num_param
44 |
45 |
46 | def get_vocoder(config, device):
47 | name = config["vocoder"]["model"]
48 | speaker = config["vocoder"]["speaker"]
49 |
50 | if name == "MelGAN":
51 | if speaker == "LJSpeech":
52 | vocoder = torch.hub.load(
53 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
54 | )
55 | elif speaker == "universal":
56 | vocoder = torch.hub.load(
57 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
58 | )
59 | vocoder.mel2wav.eval()
60 | vocoder.mel2wav.to(device)
61 | elif name == "HiFi-GAN":
62 | with open("hifigan/config.json", "r") as f:
63 | config = json.load(f)
64 | config = hifigan.AttrDict(config)
65 | vocoder = hifigan.Generator(config)
66 | if speaker == "LJSpeech":
67 | ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar")
68 | elif speaker == "universal":
69 | ckpt = torch.load("hifigan/generator_universal.pth.tar")
70 | vocoder.load_state_dict(ckpt["generator"])
71 | vocoder.eval()
72 | vocoder.remove_weight_norm()
73 | vocoder.to(device)
74 |
75 | return vocoder
76 |
77 |
78 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
79 | name = model_config["vocoder"]["model"]
80 | with torch.no_grad():
81 | if name == "MelGAN":
82 | wavs = vocoder.inverse(mels / np.log(10))
83 | elif name == "HiFi-GAN":
84 | wavs = vocoder(mels).squeeze(1)
85 |
86 | wavs = (
87 | wavs.cpu().numpy()
88 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"]
89 | ).astype("int16")
90 | wavs = [wav for wav in wavs]
91 |
92 | for i in range(len(mels)):
93 | if lengths is not None:
94 | wavs[i] = wavs[i][: lengths[i]]
95 |
96 | return wavs
97 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import matplotlib
8 | from scipy.io import wavfile
9 | from matplotlib import pyplot as plt
10 |
11 |
12 | matplotlib.use("Agg")
13 |
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | def to_device(data, device):
19 | if len(data) == 17:
20 | (
21 | ids,
22 | raw_texts,
23 | speakers,
24 | texts,
25 | src_lens,
26 | max_src_len,
27 | mels,
28 | mel_lens,
29 | max_mel_len,
30 | pitches,
31 | energies,
32 | durations,
33 | raw_quary_texts,
34 | quary_texts,
35 | quary_src_lens,
36 | max_quary_src_len,
37 | quary_durations,
38 | ) = data
39 |
40 | speakers = torch.from_numpy(speakers).long().to(device)
41 | texts = torch.from_numpy(texts).long().to(device)
42 | src_lens = torch.from_numpy(src_lens).to(device)
43 | quary_texts = torch.from_numpy(quary_texts).long().to(device)
44 | quary_src_lens = torch.from_numpy(quary_src_lens).to(device)
45 | mels = torch.from_numpy(mels).float().to(device)
46 | mel_lens = torch.from_numpy(mel_lens).to(device)
47 | pitches = torch.from_numpy(pitches).float().to(device)
48 | energies = torch.from_numpy(energies).to(device)
49 | durations = torch.from_numpy(durations).long().to(device)
50 | quary_durations = torch.from_numpy(quary_durations).long().to(device)
51 |
52 | return (
53 | ids,
54 | raw_texts,
55 | speakers,
56 | texts,
57 | src_lens,
58 | max_src_len,
59 | mels,
60 | mel_lens,
61 | max_mel_len,
62 | pitches,
63 | energies,
64 | durations,
65 | raw_quary_texts,
66 | quary_texts,
67 | quary_src_lens,
68 | max_quary_src_len,
69 | quary_durations,
70 | )
71 |
72 | if len(data) == 10:
73 | (
74 | ids,
75 | raw_texts,
76 | speakers,
77 | texts,
78 | src_lens,
79 | max_src_len,
80 | mels,
81 | mel_lens,
82 | max_mel_len,
83 | ref_infos,
84 | ) = data
85 |
86 | texts = torch.from_numpy(texts).long().to(device)
87 | src_lens = torch.from_numpy(src_lens).to(device)
88 | mels = torch.from_numpy(mels).float().to(device)
89 | mel_lens = torch.from_numpy(mel_lens).to(device)
90 |
91 | return (
92 | ids,
93 | raw_texts,
94 | speakers,
95 | texts,
96 | src_lens,
97 | max_src_len,
98 | mels,
99 | mel_lens,
100 | max_mel_len,
101 | ref_infos,
102 | )
103 |
104 |
105 | def log(
106 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""
107 | ):
108 | if losses is not None:
109 | logger.add_scalar("Loss/total_loss", losses[0], step)
110 | logger.add_scalar("Loss/mel_loss", losses[1], step)
111 | logger.add_scalar("Loss/pitch_loss", losses[2], step)
112 | logger.add_scalar("Loss/energy_loss", losses[3], step)
113 | logger.add_scalar("Loss/duration_loss", losses[4], step)
114 | logger.add_scalar("Loss/adv_D_s_loss", losses[5], step)
115 | logger.add_scalar("Loss/adv_D_t_loss", losses[6], step)
116 | logger.add_scalar("Loss/D_s_loss", losses[7], step)
117 | logger.add_scalar("Loss/D_t_loss", losses[8], step)
118 | logger.add_scalar("Loss/cls_loss", losses[9], step)
119 |
120 | if fig is not None:
121 | logger.add_figure(tag, fig)
122 |
123 | if audio is not None:
124 | logger.add_audio(
125 | tag,
126 | audio / max(abs(audio)),
127 | sample_rate=sampling_rate,
128 | )
129 |
130 |
131 | def get_mask_from_lengths(lengths, max_len=None):
132 | batch_size = lengths.shape[0]
133 | if max_len is None:
134 | max_len = torch.max(lengths).item()
135 |
136 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
137 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
138 |
139 | return mask
140 |
141 |
142 | def expand(values, durations):
143 | out = list()
144 | for value, d in zip(values, durations):
145 | out += [value] * max(0, int(d))
146 | return np.array(out)
147 |
148 |
149 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config):
150 |
151 | basename = targets[0][0]
152 | src_len = predictions[7][0].item()
153 | mel_len = predictions[8][0].item()
154 | mel_target = targets[6][0, :mel_len].detach().transpose(0, 1)
155 | mel_prediction = predictions[0][0, :mel_len].detach().transpose(0, 1)
156 | duration = targets[11][0, :src_len].detach().cpu().numpy()
157 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
158 | pitch = targets[9][0, :src_len].detach().cpu().numpy()
159 | pitch = expand(pitch, duration)
160 | else:
161 | pitch = targets[9][0, :mel_len].detach().cpu().numpy()
162 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
163 | energy = targets[10][0, :src_len].detach().cpu().numpy()
164 | energy = expand(energy, duration)
165 | else:
166 | energy = targets[10][0, :mel_len].detach().cpu().numpy()
167 |
168 | with open(
169 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
170 | ) as f:
171 | stats = json.load(f)
172 | stats = stats["pitch"] + stats["energy"][:2]
173 |
174 | fig = plot_mel(
175 | [
176 | (mel_prediction.cpu().numpy(), pitch, energy),
177 | (mel_target.cpu().numpy(), pitch, energy),
178 | ],
179 | stats,
180 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
181 | )
182 |
183 | if vocoder is not None:
184 | from .model import vocoder_infer
185 |
186 | wav_reconstruction = vocoder_infer(
187 | mel_target.unsqueeze(0),
188 | vocoder,
189 | model_config,
190 | preprocess_config,
191 | )[0]
192 | wav_prediction = vocoder_infer(
193 | mel_prediction.unsqueeze(0),
194 | vocoder,
195 | model_config,
196 | preprocess_config,
197 | )[0]
198 | else:
199 | wav_reconstruction = wav_prediction = None
200 |
201 | return fig, wav_reconstruction, wav_prediction, basename
202 |
203 |
204 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
205 |
206 | basenames = targets[0]
207 | for i in range(len(predictions[0])):
208 | basename = basenames[i]
209 | src_len = predictions[7][i].item()
210 | mel_len = predictions[8][i].item()
211 | mel_prediction = predictions[0][i, :mel_len].detach().transpose(0, 1)
212 | duration = predictions[4][i, :src_len].detach().cpu().numpy()
213 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
214 | pitch = predictions[1][i, :src_len].detach().cpu().numpy()
215 | pitch = expand(pitch, duration)
216 | else:
217 | pitch = predictions[1][i, :mel_len].detach().cpu().numpy()
218 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
219 | energy = predictions[2][i, :src_len].detach().cpu().numpy()
220 | energy = expand(energy, duration)
221 | else:
222 | energy = predictions[2][i, :mel_len].detach().cpu().numpy()
223 |
224 | with open(
225 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
226 | ) as f:
227 | stats = json.load(f)
228 | stats = stats["pitch"] + stats["energy"][:2]
229 |
230 | fig = plot_mel(
231 | [
232 | (mel_prediction.cpu().numpy(), pitch, energy),
233 | targets[-1][i],
234 | ],
235 | stats,
236 | ["Synthetized Spectrogram", "Reference Spectrogram"],
237 | )
238 | plt.savefig(os.path.join(path, "{}.png".format(basename)))
239 | plt.close()
240 |
241 | from .model import vocoder_infer
242 |
243 | mel_predictions = predictions[0].transpose(1, 2)
244 | lengths = predictions[8] * preprocess_config["preprocessing"]["stft"]["hop_length"]
245 | wav_predictions = vocoder_infer(
246 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
247 | )
248 |
249 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
250 | for wav, basename in zip(wav_predictions, basenames):
251 | wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
252 |
253 |
254 | def plot_mel(data, stats, titles):
255 | fig, axes = plt.subplots(len(data), 1, squeeze=False)
256 | if titles is None:
257 | titles = [None for i in range(len(data))]
258 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
259 | pitch_min = pitch_min * pitch_std + pitch_mean
260 | pitch_max = pitch_max * pitch_std + pitch_mean
261 |
262 | def add_axis(fig, old_ax):
263 | ax = fig.add_axes(old_ax.get_position(), anchor="W")
264 | ax.set_facecolor("None")
265 | return ax
266 |
267 | for i in range(len(data)):
268 | mel, pitch, energy = data[i]
269 | pitch = pitch * pitch_std + pitch_mean
270 | axes[i][0].imshow(mel, origin="lower")
271 | axes[i][0].set_aspect(2.5, adjustable="box")
272 | axes[i][0].set_ylim(0, mel.shape[0])
273 | axes[i][0].set_title(titles[i], fontsize="medium")
274 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
275 | axes[i][0].set_anchor("W")
276 |
277 | ax1 = add_axis(fig, axes[i][0])
278 | ax1.plot(pitch, color="tomato", linewidth=.7)
279 | ax1.set_xlim(0, mel.shape[1])
280 | ax1.set_ylim(0, pitch_max)
281 | ax1.set_ylabel("F0", color="tomato")
282 | ax1.tick_params(
283 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
284 | )
285 |
286 | ax2 = add_axis(fig, axes[i][0])
287 | ax2.plot(energy, color="darkviolet", linewidth=.7)
288 | ax2.set_xlim(0, mel.shape[1])
289 | ax2.set_ylim(energy_min, energy_max)
290 | ax2.set_ylabel("Energy", color="darkviolet")
291 | ax2.yaxis.set_label_position("right")
292 | ax2.tick_params(
293 | labelsize="x-small",
294 | colors="darkviolet",
295 | bottom=False,
296 | labelbottom=False,
297 | left=False,
298 | labelleft=False,
299 | right=True,
300 | labelright=True,
301 | )
302 |
303 | return fig
304 |
305 |
306 | def pad_1D(inputs, PAD=0):
307 | def pad_data(x, length, PAD):
308 | x_padded = np.pad(
309 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
310 | )
311 | return x_padded
312 |
313 | max_len = max((len(x) for x in inputs))
314 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
315 |
316 | return padded
317 |
318 |
319 | def pad_2D(inputs, maxlen=None):
320 | def pad(x, max_len):
321 | PAD = 0
322 | if np.shape(x)[0] > max_len:
323 | raise ValueError("not max_len")
324 |
325 | s = np.shape(x)[1]
326 | x_padded = np.pad(
327 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
328 | )
329 | return x_padded[:, :s]
330 |
331 | if maxlen:
332 | output = np.stack([pad(x, maxlen) for x in inputs])
333 | else:
334 | max_len = max(np.shape(x)[0] for x in inputs)
335 | output = np.stack([pad(x, max_len) for x in inputs])
336 |
337 | return output
338 |
339 |
340 | def pad(input_ele, mel_max_length=None):
341 | if mel_max_length:
342 | max_len = mel_max_length
343 | else:
344 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
345 |
346 | out_list = list()
347 | for i, batch in enumerate(input_ele):
348 | if len(batch.shape) == 1:
349 | one_batch_padded = F.pad(
350 | batch, (0, max_len - batch.size(0)), "constant", 0.0
351 | )
352 | elif len(batch.shape) == 2:
353 | one_batch_padded = F.pad(
354 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
355 | )
356 | out_list.append(one_batch_padded)
357 | out_padded = torch.stack(out_list)
358 | return out_padded
359 |
--------------------------------------------------------------------------------