├── .gitignore
├── LICENSE
├── README.md
├── configs
├── svs
│ ├── data.yaml
│ ├── model.yaml
│ └── train.yaml
└── tts
│ ├── data.yaml
│ ├── model.yaml
│ └── train.yaml
├── dataset
├── __init__.py
├── dataset_svs.py
├── dataset_tts.py
├── espnet_texts
│ ├── __init__.py
│ ├── cleaners.py
│ ├── cmudict.py
│ ├── dict.py
│ ├── numbers.py
│ └── symbols.py
└── texts
│ ├── __init__.py
│ ├── cleaners.py
│ ├── cmudict.py
│ ├── numbers.py
│ ├── pinyin.py
│ └── symbols.py
├── lexicon
├── libritts-lexicon.txt
└── pinyin-lexicon.txt
├── loss
├── __init__.py
├── fastspeech2_loss.py
└── loss.py
├── models
├── __init__.py
├── discriminator.py
├── fastspeech2.py
└── xiaoice2.py
├── modules
├── __init__.py
├── conv
│ └── __init__.py
├── transformer
│ ├── Constants.py
│ ├── Layers.py
│ ├── Models.py
│ ├── Modules.py
│ ├── SubLayers.py
│ └── __init__.py
└── variance
│ ├── __init__.py
│ └── modules.py
├── pics
├── 2085003136_145600.png
├── after_2085003136_145600.png
├── before_2085003136_145600.png
├── before_mel_l2_loss.png
├── post_mel_l2_loss.png
└── xs1_before_2085003136_145600.png
├── preprocess
├── audio_preprocess.py
└── data_prep.py
├── pyutils
├── __init__.py
├── gen_duration_from_tg.py
├── logger.py
├── mask.py
├── optimizer.py
├── parse_options.sh
├── plot.py
├── save_and_load.py
└── scheduler.py
├── run.sh
├── train.py
├── train_gan.py
└── utils
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | #
162 | job.sh
163 | data
164 | data/*
165 | exp
166 | exp/*
167 | *.out
168 | wandb
169 | wandb/*
170 | .nfs*
171 | local_run.sh
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, zengchang233
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [XiaoiceSing2](https://www.isca-speech.org/archive/interspeech_2023/chunhui23_interspeech.html)
2 | The source code for the paper [XiaoiceSing2](https://www.isca-speech.org/archive/interspeech_2023/chunhui23_interspeech.html) (interspeech2023)
3 |
4 | [Demo page](https://wavelandspeech.github.io/xiaoice2/)
5 |
6 | ## Notice
7 |
8 | I am busy with job-hunting now. I will update other modules, including the [HiFi-WaveGAN](https://arxiv.org/abs/2210.12740) after my final decision.
9 |
10 | ## Implementation (developping)
11 |
12 | - [x] fastspeech2-based generator
13 | - [x] discriminator group, including segment discriminators and detail discriminators
14 | - [ ] ConvFFT block
15 |
16 | ## Dataset and preparation
17 |
18 | - [x] opencpop 
19 | - [ ] kiritan 
20 | - [ ] CSD 
21 | - [ ] m4singer 
22 | - [ ] NUS48E
23 |
24 | Kaldi style preparation
25 |
26 | - wav.scp
27 | - utt2spk
28 | - spk2utt
29 | - text
30 |
31 | ```
32 | ./run.sh --start-stage 1 --stop-stage 1 # extract melspectrogram, f0, energy, and statistical value
33 | ```
34 |
35 | ## Training
36 |
37 | ```
38 | ./run.sh --start-stage 2 --stop-stage 2
39 | ```
40 |
41 | ### Real and generated melspectrogram (145600 training steps)
42 |
43 | Real(left) XiaoiceSing(middle) XiaoiceSing2(right)
44 |
45 |
50 |
51 | ### L2 loss curve for melspectrogram
52 |
53 | L2 loss before post-processing(left) L2 loss after post-processing(right)
54 |
55 |
56 |

57 |

58 |
59 |
60 | ## Inference
61 |
62 | ```
63 | ./run.sh --start-stage 3 --stop-stage 3
64 | ```
65 |
--------------------------------------------------------------------------------
/configs/svs/data.yaml:
--------------------------------------------------------------------------------
1 | audio_manifest: 'data/opencpop/train.scp'
2 | svs_manifest: 'data/opencpop/train.txt'
3 | spk_manifest: 'data/opencpop/utt2spk'
4 | f0_min_max: 'data/opencpop/f0_min_max.npy'
5 | f0_mean: 'data/opencpop/f0_mean.npy'
6 | f0_std: 'data/opencpop/f0_std.npy'
7 | energy_min_max: 'data/opencpop/energy_min_max.npy'
8 | energy_mean: 'data/opencpop/energy_mean.npy'
9 | energy_std: 'data/opencpop/energy_std.npy'
10 | phone_set: 'data/opencpop/phone_set.txt'
11 |
12 | n_fft: 1024
13 | n_mels: 120
14 | hop_length: 256
15 | win_length: 1024
16 | sampling_rate: 44100
17 | seg_size: 700
18 | fmin: 0.0
19 | fmax: 22050
20 |
21 | tts_cleaner_names: []
22 | use_phonemes: True
23 | eos: False
24 |
25 | pitch:
26 | feature: "frame_level" # support 'phoneme_level' or 'frame_level'
27 | normalization: True
28 | energy:
29 | feature: "frame_level" # support 'phoneme_level' or 'frame_level'
30 | normalization: True
31 |
--------------------------------------------------------------------------------
/configs/svs/model.yaml:
--------------------------------------------------------------------------------
1 | generator:
2 | transformer:
3 | encoder:
4 | max_seq_len: 5000
5 | n_src_vocab: 100 # random number, will be reassigned in train.py
6 | d_word_vec: 512 # dimension of word vector
7 | n_layers: 6
8 | n_head: 8
9 | d_model: 512
10 | d_inner: 2048
11 | kernel_size: [9, 1]
12 | dropout: 0.2
13 | max_note_pitch: 88
14 | max_note_duration: 2000
15 | decoder:
16 | max_seq_len: 5000
17 | d_word_vec: 512
18 | n_layers: 6
19 | n_head: 8
20 | d_model: 512
21 | d_inner: 2048
22 | kernel_size: [9, 1]
23 | dropout: 0.2
24 |
25 | variance_predictor:
26 | input_size: 512
27 | filter_size: 512
28 | kernel_size: 3
29 | dropout: 0.5
30 |
31 | variance_embedding:
32 | pitch_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
33 | energy_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
34 | n_bins: 256
35 |
36 | multi_speaker: False
37 | uv_threshold: 0.5
38 |
39 | postnet:
40 | postnet_embedding_dim: 512
41 | postnet_kernel_size: 5
42 | postnet_n_convolutions: 5
43 |
44 | discriminator:
45 | segment_disc:
46 | pass
47 |
48 | detail_disc:
49 | pass
50 |
51 | vocoder:
52 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
53 | speaker: "universal" # support 'LJSpeech', 'universal'
54 |
--------------------------------------------------------------------------------
/configs/svs/train.yaml:
--------------------------------------------------------------------------------
1 | epochs: 800
2 | batch_size: 8
3 | grad_clip: 1.0
4 | num_workers: 8
5 |
6 | feat_loss_weight: [1.0, 1.0, 1.0]
7 | adv_g_loss_weight: [0.1, 0.1, 0.1]
8 | start_disc_steps: 5000
9 |
10 | g_optimizer: 'Adam'
11 | g_optimizer_args:
12 | lr: 0.0001
13 | betas: [0.9, 0.98]
14 | eps: 0.000000001
15 | weight_decay: 0.0
16 |
17 | g_scheduler: 'WarmupLR'
18 | g_scheduler_args:
19 | warmup_steps: 4000
20 | last_epoch: -1
21 |
22 | d_optimizer: 'Adam'
23 | d_optimizer_args:
24 | lr: 0.0001
25 | betas: [0.9, 0.98]
26 | eps: 0.000000001
27 | weight_decay: 0.0
28 |
29 | d_scheduler: 'WarmupLR'
30 | d_scheduler_args:
31 | warmup_steps: 4000
32 | last_epoch: -1
33 |
34 | wandb: True
35 | wandb_args:
36 | project: 'svs'
37 | group: 'xiaoicesing2'
38 | job_type: 'opencpop'
39 | name: 'warmup4k-disc5k'
40 |
41 | log_interval: 10
42 | save_interval: 200
43 | ckpt_clean: 10
44 |
45 |
--------------------------------------------------------------------------------
/configs/tts/data.yaml:
--------------------------------------------------------------------------------
1 | audio_manifest: 'data/wav.scp'
2 | duration_manifest: 'data/aishell3/train/duration.txt'
3 | raw_text_manifest: 'data/aishell3/train/raw_text'
4 | spk_manifest: 'data/aishell3/train/utt2spk'
5 | f0_min_max: 'data/f0_min_max.npy'
6 | energy_min_max: 'data/energy_min_max.npy'
7 |
8 | n_fft: 1024
9 | n_mels: 120
10 | hop_length: 256
11 | win_length: 1024
12 | sampling_rate: 44100
13 | seg_size: 700
14 | fmin: 0.0
15 | fmax: 22050
16 |
17 | tts_cleaner_names: []
18 | use_phonemes: True
19 | eos: False
20 |
21 | pitch:
22 | feature: "frame_level" # support 'phoneme_level' or 'frame_level'
23 | normalization: True
24 | energy:
25 | feature: "frame_level" # support 'phoneme_level' or 'frame_level'
26 | normalization: True
27 |
--------------------------------------------------------------------------------
/configs/tts/model.yaml:
--------------------------------------------------------------------------------
1 | generator:
2 | transformer:
3 | encoder:
4 | max_seq_len: 5000
5 | n_src_vocab: 100 # random number, will be reassigned in train.py
6 | d_word_vec: 512 # dimension of word vector
7 | n_layers: 6
8 | n_head: 8
9 | d_model: 512
10 | d_inner: 2048
11 | kernel_size: [9, 1]
12 | dropout: 0.2
13 | decoder:
14 | max_seq_len: 5000
15 | d_word_vec: 512
16 | n_layers: 6
17 | n_head: 8
18 | d_model: 512
19 | d_inner: 2048
20 | kernel_size: [9, 1]
21 | dropout: 0.2
22 |
23 | variance_predictor:
24 | input_size: 512
25 | filter_size: 512
26 | kernel_size: 3
27 | dropout: 0.5
28 |
29 | variance_embedding:
30 | pitch_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
31 | energy_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
32 | n_bins: 256
33 |
34 | multi_speaker: False
35 | uv_threshold: 0.5
36 |
37 | postnet:
38 | postnet_embedding_dim: 512
39 | postnet_kernel_size: 5
40 | postnet_n_convolutions: 5
41 |
42 | discriminator:
43 | segment_disc:
44 | pass
45 |
46 | detail_disc:
47 | pass
48 |
49 | vocoder:
50 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN'
51 | speaker: "universal" # support 'LJSpeech', 'universal'
52 |
--------------------------------------------------------------------------------
/configs/tts/train.yaml:
--------------------------------------------------------------------------------
1 | epochs: 800
2 | batch_size: 8
3 | grad_clip: 1.0
4 | num_workers: 8
5 |
6 | feat_loss_weight: [1.0, 1.0, 1.0]
7 | adv_g_loss_weight: [0.1, 0.1, 0.1]
8 | start_disc_steps: 5000
9 |
10 | g_optimizer: 'Adam'
11 | g_optimizer_args:
12 | lr: 0.0001
13 | betas: [0.9, 0.98]
14 | eps: 0.000000001
15 | weight_decay: 0.0
16 |
17 | g_scheduler: 'WarmupLR'
18 | g_scheduler_args:
19 | warmup_steps: 4000
20 | last_epoch: -1
21 |
22 | d_optimizer: 'Adam'
23 | d_optimizer_args:
24 | lr: 0.0001
25 | betas: [0.9, 0.98]
26 | eps: 0.000000001
27 | weight_decay: 0.0
28 |
29 | d_scheduler: 'WarmupLR'
30 | d_scheduler_args:
31 | warmup_steps: 4000
32 | last_epoch: -1
33 |
34 | wandb: True
35 | wandb_args:
36 | project: 'tts'
37 | group: 'cross-lingual'
38 | job_type: 'fs2_GAN'
39 | name: 'fs2GAN_aishell3-warmup4k-disc5k'
40 |
41 | log_interval: 10
42 | save_interval: 200
43 | ckpt_clean: 10
44 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset_tts import *
2 | from .dataset_svs import *
3 | from .texts import *
4 |
--------------------------------------------------------------------------------
/dataset/dataset_svs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torch.nn.functional as F
4 |
5 | import sys
6 | sys.path.insert(0, '/home/smg/zengchang/code/svs/xiaoicesing2')
7 |
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | import numpy as np
11 | from pyutils import pad_list, remove_outlier
12 | from librosa import note_to_midi
13 |
14 | import json
15 | import os
16 |
17 | def f02pitch(f0):
18 | #f0 =f0 + 0.01
19 | return np.log2(f0 / 27.5) * 12 + 21
20 |
21 | def pitch2f0(pitch):
22 | f0 = np.exp2((pitch - 21 ) / 12) * 27.5
23 | for i in range(len(f0)):
24 | if f0[i] <= 10:
25 | f0[i] = 0
26 | return f0
27 |
28 | def pitchxuv(pitch, uv, to_f0 = False):
29 | result = pitch * uv
30 | if to_f0:
31 | result = pitch2f0(result)
32 | return result
33 |
34 | def pad1d(x, max_len):
35 | return np.pad(x, (0, max_len - len(x)), mode="constant")
36 |
37 | def pad2d(x, max_len):
38 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant")
39 |
40 | def interpolate_f0(f0):
41 | data = np.reshape(f0, (f0.size, 1))
42 |
43 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
44 | vuv_vector[data > 0.0] = 1.0
45 | vuv_vector[data <= 0.0] = 0.0
46 |
47 | ip_data = data
48 |
49 | frame_number = data.size
50 | last_value = 0.0
51 | for i in range(frame_number):
52 | if data[i] <= 0.0:
53 | j = i + 1
54 | for j in range(i + 1, frame_number):
55 | if data[j] > 0.0:
56 | break
57 | if j < frame_number - 1:
58 | if last_value > 0.0:
59 | step = (data[j] - data[i - 1]) / float(j - i)
60 | for k in range(i, j):
61 | ip_data[k] = data[i - 1] + step * (k - i + 1)
62 | else:
63 | for k in range(i, j):
64 | ip_data[k] = data[j]
65 | else:
66 | for k in range(i, frame_number):
67 | ip_data[k] = last_value
68 | else:
69 | ip_data[i] = data[i] # this may not be necessary
70 | last_value = data[i]
71 |
72 | return ip_data[:,0], vuv_vector[:,0]
73 |
74 | class SVSDataset(Dataset):
75 | def __init__(self, configs):
76 | audio_manifest = configs['audio_manifest']
77 | transcription_manifest = configs['svs_manifest']
78 | spk_manifest = configs['spk_manifest']
79 | self.sampling_rate = configs['sampling_rate']
80 | self.utt2path = {}
81 | self.utt2raw_text = {}
82 | self.utt2phone_seq = {}
83 | self.utt2note_pitch = {}
84 | self.utt2note_dur = {}
85 | self.utt2dur = {}
86 | self.utt2spk = {}
87 | hop_length = configs['hop_length'] / self.sampling_rate
88 | with open(audio_manifest, 'r') as f:
89 | for line in f:
90 | line = line.rstrip().split(' ')
91 | self.utt2path[line[0]] = line[1]
92 | with open(transcription_manifest, 'r') as f:
93 | for line in f:
94 | line = line.rstrip().split('|')
95 | self.utt2raw_text[line[0]] = line[1]
96 | self.utt2phone_seq[line[0]] = line[2].split(' ')
97 | self.utt2note_pitch[line[0]] = [note_to_midi(note.split('/')[0]) if note != 'rest' else 0 for note in line[3].split(' ')]
98 | self.utt2note_dur[line[0]] = [round(eval(dur) / hop_length) for dur in line[4].split(' ')]
99 | self.utt2dur[line[0]] = [round(eval(dur) / hop_length) for dur in line[5].split(' ')]
100 | with open(spk_manifest, 'r') as f:
101 | for line in f:
102 | line = line.rstrip().split(' ')
103 | self.utt2spk[line[0]] = line[1]
104 | if not os.path.exists(configs['phone_set']):
105 | phone_set = set()
106 | for phone_seq in self.utt2phone_seq.values():
107 | phone_set.update(phone_seq)
108 | phone_set = list(phone_set)
109 | phone_set.sort()
110 | with open(configs['phone_set'], 'w') as f:
111 | json.dump(phone_set, f)
112 | self.phone_set = phone_set
113 | else:
114 | with open(configs['phone_set'], 'r') as f:
115 | self.phone_set = json.load(f)
116 |
117 | self.spk2int = {spk: idx for idx, spk in enumerate(set(self.utt2spk.values()))}
118 | self.int2spk = {idx: spk for spk, idx in self.spk2int.items()}
119 | self.phone2idx = {phone: idx for idx, phone in enumerate(self.phone_set)}
120 | self.utt = list(self.utt2path.keys())
121 |
122 | def _norm_mean_std(self, x, mean, std, is_remove_outlier=False):
123 | if is_remove_outlier:
124 | x = remove_outlier(x)
125 | zero_idxs = np.where(x == 0.0)[0]
126 | x = (x - mean) / std
127 | x[zero_idxs] = 0.0
128 | return x
129 |
130 | def get_spk_number(self):
131 | return len(self.spk2int)
132 |
133 | def get_phone_number(self):
134 | return len(self.phone2idx)
135 |
136 | def __len__(self):
137 | return len(self.utt)
138 |
139 | def __getitem__(self, idx):
140 | uttid = self.utt[idx]
141 | mel_path = self.utt2path[uttid].replace('.wav', '.mel.npy')
142 | f0_path = self.utt2path[uttid].replace('.wav', '.f0.npy')
143 | energy_path = self.utt2path[uttid].replace('.wav', '.en.npy')
144 |
145 | mel = np.load(mel_path) #.transpose(1, 0)
146 | f0 = np.load(f0_path)
147 | f0, uv = interpolate_f0(f0)
148 | # unnormalized_f0 = self.f0_std * f0 + self.f0_mean
149 | pitch = f02pitch(f0)
150 | energy = np.load(energy_path)
151 | # energy = self.energy_std * energy + self.energy_mean
152 |
153 | raw_text = self.utt2raw_text[uttid]
154 | phone_text = self.utt2phone_seq[uttid]
155 | phone_seq = np.array([self.phone2idx[phone] for phone in phone_text])
156 | note_pitch = np.array(self.utt2note_pitch[uttid])
157 | note_duration = np.array(self.utt2note_dur[uttid])
158 | duration = np.array(self.utt2dur[uttid])
159 |
160 | mel_len = mel.shape[0]
161 | duration = duration[: len(phone_seq)]
162 | duration[-1] = duration[-1] + (mel.shape[0] - sum(duration))
163 | assert mel_len == sum(duration), f'{mel_len} != {sum(duration)}'
164 |
165 | return {
166 | 'uttid': uttid,
167 | 'raw_text': raw_text,
168 | 'text': phone_seq,
169 | 'note_pitch': note_pitch,
170 | 'note_duration': note_duration,
171 | 'mel': mel,
172 | 'duration': duration,
173 | 'pitch': pitch,
174 | 'uv': uv,
175 | 'energy': energy
176 | }
177 |
178 | class SVSCollate():
179 | def __init__(self):
180 | pass
181 |
182 | def __call__(self, batch):
183 | ilens = torch.from_numpy(np.array([x['text'].shape[0] for x in batch])).long()
184 | olens = torch.from_numpy(np.array([y['mel'].shape[0] for y in batch])).long()
185 | ids = [x['uttid'] for x in batch]
186 | raw_texts = [x['raw_text'] for x in batch]
187 |
188 | # perform padding and conversion to tensor
189 | inputs = pad_list([torch.from_numpy(x['text']).long() for x in batch], 0)
190 | note_pitchs = pad_list([torch.from_numpy(x['note_pitch']).long() for x in batch], 0)
191 | note_durations = pad_list([torch.from_numpy(x['note_duration']).long() for x in batch], 0)
192 |
193 | mels = pad_list([torch.from_numpy(y['mel']).float() for y in batch], 0)
194 | durations = pad_list([torch.from_numpy(x['duration']).long() for x in batch], 0)
195 | energys = pad_list([torch.from_numpy(y['energy']).float() for y in batch], 0).squeeze(-1)
196 | pitchs = pad_list([torch.from_numpy(y['pitch']).float() for y in batch], 0).squeeze(-1)
197 | uvs = pad_list([torch.from_numpy(y['uv']).float() for y in batch], 0).squeeze(-1)
198 |
199 | return {
200 | 'uttids': ids,
201 | 'raw_texts': raw_texts,
202 | 'texts': inputs,
203 | 'note_pitchs': note_pitchs,
204 | 'note_durations': note_durations,
205 | 'src_lens': ilens,
206 | 'max_src_len': ilens.max(),
207 | 'mels': mels,
208 | 'mel_lens': olens,
209 | 'max_mel_len': olens.max(),
210 | 'p_targets': pitchs,
211 | 'e_targets': energys,
212 | 'uv_targets': uvs,
213 | 'd_targets': durations
214 | }
215 |
216 | if __name__ == '__main__':
217 | import yaml
218 | from tqdm import tqdm
219 | from torch.utils.data import DataLoader
220 | with open('./configs/data.yaml', 'r') as f:
221 | configs = yaml.load(f, Loader = yaml.FullLoader)
222 | dataset = SVSDataset(configs)
223 | collate_fn = SVSCollate()
224 | dataloader = DataLoader(dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn, num_workers = 8)
225 | for data in tqdm(dataloader):
226 | assert data['note_pitchs'].shape[-1] == data['note_durations'].shape[-1]
227 | assert data['uv_targets'].shape[1] == data['p_targets'].shape[-1]
228 | pass
229 |
--------------------------------------------------------------------------------
/dataset/dataset_tts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torch.nn.functional as F
4 |
5 | import sys
6 | sys.path.insert(0, '/home/zengchang/code/acoustic_v2')
7 |
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | import numpy as np
11 |
12 | from pyutils import pad_list, str_to_int_list, remove_outlier
13 | from dataset.texts import text_to_sequence
14 |
15 | def f02pitch(f0):
16 | #f0 =f0 + 0.01
17 | return np.log2(f0 / 27.5) * 12 + 21
18 |
19 | def pitch2f0(pitch):
20 | f0 = np.exp2((pitch - 21 ) / 12) * 27.5
21 | for i in range(len(f0)):
22 | if f0[i] <= 10:
23 | f0[i] = 0
24 | return f0
25 |
26 | def pitchxuv(pitch, uv, to_f0 = False):
27 | result = pitch * uv
28 | if to_f0:
29 | result = pitch2f0(result)
30 | return result
31 |
32 | def pad1d(x, max_len):
33 | return np.pad(x, (0, max_len - len(x)), mode="constant")
34 |
35 | def pad2d(x, max_len):
36 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant")
37 |
38 | class TTSDataset(Dataset):
39 | def __init__(self, config):
40 | audio_manifest = config['audio_manifest']
41 | raw_text_manifest = config['raw_text_manifest']
42 | duration_manifest = config['duration_manifest']
43 | spk_manifest = config['spk_manifest']
44 | self.sampling_rate = config['sampling_rate']
45 | self.utt2path = {}
46 | self.utt2text = {}
47 | self.utt2duration = {}
48 | self.utt2raw_text = {}
49 | self.utt2spk = {}
50 | with open(audio_manifest, 'r') as f:
51 | for line in f:
52 | line = line.rstrip().split(' ')
53 | self.utt2path[line[0]] = line[1]
54 | with open(duration_manifest, 'r') as f:
55 | for line in f:
56 | line = line.rstrip().split('|')
57 | self.utt2text[line[0]] = ' '.join(line[2].split(' ')[0::2])
58 | self.utt2duration[line[0]] = ' '.join(line[2].split(' ')[1::2])
59 | with open(raw_text_manifest, 'r') as f:
60 | for line in f:
61 | line = line.rstrip().split(' ')
62 | self.utt2raw_text[line[0]] = line[1]
63 | with open(spk_manifest, 'r') as f:
64 | for line in f:
65 | line = line.rstrip().split(' ')
66 | self.utt2spk[line[0]] = line[1]
67 | self.utt = list(self.utt2path.keys())
68 | self.spk2int = {spk: idx for idx, spk in enumerate(set(self.utt2spk.values()))}
69 | self.int2spk = {idx: spk for spk, idx in self.spk2int.items()}
70 |
71 | self.use_phonemes = config['use_phonemes']
72 | self.tts_cleaner_names = config['tts_cleaner_names']
73 | self.eos = config['eos']
74 |
75 | def get_spk_number(self):
76 | return len(self.spk2int)
77 |
78 | def _norm_mean_std(self, x, mean, std, is_remove_outlier=False):
79 | if is_remove_outlier:
80 | x = remove_outlier(x)
81 | zero_idxs = np.where(x == 0.0)[0]
82 | x = (x - mean) / std
83 | x[zero_idxs] = 0.0
84 | return x
85 |
86 | def __len__(self):
87 | return len(self.utt)
88 |
89 | def __getitem__(self, idx):
90 | # set_trace()
91 | uttid = self.utt[idx]
92 | mel_path = self.utt2path[uttid].replace('.wav', '.mel.npy')
93 | f0_path = self.utt2path[uttid].replace('.wav', '.f0.npy')
94 | energy_path = self.utt2path[uttid].replace('.wav', '.en.npy')
95 |
96 | mel = np.load(mel_path) #.transpose(1, 0)
97 | f0 = np.load(f0_path)
98 | # pitch = f02pitch(f0)
99 | energy = np.load(energy_path)
100 | raw_text = self.utt2raw_text[uttid]
101 | phone_text = self.utt2text[uttid]
102 | phone_seq = np.array(text_to_sequence(phone_text, self.tts_cleaner_names))
103 | duration = np.array(str_to_int_list(self.utt2duration[uttid]))
104 | spk = self.spk2int[self.utt2spk[uttid]]
105 |
106 | mel_len = mel.shape[0]
107 | duration = duration[: len(phone_seq)]
108 | duration[-1] = duration[-1] + (mel.shape[0] - sum(duration))
109 | assert mel_len == sum(duration), f'{mel_len} != {sum(duration)}'
110 |
111 | return {
112 | 'uttid': uttid,
113 | 'raw_text': raw_text,
114 | 'text': phone_seq,
115 | 'mel': mel,
116 | 'duration': duration,
117 | 'f0': f0,
118 | 'energy': energy,
119 | 'spk': spk
120 | }
121 |
122 | class TTSCollate():
123 | def __init__(self):
124 | pass
125 |
126 | def __call__(self, batch):
127 | ilens = torch.from_numpy(np.array([x['text'].shape[0] for x in batch])).long()
128 | olens = torch.from_numpy(np.array([y['mel'].shape[0] for y in batch])).long()
129 | ids = [x['uttid'] for x in batch]
130 | raw_texts = [x['raw_text'] for x in batch]
131 |
132 | # perform padding and conversion to tensor
133 | inputs = pad_list([torch.from_numpy(x['text']).long() for x in batch], 0)
134 | mels = pad_list([torch.from_numpy(y['mel']).float() for y in batch], 0)
135 |
136 | durations = pad_list([torch.from_numpy(x['duration']).long() for x in batch], 0)
137 | energys = pad_list([torch.from_numpy(y['energy']).float() for y in batch], 0).squeeze(-1)
138 | f0 = pad_list([torch.from_numpy(y['f0']).float() for y in batch], 0).squeeze(-1)
139 | # pitch = pad_list([torch.from_numpy(y['pitch']).float() for y in batch], 0).squeeze(-1)
140 |
141 | spks = torch.tensor([x['spk'] for x in batch], dtype = torch.int64)
142 |
143 | return {
144 | 'uttids': ids,
145 | 'spks': spks,
146 | 'raw_texts': raw_texts,
147 | 'texts': inputs,
148 | 'src_lens': ilens,
149 | 'max_src_len': ilens.max(),
150 | 'mels': mels,
151 | 'mel_lens': olens,
152 | 'max_mel_len': olens.max(),
153 | 'p_targets': f0,
154 | 'e_targets': energys,
155 | 'd_targets': durations
156 | }
157 |
158 | if __name__ == '__main__':
159 | import yaml
160 | from tqdm import tqdm
161 | from torch.utils.data import DataLoader
162 | with open('./configs/data.yaml', 'r') as f:
163 | config = yaml.load(f, Loader = yaml.FullLoader)
164 | dataset = TTSDataset(config)
165 | print(dataset[0]['text'])
166 | print(dataset[0]['duration'])
167 | collate_fn = TTSCollate()
168 | dataloader = DataLoader(dataset, batch_size = 64, shuffle = True, collate_fn = collate_fn, num_workers = 8)
169 | for data in tqdm(dataloader):
170 | assert data['texts'].shape[1] == data['d_targets'].shape[1], "{} != {}".format(data['texts'].shape[1], data['d_targets'].shape[1])
171 | pass
172 | # print(data['texts'].shape)
173 | # print(data['texts'])
174 | # print(data['input_len'])
175 | # print(data['mels'].shape)
176 | # print(data['labels'].shape)
177 | # print(data['output_len'])
178 | # print(data['uttids'])
179 | # print(data['durations'].shape)
180 | # print(data['durations'].sum(dim = 1))
181 | # print(data['energys'].shape)
182 | # print(data['f0s'].shape)
183 | # print(data['raw_texts'])
184 | # break
185 | # print(dataset[0]['mel'].shape)
186 |
--------------------------------------------------------------------------------
/dataset/espnet_texts/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import re
3 | from dataset.texts import cleaners
4 | from dataset.texts.symbols import (
5 | symbols,
6 | _eos,
7 | phonemes_symbols,
8 | PAD,
9 | EOS,
10 | _PHONEME_SEP,
11 | )
12 | from dataset.texts.dict_ import symbols_
13 | import nltk
14 | from g2p_en import G2p
15 |
16 | # Mappings from symbol to numeric ID and vice versa:
17 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
18 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
19 |
20 | # Regular expression matching text enclosed in curly braces:
21 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
22 |
23 | symbols_inv = {v: k for k, v in symbols_.items()}
24 |
25 | valid_symbols = [
26 | "AA",
27 | "AA1",
28 | "AE",
29 | "AE0",
30 | "AE1",
31 | "AH",
32 | "AH0",
33 | "AH1",
34 | "AO",
35 | "AO1",
36 | "AW",
37 | "AW0",
38 | "AW1",
39 | "AY",
40 | "AY0",
41 | "AY1",
42 | "B",
43 | "CH",
44 | "D",
45 | "DH",
46 | "EH",
47 | "EH0",
48 | "EH1",
49 | "ER",
50 | "EY",
51 | "EY0",
52 | "EY1",
53 | "F",
54 | "G",
55 | "HH",
56 | "IH",
57 | "IH0",
58 | "IH1",
59 | "IY",
60 | "IY0",
61 | "IY1",
62 | "JH",
63 | "K",
64 | "L",
65 | "M",
66 | "N",
67 | "NG",
68 | "OW",
69 | "OW0",
70 | "OW1",
71 | "OY",
72 | "OY0",
73 | "OY1",
74 | "P",
75 | "R",
76 | "S",
77 | "SH",
78 | "T",
79 | "TH",
80 | "UH",
81 | "UH0",
82 | "UH1",
83 | "UW",
84 | "UW0",
85 | "UW1",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | "pau",
92 | "sil",
93 | "spn"
94 | ]
95 |
96 |
97 | def pad_with_eos_bos(_sequence):
98 | return _sequence + [_symbol_to_id[_eos]]
99 |
100 |
101 | def text_to_sequence(text, cleaner_names, eos):
102 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
103 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
104 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
105 | Args:
106 | text: string to convert to a sequence
107 | cleaner_names: names of the cleaner functions to run the text through
108 | Returns:
109 | List of integers corresponding to the symbols in the text
110 | """
111 | sequence = []
112 | if eos:
113 | text = text + "~"
114 | try:
115 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
116 | except KeyError:
117 | print("text : ", text)
118 | exit(0)
119 |
120 | return sequence
121 |
122 |
123 | def sequence_to_text(sequence):
124 | """Converts a sequence of IDs back to a string"""
125 | result = ""
126 | for symbol_id in sequence:
127 | if symbol_id in symbols_inv:
128 | s = symbols_inv[symbol_id]
129 | # Enclose ARPAbet back in curly braces:
130 | if len(s) > 1 and s[0] == "@":
131 | s = "{%s}" % s[1:]
132 | result += s
133 | return result.replace("}{", " ")
134 |
135 |
136 | def _clean_text(text, cleaner_names):
137 | for name in cleaner_names:
138 | cleaner = getattr(cleaners, name)
139 | if not cleaner:
140 | raise Exception("Unknown cleaner: %s" % name)
141 | text = cleaner(text)
142 | return text
143 |
144 |
145 | def _symbols_to_sequence(symbols):
146 | return [symbols_[s.upper()] for s in symbols]
147 |
148 |
149 | def _arpabet_to_sequence(text):
150 | return _symbols_to_sequence(["@" + s for s in text.split()])
151 |
152 |
153 | def _should_keep_symbol(s):
154 | return s in _symbol_to_id and s != "_" and s != "~"
155 |
156 |
157 | # For phonemes
158 | _phoneme_to_id = {s: i for i, s in enumerate(valid_symbols)}
159 | _id_to_phoneme = {i: s for i, s in enumerate(valid_symbols)}
160 |
161 |
162 | def _should_keep_token(token, token_dict):
163 | return (
164 | token in token_dict
165 | and token != PAD
166 | and token != EOS
167 | and token != _phoneme_to_id[PAD]
168 | and token != _phoneme_to_id[EOS]
169 | )
170 |
171 |
172 | def phonemes_to_sequence(phonemes):
173 | string = phonemes.split() if isinstance(phonemes, str) else phonemes
174 | # string.append(EOS)
175 | sequence = list(map(convert_phoneme_CMU, string))
176 | sequence = [_phoneme_to_id[s] for s in sequence]
177 | # if _should_keep_token(s, _phoneme_to_id)]
178 | return sequence
179 |
180 |
181 | def sequence_to_phonemes(sequence, use_eos=False):
182 | string = [_id_to_phoneme[idx] for idx in sequence]
183 | # if _should_keep_token(idx, _id_to_phoneme)]
184 | string = _PHONEME_SEP.join(string)
185 | if use_eos:
186 | string = string.replace(EOS, "")
187 | return string
188 |
189 |
190 | def convert_phoneme_CMU(phoneme):
191 | REMAPPING = {
192 | 'AA0': 'AA1',
193 | 'AA2': 'AA1',
194 | 'AE2': 'AE1',
195 | 'AH2': 'AH1',
196 | 'AO0': 'AO1',
197 | 'AO2': 'AO1',
198 | 'AW2': 'AW1',
199 | 'AY2': 'AY1',
200 | 'EH2': 'EH1',
201 | 'ER0': 'EH1',
202 | 'ER1': 'EH1',
203 | 'ER2': 'EH1',
204 | 'EY2': 'EY1',
205 | 'IH2': 'IH1',
206 | 'IY2': 'IY1',
207 | 'OW2': 'OW1',
208 | 'OY2': 'OY1',
209 | 'UH2': 'UH1',
210 | 'UW2': 'UW1',
211 | }
212 | return REMAPPING.get(phoneme, phoneme)
213 |
214 |
215 | def text_to_phonemes(text, custom_words={}):
216 | """
217 | Convert text into ARPAbet.
218 | For known words use CMUDict; for the rest try 'espeak' (to IPA) followed by 'listener'.
219 | :param text: str, input text.
220 | :param custom_words:
221 | dict {str: list of str}, optional
222 | Pronounciations (a list of ARPAbet phonemes) you'd like to override.
223 | Example: {'word': ['W', 'EU1', 'R', 'D']}
224 | :return: list of str, phonemes
225 | """
226 | g2p = G2p()
227 |
228 | """def convert_phoneme_CMU(phoneme):
229 | REMAPPING = {
230 | 'AA0': 'AA1',
231 | 'AA2': 'AA1',
232 | 'AE2': 'AE1',
233 | 'AH2': 'AH1',
234 | 'AO0': 'AO1',
235 | 'AO2': 'AO1',
236 | 'AW2': 'AW1',
237 | 'AY2': 'AY1',
238 | 'EH2': 'EH1',
239 | 'ER0': 'EH1',
240 | 'ER1': 'EH1',
241 | 'ER2': 'EH1',
242 | 'EY2': 'EY1',
243 | 'IH2': 'IH1',
244 | 'IY2': 'IY1',
245 | 'OW2': 'OW1',
246 | 'OY2': 'OY1',
247 | 'UH2': 'UH1',
248 | 'UW2': 'UW1',
249 | }
250 | return REMAPPING.get(phoneme, phoneme)
251 | """
252 |
253 | def convert_phoneme_listener(phoneme):
254 | VOWELS = ['A', 'E', 'I', 'O', 'U']
255 | if phoneme[0] in VOWELS:
256 | phoneme += '1'
257 | return phoneme # convert_phoneme_CMU(phoneme)
258 |
259 | try:
260 | known_words = nltk.corpus.cmudict.dict()
261 | except LookupError:
262 | nltk.download("cmudict")
263 | known_words = nltk.corpus.cmudict.dict()
264 |
265 | for word, phonemes in custom_words.items():
266 | known_words[word.lower()] = [phonemes]
267 |
268 | words = nltk.tokenize.WordPunctTokenizer().tokenize(text.lower())
269 |
270 | phonemes = []
271 | PUNCTUATION = "!?.,-:;\"'()"
272 | for word in words:
273 | if all(c in PUNCTUATION for c in word):
274 | pronounciation = ["pau"]
275 | elif word in known_words:
276 | pronounciation = known_words[word][0]
277 | pronounciation = list(
278 | pronounciation
279 | ) # map(convert_phoneme_CMU, pronounciation))
280 | else:
281 | pronounciation = g2p(word)
282 | pronounciation = list(
283 | pronounciation
284 | ) # (map(convert_phoneme_CMU, pronounciation))
285 |
286 | phonemes += pronounciation
287 |
288 | return phonemes
--------------------------------------------------------------------------------
/dataset/espnet_texts/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | """
14 |
15 |
16 | # Regular expression matching whitespace:
17 | import re
18 | from unidecode import unidecode
19 | from .numbers import normalize_numbers
20 |
21 | _whitespace_re = re.compile(r"\s+")
22 | punctuations = """+-!()[]{};:'"\<>/?@#^&*_~"""
23 |
24 | # List of (regular expression, replacement) pairs for abbreviations:
25 | _abbreviations = [
26 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
27 | for x in [
28 | ("mrs", "misess"),
29 | ("mr", "mister"),
30 | ("dr", "doctor"),
31 | ("st", "saint"),
32 | ("co", "company"),
33 | ("jr", "junior"),
34 | ("maj", "major"),
35 | ("gen", "general"),
36 | ("drs", "doctors"),
37 | ("rev", "reverend"),
38 | ("lt", "lieutenant"),
39 | ("hon", "honorable"),
40 | ("sgt", "sergeant"),
41 | ("capt", "captain"),
42 | ("esq", "esquire"),
43 | ("ltd", "limited"),
44 | ("col", "colonel"),
45 | ("ft", "fort"),
46 | ]
47 | ]
48 |
49 |
50 | def expand_abbreviations(text):
51 | for regex, replacement in _abbreviations:
52 | text = re.sub(regex, replacement, text)
53 | return text
54 |
55 |
56 | def expand_numbers(text):
57 | return normalize_numbers(text)
58 |
59 |
60 | def lowercase(text):
61 | return text.lower()
62 |
63 |
64 | def collapse_whitespace(text):
65 | return re.sub(_whitespace_re, " ", text)
66 |
67 |
68 | def convert_to_ascii(text):
69 | return unidecode(text)
70 |
71 |
72 | def basic_cleaners(text):
73 | """Basic pipeline that lowercases and collapses whitespace without transliteration."""
74 | text = lowercase(text)
75 | text = collapse_whitespace(text)
76 | return text
77 |
78 |
79 | def transliteration_cleaners(text):
80 | """Pipeline for non-English text that transliterates to ASCII."""
81 | text = convert_to_ascii(text)
82 | text = lowercase(text)
83 | text = collapse_whitespace(text)
84 | return text
85 |
86 |
87 | def english_cleaners(text):
88 | """Pipeline for English text, including number and abbreviation expansion."""
89 | text = convert_to_ascii(text)
90 | text = lowercase(text)
91 | text = expand_numbers(text)
92 | text = expand_abbreviations(text)
93 | text = collapse_whitespace(text)
94 | return text
95 |
96 |
97 | def punctuation_removers(text):
98 | no_punct = ""
99 | for char in text:
100 | if char not in punctuations:
101 | no_punct = no_punct + char
102 | return no_punct
--------------------------------------------------------------------------------
/dataset/espnet_texts/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | "AA",
8 | "AA0",
9 | "AA1",
10 | "AA2",
11 | "AE",
12 | "AE0",
13 | "AE1",
14 | "AE2",
15 | "AH",
16 | "AH0",
17 | "AH1",
18 | "AH2",
19 | "AO",
20 | "AO0",
21 | "AO1",
22 | "AO2",
23 | "AW",
24 | "AW0",
25 | "AW1",
26 | "AW2",
27 | "AY",
28 | "AY0",
29 | "AY1",
30 | "AY2",
31 | "B",
32 | "CH",
33 | "D",
34 | "DH",
35 | "EH",
36 | "EH0",
37 | "EH1",
38 | "EH2",
39 | "ER",
40 | "ER0",
41 | "ER1",
42 | "ER2",
43 | "EY",
44 | "EY0",
45 | "EY1",
46 | "EY2",
47 | "F",
48 | "G",
49 | "HH",
50 | "IH",
51 | "IH0",
52 | "IH1",
53 | "IH2",
54 | "IY",
55 | "IY0",
56 | "IY1",
57 | "IY2",
58 | "JH",
59 | "K",
60 | "L",
61 | "M",
62 | "N",
63 | "NG",
64 | "OW",
65 | "OW0",
66 | "OW1",
67 | "OW2",
68 | "OY",
69 | "OY0",
70 | "OY1",
71 | "OY2",
72 | "P",
73 | "R",
74 | "S",
75 | "SH",
76 | "T",
77 | "TH",
78 | "UH",
79 | "UH0",
80 | "UH1",
81 | "UH2",
82 | "UW",
83 | "UW0",
84 | "UW1",
85 | "UW2",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | ]
92 |
93 | _valid_symbol_set = set(valid_symbols)
94 |
95 |
96 | class CMUDict:
97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98 |
99 | def __init__(self, file_or_path, keep_ambiguous=True):
100 | if isinstance(file_or_path, str):
101 | with open(file_or_path, encoding="latin-1") as f:
102 | entries = _parse_cmudict(f)
103 | else:
104 | entries = _parse_cmudict(file_or_path)
105 | if not keep_ambiguous:
106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107 | self._entries = entries
108 |
109 | def __len__(self):
110 | return len(self._entries)
111 |
112 | def lookup(self, word):
113 | """Returns list of ARPAbet pronunciations of the given word."""
114 | return self._entries.get(word.upper())
115 |
116 |
117 | _alt_re = re.compile(r"\([0-9]+\)")
118 |
119 |
120 | def _parse_cmudict(file):
121 | cmudict = {}
122 | for line in file:
123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124 | parts = line.split(" ")
125 | word = re.sub(_alt_re, "", parts[0])
126 | pronunciation = _get_pronunciation(parts[1])
127 | if pronunciation:
128 | if word in cmudict:
129 | cmudict[word].append(pronunciation)
130 | else:
131 | cmudict[word] = [pronunciation]
132 | return cmudict
133 |
134 |
135 | def _get_pronunciation(s):
136 | parts = s.strip().split(" ")
137 | for part in parts:
138 | if part not in _valid_symbol_set:
139 | return None
140 | return " ".join(parts)
--------------------------------------------------------------------------------
/dataset/espnet_texts/dict.py:
--------------------------------------------------------------------------------
1 | symbols_ = {
2 | "": 1,
3 | "!": 2,
4 | "'": 3,
5 | ",": 4,
6 | ".": 5,
7 | " ": 6,
8 | "?": 7,
9 | "A": 8,
10 | "B": 9,
11 | "C": 10,
12 | "D": 11,
13 | "E": 12,
14 | "F": 13,
15 | "G": 14,
16 | "H": 15,
17 | "I": 16,
18 | "J": 17,
19 | "K": 18,
20 | "L": 19,
21 | "M": 20,
22 | "N": 21,
23 | "O": 22,
24 | "P": 23,
25 | "Q": 24,
26 | "R": 25,
27 | "S": 26,
28 | "T": 27,
29 | "U": 28,
30 | "V": 29,
31 | "W": 30,
32 | "X": 31,
33 | "Y": 32,
34 | "Z": 33,
35 | "~": 34,
36 | }
--------------------------------------------------------------------------------
/dataset/espnet_texts/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars" # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return "%s %s" % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return "%s %s" % (cents, cent_unit)
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(
60 | num, andword="", zero="oh", group=2
61 | ).replace(", ", " ")
62 | else:
63 | return _inflect.number_to_words(num, andword="")
64 |
65 |
66 | def normalize_numbers(text):
67 | text = re.sub(_comma_number_re, _remove_commas, text)
68 | text = re.sub(_pounds_re, r"\1 pounds", text)
69 | text = re.sub(_dollars_re, _expand_dollars, text)
70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71 | text = re.sub(_ordinal_re, _expand_ordinal, text)
72 | text = re.sub(_number_re, _expand_number, text)
73 | return text
--------------------------------------------------------------------------------
/dataset/espnet_texts/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
7 |
8 | from dataset.texts import cmudict
9 |
10 | _pad = "_"
11 | _eos = "~"
12 | _bos = "^"
13 | _punctuation = "!'(),.:;? "
14 | _special = "-"
15 | _letters = "abcdefghijklmnopqrstuvwxyz"
16 |
17 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
18 | # _arpabet = ['@' + s for s in cmudict.valid_symbols]
19 |
20 | # Export all symbols:
21 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + [_eos]
22 |
23 | # For Phonemes
24 |
25 | PAD = "#"
26 | EOS = "~"
27 | PHONEME_CODES = "AA1 AE0 AE1 AH0 AH1 AO0 AO1 AW0 AW1 AY0 AY1 B CH D DH EH0 EH1 EU0 EU1 EY0 EY1 F G HH IH0 IH1 IY0 IY1 JH K L M N NG OW0 OW1 OY0 OY1 P R S SH T TH UH0 UH1 UW0 UW1 V W Y Z ZH pau".split()
28 | _PHONEME_SEP = " "
29 |
30 | phonemes_symbols = [PAD, EOS] + PHONEME_CODES # PAD should be first to have zero id
--------------------------------------------------------------------------------
/dataset/texts/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import re
3 | # import sys
4 | # sys.path.insert(0, '/home/zengchang/code/acoustic_v2')
5 | from dataset.texts import cleaners
6 | from dataset.texts.symbols import symbols
7 |
8 | from ipdb import set_trace
9 |
10 | # Mappings from symbol to numeric ID and vice versa:
11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
13 |
14 | # Regular expression matching text enclosed in curly braces:
15 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
16 |
17 |
18 | def text_to_sequence(text, cleaner_names):
19 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
20 |
21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
23 |
24 | Args:
25 | text: string to convert to a sequence
26 | cleaner_names: names of the cleaner functions to run the text through
27 |
28 | Returns:
29 | List of integers corresponding to the symbols in the text
30 | """
31 | # set_trace()
32 | sequence = []
33 |
34 | # Check for curly braces and treat their contents as ARPAbet:
35 | # while len(text):
36 | # m = _curly_re.match(text)
37 |
38 | # if not m:
39 | # clean_text = _clean_text(text, cleaner_names).split(' ')
40 | # sequence += _symbols_to_sequence(clean_text)
41 | # break
42 | # clean_text1 = _clean_text(m.group(1), cleaner_names).split(' ')
43 | # sequence += _symbols_to_sequence(clean_text1)
44 | # clean_text2 = m.group(2).split(' ')
45 | # sequence += _arpabet_to_sequence(clean_text2)
46 | # text = m.group(3)
47 | sequence += _arpabet_to_sequence(text)
48 |
49 | return sequence
50 |
51 |
52 | def sequence_to_text(sequence):
53 | """Converts a sequence of IDs back to a string"""
54 | result = ""
55 | for symbol_id in sequence:
56 | if symbol_id in _id_to_symbol:
57 | s = _id_to_symbol[symbol_id]
58 | # Enclose ARPAbet back in curly braces:
59 | if len(s) > 1 and s[0] == "@":
60 | s = "{%s}" % s[1:]
61 | result += s
62 | return result.replace("}{", " ")
63 |
64 | def _clean_text(text, cleaner_names):
65 | for name in cleaner_names:
66 | cleaner = getattr(cleaners, name)
67 | if not cleaner:
68 | raise Exception("Unknown cleaner: %s" % name)
69 | text = cleaner(text)
70 | return text
71 |
72 | def _symbols_to_sequence(symbols):
73 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
74 |
75 | def _arpabet_to_sequence(text):
76 | return _symbols_to_sequence(["@" + s for s in text.split()])
77 |
78 | def _should_keep_symbol(s):
79 | return s in _symbol_to_id and s != "_" and s != "~"
80 |
81 | if __name__ == "__main__":
82 | text = 'Turn left on {HH AW1 S S T AH0 N} Street.'
83 | print(text_to_sequence(text, ['english_cleaners']))
--------------------------------------------------------------------------------
/dataset/texts/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 | import re
18 | from unidecode import unidecode
19 | from .numbers import normalize_numbers
20 | _whitespace_re = re.compile(r'\s+')
21 |
22 | # List of (regular expression, replacement) pairs for abbreviations:
23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
24 | ('mrs', 'misess'),
25 | ('mr', 'mister'),
26 | ('dr', 'doctor'),
27 | ('st', 'saint'),
28 | ('co', 'company'),
29 | ('jr', 'junior'),
30 | ('maj', 'major'),
31 | ('gen', 'general'),
32 | ('drs', 'doctors'),
33 | ('rev', 'reverend'),
34 | ('lt', 'lieutenant'),
35 | ('hon', 'honorable'),
36 | ('sgt', 'sergeant'),
37 | ('capt', 'captain'),
38 | ('esq', 'esquire'),
39 | ('ltd', 'limited'),
40 | ('col', 'colonel'),
41 | ('ft', 'fort'),
42 | ]]
43 |
44 |
45 | def expand_abbreviations(text):
46 | for regex, replacement in _abbreviations:
47 | text = re.sub(regex, replacement, text)
48 | return text
49 |
50 |
51 | def expand_numbers(text):
52 | return normalize_numbers(text)
53 |
54 |
55 | def lowercase(text):
56 | return text.lower()
57 |
58 |
59 | def collapse_whitespace(text):
60 | return re.sub(_whitespace_re, ' ', text)
61 |
62 |
63 | def convert_to_ascii(text):
64 | return unidecode(text)
65 |
66 |
67 | def basic_cleaners(text):
68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
69 | text = lowercase(text)
70 | text = collapse_whitespace(text)
71 | return text
72 |
73 |
74 | def transliteration_cleaners(text):
75 | '''Pipeline for non-English text that transliterates to ASCII.'''
76 | text = convert_to_ascii(text)
77 | text = lowercase(text)
78 | text = collapse_whitespace(text)
79 | return text
80 |
81 |
82 | def english_cleaners(text):
83 | '''Pipeline for English text, including number and abbreviation expansion.'''
84 | text = convert_to_ascii(text)
85 | text = lowercase(text)
86 | text = expand_numbers(text)
87 | text = expand_abbreviations(text)
88 | text = collapse_whitespace(text)
89 | return text
--------------------------------------------------------------------------------
/dataset/texts/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | "AA",
8 | "AA0",
9 | "AA1",
10 | "AA2",
11 | "AE",
12 | "AE0",
13 | "AE1",
14 | "AE2",
15 | "AH",
16 | "AH0",
17 | "AH1",
18 | "AH2",
19 | "AO",
20 | "AO0",
21 | "AO1",
22 | "AO2",
23 | "AW",
24 | "AW0",
25 | "AW1",
26 | "AW2",
27 | "AY",
28 | "AY0",
29 | "AY1",
30 | "AY2",
31 | "B",
32 | "CH",
33 | "D",
34 | "DH",
35 | "EH",
36 | "EH0",
37 | "EH1",
38 | "EH2",
39 | "ER",
40 | "ER0",
41 | "ER1",
42 | "ER2",
43 | "EY",
44 | "EY0",
45 | "EY1",
46 | "EY2",
47 | "F",
48 | "G",
49 | "HH",
50 | "IH",
51 | "IH0",
52 | "IH1",
53 | "IH2",
54 | "IY",
55 | "IY0",
56 | "IY1",
57 | "IY2",
58 | "JH",
59 | "K",
60 | "L",
61 | "M",
62 | "N",
63 | "NG",
64 | "OW",
65 | "OW0",
66 | "OW1",
67 | "OW2",
68 | "OY",
69 | "OY0",
70 | "OY1",
71 | "OY2",
72 | "P",
73 | "R",
74 | "S",
75 | "SH",
76 | "T",
77 | "TH",
78 | "UH",
79 | "UH0",
80 | "UH1",
81 | "UH2",
82 | "UW",
83 | "UW0",
84 | "UW1",
85 | "UW2",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | ]
92 |
93 | _valid_symbol_set = set(valid_symbols)
94 |
95 |
96 | class CMUDict:
97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98 |
99 | def __init__(self, file_or_path, keep_ambiguous=True):
100 | if isinstance(file_or_path, str):
101 | with open(file_or_path, encoding="latin-1") as f:
102 | entries = _parse_cmudict(f)
103 | else:
104 | entries = _parse_cmudict(file_or_path)
105 | if not keep_ambiguous:
106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107 | self._entries = entries
108 |
109 | def __len__(self):
110 | return len(self._entries)
111 |
112 | def lookup(self, word):
113 | """Returns list of ARPAbet pronunciations of the given word."""
114 | return self._entries.get(word.upper())
115 |
116 |
117 | _alt_re = re.compile(r"\([0-9]+\)")
118 |
119 |
120 | def _parse_cmudict(file):
121 | cmudict = {}
122 | for line in file:
123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124 | parts = line.split(" ")
125 | word = re.sub(_alt_re, "", parts[0])
126 | pronunciation = _get_pronunciation(parts[1])
127 | if pronunciation:
128 | if word in cmudict:
129 | cmudict[word].append(pronunciation)
130 | else:
131 | cmudict[word] = [pronunciation]
132 | return cmudict
133 |
134 |
135 | def _get_pronunciation(s):
136 | parts = s.strip().split(" ")
137 | for part in parts:
138 | if part not in _valid_symbol_set:
139 | return None
140 | return " ".join(parts)
--------------------------------------------------------------------------------
/dataset/texts/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars" # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return "%s %s" % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return "%s %s" % (cents, cent_unit)
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(
60 | num, andword="", zero="oh", group=2
61 | ).replace(", ", " ")
62 | else:
63 | return _inflect.number_to_words(num, andword="")
64 |
65 |
66 | def normalize_numbers(text):
67 | text = re.sub(_comma_number_re, _remove_commas, text)
68 | text = re.sub(_pounds_re, r"\1 pounds", text)
69 | text = re.sub(_dollars_re, _expand_dollars, text)
70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71 | text = re.sub(_ordinal_re, _expand_ordinal, text)
72 | text = re.sub(_number_re, _expand_number, text)
73 | return text
--------------------------------------------------------------------------------
/dataset/texts/pinyin.py:
--------------------------------------------------------------------------------
1 | initials = [
2 | "b",
3 | "c",
4 | "ch",
5 | "d",
6 | "f",
7 | "g",
8 | "h",
9 | "j",
10 | "k",
11 | "l",
12 | "m",
13 | "n",
14 | "p",
15 | "q",
16 | "r",
17 | "s",
18 | "sh",
19 | "t",
20 | "w",
21 | "x",
22 | "y",
23 | "z",
24 | "zh",
25 | ]
26 | finals = [
27 | "a1",
28 | "a2",
29 | "a3",
30 | "a4",
31 | "a5",
32 | "ai1",
33 | "ai2",
34 | "ai3",
35 | "ai4",
36 | "ai5",
37 | "an1",
38 | "an2",
39 | "an3",
40 | "an4",
41 | "an5",
42 | "ang1",
43 | "ang2",
44 | "ang3",
45 | "ang4",
46 | "ang5",
47 | "ao1",
48 | "ao2",
49 | "ao3",
50 | "ao4",
51 | "ao5",
52 | "e1",
53 | "e2",
54 | "e3",
55 | "e4",
56 | "e5",
57 | "ei1",
58 | "ei2",
59 | "ei3",
60 | "ei4",
61 | "ei5",
62 | "en1",
63 | "en2",
64 | "en3",
65 | "en4",
66 | "en5",
67 | "eng1",
68 | "eng2",
69 | "eng3",
70 | "eng4",
71 | "eng5",
72 | "er1",
73 | "er2",
74 | "er3",
75 | "er4",
76 | "er5",
77 | "i1",
78 | "i2",
79 | "i3",
80 | "i4",
81 | "i5",
82 | "ia1",
83 | "ia2",
84 | "ia3",
85 | "ia4",
86 | "ia5",
87 | "ian1",
88 | "ian2",
89 | "ian3",
90 | "ian4",
91 | "ian5",
92 | "iang1",
93 | "iang2",
94 | "iang3",
95 | "iang4",
96 | "iang5",
97 | "iao1",
98 | "iao2",
99 | "iao3",
100 | "iao4",
101 | "iao5",
102 | "ie1",
103 | "ie2",
104 | "ie3",
105 | "ie4",
106 | "ie5",
107 | "ii1",
108 | "ii2",
109 | "ii3",
110 | "ii4",
111 | "ii5",
112 | "iii1",
113 | "iii2",
114 | "iii3",
115 | "iii4",
116 | "iii5",
117 | "in1",
118 | "in2",
119 | "in3",
120 | "in4",
121 | "in5",
122 | "ing1",
123 | "ing2",
124 | "ing3",
125 | "ing4",
126 | "ing5",
127 | "iong1",
128 | "iong2",
129 | "iong3",
130 | "iong4",
131 | "iong5",
132 | "iou1",
133 | "iou2",
134 | "iou3",
135 | "iou4",
136 | "iou5",
137 | "o1",
138 | "o2",
139 | "o3",
140 | "o4",
141 | "o5",
142 | "ong1",
143 | "ong2",
144 | "ong3",
145 | "ong4",
146 | "ong5",
147 | "ou1",
148 | "ou2",
149 | "ou3",
150 | "ou4",
151 | "ou5",
152 | "u1",
153 | "u2",
154 | "u3",
155 | "u4",
156 | "u5",
157 | "ua1",
158 | "ua2",
159 | "ua3",
160 | "ua4",
161 | "ua5",
162 | "uai1",
163 | "uai2",
164 | "uai3",
165 | "uai4",
166 | "uai5",
167 | "uan1",
168 | "uan2",
169 | "uan3",
170 | "uan4",
171 | "uan5",
172 | "uang1",
173 | "uang2",
174 | "uang3",
175 | "uang4",
176 | "uang5",
177 | "uei1",
178 | "uei2",
179 | "uei3",
180 | "uei4",
181 | "uei5",
182 | "uen1",
183 | "uen2",
184 | "uen3",
185 | "uen4",
186 | "uen5",
187 | "uo1",
188 | "uo2",
189 | "uo3",
190 | "uo4",
191 | "uo5",
192 | "v1",
193 | "v2",
194 | "v3",
195 | "v4",
196 | "v5",
197 | "van1",
198 | "van2",
199 | "van3",
200 | "van4",
201 | "van5",
202 | "ve1",
203 | "ve2",
204 | "ve3",
205 | "ve4",
206 | "ve5",
207 | "vn1",
208 | "vn2",
209 | "vn3",
210 | "vn4",
211 | "vn5",
212 | ]
213 | valid_symbols = initials + finals + ["rr"]
--------------------------------------------------------------------------------
/dataset/texts/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
7 |
8 | from dataset.texts import cmudict, pinyin
9 |
10 | _pad = "_~"
11 | _punctuation = "!'(),.:;? "
12 | _special = "-"
13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
14 | _silences = ["@sp", "@spn", "@sil"]
15 | # _silences = ["sp", "spn", "sil"]
16 |
17 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
18 | # _arpabet = [s for s in cmudict.valid_symbols]
19 | _arpabet = ["@" + s for s in cmudict.valid_symbols]
20 | # _pinyin = [s for s in pinyin.valid_symbols]
21 | _pinyin = ["@" + s for s in pinyin.valid_symbols]
22 |
23 | # Export all symbols:
24 | symbols = (
25 | [_pad]
26 | + list(_special)
27 | + list(_punctuation)
28 | + list(_letters)
29 | + _arpabet
30 | + _pinyin
31 | + _silences
32 | )
33 |
34 | # symbols
35 | '''
36 | ['_',
37 | '~',
38 | '-',
39 | '!',
40 | "'",
41 | '(',
42 | ')',
43 | ',',
44 | '.',
45 | ':',
46 | ';',
47 | '?',
48 | ' ',
49 | 'A',
50 | 'B',
51 | 'C',
52 | 'D',
53 | 'E',
54 | 'F',
55 | 'G',
56 | 'H',
57 | 'I',
58 | 'J',
59 | 'K',
60 | 'L',
61 | 'M',
62 | 'N',
63 | 'O',
64 | 'P',
65 | 'Q',
66 | 'R',
67 | 'S',
68 | 'T',
69 | 'U',
70 | 'V',
71 | 'W',
72 | 'X',
73 | 'Y',
74 | 'Z',
75 | 'a',
76 | 'b',
77 | 'c',
78 | 'd',
79 | 'e',
80 | 'f',
81 | 'g',
82 | 'h',
83 | 'i',
84 | 'j',
85 | 'k',
86 | 'l',
87 | 'm',
88 | 'n',
89 | 'o',
90 | 'p',
91 | 'q',
92 | 'r',
93 | 's',
94 | 't',
95 | 'u',
96 | 'v',
97 | 'w',
98 | 'x',
99 | 'y',
100 | 'z',
101 | '@AA',
102 | '@AA0',
103 | '@AA1',
104 | '@AA2',
105 | '@AE',
106 | '@AE0',
107 | '@AE1',
108 | '@AE2',
109 | '@AH',
110 | '@AH0',
111 | '@AH1',
112 | '@AH2',
113 | '@AO',
114 | '@AO0',
115 | '@AO1',
116 | '@AO2',
117 | '@AW',
118 | '@AW0',
119 | '@AW1',
120 | '@AW2',
121 | '@AY',
122 | '@AY0',
123 | '@AY1',
124 | '@AY2',
125 | '@B',
126 | '@CH',
127 | '@D',
128 | '@DH',
129 | '@EH',
130 | '@EH0',
131 | '@EH1',
132 | '@EH2',
133 | '@ER',
134 | '@ER0',
135 | '@ER1',
136 | '@ER2',
137 | '@EY',
138 | '@EY0',
139 | '@EY1',
140 | '@EY2',
141 | '@F',
142 | '@G',
143 | '@HH',
144 | '@IH',
145 | '@IH0',
146 | '@IH1',
147 | '@IH2',
148 | '@IY',
149 | '@IY0',
150 | '@IY1',
151 | '@IY2',
152 | '@JH',
153 | '@K',
154 | '@L',
155 | '@M',
156 | '@N',
157 | '@NG',
158 | '@OW',
159 | '@OW0',
160 | '@OW1',
161 | '@OW2',
162 | '@OY',
163 | '@OY0',
164 | '@OY1',
165 | '@OY2',
166 | '@P',
167 | '@R',
168 | '@S',
169 | '@SH',
170 | '@T',
171 | '@TH',
172 | '@UH',
173 | '@UH0',
174 | '@UH1',
175 | '@UH2',
176 | '@UW',
177 | '@UW0',
178 | '@UW1',
179 | '@UW2',
180 | '@V',
181 | '@W',
182 | '@Y',
183 | '@Z',
184 | '@ZH',
185 | '@b',
186 | '@c',
187 | '@ch',
188 | '@d',
189 | '@f',
190 | '@g',
191 | '@h',
192 | '@j',
193 | '@k',
194 | '@l',
195 | '@m',
196 | '@n',
197 | '@p',
198 | '@q',
199 | '@r',
200 | '@s',
201 | '@sh',
202 | '@t',
203 | '@w',
204 | '@x',
205 | '@y',
206 | '@z',
207 | '@zh',
208 | '@a1',
209 | '@a2',
210 | '@a3',
211 | '@a4',
212 | '@a5',
213 | '@ai1',
214 | '@ai2',
215 | '@ai3',
216 | '@ai4',
217 | '@ai5',
218 | '@an1',
219 | '@an2',
220 | '@an3',
221 | '@an4',
222 | '@an5',
223 | '@ang1',
224 | '@ang2',
225 | '@ang3',
226 | '@ang4',
227 | '@ang5',
228 | '@ao1',
229 | '@ao2',
230 | '@ao3',
231 | '@ao4',
232 | '@ao5',
233 | '@e1',
234 | '@e2',
235 | '@e3',
236 | '@e4',
237 | '@e5',
238 | '@ei1',
239 | '@ei2',
240 | '@ei3',
241 | '@ei4',
242 | '@ei5',
243 | '@en1',
244 | '@en2',
245 | '@en3',
246 | '@en4',
247 | '@en5',
248 | '@eng1',
249 | '@eng2',
250 | '@eng3',
251 | '@eng4',
252 | '@eng5',
253 | '@er1',
254 | '@er2',
255 | '@er3',
256 | '@er4',
257 | '@er5',
258 | '@i1',
259 | '@i2',
260 | '@i3',
261 | '@i4',
262 | '@i5',
263 | '@ia1',
264 | '@ia2',
265 | '@ia3',
266 | '@ia4',
267 | '@ia5',
268 | '@ian1',
269 | '@ian2',
270 | '@ian3',
271 | '@ian4',
272 | '@ian5',
273 | '@iang1',
274 | '@iang2',
275 | '@iang3',
276 | '@iang4',
277 | '@iang5',
278 | '@iao1',
279 | '@iao2',
280 | '@iao3',
281 | '@iao4',
282 | '@iao5',
283 | '@ie1',
284 | '@ie2',
285 | '@ie3',
286 | '@ie4',
287 | '@ie5',
288 | '@ii1',
289 | '@ii2',
290 | '@ii3',
291 | '@ii4',
292 | '@ii5',
293 | '@iii1',
294 | '@iii2',
295 | '@iii3',
296 | '@iii4',
297 | '@iii5',
298 | '@in1',
299 | '@in2',
300 | '@in3',
301 | '@in4',
302 | '@in5',
303 | '@ing1',
304 | '@ing2',
305 | '@ing3',
306 | '@ing4',
307 | '@ing5',
308 | '@iong1',
309 | '@iong2',
310 | '@iong3',
311 | '@iong4',
312 | '@iong5',
313 | '@iou1',
314 | '@iou2',
315 | '@iou3',
316 | '@iou4',
317 | '@iou5',
318 | '@o1',
319 | '@o2',
320 | '@o3',
321 | '@o4',
322 | '@o5',
323 | '@ong1',
324 | '@ong2',
325 | '@ong3',
326 | '@ong4',
327 | '@ong5',
328 | '@ou1',
329 | '@ou2',
330 | '@ou3',
331 | '@ou4',
332 | '@ou5',
333 | '@u1',
334 | '@u2',
335 | '@u3',
336 | '@u4',
337 | '@u5',
338 | '@ua1',
339 | '@ua2',
340 | '@ua3',
341 | '@ua4',
342 | '@ua5',
343 | '@uai1',
344 | '@uai2',
345 | '@uai3',
346 | '@uai4',
347 | '@uai5',
348 | '@uan1',
349 | '@uan2',
350 | '@uan3',
351 | '@uan4',
352 | '@uan5',
353 | '@uang1',
354 | '@uang2',
355 | '@uang3',
356 | '@uang4',
357 | '@uang5',
358 | '@uei1',
359 | '@uei2',
360 | '@uei3',
361 | '@uei4',
362 | '@uei5',
363 | '@uen1',
364 | '@uen2',
365 | '@uen3',
366 | '@uen4',
367 | '@uen5',
368 | '@uo1',
369 | '@uo2',
370 | '@uo3',
371 | '@uo4',
372 | '@uo5',
373 | '@v1',
374 | '@v2',
375 | '@v3',
376 | '@v4',
377 | '@v5',
378 | '@van1',
379 | '@van2',
380 | '@van3',
381 | '@van4',
382 | '@van5',
383 | '@ve1',
384 | '@ve2',
385 | '@ve3',
386 | '@ve4',
387 | '@ve5',
388 | '@vn1',
389 | '@vn2',
390 | '@vn3',
391 | '@vn4',
392 | '@vn5',
393 | '@rr',
394 | '@sp',
395 | '@spn',
396 | '@sil']
397 | '''
398 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .fastspeech2_loss import *
2 | from .loss import FastSpeech2Loss, FeatLoss, LSGANDLoss, LSGANGLoss
3 |
--------------------------------------------------------------------------------
/loss/fastspeech2_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from typing import Optional
4 |
5 | class PitchPredictorLoss(nn.Module):
6 | """Loss function module for duration predictor.
7 |
8 | The loss value is Calculated in log domain to make it Gaussian.
9 |
10 | """
11 |
12 | def __init__(self, offset=1.0):
13 | """Initilize duration predictor loss module.
14 |
15 | Args:
16 | offset (float, optional): Offset value to avoid nan in log domain.
17 |
18 | """
19 | super(PitchPredictorLoss, self).__init__()
20 | self.criterion = nn.MSELoss()
21 | self.offset = offset
22 |
23 | def forward(self, outputs, targets):
24 | """Calculate forward propagation.
25 |
26 | Args:
27 | outputs (Tensor): Batch of prediction durations in log domain (B, T)
28 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
29 |
30 | Returns:
31 | Tensor: Mean squared error loss value.
32 |
33 | Note:
34 | `outputs` is in log domain but `targets` is in linear domain.
35 |
36 | """
37 | # NOTE: We convert the output in log domain low error value
38 | # print("Output :", outputs[0])
39 | # print("Before Output :", targets[0])
40 | # targets = torch.log(targets.float() + self.offset)
41 | # print("Before Output :", targets[0])
42 | # outputs = torch.log(outputs.float() + self.offset)
43 | loss = self.criterion(outputs, targets)
44 | # print(loss)
45 | return loss
46 |
47 |
48 | class EnergyPredictorLoss(nn.Module):
49 | """Loss function module for duration predictor.
50 |
51 | The loss value is Calculated in log domain to make it Gaussian.
52 |
53 | """
54 |
55 | def __init__(self, offset=1.0):
56 | """Initilize duration predictor loss module.
57 |
58 | Args:
59 | offset (float, optional): Offset value to avoid nan in log domain.
60 |
61 | """
62 | super(EnergyPredictorLoss, self).__init__()
63 | self.criterion = nn.MSELoss()
64 | self.offset = offset
65 |
66 | def forward(self, outputs, targets):
67 | """Calculate forward propagation.
68 |
69 | Args:
70 | outputs (Tensor): Batch of prediction durations in log domain (B, T)
71 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
72 |
73 | Returns:
74 | Tensor: Mean squared error loss value.
75 |
76 | Note:
77 | `outputs` is in log domain but `targets` is in linear domain.
78 |
79 | """
80 | # NOTE: outputs is in log domain while targets in linear
81 | # targets = torch.log(targets.float() + self.offset)
82 | loss = self.criterion(outputs, targets)
83 |
84 | return loss
85 |
86 | class DurationPredictorLoss(nn.Module):
87 | """Loss function module for duration predictor.
88 |
89 | The loss value is Calculated in log domain to make it Gaussian.
90 |
91 | """
92 |
93 | def __init__(self, offset=1.0):
94 | """Initilize duration predictor loss module.
95 |
96 | Args:
97 | offset (float, optional): Offset value to avoid nan in log domain.
98 |
99 | """
100 | super(DurationPredictorLoss, self).__init__()
101 | self.criterion = nn.MSELoss()
102 | self.offset = offset
103 |
104 | def forward(self, outputs, targets):
105 | """Calculate forward propagation.
106 |
107 | Args:
108 | outputs (Tensor): Batch of prediction durations in log domain (B, T)
109 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
110 |
111 | Returns:
112 | Tensor: Mean squared error loss value.
113 |
114 | Note:
115 | `outputs` is in log domain but `targets` is in linear domain.
116 |
117 | """
118 | # NOTE: outputs is in log domain while targets in linear
119 | targets = torch.log(targets.float() + self.offset)
120 | loss = self.criterion(outputs, targets)
121 |
122 | return loss
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class FastSpeech2Loss(nn.Module):
5 | """ FastSpeech2 Loss """
6 |
7 | def __init__(self, data_config):
8 | super(FastSpeech2Loss, self).__init__()
9 | self.pitch_feature_level = data_config["pitch"]["feature"]
10 | self.energy_feature_level = data_config["energy"]["feature"]
11 | self.mse_loss = nn.MSELoss()
12 | self.mae_loss = nn.L1Loss()
13 |
14 | def forward(self, inputs, predictions):
15 | (
16 | mel_targets,
17 | _,
18 | _,
19 | pitch_targets,
20 | energy_targets,
21 | uv_targets,
22 | duration_targets,
23 | ) = inputs
24 |
25 | (
26 | mel_predictions,
27 | postnet_mel_predictions,
28 | pitch_predictions,
29 | energy_predictions,
30 | uv_predictions,
31 | log_duration_predictions,
32 | _,
33 | src_masks,
34 | mel_masks,
35 | _,
36 | _,
37 | ) = predictions
38 |
39 | src_masks = ~src_masks
40 | mel_masks = ~mel_masks
41 | log_duration_targets = torch.log(duration_targets.float() + 1)
42 | mel_targets = mel_targets[:, : mel_masks.shape[1], :]
43 | mel_masks = mel_masks[:, :mel_masks.shape[1]]
44 |
45 | log_duration_targets.requires_grad = False
46 | pitch_targets.requires_grad = False
47 | energy_targets.requires_grad = False
48 | mel_targets.requires_grad = False
49 | if not uv_targets is None:
50 | uv_targets.requires_grad = False
51 |
52 | if self.pitch_feature_level == "phoneme_level":
53 | pitch_predictions = pitch_predictions.masked_select(src_masks)
54 | pitch_targets = pitch_targets.masked_select(src_masks)
55 | elif self.pitch_feature_level == "frame_level":
56 | pitch_predictions = pitch_predictions.masked_select(mel_masks)
57 | pitch_targets = pitch_targets.masked_select(mel_masks)
58 |
59 | if self.energy_feature_level == "phoneme_level":
60 | energy_predictions = energy_predictions.masked_select(src_masks)
61 | energy_targets = energy_targets.masked_select(src_masks)
62 | if self.energy_feature_level == "frame_level":
63 | energy_predictions = energy_predictions.masked_select(mel_masks)
64 | energy_targets = energy_targets.masked_select(mel_masks)
65 |
66 | log_duration_predictions = log_duration_predictions.masked_select(src_masks)
67 | log_duration_targets = log_duration_targets.masked_select(src_masks)
68 |
69 | if not uv_targets is None:
70 | uv_predictions = uv_predictions.masked_select(mel_masks)
71 | uv_targets = uv_targets.masked_select(mel_masks)
72 |
73 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
74 | postnet_mel_predictions = postnet_mel_predictions.masked_select(
75 | mel_masks.unsqueeze(-1)
76 | )
77 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
78 |
79 | mel_loss = self.mae_loss(mel_predictions, mel_targets)
80 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
81 |
82 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
83 | energy_loss = self.mse_loss(energy_predictions, energy_targets)
84 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
85 | total_loss = (mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss)
86 | uv_loss = None
87 | if not uv_targets is None:
88 | uv_loss = self.mse_loss(uv_predictions, uv_targets)
89 | total_loss = (
90 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss + 0.1 * uv_loss
91 | )
92 |
93 | return (
94 | total_loss,
95 | mel_loss,
96 | postnet_mel_loss,
97 | pitch_loss,
98 | energy_loss,
99 | 0 if uv_loss is None else uv_loss,
100 | duration_loss,
101 | )
102 |
103 | class FeatLoss(nn.Module):
104 | '''
105 | feature loss (multi-band discriminator)
106 | '''
107 | def __init__(self, feat_loss_weight = (1.0, 1.0, 1.0)):
108 | super(FeatLoss, self).__init__()
109 | self.loss_d = nn.MSELoss() #.to(self.device)
110 | self.feat_loss_weight = feat_loss_weight
111 |
112 | def forward(self, D_fake):
113 | feat_g_loss = 0.0
114 | feat_loss = [0.0] * len(D_fake)
115 | report_keys = {}
116 | for j in range(len(D_fake)):
117 | for k in range(len(D_fake[j][0])):
118 | for n in range(len(D_fake[j][0][k][1])):
119 | if len(D_fake[j][0][k][1][n].shape) == 4:
120 | t_batch = D_fake[j][0][k][1][n].shape[0]
121 | t_length = D_fake[j][0][k][1][n].shape[-1]
122 | D_fake[j][0][k][1][n] = D_fake[j][0][k][1][n].view(t_batch, t_length,-1)
123 | D_fake[j][1][k][1][n] = D_fake[j][1][k][1][n].view(t_batch, t_length,-1)
124 | feat_loss[j] += self.loss_d(D_fake[j][0][k][1][n], D_fake[j][1][k][1][n]) * 2
125 | feat_loss[j] /= (n + 1)
126 | feat_loss[j] /= (k + 1)
127 | feat_loss[j] *= self.feat_loss_weight[j]
128 | report_keys['feat_loss_' + str(j)] = feat_loss[j]
129 | feat_g_loss += feat_loss[j]
130 |
131 | return feat_g_loss, report_keys
132 |
133 | class LSGANGLoss(nn.Module):
134 | def __init__(self, adv_loss_weight):
135 | super(LSGANGLoss, self).__init__()
136 | self.loss_d = nn.MSELoss() #.to(self.device)
137 | self.adv_loss_weight = adv_loss_weight
138 |
139 | def forward(self, D_fake):
140 | adv_g_loss = 0.0
141 | adv_loss = [0.0] * len(D_fake)
142 | report_keys = {}
143 | for j in range(len(D_fake)):
144 | for k in range(len(D_fake[j][0])):
145 | adv_loss[j] += self.loss_d(D_fake[j][0][k][0], D_fake[j][0][k][0].new_ones(D_fake[j][0][k][0].size()))
146 | adv_loss[j] /= (k + 1)
147 | adv_loss[j] *= self.adv_loss_weight[j]
148 | report_keys['adv_g_loss_' + str(j)] = adv_loss[j]
149 | adv_g_loss += adv_loss[j]
150 | return adv_g_loss, report_keys
151 |
152 | class LSGANDLoss(nn.Module):
153 | def __init__(self):
154 | super(LSGANDLoss, self).__init__()
155 | self.loss_d = nn.MSELoss()
156 |
157 | def forward(self, D_fake):
158 | adv_d_loss = 0.0
159 | adv_loss = [0.0] * len(D_fake)
160 | real_loss = [0.0] * len(D_fake)
161 | fake_loss = [0.0] * len(D_fake)
162 | report_keys = {}
163 | for j in range(len(D_fake)):
164 | for k in range(len(D_fake[j][0])):
165 | real_loss[j] += self.loss_d(D_fake[j][1][k][0], D_fake[j][1][k][0].new_ones(D_fake[j][1][k][0].size()))
166 | fake_loss[j] += self.loss_d(D_fake[j][0][k][0], D_fake[j][0][k][0].new_zeros(D_fake[j][0][k][0].size()))
167 | real_loss[j] /= (k + 1)
168 | fake_loss[j] /= (k + 1)
169 | adv_loss[j] = 0.5 * (real_loss[j] + fake_loss[j])
170 | report_keys['adv_d_loss_' + str(j)] = adv_loss[j]
171 | adv_d_loss += adv_loss[j]
172 | return adv_d_loss, report_keys
173 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .fastspeech2 import *
2 | from .xiaoice2 import *
3 | from .discriminator import *
4 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import logging
5 |
6 | class GLU(nn.Module):
7 | def __init__(self):
8 | super(GLU, self).__init__()
9 | # Custom Implementation because the Voice Conversion Cycle GAN
10 | # paper assumes GLU won't reduce the dimension of tensor by 2.
11 |
12 | def forward(self, input):
13 | return input * torch.sigmoid(input)
14 |
15 | class DiscriminatorFactory(nn.Module):
16 | def __init__(self,
17 | time_length,
18 | freq_length,
19 | conv_channel,
20 | ):
21 | super(DiscriminatorFactory, self).__init__()
22 |
23 | layers = 10
24 | conv_channels = conv_channel
25 | kernel_size = 3
26 | conv_in_channels = 60
27 | use_weight_norm = True
28 |
29 | self.conv_layers = torch.nn.ModuleList()
30 | for i in range(layers - 1):
31 | if i == 0:
32 | dilation = 1
33 | else:
34 | dilation = 1
35 | conv_in_channels = conv_channels
36 | padding = (kernel_size - 1) // 2 * dilation
37 | conv_layer = [
38 | nn.Conv1d(conv_in_channels, conv_channels,
39 | kernel_size=kernel_size, padding=padding,
40 | dilation=dilation, bias=True),
41 | nn.LeakyReLU(0.2, inplace=True),
42 | #nn.BatchNorm1d(conv_channels)
43 | ]
44 | self.conv_layers += conv_layer
45 | padding = (kernel_size - 1) // 2
46 | last_conv_layer = nn.Conv1d(
47 | conv_in_channels, 1,
48 | kernel_size=kernel_size, padding=padding, bias=True)
49 | self.conv_layers += [last_conv_layer]
50 |
51 | # apply weight norm
52 | if use_weight_norm:
53 | self.apply_weight_norm()
54 |
55 | def apply_weight_norm(self):
56 | """Apply weight normalization module from all of the layers."""
57 | def _apply_weight_norm(m):
58 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
59 | torch.nn.utils.weight_norm(m)
60 | logging.debug(f"weight norm is applied to {m}.")
61 | self.apply(_apply_weight_norm)
62 |
63 | def remove_weight_norm(self):
64 | """Remove weight normalization module from all of the layers."""
65 | def _remove_weight_norm(m):
66 | try:
67 | logging.debug(f"weight norm is removed from {m}.")
68 | torch.nn.utils.remove_weight_norm(m)
69 | except ValueError: # this module didn't have weight norm
70 | return
71 | self.apply(_remove_weight_norm)
72 |
73 | def forward(self, x):
74 | """
75 | Args:
76 | x: (B, C, T), by default, C = 40.
77 |
78 | Returns:
79 | tensor: (B, 1, T)
80 | """
81 | feature_list = []
82 | i = 1
83 | for f in self.conv_layers:
84 | x = f(x)
85 | if i % 2 == 1:
86 | feature_list.append(x)
87 | i += 1
88 | return [x, feature_list]
89 |
90 |
91 | class MultiWindowDiscriminator(nn.Module):
92 | """docstring for MultiWindowDiscriminator"""
93 | def __init__(self,
94 | time_lengths,
95 | freq_lengths,
96 | conv_channels,
97 | ):
98 | super(MultiWindowDiscriminator, self).__init__()
99 | self.win_lengths = time_lengths
100 |
101 | self.conv_layers = nn.ModuleList()
102 | self.patch_layers = nn.ModuleList()
103 | for time_length, freq_length, conv_channel in zip(time_lengths, freq_lengths, conv_channels):
104 | conv_layer = [
105 | DiscriminatorFactory(np.abs(time_length), freq_length, conv_channel),
106 | ] # 1d
107 | self.conv_layers += conv_layer
108 | patch_layer = [PatchGAN()]
109 | self.patch_layers += patch_layer
110 |
111 |
112 | def clip(self, x, x_len, win_length, y=None, random_N=None):
113 | '''Ramdom clip x to win_length.
114 | Args:
115 | x (tensor) : (B, T, C).
116 | x_len (tensor) : (B,).
117 | win_length (int): target clip length
118 |
119 | Returns:
120 | (tensor) : (B, win_length, C).
121 |
122 | '''
123 | x_batch = []
124 | y_batch = []
125 | T_end = win_length
126 | if T_end > 0:
127 | cursor = 1
128 | else:
129 | cursor = -1
130 | min_a = min(x_len)
131 | if np.abs(T_end) + random_N > min_a:
132 | T_end = min_a - random_N - 1
133 | T_end = T_end * cursor
134 | #print(x_len, random_N, win_length, T_end)
135 | for i in range(x.size(0)):
136 | if T_end < 0:
137 | x_batch += [x[i, x_len[i].cpu() + T_end - random_N: x_len[i].cpu() - random_N, :].unsqueeze(0)]
138 | else:
139 | x_batch += [x[i, random_N : T_end + random_N, :].unsqueeze(0)]
140 | if y != None:
141 | if T_end < 0:
142 | y_batch += [y[i, x_len[i].cpu() + T_end - random_N: x_len[i].cpu() - random_N, :].unsqueeze(0)]
143 | else:
144 | y_batch += [y[i, random_N : T_end+ random_N, :].unsqueeze(0)]
145 |
146 | x_batch = torch.cat(x_batch, 0)
147 | if y != None:
148 | y_batch = torch.cat(y_batch, 0)
149 | if y != None:
150 | return x_batch, y_batch
151 | else:
152 | return x_batch
153 |
154 | def forward(self, x, x_len, y=None, random_N=None):
155 | '''
156 | Args:
157 | x (tensor): input mel, (B, T, C).
158 | x_length (tensor): len of per mel. (B,).
159 |
160 | Returns:
161 | tensor : (B).
162 | '''
163 | validity_x = list()
164 | validity_y = list()
165 | #validity = 0.0
166 | for i in range(len(self.conv_layers)):
167 | if y != None:
168 | if self.win_lengths[i] != 1:
169 | x_clip,y_clip = self.clip(x, x_len, self.win_lengths[i], y, random_N[i]) # (B, win_length, C)
170 | else:
171 | #print(x.shape, y.shape)
172 | #x_clip, y_clip = x[:,:1300,:], y[:,:1300,:]
173 | x_clip, y_clip = x, y
174 | y_clip = y_clip.transpose(2,1)
175 |
176 | else:
177 | if self.win_lengths[i] != 1:
178 | x_clip = self.clip(x, x_len, self.win_lengths[i], y, random_N[i]) # (B, win_length, C)
179 | else:
180 | #print(x.shape)
181 | #x_clip = x[:,:1300, :]
182 | x_clip = x
183 |
184 | x_clip = x_clip.transpose(2, 1) # (B, C, win_length)
185 | x_clip_r = self.conv_layers[i](x_clip) # 1d
186 | validity_x += [x_clip_r]
187 | x_clip_r = self.patch_layers[i](x_clip) # 2d
188 | validity_x += [x_clip_r]
189 | if y!= None:
190 | y_clip_r = self.conv_layers[i](y_clip)
191 | validity_y += [y_clip_r]
192 | y_clip_r = self.patch_layers[i](y_clip)
193 | validity_y += [y_clip_r]
194 |
195 | #validity += x_clip
196 | if y == None:
197 | return validity_x
198 | else:
199 | return validity_x, validity_y
200 |
201 | class PatchGAN(nn.Module): #
202 | def __init__(self):
203 | super(PatchGAN, self).__init__()
204 |
205 | self.convLayer1 = nn.Sequential(nn.Conv2d(in_channels=1,
206 | out_channels=32,
207 | kernel_size=(3, 3),
208 | stride=(1, 1),
209 | padding=(1, 1)),
210 | GLU())
211 |
212 | # DownSample Layer
213 | self.downSample1 = self.downSample(in_channels=32,
214 | out_channels=64,
215 | kernel_size=(3, 3),
216 | stride=(2, 2),
217 | padding=1)
218 |
219 | self.downSample2 = self.downSample(in_channels=64,
220 | out_channels=128,
221 | kernel_size=(3, 3),
222 | stride=[2, 2],
223 | padding=1)
224 | # Conv Layer
225 | self.outputConvLayer_2 = nn.Sequential(nn.Conv2d(in_channels=128,
226 | out_channels=1,
227 | kernel_size=(1, 3),
228 | stride=[1, 1],
229 | padding=[0, 1])
230 | )
231 |
232 | self.downSample3 = self.downSample(in_channels=128,
233 | out_channels=256,
234 | kernel_size=[3, 3],
235 | stride=[2, 2],
236 | padding=1)
237 | # Conv Layer
238 | self.outputConvLayer_3 = nn.Sequential(nn.Conv2d(in_channels=256,
239 | out_channels=1,
240 | kernel_size=(1, 3),
241 | stride=[1, 1],
242 | padding=[0, 1]))
243 |
244 | self.downSample4 = self.downSample(in_channels=256,
245 | out_channels=512,
246 | kernel_size=[3, 3],
247 | stride=[2, 2],
248 | padding=1)
249 | # Conv Layer
250 | self.outputConvLayer_4 = nn.Sequential(nn.Conv2d(in_channels=512,
251 | out_channels=1,
252 | kernel_size=(1, 3),
253 | stride=[1, 1],
254 | padding=[0, 1]))
255 | self.downSample5 = self.downSample(in_channels=512,
256 | out_channels=1024,
257 | kernel_size=[3, 3],
258 | stride=[2, 2],
259 | padding=1)
260 |
261 |
262 | # Conv Layer
263 | self.outputConvLayer = nn.Sequential(nn.Conv2d(in_channels=1024,
264 | out_channels=1,
265 | kernel_size=(1, 3),
266 | stride=[1, 1],
267 | padding=[0, 1]))
268 |
269 | def downSample(self, in_channels, out_channels, kernel_size, stride, padding):
270 | convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels,
271 | out_channels=out_channels,
272 | kernel_size=kernel_size,
273 | stride=stride,
274 | padding=padding),
275 | nn.InstanceNorm2d(num_features=out_channels,
276 | affine=True),
277 | GLU())
278 | return convLayer
279 |
280 | def forward(self, input):
281 | # input has shape [batch_size, num_features, time]
282 | # discriminator requires shape [batchSize, 1, num_features, time]
283 | input = input.unsqueeze(1)
284 | #print("input : {}".format(input.shape))
285 | feature_list = []
286 | conv_layer_1 = self.convLayer1(input)
287 | feature_list.append(conv_layer_1)
288 | #print("conv_layer_1: {}".format(conv_layer_1.shape))
289 |
290 | downsample1 = self.downSample1(conv_layer_1)
291 | feature_list.append(downsample1)
292 | #output_1 = torch.sigmoid(self.outputConvLayer_1(downsample1))
293 | #print("downsample1 {} output_1 {}".format(downsample1.shape, output_1.shape))
294 | downsample2 = self.downSample2(downsample1)
295 | feature_list.append(downsample2)
296 | output_2 = torch.sigmoid(self.outputConvLayer_2(downsample2))
297 | #print("downsample2 {} output_2 {}".format(downsample2.shape, output_2.shape))
298 | downsample3 = self.downSample3(downsample2)
299 | feature_list.append(downsample3)
300 | output_3 = torch.sigmoid(self.outputConvLayer_3(downsample3))
301 | #print("downsample3 {} output_3 {}".format(downsample3.shape, output_3.shape))
302 | downsample4 = self.downSample4(downsample3)
303 | feature_list.append(downsample4)
304 | output_4 = torch.sigmoid(self.outputConvLayer_4(downsample4))
305 | #print("downsample4 {} output_4 {}".format(downsample4.shape, output_4.shape))
306 | downsample5 = self.downSample5(downsample4)
307 | feature_list.append(downsample5)
308 | #print("downsample5 {}".format(downsample5.shape))
309 |
310 | output = torch.sigmoid(self.outputConvLayer(downsample5))
311 | #print("output {} ".format(output.shape))
312 | output = output.view(output.shape[0], output.shape[1], -1)
313 | output_4 = output_4.view(output.shape[0], output.shape[1], -1)
314 | output_3 = output_3.view(output.shape[0], output.shape[1], -1)
315 | output_2 = output_2.view(output.shape[0], output.shape[1], -1)
316 | #output_1 = output_1.view(output.shape[0], output.shape[1], -1)
317 | output = torch.cat((output,output_4,output_3, output_2), axis=2)
318 | #output = output.view(output.shape[0], output.shape[1], -1)
319 | return [output, feature_list]
320 |
321 | class MultibandFrequencyDiscriminator(nn.Module):
322 | def __init__(self,
323 | time_lengths=[200, 400, 600, 800, 1],
324 | freq_lengths=[ 60, 60, 60, 60, 60],
325 | multi_channels=[[87, 87, 87, 87, 87 ], [87, 87, 87, 87, 87], [87,87, 87, 87, 87]]
326 | ):
327 | super(MultibandFrequencyDiscriminator, self).__init__()
328 |
329 | self.time_lengths = time_lengths
330 | self.multi_win_discriminator_low = MultiWindowDiscriminator(time_lengths, freq_lengths, multi_channels[0])
331 | self.multi_win_discriminator_middle = MultiWindowDiscriminator(time_lengths, freq_lengths, multi_channels[1])
332 | self.multi_win_discriminator_high = MultiWindowDiscriminator(time_lengths, freq_lengths, multi_channels[2])
333 |
334 | def forward(self, x, x_len, y=None, random_N=[]):
335 | '''
336 | Args:
337 | x (tensor): input mel, (B, T, C).
338 | x_length (tensor): len of per mel. (B,).
339 |
340 | Returns:
341 | list : [(B), (B,), (B,)].
342 | '''
343 | if len(random_N) == 0:
344 | len_min = min(x_len.cpu())
345 | time_max = max(self.time_lengths)
346 | start = 0
347 | end = len_min - time_max
348 | if end <= 0:
349 | end = int(len_min / 2)
350 | random_N = np.random.randint(start, end , len(self.time_lengths))
351 |
352 | #print(x_len)
353 | base_mel = x[:,:,:120]
354 | xa = base_mel[:,:,:60]
355 | xb = base_mel[:,:,30:90]
356 | xc = base_mel[:,:,60:120]
357 | if y != None:
358 | y_mel = y[:, :, :120]
359 | ya = y_mel[:,:,:60]
360 | yb = y_mel[:,:, 30:90]
361 | yc = y_mel[:,:,60:120]
362 | else:
363 | ya = yb = yc = None
364 |
365 |
366 | x_list = [
367 | self.multi_win_discriminator_low(xa, x_len, ya, random_N),
368 | self.multi_win_discriminator_middle(xb, x_len,yb, random_N),
369 | self.multi_win_discriminator_high(xc, x_len, yc, random_N),
370 | ]
371 | return x_list, random_N
372 |
373 | class Discriminator(nn.Module):
374 | def __init__(self):
375 | super(Discriminator, self).__init__()
376 |
377 | self.discriminator = MultibandFrequencyDiscriminator()
378 |
379 | def forward(self, x, x_len, y=None, random_N=[]):
380 | return self.discriminator(x, x_len, y, random_N)
381 |
382 |
383 | if __name__ == "__main__":
384 | inputs = torch.randn(4, 1200, 120).cuda()
385 | tgt = torch.randn(4, 1200, 120).cuda()
386 | inputs_len = torch.tensor([1200]).cuda()
387 | net = Discriminator().cuda()
388 | print(net)
389 | outputs, random_n = net(inputs, inputs_len, tgt)
390 | # import pdb; pdb.set_trace()
391 | for output in outputs:
392 | for a in output[0]:
393 | for aa in a:
394 | for aaa in aa:
395 | print(aaa.shape)
396 | for b in output[1]:
397 | for bb in b:
398 | for bbb in bb:
399 | print(bbb.shape)
400 | # print(output[0].shape)
401 | # print(output[1].shape)
402 |
--------------------------------------------------------------------------------
/models/fastspeech2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from loss import FastSpeech2Loss
5 |
6 | from modules.transformer import Encoder, Decoder, PostNet
7 | from modules.variance.modules import VarianceAdaptor
8 | from pyutils import get_mask_from_lengths
9 |
10 | class FastSpeech2(nn.Module):
11 | """ FastSpeech2 """
12 |
13 | def __init__(self, data_config, model_config):
14 | super(FastSpeech2, self).__init__()
15 | self.model_config = model_config
16 |
17 | self.encoder = Encoder(**model_config['transformer']['encoder'])
18 | self.variance_adaptor = VarianceAdaptor(data_config, model_config)
19 | self.decoder = Decoder(**model_config['transformer']['decoder'])
20 | self.mel_linear = nn.Linear(
21 | model_config["transformer"]["decoder"]['d_word_vec'],
22 | data_config["n_mels"],
23 | )
24 | self.postnet = PostNet(data_config['n_mels'], **model_config['postnet'])
25 |
26 | self.speaker_emb = None
27 | if model_config["multi_speaker"]:
28 | n_speaker = model_config['spk_num']
29 | self.speaker_emb = nn.Embedding(
30 | n_speaker,
31 | model_config["transformer"]["encoder"]["d_word_vec"],
32 | )
33 | self.loss = FastSpeech2Loss(data_config, model_config)
34 |
35 | def forward(
36 | self,
37 | spks,
38 | texts,
39 | src_lens,
40 | max_src_len,
41 | mels=None,
42 | mel_lens=None,
43 | max_mel_len=None,
44 | p_targets=None,
45 | e_targets=None,
46 | d_targets=None,
47 | p_control=1.0,
48 | e_control=1.0,
49 | d_control=1.0,
50 | ):
51 | src_masks = get_mask_from_lengths(src_lens, max_src_len)
52 | mel_masks = (
53 | get_mask_from_lengths(mel_lens, max_mel_len)
54 | if mel_lens is not None
55 | else None
56 | )
57 | output = self.encoder(texts, src_masks)
58 |
59 | if self.speaker_emb is not None:
60 | output = output + self.speaker_emb(spks).unsqueeze(1).expand(
61 | -1, max_src_len, -1
62 | )
63 |
64 | (
65 | output,
66 | p_predictions,
67 | e_predictions,
68 | log_d_predictions,
69 | d_rounded,
70 | mel_lens,
71 | mel_masks,
72 | ) = self.variance_adaptor(
73 | output,
74 | src_masks,
75 | mel_masks,
76 | max_mel_len,
77 | p_targets,
78 | e_targets,
79 | d_targets,
80 | p_control,
81 | e_control,
82 | d_control,
83 | )
84 |
85 | output, mel_masks = self.decoder(output, mel_masks)
86 | output = self.mel_linear(output)
87 |
88 | postnet_output = self.postnet(output) + output
89 |
90 | outputs = (output, postnet_output, p_predictions, e_predictions, log_d_predictions, d_rounded, src_masks, mel_masks, src_lens, mel_lens)
91 | (total_loss, mel_loss, post_mel_loss, pitch_loss, energy_loss, duration_loss) = self.loss((mels, mel_lens, max_mel_len, p_targets, e_targets, d_targets), outputs)
92 | report_keys = {
93 | 'loss': total_loss,
94 | 'mel_loss': mel_loss,
95 | 'post_mel_loss': post_mel_loss,
96 | 'pitch_loss': pitch_loss,
97 | 'energy_loss': energy_loss,
98 | 'duration_loss': duration_loss
99 | }
100 | return total_loss, report_keys, output, postnet_output
101 |
102 | if __name__ == "__main__":
103 | import yaml
104 | with open('/home/zengchang/code/acoustic_v2/configs/data.yaml', 'r') as f:
105 | data_config = yaml.load(f, Loader = yaml.FullLoader)
106 | with open('/home/zengchang/code/acoustic_v2/configs/model.yaml', 'r') as f:
107 | model_config = yaml.load(f, Loader = yaml.FullLoader)
108 | model = FastSpeech2(data_config, model_config['generator'])
109 | print(model)
110 |
--------------------------------------------------------------------------------
/models/xiaoice2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from loss import FastSpeech2Loss
5 |
6 | from modules.transformer import Encoder, Decoder, PostNet
7 | from modules.variance.modules import VarianceAdaptor
8 | from pyutils import get_mask_from_lengths
9 |
10 | class Xiaoice2(nn.Module):
11 | """ Xiaoice2 """
12 |
13 | def __init__(self, data_config, model_config):
14 | super(Xiaoice2, self).__init__()
15 | self.model_config = model_config
16 |
17 | self.encoder = Encoder(**model_config['transformer']['encoder'])
18 | self.variance_adaptor = VarianceAdaptor(data_config, model_config)
19 | self.decoder = Decoder(**model_config['transformer']['decoder'])
20 | self.mel_linear = nn.Linear(
21 | model_config["transformer"]["decoder"]['d_word_vec'],
22 | data_config["n_mels"],
23 | )
24 | self.postnet = PostNet(data_config['n_mels'], **model_config['postnet'])
25 |
26 | self.speaker_emb = None
27 | if model_config["multi_speaker"]:
28 | n_speaker = model_config['spk_num']
29 | self.speaker_emb = nn.Embedding(
30 | n_speaker,
31 | model_config["transformer"]["encoder"]["d_word_vec"],
32 | )
33 | self.loss = FastSpeech2Loss(data_config)
34 |
35 | def forward(
36 | self,
37 | texts,
38 | note_pitchs,
39 | note_durations,
40 | src_lens,
41 | max_src_len,
42 | mels=None,
43 | mel_lens=None,
44 | max_mel_len=None,
45 | p_targets=None,
46 | e_targets=None,
47 | uv_targets=None,
48 | d_targets=None,
49 | p_control=1.0,
50 | e_control=1.0,
51 | d_control=1.0,
52 | spks=None
53 | ):
54 | src_masks = get_mask_from_lengths(src_lens, max_src_len)
55 | mel_masks = (
56 | get_mask_from_lengths(mel_lens, max_mel_len)
57 | if mel_lens is not None
58 | else None
59 | )
60 | output = self.encoder(texts, note_pitchs, note_durations, src_masks)
61 |
62 | if self.speaker_emb is not None:
63 | output = output + self.speaker_emb(spks).unsqueeze(1).expand(
64 | -1, max_src_len, -1
65 | )
66 |
67 | (
68 | output,
69 | p_predictions,
70 | e_predictions,
71 | uv_predictions,
72 | log_d_predictions,
73 | d_rounded,
74 | mel_lens,
75 | mel_masks,
76 | ) = self.variance_adaptor(
77 | output,
78 | src_masks,
79 | mel_masks,
80 | max_mel_len,
81 | p_targets,
82 | e_targets,
83 | uv_targets,
84 | d_targets,
85 | p_control,
86 | e_control,
87 | d_control,
88 | )
89 |
90 | output, mel_masks = self.decoder(output, mel_masks)
91 | output = self.mel_linear(output)
92 |
93 | postnet_output = self.postnet(output) + output
94 |
95 | outputs = (output, postnet_output, p_predictions, e_predictions, uv_predictions, log_d_predictions, d_rounded, src_masks, mel_masks, src_lens, mel_lens)
96 |
97 | (total_loss, mel_loss, post_mel_loss, pitch_loss, energy_loss, uv_loss, duration_loss) = self.loss((mels, mel_lens, max_mel_len, p_targets, e_targets, uv_targets, d_targets), outputs)
98 |
99 | report_keys = {
100 | 'loss': total_loss,
101 | 'mel_loss': mel_loss,
102 | 'post_mel_loss': post_mel_loss,
103 | 'pitch_loss': pitch_loss,
104 | 'energy_loss': energy_loss,
105 | 'uv_loss': uv_loss,
106 | 'duration_loss': duration_loss
107 | }
108 |
109 | return total_loss, report_keys, output, postnet_output
110 |
111 | if __name__ == "__main__":
112 | import yaml
113 | with open('/home/zengchang/code/acoustic_v2/configs/data.yaml', 'r') as f:
114 | data_config = yaml.load(f, Loader = yaml.FullLoader)
115 | with open('/home/zengchang/code/acoustic_v2/configs/model.yaml', 'r') as f:
116 | model_config = yaml.load(f, Loader = yaml.FullLoader)
117 | model = Xiaoice2(data_config, model_config['generator'])
118 | print(model)
119 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer.Models import Encoder, Decoder
2 | from .transformer.Layers import PostNet
--------------------------------------------------------------------------------
/modules/conv/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/modules/conv/__init__.py
--------------------------------------------------------------------------------
/modules/transformer/Constants.py:
--------------------------------------------------------------------------------
1 | PAD = 0
2 | UNK = 1
3 | BOS = 2
4 | EOS = 3
5 |
6 | PAD_WORD = ""
7 | UNK_WORD = ""
8 | BOS_WORD = ""
9 | EOS_WORD = ""
10 |
--------------------------------------------------------------------------------
/modules/transformer/Layers.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | from torch.nn import functional as F
7 |
8 | try:
9 | from modules_v2.transformer.sublayer import MultiHeadAttention, PositionwiseFeedForward, MultiLayeredConv1d
10 | from modules_v2.transformer.embedding import PositionalEncoding
11 | from modules_v2.transformer.layer import EncoderLayer
12 | except (ImportError, ModuleNotFoundError):
13 | import sys
14 | import os
15 | filepath = os.path.dirname(os.path.abspath(__file__))
16 | sys.path.insert(0, filepath)
17 | from SubLayers import MultiHeadAttention, PositionwiseFeedForward
18 |
19 | class FFTBlock(torch.nn.Module):
20 | """FFT Block"""
21 |
22 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
23 | super(FFTBlock, self).__init__()
24 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
25 | self.pos_ffn = PositionwiseFeedForward(
26 | d_model, d_inner, kernel_size, dropout=dropout
27 | )
28 |
29 | def forward(self, enc_input, mask=None, slf_attn_mask=None):
30 | enc_output, enc_slf_attn = self.slf_attn(
31 | enc_input, enc_input, enc_input, mask=slf_attn_mask
32 | )
33 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
34 |
35 | enc_output = self.pos_ffn(enc_output)
36 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
37 |
38 | return enc_output, enc_slf_attn
39 |
40 |
41 | class ConvNorm(torch.nn.Module):
42 | def __init__(
43 | self,
44 | in_channels,
45 | out_channels,
46 | kernel_size=1,
47 | stride=1,
48 | padding=None,
49 | dilation=1,
50 | bias=True,
51 | w_init_gain="linear",
52 | ):
53 | super(ConvNorm, self).__init__()
54 |
55 | if padding is None:
56 | assert kernel_size % 2 == 1
57 | padding = int(dilation * (kernel_size - 1) / 2)
58 |
59 | self.conv = torch.nn.Conv1d(
60 | in_channels,
61 | out_channels,
62 | kernel_size=kernel_size,
63 | stride=stride,
64 | padding=padding,
65 | dilation=dilation,
66 | bias=bias,
67 | )
68 |
69 | def forward(self, signal):
70 | conv_signal = self.conv(signal)
71 |
72 | return conv_signal
73 |
74 |
75 | class PostNet(nn.Module):
76 | """
77 | PostNet: Five 1-d convolution with 512 channels and kernel size 5
78 | """
79 |
80 | def __init__(
81 | self,
82 | n_mels=80,
83 | postnet_embedding_dim=512,
84 | postnet_kernel_size=5,
85 | postnet_n_convolutions=5,
86 | ):
87 |
88 | super(PostNet, self).__init__()
89 | self.convolutions = nn.ModuleList()
90 |
91 | self.convolutions.append(
92 | nn.Sequential(
93 | ConvNorm(
94 | n_mels,
95 | postnet_embedding_dim,
96 | kernel_size=postnet_kernel_size,
97 | stride=1,
98 | padding=int((postnet_kernel_size - 1) / 2),
99 | dilation=1,
100 | w_init_gain="tanh",
101 | ),
102 | nn.BatchNorm1d(postnet_embedding_dim),
103 | )
104 | )
105 |
106 | for i in range(1, postnet_n_convolutions - 1):
107 | self.convolutions.append(
108 | nn.Sequential(
109 | ConvNorm(
110 | postnet_embedding_dim,
111 | postnet_embedding_dim,
112 | kernel_size=postnet_kernel_size,
113 | stride=1,
114 | padding=int((postnet_kernel_size - 1) / 2),
115 | dilation=1,
116 | w_init_gain="tanh",
117 | ),
118 | nn.BatchNorm1d(postnet_embedding_dim),
119 | )
120 | )
121 |
122 | self.convolutions.append(
123 | nn.Sequential(
124 | ConvNorm(
125 | postnet_embedding_dim,
126 | n_mels,
127 | kernel_size=postnet_kernel_size,
128 | stride=1,
129 | padding=int((postnet_kernel_size - 1) / 2),
130 | dilation=1,
131 | w_init_gain="linear",
132 | ),
133 | nn.BatchNorm1d(n_mels),
134 | )
135 | )
136 |
137 | def forward(self, x):
138 | x = x.contiguous().transpose(1, 2)
139 |
140 | for i in range(len(self.convolutions) - 1):
141 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
142 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
143 |
144 | x = x.contiguous().transpose(1, 2)
145 | return x
146 |
147 | if __name__ == "__main__":
148 | import sys
149 | sys.path.insert(0, '/home/zengchang/code/acoustic_v2/modules_v2/transformer')
150 | fft_block = FFTBlock(512, 8, 64, 64, 2048, [3,3])
151 | x = torch.randn(2, 100, 512)
152 | mask = torch.ones(2, 100).bool()
153 | slf_attn_mask = torch.ones(2, 100, 100).bool()
154 | y, attn = fft_block(x, mask, slf_attn_mask)
155 | print(y.shape)
156 | print(attn.shape)
--------------------------------------------------------------------------------
/modules/transformer/Models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | try:
6 | import modules.transformer.Constants as Constants
7 | from modules.transformer.Layers import FFTBlock
8 | from dataset.texts.symbols import symbols
9 | except:
10 | import sys
11 | import os
12 | filepath = '/'.join(os.path.dirname(os.path.abspath(__file__)).split('/')[:-2])
13 | print(filepath)
14 | sys.path.insert(0, filepath)
15 | import modules.transformer.Constants as Constants
16 | from modules.transformer.Layers import FFTBlock
17 | from dataset.texts.symbols import symbols
18 |
19 |
20 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
21 | """ Sinusoid position encoding table """
22 |
23 | def cal_angle(position, hid_idx):
24 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
25 |
26 | def get_posi_angle_vec(position):
27 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
28 |
29 | sinusoid_table = np.array(
30 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
31 | )
32 |
33 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
34 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
35 |
36 | if padding_idx is not None:
37 | # zero vector for padding dimension
38 | sinusoid_table[padding_idx] = 0.0
39 |
40 | return torch.FloatTensor(sinusoid_table)
41 |
42 |
43 | class Encoder(nn.Module):
44 | """ Encoder """
45 |
46 | def __init__(
47 | self, max_seq_len, n_src_vocab, d_word_vec,
48 | n_layers, n_head, d_model, d_inner, max_note_pitch,
49 | max_note_duration, kernel_size, dropout=0.1
50 | ):
51 | super(Encoder, self).__init__()
52 |
53 | n_position = max_seq_len + 1
54 | n_src_vocab = n_src_vocab
55 | d_word_vec = d_word_vec
56 | n_layers = n_layers
57 | n_head = n_head
58 | d_k = d_v = (d_word_vec // n_head)
59 | d_model = d_model
60 | d_inner = d_inner
61 | kernel_size = kernel_size
62 | dropout = dropout
63 |
64 | # self.max_seq_len = config["max_seq_len"]
65 | self.max_seq_len = max_seq_len
66 | self.d_model = d_model
67 |
68 | self.src_word_emb = nn.Embedding(
69 | n_src_vocab, d_word_vec, padding_idx = Constants.PAD
70 | )
71 | self.note_pitch_emb = nn.Embedding(
72 | max_note_pitch, d_word_vec, padding_idx = Constants.PAD
73 | )
74 | self.note_duration_emb = nn.Embedding(
75 | max_note_duration, d_word_vec, padding_idx = Constants.PAD
76 | )
77 | self.position_enc = nn.Parameter(
78 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
79 | requires_grad=False,
80 | )
81 |
82 | self.layer_stack = nn.ModuleList(
83 | [
84 | FFTBlock(
85 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
86 | )
87 | for _ in range(n_layers)
88 | ]
89 | )
90 |
91 | def forward(self, src_seq, note_pitchs, note_durations, mask, return_attns=False):
92 |
93 | enc_slf_attn_list = []
94 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
95 |
96 | # -- Prepare masks
97 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
98 |
99 | # -- Forward
100 | if not self.training and src_seq.shape[1] > self.max_seq_len:
101 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table(
102 | src_seq.shape[1], self.d_model
103 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
104 | src_seq.device
105 | )
106 | else:
107 | # print("training!!!!!!!!!!")
108 | enc_output = self.src_word_emb(src_seq) \
109 | + self.note_pitch_emb(note_pitchs) \
110 | + self.note_duration_emb(note_durations) \
111 | + self.position_enc[:, :max_len, :].expand(batch_size, -1, -1)
112 |
113 | for enc_layer in self.layer_stack:
114 | enc_output, enc_slf_attn = enc_layer(
115 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask
116 | )
117 | if return_attns:
118 | enc_slf_attn_list += [enc_slf_attn]
119 |
120 | return enc_output
121 |
122 |
123 | class Decoder(nn.Module):
124 | """ Decoder """
125 |
126 | def __init__(
127 | self, max_seq_len, d_word_vec,
128 | n_layers, n_head, d_model, d_inner,
129 | kernel_size, dropout=0.1
130 | ):
131 | super(Decoder, self).__init__()
132 |
133 | n_position = max_seq_len + 1
134 | d_word_vec = d_word_vec
135 | n_layers = n_layers
136 | n_head = n_head
137 | d_k = d_v = (d_word_vec // n_head)
138 | d_model = d_model
139 | d_inner = d_inner
140 | kernel_size = kernel_size
141 | dropout = dropout
142 |
143 | # self.max_seq_len = config["max_seq_len"]
144 | self.max_seq_len = max_seq_len
145 | self.d_model = d_model
146 |
147 | self.position_enc = nn.Parameter(
148 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
149 | requires_grad=False,
150 | )
151 |
152 | self.layer_stack = nn.ModuleList(
153 | [
154 | FFTBlock(
155 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
156 | )
157 | for _ in range(n_layers)
158 | ]
159 | )
160 |
161 | def forward(self, enc_seq, mask, return_attns=False):
162 |
163 | dec_slf_attn_list = []
164 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
165 |
166 | # -- Forward
167 | if not self.training and enc_seq.shape[1] > self.max_seq_len:
168 | # -- Prepare masks
169 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
170 | dec_output = enc_seq + get_sinusoid_encoding_table(
171 | enc_seq.shape[1], self.d_model
172 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
173 | enc_seq.device
174 | )
175 | else:
176 | max_len = min(max_len, self.max_seq_len)
177 |
178 | # -- Prepare masks
179 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
180 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[
181 | :, :max_len, :
182 | ].expand(batch_size, -1, -1)
183 | mask = mask[:, :max_len]
184 | slf_attn_mask = slf_attn_mask[:, :, :max_len]
185 |
186 | for dec_layer in self.layer_stack:
187 | dec_output, dec_slf_attn = dec_layer(
188 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask
189 | )
190 | if return_attns:
191 | dec_slf_attn_list += [dec_slf_attn]
192 |
193 | return dec_output, mask
194 |
195 | if __name__ == "__main__":
196 | encoder = Encoder({"max_seq_len": 100, "transformer": {"encoder_hidden": 256, "encoder_layer": 6, "encoder_head": 4, "conv_filter_size": 1024, "conv_kernel_size": [3,3], "encoder_dropout": 0.1}})
197 | decoder = Decoder({"max_seq_len": 100, "transformer": {"decoder_hidden": 256, "decoder_layer": 6, "decoder_head": 4, "conv_filter_size": 1024, "conv_kernel_size": [3,3], "decoder_dropout": 0.1}})
198 | src_seq = torch.randint(0, 100, (2, 100))
199 | mask = torch.ones((2, 100)).bool()
200 | enc_output = encoder(src_seq, mask)
201 | dec_output, mask = decoder(enc_output, mask)
202 | print(dec_output.shape, mask.shape)
203 |
--------------------------------------------------------------------------------
/modules/transformer/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class ScaledDotProductAttention(nn.Module):
7 | """ Scaled Dot-Product Attention """
8 |
9 | def __init__(self, temperature):
10 | super().__init__()
11 | self.temperature = temperature
12 | self.softmax = nn.Softmax(dim=2)
13 |
14 | def forward(self, q, k, v, mask=None):
15 |
16 | attn = torch.bmm(q, k.transpose(1, 2))
17 | attn = attn / self.temperature
18 |
19 | if mask is not None:
20 | attn = attn.masked_fill(mask, -np.inf)
21 |
22 | attn = self.softmax(attn)
23 | output = torch.bmm(attn, v)
24 |
25 | return output, attn
26 |
--------------------------------------------------------------------------------
/modules/transformer/SubLayers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from Modules import ScaledDotProductAttention
6 |
7 |
8 | class MultiHeadAttention(nn.Module):
9 | """ Multi-Head Attention module """
10 |
11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
12 | super().__init__()
13 |
14 | self.n_head = n_head
15 | self.d_k = d_k
16 | self.d_v = d_v
17 |
18 | self.w_qs = nn.Linear(d_model, n_head * d_k)
19 | self.w_ks = nn.Linear(d_model, n_head * d_k)
20 | self.w_vs = nn.Linear(d_model, n_head * d_v)
21 |
22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
23 | self.layer_norm = nn.LayerNorm(d_model)
24 |
25 | self.fc = nn.Linear(n_head * d_v, d_model)
26 |
27 | self.dropout = nn.Dropout(dropout)
28 |
29 | def forward(self, q, k, v, mask=None):
30 |
31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
32 |
33 | sz_b, len_q, _ = q.size()
34 | sz_b, len_k, _ = k.size()
35 | sz_b, len_v, _ = v.size()
36 |
37 | residual = q
38 |
39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
45 |
46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
47 | output, attn = self.attention(q, k, v, mask=mask)
48 |
49 | output = output.view(n_head, sz_b, len_q, d_v)
50 | output = (
51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)
52 | ) # b x lq x (n*dv)
53 |
54 | output = self.dropout(self.fc(output))
55 | output = self.layer_norm(output + residual)
56 |
57 | return output, attn
58 |
59 |
60 | class PositionwiseFeedForward(nn.Module):
61 | """ A two-feed-forward-layer module """
62 |
63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1):
64 | super().__init__()
65 |
66 | # Use Conv1D
67 | # position-wise
68 | self.w_1 = nn.Conv1d(
69 | d_in,
70 | d_hid,
71 | kernel_size=kernel_size[0],
72 | padding=(kernel_size[0] - 1) // 2,
73 | )
74 | # position-wise
75 | self.w_2 = nn.Conv1d(
76 | d_hid,
77 | d_in,
78 | kernel_size=kernel_size[1],
79 | padding=(kernel_size[1] - 1) // 2,
80 | )
81 |
82 | self.layer_norm = nn.LayerNorm(d_in)
83 | self.dropout = nn.Dropout(dropout)
84 |
85 | def forward(self, x):
86 | residual = x
87 | output = x.transpose(1, 2)
88 | output = self.w_2(F.relu(self.w_1(output)))
89 | output = output.transpose(1, 2)
90 | output = self.dropout(output)
91 | output = self.layer_norm(output + residual)
92 |
93 | return output
94 |
--------------------------------------------------------------------------------
/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .Constants import *
2 | from .Layers import *
3 | from .SubLayers import *
4 | from .Models import *
5 | from .Modules import *
--------------------------------------------------------------------------------
/modules/variance/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/modules/variance/__init__.py
--------------------------------------------------------------------------------
/modules/variance/modules.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from pyutils import get_mask_from_lengths, pad
8 |
9 | def f02pitch(f0):
10 | #f0 =f0 + 0.01
11 | return np.log2(f0 / 27.5) * 12 + 21
12 |
13 | class VarianceAdaptor(nn.Module):
14 | """Variance Adaptor"""
15 |
16 | def __init__(self, data_config, model_config):
17 | super(VarianceAdaptor, self).__init__()
18 | self.duration_predictor = VariancePredictor(**model_config['variance_predictor'])
19 | self.length_regulator = LengthRegulator()
20 | self.pitch_predictor = VariancePredictor(**model_config['variance_predictor'])
21 | self.uv_predictor = VariancePredictor(**model_config['variance_predictor'])
22 | self.energy_predictor = VariancePredictor(**model_config['variance_predictor'])
23 |
24 | self.uv_threshold = model_config['uv_threshold']
25 |
26 | self.pitch_feature_level = data_config["pitch"]["feature"]
27 | self.energy_feature_level = data_config["energy"]["feature"]
28 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
29 | assert self.energy_feature_level in ["phoneme_level", "frame_level"]
30 |
31 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
32 | energy_quantization = model_config["variance_embedding"]["energy_quantization"]
33 | n_bins = model_config["variance_embedding"]["n_bins"]
34 | assert pitch_quantization in ["linear", "log"]
35 | assert energy_quantization in ["linear", "log"]
36 |
37 | pitch_min_max = f02pitch(np.load(data_config['f0_min_max']))
38 | pitch_min, pitch_max = pitch_min_max[0][0], pitch_min_max[0][1]
39 | # print(np.load(data_config['energy_min_max']))
40 |
41 | energy_min_max = np.load(data_config['energy_min_max'])
42 | energy_min, energy_max = energy_min_max[0][0] + 1e-4, energy_min_max[0][1]
43 |
44 | if pitch_quantization == "log":
45 | self.pitch_bins = nn.Parameter(
46 | torch.exp(
47 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
48 | ),
49 | requires_grad=False,
50 | )
51 | else:
52 | self.pitch_bins = nn.Parameter(
53 | torch.linspace(pitch_min, pitch_max, n_bins - 1),
54 | requires_grad=False,
55 | )
56 | if energy_quantization == "log":
57 | self.energy_bins = nn.Parameter(
58 | torch.exp(
59 | torch.linspace(np.log(energy_min + 1e-6), np.log(energy_max), n_bins - 1)
60 | ),
61 | requires_grad=False,
62 | )
63 | else:
64 | self.energy_bins = nn.Parameter(
65 | torch.linspace(energy_min, energy_max, n_bins - 1),
66 | requires_grad=False,
67 | )
68 |
69 | self.pitch_embedding = nn.Embedding(
70 | n_bins, model_config["transformer"]["encoder"]["d_word_vec"]
71 | )
72 | self.energy_embedding = nn.Embedding(
73 | n_bins, model_config["transformer"]["encoder"]["d_word_vec"]
74 | )
75 | self.uv_embedding = nn.Embedding(
76 | 2, model_config['transformer']['encoder']['d_word_vec']
77 | )
78 |
79 | def get_uv_embedding(self, x, target, mask, control=1.0):
80 | prediction = self.uv_predictor(x, mask)
81 | if target is not None:
82 | embedding = self.uv_embedding(target.to(torch.int64))
83 | else:
84 | prediction = prediction * control
85 | prediction = torch.sigmoid(prediction)
86 | for i in range(prediction.shape[0]):
87 | prediction[i] = prediction[i] >= self.uv_threshold # (B, max_frames, 1)
88 |
89 | embedding = self.uv_embedding(prediction.long())
90 |
91 | return prediction, embedding
92 |
93 | def get_pitch_embedding(self, x, target, mask, control):
94 | prediction = self.pitch_predictor(x, mask)
95 | if target is not None:
96 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
97 | else:
98 | prediction = prediction * control
99 | embedding = self.pitch_embedding(
100 | torch.bucketize(prediction, self.pitch_bins)
101 | )
102 | return prediction, embedding
103 |
104 | def get_energy_embedding(self, x, target, mask, control):
105 | prediction = self.energy_predictor(x, mask)
106 | if target is not None:
107 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
108 | else:
109 | prediction = prediction * control
110 | embedding = self.energy_embedding(
111 | torch.bucketize(prediction, self.energy_bins)
112 | )
113 | return prediction, embedding
114 |
115 | def forward(
116 | self,
117 | x,
118 | src_mask,
119 | mel_mask=None,
120 | max_len=None,
121 | pitch_target=None,
122 | energy_target=None,
123 | uv_target=None,
124 | duration_target=None,
125 | p_control=1.0,
126 | e_control=1.0,
127 | d_control=1.0,
128 | uv_control=1.0
129 | ):
130 |
131 | log_duration_prediction = self.duration_predictor(x, src_mask)
132 | if self.pitch_feature_level == "phoneme_level":
133 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
134 | x, pitch_target, src_mask, p_control
135 | )
136 | x = x + pitch_embedding
137 | if self.energy_feature_level == "phoneme_level":
138 | energy_prediction, energy_embedding = self.get_energy_embedding(
139 | x, energy_target, src_mask, p_control
140 | )
141 | x = x + energy_embedding
142 |
143 | if duration_target is not None:
144 | x, mel_len = self.length_regulator(x, duration_target, max_len)
145 | duration_rounded = duration_target
146 | else:
147 | duration_rounded = torch.clamp(
148 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
149 | min=0,
150 | )
151 | x, mel_len = self.length_regulator(x, duration_rounded, max_len)
152 | mel_mask = get_mask_from_lengths(mel_len)
153 |
154 | if self.pitch_feature_level == "frame_level":
155 | pitch_prediction, pitch_embedding = self.get_pitch_embedding(
156 | x, pitch_target, mel_mask, p_control
157 | )
158 | x = x + pitch_embedding
159 | if self.energy_feature_level == "frame_level":
160 | energy_prediction, energy_embedding = self.get_energy_embedding(
161 | x, energy_target, mel_mask, p_control
162 | )
163 | x = x + energy_embedding
164 |
165 | uv_prediction, uv_embedding = self.get_uv_embedding(
166 | x, uv_target, mel_mask, uv_control
167 | )
168 | x = x + uv_embedding
169 |
170 | return (
171 | x,
172 | pitch_prediction,
173 | energy_prediction,
174 | uv_prediction,
175 | log_duration_prediction,
176 | duration_rounded,
177 | mel_len,
178 | mel_mask,
179 | )
180 |
181 |
182 | class LengthRegulator(nn.Module):
183 | """Length Regulator"""
184 |
185 | def __init__(self):
186 | super(LengthRegulator, self).__init__()
187 |
188 | def LR(self, x, duration, max_len):
189 | device = x.device
190 | output = list()
191 | mel_len = list()
192 | for batch, expand_target in zip(x, duration):
193 | expanded = self.expand(batch, expand_target)
194 | output.append(expanded)
195 | mel_len.append(expanded.shape[0])
196 |
197 | if max_len is not None:
198 | output = pad(output, max_len)
199 | else:
200 | output = pad(output)
201 |
202 | return output, torch.LongTensor(mel_len).to(device)
203 |
204 | def expand(self, batch, predicted):
205 | out = list()
206 |
207 | for i, vec in enumerate(batch):
208 | expand_size = predicted[i].item()
209 | out.append(vec.expand(max(int(expand_size), 0), -1))
210 | out = torch.cat(out, 0)
211 |
212 | return out
213 |
214 | def forward(self, x, duration, max_len):
215 | output, mel_len = self.LR(x, duration, max_len)
216 | return output, mel_len
217 |
218 |
219 | class VariancePredictor(nn.Module):
220 | """Duration, Pitch and Energy Predictor"""
221 |
222 | def __init__(
223 | self, input_size, filter_size,
224 | kernel_size, dropout
225 | ):
226 | super(VariancePredictor, self).__init__()
227 |
228 | # self.input_size = model_config["transformer"]["encoder_hidden"]
229 | # self.filter_size = model_config["variance_predictor"]["filter_size"]
230 | # self.kernel = model_config["variance_predictor"]["kernel_size"]
231 | # self.conv_output_size = model_config["variance_predictor"]["filter_size"]
232 | # self.dropout = model_config["variance_predictor"]["dropout"]
233 | self.input_size = input_size
234 | self.filter_size = filter_size
235 | self.kernel = kernel_size
236 | self.conv_output_size = filter_size
237 | self.dropout = dropout
238 |
239 | self.conv_layer = nn.Sequential(
240 | OrderedDict(
241 | [
242 | (
243 | "conv1d_1",
244 | Conv(
245 | self.input_size,
246 | self.filter_size,
247 | kernel_size=self.kernel,
248 | padding=(self.kernel - 1) // 2,
249 | ),
250 | ),
251 | ("relu_1", nn.ReLU()),
252 | ("layer_norm_1", nn.LayerNorm(self.filter_size)),
253 | ("dropout_1", nn.Dropout(self.dropout)),
254 | (
255 | "conv1d_2",
256 | Conv(
257 | self.filter_size,
258 | self.filter_size,
259 | kernel_size=self.kernel,
260 | padding=1,
261 | ),
262 | ),
263 | ("relu_2", nn.ReLU()),
264 | ("layer_norm_2", nn.LayerNorm(self.filter_size)),
265 | ("dropout_2", nn.Dropout(self.dropout)),
266 | ]
267 | )
268 | )
269 |
270 | self.linear_layer = nn.Linear(self.conv_output_size, 1)
271 |
272 | def forward(self, encoder_output, mask):
273 | out = self.conv_layer(encoder_output)
274 | out = self.linear_layer(out)
275 | out = out.squeeze(-1)
276 |
277 | if mask is not None:
278 | out = out.masked_fill(mask, 0.0)
279 |
280 | return out
281 |
282 |
283 | class Conv(nn.Module):
284 | """
285 | Convolution Module
286 | """
287 |
288 | def __init__(
289 | self,
290 | in_channels,
291 | out_channels,
292 | kernel_size=1,
293 | stride=1,
294 | padding=0,
295 | dilation=1,
296 | bias=True,
297 | w_init="linear",
298 | ):
299 | """
300 | :param in_channels: dimension of input
301 | :param out_channels: dimension of output
302 | :param kernel_size: size of kernel
303 | :param stride: size of stride
304 | :param padding: size of padding
305 | :param dilation: dilation rate
306 | :param bias: boolean. if True, bias is included.
307 | :param w_init: str. weight inits with xavier initialization.
308 | """
309 | super(Conv, self).__init__()
310 |
311 | self.conv = nn.Conv1d(
312 | in_channels,
313 | out_channels,
314 | kernel_size=kernel_size,
315 | stride=stride,
316 | padding=padding,
317 | dilation=dilation,
318 | bias=bias,
319 | )
320 |
321 | def forward(self, x):
322 | x = x.contiguous().transpose(1, 2)
323 | x = self.conv(x)
324 | x = x.contiguous().transpose(1, 2)
325 |
326 | return x
327 |
328 | if __name__ == "__main__":
329 | import yaml
330 | with open('/home/zengchang/code/acoustic_v2/configs/data.yaml', 'r') as f:
331 | data_config = yaml.load(f, Loader = yaml.FullLoader)
332 | with open('/home/zengchang/code/acoustic_v2/configs/model.yaml', 'r') as f:
333 | model_config = yaml.load(f, Loader = yaml.FullLoader)
334 | model = VarianceAdaptor(data_config, model_config['generator'])
335 | print(model)
336 |
--------------------------------------------------------------------------------
/pics/2085003136_145600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/2085003136_145600.png
--------------------------------------------------------------------------------
/pics/after_2085003136_145600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/after_2085003136_145600.png
--------------------------------------------------------------------------------
/pics/before_2085003136_145600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/before_2085003136_145600.png
--------------------------------------------------------------------------------
/pics/before_mel_l2_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/before_mel_l2_loss.png
--------------------------------------------------------------------------------
/pics/post_mel_l2_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/post_mel_l2_loss.png
--------------------------------------------------------------------------------
/pics/xs1_before_2085003136_145600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/xs1_before_2085003136_145600.png
--------------------------------------------------------------------------------
/preprocess/audio_preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import librosa
5 | import pyworld
6 | import parselmouth
7 | import soundfile as sf
8 | import numpy as np
9 | import yaml
10 | from tqdm import tqdm
11 | from sklearn.preprocessing import StandardScaler
12 | from ipdb import set_trace
13 | from pyutils import f02pitch, pitch2f0, pitchxuv
14 | cwd=os.path.dirname(os.path.realpath(__file__))
15 | sys.path.insert(0, cwd)
16 |
17 | def resample_wav(wav, src_sr, tgt_sr):
18 | return librosa.resample(wav, orig_sr=src_sr, target_sr=tgt_sr)
19 |
20 | def _resize_f0(x, target_len):
21 | source = np.array(x)
22 | source[source < 0.001] = np.nan
23 | target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)), source)
24 | res = np.nan_to_num(target)
25 | return res
26 |
27 | def compute_f0_dio(wav, p_len=None, sampling_rate=48000, hop_length=240):
28 | if p_len is None:
29 | p_len = wav.shape[0]//hop_length
30 | f0, t = pyworld.dio(
31 | wav.astype(np.double),
32 | fs=sampling_rate,
33 | f0_ceil=800,
34 | frame_period=1000 * hop_length / sampling_rate
35 | )
36 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, sampling_rate)
37 | for index, pitch in enumerate(f0):
38 | f0[index] = round(pitch, 1)
39 | return _resize_f0(f0, p_len)
40 |
41 | def compute_f0_parselmouth(wav, p_len=None, sampling_rate=48000, hop_length=240):
42 | x = wav
43 | if p_len is None:
44 | p_len = x.shape[0]//hop_length
45 | else:
46 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
47 | time_step = hop_length / sampling_rate * 1000
48 | f0_min = 50
49 | f0_max = 1100
50 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac(
51 | time_step=time_step / 1000, voicing_threshold=0.6,
52 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
53 |
54 | pad_size=(p_len - len(f0) + 1) // 2
55 | if(pad_size>0 or p_len - len(f0) - pad_size>0):
56 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
57 | return f0
58 |
59 | def interpolate_f0(f0):
60 | data = np.reshape(f0, (f0.size, 1))
61 |
62 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
63 | vuv_vector[data > 0.0] = 1.0
64 | vuv_vector[data <= 0.0] = 0.0
65 |
66 | ip_data = data
67 |
68 | frame_number = data.size
69 | last_value = 0.0
70 | for i in range(frame_number):
71 | if data[i] <= 0.0:
72 | j = i + 1
73 | for j in range(i + 1, frame_number):
74 | if data[j] > 0.0:
75 | break
76 | if j < frame_number - 1:
77 | if last_value > 0.0:
78 | step = (data[j] - data[i - 1]) / float(j - i)
79 | for k in range(i, j):
80 | ip_data[k] = data[i - 1] + step * (k - i + 1)
81 | else:
82 | for k in range(i, j):
83 | ip_data[k] = data[j]
84 | else:
85 | for k in range(i, frame_number):
86 | ip_data[k] = last_value
87 | else:
88 | ip_data[i] = data[i] # this may not be necessary
89 | last_value = data[i]
90 |
91 | return ip_data[:,0], vuv_vector[:,0]
92 |
93 | def read_scp(scp_file):
94 | filelists = []
95 | with open(scp_file, 'r') as f:
96 | for line in f:
97 | line = line.rstrip()
98 | filelists.append(line)
99 | return filelists
100 |
101 | def spec_normalize(feat):
102 | '''
103 | params:
104 | feat: T, F
105 | '''
106 | return (feat - feat.mean(axis = 0, keepdims = True)) / (feat.std(axis = 0, keepdims = True) + 2e-12)
107 |
108 | def pad_wav(wav, config):
109 | padded_wav = np.pad(wav, (int((config['n_fft']-config['hop_length'])/2), int((config['n_fft']-config['hop_length'])/2)), mode='reflect')
110 | return padded_wav
111 |
112 | def extract_spec_with_energy(wav, filepath, config, spec_scaler = None, energy_scaler = None):
113 | '''
114 | (T, F/C)
115 | '''
116 | wav = pad_wav(wav, config)
117 | stft = librosa.stft(
118 | wav,
119 | n_fft = config['n_fft'],
120 | hop_length = config['hop_length'],
121 | win_length = config['win_length'],
122 | window = 'hann',
123 | center = False,
124 | pad_mode = 'reflect'
125 | )
126 | # set_trace()
127 | spec = np.abs(stft).transpose(1, 0)
128 | energy = np.sqrt((spec**2).sum(axis = 1))
129 | energy = energy.reshape(-1, 1)
130 | if spec_scaler is not None:
131 | spec_scaler.partial_fit(spec)
132 | if energy_scaler is not None:
133 | energy_scaler.partial_fit(energy)
134 | suffix = filepath.split('.')[-1]
135 | spec_filepath = filepath.replace(f'.{suffix}', '.spec.npy')
136 | np.save(spec_filepath, spec)
137 | suffix = filepath.split('.')[-1]
138 | energy_filepath = filepath.replace(f'.{suffix}', '.en.npy')
139 | np.save(energy_filepath, energy)
140 | # return spec, energy
141 |
142 | def extract_mel(wav, filepath, config, mel_scaler):
143 | '''
144 | log mel + spec normalization
145 | (T, F/C)
146 | '''
147 | wav = pad_wav(wav, config)
148 | mel_spec = librosa.feature.melspectrogram(
149 | y = wav,
150 | sr = config['sampling_rate'],
151 | n_fft = config['n_fft'],
152 | hop_length = config['hop_length'],
153 | win_length = config['win_length'],
154 | window = 'hann',
155 | n_mels = config['n_mels'],
156 | fmin = config['fmin'],
157 | fmax = config['fmax'],
158 | center = False,
159 | pad_mode = 'reflect'
160 | )
161 | log_mel_spec = np.log(mel_spec + 1e-9).transpose(1, 0)
162 | # normalized_log_mel_spec = spec_normalize(log_mel_spec)
163 | mel_scaler.partial_fit(log_mel_spec)
164 | suffix = filepath.split('.')[-1]
165 | mel_filepath = filepath.replace(f'.{suffix}', '.mel.npy')
166 | np.save(mel_filepath, log_mel_spec)
167 |
168 | def extract_f0(filepath, config, f0_scaler = None):
169 | '''
170 | (T, 1)
171 | '''
172 | wav, sr = sf.read(filepath)
173 | wav = resample_wav(wav, sr, config['sampling_rate'])
174 | sr = config['sampling_rate']
175 | assert sr == config['sampling_rate'], "Sampling rate ({}) != {}, please fix it!".format(sr, config['sampling_rate'])
176 | # wav = pad_wav(wav, config) # don't padding for computing f0
177 | f0 = compute_f0_dio(
178 | wav,
179 | sampling_rate = config["sampling_rate"],
180 | hop_length = config["hop_length"]
181 | )
182 | f0, uv = interpolate_f0(f0)
183 | f0 = f0.reshape(-1, 1)
184 | if f0_scaler is not None:
185 | f0_scaler.partial_fit(f0)
186 | suffix = filepath.split('.')[-1]
187 | f0_filepath = filepath.replace(f'.{suffix}', '.f0.npy')
188 | uv_filepath = filepath.replace(f'.{suffix}', '.uv.npy')
189 | np.save(f0_filepath, f0)
190 | np.save(uv_filepath, uv)
191 |
192 | def process_one_utterance_spec(filepath, config, spec_scaler, mel_scaler, energy_scaler = None):
193 | wav, sr = sf.read(filepath)
194 | wav = resample_wav(wav, sr, config['sampling_rate'])
195 | sr = config['sampling_rate']
196 | assert sr == config['sampling_rate'], "Sampling rate ({}) != {}, please fix it!".format(sr, config['sampling_rate'])
197 | if args.spec:
198 | extract_spec_with_energy(wav, filepath, config, spec_scaler, energy_scaler)
199 | if args.mel:
200 | extract_mel(wav, filepath, config, mel_scaler)
201 |
202 | def normalize(filelists, mean, std, feature = 'f0'):
203 | '''
204 | normalize spec/mel_spec
205 | unnormalize f0/energy
206 | '''
207 | min_value = np.finfo(np.float64).max
208 | max_value = np.finfo(np.float64).min
209 | for filepath in filelists:
210 | suffix = filepath.split('.')[-1]
211 | filepath = filepath.replace(f'.{suffix}', f'.{feature}.npy')
212 | values = np.load(filepath)
213 | if feature in ['f0', 'en']:
214 | min_value = min(min_value, min(values))
215 | max_value = max(max_value, max(values))
216 | else:
217 | values = (np.load(filepath) - mean) / std
218 | np.save(filepath, values)
219 | return np.array([min_value, max_value]).reshape(1, -1)
220 |
221 | def parse_args():
222 | parser = argparse.ArgumentParser()
223 | parser.add_argument("--data-config", dest = "data_config", type = str, default = "", help = "data config path")
224 | parser.add_argument("--spec", action = "store_true", help = "extract stft spec feature")
225 | parser.add_argument("--mel", action = "store_true", help = "extract mel feature")
226 | parser.add_argument("--f0", action = "store_true", help = "extract f0 and uv")
227 | parser.add_argument("--energy", action = "store_true", help = "extract energy")
228 | parser.add_argument("--stat", action = "store_true", help = "Count the statistical numbers (mean and std) for energy and f0")
229 |
230 | args = parser.parse_args()
231 | return args
232 |
233 | def main(args):
234 | with open(args.data_config, 'r') as f:
235 | data_config = yaml.load(f, Loader = yaml.FullLoader)
236 | filelists = []
237 | with open(data_config['audio_manifest'], 'r') as f:
238 | for line in f:
239 | line = line.rstrip().split(' ')[-1]
240 | filelists.append(line)
241 | args.scp_file = data_config['audio_manifest']
242 |
243 | spec_scaler = StandardScaler()
244 | mel_scaler = StandardScaler()
245 | f0_scaler = None
246 | energy_scaler = None
247 | if args.stat:
248 | f0_scaler = StandardScaler()
249 | energy_scaler = StandardScaler()
250 |
251 | print("Extracting features...")
252 | for filepath in tqdm(filelists):
253 | if args.spec or args.mel:
254 | try:
255 | process_one_utterance_spec(filepath, data_config, spec_scaler, mel_scaler, energy_scaler)
256 | except:
257 | print(filepath)
258 | if args.f0:
259 | try:
260 | extract_f0(filepath, data_config, f0_scaler)
261 | except:
262 | print(filepath)
263 |
264 | if args.stat:
265 | if args.spec:
266 | spec_mean = spec_scaler.mean_.reshape(1, -1)
267 | spec_std = spec_scaler.scale_.reshape(1, -1)
268 | np.save(os.path.join(os.path.dirname(args.scp_file), 'spec_mean.npy'), spec_mean)
269 | np.save(os.path.join(os.path.dirname(args.scp_file), 'spec_std.npy'), spec_std)
270 | normalize(filelists, spec_mean, spec_std, feature = 'spec')
271 |
272 | if args.mel:
273 | mel_mean = mel_scaler.mean_.reshape(1, -1)
274 | mel_std = mel_scaler.scale_.reshape(1, -1)
275 | np.save(os.path.join(os.path.dirname(args.scp_file), 'mel_mean.npy'), mel_mean)
276 | np.save(os.path.join(os.path.dirname(args.scp_file), 'mel_std.npy'), mel_std)
277 | normalize(filelists, mel_mean, mel_std, feature = 'mel')
278 |
279 | if args.f0:
280 | print("Calculating f0 stats...")
281 | f0_mean = f0_scaler.mean_.reshape(1, -1)
282 | f0_std = f0_scaler.scale_.reshape(1, -1)
283 | np.save(os.path.join(os.path.dirname(args.scp_file), 'f0_mean.npy'), f0_mean)
284 | np.save(os.path.join(os.path.dirname(args.scp_file), 'f0_std.npy'), f0_std)
285 | f0_min_max = normalize(filelists, f0_mean, f0_std, feature = 'f0')
286 | np.save(os.path.join(os.path.dirname(args.scp_file), 'f0_min_max.npy'), f0_min_max)
287 |
288 | if args.energy:
289 | print("Calculating energy stats...")
290 | energy_mean = energy_scaler.mean_.reshape(1, -1)
291 | energy_std = energy_scaler.scale_.reshape(1, -1)
292 | np.save(os.path.join(os.path.dirname(args.scp_file), 'energy_mean.npy'), energy_mean)
293 | np.save(os.path.join(os.path.dirname(args.scp_file), 'energy_std.npy'), energy_std)
294 | energy_min_max = normalize(filelists, energy_mean, energy_std, feature = 'en')
295 | np.save(os.path.join(os.path.dirname(args.scp_file), 'energy_min_max.npy'), energy_min_max)
296 |
297 | if __name__ == "__main__":
298 | args = parse_args()
299 | print(args)
300 | main(args)
301 |
--------------------------------------------------------------------------------
/preprocess/data_prep.py:
--------------------------------------------------------------------------------
1 | import os
2 | import librosa
3 | import numpy as np
4 | from scipy.io import wavfile
5 | from tqdm import tqdm
6 |
7 | def prepare_aishell3(config):
8 | pass
9 |
10 |
11 | def prepare_align(config):
12 | in_dir = config["path"]["corpus_path"]
13 | out_dir = config["path"]["raw_path"]
14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
16 | for dataset in ["train", "test"]:
17 | print("Processing {}ing set...".format(dataset))
18 | with open(os.path.join(in_dir, dataset, "content.txt"), encoding="utf-8") as f:
19 | for line in tqdm(f):
20 | wav_name, text = line.strip("\n").split("\t")
21 | speaker = wav_name[:7]
22 | text = text.split(" ")[1::2]
23 | wav_path = os.path.join(in_dir, dataset, "wav", speaker, wav_name)
24 | if os.path.exists(wav_path):
25 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
26 | wav, _ = librosa.load(wav_path, sampling_rate)
27 | wav = wav / max(abs(wav)) * max_wav_value
28 | wavfile.write(
29 | os.path.join(out_dir, speaker, wav_name),
30 | sampling_rate,
31 | wav.astype(np.int16),
32 | )
33 | with open(
34 | os.path.join(out_dir, speaker, "{}.lab".format(wav_name[:11])),
35 | "w",
36 | ) as f1:
37 | f1.write(" ".join(text))
38 |
--------------------------------------------------------------------------------
/pyutils/__init__.py:
--------------------------------------------------------------------------------
1 | from .save_and_load import *
2 | from .plot import *
3 | from .logger import *
4 | from .mask import *
5 | from .logger import *
6 | from .optimizer import *
7 | from . import scheduler
8 | import torch
9 | import numpy as np
10 |
11 | def f02pitch(f0):
12 | #f0 =f0 + 0.01
13 | return np.log2(f0 / 27.5) * 12 + 21
14 |
15 | def pitch2f0(pitch):
16 | f0 = np.exp2((pitch - 21 ) / 12) * 27.5
17 | for i in range(len(f0)):
18 | if f0[i] <= 10:
19 | f0[i] = 0
20 | return f0
21 |
22 | def pitchxuv(pitch, uv, to_f0 = False):
23 | result = pitch * uv
24 | if to_f0:
25 | result = pitch2f0(result)
26 | return result
27 |
28 | def initialize(model, init_type="pytorch"):
29 | """Initialize Transformer module
30 |
31 | :param torch.nn.Module model: core instance
32 | :param str init_type: initialization type
33 | """
34 | if init_type == "pytorch":
35 | return
36 |
37 | # weight init
38 | for p in model.parameters():
39 | if p.dim() > 1:
40 | if init_type == "xavier_uniform":
41 | torch.nn.init.xavier_uniform_(p.data)
42 | elif init_type == "xavier_normal":
43 | torch.nn.init.xavier_normal_(p.data)
44 | elif init_type == "kaiming_uniform":
45 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
46 | elif init_type == "kaiming_normal":
47 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
48 | else:
49 | raise ValueError("Unknown initialization: " + init_type)
50 | # bias init
51 | for p in model.parameters():
52 | if p.dim() == 1:
53 | p.data.zero_()
54 |
55 | # reset some loss with default init
56 | for m in model.modules():
57 | if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)):
58 | m.reset_parameters()
59 |
60 | def get_mask_from_lengths(lengths, max_len=None):
61 | device = lengths.device
62 | batch_size = lengths.shape[0]
63 | if max_len is None:
64 | max_len = torch.max(lengths).item()
65 |
66 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
67 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
68 |
69 | return mask
70 |
71 | def pad(input_ele, mel_max_length=None):
72 | if mel_max_length:
73 | max_len = mel_max_length
74 | else:
75 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
76 |
77 | out_list = list()
78 | for i, batch in enumerate(input_ele):
79 | if len(batch.shape) == 1:
80 | one_batch_padded = F.pad(
81 | batch, (0, max_len - batch.size(0)), "constant", 0.0
82 | )
83 | elif len(batch.shape) == 2:
84 | one_batch_padded = F.pad(
85 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
86 | )
87 | out_list.append(one_batch_padded)
88 | out_padded = torch.stack(out_list)
89 | return out_padded
90 |
--------------------------------------------------------------------------------
/pyutils/gen_duration_from_tg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import os
16 | from pathlib import Path
17 |
18 | import librosa
19 | import numpy as np
20 | import yaml
21 | from praatio import textgrid
22 | from yacs.config import CfgNode
23 | from tqdm import tqdm
24 |
25 |
26 | def readtg(tg_path, sample_rate=24000, n_shift=300):
27 | alignment = textgrid.openTextgrid(tg_path, includeEmptyIntervals=True)
28 | phones = []
29 | ends = []
30 | for interval in alignment.tierDict["phones"].entryList:
31 | phone = interval.label
32 | phones.append(phone)
33 | ends.append(interval.end)
34 | frame_pos = librosa.time_to_frames(ends, sr=sample_rate, hop_length=n_shift)
35 | durations = np.diff(frame_pos, prepend=0)
36 | assert len(durations) == len(phones)
37 | # merge "" and sp in the end
38 | if phones[-1] == "" and len(phones) > 1 and phones[-2] == "sp":
39 | phones = phones[:-1]
40 | durations[-2] += durations[-1]
41 | durations = durations[:-1]
42 | # replace the last "sp" with "sil" in MFA1.x
43 | phones[-1] = "sil" if phones[-1] == "sp" else phones[-1]
44 | # replace the edge "" with "sil", replace the inner "" with "sp"
45 | new_phones = []
46 | for i, phn in enumerate(phones):
47 | if phn == "":
48 | if i in {0, len(phones) - 1}:
49 | new_phones.append("sil")
50 | else:
51 | new_phones.append("sp")
52 | else:
53 | new_phones.append(phn)
54 | phones = new_phones
55 | results = ""
56 | for (p, d) in zip(phones, durations):
57 | results += p + " " + str(d) + " "
58 | return results.strip()
59 |
60 |
61 | # assume that the directory structure of inputdir is inputdir/speaker/*.TextGrid
62 | # in MFA1.x, there are blank labels("") in the end, and maybe "sp" before it
63 | # in MFA2.x, there are blank labels("") in the begin and the end, while no "sp" and "sil" anymore
64 | # we replace it with "sil"
65 | def gen_duration_from_textgrid(inputdir, output, sample_rate=24000,
66 | n_shift=300):
67 | # key: utt_id, value: (speaker, phn_durs)
68 | durations_dict = {}
69 | list_dir = os.listdir(inputdir)
70 | speakers = [dir for dir in list_dir if os.path.isdir(inputdir / dir)]
71 | for speaker in speakers:
72 | subdir = inputdir / speaker
73 | for file in tqdm(os.listdir(subdir)):
74 | if file.endswith(".TextGrid"):
75 | tg_path = subdir / file
76 | name = file.split(".")[0]
77 | durations_dict[name] = (speaker, readtg(
78 | tg_path, sample_rate=sample_rate, n_shift=n_shift))
79 | with open(output, "w") as wf:
80 | for name in sorted(durations_dict.keys()):
81 | wf.write(name + "|" + durations_dict[name][0] + "|" +
82 | durations_dict[name][1] + "\n")
83 |
84 |
85 | def main():
86 | # parse config and args
87 | parser = argparse.ArgumentParser(
88 | description="Preprocess audio and then extract features.")
89 | parser.add_argument(
90 | "--inputdir",
91 | default=None,
92 | type=str,
93 | help="directory to alignment files.")
94 | parser.add_argument(
95 | "--output", type=str, required=True, help="output duration file.")
96 | parser.add_argument("--sample-rate", type=int, help="the sample of wavs.")
97 | parser.add_argument(
98 | "--n-shift",
99 | type=int,
100 | help="the n_shift of time_to_freames, also called hop_length.")
101 | parser.add_argument(
102 | "--config", type=str, help="config file with fs and n_shift.")
103 |
104 | args = parser.parse_args()
105 | with open(args.config) as f:
106 | config = CfgNode(yaml.safe_load(f))
107 |
108 | inputdir = Path(args.inputdir).expanduser()
109 | output = Path(args.output).expanduser()
110 | print(config)
111 | # import sys
112 | # sys.exit(0)
113 | gen_duration_from_textgrid(inputdir, output, config.sampling_rate, config.hop_length)
114 |
115 |
116 | if __name__ == "__main__":
117 | main()
--------------------------------------------------------------------------------
/pyutils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | def get_logger(logging_file):
4 | logger = logging.getLogger()
5 | logger.setLevel(logging.INFO)
6 |
7 | formatter = logging.Formatter(
8 | "%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s: %(message)s"
9 | )
10 |
11 | file_log_handler = logging.FileHandler(logging_file, mode = 'w')
12 | file_log_handler.setLevel(logging.INFO)
13 | file_log_handler.setFormatter(formatter)
14 |
15 | # stream_log_handler = logging.StreamHandler()
16 | # stream_log_handler.setLevel(logging.INFO)
17 | # stream_log_handler.setFormatter(formatter)
18 |
19 | logger.addHandler(file_log_handler)
20 | # logger.addHandler(stream_log_handler)
21 |
22 | return logger
--------------------------------------------------------------------------------
/pyutils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from ipdb import set_trace
4 | from torch.optim import (
5 | SGD,
6 | Adam,
7 | AdamW,
8 | RMSprop,
9 | RAdam,
10 | NAdam,
11 | ASGD
12 | )
13 |
14 | class NoamOpt():
15 | "Optim wrapper that implements rate."
16 |
17 | def __init__(self, optimizer, model_size, warmup, factor = 1.0):
18 | '''
19 | model_size: d_model
20 | factor: factor
21 | warmup: warmup step
22 | optimizer: optimizer (Adam default)
23 | '''
24 | self.optimizer = optimizer
25 | self._step = 0
26 | self.warmup = warmup
27 | self.factor = factor
28 | self.model_size = model_size
29 | self._rate = 0
30 |
31 | @property
32 | def param_groups(self):
33 | return self.optimizer.param_groups
34 |
35 | def step(self):
36 | "Update parameters and rate"
37 | self._step += 1
38 | rate = self.rate()
39 | for p in self.optimizer.param_groups:
40 | p["lr"] = rate
41 | self._rate = rate
42 | self.optimizer.step()
43 |
44 | def rate(self, step=None):
45 | "Implement `lrate` above"
46 | if step is None:
47 | step = self._step
48 | return (
49 | self.factor
50 | * self.model_size ** (-0.5)
51 | * min(step ** (-0.5), step * self.warmup ** (-1.5))
52 | )
53 |
54 | def zero_grad(self):
55 | self.optimizer.zero_grad()
56 |
57 | def state_dict(self):
58 | return {
59 | "_step": self._step,
60 | "warmup": self.warmup,
61 | "factor": self.factor,
62 | "model_size": self.model_size,
63 | "_rate": self._rate,
64 | "optimizer": self.optimizer.state_dict(),
65 | }
66 |
67 | def load_state_dict(self, state_dict):
68 | for key, value in state_dict.items():
69 | if key == "optimizer":
70 | self.optimizer.load_state_dict(state_dict["optimizer"])
71 | else:
72 | setattr(self, key, value)
73 |
74 | class ScheduledOptimD():
75 | ''' A simple wrapper class for learning rate scheduling '''
76 |
77 | def __init__(self, optimizer, init_lr, n_warmup_steps, current_steps):
78 | self.optimizer = optimizer
79 | self.n_warmup_steps = n_warmup_steps
80 | self.n_current_steps = current_steps
81 | self.init_lr = init_lr
82 |
83 | def step_and_update_lr_frozen(self, learning_rate_frozen):
84 | for param_group in self.optimizer.param_groups:
85 | param_group['lr'] = learning_rate_frozen
86 | self.optimizer.step()
87 |
88 | def step_and_update_lr(self):
89 | self._update_learning_rate()
90 | self.optimizer.step()
91 |
92 | def get_learning_rate(self):
93 | learning_rate = 0.0
94 | for param_group in self.optimizer.param_groups:
95 | learning_rate = param_group['lr']
96 |
97 | return learning_rate
98 |
99 | def zero_grad(self):
100 | # print(self.init_lr)
101 | self.optimizer.zero_grad()
102 |
103 | def set_current_steps(self, step):
104 | self.n_current_steps = step
105 |
106 | def _get_lr_scale(self):
107 | # set_trace()
108 | return np.min([
109 | np.power(self.n_current_steps, -0.5),
110 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
111 |
112 | def _update_learning_rate(self):
113 | ''' Learning rate scheduling per step '''
114 |
115 | lr = self.init_lr * self._get_lr_scale()
116 |
117 | for param_group in self.optimizer.param_groups:
118 | param_group['lr'] = lr
119 |
120 | def state_dict(self):
121 | return {
122 | "_step": self.n_current_steps,
123 | "warmup": self.n_warmup_steps,
124 | "factor": self.init_lr,
125 | "_rate": self.get_learning_rate(),
126 | "optimizer": self.optimizer.state_dict(),
127 | }
128 |
129 | def load_state_dict(self, state_dict):
130 | for key, value in state_dict.items():
131 | if key == "optimizer":
132 | self.optimizer.load_state_dict(state_dict["optimizer"])
133 | else:
134 | setattr(self, key, value)
135 |
136 | def get_g_opt(model, optim, d_model, warmup, factor):
137 | base = torch.optim.Adam(model.parameters(), lr = 0, betas = (0.9, 0.98), eps = 1e-9)
138 | return NoamOpt(base, d_model, warmup, factor)
139 |
140 | def get_d_opt(model, optim, warmup, factor, current_step):
141 | base = torch.optim.Adam(model.parameters(), lr = 0, betas = (0.9, 0.98), eps = 1e-9)
142 | return ScheduledOptimD(base, factor, warmup, current_step)
143 |
--------------------------------------------------------------------------------
/pyutils/parse_options.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4 | # Arnab Ghoshal, Karel Vesely
5 |
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15 | # MERCHANTABLITY OR NON-INFRINGEMENT.
16 | # See the Apache 2 License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # Parse command-line options.
21 | # To be sourced by another script (as in ". parse_options.sh").
22 | # Option format is: --option-name arg
23 | # and shell variable "option_name" gets set to value "arg."
24 | # The exception is --help, which takes no arguments, but prints the
25 | # $help_message variable (if defined).
26 |
27 |
28 | ###
29 | ### The --config file options have lower priority to command line
30 | ### options, so we need to import them first...
31 | ###
32 |
33 | # Now import all the configs specified by command-line, in left-to-right order
34 | for ((argpos=1; argpos<$#; argpos++)); do
35 | if [ "${!argpos}" == "--config" ]; then
36 | argpos_plus1=$((argpos+1))
37 | config=${!argpos_plus1}
38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39 | . $config # source the config file.
40 | fi
41 | done
42 |
43 |
44 | ###
45 | ### Now we process the command line options
46 | ###
47 | while true; do
48 | [ -z "${1:-}" ] && break; # break if there are no arguments
49 | case "$1" in
50 | # If the enclosing script is called with --help option, print the help
51 | # message and exit. Scripts should put help messages in $help_message
52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53 | else printf "$help_message\n" 1>&2 ; fi;
54 | exit 0 ;;
55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56 | exit 1 ;;
57 | # If the first command-line argument begins with "--" (e.g. --foo-bar),
58 | # then work out the variable name as $name, which will equal "foo_bar".
59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60 | # Next we test whether the variable in question is undefned-- if so it's
61 | # an invalid option and we die. Note: $0 evaluates to the name of the
62 | # enclosing script.
63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64 | # is undefined. We then have to wrap this test inside "eval" because
65 | # foo_bar is itself inside a variable ($name).
66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67 |
68 | oldval="`eval echo \\$$name`";
69 | # Work out whether we seem to be expecting a Boolean argument.
70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71 | was_bool=true;
72 | else
73 | was_bool=false;
74 | fi
75 |
76 | # Set the variable to the right value-- the escaped quotes make it work if
77 | # the option had spaces, like --cmd "queue.pl -sync y"
78 | eval $name=\"$2\";
79 |
80 | # Check that Boolean-valued arguments are really Boolean.
81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83 | exit 1;
84 | fi
85 | shift 2;
86 | ;;
87 | *) break;
88 | esac
89 | done
90 |
91 |
92 | # Check for an empty argument to the --cmd option, which can easily occur as a
93 | # result of scripting errors.
94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95 |
96 |
97 | true; # so this script returns exit code 0.
--------------------------------------------------------------------------------
/pyutils/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import librosa
3 | import soundfile as sf
4 | import numpy as np
5 | import logging
6 | import argparse
7 | logging.getLogger('matplotlib.font_manager').disabled = True
8 |
9 | def specplot(spec,
10 | pic_path = 'exp/test/melspectrograms/spec.png',
11 | **kwargs):
12 | """Plot the log mel spectrogram of audio."""
13 | fig = plt.figure()
14 | plt.imshow(spec, origin = 'lower', cmap = plt.cm.magma, aspect='auto')
15 | plt.colorbar()
16 | fig.savefig(pic_path)
17 | plt.close()
18 |
19 | def specplot_from_audio(filename = None,
20 | audio = None,
21 | rate = None,
22 | rotate=False,
23 | n_ffts = 1024,
24 | pic_path = 'exp/test/spectrograms/spec.png',
25 | **kwargs):
26 | """Plot the log magnitude spectrogram of audio."""
27 | if filename is not None:
28 | audio, rate = sf.read(filename)
29 | hop_length = kwargs.get('hop_length', None)
30 | win_length = kwargs.get('win_length', None)
31 | stft = librosa.stft(
32 | audio,
33 | n_fft = n_ffts,
34 | hop_length = hop_length,
35 | win_length = win_length
36 | )
37 | mag, phase = librosa.magphase(stft)
38 | logmag = np.log10(mag)
39 | fig = plt.figure()
40 | plt.imshow(logmag, cmap = plt.cm.magma, origin = 'lower', aspect = 'auto')
41 | plt.colorbar()
42 | fig.savefig(pic_path)
43 | plt.close()
44 |
45 | def melspecplot(mel_spec,
46 | pic_path = 'exp/test/melspectrograms/melspec.png',
47 | **kwargs):
48 | """Plot the log mel spectrogram of audio."""
49 | fig = plt.figure()
50 | plt.imshow(mel_spec, origin = 'lower', cmap = plt.cm.magma, aspect='auto')
51 | plt.colorbar()
52 | fig.savefig(pic_path)
53 | plt.close()
54 |
55 | def melspecplot_from_audio(filename = None,
56 | audio = None,
57 | rate = None,
58 | rotate = False,
59 | n_ffts = 1024,
60 | pic_path = 'exp/test/melspectrograms/melspec.png',
61 | **kwargs):
62 | """Plot the log mel spectrogram of audio."""
63 | if filename is not None:
64 | audio, rate = sf.read(filename)
65 | hop_length = kwargs.get('hop_length', None)
66 | win_length = kwargs.get('win_length', None)
67 | n_mels = kwargs.get('n_mels', 23)
68 | mel_spec = librosa.feature.melspectrogram(
69 | y = audio,
70 | sr = rate,
71 | n_fft = n_ffts,
72 | hop_length = hop_length,
73 | win_length = win_length,
74 | n_mels = n_mels
75 | )
76 | mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
77 | plt.imshow(mel_spec, cmap = plt.cm.magma, origin = 'lower', aspect = 'auto')
78 | plt.savefig(pic_path)
79 | plt.close()
80 |
81 | def get_args():
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument('--filename', type=str, default=None)
84 | parser.add_argument('--output', type=str, default=None)
85 | parser.add_argument('--mean', type=str, default=None)
86 | parser.add_argument('--std', type=str, default=None)
87 | return parser.parse_args()
88 |
89 | if __name__ == "__main__":
90 | args = get_args()
91 | import numpy as np
92 | # specplot_from_audio(args.filename, pic_path = args.output + '.aspec.png')
93 | data = np.load(args.filename).T
94 | mean = np.load(args.mean).T # (n_fft + 1, T)
95 | std = np.load(args.std).T
96 | if 'spec' in args.filename:
97 | specplot(data, mean, std, args.output + '.spec.png')
98 | if 'mel' in args.filename:
99 | melspecplot(data, args.output + '.melspec.png')
--------------------------------------------------------------------------------
/pyutils/save_and_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import re
4 | import sys
5 | import argparse
6 | import logging
7 | import json
8 | import subprocess
9 | import warnings
10 | import random
11 | import functools
12 |
13 | import librosa
14 | import numpy as np
15 | from scipy.io.wavfile import read
16 | import torch
17 | from torch.nn import functional as F
18 | # from modules.commons import sequence_mask
19 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
20 | logger = logging
21 |
22 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
23 | f_list = glob.glob(os.path.join(dir_path, regex))
24 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
25 | x = f_list[-1]
26 | print(x)
27 | return x
28 |
29 | def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None, skip_optimizer=False):
30 | assert os.path.isfile(checkpoint_path)
31 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
32 | iteration = checkpoint_dict['iteration']
33 | learning_rate = checkpoint_dict['learning_rate']
34 | if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
35 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
36 | if scheduler is not None and not skip_optimizer and checkpoint_dict['scheduler'] is not None:
37 | scheduler.load_state_dict(checkpoint_dict['scheduler'])
38 | saved_state_dict = checkpoint_dict['model']
39 | if hasattr(model, 'module'):
40 | state_dict = model.module.state_dict()
41 | else:
42 | state_dict = model.state_dict()
43 | new_state_dict = {}
44 | for k, v in state_dict.items():
45 | try:
46 | # assert "dec" in k or "disc" in k
47 | # print("load", k)
48 | new_state_dict[k] = saved_state_dict[k]
49 | assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
50 | except:
51 | print("error, %s is not in the checkpoint" % k)
52 | logger.info("%s is not in the checkpoint" % k)
53 | new_state_dict[k] = v
54 | if hasattr(model, 'module'):
55 | model.module.load_state_dict(new_state_dict)
56 | else:
57 | model.load_state_dict(new_state_dict)
58 | print("load")
59 | logger.info("Loaded checkpoint '{}' (iteration {})".format(
60 | checkpoint_path, iteration))
61 | return model, optimizer, scheduler, learning_rate, iteration
62 |
63 | def save_checkpoint(model, optimizer, scheduler, learning_rate, iteration, checkpoint_path):
64 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
65 | iteration, checkpoint_path))
66 | if hasattr(model, 'module'):
67 | state_dict = model.module.state_dict()
68 | else:
69 | state_dict = model.state_dict()
70 | torch.save({'model': state_dict,
71 | 'iteration': iteration,
72 | 'optimizer': optimizer.state_dict(),
73 | 'scheduler': scheduler.state_dict(),
74 | 'learning_rate': learning_rate},
75 | checkpoint_path)
76 |
77 | def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
78 | """Freeing up space by deleting saved ckpts
79 |
80 | Arguments:
81 | path_to_models -- Path to the model directory
82 | n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
83 | sort_by_time -- True -> chronologically delete ckpts
84 | False -> lexicographically delete ckpts
85 | """
86 | ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
87 | name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
88 | time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
89 | sort_key = time_key if sort_by_time else name_key
90 | x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
91 | to_del = [os.path.join(path_to_models, fn) for fn in
92 | (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
93 | del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
94 | del_routine = lambda x: [os.remove(x), del_info(x)]
95 | rs = [del_routine(fn) for fn in to_del]
96 |
97 | class HParams():
98 | def __init__(self, **kwargs):
99 | for k, v in kwargs.items():
100 | if type(v) == dict:
101 | v = HParams(**v)
102 | self[k] = v
103 |
104 | def keys(self):
105 | return self.__dict__.keys()
106 |
107 | def items(self):
108 | return self.__dict__.items()
109 |
110 | def values(self):
111 | return self.__dict__.values()
112 |
113 | def __len__(self):
114 | return len(self.__dict__)
115 |
116 | def __getitem__(self, key):
117 | return getattr(self, key)
118 |
119 | def __setitem__(self, key, value):
120 | return setattr(self, key, value)
121 |
122 | def __contains__(self, key):
123 | return key in self.__dict__
124 |
125 | def __repr__(self):
126 | return self.__dict__.__repr__()
127 |
128 |
--------------------------------------------------------------------------------
/pyutils/scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 | import torch
4 |
5 | class WarmupLR(_LRScheduler):
6 | """The WarmupLR scheduler
7 |
8 | This scheduler is almost same as NoamLR Scheduler except for following
9 | difference:
10 |
11 | NoamLR:
12 | lr = optimizer.lr * model_size ** -0.5
13 | * min(step ** -0.5, step * warmup_step ** -1.5)
14 | WarmupLR:
15 | lr = optimizer.lr * warmup_step ** 0.5
16 | * min(step ** -0.5, step * warmup_step ** -1.5)
17 |
18 | Note that the maximum lr equals to optimizer.lr in this scheduler.
19 |
20 | """
21 |
22 | def __init__(
23 | self,
24 | optimizer,
25 | warmup_steps = 25000,
26 | last_epoch = -1,
27 | ):
28 | self.warmup_steps = warmup_steps
29 |
30 | # __init__() must be invoked before setting field
31 | # because step() is also invoked in __init__()
32 | super().__init__(optimizer, last_epoch)
33 |
34 | def __repr__(self):
35 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
36 |
37 | def get_lr(self):
38 | step_num = self.last_epoch + 1
39 | if self.warmup_steps == 0:
40 | return [
41 | lr * step_num ** -0.5
42 | for lr in self.base_lrs
43 | ]
44 | else:
45 | return [
46 | lr
47 | * self.warmup_steps ** 0.5
48 | * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
49 | for lr in self.base_lrs
50 | ]
51 |
52 | def set_step(self, step: int):
53 | self.last_epoch = step
54 |
55 | class BaseClass:
56 | '''
57 | Base Class for learning rate scheduler
58 | '''
59 |
60 | def __init__(self,
61 | optimizer,
62 | num_epochs,
63 | epoch_iter,
64 | initial_lr,
65 | final_lr,
66 | warm_up_epoch=6,
67 | scale_ratio=1.0,
68 | warm_from_zero=False):
69 | '''
70 | warm_up_epoch: the first warm_up_epoch is the multiprocess warm-up stage
71 | scale_ratio: multiplied to the current lr in the multiprocess training
72 | process
73 | '''
74 | self.optimizer = optimizer
75 | self.max_iter = num_epochs * epoch_iter
76 | self.initial_lr = initial_lr
77 | self.final_lr = final_lr
78 | self.scale_ratio = scale_ratio
79 | self.current_iter = 0
80 | self.warm_up_iter = warm_up_epoch * epoch_iter
81 | self.warm_from_zero = warm_from_zero
82 |
83 | def get_multi_process_coeff(self):
84 | lr_coeff = 1.0 * self.scale_ratio
85 | if self.current_iter < self.warm_up_iter:
86 | if self.warm_from_zero:
87 | lr_coeff = self.scale_ratio * self.current_iter / self.warm_up_iter
88 | elif self.scale_ratio > 1:
89 | lr_coeff = (self.scale_ratio -
90 | 1) * self.current_iter / self.warm_up_iter + 1.0
91 |
92 | return lr_coeff
93 |
94 | def get_current_lr(self):
95 | '''
96 | This function should be implemented in the child class
97 | '''
98 | return 0.0
99 |
100 | def get_lr(self):
101 | return self.optimizer.param_groups[0]['lr']
102 |
103 | def set_lr(self):
104 | current_lr = self.get_current_lr()
105 | for param_group in self.optimizer.param_groups:
106 | param_group['lr'] = current_lr
107 |
108 | def step(self, current_iter=None):
109 | if current_iter is not None:
110 | self.current_iter = current_iter
111 |
112 | self.set_lr()
113 | self.current_iter += 1
114 |
115 | def step_return_lr(self, current_iter=None):
116 | if current_iter is not None:
117 | self.current_iter = current_iter
118 |
119 | current_lr = self.get_current_lr()
120 | self.current_iter += 1
121 |
122 | return current_lr
123 |
124 | class ExponentialDecrease(BaseClass):
125 |
126 | def __init__(self,
127 | optimizer,
128 | num_epochs,
129 | epoch_iter,
130 | initial_lr,
131 | final_lr,
132 | warm_up_epoch=6,
133 | scale_ratio=1.0,
134 | warm_from_zero=False):
135 | super().__init__(optimizer, num_epochs, epoch_iter, initial_lr,
136 | final_lr, warm_up_epoch, scale_ratio, warm_from_zero)
137 |
138 | def get_current_lr(self):
139 | lr_coeff = self.get_multi_process_coeff()
140 | current_lr = lr_coeff * self.initial_lr * math.exp(
141 | (self.current_iter / self.max_iter) *
142 | math.log(self.final_lr / self.initial_lr))
143 | return current_lr
144 |
145 | class TriAngular2(BaseClass):
146 | '''
147 | The implementation of https://arxiv.org/pdf/1506.01186.pdf
148 | '''
149 |
150 | def __init__(self,
151 | optimizer,
152 | num_epochs,
153 | epoch_iter,
154 | initial_lr,
155 | final_lr,
156 | warm_up_epoch=6,
157 | scale_ratio=1.0,
158 | cycle_step=2,
159 | reduce_lr_diff_ratio=0.5):
160 | super().__init__(optimizer, num_epochs, epoch_iter, initial_lr,
161 | final_lr, warm_up_epoch, scale_ratio)
162 |
163 | self.reduce_lr_diff_ratio = reduce_lr_diff_ratio
164 | self.cycle_iter = cycle_step * epoch_iter
165 | self.step_size = self.cycle_iter // 2
166 |
167 | self.max_lr = initial_lr
168 | self.min_lr = final_lr
169 | self.gap = self.max_lr - self.min_lr
170 |
171 | def get_current_lr(self):
172 | lr_coeff = self.get_multi_process_coeff()
173 | point = self.current_iter % self.cycle_iter
174 | cycle_index = self.current_iter // self.cycle_iter
175 |
176 | self.max_lr = self.min_lr + self.gap * self.reduce_lr_diff_ratio**cycle_index
177 |
178 | if point <= self.step_size:
179 | current_lr = self.min_lr + (self.max_lr -
180 | self.min_lr) * point / self.step_size
181 | else:
182 | current_lr = self.max_lr - (self.max_lr - self.min_lr) * (
183 | point - self.step_size) / self.step_size
184 |
185 | current_lr = lr_coeff * current_lr
186 |
187 | return current_lr
188 |
189 |
190 | def show_lr_curve(scheduler):
191 | import matplotlib.pyplot as plt
192 |
193 | lr_list = []
194 | for current_lr in range(0, scheduler.max_iter):
195 | lr_list.append(scheduler.step_return_lr(current_lr))
196 | data_index = list(range(1, len(lr_list) + 1))
197 |
198 | plt.plot(data_index, lr_list, '-o', markersize=1)
199 | plt.legend(loc='best')
200 | plt.xlabel("Iteration")
201 | plt.ylabel("LR")
202 |
203 | plt.show()
204 |
205 |
206 | if __name__ == '__main__':
207 | optimizer = None
208 | num_epochs = 6
209 | epoch_iter = 500
210 | initial_lr = 0.6
211 | final_lr = 0.1
212 | warm_up_epoch = 2
213 | scale_ratio = 4
214 | scheduler = ExponentialDecrease(optimizer, num_epochs, epoch_iter,
215 | initial_lr, final_lr, warm_up_epoch,
216 | scale_ratio)
217 | # scheduler = TriAngular2(optimizer,
218 | # num_epochs,
219 | # epoch_iter,
220 | # initial_lr,
221 | # final_lr,
222 | # warm_up_epoch,
223 | # scale_ratio,
224 | # cycle_step=2,
225 | # reduce_lr_diff_ratio=0.5)
226 |
227 | show_lr_curve(scheduler)
228 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | export PYTHONPATH=$(pwd):$PYTHONPATH
4 | start_stage=2
5 | stop_stage=2
6 | step=
7 | input=
8 | output=
9 |
10 | echo $PYTHONPATH
11 |
12 | echo "$0 $@" # Print the command line for logging
13 |
14 | . ./utils/parse_options.sh
15 |
16 | if [ ${start_stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
17 | echo "stage 0: Data Preparation"
18 | python3 preprocess/data_prep.py
19 | fi
20 |
21 | if [ ${start_stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
22 | echo "stage 1: Feature Extraction"
23 | python preprocess/audio_preprocess.py --data-config configs/data.yaml \
24 | --spec \
25 | --mel \
26 | --f0 \
27 | --energy \
28 | --stat
29 | fi
30 |
31 | if [ ${start_stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
32 | echo "stage 2: Training"
33 | export CUDA_VISIBLE_DEVICES=0,1
34 | python train_gan.py --data-config configs/svs/data.yaml \
35 | --model-config configs/svs/model.yaml \
36 | --train-config configs/svs/train.yaml \
37 | --num-gpus 2 \
38 | --dist-url 'tcp://localhost:30305'
39 | fi
40 |
41 | if [ ${start_stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
42 | echo "stage 3: Synthesizing"
43 | python synthesize.py --exp-name ${name} \
44 | --step ${step} \
45 | --input ${input} \
46 | --output ${output}
47 | fi
48 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import argparse
4 | import math
5 | import logging
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 | import torch.multiprocessing as mp
11 |
12 | from dataset import SVSDataset, SVSCollate
13 | from models import Xiaoice2 as Generator
14 | from loss import FastSpeech2Loss
15 | import pyutils
16 | from pyutils import (
17 | load_checkpoint,
18 | save_checkpoint,
19 | clean_checkpoints,
20 | latest_checkpoint_path,
21 | melspecplot,
22 | get_logger
23 | )
24 |
25 | import wandb
26 |
27 | logging.basicConfig(format = "%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s: %(message)s", level = logging.INFO)
28 |
29 | class Trainer():
30 | def __init__(self, rank, args, data_configs, model_configs, train_configs):
31 | self.rank = rank
32 | self.device = torch.device('cuda:{:d}'.format(rank))
33 | trainset = SVSDataset(data_configs)
34 | collate_fn = SVSCollate()
35 | sampler = torch.utils.data.DistributedSampler(trainset) if args.num_gpus > 1 else None
36 | self.trainloader = DataLoader(
37 | trainset,
38 | shuffle = False,
39 | sampler = sampler,
40 | collate_fn = collate_fn,
41 | batch_size = train_configs['batch_size'],
42 | pin_memory = True,
43 | num_workers = train_configs['num_workers'],
44 | prefetch_factor = 10
45 | )
46 |
47 | model_configs['generator']['transformer']['encoder']['n_src_vocab'] = trainset.get_phone_number() + 1
48 | model_configs['generator']['spk_num'] = trainset.get_spk_number()
49 |
50 | if args.num_gpus > 1:
51 | self.models = (
52 | nn.parallel.DistributedDataParallel(
53 | Generator(
54 | data_configs,
55 | model_configs['generator']
56 | ).to(self.device),
57 | device_ids = [rank]
58 | ),
59 | )
60 | else:
61 | self.models = (
62 | Generator(
63 | data_configs,
64 | model_configs['generator']
65 | ).to(self.device),
66 | )
67 | self.data_configs = data_configs
68 | self.model_configs = model_configs
69 | self.train_configs = train_configs
70 | self.args = args
71 |
72 | try:
73 | self.g_optimizer = getattr(
74 | torch.optim, train_configs['g_optimizer']
75 | )(self.models[0].parameters(), **train_configs['g_optimizer_args'])
76 |
77 | self.g_scheduler = getattr(
78 | pyutils.scheduler, train_configs['g_scheduler']
79 | )(self.g_optimizer, **train_configs['g_scheduler_args'])
80 | except:
81 | raise NotImplementedError("Unknown optimizer or scheduler")
82 |
83 | self.fs2loss = FastSpeech2Loss(data_configs)
84 |
85 | if self.rank == 0:
86 | self._make_exp_dir()
87 | self.logger = get_logger(os.path.join(self.args.exp_name, 'logs/train.log'))
88 |
89 | try:
90 | latest_ckpt_path = latest_checkpoint_path(
91 | os.path.join(self.args.exp_name, 'models'),
92 | 'G_*.pth'
93 | )
94 | _, _, _, _, epoch_str = load_checkpoint(
95 | latest_ckpt_path,
96 | self.models[0],
97 | self.g_optimizer,
98 | self.g_scheduler,
99 | False
100 | )
101 | self.start_epoch = max(epoch_str, 1)
102 | name = latest_ckpt_path
103 | self.total_step = int(name[name.rfind("_")+1:name.rfind(".")]) + 1
104 | except Exception:
105 | print("Load old checkpoint failed...")
106 | print("Start a new training...")
107 | self.start_epoch = 1
108 | self.total_step = 0
109 |
110 | self.epochs = self.train_configs['epochs']
111 |
112 | def _dump_args_and_config(self, filename, config):
113 | with open(os.path.join(self.args.exp_name, 'configs', filename) + '.yaml', 'w') as f:
114 | yaml.dump(config, f)
115 |
116 | def _make_exp_dir(self):
117 | os.makedirs(self.args.exp_name, exist_ok=True)
118 | os.makedirs(os.path.join(self.args.exp_name, 'configs'), exist_ok=True)
119 | os.makedirs(os.path.join(self.args.exp_name, 'models'), exist_ok=True)
120 | os.makedirs(os.path.join(self.args.exp_name, 'audios'), exist_ok=True)
121 | os.makedirs(os.path.join(self.args.exp_name, 'spectrograms'), exist_ok=True)
122 | os.makedirs(os.path.join(self.args.exp_name, 'melspectrograms'), exist_ok=True)
123 | os.makedirs(os.path.join(self.args.exp_name, 'eval_results'), exist_ok=True)
124 | os.makedirs(os.path.join(self.args.exp_name, 'logs'), exist_ok = True)
125 | with open(os.path.join(self.args.exp_name, 'model_arch.txt'), 'w') as f:
126 | for model in self.models:
127 | print(model, file = f)
128 | self._dump_args_and_config('args', vars(self.args))
129 | self._dump_args_and_config('data', self.data_configs)
130 | self._dump_args_and_config('model', self.model_configs)
131 | self._dump_args_and_config('train', self.train_configs)
132 |
133 | def train(self):
134 | for epoch in range(self.start_epoch, self.epochs + 1, 1):
135 | self.train_epoch(epoch)
136 |
137 | def train_epoch(self, epoch):
138 | self.total_loss = 0.0
139 | for batch_idx, data in enumerate(self.trainloader):
140 | output, postnet_output = self.train_batch(data, epoch, batch_idx)
141 | self.g_scheduler.step()
142 |
143 | if self.rank == 0 and self.total_step % self.train_configs['save_interval'] == 0:
144 | ckpt_path = os.path.join(self.args.exp_name, 'models', 'G_{}.pth'.format(self.total_step))
145 | save_checkpoint(
146 | self.models[0],
147 | self.g_optimizer,
148 | self.g_scheduler,
149 | self.g_scheduler.get_lr()[0],
150 | epoch,
151 | ckpt_path
152 | )
153 | length = data['mel_lens'][0]
154 | real_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', '{}_{}.png'.format(data['uttids'][0], self.total_step))
155 | before_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'before_{}_{}.png'.format(data['uttids'][0], self.total_step))
156 | after_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'after_{}_{}.png'.format(data['uttids'][0], self.total_step))
157 | melspecplot(data['mels'][0][:length, :].transpose(1, 0).numpy(), real_pic_path) # (n_mels, T)
158 | melspecplot(output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), before_pic_path)
159 | melspecplot(postnet_output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), after_pic_path)
160 |
161 | if self.rank == 0:
162 | clean_checkpoints(os.path.join(self.args.exp_name, 'models'), n_ckpts_to_keep = self.train_configs['ckpt_clean'])
163 |
164 | def _move_to_device(self, data):
165 | new_data = {}
166 | for k, v in data.items():
167 | if type(v) is torch.Tensor:
168 | new_data[k] = v.to(self.device)
169 | return new_data
170 |
171 | def train_batch(self, data, epoch, step):
172 | for model in self.models:
173 | model.train()
174 | new_data = self._move_to_device(data)
175 |
176 | self.g_optimizer.zero_grad()
177 | loss, report_keys, output, postnet_output = self.models[0](**new_data)
178 | loss.backward()
179 | grad_norm = nn.utils.clip_grad_norm_(self.models[0].parameters(), self.train_configs['grad_clip'])
180 | if math.isnan(grad_norm):
181 | raise ZeroDivisionError('Grad norm is nan')
182 | self.g_optimizer.step()
183 | self.total_loss += loss.item()
184 |
185 | self.total_step += 1
186 | if self.rank == 0:
187 | self.print_msg(epoch, step, report_keys) #, accuracy.item())
188 | wandb_log_dict = {
189 | 'train/avg_g_loss': self.total_loss / (step + 1),
190 | 'train/g_lr': self.g_scheduler.get_lr()[0]
191 | }
192 | for k, v in report_keys.items():
193 | wandb_log_dict['train/' + k] = v
194 | wandb.log(wandb_log_dict)
195 | return output, postnet_output
196 |
197 | def print_msg(self, epoch, step, report_keys):
198 | if self.total_step % self.train_configs['log_interval'] == 0:
199 | temp = ''
200 | for k, v in report_keys.items():
201 | temp += '{}: {:.6f} '.format(k, v)
202 | message = ('[Epoch: {} Step: {} Total steps: {}] ' + temp).format(
203 | epoch, step + 1, self.total_step
204 | )
205 | self.logger.info(message)
206 |
207 | def parse_args():
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument('--data-config', dest = 'data_config', type = str, default = './conf/data.yaml')
210 | parser.add_argument('--model-config', dest = 'model_config', type = str, default = './conf/model.yaml')
211 | parser.add_argument('--train-config', dest = 'train_config', type = str, default = './conf/train.yaml')
212 | parser.add_argument('--num-gpus', dest = 'num_gpus', type = int, default = 1)
213 | parser.add_argument('--exp-name', dest = 'exp_name', type = str, default = 'default')
214 | parser.add_argument('--dist-backend', dest = 'dist_backend', type = str, default = 'nccl')
215 | parser.add_argument('--dist-url', dest = 'dist_url', type = str, default = 'tcp://localhost:30302')
216 | return parser.parse_args()
217 |
218 | def main(rank, args, configs):
219 | if args.num_gpus > 1:
220 | torch.cuda.set_device(rank)
221 | torch.distributed.init_process_group(
222 | backend = args.dist_backend,
223 | init_method = args.dist_url,
224 | world_size = args.num_gpus,
225 | rank = rank
226 | )
227 |
228 | data_configs, model_configs, train_configs = configs
229 | args.exp_name = train_configs['wandb_args']['group'] + '-' + \
230 | train_configs['wandb_args']['job_type'] + '-' + \
231 | train_configs['wandb_args']['name']
232 | args.exp_name = os.path.join('exp', args.exp_name)
233 |
234 | # wandb initialization
235 | if train_configs['wandb']:
236 | wandb_configs = vars(args)
237 | for config in configs:
238 | wandb_configs.update(config)
239 | wandb.init(
240 | **train_configs['wandb_args'],
241 | config = wandb_configs
242 | )
243 |
244 | trainer = Trainer(rank, args, data_configs, model_configs, train_configs)
245 | trainer.train()
246 |
247 | if train_configs['wandb']:
248 | wandb.finish()
249 |
250 | if __name__ == "__main__":
251 | args = parse_args()
252 | args.exp_name = os.path.join('exp', args.exp_name)
253 | with open(args.data_config, 'r') as f:
254 | data_configs = yaml.load(f, Loader = yaml.FullLoader)
255 | with open(args.model_config, 'r') as f:
256 | model_configs = yaml.load(f, Loader = yaml.FullLoader)
257 | with open(args.train_config, 'r') as f:
258 | train_configs = yaml.load(f, Loader = yaml.FullLoader)
259 | configs = (data_configs, model_configs, train_configs)
260 |
261 | num_gpus = torch.cuda.device_count()
262 | if args.num_gpus > 1:
263 | mp.spawn(main, nprocs = num_gpus, args = (args, configs))
264 | else:
265 | main(0, args, configs)
266 |
--------------------------------------------------------------------------------
/train_gan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import argparse
4 | import math
5 | import shutil
6 | import logging
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | import torch.multiprocessing as mp
12 |
13 | from dataset import SVSDataset as Dataset
14 | from dataset import SVSCollate as Collate
15 | from models import Xiaoice2 as Generator
16 | from models import Discriminator as Discriminator
17 | from loss import FastSpeech2Loss, FeatLoss, LSGANGLoss, LSGANDLoss
18 | import pyutils
19 | from pyutils import (
20 | load_checkpoint,
21 | save_checkpoint,
22 | clean_checkpoints,
23 | latest_checkpoint_path,
24 | melspecplot,
25 | get_logger
26 | )
27 |
28 | import wandb
29 |
30 | logging.basicConfig(format = "%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s: %(message)s", level = logging.INFO)
31 |
32 | python_script = os.path.realpath(__file__)
33 |
34 | class Trainer():
35 | def __init__(self, rank, args, data_configs, model_configs, train_configs):
36 | self.rank = rank
37 | self.device = torch.device('cuda:{:d}'.format(rank))
38 | trainset = Dataset(data_configs)
39 | collate_fn = Collate()
40 | sampler = torch.utils.data.DistributedSampler(trainset) if args.num_gpus > 1 else None
41 | self.trainloader = DataLoader(
42 | trainset,
43 | shuffle = False,
44 | sampler = sampler,
45 | collate_fn = collate_fn,
46 | batch_size = train_configs['batch_size'],
47 | pin_memory = True,
48 | num_workers = train_configs['num_workers'],
49 | prefetch_factor = 10
50 | )
51 |
52 | model_configs['generator']['transformer']['encoder']['n_src_vocab'] = trainset.get_phone_number() + 1
53 | model_configs['generator']['spk_num'] = trainset.get_spk_number()
54 |
55 | if args.num_gpus > 1:
56 | self.models = (
57 | nn.parallel.DistributedDataParallel(
58 | Generator(
59 | data_configs,
60 | model_configs['generator']
61 | ).to(self.device),
62 | device_ids = [rank]
63 | ),
64 | nn.parallel.DistributedDataParallel(
65 | Discriminator().to(self.device),
66 | device_ids = [rank]
67 | )
68 | )
69 | else:
70 | self.models = (
71 | Generator(
72 | data_configs,
73 | model_configs['generator']
74 | ).to(self.device),
75 | Discriminator().to(self.device)
76 | )
77 | self.data_configs = data_configs
78 | self.model_configs = model_configs
79 | self.train_configs = train_configs
80 | self.args = args
81 |
82 | try:
83 | self.g_optimizer = getattr(
84 | torch.optim, train_configs['g_optimizer']
85 | )(self.models[0].parameters(), **train_configs['g_optimizer_args'])
86 |
87 | self.g_scheduler = getattr(
88 | pyutils.scheduler, train_configs['g_scheduler']
89 | )(self.g_optimizer, **train_configs['g_scheduler_args'])
90 |
91 | self.d_optimizer = getattr(
92 | torch.optim, train_configs['d_optimizer']
93 | )(self.models[1].parameters(), **train_configs['d_optimizer_args'])
94 |
95 | self.d_scheduler = getattr(
96 | pyutils.scheduler, train_configs['d_scheduler']
97 | )(self.d_optimizer, **train_configs['d_scheduler_args'])
98 | except:
99 | raise NotImplementedError("Unknown optimizer or scheduler")
100 |
101 | self.fs2loss = FastSpeech2Loss(data_configs)
102 | self.feat_loss = FeatLoss(train_configs['feat_loss_weight'])
103 | self.adv_g_loss = LSGANGLoss(train_configs['adv_g_loss_weight'])
104 | self.adv_d_loss = LSGANDLoss()
105 |
106 | if self.rank == 0:
107 | self._make_exp_dir()
108 | self.logger = get_logger(os.path.join(self.args.exp_name, 'logs/train.log'))
109 |
110 | try:
111 | latest_gckpt_path = latest_checkpoint_path(
112 | os.path.join(self.args.exp_name, 'models'),
113 | 'G_*.pth'
114 | )
115 | latest_dckpt_path = latest_checkpoint_path(
116 | os.path.join(self.args.exp_name, 'models'),
117 | 'D_*.pth'
118 | )
119 | _, _, _, _, epoch_str = load_checkpoint(
120 | latest_gckpt_path,
121 | self.models[0],
122 | self.g_optimizer,
123 | self.g_scheduler,
124 | False
125 | )
126 | _, _, _, _, epoch_str = load_checkpoint(
127 | latest_dckpt_path,
128 | self.models[1],
129 | self.d_optimizer,
130 | self.d_scheduler,
131 | False
132 | )
133 | self.start_epoch = max(epoch_str, 1)
134 | name = latest_gckpt_path
135 | self.total_step = int(name[name.rfind("_")+1:name.rfind(".")])+1
136 | except Exception:
137 | print("Load old checkpoint failed...")
138 | print("Start a new training...")
139 | self.start_epoch = 1
140 | self.total_step = 0
141 |
142 | self.epochs = self.train_configs['epochs']
143 | self.start_disc_steps = self.train_configs['start_disc_steps']
144 |
145 | def _dump_args_and_config(self, filename, config):
146 | with open(os.path.join(self.args.exp_name, 'conf', filename) + '.yaml', 'w') as f:
147 | yaml.dump(config, f)
148 |
149 | def _make_exp_dir(self):
150 | os.makedirs(self.args.exp_name, exist_ok=True)
151 | os.makedirs(os.path.join(self.args.exp_name, 'conf'), exist_ok=True)
152 | os.makedirs(os.path.join(self.args.exp_name, 'models'), exist_ok=True)
153 | os.makedirs(os.path.join(self.args.exp_name, 'audios'), exist_ok=True)
154 | os.makedirs(os.path.join(self.args.exp_name, 'spectrograms'), exist_ok=True)
155 | os.makedirs(os.path.join(self.args.exp_name, 'melspectrograms'), exist_ok=True)
156 | os.makedirs(os.path.join(self.args.exp_name, 'eval_results'), exist_ok=True)
157 | os.makedirs(os.path.join(self.args.exp_name, 'logs'), exist_ok = True)
158 | with open(os.path.join(self.args.exp_name, 'model_arch.txt'), 'w') as f:
159 | for model in self.models:
160 | print(model, file = f)
161 | self._dump_args_and_config('args', vars(self.args))
162 | self._dump_args_and_config('data', self.data_configs)
163 | self._dump_args_and_config('model', self.model_configs)
164 | self._dump_args_and_config('train', self.train_configs)
165 | basename = os.path.basename(python_script)
166 | shutil.copyfile(python_script, os.path.join(self.args.exp_name, basename))
167 |
168 | def train(self):
169 | for epoch in range(self.start_epoch, self.epochs + 1, 1):
170 | self.train_epoch(epoch)
171 |
172 | def train_epoch(self, epoch):
173 | self.total_g_loss = 0.0
174 | self.total_d_loss = 0.0
175 | for batch_idx, data in enumerate(self.trainloader):
176 | output, postnet_output = self.train_batch(data, epoch, batch_idx)
177 | self.g_scheduler.step()
178 | self.d_scheduler.step()
179 |
180 | if self.rank == 0 and self.total_step % self.train_configs['save_interval'] == 0:
181 | gckpt_path = os.path.join(self.args.exp_name, 'models', 'G_{}.pth'.format(self.total_step))
182 | save_checkpoint(
183 | self.models[0],
184 | self.g_optimizer,
185 | self.g_scheduler,
186 | self.g_scheduler.get_lr()[0],
187 | epoch,
188 | gckpt_path
189 | )
190 | dckpt_path = os.path.join(self.args.exp_name, 'models', 'D_{}.pth'.format(self.total_step))
191 | save_checkpoint(
192 | self.models[1],
193 | self.d_optimizer,
194 | self.d_scheduler,
195 | self.d_scheduler.get_lr()[0],
196 | epoch,
197 | dckpt_path
198 | )
199 | length = data['mel_lens'][0]
200 | real_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', '{}_{}.png'.format(data['uttids'][0], self.total_step))
201 | before_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'before_{}_{}.png'.format(data['uttids'][0], self.total_step))
202 | after_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'after_{}_{}.png'.format(data['uttids'][0], self.total_step))
203 | melspecplot(data['mels'][0][:length, :].transpose(1, 0).numpy(), real_pic_path) # (n_mels, T)
204 | melspecplot(output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), before_pic_path)
205 | melspecplot(postnet_output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), after_pic_path)
206 |
207 | if self.rank == 0:
208 | clean_checkpoints(
209 | os.path.join(self.args.exp_name, 'models'),
210 | n_ckpts_to_keep = self.train_configs['ckpt_clean']
211 | )
212 |
213 | def _move_to_device(self, data):
214 | new_data = {}
215 | for k, v in data.items():
216 | if type(v) is torch.Tensor:
217 | new_data[k] = v.to(self.device)
218 | return new_data
219 |
220 | def train_batch(self, data, epoch, step):
221 | for model in self.models:
222 | model.train()
223 | new_data = self._move_to_device(data)
224 |
225 | # loss, report_keys, output, postnet_output = self.models[0](**new_data)
226 | self.g_optimizer.zero_grad()
227 | g_loss, report_keys, output, postnet_output = self.models[0](**new_data)
228 | if self.total_step >= self.train_configs['start_disc_steps']:
229 | d_fake, random_N = self.models[1](output, new_data['mel_lens'], new_data['mels'])
230 | feat_loss, feat_loss_report_keys = self.feat_loss(d_fake)
231 | adv_g_loss, adv_gloss_report_keys = self.adv_g_loss(d_fake)
232 | g_loss += feat_loss
233 | g_loss += adv_g_loss
234 | report_keys.update(feat_loss_report_keys)
235 | report_keys.update(adv_gloss_report_keys)
236 | g_loss.backward()
237 |
238 | grad_norm = nn.utils.clip_grad_norm_(self.models[0].parameters(), self.train_configs['grad_clip'])
239 | if math.isnan(grad_norm):
240 | raise ZeroDivisionError('Grad norm is nan')
241 | self.g_optimizer.step()
242 | self.total_g_loss += g_loss.item()
243 |
244 | if self.total_step >= self.train_configs['start_disc_steps']:
245 | self.d_optimizer.zero_grad()
246 | d_fake, _ = self.models[1](
247 | output.detach(),
248 | new_data['mel_lens'],
249 | new_data['mels'],
250 | random_N
251 | )
252 | adv_d_loss, adv_dloss_report_keys = self.adv_d_loss(d_fake)
253 | adv_d_loss.backward()
254 | report_keys.update(adv_dloss_report_keys)
255 | grad_norm = nn.utils.clip_grad_norm_(self.models[1].parameters(), self.train_configs['grad_clip'])
256 | if math.isnan(grad_norm):
257 | raise ZeroDivisionError('Grad norm is nan')
258 | self.d_optimizer.step()
259 | self.total_d_loss += adv_d_loss.item()
260 |
261 | self.total_step += 1
262 | if self.rank == 0:
263 | self.print_msg(epoch, step, report_keys) #, accuracy.item())
264 | wandb_log_dict = {
265 | 'train/avg_g_loss': self.total_g_loss / (step + 1),
266 | 'train/avg_d_loss': self.total_d_loss / (step + 1),
267 | 'train/g_lr': self.g_scheduler.get_lr()[0],
268 | 'train/d_lr': self.d_scheduler.get_lr()[0]
269 | }
270 | for k, v in report_keys.items():
271 | wandb_log_dict['train/' + k] = v
272 | if self.train_configs['wandb']:
273 | wandb.log(wandb_log_dict)
274 | return output, postnet_output
275 |
276 | def print_msg(self, epoch, step, report_keys):
277 | if self.total_step % self.train_configs['log_interval'] == 0:
278 | temp = ''
279 | for k, v in report_keys.items():
280 | temp += '{}: {:.6f} '.format(k, v)
281 | message = ('[Epoch: {} Step: {} Total steps: {}] ' + temp).format(
282 | epoch, step + 1, self.total_step
283 | )
284 | self.logger.info(message)
285 |
286 | def parse_args():
287 | parser = argparse.ArgumentParser()
288 | parser.add_argument('--data-config', dest = 'data_config', type = str, default = './conf/data.yaml')
289 | parser.add_argument('--model-config', dest = 'model_config', type = str, default = './conf/model.yaml')
290 | parser.add_argument('--train-config', dest = 'train_config', type = str, default = './conf/train.yaml')
291 | parser.add_argument('--num-gpus', dest = 'num_gpus', type = int, default = 1)
292 | # parser.add_argument('--exp-name', dest = 'exp_name', type = str, default = 'default')
293 | parser.add_argument('--dist-backend', dest = 'dist_backend', type = str, default = 'nccl')
294 | parser.add_argument('--dist-url', dest = 'dist_url', type = str, default = 'tcp://localhost:30302')
295 | return parser.parse_args()
296 |
297 | def main(rank, args, configs):
298 | if args.num_gpus > 1:
299 | torch.cuda.set_device(rank)
300 | torch.distributed.init_process_group(
301 | backend = args.dist_backend,
302 | init_method = args.dist_url,
303 | world_size = args.num_gpus,
304 | rank = rank
305 | )
306 |
307 | data_configs, model_configs, train_configs = configs
308 | args.exp_name = train_configs['wandb_args']['group'] + '-' + \
309 | train_configs['wandb_args']['job_type'] + '-' + \
310 | train_configs['wandb_args']['name']
311 | args.exp_name = os.path.join('exp', args.exp_name)
312 |
313 | # wandb initialization
314 | if train_configs['wandb']:
315 | wandb_configs = vars(args)
316 | for config in configs:
317 | wandb_configs.update(config)
318 | wandb.init(
319 | **train_configs['wandb_args'],
320 | config = wandb_configs
321 | )
322 |
323 | trainer = Trainer(rank, args, data_configs, model_configs, train_configs)
324 | trainer.train()
325 |
326 | if train_configs['wandb']:
327 | wandb.finish()
328 |
329 | if __name__ == "__main__":
330 | args = parse_args()
331 | with open(args.data_config, 'r') as f:
332 | data_configs = yaml.load(f, Loader = yaml.FullLoader)
333 | with open(args.model_config, 'r') as f:
334 | model_configs = yaml.load(f, Loader = yaml.FullLoader)
335 | with open(args.train_config, 'r') as f:
336 | train_configs = yaml.load(f, Loader = yaml.FullLoader)
337 | configs = (data_configs, model_configs, train_configs)
338 |
339 | num_gpus = torch.cuda.device_count()
340 | if args.num_gpus > 1:
341 | mp.spawn(main, nprocs = num_gpus, args = (args, configs))
342 | else:
343 | main(0, args, configs)
344 |
--------------------------------------------------------------------------------
/utils:
--------------------------------------------------------------------------------
1 | /home/smg/zengchang/apps/kaldi/egs/wsj/s5/utils
--------------------------------------------------------------------------------