├── README.md ├── assets └── fig1.png ├── checkpoints └── hifi_vctk │ ├── config.json │ ├── config.yaml │ └── generator_v1 ├── data_gen └── tts │ ├── base_binarizer.py │ ├── base_preprocess.py │ ├── runs │ ├── align_and_binarize.py │ ├── binarize.py │ ├── preprocess.py │ └── train_mfa_align.py │ ├── txt_processors │ ├── __init__.py │ ├── base_text_processor.py │ └── en.py │ └── wav_processors │ ├── __init__.py │ ├── base_processor.py │ └── common_processors.py ├── egs ├── datasets │ └── audio │ │ └── vctk │ │ ├── base_text2mel.yaml │ │ ├── diffprosody.yaml │ │ ├── hifigan.yaml │ │ ├── preprocess.py │ │ └── prosody_generator.yaml └── egs_bases │ ├── config_base.yaml │ └── tts │ ├── base.yaml │ ├── dataset_params.yaml │ ├── dp.yaml │ ├── fs.yaml │ ├── pg.yaml │ └── vocoder │ ├── base.yaml │ └── hifigan.yaml ├── extract_lpv.py ├── inference └── tts │ ├── base_tts_infer.py │ └── dp.py ├── mfa_usr ├── adapt.py ├── adapt_config.yaml ├── install_mfa.sh ├── mfa.py ├── mfa_train_config.yaml ├── run_mfa_align.py └── run_mfa_train_align.sh ├── modules ├── commons │ ├── conv.py │ ├── layers.py │ ├── nar_tts_modules.py │ ├── rel_transformer.py │ ├── rnn.py │ ├── transformer.py │ └── wavenet.py ├── tts │ ├── commons │ │ └── align_ops.py │ ├── diffprosody │ │ ├── diffprosody.py │ │ ├── diffusion.py │ │ ├── discriminator.py │ │ └── prosody_encoder.py │ ├── fs.py │ └── fs2_orig.py └── vocoder │ └── hifigan │ ├── hifigan.py │ ├── mel_utils.py │ └── stft_loss.py ├── requirements.txt ├── run.sh ├── tasks ├── run.py ├── tts │ ├── dataset_utils.py │ ├── diffprosody.py │ ├── fs.py │ ├── prosody_generator.py │ ├── speech_base.py │ ├── tts_utils.py │ └── vocoder_infer │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── base_vocoder.cpython-38.pyc │ │ └── hifigan.cpython-38.pyc │ │ ├── base_vocoder.py │ │ └── hifigan.py └── vocoder │ ├── dataset_utils.py │ ├── hifigan.py │ └── vocoder_base.py └── utils ├── audio ├── align.py ├── cwt.py ├── griffin_lim.py ├── io.py ├── pitch │ ├── __pycache__ │ │ └── utils.cpython-38.pyc │ └── utils.py ├── pitch_extractors.py ├── rnnoise.py └── vad.py ├── commons ├── base_task.py ├── ckpt_utils.py ├── dataset_utils.py ├── ddp_utils.py ├── hparams.py ├── indexed_datasets.py ├── meters.py ├── multiprocess_utils.py ├── single_thread_env.py ├── tensor_utils.py └── trainer.py ├── metrics ├── diagonal_metrics.py ├── dtw.py ├── laplace_var.py ├── mcd.py ├── pitch_distance.py └── ssim.py ├── nn ├── model_utils.py ├── schedulers.py └── seq_utils.py ├── os_utils.py ├── plot └── plot.py └── text ├── encoding.py └── text_encoder.py /README.md: -------------------------------------------------------------------------------- 1 | ### DiffProsody: Diffusion-based Latent Prosody Generation for Expressive Speech Synthesis with Prosody Conditional Adversarial Training [[Demo]](https://prml-lab-speech-team.github.io/demo/DiffProsody/) 2 | 3 | ## Abstract 4 | 5 | Expressive text-to-speech systems have undergone significant advancements owing to prosody modeling, but conventional methods can still be improved. Traditional approaches have relied on the autoregressive method to predict the quantized prosody vector; however, it suffers from the issues of long-term dependency and slow inference. This study proposes a novel approach called DiffProsody in which expressive speech is synthesized using a diffusion-based latent prosody generator and prosody conditional adversarial training. Our findings confirm the effectiveness of our prosody generator in generating a prosody vector. Furthermore, our prosody conditional discriminator significantly improves the quality of the generated speech by accurately emulating prosody. We use denoising diffusion generative adversarial networks to improve the prosody generation speed. Consequently, DiffProsody is capable of generating prosody 16 times faster than the conventional diffusion model. The superior performance of our proposed method has been demonstrated via experiments. 6 | 7 | ## Model 8 | ![image](assets/fig1.png) 9 | 10 | ## Training Procedure 11 | ### Environments 12 | ``` 13 | pip install -r requirements.txt 14 | sudo apt install -y sox libsox-fmt-mp3 15 | bash mfa_usr/install_mfa.sh # install force alignment tools 16 | ``` 17 | 18 | ### 1. Preprocess data 19 | 20 | - Download [VCTK](https://datashare.ed.ac.uk/handle/10283/2651) dataset 21 | 22 | ```bash 23 | # Preprocess step: text and unify the file structure. 24 | python data_gen/tts/runs/preprocess.py --config "egs/datasets/audio/vctk/diffprosody.yaml" 25 | # Align step: MFA alignment. 26 | python data_gen/tts/runs/train_mfa_align.py --config "egs/datasets/audio/vctk/diffprosody.yaml" 27 | # Binarization step: Binarize data for fast IO. You only need to rerun this line when running different task if you have `preprocess`ed and `align`ed the dataset before. 28 | python data_gen/tts/runs/binarize.py --config "egs/datasets/audio/vctk/diffprosody.yaml" 29 | ``` 30 | 31 | ### 2. Training TTS module and prosody encoder 32 | ```bash 33 | export PYTHONPATH=. 34 | CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config "egs/datasets/audio/vctk/diffprosody.yaml" --exp_name "DiffProsody" 35 | ``` 36 | 37 | ### 3. Extracting latent prosody vector 38 | ```bash 39 | CUDA_VISIBLE_DEVICES=0 python extract_lpv.py --config "egs/datasets/audio/vctk/diffprosody.yaml" --exp_name "DiffProsody" 40 | ``` 41 | 42 | ### 4. Training diffusion-based latent prosody generator 43 | - You should set the path according to your environment 44 | ```bash 45 | CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config "egs/datasets/audio/vctk/prosody_generator.yaml" --exp_name "DiffProsodyGenerator" --reset --hparams="tts_model=/{ckpt dir}/DiffProsody" 46 | ``` 47 | 48 | ### 5. Inference 49 | 50 | ```bash 51 | CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config "egs/datasets/audio/vctk/prosody_generator.yaml" --exp_name "DiffProsodyGenerator" --infer --hparams="tts_model=/{ckpt dir}/DiffProsody" 52 | ``` 53 | 54 | ### 6. Pretrained checkpoints 55 | - TTS module trained on 160k [[Download]](https://works.do/xsBlIw8) 56 | - Diffusion-based prosody generator trained on 320k [[Download]](https://works.do/5CAF6E0) 57 | 58 | ## Acknowledgements 59 | **Our codes are based on the following repos:** 60 | * [NATSpeech](https://github.com/NATSpeech/NATSpeech) 61 | * [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) 62 | * [HifiGAN](https://github.com/jik876/hifi-gan) 63 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hs-oh-prml/DiffProsody/6d5b6dbb58497fdff791d06fca09a4fae2a2cc11/assets/fig1.png -------------------------------------------------------------------------------- /checkpoints/hifi_vctk/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /checkpoints/hifi_vctk/config.yaml: -------------------------------------------------------------------------------- 1 | accumulate_grad_batches: 1 2 | adam_b1: 0.8 3 | adam_b2: 0.99 4 | amp: false 5 | audio_num_mel_bins: 80 6 | audio_sample_rate: 22050 7 | base_config: 8 | - configs/tts/hifigan.yaml 9 | - configs/tts/libritts/base_mel2wav.yaml 10 | binarization_args: 11 | shuffle: false 12 | trim_eos_bos: false 13 | trim_sil: false 14 | with_align: false 15 | with_f0: true 16 | with_f0cwt: false 17 | with_linear: false 18 | with_spk_embed: false 19 | with_txt: true 20 | with_wav: true 21 | binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer 22 | binary_data_dir: data/binary/ljspeech_wav 23 | check_val_every_n_epoch: 10 24 | clip_grad_norm: 1 25 | clip_grad_value: 0 26 | debug: false 27 | dec_ffn_kernel_size: 9 28 | dec_layers: 4 29 | dict_dir: '' 30 | disc_start_steps: 40000 31 | discriminator_grad_norm: 1 32 | discriminator_optimizer_params: 33 | eps: 1.0e-06 34 | lr: 0.0002 35 | weight_decay: 0.0 36 | discriminator_params: 37 | bias: true 38 | conv_channels: 64 39 | in_channels: 1 40 | kernel_size: 3 41 | layers: 10 42 | nonlinear_activation: LeakyReLU 43 | nonlinear_activation_params: 44 | negative_slope: 0.2 45 | out_channels: 1 46 | use_weight_norm: true 47 | discriminator_scheduler_params: 48 | gamma: 0.999 49 | step_size: 600 50 | dropout: 0.1 51 | ds_workers: 1 52 | enc_ffn_kernel_size: 9 53 | enc_layers: 4 54 | endless_ds: true 55 | ffn_act: gelu 56 | ffn_padding: SAME 57 | fft_size: 1024 58 | fm_loss: false 59 | fmax: 8000 60 | fmin: 80 61 | frames_multiple: 1 62 | gen_dir_name: '' 63 | generator_grad_norm: 10 64 | generator_optimizer_params: 65 | eps: 1.0e-06 66 | lr: 0.0002 67 | weight_decay: 0.0 68 | generator_params: 69 | aux_channels: 80 70 | aux_context_window: 0 71 | dropout: 0.0 72 | gate_channels: 128 73 | in_channels: 1 74 | kernel_size: 3 75 | layers: 30 76 | out_channels: 1 77 | residual_channels: 64 78 | skip_channels: 64 79 | stacks: 3 80 | upsample_net: ConvInUpsampleNetwork 81 | upsample_params: 82 | upsample_scales: 83 | - 4 84 | - 4 85 | - 4 86 | - 4 87 | use_nsf: false 88 | use_pitch_embed: false 89 | use_weight_norm: true 90 | generator_scheduler_params: 91 | gamma: 0.999 92 | step_size: 600 93 | griffin_lim_iters: 60 94 | hidden_size: 256 95 | hop_size: 256 96 | infer: false 97 | lambda_adv: 4.0 98 | lambda_mel: 45.0 99 | load_ckpt: '' 100 | loud_norm: false 101 | lr: 2.0 102 | max_epochs: 1000 103 | max_eval_sentences: 1 104 | max_eval_tokens: 60000 105 | max_frames: 1548 106 | max_input_tokens: 1550 107 | max_samples: 8192 108 | max_sentences: 24 109 | max_tokens: 30000 110 | max_updates: 3000000 111 | mel_vmax: 1.5 112 | mel_vmin: -6 113 | min_level_db: -100 114 | num_ckpt_keep: 3 115 | num_heads: 2 116 | num_mels: 80 117 | num_sanity_val_steps: 5 118 | num_spk: 1 119 | optimizer_adam_beta1: 0.9 120 | optimizer_adam_beta2: 0.98 121 | out_wav_norm: false 122 | pitch_extractor: parselmouth 123 | pre_align_args: 124 | allow_no_txt: false 125 | denoise: false 126 | forced_align: mfa 127 | sox_resample: false 128 | trim_sil: false 129 | txt_processor: en 130 | use_tone: true 131 | pre_align_cls: '' 132 | print_nan_grads: false 133 | processed_data_dir: /workspace/dataset/processed/libritts 134 | profile_infer: false 135 | raw_data_dir: /workspace/dataset/libritts 136 | ref_level_db: 20 137 | rerun_gen: true 138 | resblock: '1' 139 | resblock_dilation_sizes: 140 | - - 1 141 | - 3 142 | - 5 143 | - - 1 144 | - 3 145 | - 5 146 | - - 1 147 | - 3 148 | - 5 149 | resblock_kernel_sizes: 150 | - 3 151 | - 7 152 | - 11 153 | reset_phone_dict: true 154 | resume_from_checkpoint: 0 155 | sampling_rate: 22050 156 | save_best: true 157 | save_codes: [] 158 | save_f0: false 159 | save_gt: true 160 | seed: 1234 161 | sort_by_len: true 162 | stft_loss_params: 163 | fft_sizes: 164 | - 1024 165 | - 2048 166 | - 512 167 | hop_sizes: 168 | - 120 169 | - 240 170 | - 50 171 | win_lengths: 172 | - 600 173 | - 1200 174 | - 240 175 | window: hann_window 176 | stop_token_weight: 5.0 177 | task_cls: tasks.vocoder.hifigan.HifiGanTask 178 | tb_log_interval: 100 179 | test_input_dir: '' 180 | test_num: 100 181 | test_set_name: test 182 | train_set_name: train 183 | upsample_initial_channel: 512 184 | upsample_kernel_sizes: 185 | - 16 186 | - 16 187 | - 4 188 | - 4 189 | upsample_rates: 190 | - 8 191 | - 8 192 | - 2 193 | - 2 194 | use_mel_loss: false 195 | use_pitch_embed: false 196 | val_check_interval: 2000 197 | valid_monitor_key: val_loss 198 | valid_monitor_mode: min 199 | valid_set_name: valid 200 | vocoder: pwg 201 | vocoder_ckpt: '' 202 | warmup_updates: 8000 203 | weight_decay: 0 204 | win_length: null 205 | win_size: 1024 206 | window: hann 207 | work_dir: /workspace/checkpoints/hifi 208 | -------------------------------------------------------------------------------- /checkpoints/hifi_vctk/generator_v1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hs-oh-prml/DiffProsody/6d5b6dbb58497fdff791d06fca09a4fae2a2cc11/checkpoints/hifi_vctk/generator_v1 -------------------------------------------------------------------------------- /data_gen/tts/runs/align_and_binarize.py: -------------------------------------------------------------------------------- 1 | import utils.commons.single_thread_env # NOQA 2 | from utils.commons.hparams import set_hparams, hparams 3 | from data_gen.tts.runs.binarize import binarize 4 | from data_gen.tts.runs.preprocess import preprocess 5 | from data_gen.tts.runs.train_mfa_align import train_mfa_align 6 | 7 | if __name__ == '__main__': 8 | set_hparams() 9 | preprocess() 10 | if hparams['preprocess_args']['use_mfa']: 11 | train_mfa_align() 12 | binarize() 13 | -------------------------------------------------------------------------------- /data_gen/tts/runs/binarize.py: -------------------------------------------------------------------------------- 1 | import utils.commons.single_thread_env # NOQA 2 | from utils.commons.hparams import hparams, set_hparams 3 | import importlib 4 | 5 | 6 | def binarize(): 7 | binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') 8 | pkg = ".".join(binarizer_cls.split(".")[:-1]) 9 | cls_name = binarizer_cls.split(".")[-1] 10 | binarizer_cls = getattr(importlib.import_module(pkg), cls_name) 11 | print("| Binarizer: ", binarizer_cls) 12 | binarizer_cls().process() 13 | 14 | 15 | if __name__ == '__main__': 16 | set_hparams() 17 | binarize() 18 | -------------------------------------------------------------------------------- /data_gen/tts/runs/preprocess.py: -------------------------------------------------------------------------------- 1 | import utils.commons.single_thread_env # NOQA 2 | from utils.commons.hparams import hparams, set_hparams 3 | import importlib 4 | 5 | 6 | def preprocess(): 7 | assert hparams['preprocess_cls'] != '' 8 | 9 | pkg = ".".join(hparams["preprocess_cls"].split(".")[:-1]) 10 | cls_name = hparams["preprocess_cls"].split(".")[-1] 11 | process_cls = getattr(importlib.import_module(pkg), cls_name) 12 | process_cls().process() 13 | 14 | 15 | if __name__ == '__main__': 16 | set_hparams() 17 | preprocess() 18 | -------------------------------------------------------------------------------- /data_gen/tts/runs/train_mfa_align.py: -------------------------------------------------------------------------------- 1 | import utils.commons.single_thread_env # NOQA 2 | import glob 3 | import subprocess 4 | from textgrid import TextGrid 5 | import os 6 | from utils.commons.hparams import hparams, set_hparams 7 | 8 | 9 | def train_mfa_align(mfa_outputs="mfa_outputs", 10 | mfa_inputs="mfa_inputs", 11 | model_name=None, pretrain_model_name=None, 12 | mfa_cmd='train'): 13 | CORPUS = hparams['processed_data_dir'].split("/")[-1] 14 | NUM_JOB = int(os.getenv('N_PROC', os.cpu_count())) 15 | env_vars = [f'CORPUS={CORPUS}', f'NUM_JOB={NUM_JOB}'] 16 | if mfa_outputs is not None: 17 | env_vars.append(f'MFA_OUTPUTS={mfa_outputs}') 18 | if mfa_inputs is not None: 19 | env_vars.append(f'MFA_INPUTS={mfa_inputs}') 20 | if model_name is not None: 21 | env_vars.append(f'MODEL_NAME={model_name}') 22 | if pretrain_model_name is not None: 23 | env_vars.append(f'PRETRAIN_MODEL_NAME={pretrain_model_name}') 24 | if mfa_cmd is not None: 25 | env_vars.append(f'MFA_CMD={mfa_cmd}') 26 | env_str = ' '.join(env_vars) 27 | print(f"| Run MFA for {CORPUS}. Env vars: {env_str}") 28 | subprocess.check_call(f'{env_str} bash mfa_usr/run_mfa_train_align.sh', shell=True) 29 | mfa_offset = hparams['preprocess_args']['mfa_offset'] 30 | if mfa_offset > 0: 31 | for tg_fn in glob.glob(f'{hparams["processed_data_dir"]}/{mfa_outputs}/*.TextGrid'): 32 | tg = TextGrid.fromFile(tg_fn) 33 | max_time = tg.maxTime 34 | for tier in tg.tiers: 35 | for interval in tier.intervals: 36 | interval.maxTime = min(interval.maxTime + mfa_offset, max_time) 37 | interval.minTime = min(interval.minTime + mfa_offset, max_time) 38 | tier.intervals[0].minTime = 0 39 | tier.maxTime = min(tier.maxTime + mfa_offset, max_time) 40 | tg.write(tg_fn) 41 | TextGrid.fromFile(tg_fn) 42 | 43 | 44 | if __name__ == '__main__': 45 | set_hparams(print_hparams=False) 46 | train_mfa_align() 47 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import en -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/base_text_processor.py: -------------------------------------------------------------------------------- 1 | from utils.text.text_encoder import is_sil_phoneme 2 | 3 | REGISTERED_TEXT_PROCESSORS = {} 4 | 5 | 6 | def register_txt_processors(name): 7 | def _f(cls): 8 | REGISTERED_TEXT_PROCESSORS[name] = cls 9 | return cls 10 | 11 | return _f 12 | 13 | 14 | def get_txt_processor_cls(name): 15 | return REGISTERED_TEXT_PROCESSORS.get(name, None) 16 | 17 | 18 | class BaseTxtProcessor: 19 | @staticmethod 20 | def sp_phonemes(): 21 | return ['|'] 22 | 23 | @classmethod 24 | def process(cls, txt, preprocess_args): 25 | raise NotImplementedError 26 | 27 | @classmethod 28 | def postprocess(cls, txt_struct, preprocess_args): 29 | # remove sil phoneme in head and tail 30 | while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[0][0]): 31 | txt_struct = txt_struct[1:] 32 | while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[-1][0]): 33 | txt_struct = txt_struct[:-1] 34 | if preprocess_args['with_phsep']: 35 | txt_struct = cls.add_bdr(txt_struct) 36 | if preprocess_args['add_eos_bos']: 37 | txt_struct = [["", [""]]] + txt_struct + [["", [""]]] 38 | return txt_struct 39 | 40 | @classmethod 41 | def add_bdr(cls, txt_struct): 42 | txt_struct_ = [] 43 | for i, ts in enumerate(txt_struct): 44 | txt_struct_.append(ts) 45 | if i != len(txt_struct) - 1 and \ 46 | not is_sil_phoneme(txt_struct[i][0]) and not is_sil_phoneme(txt_struct[i + 1][0]): 47 | txt_struct_.append(['|', ['|']]) 48 | return txt_struct_ 49 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/en.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | from g2p_en import G2p 5 | from g2p_en.expand import normalize_numbers 6 | from nltk import pos_tag 7 | from nltk.tokenize import TweetTokenizer 8 | 9 | from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors 10 | from utils.text.text_encoder import PUNCS, is_sil_phoneme 11 | 12 | 13 | class EnG2p(G2p): 14 | word_tokenize = TweetTokenizer().tokenize 15 | 16 | def __call__(self, text): 17 | # preprocessing 18 | words = EnG2p.word_tokenize(text) 19 | tokens = pos_tag(words) # tuples of (word, tag) 20 | 21 | # steps 22 | prons = [] 23 | for word, pos in tokens: 24 | if re.search("[a-z]", word) is None: 25 | pron = [word] 26 | 27 | elif word in self.homograph2features: # Check homograph 28 | pron1, pron2, pos1 = self.homograph2features[word] 29 | if pos.startswith(pos1): 30 | pron = pron1 31 | else: 32 | pron = pron2 33 | elif word in self.cmu: # lookup CMU dict 34 | pron = self.cmu[word][0] 35 | else: # predict for oov 36 | pron = self.predict(word) 37 | 38 | prons.extend(pron) 39 | prons.extend([" "]) 40 | 41 | return prons[:-1] 42 | 43 | 44 | @register_txt_processors('en') 45 | class TxtProcessor(BaseTxtProcessor): 46 | g2p = EnG2p() 47 | 48 | @staticmethod 49 | def preprocess_text(text): 50 | text = normalize_numbers(text) 51 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 52 | if unicodedata.category(char) != 'Mn') # Strip accents 53 | text = text.lower() 54 | text = re.sub("[\'\"()]+", "", text) 55 | text = re.sub("[-]+", " ", text) 56 | text = re.sub(f"[^ a-z{PUNCS}]", "", text) 57 | text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> ! 58 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 59 | text = text.replace("i.e.", "that is") 60 | text = text.replace("i.e.", "that is") 61 | text = text.replace("etc.", "etc") 62 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 63 | text = re.sub(rf"\s+", r" ", text) 64 | return text 65 | 66 | @classmethod 67 | def process(cls, txt, preprocess_args): 68 | txt = cls.preprocess_text(txt).strip() 69 | phs = cls.g2p(txt) 70 | txt_struct = [[w, []] for w in txt.split(" ")] 71 | i_word = 0 72 | for p in phs: 73 | if p == ' ': 74 | i_word += 1 75 | else: 76 | txt_struct[i_word][1].append(p) 77 | txt_struct = cls.postprocess(txt_struct, preprocess_args) 78 | return txt_struct, txt 79 | -------------------------------------------------------------------------------- /data_gen/tts/wav_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base_processor 2 | from . import common_processors 3 | -------------------------------------------------------------------------------- /data_gen/tts/wav_processors/base_processor.py: -------------------------------------------------------------------------------- 1 | REGISTERED_WAV_PROCESSORS = {} 2 | 3 | 4 | def register_wav_processors(name): 5 | def _f(cls): 6 | REGISTERED_WAV_PROCESSORS[name] = cls 7 | return cls 8 | 9 | return _f 10 | 11 | 12 | def get_wav_processor_cls(name): 13 | return REGISTERED_WAV_PROCESSORS.get(name, None) 14 | 15 | 16 | class BaseWavProcessor: 17 | @property 18 | def name(self): 19 | raise NotImplementedError 20 | 21 | def output_fn(self, input_fn): 22 | return f'{input_fn[:-4]}_{self.name}.wav' 23 | 24 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /data_gen/tts/wav_processors/common_processors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import librosa 4 | import numpy as np 5 | from data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors 6 | from utils.audio import trim_long_silences 7 | from utils.audio.io import save_wav 8 | from utils.audio.rnnoise import rnnoise 9 | from utils.commons.hparams import hparams 10 | 11 | 12 | @register_wav_processors(name='sox_to_wav') 13 | class ConvertToWavProcessor(BaseWavProcessor): 14 | @property 15 | def name(self): 16 | return 'ToWav' 17 | 18 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 19 | if input_fn[-4:] == '.wav': 20 | return input_fn, sr 21 | else: 22 | output_fn = self.output_fn(input_fn) 23 | subprocess.check_call(f'sox -v 0.95 "{input_fn}" -t wav "{output_fn}"', shell=True) 24 | return output_fn, sr 25 | 26 | 27 | @register_wav_processors(name='sox_resample') 28 | class ResampleProcessor(BaseWavProcessor): 29 | @property 30 | def name(self): 31 | return 'Resample' 32 | 33 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 34 | output_fn = self.output_fn(input_fn) 35 | sr_file = librosa.core.get_samplerate(input_fn) 36 | if sr != sr_file: 37 | subprocess.check_call(f'sox -v 0.95 "{input_fn}" -r{sr} "{output_fn}"', shell=True) 38 | y, _ = librosa.core.load(input_fn, sr=sr) 39 | y, _ = librosa.effects.trim(y) 40 | save_wav(y, output_fn, sr) 41 | return output_fn, sr 42 | else: 43 | return input_fn, sr 44 | 45 | 46 | @register_wav_processors(name='trim_sil') 47 | class TrimSILProcessor(BaseWavProcessor): 48 | @property 49 | def name(self): 50 | return 'TrimSIL' 51 | 52 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 53 | output_fn = self.output_fn(input_fn) 54 | y, _ = librosa.core.load(input_fn, sr=sr) 55 | y, _ = librosa.effects.trim(y) 56 | save_wav(y, output_fn, sr) 57 | return output_fn 58 | 59 | 60 | @register_wav_processors(name='trim_all_sil') 61 | class TrimAllSILProcessor(BaseWavProcessor): 62 | @property 63 | def name(self): 64 | return 'TrimSIL' 65 | 66 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 67 | output_fn = self.output_fn(input_fn) 68 | y, audio_mask, _ = trim_long_silences( 69 | input_fn, vad_max_silence_length=preprocess_args.get('vad_max_silence_length', 12)) 70 | save_wav(y, output_fn, sr) 71 | if preprocess_args['save_sil_mask']: 72 | os.makedirs(f'{processed_dir}/sil_mask', exist_ok=True) 73 | np.save(f'{processed_dir}/sil_mask/{item_name}.npy', audio_mask) 74 | return input_fn, sr, output_fn 75 | 76 | 77 | @register_wav_processors(name='denoise') 78 | class DenoiseProcessor(BaseWavProcessor): 79 | @property 80 | def name(self): 81 | return 'Denoise' 82 | 83 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 84 | output_fn = self.output_fn(input_fn) 85 | rnnoise(input_fn, output_fn, out_sample_rate=sr) 86 | return output_fn, sr 87 | -------------------------------------------------------------------------------- /egs/datasets/audio/vctk/base_text2mel.yaml: -------------------------------------------------------------------------------- 1 | base_config: egs/egs_bases/tts/base.yaml 2 | raw_data_dir: '/workspace/dataset/VCTK/VCTK-Corpus' 3 | processed_data_dir: '/workspace/dataset/processed/vctk' 4 | binary_data_dir: '/workspace/dataset/binary/vctk' 5 | preprocess_cls: egs.datasets.audio.vctk.preprocess.VCTKPreprocess 6 | binarization_args: 7 | train_range: [ 2725, -1 ] 8 | test_range: [ 0, 2180 ] 9 | valid_range: [ 2180, 2725 ] 10 | test_ids: [] 11 | f0_min: 80 12 | f0_max: 800 13 | vocoder_ckpt: checkpoints/hifi_vctk 14 | -------------------------------------------------------------------------------- /egs/datasets/audio/vctk/diffprosody.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - egs/egs_bases/tts/dp.yaml 3 | - ./base_text2mel.yaml 4 | 5 | task_cls: tasks.tts.diffprosody.DiffProsodyTask 6 | 7 | prosody_mel_bins: 20 8 | valid_infer_interval: 2000 9 | vq_warmup: 20000 10 | commitment_cost: 0.25 11 | lambda_mel_adv: 0.05 12 | 13 | disc_win_num: 3 14 | mel_disc_hidden_size: 192 15 | disc_norm: in 16 | disc_reduction: stack 17 | disc_interval: 1 18 | disc_start_steps: 0 19 | discriminator_scheduler_params: 20 | gamma: 0.5 21 | step_size: 40000 22 | discriminator_optimizer_params: 23 | eps: 1.0e-06 24 | weight_decay: 0.0 25 | 26 | max_sentences: 48 27 | lr: 0.0005 28 | disc_lr: 0.0001 29 | ema_decay: 0.998 -------------------------------------------------------------------------------- /egs/datasets/audio/vctk/hifigan.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - egs/egs_bases/tts/vocoder/hifigan.yaml 3 | - ./base_mel2wav.yaml -------------------------------------------------------------------------------- /egs/datasets/audio/vctk/preprocess.py: -------------------------------------------------------------------------------- 1 | from data_gen.tts.base_preprocess import BasePreprocessor 2 | 3 | 4 | class VCTKPreprocess(BasePreprocessor): 5 | def meta_data(self): 6 | for l in open(f'{self.raw_data_dir}/metadata.csv').readlines(): 7 | item_name, txt, spk = l.strip().split("|") 8 | item_name = item_name.replace(".wav", "") 9 | items = item_name.split("/") 10 | item_name = items[-1] 11 | spk_name = items[-2] 12 | wav_fn = f"{self.raw_data_dir}/wav48/{spk_name}/{item_name}.wav" 13 | 14 | yield {'item_name': item_name, 'wav_fn': wav_fn, 'txt': txt, 'spk_name': spk} 15 | -------------------------------------------------------------------------------- /egs/datasets/audio/vctk/prosody_generator.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - egs/egs_bases/tts/pg.yaml 3 | - ./base_text2mel.yaml 4 | tts_model: "" 5 | 6 | task_cls: tasks.tts.prosody_generator.ProsodyGeneratorTask 7 | 8 | timesteps: 4 9 | K_step: 4 10 | residual_channels: 384 11 | disc_hidden_size: 192 12 | n_layer: 4 13 | n_uncond_layer: 2 14 | n_cond_layer: 2 15 | disc_n_channels: [384, 192, 96, 48] 16 | disc_kernel_sizes: [3, 3, 3, 3] 17 | disc_strides: [1, 1, 1, 1] 18 | 19 | lambda_adv: 1.0 20 | lambda_fm: 0.0 21 | lambda_lpv: 1.0 22 | 23 | lr: 0.0002 24 | disc_lr: 0.0001 25 | discriminator_scheduler_params: 26 | gamma: 0.5 27 | step_size: 40000 28 | generator_scheduler_params: 29 | gamma: 0.5 30 | step_size: 40000 31 | 32 | discriminator_optimizer_params: 33 | eps: 1.0e-06 34 | weight_decay: 0.0 35 | 36 | disc_win_num: 3 37 | mel_disc_hidden_size: 192 38 | disc_norm: in 39 | disc_reduction: stack 40 | disc_interval: 1 41 | disc_start_steps: 0 42 | 43 | lpv_loss: l1 44 | 45 | max_tokens: 40000000 46 | max_sentences: 48 47 | max_updates: 320000 48 | 49 | max_beta: 20 50 | min_beta: 0.1 51 | ema_decay: 0.998 52 | -------------------------------------------------------------------------------- /egs/egs_bases/config_base.yaml: -------------------------------------------------------------------------------- 1 | # task 2 | binary_data_dir: '' 3 | work_dir: '' # experiment directory. 4 | infer: false # infer 5 | amp: false 6 | seed: 1234 7 | debug: false 8 | save_codes: ['tasks', 'modules', 'egs'] 9 | 10 | ############# 11 | # dataset 12 | ############# 13 | ds_workers: 1 14 | test_num: 100 15 | endless_ds: true 16 | sort_by_len: true 17 | 18 | ######### 19 | # train and eval 20 | ######### 21 | print_nan_grads: false 22 | load_ckpt: '' 23 | save_best: false 24 | num_ckpt_keep: 3 25 | clip_grad_norm: 0 26 | accumulate_grad_batches: 1 27 | tb_log_interval: 100 28 | num_sanity_val_steps: 5 # steps of validation at the beginning 29 | check_val_every_n_epoch: 10 30 | val_check_interval: 2000 31 | valid_monitor_key: 'val_loss' 32 | valid_monitor_mode: 'min' 33 | max_epochs: 1000 34 | max_updates: 1000000 35 | max_tokens: 40000 36 | max_sentences: 100000 37 | max_valid_tokens: -1 38 | max_valid_sentences: -1 39 | eval_max_batches: -1 40 | resume_from_checkpoint: 0 41 | rename_tmux: true -------------------------------------------------------------------------------- /egs/egs_bases/tts/base.yaml: -------------------------------------------------------------------------------- 1 | # task 2 | base_config: 3 | - ../config_base.yaml 4 | - ./dataset_params.yaml 5 | 6 | ############# 7 | # dataset in training 8 | ############# 9 | endless_ds: true 10 | min_frames: 0 11 | max_frames: 1548 12 | frames_multiple: 1 13 | max_input_tokens: 1550 14 | ds_workers: 1 15 | 16 | ######### 17 | # model 18 | ######### 19 | use_spk_id: false 20 | use_spk_embed: true 21 | mel_losses: "ssim:0.5|l1:0.5" 22 | 23 | ########### 24 | # optimization 25 | ########### 26 | lr: 0.0005 27 | scheduler: warmup # rsqrt|warmup|none 28 | warmup_updates: 4000 29 | optimizer_adam_beta1: 0.9 30 | optimizer_adam_beta2: 0.98 31 | weight_decay: 0 32 | clip_grad_norm: 1 33 | clip_grad_value: 0 34 | 35 | 36 | ########### 37 | # train and eval 38 | ########### 39 | use_word_input: false 40 | max_valid_sentences: 1 41 | max_valid_tokens: 60000 42 | valid_infer_interval: 10000 43 | train_set_name: 'train' 44 | train_sets: '' 45 | valid_set_name: 'valid' 46 | test_set_name: 'test' 47 | num_valid_plots: 10 48 | test_ids: [ ] 49 | test_input_yaml: '' 50 | vocoder: HifiGAN 51 | vocoder_ckpt: '' 52 | profile_infer: false 53 | out_wav_norm: false 54 | save_gt: true 55 | save_f0: false 56 | gen_dir_name: '' -------------------------------------------------------------------------------- /egs/egs_bases/tts/dataset_params.yaml: -------------------------------------------------------------------------------- 1 | audio_num_mel_bins: 80 2 | audio_sample_rate: 22050 3 | hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) 4 | win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 5 | fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter 6 | fmin: 0 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 7 | fmax: 8000 # To be increased/reduced depending on data. 8 | f0_min: 80 9 | f0_max: 800 10 | griffin_lim_iters: 30 11 | pitch_extractor: parselmouth 12 | num_spk: 2456 13 | mel_vmin: -6 14 | mel_vmax: 1.5 15 | loud_norm: false 16 | 17 | raw_data_dir: '' 18 | processed_data_dir: '' 19 | binary_data_dir: '' 20 | preprocess_cls: '' 21 | binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer 22 | preprocess_args: 23 | nsample_per_mfa_group: 1000 24 | # text process 25 | txt_processor: en 26 | use_mfa: true 27 | with_phsep: false # add '|' 28 | reset_phone_dict: true 29 | reset_word_dict: true 30 | add_eos_bos: true # add , 31 | # mfa 32 | mfa_group_shuffle: false 33 | mfa_offset: 0.02 34 | # wav processors 35 | wav_processors: ['trim_all_sil'] 36 | save_sil_mask: true 37 | vad_max_silence_length: 12 38 | binarization_args: 39 | shuffle: false 40 | with_wav: true 41 | with_align: true 42 | with_spk_embed: true 43 | with_f0: true 44 | with_f0cwt: true 45 | with_linear: false 46 | trim_eos_bos: true 47 | min_sil_duration: 0.1 48 | train_range: [ 200, -1 ] 49 | test_range: [ 0, 100 ] 50 | valid_range: [ 100, 200 ] 51 | word_dict_size: 30000 52 | pitch_key: pitch -------------------------------------------------------------------------------- /egs/egs_bases/tts/dp.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./fs.yaml 2 | task_cls: tasks.tts.diffprosody.DiffProsodyTask 3 | 4 | ########################### 5 | # models 6 | ########################### 7 | encoder_type: fft 8 | decoder_type: fft 9 | 10 | # encoders 11 | hidden_size: 192 12 | ffn_hidden_size: 384 13 | enc_ffn_kernel_size: 5 14 | enc_layers: 4 15 | use_word_encoder: true 16 | 17 | # decoders 18 | dec_layers: 4 19 | dec_ffn_kernel_size: 5 20 | 21 | # mix ling encoder 22 | word_enc_layers: 4 23 | word_encoder_type: fft 24 | text_encoder_postnet: true 25 | dropout: 0.1 26 | 27 | max_updates: 160000 28 | mel_losses: mse|ssim -------------------------------------------------------------------------------- /egs/egs_bases/tts/fs.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./base.yaml 2 | task_cls: tasks.tts.fs.FastSpeechTask 3 | 4 | # model 5 | hidden_size: 256 6 | dropout: 0.2 7 | encoder_type: fft # rel_fft|fft|tacotron|tacotron2|conformer 8 | decoder_type: fft # fft|rnn|conv|conformer|wn 9 | 10 | # rnn enc/dec 11 | encoder_K: 8 12 | decoder_rnn_dim: 0 # for rnn decoder, 0 -> hidden_size * 2 13 | 14 | # fft enc/dec 15 | enc_layers: 4 16 | enc_ffn_kernel_size: 9 17 | enc_prenet: true 18 | enc_pre_ln: true 19 | dec_layers: 4 20 | dec_ffn_kernel_size: 9 21 | num_heads: 2 22 | ffn_act: gelu 23 | ffn_hidden_size: 1024 24 | use_pos_embed: true 25 | 26 | # conv enc/dec 27 | enc_dec_norm: ln 28 | conv_use_pos: false 29 | layers_in_block: 2 30 | enc_dilations: [ 1, 1, 1, 1 ] 31 | enc_kernel_size: 5 32 | enc_post_net_kernel: 3 33 | dec_dilations: [ 1, 1, 1, 1 ] # for conv decoder 34 | dec_kernel_size: 5 35 | dec_post_net_kernel: 3 36 | 37 | # duration 38 | predictor_hidden: -1 39 | dur_predictor_kernel: 3 40 | dur_predictor_layers: 2 41 | predictor_kernel: 5 42 | predictor_layers: 5 43 | predictor_dropout: 0.5 44 | 45 | # pitch and energy 46 | use_pitch_embed: false 47 | pitch_type: frame # frame|ph|cwt 48 | use_uv: true 49 | 50 | # reference encoder and speaker embedding 51 | lambda_commit: 0.25 52 | ref_norm_layer: bn 53 | dec_inp_add_noise: false 54 | 55 | # mel 56 | # mel_losses: mse|ssim # l1|l2|gdl|ssim or l1:0.5|ssim:0.5 57 | # mel_losses: mse|ssim 58 | mel_losses: l1 59 | 60 | # loss lambda 61 | lambda_f0: 1.0 62 | lambda_uv: 1.0 63 | lambda_energy: 0.1 64 | lambda_ph_dur: 0.1 65 | lambda_sent_dur: 1.0 66 | lambda_word_dur: 1.0 67 | predictor_grad: 0.1 68 | 69 | # train and eval 70 | warmup_updates: 4000 71 | max_tokens: 40000 72 | max_sentences: 48 73 | max_valid_sentences: 1 74 | max_updates: 160000 75 | use_gt_dur: false 76 | use_gt_f0: false 77 | ds_workers: 2 -------------------------------------------------------------------------------- /egs/egs_bases/tts/pg.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./fs.yaml 2 | task_cls: tasks.tts.prosody_generator.ProsodyGeneratorTask 3 | ########################### 4 | # models 5 | ########################### 6 | 7 | encoder_type: fft 8 | decoder_type: fft 9 | 10 | # encoders 11 | hidden_size: 192 12 | ffn_hidden_size: 384 13 | enc_ffn_kernel_size: 5 14 | enc_layers: 4 15 | use_word_encoder: true 16 | 17 | # decoders 18 | dec_layers: 4 19 | dec_ffn_kernel_size: 5 20 | 21 | # mix ling encoder 22 | word_enc_layers: 4 23 | word_encoder_type: fft 24 | text_encoder_postnet: true 25 | dropout: 0.1 26 | 27 | 28 | ## model configs for diffspeech 29 | residual_layers: 20 30 | residual_channels: 256 31 | dilation_cycle_length: 1 32 | lr: 0.0002 33 | timesteps: 100 34 | K_step: 100 35 | diff_loss_type: l1 36 | diff_decoder_type: 'wavenet' 37 | schedule_type: 'linear' 38 | max_beta: 0.06 39 | 40 | prosody_mel_bins: 20 41 | valid_infer_interval: 2000 42 | vq_warmup: 2000000 43 | commitment_cost: 0 44 | 45 | ########################### 46 | # training and inference 47 | ########################### 48 | num_valid_plots: 10 49 | warmup_updates: 4000 50 | max_tokens: 40000 51 | max_sentences: 64 52 | max_updates: 100000 53 | keep_bins: 192 -------------------------------------------------------------------------------- /egs/egs_bases/tts/vocoder/base.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - egs/egs_bases/config_base.yaml 3 | - ../dataset_params.yaml 4 | binarization_args: 5 | with_wav: true 6 | with_spk_embed: false 7 | with_align: false 8 | 9 | generator_grad_norm: 10.0 # Generator's gradient norm. 10 | discriminator_grad_norm: 1.0 # Discriminator's gradient norm. 11 | 12 | ########### 13 | # train and eval 14 | ########### 15 | max_samples: 20480 16 | max_sentences: 8 17 | max_valid_sentences: 1 18 | max_updates: 2000000 19 | val_check_interval: 5000 20 | valid_infer_interval: 50000 21 | -------------------------------------------------------------------------------- /egs/egs_bases/tts/vocoder/hifigan.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./base.yaml 2 | task_cls: tasks.vocoder.hifigan.HifiGanTask 3 | resblock: "1" 4 | adam_b1: 0.8 5 | adam_b2: 0.99 6 | upsample_rates: [ 8,8,2,2 ] 7 | upsample_kernel_sizes: [ 16,16,4,4 ] 8 | upsample_initial_channel: 512 9 | resblock_kernel_sizes: [ 3,7,11 ] 10 | resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ] 11 | 12 | use_pitch_embed: false 13 | use_fm_loss: false 14 | use_ms_stft: false 15 | 16 | lambda_mel: 5.0 17 | lambda_mel_adv: 1.0 18 | lambda_cdisc: 4.0 19 | lambda_adv: 1.0 20 | 21 | lr: 0.0002 # Generator's learning rate. 22 | generator_scheduler_params: 23 | step_size: 600 24 | gamma: 0.999 25 | discriminator_scheduler_params: 26 | step_size: 600 27 | gamma: 0.999 28 | max_updates: 3000000 -------------------------------------------------------------------------------- /extract_lpv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inference.tts.base_tts_infer import BaseTTSInfer 3 | from modules.tts.diffprosody.diffprosody import DiffProsody 4 | from utils.commons.ckpt_utils import load_ckpt 5 | from utils.commons.hparams import hparams, set_hparams 6 | from tqdm import tqdm 7 | from tasks.tts.tts_utils import load_data_preprocessor 8 | from tasks.tts.dataset_utils import FastSpeechWordDataset 9 | from utils.commons.tensor_utils import move_to_cuda 10 | import os 11 | import numpy as np 12 | 13 | class get_LPV(BaseTTSInfer): 14 | def __init__(self, hparams, device=None): 15 | if device is None: 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | self.hparams = hparams 18 | self.device = device 19 | self.data_dir = hparams['binary_data_dir'] 20 | self.preprocessor, self.preprocess_args = load_data_preprocessor() 21 | self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir) 22 | self.spk_map = self.preprocessor.load_spk_map(self.data_dir) 23 | self.ds_cls = FastSpeechWordDataset 24 | self.model = self.build_model() 25 | self.model.eval() 26 | self.model.to(self.device) 27 | 28 | def build_model(self): 29 | ph_dict_size = len(self.ph_encoder) 30 | word_dict_size = len(self.word_encoder) 31 | model = DiffProsody(ph_dict_size, word_dict_size, self.hparams) 32 | if os.path.exists("{}/lpvs.npy".format(hparams["work_dir"])): 33 | model.prosody_encoder.init_vq(None) 34 | load_ckpt(model, hparams['work_dir'], 'model') 35 | model.eval() 36 | return model 37 | 38 | def forward_model(self, i): 39 | with torch.no_grad(): 40 | i = move_to_cuda(i, self.device) 41 | lpv, lpv_idx = self.model.get_lpv( 42 | i["txt_tokens"], i["word_tokens"], 43 | i["ph2word"], i["word_lengths"].max(), 44 | i["mel2word"], i["mel2ph"], 45 | i.get('spk_embed'), 46 | i['mels'], 160000 47 | ) 48 | return lpv, lpv_idx 49 | 50 | def postprocess_output(self, output): 51 | return output 52 | 53 | def infer_once(self, inp): 54 | output = self.forward_model(inp) 55 | output = self.postprocess_output(output) 56 | return output 57 | 58 | @classmethod 59 | def example_run(cls): 60 | from utils.commons.hparams import hparams as hp 61 | import json 62 | dataset_cls = FastSpeechWordDataset 63 | l = [hparams['train_set_name'], hparams['valid_set_name'], hparams['test_set_name']] 64 | # l = [hparams['test_set_name']] 65 | 66 | lpv_dir = os.path.join(hparams['work_dir'], "lpvs") 67 | 68 | os.makedirs(lpv_dir, exist_ok=True) 69 | for d in l: 70 | dataset = dataset_cls(prefix=d, shuffle=True) 71 | dataloader = torch.utils.data.DataLoader(dataset, 72 | collate_fn=dataset.collater, 73 | batch_size=1, # Numer of Sample 74 | num_workers=1, 75 | pin_memory=False) 76 | print("###### Dataloader: ", len(dataloader)) 77 | infer_ins = cls(hp) 78 | 79 | lpv_min = None 80 | lpv_max = None 81 | 82 | for idx, i in tqdm(enumerate(dataloader)): 83 | lpv, out_idx = infer_ins.infer_once(i) 84 | 85 | j_obj = { 86 | "lpv": lpv.cpu().numpy()[0], 87 | "lpv_idx": out_idx.cpu().numpy()[0] 88 | } 89 | 90 | npz_name = os.path.join(lpv_dir, i["item_name"][0]+".npz") 91 | np.savez(npz_name, **j_obj) 92 | 93 | lpv = lpv.cpu().numpy().tolist() 94 | lpv = np.array(lpv).squeeze(0) 95 | if lpv_min is None: 96 | lpv_min = lpv 97 | lpv_max = lpv 98 | temp = np.concatenate((lpv_min, lpv), axis=0) 99 | temp2 = np.concatenate((lpv_max, lpv), axis=0) 100 | lpv_min = np.min(temp, axis=0) 101 | lpv_max = np.max(temp2, axis=0) 102 | lpv_min = np.expand_dims(lpv_min, axis=0) 103 | lpv_max = np.expand_dims(lpv_max, axis=0) 104 | j_obj = { 105 | "lpv_min": lpv_min.tolist(), 106 | "lpv_max": lpv_max.tolist(), 107 | } 108 | with open(os.path.join(hparams['work_dir'], "stats_lpv_{}.json".format(d)), "w") as w: 109 | json.dump(j_obj, w) 110 | 111 | if __name__ == '__main__': 112 | set_hparams() 113 | get_LPV.example_run() 114 | -------------------------------------------------------------------------------- /inference/tts/base_tts_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | # from modules.vocoder.hifigan.hifigan import HifiGanGenerator 6 | from tasks.tts.vocoder_infer.base_vocoder import BaseVocoder, get_vocoder_cls 7 | from tasks.tts.dataset_utils import FastSpeechWordDataset 8 | from tasks.tts.tts_utils import load_data_preprocessor 9 | from utils.commons.ckpt_utils import load_ckpt 10 | from utils.commons.hparams import set_hparams 11 | 12 | 13 | class BaseTTSInfer: 14 | def __init__(self, hparams, device=None): 15 | if device is None: 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | self.hparams = hparams 18 | self.device = device 19 | self.data_dir = hparams['binary_data_dir'] 20 | self.preprocessor, self.preprocess_args = load_data_preprocessor() 21 | self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir) 22 | self.spk_map = self.preprocessor.load_spk_map(self.data_dir) 23 | self.ds_cls = FastSpeechWordDataset 24 | self.model = self.build_model() 25 | self.model.eval() 26 | self.model.to(self.device) 27 | self.vocoder = self.build_vocoder() 28 | 29 | def build_model(self): 30 | raise NotImplementedError 31 | 32 | def forward_model(self, inp): 33 | raise NotImplementedError 34 | 35 | def build_vocoder(self): 36 | vocoder = get_vocoder_cls(self.hparams['vocoder'])() 37 | 38 | return vocoder 39 | 40 | def run_vocoder(self, c): 41 | c = c.transpose(2, 1) 42 | y = self.vocoder(c)[:, 0] 43 | return y 44 | 45 | def preprocess_input(self, inp): 46 | """ 47 | :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} 48 | :return: 49 | """ 50 | preprocessor, preprocess_args = self.preprocessor, self.preprocess_args 51 | text_raw = inp['text'] 52 | item_name = inp.get('item_name', '') 53 | spk_name = inp.get('spk_name', '') 54 | ph, txt, word, ph2word, ph_gb_word = preprocessor.txt_to_ph( 55 | preprocessor.txt_processor, text_raw, preprocess_args) 56 | word_token = self.word_encoder.encode(word) 57 | ph_token = self.ph_encoder.encode(ph) 58 | spk_id = self.spk_map[spk_name] 59 | item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, 60 | 'ph_token': ph_token, 'word_token': word_token, 'ph2word': ph2word} 61 | item['ph_len'] = len(item['ph_token']) 62 | return item 63 | 64 | def input_to_batch(self, item): 65 | item_names = [item['item_name']] 66 | text = [item['text']] 67 | ph = [item['ph']] 68 | txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) 69 | txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) 70 | word_tokens = torch.LongTensor(item['word_token'])[None, :].to(self.device) 71 | word_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) 72 | ph2word = torch.LongTensor(item['ph2word'])[None, :].to(self.device) 73 | spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) 74 | batch = { 75 | 'item_name': item_names, 76 | 'text': text, 77 | 'ph': ph, 78 | 'txt_tokens': txt_tokens, 79 | 'txt_lengths': txt_lengths, 80 | 'word_tokens': word_tokens, 81 | 'word_lengths': word_lengths, 82 | 'ph2word': ph2word, 83 | 'spk_ids': spk_ids, 84 | } 85 | return batch 86 | 87 | def postprocess_output(self, output): 88 | return output 89 | 90 | def infer_once(self, inp): 91 | inp = self.preprocess_input(inp) 92 | output = self.forward_model(inp) 93 | output = self.postprocess_output(output) 94 | return output 95 | 96 | @classmethod 97 | def example_run(cls): 98 | from utils.commons.hparams import set_hparams 99 | from utils.commons.hparams import hparams as hp 100 | from utils.audio.io import save_wav 101 | 102 | set_hparams() 103 | inp = { 104 | 'text': 'the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.' 105 | } 106 | infer_ins = cls(hp) 107 | out = infer_ins.infer_once(inp) 108 | os.makedirs('infer_out', exist_ok=True) 109 | save_wav(out, f'infer_out/example_out.wav', hp['audio_sample_rate']) 110 | -------------------------------------------------------------------------------- /inference/tts/dp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inference.tts.base_tts_infer import BaseTTSInfer 3 | from modules.tts.diffprosody.diffprosody import DiffProsody 4 | from modules.tts.diffprosody.diffusion import DiffusionProsodyGenerator 5 | from utils.commons.ckpt_utils import load_ckpt 6 | from utils.commons.hparams import hparams 7 | from tqdm import tqdm 8 | from utils.nn.seq_utils import group_hidden_by_segs 9 | import os 10 | from utils.audio.io import save_wav 11 | from resemblyzer import VoiceEncoder 12 | from data_gen.tts.txt_processors.base_text_processor import get_txt_processor_cls 13 | from utils.text.text_encoder import build_token_encoder 14 | import librosa 15 | from speechbrain.pretrained import EncoderClassifier 16 | 17 | class DiffProsodyInfer(BaseTTSInfer): 18 | def build_model(self): 19 | ph_dict_size = len(self.ph_encoder) 20 | word_dict_size = len(self.word_encoder) 21 | model = DiffProsody(ph_dict_size, word_dict_size, self.hparams) 22 | model.prosody_encoder.init_vq(None) 23 | load_ckpt(model, hparams['tts_model'], 'model') 24 | model.eval() 25 | self.pd = DiffusionProsodyGenerator(self.hparams) 26 | load_ckpt(self.pd, hparams['work_dir'], 'model') 27 | self.pd.cuda() 28 | self.pd.eval() 29 | 30 | txt_processor = self.preprocess_args['txt_processor'] 31 | self.txt_processor = get_txt_processor_cls(txt_processor) 32 | self.ph_encoder = build_token_encoder('/workspace/dataset/binary/vctk/phone_set.json') 33 | self.word_encoder = build_token_encoder('/workspace/dataset/binary/vctk/word_set.json') 34 | self.resem = VoiceEncoder().cuda() 35 | return model 36 | 37 | def forward_model(self, inp): 38 | items = inp.rstrip().split("|") 39 | 40 | wav_name = items[0] 41 | txt_raw = items[1] 42 | 43 | txt_struct, txt = self.txt_processor.process(txt_raw, hparams['preprocess_args']) 44 | ph = [p for w in txt_struct for p in w[1]] 45 | words = [w[0] for w in txt_struct] 46 | ph2word = [w_id + 1 for w_id, w in enumerate(txt_struct) for _ in range(len(w[1]))] 47 | ph = " ".join(ph) 48 | word = " ".join(words) 49 | 50 | word_token = self.word_encoder.encode(word) 51 | ph_token = self.ph_encoder.encode(ph) 52 | 53 | infer_dir = os.path.join(hparams['work_dir'], "infer") 54 | os.makedirs(infer_dir, exist_ok=True) 55 | 56 | wav_dir = "/workspace/dataset/libritts/wavs" 57 | 58 | with torch.no_grad(): 59 | y, _ = librosa.load(os.path.join(wav_dir, wav_name+'.wav')) 60 | y, _ = librosa.effects.trim(y) 61 | resem = self.resem.embed_utterance(y.astype(float)) 62 | y = torch.Tensor(y).to("cuda:0") 63 | y = y.to("cuda:0") 64 | spk_embed = torch.Tensor(resem).unsqueeze(0).cuda() 65 | spk_embed = self.model.forward_style_embed(spk_embed, None).squeeze(1) 66 | 67 | txt_tokens = torch.LongTensor(ph_token).unsqueeze(0).cuda() 68 | ph2word = torch.LongTensor(ph2word).unsqueeze(0).cuda() 69 | word_tokens = torch.LongTensor(word_token).unsqueeze(0).cuda() 70 | word_lengths = len(word_token) 71 | 72 | h_ling = self.model.run_text_encoder(txt_tokens, word_tokens, 73 | ph2word, word_lengths, 74 | None, None, {}) 75 | h_ling = group_hidden_by_segs(h_ling, ph2word, word_lengths)[0] 76 | 77 | wrd_nonpadding = (word_tokens > 0).float() 78 | output = self.pd(h_ling, 79 | spk_embed=spk_embed, 80 | ph2word=ph2word, 81 | infer=True, 82 | padding=wrd_nonpadding) 83 | lpv = self.model.prosody_encoder.vector_quantization(output["lpv_out"])[1] 84 | output = self.model( 85 | txt_tokens, 86 | word_tokens, 87 | ph2word=ph2word, 88 | word_len=word_lengths, 89 | infer=True, 90 | spk_embed=spk_embed.squeeze(1), 91 | lpv=lpv 92 | ) 93 | 94 | mel = output['mel_out'][0].cpu().numpy() 95 | wav = self.vocoder.spec2wav(mel) 96 | n = wav_name 97 | save_wav(wav, os.path.join(infer_dir, n+".wav"), 98 | hparams['audio_sample_rate'], 99 | norm=hparams['out_wav_norm']) 100 | 101 | 102 | def infer_once(self, inp): 103 | output = self.forward_model(inp) 104 | return output 105 | 106 | @classmethod 107 | def example_run(cls): 108 | from utils.commons.hparams import set_hparams 109 | from utils.commons.hparams import hparams as hp 110 | 111 | set_hparams() 112 | 113 | f = open("./zeroshot_libri.txt", "r") 114 | lines = f.readlines() 115 | infer_ins = cls(hp) 116 | for idx, i in tqdm(enumerate(lines)): 117 | infer_ins.infer_once(i) 118 | # if idx > 10: break 119 | if __name__ == '__main__': 120 | DiffProsodyInfer.example_run() 121 | -------------------------------------------------------------------------------- /mfa_usr/adapt_config.yaml: -------------------------------------------------------------------------------- 1 | beam: 10 2 | retry_beam: 40 3 | 4 | features: 5 | type: "mfcc" 6 | use_energy: false 7 | frame_shift: 10 8 | snip_edges: true 9 | 10 | training: 11 | - sat: 12 | num_iterations: 5 13 | num_leaves: 4200 14 | max_gaussians: 40000 15 | power: 0.2 16 | silence_weight: 0.0 17 | fmllr_update_type: "full" 18 | -------------------------------------------------------------------------------- /mfa_usr/install_mfa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | pip uninstall -y typing 4 | pip install --ignore-requires-python git+https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git@v2.0.0b3 5 | mfa thirdparty download 6 | apt install -y libopenblas-base libsox-fmt-mp3 libfst8 libfst-tools -------------------------------------------------------------------------------- /mfa_usr/mfa_train_config.yaml: -------------------------------------------------------------------------------- 1 | beam: 10 2 | retry_beam: 40 3 | 4 | features: 5 | type: "mfcc" 6 | use_energy: false 7 | frame_shift: 10 8 | 9 | training: 10 | - monophone: 11 | num_iterations: 40 12 | max_gaussians: 1000 13 | subset: 0 14 | boost_silence: 1.25 15 | 16 | - triphone: 17 | num_iterations: 35 18 | num_leaves: 2000 19 | max_gaussians: 10000 20 | cluster_threshold: -1 21 | subset: 0 22 | boost_silence: 1.25 23 | power: 0.25 24 | 25 | - lda: 26 | num_leaves: 2500 27 | max_gaussians: 15000 28 | subset: 0 29 | num_iterations: 35 30 | features: 31 | splice_left_context: 3 32 | splice_right_context: 3 33 | 34 | - sat: 35 | num_leaves: 2500 36 | max_gaussians: 15000 37 | power: 0.2 38 | silence_weight: 0.0 39 | fmllr_update_type: "diag" 40 | subset: 0 41 | features: 42 | lda: true 43 | 44 | - sat: 45 | num_leaves: 4200 46 | max_gaussians: 40000 47 | power: 0.2 48 | silence_weight: 0.0 49 | fmllr_update_type: "diag" 50 | subset: 0 51 | features: 52 | lda: true 53 | fmllr: true 54 | 55 | # - monophone: 56 | # num_iterations: 40 57 | # max_gaussians: 1000 58 | # boost_silence: 1.0 59 | # 60 | # - triphone: 61 | # num_iterations: 35 62 | # num_leaves: 3100 63 | # max_gaussians: 50000 64 | # cluster_threshold: 100 65 | # boost_silence: 1.0 66 | # power: 0.25 67 | # 68 | # - sat: 69 | # num_leaves: 3100 70 | # max_gaussians: 50000 71 | # power: 0.2 72 | # silence_weight: 0.0 73 | # cluster_threshold: 100 74 | # fmllr_update_type: "full" -------------------------------------------------------------------------------- /mfa_usr/run_mfa_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import importlib 4 | import os 5 | import subprocess 6 | from utils.commons.hparams import set_hparams, hparams 7 | from utils.commons.multiprocess_utils import multiprocess_run_tqdm 8 | from utils.os_utils import remove_file 9 | from utils.text.encoding import get_encoding 10 | 11 | 12 | def process_item(idx, txt_fn): 13 | base_fn = os.path.splitext(txt_fn)[0] 14 | basename = os.path.basename(base_fn) 15 | if os.path.exists(base_fn + '.wav'): 16 | wav_fn = base_fn + '.wav' 17 | elif os.path.exists(base_fn + '.mp3'): 18 | wav_fn = base_fn + '.mp3' 19 | else: 20 | return 21 | # process text 22 | encoding = get_encoding(txt_fn) 23 | with open(txt_fn, encoding=encoding) as f: 24 | txt_raw = " ".join(f.readlines()).strip() 25 | phs, _, phs_for_align, _ = preprocesser.process_text(txt_processor, txt_raw, hparams['preprocess_args']) 26 | os.makedirs(f'{mfa_process_dir}/{basename}', exist_ok=True) 27 | with open(f'{mfa_process_dir}/{basename}/{basename}.lab', 'w') as f: 28 | f.write(phs_for_align) 29 | # process wav 30 | new_wav_fn = preprocesser.process_wav(basename, wav_fn, mfa_process_dir, preprocess_args) 31 | subprocess.check_call(f'cp "{new_wav_fn}" "{mfa_process_dir}/{basename}/{basename}.wav"', shell=True) 32 | 33 | 34 | if __name__ == "__main__": 35 | set_hparams() 36 | parser = argparse.ArgumentParser(description='') 37 | parser.add_argument('--input_dir', type=str, default='', help='input dir') 38 | args, unknown = parser.parse_known_args() 39 | input_dir = args.input_dir 40 | processed_data_dir = hparams['processed_data_dir'] 41 | preprocess_args = hparams['preprocess_args'] 42 | preprocess_args['sox_to_wav'] = True 43 | preprocess_args['trim_all_sil'] = True 44 | # preprocess_args['trim_sil'] = True 45 | # preprocess_args['denoise'] = True 46 | 47 | pkg = ".".join(hparams["preprocess_cls"].split(".")[:-1]) 48 | cls_name = hparams["preprocess_cls"].split(".")[-1] 49 | process_cls = getattr(importlib.import_module(pkg), cls_name) 50 | preprocesser = process_cls() 51 | txt_processor = preprocesser.txt_processor 52 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 53 | 54 | mfa_process_dir = f'{input_dir}/mfa_inputs' 55 | remove_file(mfa_process_dir, f'{input_dir}/mfa_tmp') 56 | os.makedirs(mfa_process_dir, exist_ok=True) 57 | os.makedirs(f'{mfa_process_dir}/processed_tmp', exist_ok=True) 58 | for res in multiprocess_run_tqdm( 59 | process_item, list(enumerate(glob.glob(f'{input_dir}/*.txt')))): 60 | pass 61 | remove_file(f'{mfa_process_dir}/processed_tmp') 62 | subprocess.check_call( 63 | f'mfa align {mfa_process_dir} ' # process dir 64 | f'{hparams["processed_data_dir"]}/mfa_dict.txt ' # dict 65 | f'{input_dir}/mfa_model.zip ' # model 66 | f'{input_dir}/mfa_outputs -t {input_dir}/mfa_tmp -j {num_workers} ' 67 | f' && cp -rf {input_dir}/mfa_outputs/*/* {input_dir}/' 68 | f' && cp -rf {mfa_process_dir}/*/* {input_dir}/' 69 | f' && rm -rf {input_dir}/mfa_tmp {input_dir}/mfa_outputs {mfa_process_dir}', # remove tmp dir 70 | shell=True) 71 | -------------------------------------------------------------------------------- /mfa_usr/run_mfa_train_align.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | NUM_JOB=${NUM_JOB:-36} 5 | echo "| Training MFA using ${NUM_JOB} cores." 6 | 7 | BASE_DIR=/workspace/dataset/processed/$CORPUS 8 | MODEL_NAME=${MODEL_NAME:-"mfa_model"} 9 | PRETRAIN_MODEL_NAME=${PRETRAIN_MODEL_NAME:-"mfa_model_pretrain"} 10 | MFA_INPUTS=${MFA_INPUTS:-"mfa_inputs"} 11 | MFA_OUTPUTS=${MFA_OUTPUTS:-"mfa_outputs"} 12 | MFA_CMD=${MFA_CMD:-"train"} 13 | rm -rf $BASE_DIR/mfa_outputs_tmp 14 | if [ "$MFA_CMD" = "train" ]; then 15 | mfa train $BASE_DIR/$MFA_INPUTS $BASE_DIR/mfa_dict.txt $BASE_DIR/mfa_outputs_tmp -t $BASE_DIR/mfa_tmp -o $BASE_DIR/$MODEL_NAME.zip --clean -j $NUM_JOB --config_path mfa_usr/mfa_train_config.yaml 16 | elif [ "$MFA_CMD" = "adapt" ]; then 17 | python mfa_usr/mfa.py adapt \ 18 | $BASE_DIR/$MFA_INPUTS \ 19 | $BASE_DIR/mfa_dict.txt \ 20 | $BASE_DIR/$PRETRAIN_MODEL_NAME.zip \ 21 | $BASE_DIR/$MODEL_NAME.zip \ 22 | $BASE_DIR/mfa_outputs_tmp \ 23 | -t $BASE_DIR/mfa_tmp --clean -j $NUM_JOB 24 | fi 25 | rm -rf $BASE_DIR/mfa_tmp $BASE_DIR/$MFA_OUTPUTS 26 | mkdir -p $BASE_DIR/$MFA_OUTPUTS 27 | find $BASE_DIR/mfa_outputs_tmp -regex ".*\.TextGrid" -print0 | xargs -0 -i mv {} $BASE_DIR/$MFA_OUTPUTS/ 28 | if [ -e "$BASE_DIR/mfa_outputs_tmp/unaligned.txt" ]; then 29 | cp $BASE_DIR/mfa_outputs_tmp/unaligned.txt $BASE_DIR/ 30 | fi 31 | rm -rf $BASE_DIR/mfa_outputs_tmp -------------------------------------------------------------------------------- /modules/commons/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from modules.commons.layers import LayerNorm, Embedding 7 | 8 | 9 | class LambdaLayer(nn.Module): 10 | def __init__(self, lambd): 11 | super(LambdaLayer, self).__init__() 12 | self.lambd = lambd 13 | 14 | def forward(self, x): 15 | return self.lambd(x) 16 | 17 | 18 | def init_weights_func(m): 19 | classname = m.__class__.__name__ 20 | if classname.find("Conv1d") != -1: 21 | torch.nn.init.xavier_uniform_(m.weight) 22 | 23 | 24 | class ResidualBlock(nn.Module): 25 | """Implements conv->PReLU->norm n-times""" 26 | 27 | def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0, 28 | c_multiple=2, ln_eps=1e-12): 29 | super(ResidualBlock, self).__init__() 30 | 31 | if norm_type == 'bn': 32 | norm_builder = lambda: nn.BatchNorm1d(channels) 33 | elif norm_type == 'in': 34 | norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True) 35 | elif norm_type == 'gn': 36 | norm_builder = lambda: nn.GroupNorm(8, channels) 37 | elif norm_type == 'ln': 38 | norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps) 39 | else: 40 | norm_builder = lambda: nn.Identity() 41 | 42 | self.blocks = [ 43 | nn.Sequential( 44 | norm_builder(), 45 | nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation, 46 | padding=(dilation * (kernel_size - 1)) // 2), 47 | LambdaLayer(lambda x: x * kernel_size ** -0.5), 48 | nn.GELU(), 49 | nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation), 50 | ) 51 | for i in range(n) 52 | ] 53 | 54 | self.blocks = nn.ModuleList(self.blocks) 55 | self.dropout = dropout 56 | 57 | def forward(self, x): 58 | nonpadding = (x.abs().sum(1) > 0).float()[:, None, :] 59 | for b in self.blocks: 60 | x_ = b(x) 61 | if self.dropout > 0 and self.training: 62 | x_ = F.dropout(x_, self.dropout, training=self.training) 63 | x = x + x_ 64 | x = x * nonpadding 65 | return x 66 | 67 | 68 | class ConvBlocks(nn.Module): 69 | """Decodes the expanded phoneme encoding into spectrograms""" 70 | 71 | def __init__(self, hidden_size, out_dims, dilations, kernel_size, 72 | norm_type='ln', layers_in_block=2, c_multiple=2, 73 | dropout=0.0, ln_eps=1e-5, 74 | init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3): 75 | super(ConvBlocks, self).__init__() 76 | self.is_BTC = is_BTC 77 | if num_layers is not None: 78 | dilations = [1] * num_layers 79 | self.res_blocks = nn.Sequential( 80 | *[ResidualBlock(hidden_size, kernel_size, d, 81 | n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple, 82 | dropout=dropout, ln_eps=ln_eps) 83 | for d in dilations], 84 | ) 85 | if norm_type == 'bn': 86 | norm = nn.BatchNorm1d(hidden_size) 87 | elif norm_type == 'in': 88 | norm = nn.InstanceNorm1d(hidden_size, affine=True) 89 | elif norm_type == 'gn': 90 | norm = nn.GroupNorm(8, hidden_size) 91 | elif norm_type == 'ln': 92 | norm = LayerNorm(hidden_size, dim=1, eps=ln_eps) 93 | self.last_norm = norm 94 | self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel, 95 | padding=post_net_kernel // 2) 96 | if init_weights: 97 | self.apply(init_weights_func) 98 | 99 | def forward(self, x, nonpadding=None): 100 | """ 101 | 102 | :param x: [B, T, H] 103 | :return: [B, T, H] 104 | """ 105 | if self.is_BTC: 106 | x = x.transpose(1, 2) 107 | if nonpadding is None: 108 | nonpadding = (x.abs().sum(1) > 0).float()[:, None, :] 109 | elif self.is_BTC: 110 | nonpadding = nonpadding.transpose(1, 2) 111 | x = self.res_blocks(x) * nonpadding 112 | x = self.last_norm(x) * nonpadding 113 | x = self.post_net1(x) * nonpadding 114 | if self.is_BTC: 115 | x = x.transpose(1, 2) 116 | return x 117 | 118 | 119 | class TextConvEncoder(ConvBlocks): 120 | def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size, 121 | norm_type='ln', layers_in_block=2, c_multiple=2, 122 | dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3): 123 | super().__init__(hidden_size, out_dims, dilations, kernel_size, 124 | norm_type, layers_in_block, c_multiple, 125 | dropout, ln_eps, init_weights, num_layers=num_layers, 126 | post_net_kernel=post_net_kernel) 127 | self.embed_tokens = Embedding(dict_size, hidden_size, 0) 128 | self.embed_scale = math.sqrt(hidden_size) 129 | 130 | def forward(self, txt_tokens): 131 | """ 132 | 133 | :param txt_tokens: [B, T] 134 | :return: { 135 | 'encoder_out': [B x T x C] 136 | } 137 | """ 138 | x = self.embed_scale * self.embed_tokens(txt_tokens) 139 | return super().forward(x) 140 | 141 | 142 | class ConditionalConvBlocks(ConvBlocks): 143 | def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size, 144 | norm_type='ln', layers_in_block=2, c_multiple=2, 145 | dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None): 146 | super().__init__(hidden_size, c_out, dilations, kernel_size, 147 | norm_type, layers_in_block, c_multiple, 148 | dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers) 149 | self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1) 150 | self.is_BTC_ = is_BTC 151 | if init_weights: 152 | self.g_prenet.apply(init_weights_func) 153 | 154 | def forward(self, x, cond, nonpadding=None): 155 | if self.is_BTC_: 156 | x = x.transpose(1, 2) 157 | cond = cond.transpose(1, 2) 158 | if nonpadding is not None: 159 | nonpadding = nonpadding.transpose(1, 2) 160 | if nonpadding is None: 161 | nonpadding = x.abs().sum(1)[:, None] 162 | x = x + self.g_prenet(cond) 163 | x = x * nonpadding 164 | x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC 165 | if self.is_BTC_: 166 | x = x.transpose(1, 2) 167 | return x 168 | -------------------------------------------------------------------------------- /modules/commons/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LayerNorm(torch.nn.LayerNorm): 6 | """Layer normalization module. 7 | :param int nout: output dim size 8 | :param int dim: dimension to be normalized 9 | """ 10 | 11 | def __init__(self, nout, dim=-1, eps=1e-5): 12 | """Construct an LayerNorm object.""" 13 | super(LayerNorm, self).__init__(nout, eps=eps) 14 | self.dim = dim 15 | 16 | def forward(self, x): 17 | """Apply layer normalization. 18 | :param torch.Tensor x: input tensor 19 | :return: layer normalized tensor 20 | :rtype torch.Tensor 21 | """ 22 | if self.dim == -1: 23 | return super(LayerNorm, self).forward(x) 24 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 25 | 26 | 27 | class Reshape(nn.Module): 28 | def __init__(self, *args): 29 | super(Reshape, self).__init__() 30 | self.shape = args 31 | 32 | def forward(self, x): 33 | return x.view(self.shape) 34 | 35 | 36 | class Permute(nn.Module): 37 | def __init__(self, *args): 38 | super(Permute, self).__init__() 39 | self.args = args 40 | 41 | def forward(self, x): 42 | return x.permute(self.args) 43 | 44 | 45 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 46 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 47 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 48 | if padding_idx is not None: 49 | nn.init.constant_(m.weight[padding_idx], 0) 50 | return m 51 | -------------------------------------------------------------------------------- /modules/commons/nar_tts_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from modules.commons.layers import LayerNorm 5 | import torch.nn.functional as F 6 | 7 | 8 | class DurationPredictor(torch.nn.Module): 9 | def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0): 10 | super(DurationPredictor, self).__init__() 11 | self.offset = offset 12 | self.conv = torch.nn.ModuleList() 13 | self.kernel_size = kernel_size 14 | for idx in range(n_layers): 15 | in_chans = idim if idx == 0 else n_chans 16 | self.conv += [torch.nn.Sequential( 17 | torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2), 18 | torch.nn.ReLU(), 19 | LayerNorm(n_chans, dim=1), 20 | torch.nn.Dropout(dropout_rate) 21 | )] 22 | self.linear = nn.Sequential(torch.nn.Linear(n_chans, 1), nn.Softplus()) 23 | 24 | def forward(self, x, x_padding=None): 25 | x = x.transpose(1, -1) # (B, idim, Tmax) 26 | for f in self.conv: 27 | x = f(x) # (B, C, Tmax) 28 | if x_padding is not None: 29 | x = x * (1 - x_padding.float())[:, None, :] 30 | 31 | x = self.linear(x.transpose(1, -1)) # [B, T, C] 32 | x = x * (1 - x_padding.float())[:, :, None] # (B, T, C) 33 | x = x[..., 0] # (B, Tmax) 34 | return x 35 | 36 | 37 | class LengthRegulator(torch.nn.Module): 38 | def __init__(self, pad_value=0.0): 39 | super(LengthRegulator, self).__init__() 40 | self.pad_value = pad_value 41 | 42 | def forward(self, dur, dur_padding=None, alpha=1.0): 43 | """ 44 | Example (no batch dim version): 45 | 1. dur = [2,2,3] 46 | 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] 47 | 3. token_mask = [[1,1,0,0,0,0,0], 48 | [0,0,1,1,0,0,0], 49 | [0,0,0,0,1,1,1]] 50 | 4. token_idx * token_mask = [[1,1,0,0,0,0,0], 51 | [0,0,2,2,0,0,0], 52 | [0,0,0,0,3,3,3]] 53 | 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] 54 | 55 | :param dur: Batch of durations of each frame (B, T_txt) 56 | :param dur_padding: Batch of padding of each frame (B, T_txt) 57 | :param alpha: duration rescale coefficient 58 | :return: 59 | mel2ph (B, T_speech) 60 | assert alpha > 0 61 | """ 62 | dur = torch.round(dur.float() * alpha).long() 63 | if dur_padding is not None: 64 | dur = dur * (1 - dur_padding.long()) 65 | token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) 66 | dur_cumsum = torch.cumsum(dur, 1) 67 | dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0) 68 | 69 | pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) 70 | token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) 71 | mel2token = (token_idx * token_mask.long()).sum(1) 72 | return mel2token 73 | 74 | 75 | class PitchPredictor(torch.nn.Module): 76 | def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5, dropout_rate=0.1): 77 | super(PitchPredictor, self).__init__() 78 | self.conv = torch.nn.ModuleList() 79 | self.kernel_size = kernel_size 80 | for idx in range(n_layers): 81 | in_chans = idim if idx == 0 else n_chans 82 | self.conv += [torch.nn.Sequential( 83 | torch.nn.Conv1d(in_chans, n_chans, kernel_size, padding=kernel_size // 2), 84 | torch.nn.ReLU(), 85 | LayerNorm(n_chans, dim=1), 86 | torch.nn.Dropout(dropout_rate) 87 | )] 88 | self.linear = torch.nn.Linear(n_chans, odim) 89 | 90 | def forward(self, x): 91 | """ 92 | 93 | :param x: [B, T, H] 94 | :return: [B, T, H] 95 | """ 96 | x = x.transpose(1, -1) # (B, idim, Tmax) 97 | for f in self.conv: 98 | x = f(x) # (B, C, Tmax) 99 | x = self.linear(x.transpose(1, -1)) # (B, Tmax, H) 100 | return x 101 | 102 | 103 | class EnergyPredictor(PitchPredictor): 104 | pass 105 | 106 | 107 | class SyntaDurationPredictor(torch.nn.Module): 108 | def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0): 109 | super(SyntaDurationPredictor, self).__init__() 110 | from modules.tts.syntaspeech.syntactic_graph_encoder import GraphAuxEnc 111 | self.graph_encoder = GraphAuxEnc(in_dim=idim, hid_dim=idim, out_dim=idim) 112 | self.offset = offset 113 | self.conv = torch.nn.ModuleList() 114 | self.kernel_size = kernel_size 115 | for idx in range(n_layers): 116 | in_chans = idim if idx == 0 else n_chans 117 | self.conv += [torch.nn.Sequential( 118 | torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2), 119 | torch.nn.ReLU(), 120 | LayerNorm(n_chans, dim=1), 121 | torch.nn.Dropout(dropout_rate) 122 | )] 123 | self.linear = nn.Sequential(torch.nn.Linear(n_chans, 1), nn.Softplus()) 124 | 125 | def forward(self, x, x_padding=None, ph2word=None, graph_lst=None, etypes_lst=None): 126 | x = x.transpose(1, -1) # (B, idim, Tmax) 127 | assert ph2word is not None and graph_lst is not None and etypes_lst is not None 128 | x_graph = self.graph_encoder(graph_lst, x, ph2word, etypes_lst) 129 | x = x + x_graph * 1. 130 | 131 | for f in self.conv: 132 | x = f(x) # (B, C, Tmax) 133 | if x_padding is not None: 134 | x = x * (1 - x_padding.float())[:, None, :] 135 | 136 | x = self.linear(x.transpose(1, -1)) # [B, T, C] 137 | x = x * (1 - x_padding.float())[:, :, None] # (B, T, C) 138 | x = x[..., 0] # (B, Tmax) 139 | return x -------------------------------------------------------------------------------- /modules/commons/wavenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 6 | n_channels_int = n_channels[0] 7 | in_act = input_a + input_b 8 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 9 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 10 | acts = t_act * s_act 11 | return acts 12 | 13 | 14 | class WN(torch.nn.Module): 15 | def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0, 16 | p_dropout=0, share_cond_layers=False, is_BTC=False): 17 | super(WN, self).__init__() 18 | assert (kernel_size % 2 == 1) 19 | assert (hidden_size % 2 == 0) 20 | self.is_BTC = is_BTC 21 | self.hidden_size = hidden_size 22 | self.kernel_size = kernel_size 23 | self.dilation_rate = dilation_rate 24 | self.n_layers = n_layers 25 | self.gin_channels = c_cond 26 | self.p_dropout = p_dropout 27 | self.share_cond_layers = share_cond_layers 28 | 29 | self.in_layers = torch.nn.ModuleList() 30 | self.res_skip_layers = torch.nn.ModuleList() 31 | self.drop = nn.Dropout(p_dropout) 32 | 33 | if c_cond != 0 and not share_cond_layers: 34 | cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1) 35 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 36 | 37 | for i in range(n_layers): 38 | dilation = dilation_rate ** i 39 | padding = int((kernel_size * dilation - dilation) / 2) 40 | in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size, 41 | dilation=dilation, padding=padding) 42 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 43 | self.in_layers.append(in_layer) 44 | 45 | # last one is not necessary 46 | if i < n_layers - 1: 47 | res_skip_channels = 2 * hidden_size 48 | else: 49 | res_skip_channels = hidden_size 50 | 51 | res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1) 52 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 53 | self.res_skip_layers.append(res_skip_layer) 54 | 55 | def forward(self, x, nonpadding=None, cond=None): 56 | if self.is_BTC: 57 | x = x.transpose(1, 2) 58 | cond = cond.transpose(1, 2) if cond is not None else None 59 | nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None 60 | if nonpadding is None: 61 | nonpadding = 1 62 | output = torch.zeros_like(x) 63 | n_channels_tensor = torch.IntTensor([self.hidden_size]) 64 | 65 | if cond is not None and not self.share_cond_layers: 66 | cond = self.cond_layer(cond) 67 | 68 | for i in range(self.n_layers): 69 | x_in = self.in_layers[i](x) 70 | x_in = self.drop(x_in) 71 | if cond is not None: 72 | cond_offset = i * 2 * self.hidden_size 73 | cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :] 74 | else: 75 | cond_l = torch.zeros_like(x_in) 76 | 77 | acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor) 78 | 79 | res_skip_acts = self.res_skip_layers[i](acts) 80 | if i < self.n_layers - 1: 81 | x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding 82 | output = output + res_skip_acts[:, self.hidden_size:, :] 83 | else: 84 | output = output + res_skip_acts 85 | output = output * nonpadding 86 | if self.is_BTC: 87 | output = output.transpose(1, 2) 88 | return output 89 | 90 | def remove_weight_norm(self): 91 | def remove_weight_norm(m): 92 | try: 93 | nn.utils.remove_weight_norm(m) 94 | except ValueError: # this module didn't have weight norm 95 | return 96 | 97 | self.apply(remove_weight_norm) 98 | -------------------------------------------------------------------------------- /modules/tts/commons/align_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def build_word_mask(x2word, y2word): 6 | return (x2word[:, :, None] == y2word[:, None, :]).long() 7 | 8 | 9 | def mel2ph_to_mel2word(mel2ph, ph2word): 10 | mel2word = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1 11 | mel2word = mel2word * (mel2ph > 0).long() 12 | return mel2word 13 | 14 | 15 | def clip_mel2token_to_multiple(mel2token, frames_multiple): 16 | max_frames = mel2token.shape[1] // frames_multiple * frames_multiple 17 | mel2token = mel2token[:, :max_frames] 18 | return mel2token 19 | 20 | 21 | def expand_states(h, mel2token): 22 | h = F.pad(h, [0, 0, 1, 0]) 23 | mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) 24 | h = torch.gather(h, 1, mel2token_) # [B, T, H] 25 | return h 26 | -------------------------------------------------------------------------------- /modules/tts/diffprosody/diffprosody.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import Linear 6 | from modules.commons.layers import Embedding 7 | from modules.commons.rel_transformer import RelTransformerEncoder 8 | from modules.tts.commons.align_ops import build_word_mask, expand_states 9 | from modules.tts.fs import FS_DECODERS, FastSpeech 10 | from .prosody_encoder import ProsodyEncoder 11 | 12 | class SinusoidalPosEmb(nn.Module): 13 | def __init__(self, dim): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | def forward(self, x): 18 | """ 19 | :param x: [B, T] 20 | :return: [B, T, H] 21 | """ 22 | device = x.device 23 | half_dim = self.dim // 2 24 | emb = math.log(10000) / (half_dim - 1) 25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 26 | emb = x[:, :, None] * emb[None, :] 27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 28 | return emb 29 | 30 | class Mish(nn.Module): 31 | def forward(self, x): 32 | return x * torch.tanh(F.softplus(x)) 33 | 34 | class DiffProsody(FastSpeech): 35 | def __init__(self, 36 | ph_dict_size, word_dict_size, 37 | hparams, out_dims=None): 38 | super().__init__(ph_dict_size, hparams, out_dims) 39 | # build linguistic encoder 40 | self.word_encoder = RelTransformerEncoder( 41 | word_dict_size, self.hidden_size, self.hidden_size, self.hidden_size, 2, 42 | hparams['word_enc_layers'], hparams['enc_ffn_kernel_size']) 43 | 44 | self.sin_pos = SinusoidalPosEmb(self.hidden_size) 45 | 46 | self.decoder = FS_DECODERS[hparams['decoder_type']](hparams) 47 | self.prosody_encoder = ProsodyEncoder(hparams) 48 | 49 | self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True) 50 | self.word_pos_proj = Linear(self.hidden_size, self.hidden_size) 51 | 52 | def build_embedding(self, dictionary, embed_dim): 53 | num_embeddings = len(dictionary) 54 | emb = Embedding(num_embeddings, embed_dim, self.padding_idx) 55 | return emb 56 | 57 | def forward(self, txt_tokens, word_tokens, ph2word, word_len, mel2word=None, mel2ph=None, 58 | spk_embed=None, spk_id=None, pitch=None, infer=False, tgt_mels=None, bert_tokens=None, 59 | global_step=None, lpv=None, *args, **kwargs): 60 | ret = {} 61 | # print(bert_tokens.shape, word_tokens.shape) 62 | h_ling = self.run_text_encoder( 63 | txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, ret) 64 | h_spk = self.forward_style_embed(spk_embed, spk_id) 65 | 66 | vq_loss = None 67 | if not infer: 68 | wrd_nonpadding = (word_tokens > 0).float()[:, :, None] 69 | vq_loss, lpv, lpv_idx, perplexity, _ = self.prosody_encoder(tgt_mels, h_ling, h_spk, 70 | mel2word, mel2ph, ph2word, word_len, 71 | wrd_nonpadding, global_step) 72 | if global_step > self.hparams["vq_warmup"]: 73 | # print(wrd_nonpadding.shape, lpv_idx.shape) 74 | lpv_idx = lpv_idx.masked_select(wrd_nonpadding.unsqueeze(-1).bool()) 75 | ret['lpv_idx'] = lpv_idx 76 | ret['perplexity'] = perplexity 77 | else: 78 | assert lpv is not None, 'LPV required for inference' 79 | x = h_ling + h_spk + expand_states(lpv, ph2word) 80 | mel2ph = self.forward_dur(x, mel2ph, txt_tokens, ret) 81 | # mel2ph = clip_mel2token_to_multiple(mel2ph, self.hparams['frames_multiple']) 82 | tgt_nonpadding = (mel2ph > 0).float()[:, :, None] 83 | 84 | x = expand_states(x, mel2ph) 85 | x = x * tgt_nonpadding 86 | 87 | ret['nonpadding'] = tgt_nonpadding 88 | ret['decoder_inp'] = x 89 | ret['lpv'] = lpv 90 | ret['lpv_long'] = expand_states(expand_states(lpv, ph2word), mel2ph) 91 | 92 | ret['vq_loss'] = vq_loss 93 | ret['mel_out'] = self.run_decoder(x, tgt_nonpadding, ret, infer, tgt_mels, global_step) 94 | 95 | return ret 96 | 97 | def forward_style_embed(self, spk_embed=None, spk_id=None): 98 | # add spk embed 99 | # style_embed = self.spk_id_proj(spk_id)[:, None, :] 100 | style_embed = self.spk_embed_proj(spk_embed)[:, None, :] 101 | return style_embed 102 | 103 | def get_lpv(self, 104 | txt_tokens, word_tokens, 105 | ph2word, word_len, 106 | mel2word, mel2ph, 107 | spk_embed, 108 | tgt_mels, global_step): 109 | ret = {} 110 | h_ling = self.run_text_encoder( 111 | txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, ret) 112 | h_spk = self.forward_style_embed(spk_embed) 113 | 114 | wrd_nonpadding = (word_tokens > 0).float()[:, :, None] 115 | _, _, idx, _, lpv = self.prosody_encoder(tgt_mels, h_ling, h_spk, 116 | mel2word, mel2ph, ph2word, word_len, 117 | wrd_nonpadding, global_step) 118 | return lpv, idx 119 | 120 | def run_text_encoder(self, txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, ret): 121 | src_nonpadding = (txt_tokens > 0).float()[:, :, None] 122 | ph_encoder_out = self.encoder(txt_tokens) * src_nonpadding 123 | word_encoder_out = self.word_encoder(word_tokens) 124 | ph_encoder_out = ph_encoder_out + expand_states(word_encoder_out, ph2word) 125 | 126 | return ph_encoder_out 127 | 128 | def run_decoder(self, x, tgt_nonpadding, ret, infer, tgt_mels=None, global_step=0): 129 | x = self.decoder(x) 130 | x = self.mel_out(x) 131 | return x * tgt_nonpadding 132 | 133 | def get_pos_embed(self, word2word, x2word): 134 | x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph] 135 | x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1) 136 | x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H] 137 | return x_pos 138 | 139 | @property 140 | def device(self): 141 | return next(self.parameters()).device -------------------------------------------------------------------------------- /modules/tts/diffprosody/prosody_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.commons.conv import ConditionalConvBlocks, ConvBlocks 5 | from utils.nn.seq_utils import group_hidden_by_segs 6 | from modules.tts.commons.align_ops import expand_states 7 | 8 | class VectorQuantizerEMA(nn.Module): 9 | def __init__(self, num_embeddings, embedding_dim, commitment_cost, 10 | cluster_centers=None, decay=0.996, epsilon=1e-5): 11 | super(VectorQuantizerEMA, self).__init__() 12 | 13 | self._embedding_dim = embedding_dim 14 | self._num_embeddings = num_embeddings 15 | 16 | if cluster_centers is not None: 17 | self._embedding = nn.Embedding.from_pretrained(cluster_centers) 18 | else: 19 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 20 | self._embedding.weight.data.uniform_(-1.0 / self._num_embeddings, 1.0 / self._embedding_dim) 21 | 22 | self._commitment_cost = commitment_cost 23 | 24 | self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) 25 | self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) 26 | self._ema_w.data.normal_() 27 | 28 | self._decay = decay 29 | self._epsilon = epsilon 30 | 31 | def forward(self, inputs): 32 | # convert inputs from BCHW -> BHWC 33 | inputs = inputs.contiguous() 34 | input_shape = inputs.shape 35 | 36 | # Flatten input 37 | flat_input = inputs.view(-1, self._embedding_dim) 38 | 39 | # Calculate distances 40 | distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 41 | + torch.sum(self._embedding.weight**2, dim=1) 42 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) 43 | 44 | # Encoding 45 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 46 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 47 | encodings.scatter_(1, encoding_indices, 1) 48 | # Quantize and unflatten 49 | quantized = self._embedding(encoding_indices).view(inputs.shape) 50 | 51 | # Use EMA to update the embedding vectors 52 | if self.training: 53 | self._ema_cluster_size = self._ema_cluster_size * self._decay + \ 54 | (1 - self._decay) * torch.sum(encodings, 0) 55 | 56 | # Laplace smoothing of the cluster size 57 | n = torch.sum(self._ema_cluster_size.data) 58 | self._ema_cluster_size = ( 59 | (self._ema_cluster_size + self._epsilon) 60 | / (n + self._num_embeddings * self._epsilon) * n) 61 | 62 | dw = torch.matmul(encodings.t(), flat_input) 63 | self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) 64 | 65 | self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) 66 | 67 | # Loss 68 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 69 | loss = self._commitment_cost * e_latent_loss 70 | 71 | # Straight Through Estimator 72 | quantized = inputs + (quantized - inputs).detach() 73 | avg_probs = torch.mean(encodings, dim=0) 74 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 75 | encoding_indices = encoding_indices.view(input_shape[0], input_shape[1]) 76 | # convert quantized from BHWC -> BCHW 77 | return loss, quantized, perplexity, encodings, encoding_indices 78 | 79 | class Mish(nn.Module): 80 | def forward(self, x): 81 | return x * torch.tanh(F.softplus(x)) 82 | 83 | class LinearNorm(nn.Module): 84 | """ LinearNorm Projection """ 85 | 86 | def __init__(self, in_features, out_features, bias=False): 87 | super(LinearNorm, self).__init__() 88 | self.linear = nn.Linear(in_features, out_features, bias) 89 | 90 | nn.init.xavier_uniform_(self.linear.weight) 91 | if bias: 92 | nn.init.constant_(self.linear.bias, 0.0) 93 | 94 | def forward(self, x): 95 | x = self.linear(x) 96 | return x 97 | 98 | class ProsodyEncoder(nn.Module): 99 | def __init__(self, hparams, ln_eps=1e-12): 100 | super(ProsodyEncoder, self).__init__() 101 | self.hparams = hparams 102 | self.n = 5 103 | self.hidden_size = self.hparams["hidden_size"] 104 | self.kernel_size = 5 105 | self.n_embeddings = 128 106 | self.embedding_dim = self.hparams["hidden_size"] 107 | self.beta = self.hparams["commitment_cost"] 108 | print("Prosody mel bins: ", self.hparams["prosody_mel_bins"]) 109 | self.pre = nn.Sequential( 110 | nn.Linear(self.hparams["prosody_mel_bins"], self.hidden_size // 4), 111 | Mish(), 112 | nn.Linear(self.hidden_size // 4, self.hidden_size) 113 | ) 114 | self.conv1 = ConditionalConvBlocks(self.hidden_size, self.hidden_size, self.hidden_size, 115 | None, self.kernel_size, num_layers=self.n) 116 | self.conv2 = ConditionalConvBlocks(self.hidden_size, self.hidden_size, self.hidden_size, 117 | None, self.kernel_size, num_layers=self.n) 118 | self.vector_quantization = None 119 | self.post_net = ConvBlocks(self.hidden_size, self.hidden_size, None, 1, num_layers=3) 120 | 121 | def init_vq(self, cluster_centers): 122 | self.vector_quantization = VectorQuantizerEMA( 123 | self.n_embeddings, self.embedding_dim, self.beta, cluster_centers=cluster_centers, decay=self.hparams["ema_decay"]).cuda() 124 | print("Initialized Codebook with cluster centers [EMA]") 125 | def forward(self, x, h_lin, h_spk, mel2word, mel2ph, ph2word, word_len, wrd_nonpadding, global_step=None): 126 | x = x[:, :, :self.hparams["prosody_mel_bins"]] # Mel: 80 bin -> 20 bin 127 | x = self.pre(x) 128 | 129 | cond = h_lin + h_spk # Phoneme-level 130 | cond1 = expand_states(cond, mel2ph) # Frame level 131 | x = self.conv1(x, cond1) 132 | x = group_hidden_by_segs(x, mel2word, word_len)[0] # Word-level 133 | cond2 = group_hidden_by_segs(cond, ph2word, word_len)[0] 134 | x = self.conv2(x, cond2) 135 | x = self.post_net(x) 136 | 137 | if global_step > self.hparams["vq_warmup"]: 138 | embedding_loss, x2, perplexity, min_encodings, min_encoding_indices = self.vector_quantization(x) # VQ 139 | x2 = x2 * wrd_nonpadding 140 | return embedding_loss, x2, min_encoding_indices, perplexity, x 141 | else: 142 | return None, x, None, None, x -------------------------------------------------------------------------------- /modules/tts/fs2_orig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules.commons.layers import Embedding 4 | from modules.commons.nar_tts_modules import EnergyPredictor, PitchPredictor 5 | from modules.tts.commons.align_ops import expand_states 6 | from modules.tts.fs import FastSpeech 7 | from utils.audio.cwt import cwt2f0, get_lf0_cwt 8 | from utils.audio.pitch.utils import denorm_f0, f0_to_coarse, norm_f0 9 | import numpy as np 10 | 11 | 12 | class FastSpeech2Orig(FastSpeech): 13 | def __init__(self, dict_size, hparams, out_dims=None): 14 | super().__init__(dict_size, hparams, out_dims) 15 | predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size 16 | if hparams['use_energy_embed']: 17 | self.energy_embed = Embedding(256, self.hidden_size, 0) 18 | self.energy_predictor = EnergyPredictor( 19 | self.hidden_size, n_chans=predictor_hidden, 20 | n_layers=hparams['predictor_layers'], dropout_rate=hparams['predictor_dropout'], odim=2, 21 | kernel_size=hparams['predictor_kernel']) 22 | if hparams['pitch_type'] == 'cwt' and hparams['use_pitch_embed']: 23 | self.pitch_predictor = PitchPredictor( 24 | self.hidden_size, n_chans=predictor_hidden, 25 | n_layers=hparams['predictor_layers'], dropout_rate=hparams['predictor_dropout'], odim=11, 26 | kernel_size=hparams['predictor_kernel']) 27 | self.cwt_stats_layers = nn.Sequential( 28 | nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), 29 | nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, 2)) 30 | 31 | def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None, 32 | f0=None, uv=None, energy=None, infer=False, spk_vec=None, **kwargs): 33 | ret = {} 34 | encoder_out = self.encoder(txt_tokens) # [B, T, C] 35 | src_nonpadding = (txt_tokens > 0).float()[:, :, None] 36 | if spk_vec is not None: 37 | style_embed = spk_vec 38 | else: 39 | style_embed = self.forward_style_embed(spk_embed, spk_id) 40 | 41 | # add dur 42 | dur_inp = (encoder_out + style_embed) * src_nonpadding 43 | mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret) 44 | tgt_nonpadding = (mel2ph > 0).float()[:, :, None] 45 | decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph) 46 | 47 | # add pitch and energy embed 48 | if self.hparams['use_pitch_embed']: 49 | pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding 50 | decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out) 51 | 52 | # add pitch and energy embed 53 | if self.hparams['use_energy_embed']: 54 | energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding 55 | decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret) 56 | 57 | # decoder input 58 | ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding 59 | if self.hparams['dec_inp_add_noise']: 60 | B, T, _ = decoder_inp.shape 61 | z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device) 62 | ret['adv_z'] = z 63 | decoder_inp = torch.cat([decoder_inp, z], -1) 64 | decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding 65 | ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) 66 | return ret 67 | 68 | def forward_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None): 69 | if self.hparams['pitch_type'] == 'cwt': 70 | decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) 71 | pitch_padding = mel2ph == 0 72 | ret['cwt'] = cwt_out = self.pitch_predictor(decoder_inp) 73 | stats_out = self.cwt_stats_layers(decoder_inp.mean(1)) # [B, 2] 74 | mean = ret['f0_mean'] = stats_out[:, 0] 75 | std = ret['f0_std'] = stats_out[:, 1] 76 | cwt_spec = cwt_out[:, :, :10] 77 | if f0 is None: 78 | std = std * self.hparams['cwt_std_scale'] 79 | f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph) 80 | if self.hparams['use_uv']: 81 | assert cwt_out.shape[-1] == 11 82 | uv = cwt_out[:, :, -1] > 0 83 | ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv if self.hparams['use_uv'] else None, 84 | pitch_padding=pitch_padding) 85 | pitch = f0_to_coarse(f0_denorm) # start from 0 86 | pitch_embed = self.pitch_embed(pitch) 87 | return pitch_embed 88 | else: 89 | return super(FastSpeech2Orig, self).forward_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out) 90 | 91 | def forward_energy(self, decoder_inp, energy, ret): 92 | decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) 93 | ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0] 94 | energy_embed_inp = energy_pred if energy is None else energy 95 | energy_embed_inp = torch.clamp(energy_embed_inp * 256 // 4, min=0, max=255).long() 96 | energy_embed = self.energy_embed(energy_embed_inp) 97 | return energy_embed 98 | 99 | def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): 100 | _, cwt_scales = get_lf0_cwt(np.ones(10)) 101 | f0 = cwt2f0(cwt_spec, mean, std, cwt_scales) 102 | f0 = torch.cat( 103 | [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1) 104 | f0_norm = norm_f0(f0, None) 105 | return f0_norm 106 | -------------------------------------------------------------------------------- /modules/vocoder/hifigan/mel_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, hparams, center=False, complex=False): 46 | # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) 47 | # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 48 | # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 49 | # fmax: 10000 # To be increased/reduced depending on data. 50 | # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter 51 | # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, 52 | n_fft = hparams['fft_size'] 53 | num_mels = hparams['audio_num_mel_bins'] 54 | sampling_rate = hparams['audio_sample_rate'] 55 | hop_size = hparams['hop_size'] 56 | win_size = hparams['win_size'] 57 | fmin = hparams['fmin'] 58 | fmax = hparams['fmax'] 59 | y = y.clamp(min=-1., max=1.) 60 | global mel_basis, hann_window 61 | if fmax not in mel_basis: 62 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 63 | mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 64 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 65 | 66 | y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)], 67 | mode='reflect') 68 | y = y.squeeze(1) 69 | 70 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 71 | center=center, pad_mode='reflect', normalized=False, onesided=True) 72 | 73 | if not complex: 74 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 75 | spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) 76 | spec = spectral_normalize_torch(spec) 77 | else: 78 | B, C, T, _ = spec.shape 79 | spec = spec.transpose(1, 2) # [B, T, n_fft, 2] 80 | return spec 81 | -------------------------------------------------------------------------------- /modules/vocoder/hifigan/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def stft(x, fft_size, hop_size, win_length, window): 13 | """Perform STFT and convert to magnitude spectrogram. 14 | Args: 15 | x (Tensor): Input signal tensor (B, T). 16 | fft_size (int): FFT size. 17 | hop_size (int): Hop size. 18 | win_length (int): Window length. 19 | window (str): Window function type. 20 | Returns: 21 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 22 | """ 23 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 24 | real = x_stft[..., 0] 25 | imag = x_stft[..., 1] 26 | 27 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 28 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 29 | 30 | 31 | class SpectralConvergengeLoss(torch.nn.Module): 32 | """Spectral convergence loss module.""" 33 | 34 | def __init__(self): 35 | """Initilize spectral convergence loss module.""" 36 | super(SpectralConvergengeLoss, self).__init__() 37 | 38 | def forward(self, x_mag, y_mag): 39 | """Calculate forward propagation. 40 | Args: 41 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 42 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 43 | Returns: 44 | Tensor: Spectral convergence loss value. 45 | """ 46 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 47 | 48 | 49 | class LogSTFTMagnitudeLoss(torch.nn.Module): 50 | """Log STFT magnitude loss module.""" 51 | 52 | def __init__(self): 53 | """Initilize los STFT magnitude loss module.""" 54 | super(LogSTFTMagnitudeLoss, self).__init__() 55 | 56 | def forward(self, x_mag, y_mag): 57 | """Calculate forward propagation. 58 | Args: 59 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 60 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 61 | Returns: 62 | Tensor: Log STFT magnitude loss value. 63 | """ 64 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 65 | 66 | 67 | class STFTLoss(torch.nn.Module): 68 | """STFT loss module.""" 69 | 70 | def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 71 | """Initialize STFT loss module.""" 72 | super(STFTLoss, self).__init__() 73 | self.fft_size = fft_size 74 | self.shift_size = shift_size 75 | self.win_length = win_length 76 | self.window = getattr(torch, window)(win_length) 77 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 78 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 79 | 80 | def forward(self, x, y): 81 | """Calculate forward propagation. 82 | Args: 83 | x (Tensor): Predicted signal (B, T). 84 | y (Tensor): Groundtruth signal (B, T). 85 | Returns: 86 | Tensor: Spectral convergence loss value. 87 | Tensor: Log STFT magnitude loss value. 88 | """ 89 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device())) 90 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device())) 91 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 92 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 93 | 94 | return sc_loss, mag_loss 95 | 96 | 97 | class MultiResolutionSTFTLoss(torch.nn.Module): 98 | """Multi resolution STFT loss module.""" 99 | 100 | def __init__(self, 101 | fft_sizes=[1024, 2048, 512], 102 | hop_sizes=[120, 240, 50], 103 | win_lengths=[600, 1200, 240], 104 | window="hann_window"): 105 | """Initialize Multi resolution STFT loss module. 106 | Args: 107 | fft_sizes (list): List of FFT sizes. 108 | hop_sizes (list): List of hop sizes. 109 | win_lengths (list): List of window lengths. 110 | window (str): Window function type. 111 | """ 112 | super(MultiResolutionSTFTLoss, self).__init__() 113 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 114 | self.stft_losses = torch.nn.ModuleList() 115 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 116 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 117 | 118 | def forward(self, x, y): 119 | """Calculate forward propagation. 120 | Args: 121 | x (Tensor): Predicted signal (B, T). 122 | y (Tensor): Groundtruth signal (B, T). 123 | Returns: 124 | Tensor: Multi resolution spectral convergence loss value. 125 | Tensor: Multi resolution log STFT magnitude loss value. 126 | """ 127 | sc_loss = 0.0 128 | mag_loss = 0.0 129 | for f in self.stft_losses: 130 | sc_l, mag_l = f(x, y) 131 | sc_loss += sc_l 132 | mag_loss += mag_l 133 | sc_loss /= len(self.stft_losses) 134 | mag_loss /= len(self.stft_losses) 135 | 136 | return sc_loss, mag_loss -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | librosa 3 | tqdm 4 | pandas 5 | numba 6 | numpy 7 | scipy 8 | PyYAML 9 | tensorboardX 10 | pyloudnorm 11 | setuptools>=41.0.0 12 | g2p_en 13 | resemblyzer 14 | webrtcvad 15 | tensorboard 16 | scikit-learn 17 | scikit-image 18 | textgrid 19 | jiwer 20 | pycwt 21 | PyWavelets 22 | praat-parselmouth 23 | jieba 24 | einops 25 | chardet -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=. 2 | DEVICE=0; 3 | DIR_NAME="/workspace/checkpoints/"; 4 | MODEL_NAME="DiffProsody"; 5 | HPARAMS="" 6 | 7 | CONFIG="egs/datasets/audio/vctk/diffprosody.yaml"; 8 | CUDA_VISIBLE_DEVICES=$DEVICE python tasks/run.py --config $CONFIG --exp_name $MODEL_NAME --reset --hparams=$HPARAMS 9 | CUDA_VISIBLE_DEVICES=$DEVICE python tasks/run.py --config $CONFIG --exp_name $MODEL_NAME --infer --hparams=$HPARAMS 10 | 11 | CUDA_VISIBLE_DEVICES=$DEVICE python extract_lpv.py --config $CONFIG --exp_name $MODEL_NAME 12 | 13 | MODEL_NAME2="DiffProsodyGenerator"; 14 | CONFIG2="egs/datasets/audio/vctk/prosody_generator.yaml"; 15 | HPARAMS2="tts_model=$DIR_NAME$MODEL_NAME" 16 | 17 | CUDA_VISIBLE_DEVICES=$DEVICE python tasks/run.py --config $CONFIG2 --exp_name $MODEL_NAME2 --reset --hparams=$HPARAMS2 18 | CUDA_VISIBLE_DEVICES=$DEVICE python tasks/run.py --config $CONFIG2 --exp_name $MODEL_NAME2 --infer --hparams=$HPARAMS2 -------------------------------------------------------------------------------- /tasks/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | from utils.commons.hparams import hparams, set_hparams 6 | import importlib 7 | 8 | 9 | def run_task(): 10 | assert hparams['task_cls'] != '' 11 | pkg = ".".join(hparams["task_cls"].split(".")[:-1]) 12 | cls_name = hparams["task_cls"].split(".")[-1] 13 | task_cls = getattr(importlib.import_module(pkg), cls_name) 14 | task_cls.start() 15 | 16 | 17 | if __name__ == '__main__': 18 | set_hparams() 19 | run_task() 20 | -------------------------------------------------------------------------------- /tasks/tts/tts_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from data_gen.tts.base_binarizer import BaseBinarizer 4 | from data_gen.tts.base_preprocess import BasePreprocessor 5 | from utils.commons.hparams import hparams 6 | 7 | def parse_dataset_configs(): 8 | max_tokens = hparams['max_tokens'] 9 | max_sentences = hparams['max_sentences'] 10 | max_valid_tokens = hparams['max_valid_tokens'] 11 | if max_valid_tokens == -1: 12 | hparams['max_valid_tokens'] = max_valid_tokens = max_tokens 13 | max_valid_sentences = hparams['max_valid_sentences'] 14 | if max_valid_sentences == -1: 15 | hparams['max_valid_sentences'] = max_valid_sentences = max_sentences 16 | return max_tokens, max_sentences, max_valid_tokens, max_valid_sentences 17 | 18 | 19 | def parse_mel_losses(): 20 | mel_losses = hparams['mel_losses'].split("|") 21 | loss_and_lambda = {} 22 | for i, l in enumerate(mel_losses): 23 | if l == '': 24 | continue 25 | if ':' in l: 26 | l, lbd = l.split(":") 27 | lbd = float(lbd) 28 | else: 29 | lbd = 1.0 30 | loss_and_lambda[l] = lbd 31 | print("| Mel losses:", loss_and_lambda) 32 | return loss_and_lambda 33 | 34 | 35 | def load_data_preprocessor(): 36 | preprocess_cls = hparams["preprocess_cls"] 37 | pkg = ".".join(preprocess_cls.split(".")[:-1]) 38 | cls_name = preprocess_cls.split(".")[-1] 39 | preprocessor: BasePreprocessor = getattr(importlib.import_module(pkg), cls_name)() 40 | preprocess_args = {} 41 | preprocess_args.update(hparams['preprocess_args']) 42 | return preprocessor, preprocess_args 43 | 44 | 45 | def load_data_binarizer(): 46 | binarizer_cls = hparams['binarizer_cls'] 47 | pkg = ".".join(binarizer_cls.split(".")[:-1]) 48 | cls_name = binarizer_cls.split(".")[-1] 49 | binarizer: BaseBinarizer = getattr(importlib.import_module(pkg), cls_name)() 50 | binarization_args = {} 51 | binarization_args.update(hparams['binarization_args']) 52 | return binarizer, binarization_args 53 | -------------------------------------------------------------------------------- /tasks/tts/vocoder_infer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hifigan -------------------------------------------------------------------------------- /tasks/tts/vocoder_infer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hs-oh-prml/DiffProsody/6d5b6dbb58497fdff791d06fca09a4fae2a2cc11/tasks/tts/vocoder_infer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tasks/tts/vocoder_infer/__pycache__/base_vocoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hs-oh-prml/DiffProsody/6d5b6dbb58497fdff791d06fca09a4fae2a2cc11/tasks/tts/vocoder_infer/__pycache__/base_vocoder.cpython-38.pyc -------------------------------------------------------------------------------- /tasks/tts/vocoder_infer/__pycache__/hifigan.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hs-oh-prml/DiffProsody/6d5b6dbb58497fdff791d06fca09a4fae2a2cc11/tasks/tts/vocoder_infer/__pycache__/hifigan.cpython-38.pyc -------------------------------------------------------------------------------- /tasks/tts/vocoder_infer/base_vocoder.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | from utils.audio import librosa_wav2spec 3 | from utils.commons.hparams import hparams 4 | import numpy as np 5 | 6 | REGISTERED_VOCODERS = {} 7 | 8 | 9 | def register_vocoder(name): 10 | def _f(cls): 11 | REGISTERED_VOCODERS[name] = cls 12 | return cls 13 | 14 | return _f 15 | 16 | 17 | def get_vocoder_cls(vocoder_name): 18 | return REGISTERED_VOCODERS.get(vocoder_name) 19 | 20 | 21 | class BaseVocoder: 22 | def spec2wav(self, mel): 23 | """ 24 | 25 | :param mel: [T, 80] 26 | :return: wav: [T'] 27 | """ 28 | 29 | raise NotImplementedError 30 | 31 | @staticmethod 32 | def wav2spec(wav_fn): 33 | """ 34 | 35 | :param wav_fn: str 36 | :return: wav, mel: [T, 80] 37 | """ 38 | wav_spec_dict = librosa_wav2spec(wav_fn, fft_size=hparams['fft_size'], 39 | hop_size=hparams['hop_size'], 40 | win_length=hparams['win_size'], 41 | num_mels=hparams['audio_num_mel_bins'], 42 | fmin=hparams['fmin'], 43 | fmax=hparams['fmax'], 44 | sample_rate=hparams['audio_sample_rate'], 45 | loud_norm=hparams['loud_norm']) 46 | wav = wav_spec_dict['wav'] 47 | mel = wav_spec_dict['mel'] 48 | return wav, mel 49 | 50 | @staticmethod 51 | def wav2mfcc(wav_fn): 52 | fft_size = hparams['fft_size'] 53 | hop_size = hparams['hop_size'] 54 | win_length = hparams['win_size'] 55 | sample_rate = hparams['audio_sample_rate'] 56 | wav, _ = librosa.core.load(wav_fn, sr=sample_rate) 57 | mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13, 58 | n_fft=fft_size, hop_length=hop_size, 59 | win_length=win_length, pad_mode="constant", power=1.0) 60 | mfcc_delta = librosa.feature.delta(mfcc, order=1) 61 | mfcc_delta_delta = librosa.feature.delta(mfcc, order=2) 62 | mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T 63 | return mfcc 64 | -------------------------------------------------------------------------------- /tasks/tts/vocoder_infer/hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.vocoder.hifigan.hifigan import HifiGanGenerator 3 | from tasks.tts.vocoder_infer.base_vocoder import register_vocoder, BaseVocoder 4 | from utils.commons.hparams import set_hparams, hparams 5 | from utils.commons.meters import Timer 6 | 7 | total_time = 0 8 | @register_vocoder('HifiGAN') 9 | class HifiGAN(BaseVocoder): 10 | def __init__(self): 11 | base_dir = hparams['vocoder_ckpt'] 12 | config_path = f'{base_dir}/config.yaml' 13 | self.config = config = set_hparams(config_path, global_hparams=False) 14 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | self.model = HifiGanGenerator(config) 16 | checkpoint_dict = torch.load("checkpoints/hifi_vctk/generator_v1", map_location=self.device) 17 | self.model.load_state_dict(checkpoint_dict['generator']) 18 | self.model.to(self.device) 19 | self.model.eval() 20 | 21 | def spec2wav(self, mel, **kwargs): 22 | device = self.device 23 | with torch.no_grad(): 24 | c = torch.FloatTensor(mel).unsqueeze(0).to(device) 25 | c = c.transpose(2, 1) 26 | with Timer('hifigan', enable=hparams['profile_infer']): 27 | y = self.model(c).view(-1) 28 | wav_out = y.cpu().numpy() 29 | return wav_out -------------------------------------------------------------------------------- /tasks/vocoder/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | from torch.utils.data import DistributedSampler 5 | from utils.commons.dataset_utils import BaseDataset, collate_1d, collate_2d 6 | from utils.commons.hparams import hparams 7 | from utils.commons.indexed_datasets import IndexedDataset 8 | 9 | 10 | class EndlessDistributedSampler(DistributedSampler): 11 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 12 | if num_replicas is None: 13 | if not dist.is_available(): 14 | raise RuntimeError("Requires distributed package to be available") 15 | num_replicas = dist.get_world_size() 16 | if rank is None: 17 | if not dist.is_available(): 18 | raise RuntimeError("Requires distributed package to be available") 19 | rank = dist.get_rank() 20 | self.dataset = dataset 21 | self.num_replicas = num_replicas 22 | self.rank = rank 23 | self.epoch = 0 24 | self.shuffle = shuffle 25 | 26 | g = torch.Generator() 27 | g.manual_seed(self.epoch) 28 | if self.shuffle: 29 | indices = [i for _ in range(1000) for i in torch.randperm( 30 | len(self.dataset), generator=g).tolist()] 31 | else: 32 | indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))] 33 | indices = indices[:len(indices) // self.num_replicas * self.num_replicas] 34 | indices = indices[self.rank::self.num_replicas] 35 | self.indices = indices 36 | 37 | def __iter__(self): 38 | return iter(self.indices) 39 | 40 | def __len__(self): 41 | return len(self.indices) 42 | 43 | 44 | class VocoderDataset(BaseDataset): 45 | def __init__(self, prefix, shuffle=False): 46 | super().__init__(shuffle) 47 | self.hparams = hparams 48 | self.prefix = prefix 49 | self.data_dir = hparams['binary_data_dir'] 50 | self.is_infer = prefix == 'test' 51 | self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size'] 52 | self.hop_size = hparams['hop_size'] 53 | self.indexed_ds = None 54 | self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') 55 | self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if s > self.batch_max_frames] 56 | print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.") 57 | self.sizes = [s for idx, s in enumerate(self.sizes) if s > self.batch_max_frames] 58 | 59 | def _get_item(self, index): 60 | if self.indexed_ds is None: 61 | self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') 62 | item = self.indexed_ds[index] 63 | return item 64 | 65 | def __getitem__(self, index): 66 | index = self.avail_idxs[index] 67 | item = self._get_item(index) 68 | sample = { 69 | "id": index, 70 | "item_name": item['item_name'], 71 | "mel": torch.FloatTensor(item['mel']), 72 | "wav": torch.FloatTensor(item['wav'].astype(np.float32)), 73 | "pitch": torch.LongTensor(item['pitch']), 74 | "f0": torch.FloatTensor(item['f0']) 75 | } 76 | return sample 77 | 78 | def collater(self, batch): 79 | if len(batch) == 0: 80 | return {} 81 | 82 | y_batch, c_batch, p_batch, f0_batch = [], [], [], [] 83 | item_name = [] 84 | for idx in range(len(batch)): 85 | item_name.append(batch[idx]['item_name']) 86 | x, c = batch[idx]['wav'], batch[idx]['mel'] 87 | p, f0 = batch[idx]['pitch'], batch[idx]['f0'] 88 | self._assert_ready_for_upsampling(x, c, self.hop_size) 89 | if len(c) > self.batch_max_frames: 90 | # randomly pickup with the batch_max_steps length of the part 91 | batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len(c) - 1 92 | batch_max_steps = batch_max_frames * self.hop_size 93 | interval_start = 0 94 | interval_end = len(c) - batch_max_frames 95 | start_frame = np.random.randint(interval_start, interval_end) 96 | start_step = start_frame * self.hop_size 97 | y = x[start_step: start_step + batch_max_steps] 98 | c = c[start_frame: start_frame + batch_max_frames] 99 | p = p[start_frame: start_frame + batch_max_frames] 100 | f0 = f0[start_frame: start_frame + batch_max_frames] 101 | self._assert_ready_for_upsampling(y, c, self.hop_size) 102 | else: 103 | print(f"Removed short sample from batch (length={len(x)}).") 104 | continue 105 | y_batch += [y.reshape(-1, 1)] # [(T, 1), (T, 1), ...] 106 | c_batch += [c] # [(T' C), (T' C), ...] 107 | p_batch += [p] # [(T' C), (T' C), ...] 108 | f0_batch += [f0] # [(T' C), (T' C), ...] 109 | 110 | # convert each batch to tensor, asuume that each item in batch has the same length 111 | y_batch = collate_2d(y_batch, 0).transpose(2, 1) # (B, 1, T) 112 | c_batch = collate_2d(c_batch, 0).transpose(2, 1) # (B, C, T') 113 | p_batch = collate_1d(p_batch, 0) # (B, T') 114 | f0_batch = collate_1d(f0_batch, 0) # (B, T') 115 | 116 | # make input noise signal batch tensor 117 | z_batch = torch.randn(y_batch.size()) # (B, 1, T) 118 | return { 119 | 'z': z_batch, 120 | 'mels': c_batch, 121 | 'wavs': y_batch, 122 | 'pitches': p_batch, 123 | 'f0': f0_batch, 124 | 'item_name': item_name 125 | } 126 | 127 | @staticmethod 128 | def _assert_ready_for_upsampling(x, c, hop_size): 129 | """Assert the audio and feature lengths are correctly adjusted for upsamping.""" 130 | assert len(x) == (len(c)) * hop_size 131 | -------------------------------------------------------------------------------- /tasks/vocoder/hifigan.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | from modules.vocoder.hifigan.hifigan import HifiGanGenerator, MultiPeriodDiscriminator, MultiScaleDiscriminator, \ 5 | generator_loss, feature_loss, discriminator_loss 6 | from modules.vocoder.hifigan.mel_utils import mel_spectrogram 7 | from modules.vocoder.hifigan.stft_loss import MultiResolutionSTFTLoss 8 | from tasks.vocoder.vocoder_base import VocoderBaseTask 9 | from utils.commons.hparams import hparams 10 | from utils.nn.model_utils import print_arch 11 | 12 | 13 | class HifiGanTask(VocoderBaseTask): 14 | def build_model(self): 15 | self.model_gen = HifiGanGenerator(hparams) 16 | self.model_disc = nn.ModuleDict() 17 | self.model_disc['mpd'] = MultiPeriodDiscriminator() 18 | self.model_disc['msd'] = MultiScaleDiscriminator() 19 | self.stft_loss = MultiResolutionSTFTLoss() 20 | print_arch(self.model_gen) 21 | if hparams['load_ckpt'] != '': 22 | self.load_ckpt(hparams['load_ckpt'], 'model_gen', 'model_gen', force=True, strict=True) 23 | self.load_ckpt(hparams['load_ckpt'], 'model_disc', 'model_disc', force=True, strict=True) 24 | return self.model_gen 25 | 26 | def _training_step(self, sample, batch_idx, optimizer_idx): 27 | mel = sample['mels'] 28 | y = sample['wavs'] 29 | f0 = sample['f0'] 30 | loss_output = {} 31 | if optimizer_idx == 0: 32 | ####################### 33 | # Generator # 34 | ####################### 35 | y_ = self.model_gen(mel, f0) 36 | y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2) 37 | y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2) 38 | loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel'] 39 | _, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, mel) 40 | _, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, mel) 41 | loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv'] 42 | loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv'] 43 | if hparams['use_fm_loss']: 44 | loss_output['fm_f'] = feature_loss(fmap_f_r, fmap_f_g) 45 | loss_output['fm_s'] = feature_loss(fmap_s_r, fmap_s_g) 46 | if hparams['use_ms_stft']: 47 | loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1)) 48 | self.y_ = y_.detach() 49 | self.y_mel = y_mel.detach() 50 | self.y_hat_mel = y_hat_mel.detach() 51 | else: 52 | ####################### 53 | # Discriminator # 54 | ####################### 55 | y_ = self.y_ 56 | # MPD 57 | y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), mel) 58 | loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g) 59 | # MSD 60 | y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), mel) 61 | loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g) 62 | total_loss = sum(loss_output.values()) 63 | return total_loss, loss_output 64 | -------------------------------------------------------------------------------- /tasks/vocoder/vocoder_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch import nn 5 | from torch.utils.data import DistributedSampler 6 | from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler 7 | from utils.audio.io import save_wav 8 | from utils.commons.base_task import BaseTask 9 | from utils.commons.dataset_utils import data_loader 10 | from utils.commons.hparams import hparams 11 | from utils.commons.tensor_utils import tensors_to_scalars 12 | 13 | 14 | class VocoderBaseTask(BaseTask): 15 | def __init__(self): 16 | super(VocoderBaseTask, self).__init__() 17 | self.max_sentences = hparams['max_sentences'] 18 | self.max_valid_sentences = hparams['max_valid_sentences'] 19 | if self.max_valid_sentences == -1: 20 | hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences 21 | self.dataset_cls = VocoderDataset 22 | 23 | @data_loader 24 | def train_dataloader(self): 25 | train_dataset = self.dataset_cls('train', shuffle=True) 26 | return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) 27 | 28 | @data_loader 29 | def val_dataloader(self): 30 | valid_dataset = self.dataset_cls('test', shuffle=False) 31 | return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) 32 | 33 | @data_loader 34 | def test_dataloader(self): 35 | test_dataset = self.dataset_cls('test', shuffle=False) 36 | return self.build_dataloader(test_dataset, False, self.max_valid_sentences) 37 | 38 | def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): 39 | world_size = 1 40 | rank = 0 41 | if dist.is_initialized(): 42 | world_size = dist.get_world_size() 43 | rank = dist.get_rank() 44 | sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler 45 | train_sampler = sampler_cls( 46 | dataset=dataset, 47 | num_replicas=world_size, 48 | rank=rank, 49 | shuffle=shuffle, 50 | ) 51 | return torch.utils.data.DataLoader( 52 | dataset=dataset, 53 | shuffle=False, 54 | collate_fn=dataset.collater, 55 | batch_size=max_sentences, 56 | num_workers=dataset.num_workers, 57 | sampler=train_sampler, 58 | pin_memory=True, 59 | ) 60 | 61 | def build_optimizer(self, model): 62 | optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'], 63 | betas=[hparams['adam_b1'], hparams['adam_b2']]) 64 | optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), lr=hparams['lr'], 65 | betas=[hparams['adam_b1'], hparams['adam_b2']]) 66 | return [optimizer_gen, optimizer_disc] 67 | 68 | def build_scheduler(self, optimizer): 69 | return { 70 | "gen": torch.optim.lr_scheduler.StepLR( 71 | optimizer=optimizer[0], 72 | **hparams["generator_scheduler_params"]), 73 | "disc": torch.optim.lr_scheduler.StepLR( 74 | optimizer=optimizer[1], 75 | **hparams["discriminator_scheduler_params"]), 76 | } 77 | 78 | def validation_step(self, sample, batch_idx): 79 | outputs = {} 80 | total_loss, loss_output = self._training_step(sample, batch_idx, 0) 81 | outputs['losses'] = tensors_to_scalars(loss_output) 82 | outputs['total_loss'] = tensors_to_scalars(total_loss) 83 | 84 | if self.global_step % hparams['valid_infer_interval'] == 0 and \ 85 | batch_idx < 10: 86 | mels = sample['mels'] 87 | y = sample['wavs'] 88 | f0 = sample['f0'] 89 | y_ = self.model_gen(mels, f0) 90 | for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): 91 | wav_pred = wav_pred / wav_pred.abs().max() 92 | if self.global_step == 0: 93 | wav_gt = wav_gt / wav_gt.abs().max() 94 | self.logger.add_audio(f'wav_{batch_idx}_{idx}_gt', wav_gt, self.global_step, 95 | hparams['audio_sample_rate']) 96 | self.logger.add_audio(f'wav_{batch_idx}_{idx}_pred', wav_pred, self.global_step, 97 | hparams['audio_sample_rate']) 98 | return outputs 99 | 100 | def test_start(self): 101 | self.gen_dir = os.path.join(hparams['work_dir'], 102 | f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') 103 | os.makedirs(self.gen_dir, exist_ok=True) 104 | 105 | def test_step(self, sample, batch_idx): 106 | mels = sample['mels'] 107 | y = sample['wavs'] 108 | f0 = sample['f0'] 109 | loss_output = {} 110 | y_ = self.model_gen(mels, f0) 111 | gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') 112 | os.makedirs(gen_dir, exist_ok=True) 113 | for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): 114 | wav_gt = wav_gt.clamp(-1, 1) 115 | wav_pred = wav_pred.clamp(-1, 1) 116 | save_wav( 117 | wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', 118 | hparams['audio_sample_rate']) 119 | save_wav( 120 | wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', 121 | hparams['audio_sample_rate']) 122 | return loss_output 123 | 124 | def test_end(self, outputs): 125 | return {} 126 | 127 | def on_before_optimization(self, opt_idx): 128 | if opt_idx == 0: 129 | nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm']) 130 | else: 131 | nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"]) 132 | 133 | def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): 134 | if optimizer_idx == 0: 135 | self.scheduler['gen'].step(self.global_step // hparams['accumulate_grad_batches']) 136 | else: 137 | self.scheduler['disc'].step(self.global_step // hparams['accumulate_grad_batches']) 138 | -------------------------------------------------------------------------------- /utils/audio/align.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | import numpy as np 5 | from textgrid import TextGrid 6 | 7 | from utils.text.text_encoder import is_sil_phoneme 8 | 9 | 10 | def get_mel2ph(tg_fn, ph, mel, hop_size, audio_sample_rate, min_sil_duration=0): 11 | ph_list = ph.split(" ") 12 | itvs = TextGrid.fromFile(tg_fn)[1] 13 | itvs_ = [] 14 | for i in range(len(itvs)): 15 | if itvs[i].maxTime - itvs[i].minTime < min_sil_duration and i > 0 and is_sil_phoneme(itvs[i].mark): 16 | itvs_[-1].maxTime = itvs[i].maxTime 17 | else: 18 | itvs_.append(itvs[i]) 19 | itvs.intervals = itvs_ 20 | itv_marks = [itv.mark for itv in itvs] 21 | tg_len = len([x for x in itvs if not is_sil_phoneme(x.mark)]) 22 | ph_len = len([x for x in ph_list if not is_sil_phoneme(x)]) 23 | assert tg_len == ph_len, (tg_len, ph_len, itv_marks, ph_list, tg_fn) 24 | mel2ph = np.zeros([mel.shape[0]], int) 25 | i_itv = 0 26 | i_ph = 0 27 | while i_itv < len(itvs): 28 | itv = itvs[i_itv] 29 | ph = ph_list[i_ph] 30 | itv_ph = itv.mark 31 | start_frame = int(itv.minTime * audio_sample_rate / hop_size + 0.5) 32 | end_frame = int(itv.maxTime * audio_sample_rate / hop_size + 0.5) 33 | if is_sil_phoneme(itv_ph) and not is_sil_phoneme(ph): 34 | mel2ph[start_frame:end_frame] = i_ph 35 | i_itv += 1 36 | elif not is_sil_phoneme(itv_ph) and is_sil_phoneme(ph): 37 | i_ph += 1 38 | else: 39 | if not ((is_sil_phoneme(itv_ph) and is_sil_phoneme(ph)) \ 40 | or re.sub(r'\d+', '', itv_ph.lower()) == re.sub(r'\d+', '', ph.lower())): 41 | print(f"| WARN: {tg_fn} phs are not same: ", itv_ph, ph, itv_marks, ph_list) 42 | mel2ph[start_frame:end_frame] = i_ph + 1 43 | i_ph += 1 44 | i_itv += 1 45 | mel2ph[-1] = mel2ph[-2] 46 | assert not np.any(mel2ph == 0) 47 | T_t = len(ph_list) 48 | dur = mel2token_to_dur(mel2ph, T_t) 49 | return mel2ph.tolist(), dur.tolist() 50 | 51 | 52 | def split_audio_by_mel2ph(audio, mel2ph, hop_size, audio_num_mel_bins): 53 | if isinstance(audio, torch.Tensor): 54 | audio = audio.numpy() 55 | if isinstance(mel2ph, torch.Tensor): 56 | mel2ph = mel2ph.numpy() 57 | assert len(audio.shape) == 1, len(mel2ph.shape) == 1 58 | split_locs = [] 59 | for i in range(1, len(mel2ph)): 60 | if mel2ph[i] != mel2ph[i - 1]: 61 | split_loc = i * hop_size 62 | split_locs.append(split_loc) 63 | 64 | new_audio = [] 65 | for i in range(len(split_locs) - 1): 66 | new_audio.append(audio[split_locs[i]:split_locs[i + 1]]) 67 | new_audio.append(np.zeros([0.5 * audio_num_mel_bins])) 68 | return np.concatenate(new_audio) 69 | 70 | 71 | def mel2token_to_dur(mel2token, T_txt=None, max_dur=None): 72 | is_torch = isinstance(mel2token, torch.Tensor) 73 | has_batch_dim = True 74 | if not is_torch: 75 | mel2token = torch.LongTensor(mel2token) 76 | if T_txt is None: 77 | T_txt = mel2token.max() 78 | if len(mel2token.shape) == 1: 79 | mel2token = mel2token[None, ...] 80 | has_batch_dim = False 81 | B, _ = mel2token.shape 82 | dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token)) 83 | dur = dur[:, 1:] 84 | if max_dur is not None: 85 | dur = dur.clamp(max=max_dur) 86 | if not is_torch: 87 | dur = dur.numpy() 88 | if not has_batch_dim: 89 | dur = dur[0] 90 | return dur 91 | -------------------------------------------------------------------------------- /utils/audio/cwt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pycwt import wavelet 3 | from scipy.interpolate import interp1d 4 | 5 | dt = 0.005 6 | dj = 1 7 | 8 | 9 | def convert_continuos_f0(f0): 10 | '''CONVERT F0 TO CONTINUOUS F0 11 | Args: 12 | f0 (ndarray): original f0 sequence with the shape (T) 13 | Return: 14 | (ndarray): continuous f0 with the shape (T) 15 | ''' 16 | # get uv information as binary 17 | f0 = np.copy(f0) 18 | uv = (f0 == 0).astype(float) 19 | 20 | # get start and end of f0 21 | if (f0 == 0).all(): 22 | print("| all of the f0 values are 0.") 23 | return uv, f0 24 | start_f0 = f0[f0 != 0][0] 25 | end_f0 = f0[f0 != 0][-1] 26 | 27 | # padding start and end of f0 sequence 28 | start_idx = np.where(f0 == start_f0)[0][0] 29 | end_idx = np.where(f0 == end_f0)[0][-1] 30 | f0[:start_idx] = start_f0 31 | f0[end_idx:] = end_f0 32 | 33 | # get non-zero frame index 34 | nz_frames = np.where(f0 != 0)[0] 35 | 36 | # perform linear interpolation 37 | f = interp1d(nz_frames, f0[nz_frames]) 38 | cont_f0 = f(np.arange(0, f0.shape[0])) 39 | 40 | return uv, cont_f0 41 | 42 | 43 | def get_cont_lf0(f0, frame_period=5.0): 44 | uv, cont_f0_lpf = convert_continuos_f0(f0) 45 | # cont_f0_lpf = low_pass_filter(cont_f0_lpf, int(1.0 / (frame_period * 0.001)), cutoff=20) 46 | cont_lf0_lpf = np.log(cont_f0_lpf) 47 | return uv, cont_lf0_lpf 48 | 49 | 50 | def get_lf0_cwt(lf0): 51 | ''' 52 | input: 53 | signal of shape (N) 54 | output: 55 | Wavelet_lf0 of shape(10, N), scales of shape(10) 56 | ''' 57 | mother = wavelet.MexicanHat() 58 | s0 = dt * 2 59 | J = 9 60 | 61 | Wavelet_lf0, scales, _, _, _, _ = wavelet.cwt(np.squeeze(lf0), dt, dj, s0, J, mother) 62 | # Wavelet.shape => (J + 1, len(lf0)) 63 | Wavelet_lf0 = np.real(Wavelet_lf0).T 64 | return Wavelet_lf0, scales 65 | 66 | 67 | def norm_scale(Wavelet_lf0): 68 | mean = Wavelet_lf0.mean(0)[None, :] 69 | std = Wavelet_lf0.std(0)[None, :] 70 | Wavelet_lf0_norm = (Wavelet_lf0 - mean) / std 71 | return Wavelet_lf0_norm, mean, std 72 | 73 | 74 | def normalize_cwt_lf0(f0, mean, std): 75 | uv, cont_lf0_lpf = get_cont_lf0(f0) 76 | cont_lf0_norm = (cont_lf0_lpf - mean) / std 77 | Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_norm) 78 | Wavelet_lf0_norm, _, _ = norm_scale(Wavelet_lf0) 79 | 80 | return Wavelet_lf0_norm 81 | 82 | 83 | def get_lf0_cwt_norm(f0s, mean, std): 84 | uvs = list() 85 | cont_lf0_lpfs = list() 86 | cont_lf0_lpf_norms = list() 87 | Wavelet_lf0s = list() 88 | Wavelet_lf0s_norm = list() 89 | scaless = list() 90 | 91 | means = list() 92 | stds = list() 93 | for f0 in f0s: 94 | uv, cont_lf0_lpf = get_cont_lf0(f0) 95 | cont_lf0_lpf_norm = (cont_lf0_lpf - mean) / std 96 | 97 | Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) # [560,10] 98 | Wavelet_lf0_norm, mean_scale, std_scale = norm_scale(Wavelet_lf0) # [560,10],[1,10],[1,10] 99 | 100 | Wavelet_lf0s_norm.append(Wavelet_lf0_norm) 101 | uvs.append(uv) 102 | cont_lf0_lpfs.append(cont_lf0_lpf) 103 | cont_lf0_lpf_norms.append(cont_lf0_lpf_norm) 104 | Wavelet_lf0s.append(Wavelet_lf0) 105 | scaless.append(scales) 106 | means.append(mean_scale) 107 | stds.append(std_scale) 108 | 109 | return Wavelet_lf0s_norm, scaless, means, stds 110 | 111 | 112 | def inverse_cwt_torch(Wavelet_lf0, scales): 113 | import torch 114 | b = ((torch.arange(0, len(scales)).float().to(Wavelet_lf0.device)[None, None, :] + 1 + 2.5) ** (-2.5)) 115 | lf0_rec = Wavelet_lf0 * b 116 | lf0_rec_sum = lf0_rec.sum(-1) 117 | lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdim=True)) / lf0_rec_sum.std(-1, keepdim=True) 118 | return lf0_rec_sum 119 | 120 | 121 | def inverse_cwt(Wavelet_lf0, scales): 122 | # mother = wavelet.MexicanHat() 123 | # lf0_rec_sum = wavelet.icwt(Wavelet_lf0[0].T, scales, dt, dj, mother) 124 | b = ((np.arange(0, len(scales))[None, None, :] + 1 + 2.5) ** (-2.5)) 125 | lf0_rec = Wavelet_lf0 * b 126 | lf0_rec_sum = lf0_rec.sum(-1) 127 | # lf0_rec_sum = lf0_rec_sum[None, ...] 128 | lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdims=True)) / lf0_rec_sum.std(-1, keepdims=True) 129 | return lf0_rec_sum 130 | 131 | 132 | def cwt2f0(cwt_spec, mean, std, cwt_scales): 133 | assert len(mean.shape) == 1 and len(std.shape) == 1 and len(cwt_spec.shape) == 3 134 | import torch 135 | if isinstance(cwt_spec, torch.Tensor): 136 | f0 = inverse_cwt_torch(cwt_spec, cwt_scales) 137 | f0 = f0 * std[:, None] + mean[:, None] 138 | f0 = f0.exp() # [B, T] 139 | else: 140 | f0 = inverse_cwt(cwt_spec, cwt_scales) 141 | f0 = f0 * std[:, None] + mean[:, None] 142 | f0 = np.exp(f0) # [B, T] 143 | return f0 144 | -------------------------------------------------------------------------------- /utils/audio/griffin_lim.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def _stft(y, hop_size, win_size, fft_size): 8 | return librosa.stft(y=y, n_fft=fft_size, hop_length=hop_size, win_length=win_size, pad_mode='constant') 9 | 10 | 11 | def _istft(y, hop_size, win_size): 12 | return librosa.istft(y, hop_length=hop_size, win_length=win_size) 13 | 14 | 15 | def griffin_lim(S, hop_size, win_size, fft_size, angles=None, n_iters=30): 16 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) if angles is None else angles 17 | S_complex = np.abs(S).astype(np.complex) 18 | y = _istft(S_complex * angles, hop_size, win_size) 19 | for i in range(n_iters): 20 | angles = np.exp(1j * np.angle(_stft(y, hop_size, win_size, fft_size))) 21 | y = _istft(S_complex * angles, hop_size, win_size) 22 | return y 23 | 24 | 25 | def istft(amp, ang, hop_size, win_size, fft_size, pad=False, window=None): 26 | spec = amp * torch.exp(1j * ang) 27 | spec_r = spec.real 28 | spec_i = spec.imag 29 | spec = torch.stack([spec_r, spec_i], -1) 30 | if window is None: 31 | window = torch.hann_window(win_size).to(amp.device) 32 | if pad: 33 | spec = F.pad(spec, [0, 0, 0, 1], mode='reflect') 34 | wav = torch.istft(spec, fft_size, hop_size, win_size) 35 | return wav 36 | 37 | 38 | def griffin_lim_torch(S, hop_size, win_size, fft_size, angles=None, n_iters=30): 39 | """ 40 | 41 | Examples: 42 | >>> x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, win_length=win_length, pad_mode="constant") 43 | >>> x_stft = x_stft[None, ...] 44 | >>> amp = np.abs(x_stft) 45 | >>> angle_init = np.exp(2j * np.pi * np.random.rand(*x_stft.shape)) 46 | >>> amp = torch.FloatTensor(amp) 47 | >>> wav = griffin_lim_torch(amp, angle_init, hparams) 48 | 49 | :param amp: [B, n_fft, T] 50 | :param ang: [B, n_fft, T] 51 | :return: [B, T_wav] 52 | """ 53 | angles = torch.exp(2j * np.pi * torch.rand(*S.shape)) if angles is None else angles 54 | window = torch.hann_window(win_size).to(S.device) 55 | y = istft(S, angles, hop_size, win_size, fft_size, window=window) 56 | for i in range(n_iters): 57 | x_stft = torch.stft(y, fft_size, hop_size, win_size, window) 58 | x_stft = x_stft[..., 0] + 1j * x_stft[..., 1] 59 | angles = torch.angle(x_stft) 60 | y = istft(S, angles, hop_size, win_size, fft_size, window=window) 61 | return y 62 | 63 | 64 | # Conversions 65 | _mel_basis = None 66 | _inv_mel_basis = None 67 | 68 | 69 | def _build_mel_basis(audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax): 70 | assert fmax <= audio_sample_rate // 2 71 | return librosa.filters.mel(audio_sample_rate, fft_size, n_mels=audio_num_mel_bins, fmin=fmin, fmax=fmax) 72 | 73 | 74 | def _linear_to_mel(spectogram, audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax): 75 | global _mel_basis 76 | if _mel_basis is None: 77 | _mel_basis = _build_mel_basis(audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax) 78 | return np.dot(_mel_basis, spectogram) 79 | 80 | 81 | def _mel_to_linear(mel_spectrogram, audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax): 82 | global _inv_mel_basis 83 | if _inv_mel_basis is None: 84 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis(audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax)) 85 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 86 | -------------------------------------------------------------------------------- /utils/audio/io.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import numpy as np 4 | from scipy.io import wavfile 5 | 6 | 7 | def save_wav(wav, path, sr, norm=False): 8 | if norm: 9 | wav = wav / np.abs(wav).max() 10 | wav = wav * 32767 11 | wavfile.write(path[:-4] + '.wav', sr, wav.astype(np.int16)) 12 | if path[-4:] == '.mp3': 13 | to_mp3(path[:-4]) 14 | 15 | 16 | def to_mp3(out_path): 17 | if out_path[-4:] == '.wav': 18 | out_path = out_path[:-4] 19 | subprocess.check_call( 20 | f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"', 21 | shell=True, stdin=subprocess.PIPE) 22 | subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True) 23 | -------------------------------------------------------------------------------- /utils/audio/pitch/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hs-oh-prml/DiffProsody/6d5b6dbb58497fdff791d06fca09a4fae2a2cc11/utils/audio/pitch/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/audio/pitch/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def to_lf0(f0): 6 | f0[f0 < 1.0e-5] = 1.0e-6 7 | lf0 = f0.log() if isinstance(f0, torch.Tensor) else np.log(f0) 8 | lf0[f0 < 1.0e-5] = - 1.0E+10 9 | return lf0 10 | 11 | 12 | def to_f0(lf0): 13 | f0 = np.where(lf0 <= 0, 0.0, np.exp(lf0)) 14 | return f0.flatten() 15 | 16 | 17 | def f0_to_coarse(f0, f0_bin=256, f0_max=900.0, f0_min=50.0): 18 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 19 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 20 | is_torch = isinstance(f0, torch.Tensor) 21 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 22 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 23 | 24 | f0_mel[f0_mel <= 1] = 1 25 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 26 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int) 27 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min(), f0.min(), f0.max()) 28 | return f0_coarse 29 | 30 | 31 | def coarse_to_f0(f0_coarse, f0_bin=256, f0_max=900.0, f0_min=50.0): 32 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 33 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 34 | uv = f0_coarse == 1 35 | f0 = f0_mel_min + (f0_coarse - 1) * (f0_mel_max - f0_mel_min) / (f0_bin - 2) 36 | f0 = ((f0 / 1127).exp() - 1) * 700 37 | f0[uv] = 0 38 | return f0 39 | 40 | 41 | def norm_f0(f0, uv, pitch_norm='log', f0_mean=400, f0_std=100): 42 | is_torch = isinstance(f0, torch.Tensor) 43 | if pitch_norm == 'standard': 44 | f0 = (f0 - f0_mean) / f0_std 45 | if pitch_norm == 'log': 46 | f0 = torch.log2(f0 + 1e-8) if is_torch else np.log2(f0 + 1e-8) 47 | if uv is not None: 48 | f0[uv > 0] = 0 49 | return f0 50 | 51 | 52 | def norm_interp_f0(f0, pitch_norm='log', f0_mean=None, f0_std=None): 53 | is_torch = isinstance(f0, torch.Tensor) 54 | if is_torch: 55 | device = f0.device 56 | f0 = f0.data.cpu().numpy() 57 | uv = f0 == 0 58 | f0 = norm_f0(f0, uv, pitch_norm, f0_mean, f0_std) 59 | if sum(uv) == len(f0): 60 | f0[uv] = 0 61 | elif sum(uv) > 0: 62 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 63 | if is_torch: 64 | uv = torch.FloatTensor(uv) 65 | f0 = torch.FloatTensor(f0) 66 | f0 = f0.to(device) 67 | uv = uv.to(device) 68 | return f0, uv 69 | 70 | 71 | def denorm_f0(f0, uv, pitch_norm='log', f0_mean=400, f0_std=100, pitch_padding=None, min=80, max=800): 72 | is_torch = isinstance(f0, torch.Tensor) 73 | if pitch_norm == 'standard': 74 | f0 = f0 * f0_std + f0_mean 75 | if pitch_norm == 'log': 76 | f0 = 2 ** f0 77 | f0 = f0.clamp(min=min, max=max) if is_torch else np.clip(f0, a_min=min, a_max=max) 78 | if uv is not None: 79 | f0[uv > 0] = 0 80 | if pitch_padding is not None: 81 | f0[pitch_padding] = 0 82 | return f0 83 | -------------------------------------------------------------------------------- /utils/audio/pitch_extractors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | PITCH_EXTRACTOR = {} 4 | 5 | 6 | def register_pitch_extractor(name): 7 | def register_pitch_extractor_(cls): 8 | PITCH_EXTRACTOR[name] = cls 9 | return cls 10 | 11 | return register_pitch_extractor_ 12 | 13 | 14 | def get_pitch_extractor(name): 15 | return PITCH_EXTRACTOR[name] 16 | 17 | 18 | def extract_pitch_simple(wav): 19 | from utils.commons.hparams import hparams 20 | return extract_pitch(hparams['pitch_extractor'], wav, 21 | hparams['hop_size'], hparams['audio_sample_rate'], 22 | f0_min=hparams['f0_min'], f0_max=hparams['f0_max']) 23 | 24 | 25 | def extract_pitch(extractor_name, wav_data, hop_size, audio_sample_rate, f0_min=75, f0_max=800, **kwargs): 26 | return get_pitch_extractor(extractor_name)(wav_data, hop_size, audio_sample_rate, f0_min, f0_max, **kwargs) 27 | 28 | 29 | @register_pitch_extractor('parselmouth') 30 | def parselmouth_pitch(wav_data, hop_size, audio_sample_rate, f0_min, f0_max, 31 | voicing_threshold=0.6, *args, **kwargs): 32 | import parselmouth 33 | time_step = hop_size / audio_sample_rate * 1000 34 | n_mel_frames = int(len(wav_data) // hop_size) 35 | f0_pm = parselmouth.Sound(wav_data, audio_sample_rate).to_pitch_ac( 36 | time_step=time_step / 1000, voicing_threshold=voicing_threshold, 37 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 38 | pad_size = (n_mel_frames - len(f0_pm) + 1) // 2 39 | f0 = np.pad(f0_pm, [[pad_size, n_mel_frames - len(f0_pm) - pad_size]], mode='constant') 40 | return f0 41 | -------------------------------------------------------------------------------- /utils/audio/rnnoise.py: -------------------------------------------------------------------------------- 1 | # rnnoise.py, requirements: ffmpeg, sox, rnnoise, python 2 | import os 3 | import subprocess 4 | 5 | INSTALL_STR = """ 6 | RNNoise library not found. Please install RNNoise (https://github.com/xiph/rnnoise) to $REPO/rnnoise: 7 | sudo apt-get install -y autoconf automake libtool ffmpeg sox 8 | git clone https://github.com/xiph/rnnoise.git 9 | rm -rf rnnoise/.git 10 | cd rnnoise 11 | ./autogen.sh && ./configure && make 12 | cd .. 13 | """ 14 | 15 | 16 | def rnnoise(filename, out_fn=None, verbose=False, out_sample_rate=22050): 17 | assert os.path.exists('./rnnoise/examples/rnnoise_demo'), INSTALL_STR 18 | if out_fn is None: 19 | out_fn = f"{filename[:-4]}.denoised.wav" 20 | out_48k_fn = f"{out_fn}.48000.wav" 21 | tmp0_fn = f"{out_fn}.0.wav" 22 | tmp1_fn = f"{out_fn}.1.wav" 23 | tmp2_fn = f"{out_fn}.2.raw" 24 | tmp3_fn = f"{out_fn}.3.raw" 25 | if verbose: 26 | print("Pre-processing audio...") # wav to pcm raw 27 | subprocess.check_call( 28 | f'sox "{filename}" -G -r48000 "{tmp0_fn}"', shell=True, stdin=subprocess.PIPE) # convert to raw 29 | subprocess.check_call( 30 | f'sox -v 0.95 "{tmp0_fn}" "{tmp1_fn}"', shell=True, stdin=subprocess.PIPE) # convert to raw 31 | subprocess.check_call( 32 | f'ffmpeg -y -i "{tmp1_fn}" -loglevel quiet -f s16le -ac 1 -ar 48000 "{tmp2_fn}"', 33 | shell=True, stdin=subprocess.PIPE) # convert to raw 34 | if verbose: 35 | print("Applying rnnoise algorithm to audio...") # rnnoise 36 | subprocess.check_call( 37 | f'./rnnoise/examples/rnnoise_demo "{tmp2_fn}" "{tmp3_fn}"', shell=True) 38 | 39 | if verbose: 40 | print("Post-processing audio...") # pcm raw to wav 41 | if filename == out_fn: 42 | subprocess.check_call(f'rm -f "{out_fn}"', shell=True) 43 | subprocess.check_call( 44 | f'sox -t raw -r 48000 -b 16 -e signed-integer -c 1 "{tmp3_fn}" "{out_48k_fn}"', shell=True) 45 | subprocess.check_call(f'sox "{out_48k_fn}" -G -r{out_sample_rate} "{out_fn}"', shell=True) 46 | subprocess.check_call(f'rm -f "{tmp0_fn}" "{tmp1_fn}" "{tmp2_fn}" "{tmp3_fn}" "{out_48k_fn}"', shell=True) 47 | if verbose: 48 | print("Audio-filtering completed!") 49 | -------------------------------------------------------------------------------- /utils/audio/vad.py: -------------------------------------------------------------------------------- 1 | from skimage.transform import resize 2 | import struct 3 | import webrtcvad 4 | from scipy.ndimage.morphology import binary_dilation 5 | import librosa 6 | import numpy as np 7 | import pyloudnorm as pyln 8 | import warnings 9 | 10 | warnings.filterwarnings("ignore", message="Possible clipped samples in output") 11 | 12 | int16_max = (2 ** 15) - 1 13 | 14 | 15 | def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12): 16 | """ 17 | Ensures that segments without voice in the waveform remain no longer than a 18 | threshold determined by the VAD parameters in params.py. 19 | :param wav: the raw waveform as a numpy array of floats 20 | :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have. 21 | :return: the same waveform with silences trimmed away (length <= original wav length) 22 | """ 23 | 24 | ## Voice Activation Detection 25 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 26 | # This sets the granularity of the VAD. Should not need to be changed. 27 | sampling_rate = 16000 28 | wav_raw, sr = librosa.core.load(path, sr=sr) 29 | 30 | if norm: 31 | meter = pyln.Meter(sr) # create BS.1770 meter 32 | loudness = meter.integrated_loudness(wav_raw) 33 | wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0) 34 | if np.abs(wav_raw).max() > 1.0: 35 | wav_raw = wav_raw / np.abs(wav_raw).max() 36 | 37 | wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best') 38 | 39 | vad_window_length = 30 # In milliseconds 40 | # Number of frames to average together when performing the moving average smoothing. 41 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 42 | vad_moving_average_width = 8 43 | 44 | # Compute the voice detection window size 45 | samples_per_window = (vad_window_length * sampling_rate) // 1000 46 | 47 | # Trim the end of the audio to have a multiple of the window size 48 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 49 | 50 | # Convert the float waveform to 16-bit mono PCM 51 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 52 | 53 | # Perform voice activation detection 54 | voice_flags = [] 55 | vad = webrtcvad.Vad(mode=3) 56 | for window_start in range(0, len(wav), samples_per_window): 57 | window_end = window_start + samples_per_window 58 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 59 | sample_rate=sampling_rate)) 60 | voice_flags = np.array(voice_flags) 61 | 62 | # Smooth the voice detection with a moving average 63 | def moving_average(array, width): 64 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 65 | ret = np.cumsum(array_padded, dtype=float) 66 | ret[width:] = ret[width:] - ret[:-width] 67 | return ret[width - 1:] / width 68 | 69 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 70 | audio_mask = np.round(audio_mask).astype(np.bool) 71 | 72 | # Dilate the voiced regions 73 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 74 | audio_mask = np.repeat(audio_mask, samples_per_window) 75 | audio_mask = resize(audio_mask, (len(wav_raw),)) > 0 76 | if return_raw_wav: 77 | return wav_raw, audio_mask, sr 78 | return wav_raw[audio_mask], audio_mask, sr 79 | -------------------------------------------------------------------------------- /utils/commons/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | import torch 5 | 6 | 7 | def get_last_checkpoint(work_dir, steps=None): 8 | checkpoint = None 9 | last_ckpt_path = None 10 | ckpt_paths = get_all_ckpts(work_dir, steps) 11 | if len(ckpt_paths) > 0: 12 | last_ckpt_path = ckpt_paths[0] 13 | checkpoint = torch.load(last_ckpt_path, map_location='cpu') 14 | return checkpoint, last_ckpt_path 15 | 16 | 17 | def get_all_ckpts(work_dir, steps=None): 18 | if steps is None: 19 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' 20 | else: 21 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' 22 | return sorted(glob.glob(ckpt_path_pattern), 23 | key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) 24 | 25 | 26 | def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): 27 | if os.path.isfile(ckpt_base_dir): 28 | base_dir = os.path.dirname(ckpt_base_dir) 29 | ckpt_path = ckpt_base_dir 30 | checkpoint = torch.load(ckpt_base_dir, map_location='cpu') 31 | else: 32 | base_dir = ckpt_base_dir 33 | checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) 34 | if checkpoint is not None: 35 | state_dict = checkpoint["state_dict"] 36 | if len([k for k in state_dict.keys() if '.' in k]) > 0: 37 | state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() 38 | if k.startswith(f'{model_name}.')} 39 | else: 40 | if '.' not in model_name: 41 | state_dict = state_dict[model_name] 42 | else: 43 | base_model_name = model_name.split('.')[0] 44 | rest_model_name = model_name[len(base_model_name) + 1:] 45 | state_dict = { 46 | k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() 47 | if k.startswith(f'{rest_model_name}.')} 48 | if not strict: 49 | cur_model_state_dict = cur_model.state_dict() 50 | unmatched_keys = [] 51 | for key, param in state_dict.items(): 52 | if key in cur_model_state_dict: 53 | new_param = cur_model_state_dict[key] 54 | if new_param.shape != param.shape: 55 | unmatched_keys.append(key) 56 | print("| Unmatched keys: ", key, new_param.shape, param.shape) 57 | for key in unmatched_keys: 58 | del state_dict[key] 59 | cur_model.load_state_dict(state_dict, strict=strict) #, map_location="cuda:0" 60 | print(f"| load '{model_name}' from '{ckpt_path}'.") 61 | else: 62 | e_msg = f"| ckpt not found in {base_dir}." 63 | if force: 64 | assert False, e_msg 65 | else: 66 | print(e_msg) 67 | -------------------------------------------------------------------------------- /utils/commons/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | import types 5 | from functools import wraps 6 | from itertools import chain 7 | import numpy as np 8 | import torch.utils.data 9 | from torch.utils.data import ConcatDataset 10 | from utils.commons.hparams import hparams 11 | 12 | 13 | def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): 14 | if len(values[0].shape) == 1: 15 | return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id) 16 | else: 17 | return collate_2d(values, pad_idx, left_pad, shift_right, max_len) 18 | 19 | 20 | def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): 21 | """Convert a list of 1d tensors into a padded 2d tensor.""" 22 | size = max(v.size(0) for v in values) if max_len is None else max_len 23 | res = values[0].new(len(values), size).fill_(pad_idx) 24 | 25 | def copy_tensor(src, dst): 26 | assert dst.numel() == src.numel() 27 | if shift_right: 28 | dst[1:] = src[:-1] 29 | dst[0] = shift_id 30 | else: 31 | dst.copy_(src) 32 | 33 | for i, v in enumerate(values): 34 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 35 | return res 36 | 37 | 38 | def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): 39 | """Convert a list of 2d tensors into a padded 3d tensor.""" 40 | size = max(v.size(0) for v in values) if max_len is None else max_len 41 | res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) 42 | 43 | def copy_tensor(src, dst): 44 | assert dst.numel() == src.numel() 45 | if shift_right: 46 | dst[1:] = src[:-1] 47 | else: 48 | dst.copy_(src) 49 | 50 | for i, v in enumerate(values): 51 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 52 | return res 53 | 54 | 55 | def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 56 | if len(batch) == 0: 57 | return 0 58 | if len(batch) == max_sentences: 59 | return 1 60 | if num_tokens > max_tokens: 61 | return 1 62 | return 0 63 | 64 | 65 | def batch_by_size( 66 | indices, num_tokens_fn, max_tokens=None, max_sentences=None, 67 | required_batch_size_multiple=1, distributed=False 68 | ): 69 | """ 70 | Yield mini-batches of indices bucketed by size. Batches may contain 71 | sequences of different lengths. 72 | 73 | Args: 74 | indices (List[int]): ordered list of dataset indices 75 | num_tokens_fn (callable): function that returns the number of tokens at 76 | a given index 77 | max_tokens (int, optional): max number of tokens in each batch 78 | (default: None). 79 | max_sentences (int, optional): max number of sentences in each 80 | batch (default: None). 81 | required_batch_size_multiple (int, optional): require batch size to 82 | be a multiple of N (default: 1). 83 | """ 84 | max_tokens = max_tokens if max_tokens is not None else sys.maxsize 85 | max_sentences = max_sentences if max_sentences is not None else sys.maxsize 86 | bsz_mult = required_batch_size_multiple 87 | 88 | if isinstance(indices, types.GeneratorType): 89 | indices = np.fromiter(indices, dtype=np.int64, count=-1) 90 | 91 | sample_len = 0 92 | sample_lens = [] 93 | batch = [] 94 | batches = [] 95 | for i in range(len(indices)): 96 | idx = indices[i] 97 | num_tokens = num_tokens_fn(idx) 98 | sample_lens.append(num_tokens) 99 | sample_len = max(sample_len, num_tokens) 100 | 101 | assert sample_len <= max_tokens, ( 102 | "sentence at index {} of size {} exceeds max_tokens " 103 | "limit of {}!".format(idx, sample_len, max_tokens) 104 | ) 105 | num_tokens = (len(batch) + 1) * sample_len 106 | 107 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 108 | mod_len = max( 109 | bsz_mult * (len(batch) // bsz_mult), 110 | len(batch) % bsz_mult, 111 | ) 112 | batches.append(batch[:mod_len]) 113 | batch = batch[mod_len:] 114 | sample_lens = sample_lens[mod_len:] 115 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 116 | batch.append(idx) 117 | if len(batch) > 0: 118 | batches.append(batch) 119 | return batches 120 | 121 | 122 | def unpack_dict_to_list(samples): 123 | samples_ = [] 124 | bsz = samples.get('outputs').size(0) 125 | for i in range(bsz): 126 | res = {} 127 | for k, v in samples.items(): 128 | try: 129 | res[k] = v[i] 130 | except: 131 | pass 132 | samples_.append(res) 133 | return samples_ 134 | 135 | 136 | def remove_padding(x, padding_idx=0): 137 | if x is None: 138 | return None 139 | assert len(x.shape) in [1, 2] 140 | if len(x.shape) == 2: # [T, H] 141 | return x[np.abs(x).sum(-1) != padding_idx] 142 | elif len(x.shape) == 1: # [T] 143 | return x[x != padding_idx] 144 | 145 | 146 | def data_loader(fn): 147 | """ 148 | Decorator to make any fx with this use the lazy property 149 | :param fn: 150 | :return: 151 | """ 152 | 153 | wraps(fn) 154 | attr_name = '_lazy_' + fn.__name__ 155 | 156 | def _get_data_loader(self): 157 | try: 158 | value = getattr(self, attr_name) 159 | except AttributeError: 160 | try: 161 | value = fn(self) # Lazy evaluation, done only once. 162 | except AttributeError as e: 163 | # Guard against AttributeError suppression. (Issue #142) 164 | traceback.print_exc() 165 | error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) 166 | raise RuntimeError(error) from e 167 | setattr(self, attr_name, value) # Memoize evaluation. 168 | return value 169 | 170 | return _get_data_loader 171 | 172 | 173 | class BaseDataset(torch.utils.data.Dataset): 174 | def __init__(self, shuffle): 175 | super().__init__() 176 | self.hparams = hparams 177 | self.shuffle = shuffle 178 | self.sort_by_len = hparams['sort_by_len'] 179 | self.sizes = None 180 | 181 | @property 182 | def _sizes(self): 183 | return self.sizes 184 | 185 | def __getitem__(self, index): 186 | raise NotImplementedError 187 | 188 | def collater(self, samples): 189 | raise NotImplementedError 190 | 191 | def __len__(self): 192 | return len(self._sizes) 193 | 194 | def num_tokens(self, index): 195 | return self.size(index) 196 | 197 | def size(self, index): 198 | """Return an example's size as a float or tuple. This value is used when 199 | filtering a dataset with ``--max-positions``.""" 200 | return min(self._sizes[index], hparams['max_frames']) 201 | 202 | def ordered_indices(self): 203 | """Return an ordered list of indices. Batches will be constructed based 204 | on this order.""" 205 | if self.shuffle: 206 | indices = np.random.permutation(len(self)) 207 | if self.sort_by_len: 208 | indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] 209 | else: 210 | indices = np.arange(len(self)) 211 | return indices 212 | 213 | @property 214 | def num_workers(self): 215 | return int(os.getenv('NUM_WORKERS', hparams['ds_workers'])) 216 | 217 | 218 | class BaseConcatDataset(ConcatDataset): 219 | def collater(self, samples): 220 | return self.datasets[0].collater(samples) 221 | 222 | @property 223 | def _sizes(self): 224 | if not hasattr(self, 'sizes'): 225 | self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets])) 226 | return self.sizes 227 | 228 | def size(self, index): 229 | return min(self._sizes[index], hparams['max_frames']) 230 | 231 | def num_tokens(self, index): 232 | return self.size(index) 233 | 234 | def ordered_indices(self): 235 | """Return an ordered list of indices. Batches will be constructed based 236 | on this order.""" 237 | if self.datasets[0].shuffle: 238 | indices = np.random.permutation(len(self)) 239 | if self.datasets[0].sort_by_len: 240 | indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] 241 | else: 242 | indices = np.arange(len(self)) 243 | return indices 244 | 245 | @property 246 | def num_workers(self): 247 | return self.datasets[0].num_workers 248 | -------------------------------------------------------------------------------- /utils/commons/ddp_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DistributedDataParallel 2 | from torch.nn.parallel.distributed import _find_tensors 3 | import torch.optim 4 | import torch.utils.data 5 | import torch 6 | from packaging import version 7 | 8 | class DDP(DistributedDataParallel): 9 | """ 10 | Override the forward call in lightning so it goes to training and validation step respectively 11 | """ 12 | 13 | def forward(self, *inputs, **kwargs): # pragma: no cover 14 | if version.parse(torch.__version__[:6]) < version.parse("1.11"): 15 | self._sync_params() 16 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 17 | assert len(self.device_ids) == 1 18 | if self.module.training: 19 | output = self.module.training_step(*inputs[0], **kwargs[0]) 20 | elif self.module.testing: 21 | output = self.module.test_step(*inputs[0], **kwargs[0]) 22 | else: 23 | output = self.module.validation_step(*inputs[0], **kwargs[0]) 24 | if torch.is_grad_enabled(): 25 | # We'll return the output object verbatim since it is a freeform 26 | # object. We need to find any tensors in this object, though, 27 | # because we need to figure out which parameters were used during 28 | # this forward pass, to ensure we short circuit reduction for any 29 | # unused parameters. Only if `find_unused_parameters` is set. 30 | if self.find_unused_parameters: 31 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 32 | else: 33 | self.reducer.prepare_for_backward([]) 34 | else: 35 | from torch.nn.parallel.distributed import \ 36 | logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref 37 | with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): 38 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 39 | self.logger.set_runtime_stats_and_log() 40 | self.num_iterations += 1 41 | self.reducer.prepare_for_forward() 42 | 43 | # Notify the join context that this process has not joined, if 44 | # needed 45 | work = Join.notify_join_context(self) 46 | if work: 47 | self.reducer._set_forward_pass_work_handle( 48 | work, self._divide_by_initial_world_size 49 | ) 50 | 51 | # Calling _rebuild_buckets before forward compuation, 52 | # It may allocate new buckets before deallocating old buckets 53 | # inside _rebuild_buckets. To save peak memory usage, 54 | # call _rebuild_buckets before the peak memory usage increases 55 | # during forward computation. 56 | # This should be called only once during whole training period. 57 | if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): 58 | logging.info("Reducer buckets have been rebuilt in this iteration.") 59 | self._has_rebuilt_buckets = True 60 | 61 | # sync params according to location (before/after forward) user 62 | # specified as part of hook, if hook was specified. 63 | buffer_hook_registered = hasattr(self, 'buffer_hook') 64 | if self._check_sync_bufs_pre_fwd(): 65 | self._sync_buffers() 66 | 67 | if self._join_config.enable: 68 | # Notify joined ranks whether they should sync in backwards pass or not. 69 | self._check_global_requires_backward_grad_sync(is_joined_rank=False) 70 | 71 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 72 | if self.module.training: 73 | output = self.module.training_step(*inputs[0], **kwargs[0]) 74 | elif self.module.testing: 75 | output = self.module.test_step(*inputs[0], **kwargs[0]) 76 | else: 77 | output = self.module.validation_step(*inputs[0], **kwargs[0]) 78 | 79 | # sync params according to location (before/after forward) user 80 | # specified as part of hook, if hook was specified. 81 | if self._check_sync_bufs_post_fwd(): 82 | self._sync_buffers() 83 | 84 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 85 | self.require_forward_param_sync = True 86 | # We'll return the output object verbatim since it is a freeform 87 | # object. We need to find any tensors in this object, though, 88 | # because we need to figure out which parameters were used during 89 | # this forward pass, to ensure we short circuit reduction for any 90 | # unused parameters. Only if `find_unused_parameters` is set. 91 | if self.find_unused_parameters and not self.static_graph: 92 | # Do not need to populate this for static graph. 93 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 94 | else: 95 | self.reducer.prepare_for_backward([]) 96 | else: 97 | self.require_forward_param_sync = False 98 | 99 | # TODO: DDPSink is currently enabled for unused parameter detection and 100 | # static graph training for first iteration. 101 | if (self.find_unused_parameters and not self.static_graph) or ( 102 | self.static_graph and self.num_iterations == 1 103 | ): 104 | state_dict = { 105 | 'static_graph': self.static_graph, 106 | 'num_iterations': self.num_iterations, 107 | } 108 | 109 | output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( 110 | output 111 | ) 112 | output_placeholders = [None for _ in range(len(output_tensor_list))] 113 | # Do not touch tensors that have no grad_fn, which can cause issues 114 | # such as https://github.com/pytorch/pytorch/issues/60733 115 | for i, output in enumerate(output_tensor_list): 116 | if torch.is_tensor(output) and output.grad_fn is None: 117 | output_placeholders[i] = output 118 | 119 | # When find_unused_parameters=True, makes tensors which require grad 120 | # run through the DDPSink backward pass. When not all outputs are 121 | # used in loss, this makes those corresponding tensors receive 122 | # undefined gradient which the reducer then handles to ensure 123 | # param.grad field is not touched and we don't error out. 124 | passthrough_tensor_list = _DDPSink.apply( 125 | self.reducer, 126 | state_dict, 127 | *output_tensor_list, 128 | ) 129 | for i in range(len(output_placeholders)): 130 | if output_placeholders[i] is None: 131 | output_placeholders[i] = passthrough_tensor_list[i] 132 | 133 | # Reconstruct output data structure. 134 | output = _tree_unflatten_with_rref( 135 | output_placeholders, treespec, output_is_rref 136 | ) 137 | return output 138 | -------------------------------------------------------------------------------- /utils/commons/hparams.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | 5 | from utils.os_utils import remove_file 6 | 7 | global_print_hparams = True 8 | hparams = {} 9 | 10 | 11 | class Args: 12 | def __init__(self, **kwargs): 13 | for k, v in kwargs.items(): 14 | self.__setattr__(k, v) 15 | 16 | 17 | def override_config(old_config: dict, new_config: dict): 18 | for k, v in new_config.items(): 19 | if isinstance(v, dict) and k in old_config: 20 | override_config(old_config[k], new_config[k]) 21 | else: 22 | old_config[k] = v 23 | 24 | 25 | def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): 26 | if config == '' and exp_name == '': 27 | parser = argparse.ArgumentParser(description='') 28 | parser.add_argument('--config', type=str, default='', 29 | help='location of the data corpus') 30 | parser.add_argument('--exp_name', type=str, default='', help='exp_name') 31 | parser.add_argument('-hp', '--hparams', type=str, default='', 32 | help='location of the data corpus') 33 | parser.add_argument('--infer', action='store_true', help='infer') 34 | parser.add_argument('--validate', action='store_true', help='validate') 35 | parser.add_argument('--reset', action='store_true', help='reset hparams') 36 | parser.add_argument('--remove', action='store_true', help='remove old ckpt') 37 | parser.add_argument('--debug', action='store_true', help='debug') 38 | args, unknown = parser.parse_known_args() 39 | print("| Unknow hparams: ", unknown) 40 | else: 41 | args = Args(config=config, exp_name=exp_name, hparams=hparams_str, 42 | infer=False, validate=False, reset=False, debug=False, remove=False) 43 | global hparams 44 | assert args.config != '' or args.exp_name != '' 45 | if args.config != '': 46 | print("#########: ", args.config) 47 | 48 | assert os.path.exists(args.config) 49 | 50 | config_chains = [] 51 | loaded_config = set() 52 | 53 | def load_config(config_fn): 54 | # deep first inheritance and avoid the second visit of one node 55 | if not os.path.exists(config_fn): 56 | return {} 57 | with open(config_fn) as f: 58 | hparams_ = yaml.safe_load(f) 59 | loaded_config.add(config_fn) 60 | if 'base_config' in hparams_: 61 | ret_hparams = {} 62 | if not isinstance(hparams_['base_config'], list): 63 | hparams_['base_config'] = [hparams_['base_config']] 64 | for c in hparams_['base_config']: 65 | if c.startswith('.'): 66 | c = f'{os.path.dirname(config_fn)}/{c}' 67 | c = os.path.normpath(c) 68 | if c not in loaded_config: 69 | override_config(ret_hparams, load_config(c)) 70 | override_config(ret_hparams, hparams_) 71 | else: 72 | ret_hparams = hparams_ 73 | config_chains.append(config_fn) 74 | return ret_hparams 75 | 76 | saved_hparams = {} 77 | args_work_dir = '' 78 | if args.exp_name != '': 79 | args_work_dir = f'/workspace/checkpoints/{args.exp_name}' 80 | ckpt_config_path = f'{args_work_dir}/config.yaml' 81 | if os.path.exists(ckpt_config_path): 82 | with open(ckpt_config_path) as f: 83 | saved_hparams_ = yaml.safe_load(f) 84 | if saved_hparams_ is not None: 85 | saved_hparams.update(saved_hparams_) 86 | hparams_ = {} 87 | if args.config != '': 88 | hparams_.update(load_config(args.config)) 89 | if not args.reset: 90 | hparams_.update(saved_hparams) 91 | hparams_['work_dir'] = args_work_dir 92 | 93 | # Support config overriding in command line. Support list type config overriding. 94 | # Examples: --hparams="a=1,b.c=2,d=[1 1 1]" 95 | if args.hparams != "": 96 | for new_hparam in args.hparams.split(","): 97 | k, v = new_hparam.split("=") 98 | v = v.strip("\'\" ") 99 | config_node = hparams_ 100 | for k_ in k.split(".")[:-1]: 101 | config_node = config_node[k_] 102 | k = k.split(".")[-1] 103 | if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: 104 | if type(config_node[k]) == list: 105 | v = v.replace(" ", ",") 106 | config_node[k] = eval(v) 107 | else: 108 | config_node[k] = type(config_node[k])(v) 109 | if args_work_dir != '' and args.remove: 110 | answer = input("REMOVE old checkpoint? Y/N [Default: N]: ") 111 | if answer.lower() == "y": 112 | remove_file(args_work_dir) 113 | if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: 114 | os.makedirs(hparams_['work_dir'], exist_ok=True) 115 | with open(ckpt_config_path, 'w') as f: 116 | yaml.safe_dump(hparams_, f) 117 | 118 | hparams_['infer'] = args.infer 119 | hparams_['debug'] = args.debug 120 | hparams_['validate'] = args.validate 121 | hparams_['exp_name'] = args.exp_name 122 | global global_print_hparams 123 | if global_hparams: 124 | hparams.clear() 125 | hparams.update(hparams_) 126 | if print_hparams and global_print_hparams and global_hparams: 127 | print('| Hparams chains: ', config_chains) 128 | print('| Hparams: ') 129 | for i, (k, v) in enumerate(sorted(hparams_.items())): 130 | print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") 131 | print("") 132 | global_print_hparams = False 133 | return hparams_ 134 | -------------------------------------------------------------------------------- /utils/commons/indexed_datasets.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | 6 | 7 | class IndexedDataset: 8 | def __init__(self, path, num_cache=1): 9 | super().__init__() 10 | self.path = path 11 | self.data_file = None 12 | self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] 13 | self.data_file = open(f"{path}.data", 'rb', buffering=-1) 14 | self.cache = [] 15 | self.num_cache = num_cache 16 | 17 | def check_index(self, i): 18 | if i < 0 or i >= len(self.data_offsets) - 1: 19 | raise IndexError('index out of range') 20 | 21 | def __del__(self): 22 | if self.data_file: 23 | self.data_file.close() 24 | 25 | def __getitem__(self, i): 26 | self.check_index(i) 27 | if self.num_cache > 0: 28 | for c in self.cache: 29 | if c[0] == i: 30 | return c[1] 31 | self.data_file.seek(self.data_offsets[i]) 32 | b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) 33 | item = pickle.loads(b) 34 | if self.num_cache > 0: 35 | self.cache = [(i, deepcopy(item))] + self.cache[:-1] 36 | return item 37 | 38 | def __len__(self): 39 | return len(self.data_offsets) - 1 40 | 41 | class IndexedDatasetBuilder: 42 | def __init__(self, path): 43 | self.path = path 44 | self.out_file = open(f"{path}.data", 'wb') 45 | self.byte_offsets = [0] 46 | 47 | def add_item(self, item): 48 | s = pickle.dumps(item) 49 | bytes = self.out_file.write(s) 50 | self.byte_offsets.append(self.byte_offsets[-1] + bytes) 51 | 52 | def finalize(self): 53 | self.out_file.close() 54 | np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) 55 | 56 | 57 | if __name__ == "__main__": 58 | import random 59 | from tqdm import tqdm 60 | ds_path = '/tmp/indexed_ds_example' 61 | size = 100 62 | items = [{"a": np.random.normal(size=[10000, 10]), 63 | "b": np.random.normal(size=[10000, 10])} for i in range(size)] 64 | builder = IndexedDatasetBuilder(ds_path) 65 | for i in tqdm(range(size)): 66 | builder.add_item(items[i]) 67 | builder.finalize() 68 | ds = IndexedDataset(ds_path) 69 | for i in tqdm(range(10000)): 70 | idx = random.randint(0, size - 1) 71 | assert (ds[idx]['a'] == items[idx]['a']).all() 72 | -------------------------------------------------------------------------------- /utils/commons/meters.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | 5 | class AvgrageMeter(object): 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.avg = 0 12 | self.sum = 0 13 | self.cnt = 0 14 | 15 | def update(self, val, n=1): 16 | self.sum += val * n 17 | self.cnt += n 18 | self.avg = self.sum / self.cnt 19 | 20 | 21 | class Timer: 22 | timer_map = {} 23 | 24 | def __init__(self, name, enable=False): 25 | if name not in Timer.timer_map: 26 | Timer.timer_map[name] = 0 27 | self.name = name 28 | self.enable = enable 29 | 30 | def __enter__(self): 31 | if self.enable: 32 | if torch.cuda.is_available(): 33 | torch.cuda.synchronize() 34 | self.t = time.time() 35 | 36 | def __exit__(self, exc_type, exc_val, exc_tb): 37 | if self.enable: 38 | if torch.cuda.is_available(): 39 | torch.cuda.synchronize() 40 | Timer.timer_map[self.name] += time.time() - self.t 41 | if self.enable: 42 | print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}') 43 | -------------------------------------------------------------------------------- /utils/commons/multiprocess_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | from functools import partial 4 | from tqdm import tqdm 5 | 6 | 7 | def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None): 8 | ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None 9 | while True: 10 | args = args_queue.get() 11 | if args == '': 12 | return 13 | job_idx, map_func, arg = args 14 | try: 15 | map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func 16 | if isinstance(arg, dict): 17 | res = map_func_(**arg) 18 | elif isinstance(arg, (list, tuple)): 19 | res = map_func_(*arg) 20 | else: 21 | res = map_func_(arg) 22 | results_queue.put((job_idx, res)) 23 | except: 24 | traceback.print_exc() 25 | results_queue.put((job_idx, None)) 26 | 27 | 28 | class MultiprocessManager: 29 | def __init__(self, num_workers=None, init_ctx_func=None, multithread=False, queue_max=-1): 30 | if multithread: 31 | from multiprocessing.dummy import Queue, Process 32 | else: 33 | from multiprocessing import Queue, Process 34 | if num_workers is None: 35 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 36 | self.num_workers = num_workers 37 | self.results_queue = Queue(maxsize=-1) 38 | self.jobs_pending = [] 39 | self.args_queue = Queue(maxsize=queue_max) 40 | self.workers = [] 41 | self.total_jobs = 0 42 | self.multithread = multithread 43 | for i in range(num_workers): 44 | if multithread: 45 | p = Process(target=chunked_worker, 46 | args=(i, self.args_queue, self.results_queue, init_ctx_func)) 47 | else: 48 | p = Process(target=chunked_worker, 49 | args=(i, self.args_queue, self.results_queue, init_ctx_func), 50 | daemon=True) 51 | self.workers.append(p) 52 | p.start() 53 | 54 | def add_job(self, func, args): 55 | if not self.args_queue.full(): 56 | self.args_queue.put((self.total_jobs, func, args)) 57 | else: 58 | self.jobs_pending.append((self.total_jobs, func, args)) 59 | self.total_jobs += 1 60 | 61 | def get_results(self): 62 | self.n_finished = 0 63 | while self.n_finished < self.total_jobs: 64 | while len(self.jobs_pending) > 0 and not self.args_queue.full(): 65 | self.args_queue.put(self.jobs_pending[0]) 66 | self.jobs_pending = self.jobs_pending[1:] 67 | job_id, res = self.results_queue.get() 68 | yield job_id, res 69 | self.n_finished += 1 70 | for w in range(self.num_workers): 71 | self.args_queue.put("") 72 | for w in self.workers: 73 | w.join() 74 | 75 | def close(self): 76 | if not self.multithread: 77 | for w in self.workers: 78 | w.terminate() 79 | 80 | def __len__(self): 81 | return self.total_jobs 82 | 83 | 84 | def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, 85 | multithread=False, queue_max=-1, desc=None): 86 | for i, res in tqdm( 87 | multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread, 88 | queue_max=queue_max), 89 | total=len(args), desc=desc): 90 | yield i, res 91 | 92 | 93 | def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False, 94 | queue_max=-1): 95 | """ 96 | Multiprocessing running chunked jobs. 97 | 98 | Examples: 99 | >>> for res in tqdm(multiprocess_run(job_func, args): 100 | >>> print(res) 101 | 102 | :param map_func: 103 | :param args: 104 | :param num_workers: 105 | :param ordered: 106 | :param init_ctx_func: 107 | :param q_max_size: 108 | :param multithread: 109 | :return: 110 | """ 111 | if num_workers is None: 112 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 113 | manager = MultiprocessManager(num_workers, init_ctx_func, multithread, queue_max=queue_max) 114 | for arg in args: 115 | manager.add_job(map_func, arg) 116 | if ordered: 117 | n_jobs = len(args) 118 | results = ['' for _ in range(n_jobs)] 119 | i_now = 0 120 | for job_i, res in manager.get_results(): 121 | results[job_i] = res 122 | while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != ''): 123 | yield i_now, results[i_now] 124 | results[i_now] = None 125 | i_now += 1 126 | else: 127 | for job_i, res in manager.get_results(): 128 | yield job_i, res 129 | manager.close() 130 | -------------------------------------------------------------------------------- /utils/commons/single_thread_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | os.environ['TF_NUM_INTEROP_THREADS'] = '1' 5 | os.environ['TF_NUM_INTRAOP_THREADS'] = '1' 6 | -------------------------------------------------------------------------------- /utils/commons/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | def reduce_tensors(metrics): 6 | new_metrics = {} 7 | for k, v in metrics.items(): 8 | if isinstance(v, torch.Tensor): 9 | dist.all_reduce(v) 10 | v = v / dist.get_world_size() 11 | if type(v) is dict: 12 | v = reduce_tensors(v) 13 | new_metrics[k] = v 14 | return new_metrics 15 | 16 | 17 | def tensors_to_scalars(tensors): 18 | if isinstance(tensors, torch.Tensor): 19 | tensors = tensors.item() 20 | return tensors 21 | elif isinstance(tensors, dict): 22 | new_tensors = {} 23 | for k, v in tensors.items(): 24 | v = tensors_to_scalars(v) 25 | new_tensors[k] = v 26 | return new_tensors 27 | elif isinstance(tensors, list): 28 | return [tensors_to_scalars(v) for v in tensors] 29 | else: 30 | return tensors 31 | 32 | 33 | def tensors_to_np(tensors): 34 | if isinstance(tensors, dict): 35 | new_np = {} 36 | for k, v in tensors.items(): 37 | if isinstance(v, torch.Tensor): 38 | v = v.cpu().numpy() 39 | if type(v) is dict: 40 | v = tensors_to_np(v) 41 | new_np[k] = v 42 | elif isinstance(tensors, list): 43 | new_np = [] 44 | for v in tensors: 45 | if isinstance(v, torch.Tensor): 46 | v = v.cpu().numpy() 47 | if type(v) is dict: 48 | v = tensors_to_np(v) 49 | new_np.append(v) 50 | elif isinstance(tensors, torch.Tensor): 51 | v = tensors 52 | if isinstance(v, torch.Tensor): 53 | v = v.cpu().numpy() 54 | if type(v) is dict: 55 | v = tensors_to_np(v) 56 | new_np = v 57 | else: 58 | raise Exception(f'tensors_to_np does not support type {type(tensors)}.') 59 | return new_np 60 | 61 | 62 | def move_to_cpu(tensors): 63 | ret = {} 64 | for k, v in tensors.items(): 65 | if isinstance(v, torch.Tensor): 66 | v = v.cpu() 67 | if type(v) is dict: 68 | v = move_to_cpu(v) 69 | ret[k] = v 70 | return ret 71 | 72 | 73 | def move_to_cuda(batch, gpu_id=0): 74 | # base case: object can be directly moved using `cuda` or `to` 75 | if callable(getattr(batch, 'cuda', None)): 76 | return batch.cuda(gpu_id, non_blocking=True) 77 | elif callable(getattr(batch, 'to', None)): 78 | return batch.to(torch.device('cuda', gpu_id), non_blocking=True) 79 | elif isinstance(batch, list): 80 | for i, x in enumerate(batch): 81 | batch[i] = move_to_cuda(x, gpu_id) 82 | return batch 83 | elif isinstance(batch, tuple): 84 | batch = list(batch) 85 | for i, x in enumerate(batch): 86 | batch[i] = move_to_cuda(x, gpu_id) 87 | return tuple(batch) 88 | elif isinstance(batch, dict): 89 | for k, v in batch.items(): 90 | batch[k] = move_to_cuda(v, gpu_id) 91 | return batch 92 | return batch 93 | -------------------------------------------------------------------------------- /utils/metrics/diagonal_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_focus_rate(attn, src_padding_mask=None, tgt_padding_mask=None): 5 | ''' 6 | attn: bs x L_t x L_s 7 | ''' 8 | if src_padding_mask is not None: 9 | attn = attn * (1 - src_padding_mask.float())[:, None, :] 10 | 11 | if tgt_padding_mask is not None: 12 | attn = attn * (1 - tgt_padding_mask.float())[:, :, None] 13 | 14 | focus_rate = attn.max(-1).values.sum(-1) 15 | focus_rate = focus_rate / attn.sum(-1).sum(-1) 16 | return focus_rate 17 | 18 | 19 | def get_phone_coverage_rate(attn, src_padding_mask=None, src_seg_mask=None, tgt_padding_mask=None): 20 | ''' 21 | attn: bs x L_t x L_s 22 | ''' 23 | src_mask = attn.new(attn.size(0), attn.size(-1)).bool().fill_(False) 24 | if src_padding_mask is not None: 25 | src_mask |= src_padding_mask 26 | if src_seg_mask is not None: 27 | src_mask |= src_seg_mask 28 | 29 | attn = attn * (1 - src_mask.float())[:, None, :] 30 | if tgt_padding_mask is not None: 31 | attn = attn * (1 - tgt_padding_mask.float())[:, :, None] 32 | 33 | phone_coverage_rate = attn.max(1).values.sum(-1) 34 | # phone_coverage_rate = phone_coverage_rate / attn.sum(-1).sum(-1) 35 | phone_coverage_rate = phone_coverage_rate / (1 - src_mask.float()).sum(-1) 36 | return phone_coverage_rate 37 | 38 | 39 | def get_diagonal_focus_rate(attn, attn_ks, target_len, src_padding_mask=None, tgt_padding_mask=None, 40 | band_mask_factor=5, band_width=50): 41 | ''' 42 | attn: bx x L_t x L_s 43 | attn_ks: shape: tensor with shape [batch_size], input_lens/output_lens 44 | 45 | diagonal: y=k*x (k=attn_ks, x:output, y:input) 46 | 1 0 0 47 | 0 1 0 48 | 0 0 1 49 | y>=k*(x-width) and y<=k*(x+width):1 50 | else:0 51 | ''' 52 | # width = min(target_len/band_mask_factor, 50) 53 | width1 = target_len / band_mask_factor 54 | width2 = target_len.new(target_len.size()).fill_(band_width) 55 | width = torch.where(width1 < width2, width1, width2).float() 56 | base = torch.ones(attn.size()).to(attn.device) 57 | zero = torch.zeros(attn.size()).to(attn.device) 58 | x = torch.arange(0, attn.size(1)).to(attn.device)[None, :, None].float() * base 59 | y = torch.arange(0, attn.size(2)).to(attn.device)[None, None, :].float() * base 60 | cond = (y - attn_ks[:, None, None] * x) 61 | cond1 = cond + attn_ks[:, None, None] * width[:, None, None] 62 | cond2 = cond - attn_ks[:, None, None] * width[:, None, None] 63 | mask1 = torch.where(cond1 < 0, zero, base) 64 | mask2 = torch.where(cond2 > 0, zero, base) 65 | mask = mask1 * mask2 66 | 67 | if src_padding_mask is not None: 68 | attn = attn * (1 - src_padding_mask.float())[:, None, :] 69 | if tgt_padding_mask is not None: 70 | attn = attn * (1 - tgt_padding_mask.float())[:, :, None] 71 | 72 | diagonal_attn = attn * mask 73 | diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1) 74 | return diagonal_focus_rate, mask 75 | -------------------------------------------------------------------------------- /utils/metrics/dtw.py: -------------------------------------------------------------------------------- 1 | from numpy import array, zeros, full, argmin, inf, ndim 2 | from scipy.spatial.distance import cdist 3 | from math import isinf 4 | 5 | 6 | def dtw(x, y, dist, warp=1, w=inf, s=1.0): 7 | """ 8 | Computes Dynamic Time Warping (DTW) of two sequences. 9 | 10 | :param array x: N1*M array 11 | :param array y: N2*M array 12 | :param func dist: distance used as cost measure 13 | :param int warp: how many shifts are computed. 14 | :param int w: window size limiting the maximal distance between indices of matched entries |i,j|. 15 | :param float s: weight applied on off-diagonal moves of the path. As s gets larger, the warping path is increasingly biased towards the diagonal 16 | Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path. 17 | """ 18 | assert len(x) 19 | assert len(y) 20 | assert isinf(w) or (w >= abs(len(x) - len(y))) 21 | assert s > 0 22 | r, c = len(x), len(y) 23 | if not isinf(w): 24 | D0 = full((r + 1, c + 1), inf) 25 | for i in range(1, r + 1): 26 | D0[i, max(1, i - w):min(c + 1, i + w + 1)] = 0 27 | D0[0, 0] = 0 28 | else: 29 | D0 = zeros((r + 1, c + 1)) 30 | D0[0, 1:] = inf 31 | D0[1:, 0] = inf 32 | D1 = D0[1:, 1:] # view 33 | for i in range(r): 34 | for j in range(c): 35 | if (isinf(w) or (max(0, i - w) <= j <= min(c, i + w))): 36 | D1[i, j] = dist(x[i], y[j]) 37 | C = D1.copy() 38 | jrange = range(c) 39 | for i in range(r): 40 | if not isinf(w): 41 | jrange = range(max(0, i - w), min(c, i + w + 1)) 42 | for j in jrange: 43 | min_list = [D0[i, j]] 44 | for k in range(1, warp + 1): 45 | i_k = min(i + k, r) 46 | j_k = min(j + k, c) 47 | min_list += [D0[i_k, j] * s, D0[i, j_k] * s] 48 | D1[i, j] += min(min_list) 49 | if len(x) == 1: 50 | path = zeros(len(y)), range(len(y)) 51 | elif len(y) == 1: 52 | path = range(len(x)), zeros(len(x)) 53 | else: 54 | path = _traceback(D0) 55 | return D1[-1, -1], C, D1, path 56 | 57 | 58 | def accelerated_dtw(x, y, dist, warp=1): 59 | """ 60 | Computes Dynamic Time Warping (DTW) of two sequences in a faster way. 61 | Instead of iterating through each element and calculating each distance, 62 | this uses the cdist function from scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html) 63 | 64 | :param array x: N1*M array 65 | :param array y: N2*M array 66 | :param string or func dist: distance parameter for cdist. When string is given, cdist uses optimized functions for the distance metrics. 67 | If a string is passed, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. 68 | :param int warp: how many shifts are computed. 69 | Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path. 70 | """ 71 | assert len(x) 72 | assert len(y) 73 | if ndim(x) == 1: 74 | x = x.reshape(-1, 1) 75 | if ndim(y) == 1: 76 | y = y.reshape(-1, 1) 77 | r, c = len(x), len(y) 78 | D0 = zeros((r + 1, c + 1)) 79 | D0[0, 1:] = inf 80 | D0[1:, 0] = inf 81 | D1 = D0[1:, 1:] 82 | D0[1:, 1:] = cdist(x, y, dist) 83 | C = D1.copy() 84 | for i in range(r): 85 | for j in range(c): 86 | min_list = [D0[i, j]] 87 | for k in range(1, warp + 1): 88 | min_list += [D0[min(i + k, r), j], 89 | D0[i, min(j + k, c)]] 90 | D1[i, j] += min(min_list) 91 | if len(x) == 1: 92 | path = zeros(len(y)), range(len(y)) 93 | elif len(y) == 1: 94 | path = range(len(x)), zeros(len(x)) 95 | else: 96 | path = _traceback(D0) 97 | return D1[-1, -1], C, D1, path 98 | 99 | 100 | def _traceback(D): 101 | i, j = array(D.shape) - 2 102 | p, q = [i], [j] 103 | while (i > 0) or (j > 0): 104 | tb = argmin((D[i, j], D[i, j + 1], D[i + 1, j])) 105 | if tb == 0: 106 | i -= 1 107 | j -= 1 108 | elif tb == 1: 109 | i -= 1 110 | else: # (tb == 2): 111 | j -= 1 112 | p.insert(0, i) 113 | q.insert(0, j) 114 | return array(p), array(q) 115 | 116 | 117 | if __name__ == '__main__': 118 | w = inf 119 | s = 1.0 120 | if 1: # 1-D numeric 121 | from sklearn.metrics.pairwise import manhattan_distances 122 | 123 | x = [0, 0, 1, 1, 2, 4, 2, 1, 2, 0] 124 | y = [1, 1, 1, 2, 2, 2, 2, 3, 2, 0] 125 | dist_fun = manhattan_distances 126 | w = 1 127 | # s = 1.2 128 | elif 0: # 2-D numeric 129 | from sklearn.metrics.pairwise import euclidean_distances 130 | 131 | x = [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [4, 3], [2, 3], [1, 1], [2, 2], [0, 1]] 132 | y = [[1, 0], [1, 1], [1, 1], [2, 1], [4, 3], [4, 3], [2, 3], [3, 1], [1, 2], [1, 0]] 133 | dist_fun = euclidean_distances 134 | else: # 1-D list of strings 135 | from nltk.metrics.distance import edit_distance 136 | 137 | # x = ['we', 'shelled', 'clams', 'for', 'the', 'chowder'] 138 | # y = ['class', 'too'] 139 | x = ['i', 'soon', 'found', 'myself', 'muttering', 'to', 'the', 'walls'] 140 | y = ['see', 'drown', 'himself'] 141 | # x = 'we talked about the situation'.split() 142 | # y = 'we talked about the situation'.split() 143 | dist_fun = edit_distance 144 | dist, cost, acc, path = dtw(x, y, dist_fun, w=w, s=s) 145 | 146 | # Vizualize 147 | from matplotlib import pyplot as plt 148 | 149 | plt.imshow(cost.T, origin='lower', cmap=plt.cm.Reds, interpolation='nearest') 150 | plt.plot(path[0], path[1], '-o') # relation 151 | plt.xticks(range(len(x)), x) 152 | plt.yticks(range(len(y)), y) 153 | plt.xlabel('x') 154 | plt.ylabel('y') 155 | plt.axis('tight') 156 | if isinf(w): 157 | plt.title('Minimum distance: {}, slope weight: {}'.format(dist, s)) 158 | else: 159 | plt.title('Minimum distance: {}, window widht: {}, slope weight: {}'.format(dist, w, s)) 160 | plt.show() 161 | -------------------------------------------------------------------------------- /utils/metrics/laplace_var.py: -------------------------------------------------------------------------------- 1 | import scipy.ndimage 2 | 3 | def laplace_var(x): 4 | return scipy.ndimage.laplace(x).var() 5 | -------------------------------------------------------------------------------- /utils/metrics/mcd.py: -------------------------------------------------------------------------------- 1 | import essentia 2 | import essentia.standard as ess 3 | import matplotlib.pyplot as plt 4 | from dtw import dtw 5 | from numpy.linalg import norm 6 | import numpy as np 7 | from tqdm import tqdm 8 | import librosa 9 | import glob 10 | import os 11 | import pyworld 12 | 13 | # https://github.com/MTG/essentia/blob/master/src/examples/tutorial/example_mfcc_the_htk_way.py 14 | def extractor(audio, numberBands=26): # mel capstral 추출 15 | # fs = 22050 16 | # audio = ess.MonoLoader(filename=filename, 17 | # sampleRate=fs)() 18 | # dynamic range expansion as done in HTK implementation 19 | audio = audio * 2 ** 15 20 | frameSize = 1024 # corresponds to htk default WINDOWSIZE = 250000.0 21 | hopSize = 256 # corresponds to htk default TARGETRATE = 100000.0 22 | fftSize = 1024 23 | spectrumSize = fftSize // 2 + 1 24 | zeroPadding = fftSize - frameSize 25 | 26 | w = ess.Windowing(type='hamming', # corresponds to htk default USEHAMMING = T 27 | size=frameSize, 28 | zeroPadding=zeroPadding, 29 | normalized=False, 30 | zeroPhase=False) 31 | 32 | spectrum = ess.Spectrum(size=fftSize) 33 | 34 | mfcc_htk = ess.MFCC(inputSize=spectrumSize, 35 | type='magnitude', # htk uses mel filterbank magniude 36 | warpingFormula='htkMel', # htk's mel warping formula 37 | weighting='linear', # computation of filter weights done in Hz domain 38 | highFrequencyBound=8000, # corresponds to htk default 39 | lowFrequencyBound=0, # corresponds to htk default 40 | numberBands=numberBands, # corresponds to htk default NUMCHANS = 26 41 | numberCoefficients=13, 42 | normalize='unit_sum', # htk filter normaliation to have constant height = 1 43 | dctType=2, # htk uses DCT type III 44 | logType='log', 45 | liftering=0) # corresponds to htk default CEPLIFTER = 22 46 | 47 | mfccs = [] 48 | # startFromZero = True, validFrameThresholdRatio = 1 : the way htk computes windows 49 | for frame in ess.FrameGenerator(audio, frameSize=frameSize, hopSize=hopSize, startFromZero=True, 50 | validFrameThresholdRatio=1): 51 | spect = spectrum(w(frame)) 52 | mel_bands, mfcc_coeffs = mfcc_htk(spect) 53 | mfccs.append(mfcc_coeffs) 54 | 55 | # transpose to have it in a better shape 56 | # we need to convert the list to an essentia.array first (== numpy.array of floats) 57 | # mfccs = essentia.array(pool['MFCC']).T 58 | mfccs = essentia.array(mfccs).T[1:] 59 | 60 | # plt.imshow(mfccs[1:,:], aspect = 'auto', interpolation='none') # ignore enery 61 | # plt.xlabel('Frame', fontsize=14) 62 | # plt.ylabel('MCC', fontsize=14) 63 | # plt.imshow(mfccs, aspect = 'auto', interpolation='none') 64 | # plt.show() # unnecessary if you started "ipython --pylab" 65 | return (mfccs) 66 | 67 | def get_pitchContour(wav, hop_size, sample_rate): 68 | 69 | frame_period = (hop_size / (0.001 * sample_rate)) 70 | f0, timeaxis = pyworld.harvest(wav, sample_rate, frame_period=frame_period) 71 | 72 | return f0 73 | 74 | def MCD(audio_one, audio_two, numberBands=13): # distortion 계산 75 | # https://github.com/danijel3/PyHTK/blob/master/python-notebooks/HTKFeaturesExplained.ipynb 76 | # normalization 77 | mfcc_one = extractor(audio_one, numberBands) * np.sqrt(2 / numberBands) 78 | mfcc_two = extractor(audio_two, numberBands) * np.sqrt(2 / numberBands) 79 | 80 | if np.isnan(mfcc_one[0][-1]) or np.isinf(mfcc_one[0][-1]): 81 | mfcc_one = mfcc_one[:, :-1] 82 | if np.isnan(mfcc_two[0][-1]) or np.isinf(mfcc_two[0][-1]): 83 | mfcc_two = mfcc_two[:, :-1] 84 | 85 | dist, cost, acc_cost, path = dtw(mfcc_one.T, mfcc_two.T, dist=lambda x, y: norm(x - y, ord=1)) 86 | 87 | dtw_one = mfcc_one.T[path[0]] 88 | dtw_two = mfcc_two.T[path[1]] 89 | 90 | mcd = 10 / np.log(10) * np.sqrt(2 * np.sum(((dtw_one - dtw_two) ** 2), axis=1)) 91 | mcd = np.sum(mcd) / len(mcd) 92 | 93 | return mcd 94 | 95 | def F0_RMSE(audio_one, audio_two, numberBands=26): # distortion 계산 96 | hop_size = 256 97 | sample_rate = 22050 98 | 99 | p1 = get_pitchContour(audio_one.astype(np.float64), hop_size, sample_rate) 100 | p2 = get_pitchContour(audio_two.astype(np.float64), hop_size, sample_rate) 101 | 102 | p1 = np.nan_to_num(p1) 103 | p2 = np.nan_to_num(p2) 104 | 105 | # numpy.linalg.norm option: https://leebaro.tistory.com/entry/numpylinalgnorm 106 | p1 = np.expand_dims(p1, axis=0) 107 | p2 = np.expand_dims(p2, axis=0) 108 | dist, cost, acc_cost, path = dtw(p1.T, p2.T, dist=lambda x, y: norm(x - y, ord=1)) # 1->0, ord가 norm의 옵션임. 109 | 110 | dtw_one = p1.T[path[0]] 111 | dtw_two = p2.T[path[1]] 112 | 113 | # rmse 공식 확인! sqrt가 마지막이다! 114 | f0_rmse = np.sqrt(np.sum((dtw_one - dtw_two) ** 2) / dtw_one.shape[0]) 115 | 116 | return f0_rmse 117 | 118 | def DDUR(audio_one, audio_two, sample_rate=22050, rescaling_max=0.999): 119 | 120 | audio_one = audio_one / np.abs(audio_one).max() * rescaling_max 121 | audio_two = audio_two / np.abs(audio_two).max() * rescaling_max 122 | 123 | def cal_DDUR(audio): 124 | # computed the average absolute differences between the durations of the converted and target utterances 125 | intervals = librosa.effects.split(audio, top_db=20, frame_length=1024, hop_length=256) # [(s1, e1),..., (sn, en)] 126 | DDUR = np.sum([(e-s) for s, e in intervals], axis=0) / sample_rate # (단위: sample개수 -> s) 127 | return DDUR 128 | 129 | return np.mean(np.abs(cal_DDUR(audio_one) - cal_DDUR(audio_two))) 130 | -------------------------------------------------------------------------------- /utils/metrics/pitch_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from numba import jit 4 | 5 | import torch 6 | 7 | 8 | @jit 9 | def time_warp(costs): 10 | dtw = np.zeros_like(costs) 11 | dtw[0, 1:] = np.inf 12 | dtw[1:, 0] = np.inf 13 | eps = 1e-4 14 | for i in range(1, costs.shape[0]): 15 | for j in range(1, costs.shape[1]): 16 | dtw[i, j] = costs[i, j] + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1]) 17 | return dtw 18 | 19 | 20 | def align_from_distances(distance_matrix, debug=False, return_mindist=False): 21 | # for each position in spectrum 1, returns best match position in spectrum2 22 | # using monotonic alignment 23 | dtw = time_warp(distance_matrix) 24 | 25 | i = distance_matrix.shape[0] - 1 26 | j = distance_matrix.shape[1] - 1 27 | results = [0] * distance_matrix.shape[0] 28 | while i > 0 and j > 0: 29 | results[i] = j 30 | i, j = min([(i - 1, j), (i, j - 1), (i - 1, j - 1)], key=lambda x: dtw[x[0], x[1]]) 31 | 32 | if debug: 33 | visual = np.zeros_like(dtw) 34 | visual[range(len(results)), results] = 1 35 | plt.matshow(visual) 36 | plt.show() 37 | if return_mindist: 38 | return results, dtw[-1, -1] 39 | return results 40 | 41 | 42 | def get_local_context(input_f, max_window=32, scale_factor=1.): 43 | # input_f: [S, 1], support numpy array or torch tensor 44 | # return hist: [S, max_window * 2], list of list 45 | T = input_f.shape[0] 46 | # max_window = int(max_window * scale_factor) 47 | derivative = [[0 for _ in range(max_window * 2)] for _ in range(T)] 48 | 49 | for t in range(T): # travel the time series 50 | for feat_idx in range(-max_window, max_window): 51 | if t + feat_idx < 0 or t + feat_idx >= T: 52 | value = 0 53 | else: 54 | value = input_f[t + feat_idx] 55 | derivative[t][feat_idx + max_window] = value 56 | return derivative 57 | 58 | 59 | def cal_localnorm_dist(src, tgt, src_len, tgt_len): 60 | local_src = torch.tensor(get_local_context(src)) 61 | local_tgt = torch.tensor(get_local_context(tgt, scale_factor=tgt_len / src_len)) 62 | 63 | local_norm_src = (local_src - local_src.mean(-1).unsqueeze(-1)) # / local_src.std(-1).unsqueeze(-1) # [T1, 32] 64 | local_norm_tgt = (local_tgt - local_tgt.mean(-1).unsqueeze(-1)) # / local_tgt.std(-1).unsqueeze(-1) # [T2, 32] 65 | 66 | dists = torch.cdist(local_norm_src[None, :, :], local_norm_tgt[None, :, :]) # [1, T1, T2] 67 | return dists 68 | 69 | 70 | ## here is API for one sample 71 | def LoNDTWDistance(src, tgt): 72 | # src: [S] 73 | # tgt: [T] 74 | dists = cal_localnorm_dist(src, tgt, src.shape[0], tgt.shape[0]) # [1, S, T] 75 | costs = dists.squeeze(0) # [S, T] 76 | alignment, min_distance = align_from_distances(costs.T.cpu().detach().numpy(), return_mindist=True) # [T] 77 | return alignment, min_distance 78 | 79 | # if __name__ == '__main__': 80 | # # utils from ns 81 | # from utils.pitch_utils import denorm_f0 82 | # from tasks.singing.fsinging import FastSingingDataset 83 | # from utils.hparams import hparams, set_hparams 84 | # 85 | # set_hparams() 86 | # 87 | # train_ds = FastSingingDataset('test') 88 | # 89 | # # Test One sample case 90 | # sample = train_ds[0] 91 | # amateur_f0 = sample['f0'] 92 | # prof_f0 = sample['prof_f0'] 93 | # 94 | # amateur_uv = sample['uv'] 95 | # amateur_padding = sample['mel2ph'] == 0 96 | # prof_uv = sample['prof_uv'] 97 | # prof_padding = sample['prof_mel2ph'] == 0 98 | # amateur_f0_denorm = denorm_f0(amateur_f0, amateur_uv, hparams, pitch_padding=amateur_padding) 99 | # prof_f0_denorm = denorm_f0(prof_f0, prof_uv, hparams, pitch_padding=prof_padding) 100 | # alignment, min_distance = LoNDTWDistance(amateur_f0_denorm, prof_f0_denorm) 101 | # print(min_distance) 102 | # python utils/pitch_distance.py --config egs/datasets/audio/molar/svc_ppg.yaml 103 | -------------------------------------------------------------------------------- /utils/metrics/ssim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from math import exp 10 | 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 14 | return gauss / gauss.sum() 15 | 16 | 17 | def create_window(window_size, channel): 18 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 19 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 21 | return window 22 | 23 | 24 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 25 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 26 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 27 | 28 | mu1_sq = mu1.pow(2) 29 | mu2_sq = mu2.pow(2) 30 | mu1_mu2 = mu1 * mu2 31 | 32 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 33 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 34 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 35 | 36 | C1 = 0.01 ** 2 37 | C2 = 0.03 ** 2 38 | 39 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 40 | 41 | if size_average: 42 | return ssim_map.mean() 43 | else: 44 | return ssim_map.mean(1) 45 | 46 | 47 | class SSIM(torch.nn.Module): 48 | def __init__(self, window_size=11, size_average=True): 49 | super(SSIM, self).__init__() 50 | self.window_size = window_size 51 | self.size_average = size_average 52 | self.channel = 1 53 | self.window = create_window(window_size, self.channel) 54 | 55 | def forward(self, img1, img2): 56 | (_, channel, _, _) = img1.size() 57 | 58 | if channel == self.channel and self.window.data.type() == img1.data.type(): 59 | window = self.window 60 | else: 61 | window = create_window(self.window_size, channel) 62 | 63 | if img1.is_cuda: 64 | window = window.cuda(img1.get_device()) 65 | window = window.type_as(img1) 66 | 67 | self.window = window 68 | self.channel = channel 69 | 70 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 71 | 72 | 73 | window = None 74 | 75 | 76 | def ssim(img1, img2, window_size=11, size_average=True): 77 | (_, channel, _, _) = img1.size() 78 | global window 79 | if window is None: 80 | window = create_window(window_size, channel) 81 | if img1.is_cuda: 82 | window = window.cuda(img1.get_device()) 83 | window = window.type_as(img1) 84 | return _ssim(img1, img2, window, window_size, channel, size_average) 85 | -------------------------------------------------------------------------------- /utils/nn/model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def print_arch(model, model_name='model'): 5 | print(f"| {model_name} Arch: ", model) 6 | num_params(model, model_name=model_name) 7 | 8 | 9 | def num_params(model, print_out=True, model_name="model"): 10 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 11 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 12 | if print_out: 13 | print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) 14 | return parameters 15 | -------------------------------------------------------------------------------- /utils/nn/schedulers.py: -------------------------------------------------------------------------------- 1 | class NoneSchedule(object): 2 | def __init__(self, optimizer, lr): 3 | self.optimizer = optimizer 4 | self.constant_lr = lr 5 | self.step(0) 6 | 7 | def step(self, num_updates): 8 | self.lr = self.constant_lr 9 | for param_group in self.optimizer.param_groups: 10 | param_group['lr'] = self.lr 11 | return self.lr 12 | 13 | def get_lr(self): 14 | return self.optimizer.param_groups[0]['lr'] 15 | 16 | def get_last_lr(self): 17 | return self.get_lr() 18 | 19 | 20 | class RSQRTSchedule(NoneSchedule): 21 | def __init__(self, optimizer, lr, warmup_updates, hidden_size): 22 | self.optimizer = optimizer 23 | self.constant_lr = lr 24 | self.warmup_updates = warmup_updates 25 | self.hidden_size = hidden_size 26 | self.lr = lr 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = self.lr 29 | self.step(0) 30 | 31 | def step(self, num_updates): 32 | constant_lr = self.constant_lr 33 | warmup = min(num_updates / self.warmup_updates, 1.0) 34 | rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 35 | rsqrt_hidden = self.hidden_size ** -0.5 36 | self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) 37 | for param_group in self.optimizer.param_groups: 38 | param_group['lr'] = self.lr 39 | return self.lr 40 | 41 | 42 | class WarmupSchedule(NoneSchedule): 43 | def __init__(self, optimizer, lr, warmup_updates): 44 | self.optimizer = optimizer 45 | self.constant_lr = self.lr = lr 46 | self.warmup_updates = warmup_updates 47 | for param_group in optimizer.param_groups: 48 | param_group['lr'] = self.lr 49 | self.step(0) 50 | 51 | def step(self, num_updates): 52 | constant_lr = self.constant_lr 53 | warmup = min(num_updates / self.warmup_updates, 1.0) 54 | self.lr = max(constant_lr * warmup, 1e-7) 55 | for param_group in self.optimizer.param_groups: 56 | param_group['lr'] = self.lr 57 | return self.lr 58 | -------------------------------------------------------------------------------- /utils/os_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | 5 | def link_file(from_file, to_file): 6 | subprocess.check_call( 7 | f'ln -s "`realpath --relative-to="{os.path.dirname(to_file)}" "{from_file}"`" "{to_file}"', shell=True) 8 | 9 | 10 | def move_file(from_file, to_file): 11 | subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True) 12 | 13 | 14 | def copy_file(from_file, to_file): 15 | subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True) 16 | 17 | 18 | def remove_file(*fns): 19 | for f in fns: 20 | subprocess.check_call(f'rm -rf "{f}"', shell=True) 21 | -------------------------------------------------------------------------------- /utils/plot/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy'] 9 | 10 | 11 | def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None): 12 | if isinstance(spec, torch.Tensor): 13 | spec = spec.cpu().numpy() 14 | H = spec.shape[1] // 2 15 | fig = plt.figure(figsize=(12, 6)) 16 | plt.title(title) 17 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 18 | if dur_info is not None: 19 | assert isinstance(dur_info, dict) 20 | txt = dur_info['txt'] 21 | dur_gt = dur_info['dur_gt'] 22 | if isinstance(dur_gt, torch.Tensor): 23 | dur_gt = dur_gt.cpu().numpy() 24 | dur_gt = np.cumsum(dur_gt).astype(int) 25 | for i in range(len(dur_gt)): 26 | shift = (i % 8) + 1 27 | plt.text(dur_gt[i], shift * 4, txt[i]) 28 | plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt 29 | plt.xlim(0, dur_gt[-1]) 30 | if 'dur_pred' in dur_info: 31 | dur_pred = dur_info['dur_pred'] 32 | if isinstance(dur_pred, torch.Tensor): 33 | dur_pred = dur_pred.cpu().numpy() 34 | dur_pred = np.cumsum(dur_pred).astype(int) 35 | for i in range(len(dur_pred)): 36 | shift = (i % 8) + 1 37 | plt.text(dur_pred[i], H + shift * 4, txt[i]) 38 | plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred 39 | plt.xlim(0, max(dur_gt[-1], dur_pred[-1])) 40 | if f0s is not None: 41 | ax = plt.gca() 42 | ax2 = ax.twinx() 43 | if not isinstance(f0s, dict): 44 | f0s = {'f0': f0s} 45 | for i, (k, f0) in enumerate(f0s.items()): 46 | if isinstance(f0, torch.Tensor): 47 | f0 = f0.cpu().numpy() 48 | ax2.plot(f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5) 49 | ax2.set_ylim(0, 1000) 50 | ax2.legend() 51 | return fig 52 | -------------------------------------------------------------------------------- /utils/text/encoding.py: -------------------------------------------------------------------------------- 1 | import chardet 2 | 3 | 4 | def get_encoding(file): 5 | with open(file, 'rb') as f: 6 | encoding = chardet.detect(f.read())['encoding'] 7 | if encoding == 'GB2312': 8 | encoding = 'GB18030' 9 | return encoding 10 | --------------------------------------------------------------------------------