├── .gitignore ├── LICENSE ├── README.md ├── configs ├── svs │ ├── data.yaml │ ├── model.yaml │ └── train.yaml └── tts │ ├── data.yaml │ ├── model.yaml │ └── train.yaml ├── dataset ├── __init__.py ├── dataset_svs.py ├── dataset_tts.py ├── espnet_texts │ ├── __init__.py │ ├── cleaners.py │ ├── cmudict.py │ ├── dict.py │ ├── numbers.py │ └── symbols.py └── texts │ ├── __init__.py │ ├── cleaners.py │ ├── cmudict.py │ ├── numbers.py │ ├── pinyin.py │ └── symbols.py ├── lexicon ├── libritts-lexicon.txt └── pinyin-lexicon.txt ├── loss ├── __init__.py ├── fastspeech2_loss.py └── loss.py ├── models ├── __init__.py ├── discriminator.py ├── fastspeech2.py └── xiaoice2.py ├── modules ├── __init__.py ├── conv │ └── __init__.py ├── transformer │ ├── Constants.py │ ├── Layers.py │ ├── Models.py │ ├── Modules.py │ ├── SubLayers.py │ └── __init__.py └── variance │ ├── __init__.py │ └── modules.py ├── pics ├── 2085003136_145600.png ├── after_2085003136_145600.png ├── before_2085003136_145600.png ├── before_mel_l2_loss.png ├── post_mel_l2_loss.png └── xs1_before_2085003136_145600.png ├── preprocess ├── audio_preprocess.py └── data_prep.py ├── pyutils ├── __init__.py ├── gen_duration_from_tg.py ├── logger.py ├── mask.py ├── optimizer.py ├── parse_options.sh ├── plot.py ├── save_and_load.py └── scheduler.py ├── run.sh ├── train.py ├── train_gan.py └── utils /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | # 162 | job.sh 163 | data 164 | data/* 165 | exp 166 | exp/* 167 | *.out 168 | wandb 169 | wandb/* 170 | .nfs* 171 | local_run.sh 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, zengchang233 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [XiaoiceSing2](https://www.isca-speech.org/archive/interspeech_2023/chunhui23_interspeech.html) 2 | The source code for the paper [XiaoiceSing2](https://www.isca-speech.org/archive/interspeech_2023/chunhui23_interspeech.html) (interspeech2023) 3 | 4 | [Demo page](https://wavelandspeech.github.io/xiaoice2/) 5 | 6 | ## Notice 7 | 8 | I am busy with job-hunting now. I will update other modules, including the [HiFi-WaveGAN](https://arxiv.org/abs/2210.12740) after my final decision. 9 | 10 | ## Implementation (developping) 11 | 12 | - [x] fastspeech2-based generator 13 | - [x] discriminator group, including segment discriminators and detail discriminators 14 | - [ ] ConvFFT block 15 | 16 | ## Dataset and preparation 17 | 18 | - [x] opencpop ![cn](https://raw.githubusercontent.com/gosquared/flags/master/flags/flags/shiny/24/China.png) 19 | - [ ] kiritan ![jp](https://raw.githubusercontent.com/gosquared/flags/master/flags/flags/shiny/24/Japan.png) 20 | - [ ] CSD ![kr](https://raw.githubusercontent.com/gosquared/flags/master/flags/flags/shiny/24/South-Korea.png) 21 | - [ ] m4singer ![cn](https://raw.githubusercontent.com/gosquared/flags/master/flags/flags/shiny/24/China.png) 22 | - [ ] NUS48E 23 | 24 | Kaldi style preparation 25 | 26 | - wav.scp 27 | - utt2spk 28 | - spk2utt 29 | - text 30 | 31 | ``` 32 | ./run.sh --start-stage 1 --stop-stage 1 # extract melspectrogram, f0, energy, and statistical value 33 | ``` 34 | 35 | ## Training 36 | 37 | ``` 38 | ./run.sh --start-stage 2 --stop-stage 2 39 | ``` 40 | 41 | ### Real and generated melspectrogram (145600 training steps) 42 | 43 | Real(left) XiaoiceSing(middle) XiaoiceSing2(right) 44 | 45 |
46 | real 47 | xs1 48 | xs2 49 |
50 | 51 | ### L2 loss curve for melspectrogram 52 | 53 | L2 loss before post-processing(left) L2 loss after post-processing(right) 54 | 55 |
56 | before 57 | after 58 |
59 | 60 | ## Inference 61 | 62 | ``` 63 | ./run.sh --start-stage 3 --stop-stage 3 64 | ``` 65 | -------------------------------------------------------------------------------- /configs/svs/data.yaml: -------------------------------------------------------------------------------- 1 | audio_manifest: 'data/opencpop/train.scp' 2 | svs_manifest: 'data/opencpop/train.txt' 3 | spk_manifest: 'data/opencpop/utt2spk' 4 | f0_min_max: 'data/opencpop/f0_min_max.npy' 5 | f0_mean: 'data/opencpop/f0_mean.npy' 6 | f0_std: 'data/opencpop/f0_std.npy' 7 | energy_min_max: 'data/opencpop/energy_min_max.npy' 8 | energy_mean: 'data/opencpop/energy_mean.npy' 9 | energy_std: 'data/opencpop/energy_std.npy' 10 | phone_set: 'data/opencpop/phone_set.txt' 11 | 12 | n_fft: 1024 13 | n_mels: 120 14 | hop_length: 256 15 | win_length: 1024 16 | sampling_rate: 44100 17 | seg_size: 700 18 | fmin: 0.0 19 | fmax: 22050 20 | 21 | tts_cleaner_names: [] 22 | use_phonemes: True 23 | eos: False 24 | 25 | pitch: 26 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 27 | normalization: True 28 | energy: 29 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 30 | normalization: True 31 | -------------------------------------------------------------------------------- /configs/svs/model.yaml: -------------------------------------------------------------------------------- 1 | generator: 2 | transformer: 3 | encoder: 4 | max_seq_len: 5000 5 | n_src_vocab: 100 # random number, will be reassigned in train.py 6 | d_word_vec: 512 # dimension of word vector 7 | n_layers: 6 8 | n_head: 8 9 | d_model: 512 10 | d_inner: 2048 11 | kernel_size: [9, 1] 12 | dropout: 0.2 13 | max_note_pitch: 88 14 | max_note_duration: 2000 15 | decoder: 16 | max_seq_len: 5000 17 | d_word_vec: 512 18 | n_layers: 6 19 | n_head: 8 20 | d_model: 512 21 | d_inner: 2048 22 | kernel_size: [9, 1] 23 | dropout: 0.2 24 | 25 | variance_predictor: 26 | input_size: 512 27 | filter_size: 512 28 | kernel_size: 3 29 | dropout: 0.5 30 | 31 | variance_embedding: 32 | pitch_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 33 | energy_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 34 | n_bins: 256 35 | 36 | multi_speaker: False 37 | uv_threshold: 0.5 38 | 39 | postnet: 40 | postnet_embedding_dim: 512 41 | postnet_kernel_size: 5 42 | postnet_n_convolutions: 5 43 | 44 | discriminator: 45 | segment_disc: 46 | pass 47 | 48 | detail_disc: 49 | pass 50 | 51 | vocoder: 52 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN' 53 | speaker: "universal" # support 'LJSpeech', 'universal' 54 | -------------------------------------------------------------------------------- /configs/svs/train.yaml: -------------------------------------------------------------------------------- 1 | epochs: 800 2 | batch_size: 8 3 | grad_clip: 1.0 4 | num_workers: 8 5 | 6 | feat_loss_weight: [1.0, 1.0, 1.0] 7 | adv_g_loss_weight: [0.1, 0.1, 0.1] 8 | start_disc_steps: 5000 9 | 10 | g_optimizer: 'Adam' 11 | g_optimizer_args: 12 | lr: 0.0001 13 | betas: [0.9, 0.98] 14 | eps: 0.000000001 15 | weight_decay: 0.0 16 | 17 | g_scheduler: 'WarmupLR' 18 | g_scheduler_args: 19 | warmup_steps: 4000 20 | last_epoch: -1 21 | 22 | d_optimizer: 'Adam' 23 | d_optimizer_args: 24 | lr: 0.0001 25 | betas: [0.9, 0.98] 26 | eps: 0.000000001 27 | weight_decay: 0.0 28 | 29 | d_scheduler: 'WarmupLR' 30 | d_scheduler_args: 31 | warmup_steps: 4000 32 | last_epoch: -1 33 | 34 | wandb: True 35 | wandb_args: 36 | project: 'svs' 37 | group: 'xiaoicesing2' 38 | job_type: 'opencpop' 39 | name: 'warmup4k-disc5k' 40 | 41 | log_interval: 10 42 | save_interval: 200 43 | ckpt_clean: 10 44 | 45 | -------------------------------------------------------------------------------- /configs/tts/data.yaml: -------------------------------------------------------------------------------- 1 | audio_manifest: 'data/wav.scp' 2 | duration_manifest: 'data/aishell3/train/duration.txt' 3 | raw_text_manifest: 'data/aishell3/train/raw_text' 4 | spk_manifest: 'data/aishell3/train/utt2spk' 5 | f0_min_max: 'data/f0_min_max.npy' 6 | energy_min_max: 'data/energy_min_max.npy' 7 | 8 | n_fft: 1024 9 | n_mels: 120 10 | hop_length: 256 11 | win_length: 1024 12 | sampling_rate: 44100 13 | seg_size: 700 14 | fmin: 0.0 15 | fmax: 22050 16 | 17 | tts_cleaner_names: [] 18 | use_phonemes: True 19 | eos: False 20 | 21 | pitch: 22 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 23 | normalization: True 24 | energy: 25 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 26 | normalization: True 27 | -------------------------------------------------------------------------------- /configs/tts/model.yaml: -------------------------------------------------------------------------------- 1 | generator: 2 | transformer: 3 | encoder: 4 | max_seq_len: 5000 5 | n_src_vocab: 100 # random number, will be reassigned in train.py 6 | d_word_vec: 512 # dimension of word vector 7 | n_layers: 6 8 | n_head: 8 9 | d_model: 512 10 | d_inner: 2048 11 | kernel_size: [9, 1] 12 | dropout: 0.2 13 | decoder: 14 | max_seq_len: 5000 15 | d_word_vec: 512 16 | n_layers: 6 17 | n_head: 8 18 | d_model: 512 19 | d_inner: 2048 20 | kernel_size: [9, 1] 21 | dropout: 0.2 22 | 23 | variance_predictor: 24 | input_size: 512 25 | filter_size: 512 26 | kernel_size: 3 27 | dropout: 0.5 28 | 29 | variance_embedding: 30 | pitch_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 31 | energy_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 32 | n_bins: 256 33 | 34 | multi_speaker: False 35 | uv_threshold: 0.5 36 | 37 | postnet: 38 | postnet_embedding_dim: 512 39 | postnet_kernel_size: 5 40 | postnet_n_convolutions: 5 41 | 42 | discriminator: 43 | segment_disc: 44 | pass 45 | 46 | detail_disc: 47 | pass 48 | 49 | vocoder: 50 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN' 51 | speaker: "universal" # support 'LJSpeech', 'universal' 52 | -------------------------------------------------------------------------------- /configs/tts/train.yaml: -------------------------------------------------------------------------------- 1 | epochs: 800 2 | batch_size: 8 3 | grad_clip: 1.0 4 | num_workers: 8 5 | 6 | feat_loss_weight: [1.0, 1.0, 1.0] 7 | adv_g_loss_weight: [0.1, 0.1, 0.1] 8 | start_disc_steps: 5000 9 | 10 | g_optimizer: 'Adam' 11 | g_optimizer_args: 12 | lr: 0.0001 13 | betas: [0.9, 0.98] 14 | eps: 0.000000001 15 | weight_decay: 0.0 16 | 17 | g_scheduler: 'WarmupLR' 18 | g_scheduler_args: 19 | warmup_steps: 4000 20 | last_epoch: -1 21 | 22 | d_optimizer: 'Adam' 23 | d_optimizer_args: 24 | lr: 0.0001 25 | betas: [0.9, 0.98] 26 | eps: 0.000000001 27 | weight_decay: 0.0 28 | 29 | d_scheduler: 'WarmupLR' 30 | d_scheduler_args: 31 | warmup_steps: 4000 32 | last_epoch: -1 33 | 34 | wandb: True 35 | wandb_args: 36 | project: 'tts' 37 | group: 'cross-lingual' 38 | job_type: 'fs2_GAN' 39 | name: 'fs2GAN_aishell3-warmup4k-disc5k' 40 | 41 | log_interval: 10 42 | save_interval: 200 43 | ckpt_clean: 10 44 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_tts import * 2 | from .dataset_svs import * 3 | from .texts import * 4 | -------------------------------------------------------------------------------- /dataset/dataset_svs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.insert(0, '/home/smg/zengchang/code/svs/xiaoicesing2') 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import numpy as np 11 | from pyutils import pad_list, remove_outlier 12 | from librosa import note_to_midi 13 | 14 | import json 15 | import os 16 | 17 | def f02pitch(f0): 18 | #f0 =f0 + 0.01 19 | return np.log2(f0 / 27.5) * 12 + 21 20 | 21 | def pitch2f0(pitch): 22 | f0 = np.exp2((pitch - 21 ) / 12) * 27.5 23 | for i in range(len(f0)): 24 | if f0[i] <= 10: 25 | f0[i] = 0 26 | return f0 27 | 28 | def pitchxuv(pitch, uv, to_f0 = False): 29 | result = pitch * uv 30 | if to_f0: 31 | result = pitch2f0(result) 32 | return result 33 | 34 | def pad1d(x, max_len): 35 | return np.pad(x, (0, max_len - len(x)), mode="constant") 36 | 37 | def pad2d(x, max_len): 38 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant") 39 | 40 | def interpolate_f0(f0): 41 | data = np.reshape(f0, (f0.size, 1)) 42 | 43 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32) 44 | vuv_vector[data > 0.0] = 1.0 45 | vuv_vector[data <= 0.0] = 0.0 46 | 47 | ip_data = data 48 | 49 | frame_number = data.size 50 | last_value = 0.0 51 | for i in range(frame_number): 52 | if data[i] <= 0.0: 53 | j = i + 1 54 | for j in range(i + 1, frame_number): 55 | if data[j] > 0.0: 56 | break 57 | if j < frame_number - 1: 58 | if last_value > 0.0: 59 | step = (data[j] - data[i - 1]) / float(j - i) 60 | for k in range(i, j): 61 | ip_data[k] = data[i - 1] + step * (k - i + 1) 62 | else: 63 | for k in range(i, j): 64 | ip_data[k] = data[j] 65 | else: 66 | for k in range(i, frame_number): 67 | ip_data[k] = last_value 68 | else: 69 | ip_data[i] = data[i] # this may not be necessary 70 | last_value = data[i] 71 | 72 | return ip_data[:,0], vuv_vector[:,0] 73 | 74 | class SVSDataset(Dataset): 75 | def __init__(self, configs): 76 | audio_manifest = configs['audio_manifest'] 77 | transcription_manifest = configs['svs_manifest'] 78 | spk_manifest = configs['spk_manifest'] 79 | self.sampling_rate = configs['sampling_rate'] 80 | self.utt2path = {} 81 | self.utt2raw_text = {} 82 | self.utt2phone_seq = {} 83 | self.utt2note_pitch = {} 84 | self.utt2note_dur = {} 85 | self.utt2dur = {} 86 | self.utt2spk = {} 87 | hop_length = configs['hop_length'] / self.sampling_rate 88 | with open(audio_manifest, 'r') as f: 89 | for line in f: 90 | line = line.rstrip().split(' ') 91 | self.utt2path[line[0]] = line[1] 92 | with open(transcription_manifest, 'r') as f: 93 | for line in f: 94 | line = line.rstrip().split('|') 95 | self.utt2raw_text[line[0]] = line[1] 96 | self.utt2phone_seq[line[0]] = line[2].split(' ') 97 | self.utt2note_pitch[line[0]] = [note_to_midi(note.split('/')[0]) if note != 'rest' else 0 for note in line[3].split(' ')] 98 | self.utt2note_dur[line[0]] = [round(eval(dur) / hop_length) for dur in line[4].split(' ')] 99 | self.utt2dur[line[0]] = [round(eval(dur) / hop_length) for dur in line[5].split(' ')] 100 | with open(spk_manifest, 'r') as f: 101 | for line in f: 102 | line = line.rstrip().split(' ') 103 | self.utt2spk[line[0]] = line[1] 104 | if not os.path.exists(configs['phone_set']): 105 | phone_set = set() 106 | for phone_seq in self.utt2phone_seq.values(): 107 | phone_set.update(phone_seq) 108 | phone_set = list(phone_set) 109 | phone_set.sort() 110 | with open(configs['phone_set'], 'w') as f: 111 | json.dump(phone_set, f) 112 | self.phone_set = phone_set 113 | else: 114 | with open(configs['phone_set'], 'r') as f: 115 | self.phone_set = json.load(f) 116 | 117 | self.spk2int = {spk: idx for idx, spk in enumerate(set(self.utt2spk.values()))} 118 | self.int2spk = {idx: spk for spk, idx in self.spk2int.items()} 119 | self.phone2idx = {phone: idx for idx, phone in enumerate(self.phone_set)} 120 | self.utt = list(self.utt2path.keys()) 121 | 122 | def _norm_mean_std(self, x, mean, std, is_remove_outlier=False): 123 | if is_remove_outlier: 124 | x = remove_outlier(x) 125 | zero_idxs = np.where(x == 0.0)[0] 126 | x = (x - mean) / std 127 | x[zero_idxs] = 0.0 128 | return x 129 | 130 | def get_spk_number(self): 131 | return len(self.spk2int) 132 | 133 | def get_phone_number(self): 134 | return len(self.phone2idx) 135 | 136 | def __len__(self): 137 | return len(self.utt) 138 | 139 | def __getitem__(self, idx): 140 | uttid = self.utt[idx] 141 | mel_path = self.utt2path[uttid].replace('.wav', '.mel.npy') 142 | f0_path = self.utt2path[uttid].replace('.wav', '.f0.npy') 143 | energy_path = self.utt2path[uttid].replace('.wav', '.en.npy') 144 | 145 | mel = np.load(mel_path) #.transpose(1, 0) 146 | f0 = np.load(f0_path) 147 | f0, uv = interpolate_f0(f0) 148 | # unnormalized_f0 = self.f0_std * f0 + self.f0_mean 149 | pitch = f02pitch(f0) 150 | energy = np.load(energy_path) 151 | # energy = self.energy_std * energy + self.energy_mean 152 | 153 | raw_text = self.utt2raw_text[uttid] 154 | phone_text = self.utt2phone_seq[uttid] 155 | phone_seq = np.array([self.phone2idx[phone] for phone in phone_text]) 156 | note_pitch = np.array(self.utt2note_pitch[uttid]) 157 | note_duration = np.array(self.utt2note_dur[uttid]) 158 | duration = np.array(self.utt2dur[uttid]) 159 | 160 | mel_len = mel.shape[0] 161 | duration = duration[: len(phone_seq)] 162 | duration[-1] = duration[-1] + (mel.shape[0] - sum(duration)) 163 | assert mel_len == sum(duration), f'{mel_len} != {sum(duration)}' 164 | 165 | return { 166 | 'uttid': uttid, 167 | 'raw_text': raw_text, 168 | 'text': phone_seq, 169 | 'note_pitch': note_pitch, 170 | 'note_duration': note_duration, 171 | 'mel': mel, 172 | 'duration': duration, 173 | 'pitch': pitch, 174 | 'uv': uv, 175 | 'energy': energy 176 | } 177 | 178 | class SVSCollate(): 179 | def __init__(self): 180 | pass 181 | 182 | def __call__(self, batch): 183 | ilens = torch.from_numpy(np.array([x['text'].shape[0] for x in batch])).long() 184 | olens = torch.from_numpy(np.array([y['mel'].shape[0] for y in batch])).long() 185 | ids = [x['uttid'] for x in batch] 186 | raw_texts = [x['raw_text'] for x in batch] 187 | 188 | # perform padding and conversion to tensor 189 | inputs = pad_list([torch.from_numpy(x['text']).long() for x in batch], 0) 190 | note_pitchs = pad_list([torch.from_numpy(x['note_pitch']).long() for x in batch], 0) 191 | note_durations = pad_list([torch.from_numpy(x['note_duration']).long() for x in batch], 0) 192 | 193 | mels = pad_list([torch.from_numpy(y['mel']).float() for y in batch], 0) 194 | durations = pad_list([torch.from_numpy(x['duration']).long() for x in batch], 0) 195 | energys = pad_list([torch.from_numpy(y['energy']).float() for y in batch], 0).squeeze(-1) 196 | pitchs = pad_list([torch.from_numpy(y['pitch']).float() for y in batch], 0).squeeze(-1) 197 | uvs = pad_list([torch.from_numpy(y['uv']).float() for y in batch], 0).squeeze(-1) 198 | 199 | return { 200 | 'uttids': ids, 201 | 'raw_texts': raw_texts, 202 | 'texts': inputs, 203 | 'note_pitchs': note_pitchs, 204 | 'note_durations': note_durations, 205 | 'src_lens': ilens, 206 | 'max_src_len': ilens.max(), 207 | 'mels': mels, 208 | 'mel_lens': olens, 209 | 'max_mel_len': olens.max(), 210 | 'p_targets': pitchs, 211 | 'e_targets': energys, 212 | 'uv_targets': uvs, 213 | 'd_targets': durations 214 | } 215 | 216 | if __name__ == '__main__': 217 | import yaml 218 | from tqdm import tqdm 219 | from torch.utils.data import DataLoader 220 | with open('./configs/data.yaml', 'r') as f: 221 | configs = yaml.load(f, Loader = yaml.FullLoader) 222 | dataset = SVSDataset(configs) 223 | collate_fn = SVSCollate() 224 | dataloader = DataLoader(dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn, num_workers = 8) 225 | for data in tqdm(dataloader): 226 | assert data['note_pitchs'].shape[-1] == data['note_durations'].shape[-1] 227 | assert data['uv_targets'].shape[1] == data['p_targets'].shape[-1] 228 | pass 229 | -------------------------------------------------------------------------------- /dataset/dataset_tts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.insert(0, '/home/zengchang/code/acoustic_v2') 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import numpy as np 11 | 12 | from pyutils import pad_list, str_to_int_list, remove_outlier 13 | from dataset.texts import text_to_sequence 14 | 15 | def f02pitch(f0): 16 | #f0 =f0 + 0.01 17 | return np.log2(f0 / 27.5) * 12 + 21 18 | 19 | def pitch2f0(pitch): 20 | f0 = np.exp2((pitch - 21 ) / 12) * 27.5 21 | for i in range(len(f0)): 22 | if f0[i] <= 10: 23 | f0[i] = 0 24 | return f0 25 | 26 | def pitchxuv(pitch, uv, to_f0 = False): 27 | result = pitch * uv 28 | if to_f0: 29 | result = pitch2f0(result) 30 | return result 31 | 32 | def pad1d(x, max_len): 33 | return np.pad(x, (0, max_len - len(x)), mode="constant") 34 | 35 | def pad2d(x, max_len): 36 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant") 37 | 38 | class TTSDataset(Dataset): 39 | def __init__(self, config): 40 | audio_manifest = config['audio_manifest'] 41 | raw_text_manifest = config['raw_text_manifest'] 42 | duration_manifest = config['duration_manifest'] 43 | spk_manifest = config['spk_manifest'] 44 | self.sampling_rate = config['sampling_rate'] 45 | self.utt2path = {} 46 | self.utt2text = {} 47 | self.utt2duration = {} 48 | self.utt2raw_text = {} 49 | self.utt2spk = {} 50 | with open(audio_manifest, 'r') as f: 51 | for line in f: 52 | line = line.rstrip().split(' ') 53 | self.utt2path[line[0]] = line[1] 54 | with open(duration_manifest, 'r') as f: 55 | for line in f: 56 | line = line.rstrip().split('|') 57 | self.utt2text[line[0]] = ' '.join(line[2].split(' ')[0::2]) 58 | self.utt2duration[line[0]] = ' '.join(line[2].split(' ')[1::2]) 59 | with open(raw_text_manifest, 'r') as f: 60 | for line in f: 61 | line = line.rstrip().split(' ') 62 | self.utt2raw_text[line[0]] = line[1] 63 | with open(spk_manifest, 'r') as f: 64 | for line in f: 65 | line = line.rstrip().split(' ') 66 | self.utt2spk[line[0]] = line[1] 67 | self.utt = list(self.utt2path.keys()) 68 | self.spk2int = {spk: idx for idx, spk in enumerate(set(self.utt2spk.values()))} 69 | self.int2spk = {idx: spk for spk, idx in self.spk2int.items()} 70 | 71 | self.use_phonemes = config['use_phonemes'] 72 | self.tts_cleaner_names = config['tts_cleaner_names'] 73 | self.eos = config['eos'] 74 | 75 | def get_spk_number(self): 76 | return len(self.spk2int) 77 | 78 | def _norm_mean_std(self, x, mean, std, is_remove_outlier=False): 79 | if is_remove_outlier: 80 | x = remove_outlier(x) 81 | zero_idxs = np.where(x == 0.0)[0] 82 | x = (x - mean) / std 83 | x[zero_idxs] = 0.0 84 | return x 85 | 86 | def __len__(self): 87 | return len(self.utt) 88 | 89 | def __getitem__(self, idx): 90 | # set_trace() 91 | uttid = self.utt[idx] 92 | mel_path = self.utt2path[uttid].replace('.wav', '.mel.npy') 93 | f0_path = self.utt2path[uttid].replace('.wav', '.f0.npy') 94 | energy_path = self.utt2path[uttid].replace('.wav', '.en.npy') 95 | 96 | mel = np.load(mel_path) #.transpose(1, 0) 97 | f0 = np.load(f0_path) 98 | # pitch = f02pitch(f0) 99 | energy = np.load(energy_path) 100 | raw_text = self.utt2raw_text[uttid] 101 | phone_text = self.utt2text[uttid] 102 | phone_seq = np.array(text_to_sequence(phone_text, self.tts_cleaner_names)) 103 | duration = np.array(str_to_int_list(self.utt2duration[uttid])) 104 | spk = self.spk2int[self.utt2spk[uttid]] 105 | 106 | mel_len = mel.shape[0] 107 | duration = duration[: len(phone_seq)] 108 | duration[-1] = duration[-1] + (mel.shape[0] - sum(duration)) 109 | assert mel_len == sum(duration), f'{mel_len} != {sum(duration)}' 110 | 111 | return { 112 | 'uttid': uttid, 113 | 'raw_text': raw_text, 114 | 'text': phone_seq, 115 | 'mel': mel, 116 | 'duration': duration, 117 | 'f0': f0, 118 | 'energy': energy, 119 | 'spk': spk 120 | } 121 | 122 | class TTSCollate(): 123 | def __init__(self): 124 | pass 125 | 126 | def __call__(self, batch): 127 | ilens = torch.from_numpy(np.array([x['text'].shape[0] for x in batch])).long() 128 | olens = torch.from_numpy(np.array([y['mel'].shape[0] for y in batch])).long() 129 | ids = [x['uttid'] for x in batch] 130 | raw_texts = [x['raw_text'] for x in batch] 131 | 132 | # perform padding and conversion to tensor 133 | inputs = pad_list([torch.from_numpy(x['text']).long() for x in batch], 0) 134 | mels = pad_list([torch.from_numpy(y['mel']).float() for y in batch], 0) 135 | 136 | durations = pad_list([torch.from_numpy(x['duration']).long() for x in batch], 0) 137 | energys = pad_list([torch.from_numpy(y['energy']).float() for y in batch], 0).squeeze(-1) 138 | f0 = pad_list([torch.from_numpy(y['f0']).float() for y in batch], 0).squeeze(-1) 139 | # pitch = pad_list([torch.from_numpy(y['pitch']).float() for y in batch], 0).squeeze(-1) 140 | 141 | spks = torch.tensor([x['spk'] for x in batch], dtype = torch.int64) 142 | 143 | return { 144 | 'uttids': ids, 145 | 'spks': spks, 146 | 'raw_texts': raw_texts, 147 | 'texts': inputs, 148 | 'src_lens': ilens, 149 | 'max_src_len': ilens.max(), 150 | 'mels': mels, 151 | 'mel_lens': olens, 152 | 'max_mel_len': olens.max(), 153 | 'p_targets': f0, 154 | 'e_targets': energys, 155 | 'd_targets': durations 156 | } 157 | 158 | if __name__ == '__main__': 159 | import yaml 160 | from tqdm import tqdm 161 | from torch.utils.data import DataLoader 162 | with open('./configs/data.yaml', 'r') as f: 163 | config = yaml.load(f, Loader = yaml.FullLoader) 164 | dataset = TTSDataset(config) 165 | print(dataset[0]['text']) 166 | print(dataset[0]['duration']) 167 | collate_fn = TTSCollate() 168 | dataloader = DataLoader(dataset, batch_size = 64, shuffle = True, collate_fn = collate_fn, num_workers = 8) 169 | for data in tqdm(dataloader): 170 | assert data['texts'].shape[1] == data['d_targets'].shape[1], "{} != {}".format(data['texts'].shape[1], data['d_targets'].shape[1]) 171 | pass 172 | # print(data['texts'].shape) 173 | # print(data['texts']) 174 | # print(data['input_len']) 175 | # print(data['mels'].shape) 176 | # print(data['labels'].shape) 177 | # print(data['output_len']) 178 | # print(data['uttids']) 179 | # print(data['durations'].shape) 180 | # print(data['durations'].sum(dim = 1)) 181 | # print(data['energys'].shape) 182 | # print(data['f0s'].shape) 183 | # print(data['raw_texts']) 184 | # break 185 | # print(dataset[0]['mel'].shape) 186 | -------------------------------------------------------------------------------- /dataset/espnet_texts/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from dataset.texts import cleaners 4 | from dataset.texts.symbols import ( 5 | symbols, 6 | _eos, 7 | phonemes_symbols, 8 | PAD, 9 | EOS, 10 | _PHONEME_SEP, 11 | ) 12 | from dataset.texts.dict_ import symbols_ 13 | import nltk 14 | from g2p_en import G2p 15 | 16 | # Mappings from symbol to numeric ID and vice versa: 17 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 18 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 19 | 20 | # Regular expression matching text enclosed in curly braces: 21 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 22 | 23 | symbols_inv = {v: k for k, v in symbols_.items()} 24 | 25 | valid_symbols = [ 26 | "AA", 27 | "AA1", 28 | "AE", 29 | "AE0", 30 | "AE1", 31 | "AH", 32 | "AH0", 33 | "AH1", 34 | "AO", 35 | "AO1", 36 | "AW", 37 | "AW0", 38 | "AW1", 39 | "AY", 40 | "AY0", 41 | "AY1", 42 | "B", 43 | "CH", 44 | "D", 45 | "DH", 46 | "EH", 47 | "EH0", 48 | "EH1", 49 | "ER", 50 | "EY", 51 | "EY0", 52 | "EY1", 53 | "F", 54 | "G", 55 | "HH", 56 | "IH", 57 | "IH0", 58 | "IH1", 59 | "IY", 60 | "IY0", 61 | "IY1", 62 | "JH", 63 | "K", 64 | "L", 65 | "M", 66 | "N", 67 | "NG", 68 | "OW", 69 | "OW0", 70 | "OW1", 71 | "OY", 72 | "OY0", 73 | "OY1", 74 | "P", 75 | "R", 76 | "S", 77 | "SH", 78 | "T", 79 | "TH", 80 | "UH", 81 | "UH0", 82 | "UH1", 83 | "UW", 84 | "UW0", 85 | "UW1", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | "pau", 92 | "sil", 93 | "spn" 94 | ] 95 | 96 | 97 | def pad_with_eos_bos(_sequence): 98 | return _sequence + [_symbol_to_id[_eos]] 99 | 100 | 101 | def text_to_sequence(text, cleaner_names, eos): 102 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 103 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 104 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 105 | Args: 106 | text: string to convert to a sequence 107 | cleaner_names: names of the cleaner functions to run the text through 108 | Returns: 109 | List of integers corresponding to the symbols in the text 110 | """ 111 | sequence = [] 112 | if eos: 113 | text = text + "~" 114 | try: 115 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 116 | except KeyError: 117 | print("text : ", text) 118 | exit(0) 119 | 120 | return sequence 121 | 122 | 123 | def sequence_to_text(sequence): 124 | """Converts a sequence of IDs back to a string""" 125 | result = "" 126 | for symbol_id in sequence: 127 | if symbol_id in symbols_inv: 128 | s = symbols_inv[symbol_id] 129 | # Enclose ARPAbet back in curly braces: 130 | if len(s) > 1 and s[0] == "@": 131 | s = "{%s}" % s[1:] 132 | result += s 133 | return result.replace("}{", " ") 134 | 135 | 136 | def _clean_text(text, cleaner_names): 137 | for name in cleaner_names: 138 | cleaner = getattr(cleaners, name) 139 | if not cleaner: 140 | raise Exception("Unknown cleaner: %s" % name) 141 | text = cleaner(text) 142 | return text 143 | 144 | 145 | def _symbols_to_sequence(symbols): 146 | return [symbols_[s.upper()] for s in symbols] 147 | 148 | 149 | def _arpabet_to_sequence(text): 150 | return _symbols_to_sequence(["@" + s for s in text.split()]) 151 | 152 | 153 | def _should_keep_symbol(s): 154 | return s in _symbol_to_id and s != "_" and s != "~" 155 | 156 | 157 | # For phonemes 158 | _phoneme_to_id = {s: i for i, s in enumerate(valid_symbols)} 159 | _id_to_phoneme = {i: s for i, s in enumerate(valid_symbols)} 160 | 161 | 162 | def _should_keep_token(token, token_dict): 163 | return ( 164 | token in token_dict 165 | and token != PAD 166 | and token != EOS 167 | and token != _phoneme_to_id[PAD] 168 | and token != _phoneme_to_id[EOS] 169 | ) 170 | 171 | 172 | def phonemes_to_sequence(phonemes): 173 | string = phonemes.split() if isinstance(phonemes, str) else phonemes 174 | # string.append(EOS) 175 | sequence = list(map(convert_phoneme_CMU, string)) 176 | sequence = [_phoneme_to_id[s] for s in sequence] 177 | # if _should_keep_token(s, _phoneme_to_id)] 178 | return sequence 179 | 180 | 181 | def sequence_to_phonemes(sequence, use_eos=False): 182 | string = [_id_to_phoneme[idx] for idx in sequence] 183 | # if _should_keep_token(idx, _id_to_phoneme)] 184 | string = _PHONEME_SEP.join(string) 185 | if use_eos: 186 | string = string.replace(EOS, "") 187 | return string 188 | 189 | 190 | def convert_phoneme_CMU(phoneme): 191 | REMAPPING = { 192 | 'AA0': 'AA1', 193 | 'AA2': 'AA1', 194 | 'AE2': 'AE1', 195 | 'AH2': 'AH1', 196 | 'AO0': 'AO1', 197 | 'AO2': 'AO1', 198 | 'AW2': 'AW1', 199 | 'AY2': 'AY1', 200 | 'EH2': 'EH1', 201 | 'ER0': 'EH1', 202 | 'ER1': 'EH1', 203 | 'ER2': 'EH1', 204 | 'EY2': 'EY1', 205 | 'IH2': 'IH1', 206 | 'IY2': 'IY1', 207 | 'OW2': 'OW1', 208 | 'OY2': 'OY1', 209 | 'UH2': 'UH1', 210 | 'UW2': 'UW1', 211 | } 212 | return REMAPPING.get(phoneme, phoneme) 213 | 214 | 215 | def text_to_phonemes(text, custom_words={}): 216 | """ 217 | Convert text into ARPAbet. 218 | For known words use CMUDict; for the rest try 'espeak' (to IPA) followed by 'listener'. 219 | :param text: str, input text. 220 | :param custom_words: 221 | dict {str: list of str}, optional 222 | Pronounciations (a list of ARPAbet phonemes) you'd like to override. 223 | Example: {'word': ['W', 'EU1', 'R', 'D']} 224 | :return: list of str, phonemes 225 | """ 226 | g2p = G2p() 227 | 228 | """def convert_phoneme_CMU(phoneme): 229 | REMAPPING = { 230 | 'AA0': 'AA1', 231 | 'AA2': 'AA1', 232 | 'AE2': 'AE1', 233 | 'AH2': 'AH1', 234 | 'AO0': 'AO1', 235 | 'AO2': 'AO1', 236 | 'AW2': 'AW1', 237 | 'AY2': 'AY1', 238 | 'EH2': 'EH1', 239 | 'ER0': 'EH1', 240 | 'ER1': 'EH1', 241 | 'ER2': 'EH1', 242 | 'EY2': 'EY1', 243 | 'IH2': 'IH1', 244 | 'IY2': 'IY1', 245 | 'OW2': 'OW1', 246 | 'OY2': 'OY1', 247 | 'UH2': 'UH1', 248 | 'UW2': 'UW1', 249 | } 250 | return REMAPPING.get(phoneme, phoneme) 251 | """ 252 | 253 | def convert_phoneme_listener(phoneme): 254 | VOWELS = ['A', 'E', 'I', 'O', 'U'] 255 | if phoneme[0] in VOWELS: 256 | phoneme += '1' 257 | return phoneme # convert_phoneme_CMU(phoneme) 258 | 259 | try: 260 | known_words = nltk.corpus.cmudict.dict() 261 | except LookupError: 262 | nltk.download("cmudict") 263 | known_words = nltk.corpus.cmudict.dict() 264 | 265 | for word, phonemes in custom_words.items(): 266 | known_words[word.lower()] = [phonemes] 267 | 268 | words = nltk.tokenize.WordPunctTokenizer().tokenize(text.lower()) 269 | 270 | phonemes = [] 271 | PUNCTUATION = "!?.,-:;\"'()" 272 | for word in words: 273 | if all(c in PUNCTUATION for c in word): 274 | pronounciation = ["pau"] 275 | elif word in known_words: 276 | pronounciation = known_words[word][0] 277 | pronounciation = list( 278 | pronounciation 279 | ) # map(convert_phoneme_CMU, pronounciation)) 280 | else: 281 | pronounciation = g2p(word) 282 | pronounciation = list( 283 | pronounciation 284 | ) # (map(convert_phoneme_CMU, pronounciation)) 285 | 286 | phonemes += pronounciation 287 | 288 | return phonemes -------------------------------------------------------------------------------- /dataset/espnet_texts/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | """ 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | 21 | _whitespace_re = re.compile(r"\s+") 22 | punctuations = """+-!()[]{};:'"\<>/?@#^&*_~""" 23 | 24 | # List of (regular expression, replacement) pairs for abbreviations: 25 | _abbreviations = [ 26 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 27 | for x in [ 28 | ("mrs", "misess"), 29 | ("mr", "mister"), 30 | ("dr", "doctor"), 31 | ("st", "saint"), 32 | ("co", "company"), 33 | ("jr", "junior"), 34 | ("maj", "major"), 35 | ("gen", "general"), 36 | ("drs", "doctors"), 37 | ("rev", "reverend"), 38 | ("lt", "lieutenant"), 39 | ("hon", "honorable"), 40 | ("sgt", "sergeant"), 41 | ("capt", "captain"), 42 | ("esq", "esquire"), 43 | ("ltd", "limited"), 44 | ("col", "colonel"), 45 | ("ft", "fort"), 46 | ] 47 | ] 48 | 49 | 50 | def expand_abbreviations(text): 51 | for regex, replacement in _abbreviations: 52 | text = re.sub(regex, replacement, text) 53 | return text 54 | 55 | 56 | def expand_numbers(text): 57 | return normalize_numbers(text) 58 | 59 | 60 | def lowercase(text): 61 | return text.lower() 62 | 63 | 64 | def collapse_whitespace(text): 65 | return re.sub(_whitespace_re, " ", text) 66 | 67 | 68 | def convert_to_ascii(text): 69 | return unidecode(text) 70 | 71 | 72 | def basic_cleaners(text): 73 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 74 | text = lowercase(text) 75 | text = collapse_whitespace(text) 76 | return text 77 | 78 | 79 | def transliteration_cleaners(text): 80 | """Pipeline for non-English text that transliterates to ASCII.""" 81 | text = convert_to_ascii(text) 82 | text = lowercase(text) 83 | text = collapse_whitespace(text) 84 | return text 85 | 86 | 87 | def english_cleaners(text): 88 | """Pipeline for English text, including number and abbreviation expansion.""" 89 | text = convert_to_ascii(text) 90 | text = lowercase(text) 91 | text = expand_numbers(text) 92 | text = expand_abbreviations(text) 93 | text = collapse_whitespace(text) 94 | return text 95 | 96 | 97 | def punctuation_removers(text): 98 | no_punct = "" 99 | for char in text: 100 | if char not in punctuations: 101 | no_punct = no_punct + char 102 | return no_punct -------------------------------------------------------------------------------- /dataset/espnet_texts/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) -------------------------------------------------------------------------------- /dataset/espnet_texts/dict.py: -------------------------------------------------------------------------------- 1 | symbols_ = { 2 | "": 1, 3 | "!": 2, 4 | "'": 3, 5 | ",": 4, 6 | ".": 5, 7 | " ": 6, 8 | "?": 7, 9 | "A": 8, 10 | "B": 9, 11 | "C": 10, 12 | "D": 11, 13 | "E": 12, 14 | "F": 13, 15 | "G": 14, 16 | "H": 15, 17 | "I": 16, 18 | "J": 17, 19 | "K": 18, 20 | "L": 19, 21 | "M": 20, 22 | "N": 21, 23 | "O": 22, 24 | "P": 23, 25 | "Q": 24, 26 | "R": 25, 27 | "S": 26, 28 | "T": 27, 29 | "U": 28, 30 | "V": 29, 31 | "W": 30, 32 | "X": 31, 33 | "Y": 32, 34 | "Z": 33, 35 | "~": 34, 36 | } -------------------------------------------------------------------------------- /dataset/espnet_texts/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text -------------------------------------------------------------------------------- /dataset/espnet_texts/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from dataset.texts import cmudict 9 | 10 | _pad = "_" 11 | _eos = "~" 12 | _bos = "^" 13 | _punctuation = "!'(),.:;? " 14 | _special = "-" 15 | _letters = "abcdefghijklmnopqrstuvwxyz" 16 | 17 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 18 | # _arpabet = ['@' + s for s in cmudict.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + [_eos] 22 | 23 | # For Phonemes 24 | 25 | PAD = "#" 26 | EOS = "~" 27 | PHONEME_CODES = "AA1 AE0 AE1 AH0 AH1 AO0 AO1 AW0 AW1 AY0 AY1 B CH D DH EH0 EH1 EU0 EU1 EY0 EY1 F G HH IH0 IH1 IY0 IY1 JH K L M N NG OW0 OW1 OY0 OY1 P R S SH T TH UH0 UH1 UW0 UW1 V W Y Z ZH pau".split() 28 | _PHONEME_SEP = " " 29 | 30 | phonemes_symbols = [PAD, EOS] + PHONEME_CODES # PAD should be first to have zero id -------------------------------------------------------------------------------- /dataset/texts/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | # import sys 4 | # sys.path.insert(0, '/home/zengchang/code/acoustic_v2') 5 | from dataset.texts import cleaners 6 | from dataset.texts.symbols import symbols 7 | 8 | from ipdb import set_trace 9 | 10 | # Mappings from symbol to numeric ID and vice versa: 11 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 12 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 13 | 14 | # Regular expression matching text enclosed in curly braces: 15 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 16 | 17 | 18 | def text_to_sequence(text, cleaner_names): 19 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 20 | 21 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 22 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 23 | 24 | Args: 25 | text: string to convert to a sequence 26 | cleaner_names: names of the cleaner functions to run the text through 27 | 28 | Returns: 29 | List of integers corresponding to the symbols in the text 30 | """ 31 | # set_trace() 32 | sequence = [] 33 | 34 | # Check for curly braces and treat their contents as ARPAbet: 35 | # while len(text): 36 | # m = _curly_re.match(text) 37 | 38 | # if not m: 39 | # clean_text = _clean_text(text, cleaner_names).split(' ') 40 | # sequence += _symbols_to_sequence(clean_text) 41 | # break 42 | # clean_text1 = _clean_text(m.group(1), cleaner_names).split(' ') 43 | # sequence += _symbols_to_sequence(clean_text1) 44 | # clean_text2 = m.group(2).split(' ') 45 | # sequence += _arpabet_to_sequence(clean_text2) 46 | # text = m.group(3) 47 | sequence += _arpabet_to_sequence(text) 48 | 49 | return sequence 50 | 51 | 52 | def sequence_to_text(sequence): 53 | """Converts a sequence of IDs back to a string""" 54 | result = "" 55 | for symbol_id in sequence: 56 | if symbol_id in _id_to_symbol: 57 | s = _id_to_symbol[symbol_id] 58 | # Enclose ARPAbet back in curly braces: 59 | if len(s) > 1 and s[0] == "@": 60 | s = "{%s}" % s[1:] 61 | result += s 62 | return result.replace("}{", " ") 63 | 64 | def _clean_text(text, cleaner_names): 65 | for name in cleaner_names: 66 | cleaner = getattr(cleaners, name) 67 | if not cleaner: 68 | raise Exception("Unknown cleaner: %s" % name) 69 | text = cleaner(text) 70 | return text 71 | 72 | def _symbols_to_sequence(symbols): 73 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 74 | 75 | def _arpabet_to_sequence(text): 76 | return _symbols_to_sequence(["@" + s for s in text.split()]) 77 | 78 | def _should_keep_symbol(s): 79 | return s in _symbol_to_id and s != "_" and s != "~" 80 | 81 | if __name__ == "__main__": 82 | text = 'Turn left on {HH AW1 S S T AH0 N} Street.' 83 | print(text_to_sequence(text, ['english_cleaners'])) -------------------------------------------------------------------------------- /dataset/texts/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text -------------------------------------------------------------------------------- /dataset/texts/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) -------------------------------------------------------------------------------- /dataset/texts/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text -------------------------------------------------------------------------------- /dataset/texts/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /dataset/texts/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from dataset.texts import cmudict, pinyin 9 | 10 | _pad = "_~" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | # _silences = ["sp", "spn", "sil"] 16 | 17 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 18 | # _arpabet = [s for s in cmudict.valid_symbols] 19 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 20 | # _pinyin = [s for s in pinyin.valid_symbols] 21 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 22 | 23 | # Export all symbols: 24 | symbols = ( 25 | [_pad] 26 | + list(_special) 27 | + list(_punctuation) 28 | + list(_letters) 29 | + _arpabet 30 | + _pinyin 31 | + _silences 32 | ) 33 | 34 | # symbols 35 | ''' 36 | ['_', 37 | '~', 38 | '-', 39 | '!', 40 | "'", 41 | '(', 42 | ')', 43 | ',', 44 | '.', 45 | ':', 46 | ';', 47 | '?', 48 | ' ', 49 | 'A', 50 | 'B', 51 | 'C', 52 | 'D', 53 | 'E', 54 | 'F', 55 | 'G', 56 | 'H', 57 | 'I', 58 | 'J', 59 | 'K', 60 | 'L', 61 | 'M', 62 | 'N', 63 | 'O', 64 | 'P', 65 | 'Q', 66 | 'R', 67 | 'S', 68 | 'T', 69 | 'U', 70 | 'V', 71 | 'W', 72 | 'X', 73 | 'Y', 74 | 'Z', 75 | 'a', 76 | 'b', 77 | 'c', 78 | 'd', 79 | 'e', 80 | 'f', 81 | 'g', 82 | 'h', 83 | 'i', 84 | 'j', 85 | 'k', 86 | 'l', 87 | 'm', 88 | 'n', 89 | 'o', 90 | 'p', 91 | 'q', 92 | 'r', 93 | 's', 94 | 't', 95 | 'u', 96 | 'v', 97 | 'w', 98 | 'x', 99 | 'y', 100 | 'z', 101 | '@AA', 102 | '@AA0', 103 | '@AA1', 104 | '@AA2', 105 | '@AE', 106 | '@AE0', 107 | '@AE1', 108 | '@AE2', 109 | '@AH', 110 | '@AH0', 111 | '@AH1', 112 | '@AH2', 113 | '@AO', 114 | '@AO0', 115 | '@AO1', 116 | '@AO2', 117 | '@AW', 118 | '@AW0', 119 | '@AW1', 120 | '@AW2', 121 | '@AY', 122 | '@AY0', 123 | '@AY1', 124 | '@AY2', 125 | '@B', 126 | '@CH', 127 | '@D', 128 | '@DH', 129 | '@EH', 130 | '@EH0', 131 | '@EH1', 132 | '@EH2', 133 | '@ER', 134 | '@ER0', 135 | '@ER1', 136 | '@ER2', 137 | '@EY', 138 | '@EY0', 139 | '@EY1', 140 | '@EY2', 141 | '@F', 142 | '@G', 143 | '@HH', 144 | '@IH', 145 | '@IH0', 146 | '@IH1', 147 | '@IH2', 148 | '@IY', 149 | '@IY0', 150 | '@IY1', 151 | '@IY2', 152 | '@JH', 153 | '@K', 154 | '@L', 155 | '@M', 156 | '@N', 157 | '@NG', 158 | '@OW', 159 | '@OW0', 160 | '@OW1', 161 | '@OW2', 162 | '@OY', 163 | '@OY0', 164 | '@OY1', 165 | '@OY2', 166 | '@P', 167 | '@R', 168 | '@S', 169 | '@SH', 170 | '@T', 171 | '@TH', 172 | '@UH', 173 | '@UH0', 174 | '@UH1', 175 | '@UH2', 176 | '@UW', 177 | '@UW0', 178 | '@UW1', 179 | '@UW2', 180 | '@V', 181 | '@W', 182 | '@Y', 183 | '@Z', 184 | '@ZH', 185 | '@b', 186 | '@c', 187 | '@ch', 188 | '@d', 189 | '@f', 190 | '@g', 191 | '@h', 192 | '@j', 193 | '@k', 194 | '@l', 195 | '@m', 196 | '@n', 197 | '@p', 198 | '@q', 199 | '@r', 200 | '@s', 201 | '@sh', 202 | '@t', 203 | '@w', 204 | '@x', 205 | '@y', 206 | '@z', 207 | '@zh', 208 | '@a1', 209 | '@a2', 210 | '@a3', 211 | '@a4', 212 | '@a5', 213 | '@ai1', 214 | '@ai2', 215 | '@ai3', 216 | '@ai4', 217 | '@ai5', 218 | '@an1', 219 | '@an2', 220 | '@an3', 221 | '@an4', 222 | '@an5', 223 | '@ang1', 224 | '@ang2', 225 | '@ang3', 226 | '@ang4', 227 | '@ang5', 228 | '@ao1', 229 | '@ao2', 230 | '@ao3', 231 | '@ao4', 232 | '@ao5', 233 | '@e1', 234 | '@e2', 235 | '@e3', 236 | '@e4', 237 | '@e5', 238 | '@ei1', 239 | '@ei2', 240 | '@ei3', 241 | '@ei4', 242 | '@ei5', 243 | '@en1', 244 | '@en2', 245 | '@en3', 246 | '@en4', 247 | '@en5', 248 | '@eng1', 249 | '@eng2', 250 | '@eng3', 251 | '@eng4', 252 | '@eng5', 253 | '@er1', 254 | '@er2', 255 | '@er3', 256 | '@er4', 257 | '@er5', 258 | '@i1', 259 | '@i2', 260 | '@i3', 261 | '@i4', 262 | '@i5', 263 | '@ia1', 264 | '@ia2', 265 | '@ia3', 266 | '@ia4', 267 | '@ia5', 268 | '@ian1', 269 | '@ian2', 270 | '@ian3', 271 | '@ian4', 272 | '@ian5', 273 | '@iang1', 274 | '@iang2', 275 | '@iang3', 276 | '@iang4', 277 | '@iang5', 278 | '@iao1', 279 | '@iao2', 280 | '@iao3', 281 | '@iao4', 282 | '@iao5', 283 | '@ie1', 284 | '@ie2', 285 | '@ie3', 286 | '@ie4', 287 | '@ie5', 288 | '@ii1', 289 | '@ii2', 290 | '@ii3', 291 | '@ii4', 292 | '@ii5', 293 | '@iii1', 294 | '@iii2', 295 | '@iii3', 296 | '@iii4', 297 | '@iii5', 298 | '@in1', 299 | '@in2', 300 | '@in3', 301 | '@in4', 302 | '@in5', 303 | '@ing1', 304 | '@ing2', 305 | '@ing3', 306 | '@ing4', 307 | '@ing5', 308 | '@iong1', 309 | '@iong2', 310 | '@iong3', 311 | '@iong4', 312 | '@iong5', 313 | '@iou1', 314 | '@iou2', 315 | '@iou3', 316 | '@iou4', 317 | '@iou5', 318 | '@o1', 319 | '@o2', 320 | '@o3', 321 | '@o4', 322 | '@o5', 323 | '@ong1', 324 | '@ong2', 325 | '@ong3', 326 | '@ong4', 327 | '@ong5', 328 | '@ou1', 329 | '@ou2', 330 | '@ou3', 331 | '@ou4', 332 | '@ou5', 333 | '@u1', 334 | '@u2', 335 | '@u3', 336 | '@u4', 337 | '@u5', 338 | '@ua1', 339 | '@ua2', 340 | '@ua3', 341 | '@ua4', 342 | '@ua5', 343 | '@uai1', 344 | '@uai2', 345 | '@uai3', 346 | '@uai4', 347 | '@uai5', 348 | '@uan1', 349 | '@uan2', 350 | '@uan3', 351 | '@uan4', 352 | '@uan5', 353 | '@uang1', 354 | '@uang2', 355 | '@uang3', 356 | '@uang4', 357 | '@uang5', 358 | '@uei1', 359 | '@uei2', 360 | '@uei3', 361 | '@uei4', 362 | '@uei5', 363 | '@uen1', 364 | '@uen2', 365 | '@uen3', 366 | '@uen4', 367 | '@uen5', 368 | '@uo1', 369 | '@uo2', 370 | '@uo3', 371 | '@uo4', 372 | '@uo5', 373 | '@v1', 374 | '@v2', 375 | '@v3', 376 | '@v4', 377 | '@v5', 378 | '@van1', 379 | '@van2', 380 | '@van3', 381 | '@van4', 382 | '@van5', 383 | '@ve1', 384 | '@ve2', 385 | '@ve3', 386 | '@ve4', 387 | '@ve5', 388 | '@vn1', 389 | '@vn2', 390 | '@vn3', 391 | '@vn4', 392 | '@vn5', 393 | '@rr', 394 | '@sp', 395 | '@spn', 396 | '@sil'] 397 | ''' 398 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2_loss import * 2 | from .loss import FastSpeech2Loss, FeatLoss, LSGANDLoss, LSGANGLoss 3 | -------------------------------------------------------------------------------- /loss/fastspeech2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional 4 | 5 | class PitchPredictorLoss(nn.Module): 6 | """Loss function module for duration predictor. 7 | 8 | The loss value is Calculated in log domain to make it Gaussian. 9 | 10 | """ 11 | 12 | def __init__(self, offset=1.0): 13 | """Initilize duration predictor loss module. 14 | 15 | Args: 16 | offset (float, optional): Offset value to avoid nan in log domain. 17 | 18 | """ 19 | super(PitchPredictorLoss, self).__init__() 20 | self.criterion = nn.MSELoss() 21 | self.offset = offset 22 | 23 | def forward(self, outputs, targets): 24 | """Calculate forward propagation. 25 | 26 | Args: 27 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 28 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 29 | 30 | Returns: 31 | Tensor: Mean squared error loss value. 32 | 33 | Note: 34 | `outputs` is in log domain but `targets` is in linear domain. 35 | 36 | """ 37 | # NOTE: We convert the output in log domain low error value 38 | # print("Output :", outputs[0]) 39 | # print("Before Output :", targets[0]) 40 | # targets = torch.log(targets.float() + self.offset) 41 | # print("Before Output :", targets[0]) 42 | # outputs = torch.log(outputs.float() + self.offset) 43 | loss = self.criterion(outputs, targets) 44 | # print(loss) 45 | return loss 46 | 47 | 48 | class EnergyPredictorLoss(nn.Module): 49 | """Loss function module for duration predictor. 50 | 51 | The loss value is Calculated in log domain to make it Gaussian. 52 | 53 | """ 54 | 55 | def __init__(self, offset=1.0): 56 | """Initilize duration predictor loss module. 57 | 58 | Args: 59 | offset (float, optional): Offset value to avoid nan in log domain. 60 | 61 | """ 62 | super(EnergyPredictorLoss, self).__init__() 63 | self.criterion = nn.MSELoss() 64 | self.offset = offset 65 | 66 | def forward(self, outputs, targets): 67 | """Calculate forward propagation. 68 | 69 | Args: 70 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 71 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 72 | 73 | Returns: 74 | Tensor: Mean squared error loss value. 75 | 76 | Note: 77 | `outputs` is in log domain but `targets` is in linear domain. 78 | 79 | """ 80 | # NOTE: outputs is in log domain while targets in linear 81 | # targets = torch.log(targets.float() + self.offset) 82 | loss = self.criterion(outputs, targets) 83 | 84 | return loss 85 | 86 | class DurationPredictorLoss(nn.Module): 87 | """Loss function module for duration predictor. 88 | 89 | The loss value is Calculated in log domain to make it Gaussian. 90 | 91 | """ 92 | 93 | def __init__(self, offset=1.0): 94 | """Initilize duration predictor loss module. 95 | 96 | Args: 97 | offset (float, optional): Offset value to avoid nan in log domain. 98 | 99 | """ 100 | super(DurationPredictorLoss, self).__init__() 101 | self.criterion = nn.MSELoss() 102 | self.offset = offset 103 | 104 | def forward(self, outputs, targets): 105 | """Calculate forward propagation. 106 | 107 | Args: 108 | outputs (Tensor): Batch of prediction durations in log domain (B, T) 109 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) 110 | 111 | Returns: 112 | Tensor: Mean squared error loss value. 113 | 114 | Note: 115 | `outputs` is in log domain but `targets` is in linear domain. 116 | 117 | """ 118 | # NOTE: outputs is in log domain while targets in linear 119 | targets = torch.log(targets.float() + self.offset) 120 | loss = self.criterion(outputs, targets) 121 | 122 | return loss -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FastSpeech2Loss(nn.Module): 5 | """ FastSpeech2 Loss """ 6 | 7 | def __init__(self, data_config): 8 | super(FastSpeech2Loss, self).__init__() 9 | self.pitch_feature_level = data_config["pitch"]["feature"] 10 | self.energy_feature_level = data_config["energy"]["feature"] 11 | self.mse_loss = nn.MSELoss() 12 | self.mae_loss = nn.L1Loss() 13 | 14 | def forward(self, inputs, predictions): 15 | ( 16 | mel_targets, 17 | _, 18 | _, 19 | pitch_targets, 20 | energy_targets, 21 | uv_targets, 22 | duration_targets, 23 | ) = inputs 24 | 25 | ( 26 | mel_predictions, 27 | postnet_mel_predictions, 28 | pitch_predictions, 29 | energy_predictions, 30 | uv_predictions, 31 | log_duration_predictions, 32 | _, 33 | src_masks, 34 | mel_masks, 35 | _, 36 | _, 37 | ) = predictions 38 | 39 | src_masks = ~src_masks 40 | mel_masks = ~mel_masks 41 | log_duration_targets = torch.log(duration_targets.float() + 1) 42 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 43 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 44 | 45 | log_duration_targets.requires_grad = False 46 | pitch_targets.requires_grad = False 47 | energy_targets.requires_grad = False 48 | mel_targets.requires_grad = False 49 | if not uv_targets is None: 50 | uv_targets.requires_grad = False 51 | 52 | if self.pitch_feature_level == "phoneme_level": 53 | pitch_predictions = pitch_predictions.masked_select(src_masks) 54 | pitch_targets = pitch_targets.masked_select(src_masks) 55 | elif self.pitch_feature_level == "frame_level": 56 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 57 | pitch_targets = pitch_targets.masked_select(mel_masks) 58 | 59 | if self.energy_feature_level == "phoneme_level": 60 | energy_predictions = energy_predictions.masked_select(src_masks) 61 | energy_targets = energy_targets.masked_select(src_masks) 62 | if self.energy_feature_level == "frame_level": 63 | energy_predictions = energy_predictions.masked_select(mel_masks) 64 | energy_targets = energy_targets.masked_select(mel_masks) 65 | 66 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 67 | log_duration_targets = log_duration_targets.masked_select(src_masks) 68 | 69 | if not uv_targets is None: 70 | uv_predictions = uv_predictions.masked_select(mel_masks) 71 | uv_targets = uv_targets.masked_select(mel_masks) 72 | 73 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 74 | postnet_mel_predictions = postnet_mel_predictions.masked_select( 75 | mel_masks.unsqueeze(-1) 76 | ) 77 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 78 | 79 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 80 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) 81 | 82 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 83 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 84 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) 85 | total_loss = (mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss) 86 | uv_loss = None 87 | if not uv_targets is None: 88 | uv_loss = self.mse_loss(uv_predictions, uv_targets) 89 | total_loss = ( 90 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss + 0.1 * uv_loss 91 | ) 92 | 93 | return ( 94 | total_loss, 95 | mel_loss, 96 | postnet_mel_loss, 97 | pitch_loss, 98 | energy_loss, 99 | 0 if uv_loss is None else uv_loss, 100 | duration_loss, 101 | ) 102 | 103 | class FeatLoss(nn.Module): 104 | ''' 105 | feature loss (multi-band discriminator) 106 | ''' 107 | def __init__(self, feat_loss_weight = (1.0, 1.0, 1.0)): 108 | super(FeatLoss, self).__init__() 109 | self.loss_d = nn.MSELoss() #.to(self.device) 110 | self.feat_loss_weight = feat_loss_weight 111 | 112 | def forward(self, D_fake): 113 | feat_g_loss = 0.0 114 | feat_loss = [0.0] * len(D_fake) 115 | report_keys = {} 116 | for j in range(len(D_fake)): 117 | for k in range(len(D_fake[j][0])): 118 | for n in range(len(D_fake[j][0][k][1])): 119 | if len(D_fake[j][0][k][1][n].shape) == 4: 120 | t_batch = D_fake[j][0][k][1][n].shape[0] 121 | t_length = D_fake[j][0][k][1][n].shape[-1] 122 | D_fake[j][0][k][1][n] = D_fake[j][0][k][1][n].view(t_batch, t_length,-1) 123 | D_fake[j][1][k][1][n] = D_fake[j][1][k][1][n].view(t_batch, t_length,-1) 124 | feat_loss[j] += self.loss_d(D_fake[j][0][k][1][n], D_fake[j][1][k][1][n]) * 2 125 | feat_loss[j] /= (n + 1) 126 | feat_loss[j] /= (k + 1) 127 | feat_loss[j] *= self.feat_loss_weight[j] 128 | report_keys['feat_loss_' + str(j)] = feat_loss[j] 129 | feat_g_loss += feat_loss[j] 130 | 131 | return feat_g_loss, report_keys 132 | 133 | class LSGANGLoss(nn.Module): 134 | def __init__(self, adv_loss_weight): 135 | super(LSGANGLoss, self).__init__() 136 | self.loss_d = nn.MSELoss() #.to(self.device) 137 | self.adv_loss_weight = adv_loss_weight 138 | 139 | def forward(self, D_fake): 140 | adv_g_loss = 0.0 141 | adv_loss = [0.0] * len(D_fake) 142 | report_keys = {} 143 | for j in range(len(D_fake)): 144 | for k in range(len(D_fake[j][0])): 145 | adv_loss[j] += self.loss_d(D_fake[j][0][k][0], D_fake[j][0][k][0].new_ones(D_fake[j][0][k][0].size())) 146 | adv_loss[j] /= (k + 1) 147 | adv_loss[j] *= self.adv_loss_weight[j] 148 | report_keys['adv_g_loss_' + str(j)] = adv_loss[j] 149 | adv_g_loss += adv_loss[j] 150 | return adv_g_loss, report_keys 151 | 152 | class LSGANDLoss(nn.Module): 153 | def __init__(self): 154 | super(LSGANDLoss, self).__init__() 155 | self.loss_d = nn.MSELoss() 156 | 157 | def forward(self, D_fake): 158 | adv_d_loss = 0.0 159 | adv_loss = [0.0] * len(D_fake) 160 | real_loss = [0.0] * len(D_fake) 161 | fake_loss = [0.0] * len(D_fake) 162 | report_keys = {} 163 | for j in range(len(D_fake)): 164 | for k in range(len(D_fake[j][0])): 165 | real_loss[j] += self.loss_d(D_fake[j][1][k][0], D_fake[j][1][k][0].new_ones(D_fake[j][1][k][0].size())) 166 | fake_loss[j] += self.loss_d(D_fake[j][0][k][0], D_fake[j][0][k][0].new_zeros(D_fake[j][0][k][0].size())) 167 | real_loss[j] /= (k + 1) 168 | fake_loss[j] /= (k + 1) 169 | adv_loss[j] = 0.5 * (real_loss[j] + fake_loss[j]) 170 | report_keys['adv_d_loss_' + str(j)] = adv_loss[j] 171 | adv_d_loss += adv_loss[j] 172 | return adv_d_loss, report_keys 173 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2 import * 2 | from .xiaoice2 import * 3 | from .discriminator import * 4 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import logging 5 | 6 | class GLU(nn.Module): 7 | def __init__(self): 8 | super(GLU, self).__init__() 9 | # Custom Implementation because the Voice Conversion Cycle GAN 10 | # paper assumes GLU won't reduce the dimension of tensor by 2. 11 | 12 | def forward(self, input): 13 | return input * torch.sigmoid(input) 14 | 15 | class DiscriminatorFactory(nn.Module): 16 | def __init__(self, 17 | time_length, 18 | freq_length, 19 | conv_channel, 20 | ): 21 | super(DiscriminatorFactory, self).__init__() 22 | 23 | layers = 10 24 | conv_channels = conv_channel 25 | kernel_size = 3 26 | conv_in_channels = 60 27 | use_weight_norm = True 28 | 29 | self.conv_layers = torch.nn.ModuleList() 30 | for i in range(layers - 1): 31 | if i == 0: 32 | dilation = 1 33 | else: 34 | dilation = 1 35 | conv_in_channels = conv_channels 36 | padding = (kernel_size - 1) // 2 * dilation 37 | conv_layer = [ 38 | nn.Conv1d(conv_in_channels, conv_channels, 39 | kernel_size=kernel_size, padding=padding, 40 | dilation=dilation, bias=True), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | #nn.BatchNorm1d(conv_channels) 43 | ] 44 | self.conv_layers += conv_layer 45 | padding = (kernel_size - 1) // 2 46 | last_conv_layer = nn.Conv1d( 47 | conv_in_channels, 1, 48 | kernel_size=kernel_size, padding=padding, bias=True) 49 | self.conv_layers += [last_conv_layer] 50 | 51 | # apply weight norm 52 | if use_weight_norm: 53 | self.apply_weight_norm() 54 | 55 | def apply_weight_norm(self): 56 | """Apply weight normalization module from all of the layers.""" 57 | def _apply_weight_norm(m): 58 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): 59 | torch.nn.utils.weight_norm(m) 60 | logging.debug(f"weight norm is applied to {m}.") 61 | self.apply(_apply_weight_norm) 62 | 63 | def remove_weight_norm(self): 64 | """Remove weight normalization module from all of the layers.""" 65 | def _remove_weight_norm(m): 66 | try: 67 | logging.debug(f"weight norm is removed from {m}.") 68 | torch.nn.utils.remove_weight_norm(m) 69 | except ValueError: # this module didn't have weight norm 70 | return 71 | self.apply(_remove_weight_norm) 72 | 73 | def forward(self, x): 74 | """ 75 | Args: 76 | x: (B, C, T), by default, C = 40. 77 | 78 | Returns: 79 | tensor: (B, 1, T) 80 | """ 81 | feature_list = [] 82 | i = 1 83 | for f in self.conv_layers: 84 | x = f(x) 85 | if i % 2 == 1: 86 | feature_list.append(x) 87 | i += 1 88 | return [x, feature_list] 89 | 90 | 91 | class MultiWindowDiscriminator(nn.Module): 92 | """docstring for MultiWindowDiscriminator""" 93 | def __init__(self, 94 | time_lengths, 95 | freq_lengths, 96 | conv_channels, 97 | ): 98 | super(MultiWindowDiscriminator, self).__init__() 99 | self.win_lengths = time_lengths 100 | 101 | self.conv_layers = nn.ModuleList() 102 | self.patch_layers = nn.ModuleList() 103 | for time_length, freq_length, conv_channel in zip(time_lengths, freq_lengths, conv_channels): 104 | conv_layer = [ 105 | DiscriminatorFactory(np.abs(time_length), freq_length, conv_channel), 106 | ] # 1d 107 | self.conv_layers += conv_layer 108 | patch_layer = [PatchGAN()] 109 | self.patch_layers += patch_layer 110 | 111 | 112 | def clip(self, x, x_len, win_length, y=None, random_N=None): 113 | '''Ramdom clip x to win_length. 114 | Args: 115 | x (tensor) : (B, T, C). 116 | x_len (tensor) : (B,). 117 | win_length (int): target clip length 118 | 119 | Returns: 120 | (tensor) : (B, win_length, C). 121 | 122 | ''' 123 | x_batch = [] 124 | y_batch = [] 125 | T_end = win_length 126 | if T_end > 0: 127 | cursor = 1 128 | else: 129 | cursor = -1 130 | min_a = min(x_len) 131 | if np.abs(T_end) + random_N > min_a: 132 | T_end = min_a - random_N - 1 133 | T_end = T_end * cursor 134 | #print(x_len, random_N, win_length, T_end) 135 | for i in range(x.size(0)): 136 | if T_end < 0: 137 | x_batch += [x[i, x_len[i].cpu() + T_end - random_N: x_len[i].cpu() - random_N, :].unsqueeze(0)] 138 | else: 139 | x_batch += [x[i, random_N : T_end + random_N, :].unsqueeze(0)] 140 | if y != None: 141 | if T_end < 0: 142 | y_batch += [y[i, x_len[i].cpu() + T_end - random_N: x_len[i].cpu() - random_N, :].unsqueeze(0)] 143 | else: 144 | y_batch += [y[i, random_N : T_end+ random_N, :].unsqueeze(0)] 145 | 146 | x_batch = torch.cat(x_batch, 0) 147 | if y != None: 148 | y_batch = torch.cat(y_batch, 0) 149 | if y != None: 150 | return x_batch, y_batch 151 | else: 152 | return x_batch 153 | 154 | def forward(self, x, x_len, y=None, random_N=None): 155 | ''' 156 | Args: 157 | x (tensor): input mel, (B, T, C). 158 | x_length (tensor): len of per mel. (B,). 159 | 160 | Returns: 161 | tensor : (B). 162 | ''' 163 | validity_x = list() 164 | validity_y = list() 165 | #validity = 0.0 166 | for i in range(len(self.conv_layers)): 167 | if y != None: 168 | if self.win_lengths[i] != 1: 169 | x_clip,y_clip = self.clip(x, x_len, self.win_lengths[i], y, random_N[i]) # (B, win_length, C) 170 | else: 171 | #print(x.shape, y.shape) 172 | #x_clip, y_clip = x[:,:1300,:], y[:,:1300,:] 173 | x_clip, y_clip = x, y 174 | y_clip = y_clip.transpose(2,1) 175 | 176 | else: 177 | if self.win_lengths[i] != 1: 178 | x_clip = self.clip(x, x_len, self.win_lengths[i], y, random_N[i]) # (B, win_length, C) 179 | else: 180 | #print(x.shape) 181 | #x_clip = x[:,:1300, :] 182 | x_clip = x 183 | 184 | x_clip = x_clip.transpose(2, 1) # (B, C, win_length) 185 | x_clip_r = self.conv_layers[i](x_clip) # 1d 186 | validity_x += [x_clip_r] 187 | x_clip_r = self.patch_layers[i](x_clip) # 2d 188 | validity_x += [x_clip_r] 189 | if y!= None: 190 | y_clip_r = self.conv_layers[i](y_clip) 191 | validity_y += [y_clip_r] 192 | y_clip_r = self.patch_layers[i](y_clip) 193 | validity_y += [y_clip_r] 194 | 195 | #validity += x_clip 196 | if y == None: 197 | return validity_x 198 | else: 199 | return validity_x, validity_y 200 | 201 | class PatchGAN(nn.Module): # 202 | def __init__(self): 203 | super(PatchGAN, self).__init__() 204 | 205 | self.convLayer1 = nn.Sequential(nn.Conv2d(in_channels=1, 206 | out_channels=32, 207 | kernel_size=(3, 3), 208 | stride=(1, 1), 209 | padding=(1, 1)), 210 | GLU()) 211 | 212 | # DownSample Layer 213 | self.downSample1 = self.downSample(in_channels=32, 214 | out_channels=64, 215 | kernel_size=(3, 3), 216 | stride=(2, 2), 217 | padding=1) 218 | 219 | self.downSample2 = self.downSample(in_channels=64, 220 | out_channels=128, 221 | kernel_size=(3, 3), 222 | stride=[2, 2], 223 | padding=1) 224 | # Conv Layer 225 | self.outputConvLayer_2 = nn.Sequential(nn.Conv2d(in_channels=128, 226 | out_channels=1, 227 | kernel_size=(1, 3), 228 | stride=[1, 1], 229 | padding=[0, 1]) 230 | ) 231 | 232 | self.downSample3 = self.downSample(in_channels=128, 233 | out_channels=256, 234 | kernel_size=[3, 3], 235 | stride=[2, 2], 236 | padding=1) 237 | # Conv Layer 238 | self.outputConvLayer_3 = nn.Sequential(nn.Conv2d(in_channels=256, 239 | out_channels=1, 240 | kernel_size=(1, 3), 241 | stride=[1, 1], 242 | padding=[0, 1])) 243 | 244 | self.downSample4 = self.downSample(in_channels=256, 245 | out_channels=512, 246 | kernel_size=[3, 3], 247 | stride=[2, 2], 248 | padding=1) 249 | # Conv Layer 250 | self.outputConvLayer_4 = nn.Sequential(nn.Conv2d(in_channels=512, 251 | out_channels=1, 252 | kernel_size=(1, 3), 253 | stride=[1, 1], 254 | padding=[0, 1])) 255 | self.downSample5 = self.downSample(in_channels=512, 256 | out_channels=1024, 257 | kernel_size=[3, 3], 258 | stride=[2, 2], 259 | padding=1) 260 | 261 | 262 | # Conv Layer 263 | self.outputConvLayer = nn.Sequential(nn.Conv2d(in_channels=1024, 264 | out_channels=1, 265 | kernel_size=(1, 3), 266 | stride=[1, 1], 267 | padding=[0, 1])) 268 | 269 | def downSample(self, in_channels, out_channels, kernel_size, stride, padding): 270 | convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels, 271 | out_channels=out_channels, 272 | kernel_size=kernel_size, 273 | stride=stride, 274 | padding=padding), 275 | nn.InstanceNorm2d(num_features=out_channels, 276 | affine=True), 277 | GLU()) 278 | return convLayer 279 | 280 | def forward(self, input): 281 | # input has shape [batch_size, num_features, time] 282 | # discriminator requires shape [batchSize, 1, num_features, time] 283 | input = input.unsqueeze(1) 284 | #print("input : {}".format(input.shape)) 285 | feature_list = [] 286 | conv_layer_1 = self.convLayer1(input) 287 | feature_list.append(conv_layer_1) 288 | #print("conv_layer_1: {}".format(conv_layer_1.shape)) 289 | 290 | downsample1 = self.downSample1(conv_layer_1) 291 | feature_list.append(downsample1) 292 | #output_1 = torch.sigmoid(self.outputConvLayer_1(downsample1)) 293 | #print("downsample1 {} output_1 {}".format(downsample1.shape, output_1.shape)) 294 | downsample2 = self.downSample2(downsample1) 295 | feature_list.append(downsample2) 296 | output_2 = torch.sigmoid(self.outputConvLayer_2(downsample2)) 297 | #print("downsample2 {} output_2 {}".format(downsample2.shape, output_2.shape)) 298 | downsample3 = self.downSample3(downsample2) 299 | feature_list.append(downsample3) 300 | output_3 = torch.sigmoid(self.outputConvLayer_3(downsample3)) 301 | #print("downsample3 {} output_3 {}".format(downsample3.shape, output_3.shape)) 302 | downsample4 = self.downSample4(downsample3) 303 | feature_list.append(downsample4) 304 | output_4 = torch.sigmoid(self.outputConvLayer_4(downsample4)) 305 | #print("downsample4 {} output_4 {}".format(downsample4.shape, output_4.shape)) 306 | downsample5 = self.downSample5(downsample4) 307 | feature_list.append(downsample5) 308 | #print("downsample5 {}".format(downsample5.shape)) 309 | 310 | output = torch.sigmoid(self.outputConvLayer(downsample5)) 311 | #print("output {} ".format(output.shape)) 312 | output = output.view(output.shape[0], output.shape[1], -1) 313 | output_4 = output_4.view(output.shape[0], output.shape[1], -1) 314 | output_3 = output_3.view(output.shape[0], output.shape[1], -1) 315 | output_2 = output_2.view(output.shape[0], output.shape[1], -1) 316 | #output_1 = output_1.view(output.shape[0], output.shape[1], -1) 317 | output = torch.cat((output,output_4,output_3, output_2), axis=2) 318 | #output = output.view(output.shape[0], output.shape[1], -1) 319 | return [output, feature_list] 320 | 321 | class MultibandFrequencyDiscriminator(nn.Module): 322 | def __init__(self, 323 | time_lengths=[200, 400, 600, 800, 1], 324 | freq_lengths=[ 60, 60, 60, 60, 60], 325 | multi_channels=[[87, 87, 87, 87, 87 ], [87, 87, 87, 87, 87], [87,87, 87, 87, 87]] 326 | ): 327 | super(MultibandFrequencyDiscriminator, self).__init__() 328 | 329 | self.time_lengths = time_lengths 330 | self.multi_win_discriminator_low = MultiWindowDiscriminator(time_lengths, freq_lengths, multi_channels[0]) 331 | self.multi_win_discriminator_middle = MultiWindowDiscriminator(time_lengths, freq_lengths, multi_channels[1]) 332 | self.multi_win_discriminator_high = MultiWindowDiscriminator(time_lengths, freq_lengths, multi_channels[2]) 333 | 334 | def forward(self, x, x_len, y=None, random_N=[]): 335 | ''' 336 | Args: 337 | x (tensor): input mel, (B, T, C). 338 | x_length (tensor): len of per mel. (B,). 339 | 340 | Returns: 341 | list : [(B), (B,), (B,)]. 342 | ''' 343 | if len(random_N) == 0: 344 | len_min = min(x_len.cpu()) 345 | time_max = max(self.time_lengths) 346 | start = 0 347 | end = len_min - time_max 348 | if end <= 0: 349 | end = int(len_min / 2) 350 | random_N = np.random.randint(start, end , len(self.time_lengths)) 351 | 352 | #print(x_len) 353 | base_mel = x[:,:,:120] 354 | xa = base_mel[:,:,:60] 355 | xb = base_mel[:,:,30:90] 356 | xc = base_mel[:,:,60:120] 357 | if y != None: 358 | y_mel = y[:, :, :120] 359 | ya = y_mel[:,:,:60] 360 | yb = y_mel[:,:, 30:90] 361 | yc = y_mel[:,:,60:120] 362 | else: 363 | ya = yb = yc = None 364 | 365 | 366 | x_list = [ 367 | self.multi_win_discriminator_low(xa, x_len, ya, random_N), 368 | self.multi_win_discriminator_middle(xb, x_len,yb, random_N), 369 | self.multi_win_discriminator_high(xc, x_len, yc, random_N), 370 | ] 371 | return x_list, random_N 372 | 373 | class Discriminator(nn.Module): 374 | def __init__(self): 375 | super(Discriminator, self).__init__() 376 | 377 | self.discriminator = MultibandFrequencyDiscriminator() 378 | 379 | def forward(self, x, x_len, y=None, random_N=[]): 380 | return self.discriminator(x, x_len, y, random_N) 381 | 382 | 383 | if __name__ == "__main__": 384 | inputs = torch.randn(4, 1200, 120).cuda() 385 | tgt = torch.randn(4, 1200, 120).cuda() 386 | inputs_len = torch.tensor([1200]).cuda() 387 | net = Discriminator().cuda() 388 | print(net) 389 | outputs, random_n = net(inputs, inputs_len, tgt) 390 | # import pdb; pdb.set_trace() 391 | for output in outputs: 392 | for a in output[0]: 393 | for aa in a: 394 | for aaa in aa: 395 | print(aaa.shape) 396 | for b in output[1]: 397 | for bb in b: 398 | for bbb in bb: 399 | print(bbb.shape) 400 | # print(output[0].shape) 401 | # print(output[1].shape) 402 | -------------------------------------------------------------------------------- /models/fastspeech2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from loss import FastSpeech2Loss 5 | 6 | from modules.transformer import Encoder, Decoder, PostNet 7 | from modules.variance.modules import VarianceAdaptor 8 | from pyutils import get_mask_from_lengths 9 | 10 | class FastSpeech2(nn.Module): 11 | """ FastSpeech2 """ 12 | 13 | def __init__(self, data_config, model_config): 14 | super(FastSpeech2, self).__init__() 15 | self.model_config = model_config 16 | 17 | self.encoder = Encoder(**model_config['transformer']['encoder']) 18 | self.variance_adaptor = VarianceAdaptor(data_config, model_config) 19 | self.decoder = Decoder(**model_config['transformer']['decoder']) 20 | self.mel_linear = nn.Linear( 21 | model_config["transformer"]["decoder"]['d_word_vec'], 22 | data_config["n_mels"], 23 | ) 24 | self.postnet = PostNet(data_config['n_mels'], **model_config['postnet']) 25 | 26 | self.speaker_emb = None 27 | if model_config["multi_speaker"]: 28 | n_speaker = model_config['spk_num'] 29 | self.speaker_emb = nn.Embedding( 30 | n_speaker, 31 | model_config["transformer"]["encoder"]["d_word_vec"], 32 | ) 33 | self.loss = FastSpeech2Loss(data_config, model_config) 34 | 35 | def forward( 36 | self, 37 | spks, 38 | texts, 39 | src_lens, 40 | max_src_len, 41 | mels=None, 42 | mel_lens=None, 43 | max_mel_len=None, 44 | p_targets=None, 45 | e_targets=None, 46 | d_targets=None, 47 | p_control=1.0, 48 | e_control=1.0, 49 | d_control=1.0, 50 | ): 51 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 52 | mel_masks = ( 53 | get_mask_from_lengths(mel_lens, max_mel_len) 54 | if mel_lens is not None 55 | else None 56 | ) 57 | output = self.encoder(texts, src_masks) 58 | 59 | if self.speaker_emb is not None: 60 | output = output + self.speaker_emb(spks).unsqueeze(1).expand( 61 | -1, max_src_len, -1 62 | ) 63 | 64 | ( 65 | output, 66 | p_predictions, 67 | e_predictions, 68 | log_d_predictions, 69 | d_rounded, 70 | mel_lens, 71 | mel_masks, 72 | ) = self.variance_adaptor( 73 | output, 74 | src_masks, 75 | mel_masks, 76 | max_mel_len, 77 | p_targets, 78 | e_targets, 79 | d_targets, 80 | p_control, 81 | e_control, 82 | d_control, 83 | ) 84 | 85 | output, mel_masks = self.decoder(output, mel_masks) 86 | output = self.mel_linear(output) 87 | 88 | postnet_output = self.postnet(output) + output 89 | 90 | outputs = (output, postnet_output, p_predictions, e_predictions, log_d_predictions, d_rounded, src_masks, mel_masks, src_lens, mel_lens) 91 | (total_loss, mel_loss, post_mel_loss, pitch_loss, energy_loss, duration_loss) = self.loss((mels, mel_lens, max_mel_len, p_targets, e_targets, d_targets), outputs) 92 | report_keys = { 93 | 'loss': total_loss, 94 | 'mel_loss': mel_loss, 95 | 'post_mel_loss': post_mel_loss, 96 | 'pitch_loss': pitch_loss, 97 | 'energy_loss': energy_loss, 98 | 'duration_loss': duration_loss 99 | } 100 | return total_loss, report_keys, output, postnet_output 101 | 102 | if __name__ == "__main__": 103 | import yaml 104 | with open('/home/zengchang/code/acoustic_v2/configs/data.yaml', 'r') as f: 105 | data_config = yaml.load(f, Loader = yaml.FullLoader) 106 | with open('/home/zengchang/code/acoustic_v2/configs/model.yaml', 'r') as f: 107 | model_config = yaml.load(f, Loader = yaml.FullLoader) 108 | model = FastSpeech2(data_config, model_config['generator']) 109 | print(model) 110 | -------------------------------------------------------------------------------- /models/xiaoice2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from loss import FastSpeech2Loss 5 | 6 | from modules.transformer import Encoder, Decoder, PostNet 7 | from modules.variance.modules import VarianceAdaptor 8 | from pyutils import get_mask_from_lengths 9 | 10 | class Xiaoice2(nn.Module): 11 | """ Xiaoice2 """ 12 | 13 | def __init__(self, data_config, model_config): 14 | super(Xiaoice2, self).__init__() 15 | self.model_config = model_config 16 | 17 | self.encoder = Encoder(**model_config['transformer']['encoder']) 18 | self.variance_adaptor = VarianceAdaptor(data_config, model_config) 19 | self.decoder = Decoder(**model_config['transformer']['decoder']) 20 | self.mel_linear = nn.Linear( 21 | model_config["transformer"]["decoder"]['d_word_vec'], 22 | data_config["n_mels"], 23 | ) 24 | self.postnet = PostNet(data_config['n_mels'], **model_config['postnet']) 25 | 26 | self.speaker_emb = None 27 | if model_config["multi_speaker"]: 28 | n_speaker = model_config['spk_num'] 29 | self.speaker_emb = nn.Embedding( 30 | n_speaker, 31 | model_config["transformer"]["encoder"]["d_word_vec"], 32 | ) 33 | self.loss = FastSpeech2Loss(data_config) 34 | 35 | def forward( 36 | self, 37 | texts, 38 | note_pitchs, 39 | note_durations, 40 | src_lens, 41 | max_src_len, 42 | mels=None, 43 | mel_lens=None, 44 | max_mel_len=None, 45 | p_targets=None, 46 | e_targets=None, 47 | uv_targets=None, 48 | d_targets=None, 49 | p_control=1.0, 50 | e_control=1.0, 51 | d_control=1.0, 52 | spks=None 53 | ): 54 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 55 | mel_masks = ( 56 | get_mask_from_lengths(mel_lens, max_mel_len) 57 | if mel_lens is not None 58 | else None 59 | ) 60 | output = self.encoder(texts, note_pitchs, note_durations, src_masks) 61 | 62 | if self.speaker_emb is not None: 63 | output = output + self.speaker_emb(spks).unsqueeze(1).expand( 64 | -1, max_src_len, -1 65 | ) 66 | 67 | ( 68 | output, 69 | p_predictions, 70 | e_predictions, 71 | uv_predictions, 72 | log_d_predictions, 73 | d_rounded, 74 | mel_lens, 75 | mel_masks, 76 | ) = self.variance_adaptor( 77 | output, 78 | src_masks, 79 | mel_masks, 80 | max_mel_len, 81 | p_targets, 82 | e_targets, 83 | uv_targets, 84 | d_targets, 85 | p_control, 86 | e_control, 87 | d_control, 88 | ) 89 | 90 | output, mel_masks = self.decoder(output, mel_masks) 91 | output = self.mel_linear(output) 92 | 93 | postnet_output = self.postnet(output) + output 94 | 95 | outputs = (output, postnet_output, p_predictions, e_predictions, uv_predictions, log_d_predictions, d_rounded, src_masks, mel_masks, src_lens, mel_lens) 96 | 97 | (total_loss, mel_loss, post_mel_loss, pitch_loss, energy_loss, uv_loss, duration_loss) = self.loss((mels, mel_lens, max_mel_len, p_targets, e_targets, uv_targets, d_targets), outputs) 98 | 99 | report_keys = { 100 | 'loss': total_loss, 101 | 'mel_loss': mel_loss, 102 | 'post_mel_loss': post_mel_loss, 103 | 'pitch_loss': pitch_loss, 104 | 'energy_loss': energy_loss, 105 | 'uv_loss': uv_loss, 106 | 'duration_loss': duration_loss 107 | } 108 | 109 | return total_loss, report_keys, output, postnet_output 110 | 111 | if __name__ == "__main__": 112 | import yaml 113 | with open('/home/zengchang/code/acoustic_v2/configs/data.yaml', 'r') as f: 114 | data_config = yaml.load(f, Loader = yaml.FullLoader) 115 | with open('/home/zengchang/code/acoustic_v2/configs/model.yaml', 'r') as f: 116 | model_config = yaml.load(f, Loader = yaml.FullLoader) 117 | model = Xiaoice2(data_config, model_config['generator']) 118 | print(model) 119 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer.Models import Encoder, Decoder 2 | from .transformer.Layers import PostNet -------------------------------------------------------------------------------- /modules/conv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/modules/conv/__init__.py -------------------------------------------------------------------------------- /modules/transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /modules/transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | try: 9 | from modules_v2.transformer.sublayer import MultiHeadAttention, PositionwiseFeedForward, MultiLayeredConv1d 10 | from modules_v2.transformer.embedding import PositionalEncoding 11 | from modules_v2.transformer.layer import EncoderLayer 12 | except (ImportError, ModuleNotFoundError): 13 | import sys 14 | import os 15 | filepath = os.path.dirname(os.path.abspath(__file__)) 16 | sys.path.insert(0, filepath) 17 | from SubLayers import MultiHeadAttention, PositionwiseFeedForward 18 | 19 | class FFTBlock(torch.nn.Module): 20 | """FFT Block""" 21 | 22 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 23 | super(FFTBlock, self).__init__() 24 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 25 | self.pos_ffn = PositionwiseFeedForward( 26 | d_model, d_inner, kernel_size, dropout=dropout 27 | ) 28 | 29 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 30 | enc_output, enc_slf_attn = self.slf_attn( 31 | enc_input, enc_input, enc_input, mask=slf_attn_mask 32 | ) 33 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 34 | 35 | enc_output = self.pos_ffn(enc_output) 36 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 37 | 38 | return enc_output, enc_slf_attn 39 | 40 | 41 | class ConvNorm(torch.nn.Module): 42 | def __init__( 43 | self, 44 | in_channels, 45 | out_channels, 46 | kernel_size=1, 47 | stride=1, 48 | padding=None, 49 | dilation=1, 50 | bias=True, 51 | w_init_gain="linear", 52 | ): 53 | super(ConvNorm, self).__init__() 54 | 55 | if padding is None: 56 | assert kernel_size % 2 == 1 57 | padding = int(dilation * (kernel_size - 1) / 2) 58 | 59 | self.conv = torch.nn.Conv1d( 60 | in_channels, 61 | out_channels, 62 | kernel_size=kernel_size, 63 | stride=stride, 64 | padding=padding, 65 | dilation=dilation, 66 | bias=bias, 67 | ) 68 | 69 | def forward(self, signal): 70 | conv_signal = self.conv(signal) 71 | 72 | return conv_signal 73 | 74 | 75 | class PostNet(nn.Module): 76 | """ 77 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 78 | """ 79 | 80 | def __init__( 81 | self, 82 | n_mels=80, 83 | postnet_embedding_dim=512, 84 | postnet_kernel_size=5, 85 | postnet_n_convolutions=5, 86 | ): 87 | 88 | super(PostNet, self).__init__() 89 | self.convolutions = nn.ModuleList() 90 | 91 | self.convolutions.append( 92 | nn.Sequential( 93 | ConvNorm( 94 | n_mels, 95 | postnet_embedding_dim, 96 | kernel_size=postnet_kernel_size, 97 | stride=1, 98 | padding=int((postnet_kernel_size - 1) / 2), 99 | dilation=1, 100 | w_init_gain="tanh", 101 | ), 102 | nn.BatchNorm1d(postnet_embedding_dim), 103 | ) 104 | ) 105 | 106 | for i in range(1, postnet_n_convolutions - 1): 107 | self.convolutions.append( 108 | nn.Sequential( 109 | ConvNorm( 110 | postnet_embedding_dim, 111 | postnet_embedding_dim, 112 | kernel_size=postnet_kernel_size, 113 | stride=1, 114 | padding=int((postnet_kernel_size - 1) / 2), 115 | dilation=1, 116 | w_init_gain="tanh", 117 | ), 118 | nn.BatchNorm1d(postnet_embedding_dim), 119 | ) 120 | ) 121 | 122 | self.convolutions.append( 123 | nn.Sequential( 124 | ConvNorm( 125 | postnet_embedding_dim, 126 | n_mels, 127 | kernel_size=postnet_kernel_size, 128 | stride=1, 129 | padding=int((postnet_kernel_size - 1) / 2), 130 | dilation=1, 131 | w_init_gain="linear", 132 | ), 133 | nn.BatchNorm1d(n_mels), 134 | ) 135 | ) 136 | 137 | def forward(self, x): 138 | x = x.contiguous().transpose(1, 2) 139 | 140 | for i in range(len(self.convolutions) - 1): 141 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 142 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 143 | 144 | x = x.contiguous().transpose(1, 2) 145 | return x 146 | 147 | if __name__ == "__main__": 148 | import sys 149 | sys.path.insert(0, '/home/zengchang/code/acoustic_v2/modules_v2/transformer') 150 | fft_block = FFTBlock(512, 8, 64, 64, 2048, [3,3]) 151 | x = torch.randn(2, 100, 512) 152 | mask = torch.ones(2, 100).bool() 153 | slf_attn_mask = torch.ones(2, 100, 100).bool() 154 | y, attn = fft_block(x, mask, slf_attn_mask) 155 | print(y.shape) 156 | print(attn.shape) -------------------------------------------------------------------------------- /modules/transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | try: 6 | import modules.transformer.Constants as Constants 7 | from modules.transformer.Layers import FFTBlock 8 | from dataset.texts.symbols import symbols 9 | except: 10 | import sys 11 | import os 12 | filepath = '/'.join(os.path.dirname(os.path.abspath(__file__)).split('/')[:-2]) 13 | print(filepath) 14 | sys.path.insert(0, filepath) 15 | import modules.transformer.Constants as Constants 16 | from modules.transformer.Layers import FFTBlock 17 | from dataset.texts.symbols import symbols 18 | 19 | 20 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 21 | """ Sinusoid position encoding table """ 22 | 23 | def cal_angle(position, hid_idx): 24 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 25 | 26 | def get_posi_angle_vec(position): 27 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 28 | 29 | sinusoid_table = np.array( 30 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 31 | ) 32 | 33 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 34 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 35 | 36 | if padding_idx is not None: 37 | # zero vector for padding dimension 38 | sinusoid_table[padding_idx] = 0.0 39 | 40 | return torch.FloatTensor(sinusoid_table) 41 | 42 | 43 | class Encoder(nn.Module): 44 | """ Encoder """ 45 | 46 | def __init__( 47 | self, max_seq_len, n_src_vocab, d_word_vec, 48 | n_layers, n_head, d_model, d_inner, max_note_pitch, 49 | max_note_duration, kernel_size, dropout=0.1 50 | ): 51 | super(Encoder, self).__init__() 52 | 53 | n_position = max_seq_len + 1 54 | n_src_vocab = n_src_vocab 55 | d_word_vec = d_word_vec 56 | n_layers = n_layers 57 | n_head = n_head 58 | d_k = d_v = (d_word_vec // n_head) 59 | d_model = d_model 60 | d_inner = d_inner 61 | kernel_size = kernel_size 62 | dropout = dropout 63 | 64 | # self.max_seq_len = config["max_seq_len"] 65 | self.max_seq_len = max_seq_len 66 | self.d_model = d_model 67 | 68 | self.src_word_emb = nn.Embedding( 69 | n_src_vocab, d_word_vec, padding_idx = Constants.PAD 70 | ) 71 | self.note_pitch_emb = nn.Embedding( 72 | max_note_pitch, d_word_vec, padding_idx = Constants.PAD 73 | ) 74 | self.note_duration_emb = nn.Embedding( 75 | max_note_duration, d_word_vec, padding_idx = Constants.PAD 76 | ) 77 | self.position_enc = nn.Parameter( 78 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 79 | requires_grad=False, 80 | ) 81 | 82 | self.layer_stack = nn.ModuleList( 83 | [ 84 | FFTBlock( 85 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 86 | ) 87 | for _ in range(n_layers) 88 | ] 89 | ) 90 | 91 | def forward(self, src_seq, note_pitchs, note_durations, mask, return_attns=False): 92 | 93 | enc_slf_attn_list = [] 94 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 95 | 96 | # -- Prepare masks 97 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 98 | 99 | # -- Forward 100 | if not self.training and src_seq.shape[1] > self.max_seq_len: 101 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 102 | src_seq.shape[1], self.d_model 103 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 104 | src_seq.device 105 | ) 106 | else: 107 | # print("training!!!!!!!!!!") 108 | enc_output = self.src_word_emb(src_seq) \ 109 | + self.note_pitch_emb(note_pitchs) \ 110 | + self.note_duration_emb(note_durations) \ 111 | + self.position_enc[:, :max_len, :].expand(batch_size, -1, -1) 112 | 113 | for enc_layer in self.layer_stack: 114 | enc_output, enc_slf_attn = enc_layer( 115 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 116 | ) 117 | if return_attns: 118 | enc_slf_attn_list += [enc_slf_attn] 119 | 120 | return enc_output 121 | 122 | 123 | class Decoder(nn.Module): 124 | """ Decoder """ 125 | 126 | def __init__( 127 | self, max_seq_len, d_word_vec, 128 | n_layers, n_head, d_model, d_inner, 129 | kernel_size, dropout=0.1 130 | ): 131 | super(Decoder, self).__init__() 132 | 133 | n_position = max_seq_len + 1 134 | d_word_vec = d_word_vec 135 | n_layers = n_layers 136 | n_head = n_head 137 | d_k = d_v = (d_word_vec // n_head) 138 | d_model = d_model 139 | d_inner = d_inner 140 | kernel_size = kernel_size 141 | dropout = dropout 142 | 143 | # self.max_seq_len = config["max_seq_len"] 144 | self.max_seq_len = max_seq_len 145 | self.d_model = d_model 146 | 147 | self.position_enc = nn.Parameter( 148 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 149 | requires_grad=False, 150 | ) 151 | 152 | self.layer_stack = nn.ModuleList( 153 | [ 154 | FFTBlock( 155 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 156 | ) 157 | for _ in range(n_layers) 158 | ] 159 | ) 160 | 161 | def forward(self, enc_seq, mask, return_attns=False): 162 | 163 | dec_slf_attn_list = [] 164 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 165 | 166 | # -- Forward 167 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 168 | # -- Prepare masks 169 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 170 | dec_output = enc_seq + get_sinusoid_encoding_table( 171 | enc_seq.shape[1], self.d_model 172 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 173 | enc_seq.device 174 | ) 175 | else: 176 | max_len = min(max_len, self.max_seq_len) 177 | 178 | # -- Prepare masks 179 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 180 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 181 | :, :max_len, : 182 | ].expand(batch_size, -1, -1) 183 | mask = mask[:, :max_len] 184 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 185 | 186 | for dec_layer in self.layer_stack: 187 | dec_output, dec_slf_attn = dec_layer( 188 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 189 | ) 190 | if return_attns: 191 | dec_slf_attn_list += [dec_slf_attn] 192 | 193 | return dec_output, mask 194 | 195 | if __name__ == "__main__": 196 | encoder = Encoder({"max_seq_len": 100, "transformer": {"encoder_hidden": 256, "encoder_layer": 6, "encoder_head": 4, "conv_filter_size": 1024, "conv_kernel_size": [3,3], "encoder_dropout": 0.1}}) 197 | decoder = Decoder({"max_seq_len": 100, "transformer": {"decoder_hidden": 256, "decoder_layer": 6, "decoder_head": 4, "conv_filter_size": 1024, "conv_kernel_size": [3,3], "decoder_dropout": 0.1}}) 198 | src_seq = torch.randint(0, 100, (2, 100)) 199 | mask = torch.ones((2, 100)).bool() 200 | enc_output = encoder(src_seq, mask) 201 | dec_output, mask = decoder(enc_output, mask) 202 | print(dec_output.shape, mask.shape) 203 | -------------------------------------------------------------------------------- /modules/transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /modules/transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | 49 | output = output.view(n_head, sz_b, len_q, d_v) 50 | output = ( 51 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 52 | ) # b x lq x (n*dv) 53 | 54 | output = self.dropout(self.fc(output)) 55 | output = self.layer_norm(output + residual) 56 | 57 | return output, attn 58 | 59 | 60 | class PositionwiseFeedForward(nn.Module): 61 | """ A two-feed-forward-layer module """ 62 | 63 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 64 | super().__init__() 65 | 66 | # Use Conv1D 67 | # position-wise 68 | self.w_1 = nn.Conv1d( 69 | d_in, 70 | d_hid, 71 | kernel_size=kernel_size[0], 72 | padding=(kernel_size[0] - 1) // 2, 73 | ) 74 | # position-wise 75 | self.w_2 = nn.Conv1d( 76 | d_hid, 77 | d_in, 78 | kernel_size=kernel_size[1], 79 | padding=(kernel_size[1] - 1) // 2, 80 | ) 81 | 82 | self.layer_norm = nn.LayerNorm(d_in) 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def forward(self, x): 86 | residual = x 87 | output = x.transpose(1, 2) 88 | output = self.w_2(F.relu(self.w_1(output))) 89 | output = output.transpose(1, 2) 90 | output = self.dropout(output) 91 | output = self.layer_norm(output + residual) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Constants import * 2 | from .Layers import * 3 | from .SubLayers import * 4 | from .Models import * 5 | from .Modules import * -------------------------------------------------------------------------------- /modules/variance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/modules/variance/__init__.py -------------------------------------------------------------------------------- /modules/variance/modules.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from pyutils import get_mask_from_lengths, pad 8 | 9 | def f02pitch(f0): 10 | #f0 =f0 + 0.01 11 | return np.log2(f0 / 27.5) * 12 + 21 12 | 13 | class VarianceAdaptor(nn.Module): 14 | """Variance Adaptor""" 15 | 16 | def __init__(self, data_config, model_config): 17 | super(VarianceAdaptor, self).__init__() 18 | self.duration_predictor = VariancePredictor(**model_config['variance_predictor']) 19 | self.length_regulator = LengthRegulator() 20 | self.pitch_predictor = VariancePredictor(**model_config['variance_predictor']) 21 | self.uv_predictor = VariancePredictor(**model_config['variance_predictor']) 22 | self.energy_predictor = VariancePredictor(**model_config['variance_predictor']) 23 | 24 | self.uv_threshold = model_config['uv_threshold'] 25 | 26 | self.pitch_feature_level = data_config["pitch"]["feature"] 27 | self.energy_feature_level = data_config["energy"]["feature"] 28 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 29 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 30 | 31 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"] 32 | energy_quantization = model_config["variance_embedding"]["energy_quantization"] 33 | n_bins = model_config["variance_embedding"]["n_bins"] 34 | assert pitch_quantization in ["linear", "log"] 35 | assert energy_quantization in ["linear", "log"] 36 | 37 | pitch_min_max = f02pitch(np.load(data_config['f0_min_max'])) 38 | pitch_min, pitch_max = pitch_min_max[0][0], pitch_min_max[0][1] 39 | # print(np.load(data_config['energy_min_max'])) 40 | 41 | energy_min_max = np.load(data_config['energy_min_max']) 42 | energy_min, energy_max = energy_min_max[0][0] + 1e-4, energy_min_max[0][1] 43 | 44 | if pitch_quantization == "log": 45 | self.pitch_bins = nn.Parameter( 46 | torch.exp( 47 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) 48 | ), 49 | requires_grad=False, 50 | ) 51 | else: 52 | self.pitch_bins = nn.Parameter( 53 | torch.linspace(pitch_min, pitch_max, n_bins - 1), 54 | requires_grad=False, 55 | ) 56 | if energy_quantization == "log": 57 | self.energy_bins = nn.Parameter( 58 | torch.exp( 59 | torch.linspace(np.log(energy_min + 1e-6), np.log(energy_max), n_bins - 1) 60 | ), 61 | requires_grad=False, 62 | ) 63 | else: 64 | self.energy_bins = nn.Parameter( 65 | torch.linspace(energy_min, energy_max, n_bins - 1), 66 | requires_grad=False, 67 | ) 68 | 69 | self.pitch_embedding = nn.Embedding( 70 | n_bins, model_config["transformer"]["encoder"]["d_word_vec"] 71 | ) 72 | self.energy_embedding = nn.Embedding( 73 | n_bins, model_config["transformer"]["encoder"]["d_word_vec"] 74 | ) 75 | self.uv_embedding = nn.Embedding( 76 | 2, model_config['transformer']['encoder']['d_word_vec'] 77 | ) 78 | 79 | def get_uv_embedding(self, x, target, mask, control=1.0): 80 | prediction = self.uv_predictor(x, mask) 81 | if target is not None: 82 | embedding = self.uv_embedding(target.to(torch.int64)) 83 | else: 84 | prediction = prediction * control 85 | prediction = torch.sigmoid(prediction) 86 | for i in range(prediction.shape[0]): 87 | prediction[i] = prediction[i] >= self.uv_threshold # (B, max_frames, 1) 88 | 89 | embedding = self.uv_embedding(prediction.long()) 90 | 91 | return prediction, embedding 92 | 93 | def get_pitch_embedding(self, x, target, mask, control): 94 | prediction = self.pitch_predictor(x, mask) 95 | if target is not None: 96 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) 97 | else: 98 | prediction = prediction * control 99 | embedding = self.pitch_embedding( 100 | torch.bucketize(prediction, self.pitch_bins) 101 | ) 102 | return prediction, embedding 103 | 104 | def get_energy_embedding(self, x, target, mask, control): 105 | prediction = self.energy_predictor(x, mask) 106 | if target is not None: 107 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) 108 | else: 109 | prediction = prediction * control 110 | embedding = self.energy_embedding( 111 | torch.bucketize(prediction, self.energy_bins) 112 | ) 113 | return prediction, embedding 114 | 115 | def forward( 116 | self, 117 | x, 118 | src_mask, 119 | mel_mask=None, 120 | max_len=None, 121 | pitch_target=None, 122 | energy_target=None, 123 | uv_target=None, 124 | duration_target=None, 125 | p_control=1.0, 126 | e_control=1.0, 127 | d_control=1.0, 128 | uv_control=1.0 129 | ): 130 | 131 | log_duration_prediction = self.duration_predictor(x, src_mask) 132 | if self.pitch_feature_level == "phoneme_level": 133 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 134 | x, pitch_target, src_mask, p_control 135 | ) 136 | x = x + pitch_embedding 137 | if self.energy_feature_level == "phoneme_level": 138 | energy_prediction, energy_embedding = self.get_energy_embedding( 139 | x, energy_target, src_mask, p_control 140 | ) 141 | x = x + energy_embedding 142 | 143 | if duration_target is not None: 144 | x, mel_len = self.length_regulator(x, duration_target, max_len) 145 | duration_rounded = duration_target 146 | else: 147 | duration_rounded = torch.clamp( 148 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 149 | min=0, 150 | ) 151 | x, mel_len = self.length_regulator(x, duration_rounded, max_len) 152 | mel_mask = get_mask_from_lengths(mel_len) 153 | 154 | if self.pitch_feature_level == "frame_level": 155 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 156 | x, pitch_target, mel_mask, p_control 157 | ) 158 | x = x + pitch_embedding 159 | if self.energy_feature_level == "frame_level": 160 | energy_prediction, energy_embedding = self.get_energy_embedding( 161 | x, energy_target, mel_mask, p_control 162 | ) 163 | x = x + energy_embedding 164 | 165 | uv_prediction, uv_embedding = self.get_uv_embedding( 166 | x, uv_target, mel_mask, uv_control 167 | ) 168 | x = x + uv_embedding 169 | 170 | return ( 171 | x, 172 | pitch_prediction, 173 | energy_prediction, 174 | uv_prediction, 175 | log_duration_prediction, 176 | duration_rounded, 177 | mel_len, 178 | mel_mask, 179 | ) 180 | 181 | 182 | class LengthRegulator(nn.Module): 183 | """Length Regulator""" 184 | 185 | def __init__(self): 186 | super(LengthRegulator, self).__init__() 187 | 188 | def LR(self, x, duration, max_len): 189 | device = x.device 190 | output = list() 191 | mel_len = list() 192 | for batch, expand_target in zip(x, duration): 193 | expanded = self.expand(batch, expand_target) 194 | output.append(expanded) 195 | mel_len.append(expanded.shape[0]) 196 | 197 | if max_len is not None: 198 | output = pad(output, max_len) 199 | else: 200 | output = pad(output) 201 | 202 | return output, torch.LongTensor(mel_len).to(device) 203 | 204 | def expand(self, batch, predicted): 205 | out = list() 206 | 207 | for i, vec in enumerate(batch): 208 | expand_size = predicted[i].item() 209 | out.append(vec.expand(max(int(expand_size), 0), -1)) 210 | out = torch.cat(out, 0) 211 | 212 | return out 213 | 214 | def forward(self, x, duration, max_len): 215 | output, mel_len = self.LR(x, duration, max_len) 216 | return output, mel_len 217 | 218 | 219 | class VariancePredictor(nn.Module): 220 | """Duration, Pitch and Energy Predictor""" 221 | 222 | def __init__( 223 | self, input_size, filter_size, 224 | kernel_size, dropout 225 | ): 226 | super(VariancePredictor, self).__init__() 227 | 228 | # self.input_size = model_config["transformer"]["encoder_hidden"] 229 | # self.filter_size = model_config["variance_predictor"]["filter_size"] 230 | # self.kernel = model_config["variance_predictor"]["kernel_size"] 231 | # self.conv_output_size = model_config["variance_predictor"]["filter_size"] 232 | # self.dropout = model_config["variance_predictor"]["dropout"] 233 | self.input_size = input_size 234 | self.filter_size = filter_size 235 | self.kernel = kernel_size 236 | self.conv_output_size = filter_size 237 | self.dropout = dropout 238 | 239 | self.conv_layer = nn.Sequential( 240 | OrderedDict( 241 | [ 242 | ( 243 | "conv1d_1", 244 | Conv( 245 | self.input_size, 246 | self.filter_size, 247 | kernel_size=self.kernel, 248 | padding=(self.kernel - 1) // 2, 249 | ), 250 | ), 251 | ("relu_1", nn.ReLU()), 252 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 253 | ("dropout_1", nn.Dropout(self.dropout)), 254 | ( 255 | "conv1d_2", 256 | Conv( 257 | self.filter_size, 258 | self.filter_size, 259 | kernel_size=self.kernel, 260 | padding=1, 261 | ), 262 | ), 263 | ("relu_2", nn.ReLU()), 264 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 265 | ("dropout_2", nn.Dropout(self.dropout)), 266 | ] 267 | ) 268 | ) 269 | 270 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 271 | 272 | def forward(self, encoder_output, mask): 273 | out = self.conv_layer(encoder_output) 274 | out = self.linear_layer(out) 275 | out = out.squeeze(-1) 276 | 277 | if mask is not None: 278 | out = out.masked_fill(mask, 0.0) 279 | 280 | return out 281 | 282 | 283 | class Conv(nn.Module): 284 | """ 285 | Convolution Module 286 | """ 287 | 288 | def __init__( 289 | self, 290 | in_channels, 291 | out_channels, 292 | kernel_size=1, 293 | stride=1, 294 | padding=0, 295 | dilation=1, 296 | bias=True, 297 | w_init="linear", 298 | ): 299 | """ 300 | :param in_channels: dimension of input 301 | :param out_channels: dimension of output 302 | :param kernel_size: size of kernel 303 | :param stride: size of stride 304 | :param padding: size of padding 305 | :param dilation: dilation rate 306 | :param bias: boolean. if True, bias is included. 307 | :param w_init: str. weight inits with xavier initialization. 308 | """ 309 | super(Conv, self).__init__() 310 | 311 | self.conv = nn.Conv1d( 312 | in_channels, 313 | out_channels, 314 | kernel_size=kernel_size, 315 | stride=stride, 316 | padding=padding, 317 | dilation=dilation, 318 | bias=bias, 319 | ) 320 | 321 | def forward(self, x): 322 | x = x.contiguous().transpose(1, 2) 323 | x = self.conv(x) 324 | x = x.contiguous().transpose(1, 2) 325 | 326 | return x 327 | 328 | if __name__ == "__main__": 329 | import yaml 330 | with open('/home/zengchang/code/acoustic_v2/configs/data.yaml', 'r') as f: 331 | data_config = yaml.load(f, Loader = yaml.FullLoader) 332 | with open('/home/zengchang/code/acoustic_v2/configs/model.yaml', 'r') as f: 333 | model_config = yaml.load(f, Loader = yaml.FullLoader) 334 | model = VarianceAdaptor(data_config, model_config['generator']) 335 | print(model) 336 | -------------------------------------------------------------------------------- /pics/2085003136_145600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/2085003136_145600.png -------------------------------------------------------------------------------- /pics/after_2085003136_145600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/after_2085003136_145600.png -------------------------------------------------------------------------------- /pics/before_2085003136_145600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/before_2085003136_145600.png -------------------------------------------------------------------------------- /pics/before_mel_l2_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/before_mel_l2_loss.png -------------------------------------------------------------------------------- /pics/post_mel_l2_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/post_mel_l2_loss.png -------------------------------------------------------------------------------- /pics/xs1_before_2085003136_145600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengchang233/xiaoicesing2/b2d0436ac73337123a7fe75a7be21ed385b48098/pics/xs1_before_2085003136_145600.png -------------------------------------------------------------------------------- /preprocess/audio_preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import librosa 5 | import pyworld 6 | import parselmouth 7 | import soundfile as sf 8 | import numpy as np 9 | import yaml 10 | from tqdm import tqdm 11 | from sklearn.preprocessing import StandardScaler 12 | from ipdb import set_trace 13 | from pyutils import f02pitch, pitch2f0, pitchxuv 14 | cwd=os.path.dirname(os.path.realpath(__file__)) 15 | sys.path.insert(0, cwd) 16 | 17 | def resample_wav(wav, src_sr, tgt_sr): 18 | return librosa.resample(wav, orig_sr=src_sr, target_sr=tgt_sr) 19 | 20 | def _resize_f0(x, target_len): 21 | source = np.array(x) 22 | source[source < 0.001] = np.nan 23 | target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)), source) 24 | res = np.nan_to_num(target) 25 | return res 26 | 27 | def compute_f0_dio(wav, p_len=None, sampling_rate=48000, hop_length=240): 28 | if p_len is None: 29 | p_len = wav.shape[0]//hop_length 30 | f0, t = pyworld.dio( 31 | wav.astype(np.double), 32 | fs=sampling_rate, 33 | f0_ceil=800, 34 | frame_period=1000 * hop_length / sampling_rate 35 | ) 36 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, sampling_rate) 37 | for index, pitch in enumerate(f0): 38 | f0[index] = round(pitch, 1) 39 | return _resize_f0(f0, p_len) 40 | 41 | def compute_f0_parselmouth(wav, p_len=None, sampling_rate=48000, hop_length=240): 42 | x = wav 43 | if p_len is None: 44 | p_len = x.shape[0]//hop_length 45 | else: 46 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 47 | time_step = hop_length / sampling_rate * 1000 48 | f0_min = 50 49 | f0_max = 1100 50 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac( 51 | time_step=time_step / 1000, voicing_threshold=0.6, 52 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 53 | 54 | pad_size=(p_len - len(f0) + 1) // 2 55 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 56 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 57 | return f0 58 | 59 | def interpolate_f0(f0): 60 | data = np.reshape(f0, (f0.size, 1)) 61 | 62 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32) 63 | vuv_vector[data > 0.0] = 1.0 64 | vuv_vector[data <= 0.0] = 0.0 65 | 66 | ip_data = data 67 | 68 | frame_number = data.size 69 | last_value = 0.0 70 | for i in range(frame_number): 71 | if data[i] <= 0.0: 72 | j = i + 1 73 | for j in range(i + 1, frame_number): 74 | if data[j] > 0.0: 75 | break 76 | if j < frame_number - 1: 77 | if last_value > 0.0: 78 | step = (data[j] - data[i - 1]) / float(j - i) 79 | for k in range(i, j): 80 | ip_data[k] = data[i - 1] + step * (k - i + 1) 81 | else: 82 | for k in range(i, j): 83 | ip_data[k] = data[j] 84 | else: 85 | for k in range(i, frame_number): 86 | ip_data[k] = last_value 87 | else: 88 | ip_data[i] = data[i] # this may not be necessary 89 | last_value = data[i] 90 | 91 | return ip_data[:,0], vuv_vector[:,0] 92 | 93 | def read_scp(scp_file): 94 | filelists = [] 95 | with open(scp_file, 'r') as f: 96 | for line in f: 97 | line = line.rstrip() 98 | filelists.append(line) 99 | return filelists 100 | 101 | def spec_normalize(feat): 102 | ''' 103 | params: 104 | feat: T, F 105 | ''' 106 | return (feat - feat.mean(axis = 0, keepdims = True)) / (feat.std(axis = 0, keepdims = True) + 2e-12) 107 | 108 | def pad_wav(wav, config): 109 | padded_wav = np.pad(wav, (int((config['n_fft']-config['hop_length'])/2), int((config['n_fft']-config['hop_length'])/2)), mode='reflect') 110 | return padded_wav 111 | 112 | def extract_spec_with_energy(wav, filepath, config, spec_scaler = None, energy_scaler = None): 113 | ''' 114 | (T, F/C) 115 | ''' 116 | wav = pad_wav(wav, config) 117 | stft = librosa.stft( 118 | wav, 119 | n_fft = config['n_fft'], 120 | hop_length = config['hop_length'], 121 | win_length = config['win_length'], 122 | window = 'hann', 123 | center = False, 124 | pad_mode = 'reflect' 125 | ) 126 | # set_trace() 127 | spec = np.abs(stft).transpose(1, 0) 128 | energy = np.sqrt((spec**2).sum(axis = 1)) 129 | energy = energy.reshape(-1, 1) 130 | if spec_scaler is not None: 131 | spec_scaler.partial_fit(spec) 132 | if energy_scaler is not None: 133 | energy_scaler.partial_fit(energy) 134 | suffix = filepath.split('.')[-1] 135 | spec_filepath = filepath.replace(f'.{suffix}', '.spec.npy') 136 | np.save(spec_filepath, spec) 137 | suffix = filepath.split('.')[-1] 138 | energy_filepath = filepath.replace(f'.{suffix}', '.en.npy') 139 | np.save(energy_filepath, energy) 140 | # return spec, energy 141 | 142 | def extract_mel(wav, filepath, config, mel_scaler): 143 | ''' 144 | log mel + spec normalization 145 | (T, F/C) 146 | ''' 147 | wav = pad_wav(wav, config) 148 | mel_spec = librosa.feature.melspectrogram( 149 | y = wav, 150 | sr = config['sampling_rate'], 151 | n_fft = config['n_fft'], 152 | hop_length = config['hop_length'], 153 | win_length = config['win_length'], 154 | window = 'hann', 155 | n_mels = config['n_mels'], 156 | fmin = config['fmin'], 157 | fmax = config['fmax'], 158 | center = False, 159 | pad_mode = 'reflect' 160 | ) 161 | log_mel_spec = np.log(mel_spec + 1e-9).transpose(1, 0) 162 | # normalized_log_mel_spec = spec_normalize(log_mel_spec) 163 | mel_scaler.partial_fit(log_mel_spec) 164 | suffix = filepath.split('.')[-1] 165 | mel_filepath = filepath.replace(f'.{suffix}', '.mel.npy') 166 | np.save(mel_filepath, log_mel_spec) 167 | 168 | def extract_f0(filepath, config, f0_scaler = None): 169 | ''' 170 | (T, 1) 171 | ''' 172 | wav, sr = sf.read(filepath) 173 | wav = resample_wav(wav, sr, config['sampling_rate']) 174 | sr = config['sampling_rate'] 175 | assert sr == config['sampling_rate'], "Sampling rate ({}) != {}, please fix it!".format(sr, config['sampling_rate']) 176 | # wav = pad_wav(wav, config) # don't padding for computing f0 177 | f0 = compute_f0_dio( 178 | wav, 179 | sampling_rate = config["sampling_rate"], 180 | hop_length = config["hop_length"] 181 | ) 182 | f0, uv = interpolate_f0(f0) 183 | f0 = f0.reshape(-1, 1) 184 | if f0_scaler is not None: 185 | f0_scaler.partial_fit(f0) 186 | suffix = filepath.split('.')[-1] 187 | f0_filepath = filepath.replace(f'.{suffix}', '.f0.npy') 188 | uv_filepath = filepath.replace(f'.{suffix}', '.uv.npy') 189 | np.save(f0_filepath, f0) 190 | np.save(uv_filepath, uv) 191 | 192 | def process_one_utterance_spec(filepath, config, spec_scaler, mel_scaler, energy_scaler = None): 193 | wav, sr = sf.read(filepath) 194 | wav = resample_wav(wav, sr, config['sampling_rate']) 195 | sr = config['sampling_rate'] 196 | assert sr == config['sampling_rate'], "Sampling rate ({}) != {}, please fix it!".format(sr, config['sampling_rate']) 197 | if args.spec: 198 | extract_spec_with_energy(wav, filepath, config, spec_scaler, energy_scaler) 199 | if args.mel: 200 | extract_mel(wav, filepath, config, mel_scaler) 201 | 202 | def normalize(filelists, mean, std, feature = 'f0'): 203 | ''' 204 | normalize spec/mel_spec 205 | unnormalize f0/energy 206 | ''' 207 | min_value = np.finfo(np.float64).max 208 | max_value = np.finfo(np.float64).min 209 | for filepath in filelists: 210 | suffix = filepath.split('.')[-1] 211 | filepath = filepath.replace(f'.{suffix}', f'.{feature}.npy') 212 | values = np.load(filepath) 213 | if feature in ['f0', 'en']: 214 | min_value = min(min_value, min(values)) 215 | max_value = max(max_value, max(values)) 216 | else: 217 | values = (np.load(filepath) - mean) / std 218 | np.save(filepath, values) 219 | return np.array([min_value, max_value]).reshape(1, -1) 220 | 221 | def parse_args(): 222 | parser = argparse.ArgumentParser() 223 | parser.add_argument("--data-config", dest = "data_config", type = str, default = "", help = "data config path") 224 | parser.add_argument("--spec", action = "store_true", help = "extract stft spec feature") 225 | parser.add_argument("--mel", action = "store_true", help = "extract mel feature") 226 | parser.add_argument("--f0", action = "store_true", help = "extract f0 and uv") 227 | parser.add_argument("--energy", action = "store_true", help = "extract energy") 228 | parser.add_argument("--stat", action = "store_true", help = "Count the statistical numbers (mean and std) for energy and f0") 229 | 230 | args = parser.parse_args() 231 | return args 232 | 233 | def main(args): 234 | with open(args.data_config, 'r') as f: 235 | data_config = yaml.load(f, Loader = yaml.FullLoader) 236 | filelists = [] 237 | with open(data_config['audio_manifest'], 'r') as f: 238 | for line in f: 239 | line = line.rstrip().split(' ')[-1] 240 | filelists.append(line) 241 | args.scp_file = data_config['audio_manifest'] 242 | 243 | spec_scaler = StandardScaler() 244 | mel_scaler = StandardScaler() 245 | f0_scaler = None 246 | energy_scaler = None 247 | if args.stat: 248 | f0_scaler = StandardScaler() 249 | energy_scaler = StandardScaler() 250 | 251 | print("Extracting features...") 252 | for filepath in tqdm(filelists): 253 | if args.spec or args.mel: 254 | try: 255 | process_one_utterance_spec(filepath, data_config, spec_scaler, mel_scaler, energy_scaler) 256 | except: 257 | print(filepath) 258 | if args.f0: 259 | try: 260 | extract_f0(filepath, data_config, f0_scaler) 261 | except: 262 | print(filepath) 263 | 264 | if args.stat: 265 | if args.spec: 266 | spec_mean = spec_scaler.mean_.reshape(1, -1) 267 | spec_std = spec_scaler.scale_.reshape(1, -1) 268 | np.save(os.path.join(os.path.dirname(args.scp_file), 'spec_mean.npy'), spec_mean) 269 | np.save(os.path.join(os.path.dirname(args.scp_file), 'spec_std.npy'), spec_std) 270 | normalize(filelists, spec_mean, spec_std, feature = 'spec') 271 | 272 | if args.mel: 273 | mel_mean = mel_scaler.mean_.reshape(1, -1) 274 | mel_std = mel_scaler.scale_.reshape(1, -1) 275 | np.save(os.path.join(os.path.dirname(args.scp_file), 'mel_mean.npy'), mel_mean) 276 | np.save(os.path.join(os.path.dirname(args.scp_file), 'mel_std.npy'), mel_std) 277 | normalize(filelists, mel_mean, mel_std, feature = 'mel') 278 | 279 | if args.f0: 280 | print("Calculating f0 stats...") 281 | f0_mean = f0_scaler.mean_.reshape(1, -1) 282 | f0_std = f0_scaler.scale_.reshape(1, -1) 283 | np.save(os.path.join(os.path.dirname(args.scp_file), 'f0_mean.npy'), f0_mean) 284 | np.save(os.path.join(os.path.dirname(args.scp_file), 'f0_std.npy'), f0_std) 285 | f0_min_max = normalize(filelists, f0_mean, f0_std, feature = 'f0') 286 | np.save(os.path.join(os.path.dirname(args.scp_file), 'f0_min_max.npy'), f0_min_max) 287 | 288 | if args.energy: 289 | print("Calculating energy stats...") 290 | energy_mean = energy_scaler.mean_.reshape(1, -1) 291 | energy_std = energy_scaler.scale_.reshape(1, -1) 292 | np.save(os.path.join(os.path.dirname(args.scp_file), 'energy_mean.npy'), energy_mean) 293 | np.save(os.path.join(os.path.dirname(args.scp_file), 'energy_std.npy'), energy_std) 294 | energy_min_max = normalize(filelists, energy_mean, energy_std, feature = 'en') 295 | np.save(os.path.join(os.path.dirname(args.scp_file), 'energy_min_max.npy'), energy_min_max) 296 | 297 | if __name__ == "__main__": 298 | args = parse_args() 299 | print(args) 300 | main(args) 301 | -------------------------------------------------------------------------------- /preprocess/data_prep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import numpy as np 4 | from scipy.io import wavfile 5 | from tqdm import tqdm 6 | 7 | def prepare_aishell3(config): 8 | pass 9 | 10 | 11 | def prepare_align(config): 12 | in_dir = config["path"]["corpus_path"] 13 | out_dir = config["path"]["raw_path"] 14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 16 | for dataset in ["train", "test"]: 17 | print("Processing {}ing set...".format(dataset)) 18 | with open(os.path.join(in_dir, dataset, "content.txt"), encoding="utf-8") as f: 19 | for line in tqdm(f): 20 | wav_name, text = line.strip("\n").split("\t") 21 | speaker = wav_name[:7] 22 | text = text.split(" ")[1::2] 23 | wav_path = os.path.join(in_dir, dataset, "wav", speaker, wav_name) 24 | if os.path.exists(wav_path): 25 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 26 | wav, _ = librosa.load(wav_path, sampling_rate) 27 | wav = wav / max(abs(wav)) * max_wav_value 28 | wavfile.write( 29 | os.path.join(out_dir, speaker, wav_name), 30 | sampling_rate, 31 | wav.astype(np.int16), 32 | ) 33 | with open( 34 | os.path.join(out_dir, speaker, "{}.lab".format(wav_name[:11])), 35 | "w", 36 | ) as f1: 37 | f1.write(" ".join(text)) 38 | -------------------------------------------------------------------------------- /pyutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .save_and_load import * 2 | from .plot import * 3 | from .logger import * 4 | from .mask import * 5 | from .logger import * 6 | from .optimizer import * 7 | from . import scheduler 8 | import torch 9 | import numpy as np 10 | 11 | def f02pitch(f0): 12 | #f0 =f0 + 0.01 13 | return np.log2(f0 / 27.5) * 12 + 21 14 | 15 | def pitch2f0(pitch): 16 | f0 = np.exp2((pitch - 21 ) / 12) * 27.5 17 | for i in range(len(f0)): 18 | if f0[i] <= 10: 19 | f0[i] = 0 20 | return f0 21 | 22 | def pitchxuv(pitch, uv, to_f0 = False): 23 | result = pitch * uv 24 | if to_f0: 25 | result = pitch2f0(result) 26 | return result 27 | 28 | def initialize(model, init_type="pytorch"): 29 | """Initialize Transformer module 30 | 31 | :param torch.nn.Module model: core instance 32 | :param str init_type: initialization type 33 | """ 34 | if init_type == "pytorch": 35 | return 36 | 37 | # weight init 38 | for p in model.parameters(): 39 | if p.dim() > 1: 40 | if init_type == "xavier_uniform": 41 | torch.nn.init.xavier_uniform_(p.data) 42 | elif init_type == "xavier_normal": 43 | torch.nn.init.xavier_normal_(p.data) 44 | elif init_type == "kaiming_uniform": 45 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") 46 | elif init_type == "kaiming_normal": 47 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") 48 | else: 49 | raise ValueError("Unknown initialization: " + init_type) 50 | # bias init 51 | for p in model.parameters(): 52 | if p.dim() == 1: 53 | p.data.zero_() 54 | 55 | # reset some loss with default init 56 | for m in model.modules(): 57 | if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)): 58 | m.reset_parameters() 59 | 60 | def get_mask_from_lengths(lengths, max_len=None): 61 | device = lengths.device 62 | batch_size = lengths.shape[0] 63 | if max_len is None: 64 | max_len = torch.max(lengths).item() 65 | 66 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 67 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 68 | 69 | return mask 70 | 71 | def pad(input_ele, mel_max_length=None): 72 | if mel_max_length: 73 | max_len = mel_max_length 74 | else: 75 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 76 | 77 | out_list = list() 78 | for i, batch in enumerate(input_ele): 79 | if len(batch.shape) == 1: 80 | one_batch_padded = F.pad( 81 | batch, (0, max_len - batch.size(0)), "constant", 0.0 82 | ) 83 | elif len(batch.shape) == 2: 84 | one_batch_padded = F.pad( 85 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 86 | ) 87 | out_list.append(one_batch_padded) 88 | out_padded = torch.stack(out_list) 89 | return out_padded 90 | -------------------------------------------------------------------------------- /pyutils/gen_duration_from_tg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import os 16 | from pathlib import Path 17 | 18 | import librosa 19 | import numpy as np 20 | import yaml 21 | from praatio import textgrid 22 | from yacs.config import CfgNode 23 | from tqdm import tqdm 24 | 25 | 26 | def readtg(tg_path, sample_rate=24000, n_shift=300): 27 | alignment = textgrid.openTextgrid(tg_path, includeEmptyIntervals=True) 28 | phones = [] 29 | ends = [] 30 | for interval in alignment.tierDict["phones"].entryList: 31 | phone = interval.label 32 | phones.append(phone) 33 | ends.append(interval.end) 34 | frame_pos = librosa.time_to_frames(ends, sr=sample_rate, hop_length=n_shift) 35 | durations = np.diff(frame_pos, prepend=0) 36 | assert len(durations) == len(phones) 37 | # merge "" and sp in the end 38 | if phones[-1] == "" and len(phones) > 1 and phones[-2] == "sp": 39 | phones = phones[:-1] 40 | durations[-2] += durations[-1] 41 | durations = durations[:-1] 42 | # replace the last "sp" with "sil" in MFA1.x 43 | phones[-1] = "sil" if phones[-1] == "sp" else phones[-1] 44 | # replace the edge "" with "sil", replace the inner "" with "sp" 45 | new_phones = [] 46 | for i, phn in enumerate(phones): 47 | if phn == "": 48 | if i in {0, len(phones) - 1}: 49 | new_phones.append("sil") 50 | else: 51 | new_phones.append("sp") 52 | else: 53 | new_phones.append(phn) 54 | phones = new_phones 55 | results = "" 56 | for (p, d) in zip(phones, durations): 57 | results += p + " " + str(d) + " " 58 | return results.strip() 59 | 60 | 61 | # assume that the directory structure of inputdir is inputdir/speaker/*.TextGrid 62 | # in MFA1.x, there are blank labels("") in the end, and maybe "sp" before it 63 | # in MFA2.x, there are blank labels("") in the begin and the end, while no "sp" and "sil" anymore 64 | # we replace it with "sil" 65 | def gen_duration_from_textgrid(inputdir, output, sample_rate=24000, 66 | n_shift=300): 67 | # key: utt_id, value: (speaker, phn_durs) 68 | durations_dict = {} 69 | list_dir = os.listdir(inputdir) 70 | speakers = [dir for dir in list_dir if os.path.isdir(inputdir / dir)] 71 | for speaker in speakers: 72 | subdir = inputdir / speaker 73 | for file in tqdm(os.listdir(subdir)): 74 | if file.endswith(".TextGrid"): 75 | tg_path = subdir / file 76 | name = file.split(".")[0] 77 | durations_dict[name] = (speaker, readtg( 78 | tg_path, sample_rate=sample_rate, n_shift=n_shift)) 79 | with open(output, "w") as wf: 80 | for name in sorted(durations_dict.keys()): 81 | wf.write(name + "|" + durations_dict[name][0] + "|" + 82 | durations_dict[name][1] + "\n") 83 | 84 | 85 | def main(): 86 | # parse config and args 87 | parser = argparse.ArgumentParser( 88 | description="Preprocess audio and then extract features.") 89 | parser.add_argument( 90 | "--inputdir", 91 | default=None, 92 | type=str, 93 | help="directory to alignment files.") 94 | parser.add_argument( 95 | "--output", type=str, required=True, help="output duration file.") 96 | parser.add_argument("--sample-rate", type=int, help="the sample of wavs.") 97 | parser.add_argument( 98 | "--n-shift", 99 | type=int, 100 | help="the n_shift of time_to_freames, also called hop_length.") 101 | parser.add_argument( 102 | "--config", type=str, help="config file with fs and n_shift.") 103 | 104 | args = parser.parse_args() 105 | with open(args.config) as f: 106 | config = CfgNode(yaml.safe_load(f)) 107 | 108 | inputdir = Path(args.inputdir).expanduser() 109 | output = Path(args.output).expanduser() 110 | print(config) 111 | # import sys 112 | # sys.exit(0) 113 | gen_duration_from_textgrid(inputdir, output, config.sampling_rate, config.hop_length) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() -------------------------------------------------------------------------------- /pyutils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def get_logger(logging_file): 4 | logger = logging.getLogger() 5 | logger.setLevel(logging.INFO) 6 | 7 | formatter = logging.Formatter( 8 | "%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s: %(message)s" 9 | ) 10 | 11 | file_log_handler = logging.FileHandler(logging_file, mode = 'w') 12 | file_log_handler.setLevel(logging.INFO) 13 | file_log_handler.setFormatter(formatter) 14 | 15 | # stream_log_handler = logging.StreamHandler() 16 | # stream_log_handler.setLevel(logging.INFO) 17 | # stream_log_handler.setFormatter(formatter) 18 | 19 | logger.addHandler(file_log_handler) 20 | # logger.addHandler(stream_log_handler) 21 | 22 | return logger -------------------------------------------------------------------------------- /pyutils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ipdb import set_trace 4 | from torch.optim import ( 5 | SGD, 6 | Adam, 7 | AdamW, 8 | RMSprop, 9 | RAdam, 10 | NAdam, 11 | ASGD 12 | ) 13 | 14 | class NoamOpt(): 15 | "Optim wrapper that implements rate." 16 | 17 | def __init__(self, optimizer, model_size, warmup, factor = 1.0): 18 | ''' 19 | model_size: d_model 20 | factor: factor 21 | warmup: warmup step 22 | optimizer: optimizer (Adam default) 23 | ''' 24 | self.optimizer = optimizer 25 | self._step = 0 26 | self.warmup = warmup 27 | self.factor = factor 28 | self.model_size = model_size 29 | self._rate = 0 30 | 31 | @property 32 | def param_groups(self): 33 | return self.optimizer.param_groups 34 | 35 | def step(self): 36 | "Update parameters and rate" 37 | self._step += 1 38 | rate = self.rate() 39 | for p in self.optimizer.param_groups: 40 | p["lr"] = rate 41 | self._rate = rate 42 | self.optimizer.step() 43 | 44 | def rate(self, step=None): 45 | "Implement `lrate` above" 46 | if step is None: 47 | step = self._step 48 | return ( 49 | self.factor 50 | * self.model_size ** (-0.5) 51 | * min(step ** (-0.5), step * self.warmup ** (-1.5)) 52 | ) 53 | 54 | def zero_grad(self): 55 | self.optimizer.zero_grad() 56 | 57 | def state_dict(self): 58 | return { 59 | "_step": self._step, 60 | "warmup": self.warmup, 61 | "factor": self.factor, 62 | "model_size": self.model_size, 63 | "_rate": self._rate, 64 | "optimizer": self.optimizer.state_dict(), 65 | } 66 | 67 | def load_state_dict(self, state_dict): 68 | for key, value in state_dict.items(): 69 | if key == "optimizer": 70 | self.optimizer.load_state_dict(state_dict["optimizer"]) 71 | else: 72 | setattr(self, key, value) 73 | 74 | class ScheduledOptimD(): 75 | ''' A simple wrapper class for learning rate scheduling ''' 76 | 77 | def __init__(self, optimizer, init_lr, n_warmup_steps, current_steps): 78 | self.optimizer = optimizer 79 | self.n_warmup_steps = n_warmup_steps 80 | self.n_current_steps = current_steps 81 | self.init_lr = init_lr 82 | 83 | def step_and_update_lr_frozen(self, learning_rate_frozen): 84 | for param_group in self.optimizer.param_groups: 85 | param_group['lr'] = learning_rate_frozen 86 | self.optimizer.step() 87 | 88 | def step_and_update_lr(self): 89 | self._update_learning_rate() 90 | self.optimizer.step() 91 | 92 | def get_learning_rate(self): 93 | learning_rate = 0.0 94 | for param_group in self.optimizer.param_groups: 95 | learning_rate = param_group['lr'] 96 | 97 | return learning_rate 98 | 99 | def zero_grad(self): 100 | # print(self.init_lr) 101 | self.optimizer.zero_grad() 102 | 103 | def set_current_steps(self, step): 104 | self.n_current_steps = step 105 | 106 | def _get_lr_scale(self): 107 | # set_trace() 108 | return np.min([ 109 | np.power(self.n_current_steps, -0.5), 110 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 111 | 112 | def _update_learning_rate(self): 113 | ''' Learning rate scheduling per step ''' 114 | 115 | lr = self.init_lr * self._get_lr_scale() 116 | 117 | for param_group in self.optimizer.param_groups: 118 | param_group['lr'] = lr 119 | 120 | def state_dict(self): 121 | return { 122 | "_step": self.n_current_steps, 123 | "warmup": self.n_warmup_steps, 124 | "factor": self.init_lr, 125 | "_rate": self.get_learning_rate(), 126 | "optimizer": self.optimizer.state_dict(), 127 | } 128 | 129 | def load_state_dict(self, state_dict): 130 | for key, value in state_dict.items(): 131 | if key == "optimizer": 132 | self.optimizer.load_state_dict(state_dict["optimizer"]) 133 | else: 134 | setattr(self, key, value) 135 | 136 | def get_g_opt(model, optim, d_model, warmup, factor): 137 | base = torch.optim.Adam(model.parameters(), lr = 0, betas = (0.9, 0.98), eps = 1e-9) 138 | return NoamOpt(base, d_model, warmup, factor) 139 | 140 | def get_d_opt(model, optim, warmup, factor, current_step): 141 | base = torch.optim.Adam(model.parameters(), lr = 0, betas = (0.9, 0.98), eps = 1e-9) 142 | return ScheduledOptimD(base, factor, warmup, current_step) 143 | -------------------------------------------------------------------------------- /pyutils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. -------------------------------------------------------------------------------- /pyutils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import librosa 3 | import soundfile as sf 4 | import numpy as np 5 | import logging 6 | import argparse 7 | logging.getLogger('matplotlib.font_manager').disabled = True 8 | 9 | def specplot(spec, 10 | pic_path = 'exp/test/melspectrograms/spec.png', 11 | **kwargs): 12 | """Plot the log mel spectrogram of audio.""" 13 | fig = plt.figure() 14 | plt.imshow(spec, origin = 'lower', cmap = plt.cm.magma, aspect='auto') 15 | plt.colorbar() 16 | fig.savefig(pic_path) 17 | plt.close() 18 | 19 | def specplot_from_audio(filename = None, 20 | audio = None, 21 | rate = None, 22 | rotate=False, 23 | n_ffts = 1024, 24 | pic_path = 'exp/test/spectrograms/spec.png', 25 | **kwargs): 26 | """Plot the log magnitude spectrogram of audio.""" 27 | if filename is not None: 28 | audio, rate = sf.read(filename) 29 | hop_length = kwargs.get('hop_length', None) 30 | win_length = kwargs.get('win_length', None) 31 | stft = librosa.stft( 32 | audio, 33 | n_fft = n_ffts, 34 | hop_length = hop_length, 35 | win_length = win_length 36 | ) 37 | mag, phase = librosa.magphase(stft) 38 | logmag = np.log10(mag) 39 | fig = plt.figure() 40 | plt.imshow(logmag, cmap = plt.cm.magma, origin = 'lower', aspect = 'auto') 41 | plt.colorbar() 42 | fig.savefig(pic_path) 43 | plt.close() 44 | 45 | def melspecplot(mel_spec, 46 | pic_path = 'exp/test/melspectrograms/melspec.png', 47 | **kwargs): 48 | """Plot the log mel spectrogram of audio.""" 49 | fig = plt.figure() 50 | plt.imshow(mel_spec, origin = 'lower', cmap = plt.cm.magma, aspect='auto') 51 | plt.colorbar() 52 | fig.savefig(pic_path) 53 | plt.close() 54 | 55 | def melspecplot_from_audio(filename = None, 56 | audio = None, 57 | rate = None, 58 | rotate = False, 59 | n_ffts = 1024, 60 | pic_path = 'exp/test/melspectrograms/melspec.png', 61 | **kwargs): 62 | """Plot the log mel spectrogram of audio.""" 63 | if filename is not None: 64 | audio, rate = sf.read(filename) 65 | hop_length = kwargs.get('hop_length', None) 66 | win_length = kwargs.get('win_length', None) 67 | n_mels = kwargs.get('n_mels', 23) 68 | mel_spec = librosa.feature.melspectrogram( 69 | y = audio, 70 | sr = rate, 71 | n_fft = n_ffts, 72 | hop_length = hop_length, 73 | win_length = win_length, 74 | n_mels = n_mels 75 | ) 76 | mel_spec = librosa.power_to_db(mel_spec, ref=np.max) 77 | plt.imshow(mel_spec, cmap = plt.cm.magma, origin = 'lower', aspect = 'auto') 78 | plt.savefig(pic_path) 79 | plt.close() 80 | 81 | def get_args(): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--filename', type=str, default=None) 84 | parser.add_argument('--output', type=str, default=None) 85 | parser.add_argument('--mean', type=str, default=None) 86 | parser.add_argument('--std', type=str, default=None) 87 | return parser.parse_args() 88 | 89 | if __name__ == "__main__": 90 | args = get_args() 91 | import numpy as np 92 | # specplot_from_audio(args.filename, pic_path = args.output + '.aspec.png') 93 | data = np.load(args.filename).T 94 | mean = np.load(args.mean).T # (n_fft + 1, T) 95 | std = np.load(args.std).T 96 | if 'spec' in args.filename: 97 | specplot(data, mean, std, args.output + '.spec.png') 98 | if 'mel' in args.filename: 99 | melspecplot(data, args.output + '.melspec.png') -------------------------------------------------------------------------------- /pyutils/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import re 4 | import sys 5 | import argparse 6 | import logging 7 | import json 8 | import subprocess 9 | import warnings 10 | import random 11 | import functools 12 | 13 | import librosa 14 | import numpy as np 15 | from scipy.io.wavfile import read 16 | import torch 17 | from torch.nn import functional as F 18 | # from modules.commons import sequence_mask 19 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 20 | logger = logging 21 | 22 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 23 | f_list = glob.glob(os.path.join(dir_path, regex)) 24 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 25 | x = f_list[-1] 26 | print(x) 27 | return x 28 | 29 | def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None, skip_optimizer=False): 30 | assert os.path.isfile(checkpoint_path) 31 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 32 | iteration = checkpoint_dict['iteration'] 33 | learning_rate = checkpoint_dict['learning_rate'] 34 | if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: 35 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 36 | if scheduler is not None and not skip_optimizer and checkpoint_dict['scheduler'] is not None: 37 | scheduler.load_state_dict(checkpoint_dict['scheduler']) 38 | saved_state_dict = checkpoint_dict['model'] 39 | if hasattr(model, 'module'): 40 | state_dict = model.module.state_dict() 41 | else: 42 | state_dict = model.state_dict() 43 | new_state_dict = {} 44 | for k, v in state_dict.items(): 45 | try: 46 | # assert "dec" in k or "disc" in k 47 | # print("load", k) 48 | new_state_dict[k] = saved_state_dict[k] 49 | assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) 50 | except: 51 | print("error, %s is not in the checkpoint" % k) 52 | logger.info("%s is not in the checkpoint" % k) 53 | new_state_dict[k] = v 54 | if hasattr(model, 'module'): 55 | model.module.load_state_dict(new_state_dict) 56 | else: 57 | model.load_state_dict(new_state_dict) 58 | print("load") 59 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 60 | checkpoint_path, iteration)) 61 | return model, optimizer, scheduler, learning_rate, iteration 62 | 63 | def save_checkpoint(model, optimizer, scheduler, learning_rate, iteration, checkpoint_path): 64 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 65 | iteration, checkpoint_path)) 66 | if hasattr(model, 'module'): 67 | state_dict = model.module.state_dict() 68 | else: 69 | state_dict = model.state_dict() 70 | torch.save({'model': state_dict, 71 | 'iteration': iteration, 72 | 'optimizer': optimizer.state_dict(), 73 | 'scheduler': scheduler.state_dict(), 74 | 'learning_rate': learning_rate}, 75 | checkpoint_path) 76 | 77 | def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True): 78 | """Freeing up space by deleting saved ckpts 79 | 80 | Arguments: 81 | path_to_models -- Path to the model directory 82 | n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth 83 | sort_by_time -- True -> chronologically delete ckpts 84 | False -> lexicographically delete ckpts 85 | """ 86 | ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] 87 | name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1))) 88 | time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) 89 | sort_key = time_key if sort_by_time else name_key 90 | x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key) 91 | to_del = [os.path.join(path_to_models, fn) for fn in 92 | (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] 93 | del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") 94 | del_routine = lambda x: [os.remove(x), del_info(x)] 95 | rs = [del_routine(fn) for fn in to_del] 96 | 97 | class HParams(): 98 | def __init__(self, **kwargs): 99 | for k, v in kwargs.items(): 100 | if type(v) == dict: 101 | v = HParams(**v) 102 | self[k] = v 103 | 104 | def keys(self): 105 | return self.__dict__.keys() 106 | 107 | def items(self): 108 | return self.__dict__.items() 109 | 110 | def values(self): 111 | return self.__dict__.values() 112 | 113 | def __len__(self): 114 | return len(self.__dict__) 115 | 116 | def __getitem__(self, key): 117 | return getattr(self, key) 118 | 119 | def __setitem__(self, key, value): 120 | return setattr(self, key, value) 121 | 122 | def __contains__(self, key): 123 | return key in self.__dict__ 124 | 125 | def __repr__(self): 126 | return self.__dict__.__repr__() 127 | 128 | -------------------------------------------------------------------------------- /pyutils/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | import torch 4 | 5 | class WarmupLR(_LRScheduler): 6 | """The WarmupLR scheduler 7 | 8 | This scheduler is almost same as NoamLR Scheduler except for following 9 | difference: 10 | 11 | NoamLR: 12 | lr = optimizer.lr * model_size ** -0.5 13 | * min(step ** -0.5, step * warmup_step ** -1.5) 14 | WarmupLR: 15 | lr = optimizer.lr * warmup_step ** 0.5 16 | * min(step ** -0.5, step * warmup_step ** -1.5) 17 | 18 | Note that the maximum lr equals to optimizer.lr in this scheduler. 19 | 20 | """ 21 | 22 | def __init__( 23 | self, 24 | optimizer, 25 | warmup_steps = 25000, 26 | last_epoch = -1, 27 | ): 28 | self.warmup_steps = warmup_steps 29 | 30 | # __init__() must be invoked before setting field 31 | # because step() is also invoked in __init__() 32 | super().__init__(optimizer, last_epoch) 33 | 34 | def __repr__(self): 35 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" 36 | 37 | def get_lr(self): 38 | step_num = self.last_epoch + 1 39 | if self.warmup_steps == 0: 40 | return [ 41 | lr * step_num ** -0.5 42 | for lr in self.base_lrs 43 | ] 44 | else: 45 | return [ 46 | lr 47 | * self.warmup_steps ** 0.5 48 | * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) 49 | for lr in self.base_lrs 50 | ] 51 | 52 | def set_step(self, step: int): 53 | self.last_epoch = step 54 | 55 | class BaseClass: 56 | ''' 57 | Base Class for learning rate scheduler 58 | ''' 59 | 60 | def __init__(self, 61 | optimizer, 62 | num_epochs, 63 | epoch_iter, 64 | initial_lr, 65 | final_lr, 66 | warm_up_epoch=6, 67 | scale_ratio=1.0, 68 | warm_from_zero=False): 69 | ''' 70 | warm_up_epoch: the first warm_up_epoch is the multiprocess warm-up stage 71 | scale_ratio: multiplied to the current lr in the multiprocess training 72 | process 73 | ''' 74 | self.optimizer = optimizer 75 | self.max_iter = num_epochs * epoch_iter 76 | self.initial_lr = initial_lr 77 | self.final_lr = final_lr 78 | self.scale_ratio = scale_ratio 79 | self.current_iter = 0 80 | self.warm_up_iter = warm_up_epoch * epoch_iter 81 | self.warm_from_zero = warm_from_zero 82 | 83 | def get_multi_process_coeff(self): 84 | lr_coeff = 1.0 * self.scale_ratio 85 | if self.current_iter < self.warm_up_iter: 86 | if self.warm_from_zero: 87 | lr_coeff = self.scale_ratio * self.current_iter / self.warm_up_iter 88 | elif self.scale_ratio > 1: 89 | lr_coeff = (self.scale_ratio - 90 | 1) * self.current_iter / self.warm_up_iter + 1.0 91 | 92 | return lr_coeff 93 | 94 | def get_current_lr(self): 95 | ''' 96 | This function should be implemented in the child class 97 | ''' 98 | return 0.0 99 | 100 | def get_lr(self): 101 | return self.optimizer.param_groups[0]['lr'] 102 | 103 | def set_lr(self): 104 | current_lr = self.get_current_lr() 105 | for param_group in self.optimizer.param_groups: 106 | param_group['lr'] = current_lr 107 | 108 | def step(self, current_iter=None): 109 | if current_iter is not None: 110 | self.current_iter = current_iter 111 | 112 | self.set_lr() 113 | self.current_iter += 1 114 | 115 | def step_return_lr(self, current_iter=None): 116 | if current_iter is not None: 117 | self.current_iter = current_iter 118 | 119 | current_lr = self.get_current_lr() 120 | self.current_iter += 1 121 | 122 | return current_lr 123 | 124 | class ExponentialDecrease(BaseClass): 125 | 126 | def __init__(self, 127 | optimizer, 128 | num_epochs, 129 | epoch_iter, 130 | initial_lr, 131 | final_lr, 132 | warm_up_epoch=6, 133 | scale_ratio=1.0, 134 | warm_from_zero=False): 135 | super().__init__(optimizer, num_epochs, epoch_iter, initial_lr, 136 | final_lr, warm_up_epoch, scale_ratio, warm_from_zero) 137 | 138 | def get_current_lr(self): 139 | lr_coeff = self.get_multi_process_coeff() 140 | current_lr = lr_coeff * self.initial_lr * math.exp( 141 | (self.current_iter / self.max_iter) * 142 | math.log(self.final_lr / self.initial_lr)) 143 | return current_lr 144 | 145 | class TriAngular2(BaseClass): 146 | ''' 147 | The implementation of https://arxiv.org/pdf/1506.01186.pdf 148 | ''' 149 | 150 | def __init__(self, 151 | optimizer, 152 | num_epochs, 153 | epoch_iter, 154 | initial_lr, 155 | final_lr, 156 | warm_up_epoch=6, 157 | scale_ratio=1.0, 158 | cycle_step=2, 159 | reduce_lr_diff_ratio=0.5): 160 | super().__init__(optimizer, num_epochs, epoch_iter, initial_lr, 161 | final_lr, warm_up_epoch, scale_ratio) 162 | 163 | self.reduce_lr_diff_ratio = reduce_lr_diff_ratio 164 | self.cycle_iter = cycle_step * epoch_iter 165 | self.step_size = self.cycle_iter // 2 166 | 167 | self.max_lr = initial_lr 168 | self.min_lr = final_lr 169 | self.gap = self.max_lr - self.min_lr 170 | 171 | def get_current_lr(self): 172 | lr_coeff = self.get_multi_process_coeff() 173 | point = self.current_iter % self.cycle_iter 174 | cycle_index = self.current_iter // self.cycle_iter 175 | 176 | self.max_lr = self.min_lr + self.gap * self.reduce_lr_diff_ratio**cycle_index 177 | 178 | if point <= self.step_size: 179 | current_lr = self.min_lr + (self.max_lr - 180 | self.min_lr) * point / self.step_size 181 | else: 182 | current_lr = self.max_lr - (self.max_lr - self.min_lr) * ( 183 | point - self.step_size) / self.step_size 184 | 185 | current_lr = lr_coeff * current_lr 186 | 187 | return current_lr 188 | 189 | 190 | def show_lr_curve(scheduler): 191 | import matplotlib.pyplot as plt 192 | 193 | lr_list = [] 194 | for current_lr in range(0, scheduler.max_iter): 195 | lr_list.append(scheduler.step_return_lr(current_lr)) 196 | data_index = list(range(1, len(lr_list) + 1)) 197 | 198 | plt.plot(data_index, lr_list, '-o', markersize=1) 199 | plt.legend(loc='best') 200 | plt.xlabel("Iteration") 201 | plt.ylabel("LR") 202 | 203 | plt.show() 204 | 205 | 206 | if __name__ == '__main__': 207 | optimizer = None 208 | num_epochs = 6 209 | epoch_iter = 500 210 | initial_lr = 0.6 211 | final_lr = 0.1 212 | warm_up_epoch = 2 213 | scale_ratio = 4 214 | scheduler = ExponentialDecrease(optimizer, num_epochs, epoch_iter, 215 | initial_lr, final_lr, warm_up_epoch, 216 | scale_ratio) 217 | # scheduler = TriAngular2(optimizer, 218 | # num_epochs, 219 | # epoch_iter, 220 | # initial_lr, 221 | # final_lr, 222 | # warm_up_epoch, 223 | # scale_ratio, 224 | # cycle_step=2, 225 | # reduce_lr_diff_ratio=0.5) 226 | 227 | show_lr_curve(scheduler) 228 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export PYTHONPATH=$(pwd):$PYTHONPATH 4 | start_stage=2 5 | stop_stage=2 6 | step= 7 | input= 8 | output= 9 | 10 | echo $PYTHONPATH 11 | 12 | echo "$0 $@" # Print the command line for logging 13 | 14 | . ./utils/parse_options.sh 15 | 16 | if [ ${start_stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 17 | echo "stage 0: Data Preparation" 18 | python3 preprocess/data_prep.py 19 | fi 20 | 21 | if [ ${start_stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 22 | echo "stage 1: Feature Extraction" 23 | python preprocess/audio_preprocess.py --data-config configs/data.yaml \ 24 | --spec \ 25 | --mel \ 26 | --f0 \ 27 | --energy \ 28 | --stat 29 | fi 30 | 31 | if [ ${start_stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 32 | echo "stage 2: Training" 33 | export CUDA_VISIBLE_DEVICES=0,1 34 | python train_gan.py --data-config configs/svs/data.yaml \ 35 | --model-config configs/svs/model.yaml \ 36 | --train-config configs/svs/train.yaml \ 37 | --num-gpus 2 \ 38 | --dist-url 'tcp://localhost:30305' 39 | fi 40 | 41 | if [ ${start_stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 42 | echo "stage 3: Synthesizing" 43 | python synthesize.py --exp-name ${name} \ 44 | --step ${step} \ 45 | --input ${input} \ 46 | --output ${output} 47 | fi 48 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import math 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | import torch.multiprocessing as mp 11 | 12 | from dataset import SVSDataset, SVSCollate 13 | from models import Xiaoice2 as Generator 14 | from loss import FastSpeech2Loss 15 | import pyutils 16 | from pyutils import ( 17 | load_checkpoint, 18 | save_checkpoint, 19 | clean_checkpoints, 20 | latest_checkpoint_path, 21 | melspecplot, 22 | get_logger 23 | ) 24 | 25 | import wandb 26 | 27 | logging.basicConfig(format = "%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s: %(message)s", level = logging.INFO) 28 | 29 | class Trainer(): 30 | def __init__(self, rank, args, data_configs, model_configs, train_configs): 31 | self.rank = rank 32 | self.device = torch.device('cuda:{:d}'.format(rank)) 33 | trainset = SVSDataset(data_configs) 34 | collate_fn = SVSCollate() 35 | sampler = torch.utils.data.DistributedSampler(trainset) if args.num_gpus > 1 else None 36 | self.trainloader = DataLoader( 37 | trainset, 38 | shuffle = False, 39 | sampler = sampler, 40 | collate_fn = collate_fn, 41 | batch_size = train_configs['batch_size'], 42 | pin_memory = True, 43 | num_workers = train_configs['num_workers'], 44 | prefetch_factor = 10 45 | ) 46 | 47 | model_configs['generator']['transformer']['encoder']['n_src_vocab'] = trainset.get_phone_number() + 1 48 | model_configs['generator']['spk_num'] = trainset.get_spk_number() 49 | 50 | if args.num_gpus > 1: 51 | self.models = ( 52 | nn.parallel.DistributedDataParallel( 53 | Generator( 54 | data_configs, 55 | model_configs['generator'] 56 | ).to(self.device), 57 | device_ids = [rank] 58 | ), 59 | ) 60 | else: 61 | self.models = ( 62 | Generator( 63 | data_configs, 64 | model_configs['generator'] 65 | ).to(self.device), 66 | ) 67 | self.data_configs = data_configs 68 | self.model_configs = model_configs 69 | self.train_configs = train_configs 70 | self.args = args 71 | 72 | try: 73 | self.g_optimizer = getattr( 74 | torch.optim, train_configs['g_optimizer'] 75 | )(self.models[0].parameters(), **train_configs['g_optimizer_args']) 76 | 77 | self.g_scheduler = getattr( 78 | pyutils.scheduler, train_configs['g_scheduler'] 79 | )(self.g_optimizer, **train_configs['g_scheduler_args']) 80 | except: 81 | raise NotImplementedError("Unknown optimizer or scheduler") 82 | 83 | self.fs2loss = FastSpeech2Loss(data_configs) 84 | 85 | if self.rank == 0: 86 | self._make_exp_dir() 87 | self.logger = get_logger(os.path.join(self.args.exp_name, 'logs/train.log')) 88 | 89 | try: 90 | latest_ckpt_path = latest_checkpoint_path( 91 | os.path.join(self.args.exp_name, 'models'), 92 | 'G_*.pth' 93 | ) 94 | _, _, _, _, epoch_str = load_checkpoint( 95 | latest_ckpt_path, 96 | self.models[0], 97 | self.g_optimizer, 98 | self.g_scheduler, 99 | False 100 | ) 101 | self.start_epoch = max(epoch_str, 1) 102 | name = latest_ckpt_path 103 | self.total_step = int(name[name.rfind("_")+1:name.rfind(".")]) + 1 104 | except Exception: 105 | print("Load old checkpoint failed...") 106 | print("Start a new training...") 107 | self.start_epoch = 1 108 | self.total_step = 0 109 | 110 | self.epochs = self.train_configs['epochs'] 111 | 112 | def _dump_args_and_config(self, filename, config): 113 | with open(os.path.join(self.args.exp_name, 'configs', filename) + '.yaml', 'w') as f: 114 | yaml.dump(config, f) 115 | 116 | def _make_exp_dir(self): 117 | os.makedirs(self.args.exp_name, exist_ok=True) 118 | os.makedirs(os.path.join(self.args.exp_name, 'configs'), exist_ok=True) 119 | os.makedirs(os.path.join(self.args.exp_name, 'models'), exist_ok=True) 120 | os.makedirs(os.path.join(self.args.exp_name, 'audios'), exist_ok=True) 121 | os.makedirs(os.path.join(self.args.exp_name, 'spectrograms'), exist_ok=True) 122 | os.makedirs(os.path.join(self.args.exp_name, 'melspectrograms'), exist_ok=True) 123 | os.makedirs(os.path.join(self.args.exp_name, 'eval_results'), exist_ok=True) 124 | os.makedirs(os.path.join(self.args.exp_name, 'logs'), exist_ok = True) 125 | with open(os.path.join(self.args.exp_name, 'model_arch.txt'), 'w') as f: 126 | for model in self.models: 127 | print(model, file = f) 128 | self._dump_args_and_config('args', vars(self.args)) 129 | self._dump_args_and_config('data', self.data_configs) 130 | self._dump_args_and_config('model', self.model_configs) 131 | self._dump_args_and_config('train', self.train_configs) 132 | 133 | def train(self): 134 | for epoch in range(self.start_epoch, self.epochs + 1, 1): 135 | self.train_epoch(epoch) 136 | 137 | def train_epoch(self, epoch): 138 | self.total_loss = 0.0 139 | for batch_idx, data in enumerate(self.trainloader): 140 | output, postnet_output = self.train_batch(data, epoch, batch_idx) 141 | self.g_scheduler.step() 142 | 143 | if self.rank == 0 and self.total_step % self.train_configs['save_interval'] == 0: 144 | ckpt_path = os.path.join(self.args.exp_name, 'models', 'G_{}.pth'.format(self.total_step)) 145 | save_checkpoint( 146 | self.models[0], 147 | self.g_optimizer, 148 | self.g_scheduler, 149 | self.g_scheduler.get_lr()[0], 150 | epoch, 151 | ckpt_path 152 | ) 153 | length = data['mel_lens'][0] 154 | real_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', '{}_{}.png'.format(data['uttids'][0], self.total_step)) 155 | before_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'before_{}_{}.png'.format(data['uttids'][0], self.total_step)) 156 | after_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'after_{}_{}.png'.format(data['uttids'][0], self.total_step)) 157 | melspecplot(data['mels'][0][:length, :].transpose(1, 0).numpy(), real_pic_path) # (n_mels, T) 158 | melspecplot(output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), before_pic_path) 159 | melspecplot(postnet_output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), after_pic_path) 160 | 161 | if self.rank == 0: 162 | clean_checkpoints(os.path.join(self.args.exp_name, 'models'), n_ckpts_to_keep = self.train_configs['ckpt_clean']) 163 | 164 | def _move_to_device(self, data): 165 | new_data = {} 166 | for k, v in data.items(): 167 | if type(v) is torch.Tensor: 168 | new_data[k] = v.to(self.device) 169 | return new_data 170 | 171 | def train_batch(self, data, epoch, step): 172 | for model in self.models: 173 | model.train() 174 | new_data = self._move_to_device(data) 175 | 176 | self.g_optimizer.zero_grad() 177 | loss, report_keys, output, postnet_output = self.models[0](**new_data) 178 | loss.backward() 179 | grad_norm = nn.utils.clip_grad_norm_(self.models[0].parameters(), self.train_configs['grad_clip']) 180 | if math.isnan(grad_norm): 181 | raise ZeroDivisionError('Grad norm is nan') 182 | self.g_optimizer.step() 183 | self.total_loss += loss.item() 184 | 185 | self.total_step += 1 186 | if self.rank == 0: 187 | self.print_msg(epoch, step, report_keys) #, accuracy.item()) 188 | wandb_log_dict = { 189 | 'train/avg_g_loss': self.total_loss / (step + 1), 190 | 'train/g_lr': self.g_scheduler.get_lr()[0] 191 | } 192 | for k, v in report_keys.items(): 193 | wandb_log_dict['train/' + k] = v 194 | wandb.log(wandb_log_dict) 195 | return output, postnet_output 196 | 197 | def print_msg(self, epoch, step, report_keys): 198 | if self.total_step % self.train_configs['log_interval'] == 0: 199 | temp = '' 200 | for k, v in report_keys.items(): 201 | temp += '{}: {:.6f} '.format(k, v) 202 | message = ('[Epoch: {} Step: {} Total steps: {}] ' + temp).format( 203 | epoch, step + 1, self.total_step 204 | ) 205 | self.logger.info(message) 206 | 207 | def parse_args(): 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--data-config', dest = 'data_config', type = str, default = './conf/data.yaml') 210 | parser.add_argument('--model-config', dest = 'model_config', type = str, default = './conf/model.yaml') 211 | parser.add_argument('--train-config', dest = 'train_config', type = str, default = './conf/train.yaml') 212 | parser.add_argument('--num-gpus', dest = 'num_gpus', type = int, default = 1) 213 | parser.add_argument('--exp-name', dest = 'exp_name', type = str, default = 'default') 214 | parser.add_argument('--dist-backend', dest = 'dist_backend', type = str, default = 'nccl') 215 | parser.add_argument('--dist-url', dest = 'dist_url', type = str, default = 'tcp://localhost:30302') 216 | return parser.parse_args() 217 | 218 | def main(rank, args, configs): 219 | if args.num_gpus > 1: 220 | torch.cuda.set_device(rank) 221 | torch.distributed.init_process_group( 222 | backend = args.dist_backend, 223 | init_method = args.dist_url, 224 | world_size = args.num_gpus, 225 | rank = rank 226 | ) 227 | 228 | data_configs, model_configs, train_configs = configs 229 | args.exp_name = train_configs['wandb_args']['group'] + '-' + \ 230 | train_configs['wandb_args']['job_type'] + '-' + \ 231 | train_configs['wandb_args']['name'] 232 | args.exp_name = os.path.join('exp', args.exp_name) 233 | 234 | # wandb initialization 235 | if train_configs['wandb']: 236 | wandb_configs = vars(args) 237 | for config in configs: 238 | wandb_configs.update(config) 239 | wandb.init( 240 | **train_configs['wandb_args'], 241 | config = wandb_configs 242 | ) 243 | 244 | trainer = Trainer(rank, args, data_configs, model_configs, train_configs) 245 | trainer.train() 246 | 247 | if train_configs['wandb']: 248 | wandb.finish() 249 | 250 | if __name__ == "__main__": 251 | args = parse_args() 252 | args.exp_name = os.path.join('exp', args.exp_name) 253 | with open(args.data_config, 'r') as f: 254 | data_configs = yaml.load(f, Loader = yaml.FullLoader) 255 | with open(args.model_config, 'r') as f: 256 | model_configs = yaml.load(f, Loader = yaml.FullLoader) 257 | with open(args.train_config, 'r') as f: 258 | train_configs = yaml.load(f, Loader = yaml.FullLoader) 259 | configs = (data_configs, model_configs, train_configs) 260 | 261 | num_gpus = torch.cuda.device_count() 262 | if args.num_gpus > 1: 263 | mp.spawn(main, nprocs = num_gpus, args = (args, configs)) 264 | else: 265 | main(0, args, configs) 266 | -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import math 5 | import shutil 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import torch.multiprocessing as mp 12 | 13 | from dataset import SVSDataset as Dataset 14 | from dataset import SVSCollate as Collate 15 | from models import Xiaoice2 as Generator 16 | from models import Discriminator as Discriminator 17 | from loss import FastSpeech2Loss, FeatLoss, LSGANGLoss, LSGANDLoss 18 | import pyutils 19 | from pyutils import ( 20 | load_checkpoint, 21 | save_checkpoint, 22 | clean_checkpoints, 23 | latest_checkpoint_path, 24 | melspecplot, 25 | get_logger 26 | ) 27 | 28 | import wandb 29 | 30 | logging.basicConfig(format = "%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s: %(message)s", level = logging.INFO) 31 | 32 | python_script = os.path.realpath(__file__) 33 | 34 | class Trainer(): 35 | def __init__(self, rank, args, data_configs, model_configs, train_configs): 36 | self.rank = rank 37 | self.device = torch.device('cuda:{:d}'.format(rank)) 38 | trainset = Dataset(data_configs) 39 | collate_fn = Collate() 40 | sampler = torch.utils.data.DistributedSampler(trainset) if args.num_gpus > 1 else None 41 | self.trainloader = DataLoader( 42 | trainset, 43 | shuffle = False, 44 | sampler = sampler, 45 | collate_fn = collate_fn, 46 | batch_size = train_configs['batch_size'], 47 | pin_memory = True, 48 | num_workers = train_configs['num_workers'], 49 | prefetch_factor = 10 50 | ) 51 | 52 | model_configs['generator']['transformer']['encoder']['n_src_vocab'] = trainset.get_phone_number() + 1 53 | model_configs['generator']['spk_num'] = trainset.get_spk_number() 54 | 55 | if args.num_gpus > 1: 56 | self.models = ( 57 | nn.parallel.DistributedDataParallel( 58 | Generator( 59 | data_configs, 60 | model_configs['generator'] 61 | ).to(self.device), 62 | device_ids = [rank] 63 | ), 64 | nn.parallel.DistributedDataParallel( 65 | Discriminator().to(self.device), 66 | device_ids = [rank] 67 | ) 68 | ) 69 | else: 70 | self.models = ( 71 | Generator( 72 | data_configs, 73 | model_configs['generator'] 74 | ).to(self.device), 75 | Discriminator().to(self.device) 76 | ) 77 | self.data_configs = data_configs 78 | self.model_configs = model_configs 79 | self.train_configs = train_configs 80 | self.args = args 81 | 82 | try: 83 | self.g_optimizer = getattr( 84 | torch.optim, train_configs['g_optimizer'] 85 | )(self.models[0].parameters(), **train_configs['g_optimizer_args']) 86 | 87 | self.g_scheduler = getattr( 88 | pyutils.scheduler, train_configs['g_scheduler'] 89 | )(self.g_optimizer, **train_configs['g_scheduler_args']) 90 | 91 | self.d_optimizer = getattr( 92 | torch.optim, train_configs['d_optimizer'] 93 | )(self.models[1].parameters(), **train_configs['d_optimizer_args']) 94 | 95 | self.d_scheduler = getattr( 96 | pyutils.scheduler, train_configs['d_scheduler'] 97 | )(self.d_optimizer, **train_configs['d_scheduler_args']) 98 | except: 99 | raise NotImplementedError("Unknown optimizer or scheduler") 100 | 101 | self.fs2loss = FastSpeech2Loss(data_configs) 102 | self.feat_loss = FeatLoss(train_configs['feat_loss_weight']) 103 | self.adv_g_loss = LSGANGLoss(train_configs['adv_g_loss_weight']) 104 | self.adv_d_loss = LSGANDLoss() 105 | 106 | if self.rank == 0: 107 | self._make_exp_dir() 108 | self.logger = get_logger(os.path.join(self.args.exp_name, 'logs/train.log')) 109 | 110 | try: 111 | latest_gckpt_path = latest_checkpoint_path( 112 | os.path.join(self.args.exp_name, 'models'), 113 | 'G_*.pth' 114 | ) 115 | latest_dckpt_path = latest_checkpoint_path( 116 | os.path.join(self.args.exp_name, 'models'), 117 | 'D_*.pth' 118 | ) 119 | _, _, _, _, epoch_str = load_checkpoint( 120 | latest_gckpt_path, 121 | self.models[0], 122 | self.g_optimizer, 123 | self.g_scheduler, 124 | False 125 | ) 126 | _, _, _, _, epoch_str = load_checkpoint( 127 | latest_dckpt_path, 128 | self.models[1], 129 | self.d_optimizer, 130 | self.d_scheduler, 131 | False 132 | ) 133 | self.start_epoch = max(epoch_str, 1) 134 | name = latest_gckpt_path 135 | self.total_step = int(name[name.rfind("_")+1:name.rfind(".")])+1 136 | except Exception: 137 | print("Load old checkpoint failed...") 138 | print("Start a new training...") 139 | self.start_epoch = 1 140 | self.total_step = 0 141 | 142 | self.epochs = self.train_configs['epochs'] 143 | self.start_disc_steps = self.train_configs['start_disc_steps'] 144 | 145 | def _dump_args_and_config(self, filename, config): 146 | with open(os.path.join(self.args.exp_name, 'conf', filename) + '.yaml', 'w') as f: 147 | yaml.dump(config, f) 148 | 149 | def _make_exp_dir(self): 150 | os.makedirs(self.args.exp_name, exist_ok=True) 151 | os.makedirs(os.path.join(self.args.exp_name, 'conf'), exist_ok=True) 152 | os.makedirs(os.path.join(self.args.exp_name, 'models'), exist_ok=True) 153 | os.makedirs(os.path.join(self.args.exp_name, 'audios'), exist_ok=True) 154 | os.makedirs(os.path.join(self.args.exp_name, 'spectrograms'), exist_ok=True) 155 | os.makedirs(os.path.join(self.args.exp_name, 'melspectrograms'), exist_ok=True) 156 | os.makedirs(os.path.join(self.args.exp_name, 'eval_results'), exist_ok=True) 157 | os.makedirs(os.path.join(self.args.exp_name, 'logs'), exist_ok = True) 158 | with open(os.path.join(self.args.exp_name, 'model_arch.txt'), 'w') as f: 159 | for model in self.models: 160 | print(model, file = f) 161 | self._dump_args_and_config('args', vars(self.args)) 162 | self._dump_args_and_config('data', self.data_configs) 163 | self._dump_args_and_config('model', self.model_configs) 164 | self._dump_args_and_config('train', self.train_configs) 165 | basename = os.path.basename(python_script) 166 | shutil.copyfile(python_script, os.path.join(self.args.exp_name, basename)) 167 | 168 | def train(self): 169 | for epoch in range(self.start_epoch, self.epochs + 1, 1): 170 | self.train_epoch(epoch) 171 | 172 | def train_epoch(self, epoch): 173 | self.total_g_loss = 0.0 174 | self.total_d_loss = 0.0 175 | for batch_idx, data in enumerate(self.trainloader): 176 | output, postnet_output = self.train_batch(data, epoch, batch_idx) 177 | self.g_scheduler.step() 178 | self.d_scheduler.step() 179 | 180 | if self.rank == 0 and self.total_step % self.train_configs['save_interval'] == 0: 181 | gckpt_path = os.path.join(self.args.exp_name, 'models', 'G_{}.pth'.format(self.total_step)) 182 | save_checkpoint( 183 | self.models[0], 184 | self.g_optimizer, 185 | self.g_scheduler, 186 | self.g_scheduler.get_lr()[0], 187 | epoch, 188 | gckpt_path 189 | ) 190 | dckpt_path = os.path.join(self.args.exp_name, 'models', 'D_{}.pth'.format(self.total_step)) 191 | save_checkpoint( 192 | self.models[1], 193 | self.d_optimizer, 194 | self.d_scheduler, 195 | self.d_scheduler.get_lr()[0], 196 | epoch, 197 | dckpt_path 198 | ) 199 | length = data['mel_lens'][0] 200 | real_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', '{}_{}.png'.format(data['uttids'][0], self.total_step)) 201 | before_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'before_{}_{}.png'.format(data['uttids'][0], self.total_step)) 202 | after_pic_path = os.path.join(self.args.exp_name, 'melspectrograms', 'after_{}_{}.png'.format(data['uttids'][0], self.total_step)) 203 | melspecplot(data['mels'][0][:length, :].transpose(1, 0).numpy(), real_pic_path) # (n_mels, T) 204 | melspecplot(output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), before_pic_path) 205 | melspecplot(postnet_output[0][:length, :].transpose(1, 0).detach().cpu().numpy(), after_pic_path) 206 | 207 | if self.rank == 0: 208 | clean_checkpoints( 209 | os.path.join(self.args.exp_name, 'models'), 210 | n_ckpts_to_keep = self.train_configs['ckpt_clean'] 211 | ) 212 | 213 | def _move_to_device(self, data): 214 | new_data = {} 215 | for k, v in data.items(): 216 | if type(v) is torch.Tensor: 217 | new_data[k] = v.to(self.device) 218 | return new_data 219 | 220 | def train_batch(self, data, epoch, step): 221 | for model in self.models: 222 | model.train() 223 | new_data = self._move_to_device(data) 224 | 225 | # loss, report_keys, output, postnet_output = self.models[0](**new_data) 226 | self.g_optimizer.zero_grad() 227 | g_loss, report_keys, output, postnet_output = self.models[0](**new_data) 228 | if self.total_step >= self.train_configs['start_disc_steps']: 229 | d_fake, random_N = self.models[1](output, new_data['mel_lens'], new_data['mels']) 230 | feat_loss, feat_loss_report_keys = self.feat_loss(d_fake) 231 | adv_g_loss, adv_gloss_report_keys = self.adv_g_loss(d_fake) 232 | g_loss += feat_loss 233 | g_loss += adv_g_loss 234 | report_keys.update(feat_loss_report_keys) 235 | report_keys.update(adv_gloss_report_keys) 236 | g_loss.backward() 237 | 238 | grad_norm = nn.utils.clip_grad_norm_(self.models[0].parameters(), self.train_configs['grad_clip']) 239 | if math.isnan(grad_norm): 240 | raise ZeroDivisionError('Grad norm is nan') 241 | self.g_optimizer.step() 242 | self.total_g_loss += g_loss.item() 243 | 244 | if self.total_step >= self.train_configs['start_disc_steps']: 245 | self.d_optimizer.zero_grad() 246 | d_fake, _ = self.models[1]( 247 | output.detach(), 248 | new_data['mel_lens'], 249 | new_data['mels'], 250 | random_N 251 | ) 252 | adv_d_loss, adv_dloss_report_keys = self.adv_d_loss(d_fake) 253 | adv_d_loss.backward() 254 | report_keys.update(adv_dloss_report_keys) 255 | grad_norm = nn.utils.clip_grad_norm_(self.models[1].parameters(), self.train_configs['grad_clip']) 256 | if math.isnan(grad_norm): 257 | raise ZeroDivisionError('Grad norm is nan') 258 | self.d_optimizer.step() 259 | self.total_d_loss += adv_d_loss.item() 260 | 261 | self.total_step += 1 262 | if self.rank == 0: 263 | self.print_msg(epoch, step, report_keys) #, accuracy.item()) 264 | wandb_log_dict = { 265 | 'train/avg_g_loss': self.total_g_loss / (step + 1), 266 | 'train/avg_d_loss': self.total_d_loss / (step + 1), 267 | 'train/g_lr': self.g_scheduler.get_lr()[0], 268 | 'train/d_lr': self.d_scheduler.get_lr()[0] 269 | } 270 | for k, v in report_keys.items(): 271 | wandb_log_dict['train/' + k] = v 272 | if self.train_configs['wandb']: 273 | wandb.log(wandb_log_dict) 274 | return output, postnet_output 275 | 276 | def print_msg(self, epoch, step, report_keys): 277 | if self.total_step % self.train_configs['log_interval'] == 0: 278 | temp = '' 279 | for k, v in report_keys.items(): 280 | temp += '{}: {:.6f} '.format(k, v) 281 | message = ('[Epoch: {} Step: {} Total steps: {}] ' + temp).format( 282 | epoch, step + 1, self.total_step 283 | ) 284 | self.logger.info(message) 285 | 286 | def parse_args(): 287 | parser = argparse.ArgumentParser() 288 | parser.add_argument('--data-config', dest = 'data_config', type = str, default = './conf/data.yaml') 289 | parser.add_argument('--model-config', dest = 'model_config', type = str, default = './conf/model.yaml') 290 | parser.add_argument('--train-config', dest = 'train_config', type = str, default = './conf/train.yaml') 291 | parser.add_argument('--num-gpus', dest = 'num_gpus', type = int, default = 1) 292 | # parser.add_argument('--exp-name', dest = 'exp_name', type = str, default = 'default') 293 | parser.add_argument('--dist-backend', dest = 'dist_backend', type = str, default = 'nccl') 294 | parser.add_argument('--dist-url', dest = 'dist_url', type = str, default = 'tcp://localhost:30302') 295 | return parser.parse_args() 296 | 297 | def main(rank, args, configs): 298 | if args.num_gpus > 1: 299 | torch.cuda.set_device(rank) 300 | torch.distributed.init_process_group( 301 | backend = args.dist_backend, 302 | init_method = args.dist_url, 303 | world_size = args.num_gpus, 304 | rank = rank 305 | ) 306 | 307 | data_configs, model_configs, train_configs = configs 308 | args.exp_name = train_configs['wandb_args']['group'] + '-' + \ 309 | train_configs['wandb_args']['job_type'] + '-' + \ 310 | train_configs['wandb_args']['name'] 311 | args.exp_name = os.path.join('exp', args.exp_name) 312 | 313 | # wandb initialization 314 | if train_configs['wandb']: 315 | wandb_configs = vars(args) 316 | for config in configs: 317 | wandb_configs.update(config) 318 | wandb.init( 319 | **train_configs['wandb_args'], 320 | config = wandb_configs 321 | ) 322 | 323 | trainer = Trainer(rank, args, data_configs, model_configs, train_configs) 324 | trainer.train() 325 | 326 | if train_configs['wandb']: 327 | wandb.finish() 328 | 329 | if __name__ == "__main__": 330 | args = parse_args() 331 | with open(args.data_config, 'r') as f: 332 | data_configs = yaml.load(f, Loader = yaml.FullLoader) 333 | with open(args.model_config, 'r') as f: 334 | model_configs = yaml.load(f, Loader = yaml.FullLoader) 335 | with open(args.train_config, 'r') as f: 336 | train_configs = yaml.load(f, Loader = yaml.FullLoader) 337 | configs = (data_configs, model_configs, train_configs) 338 | 339 | num_gpus = torch.cuda.device_count() 340 | if args.num_gpus > 1: 341 | mp.spawn(main, nprocs = num_gpus, args = (args, configs)) 342 | else: 343 | main(0, args, configs) 344 | -------------------------------------------------------------------------------- /utils: -------------------------------------------------------------------------------- 1 | /home/smg/zengchang/apps/kaldi/egs/wsj/s5/utils --------------------------------------------------------------------------------