├── data └── .gitkeep ├── checkpoints └── .gitkeep ├── tasks ├── vocoder_infer │ ├── __init__.py │ └── base_vocoder.py ├── runs │ └── run.py ├── utils.py ├── dataset_utils.py └── visinger.py ├── preprocessor ├── text │ ├── __init__.py │ ├── base_text_processor.py │ ├── dict │ │ └── korean.json │ └── ko_sing.py ├── wave │ ├── __init__.py │ ├── base_wave_processor.py │ └── common_processor.py └── runs │ ├── base_preprocess.py │ └── base_binarize.py ├── utils ├── audio │ ├── __init__.py │ ├── io.py │ ├── pitch │ │ └── utils.py │ ├── pitch_extractors.py │ ├── mel_processing.py │ ├── vad.py │ └── align.py ├── commons │ ├── single_thread_env.py │ ├── meters.py │ ├── indexed_datasets.py │ ├── ckpt_utils.py │ ├── tensor_utils.py │ ├── multiprocess_utils.py │ ├── hparams.py │ ├── ddp_utils.py │ ├── dataset_utils.py │ └── base_task.py ├── text │ ├── ko_symbols.py │ └── text_encoder.py ├── nn │ ├── model_utils.py │ └── seq_utils.py ├── os_utils.py └── plot │ └── plot.py ├── assets └── architecture.png ├── .gitignore ├── requirements.txt ├── models ├── commons │ └── align_ops.py └── visinger.py ├── config ├── models │ ├── base_config.yaml │ ├── base_task.yaml │ └── visinger.yaml └── datasets │ └── svs │ └── csd │ ├── preprocess.py │ └── preprocess.yaml ├── LICENSE ├── modules ├── visinger │ ├── predictor.py │ ├── decoder.py │ ├── encoder.py │ └── flow.py ├── discriminator.py └── commons │ └── utils.py ├── README.md └── inference └── visinger.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tasks/vocoder_infer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessor/text/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ko_sing -------------------------------------------------------------------------------- /utils/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.audio.vad import trim_long_silences -------------------------------------------------------------------------------- /preprocessor/wave/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base_wave_processor 2 | from . import common_processor -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jisang93/VISinger/HEAD/assets/architecture.png -------------------------------------------------------------------------------- /utils/commons/single_thread_env.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import os 3 | 4 | os.environ["OMP_NUM_THREADS"] = "1" 5 | os.environ['TF_NUM_INTEROP_THREADS'] = '1' 6 | os.environ['TF_NUM_INTRAOP_THREADS'] = '1' 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # mypy 2 | __pycache__/ 3 | */__pycache__/ 4 | # checkpoints 5 | /checkpoints/* 6 | !/checkpoints/.gitkeep 7 | # infer 8 | /infer_out/* 9 | # data 10 | /data/* 11 | !/data/.gitkeep 12 | # tmp 13 | tmp/* 14 | # Sheel file 15 | *.sh -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | essentia 3 | g2pk 4 | jamo 5 | librosa 6 | matplotlib 7 | miditoolkit 8 | praat-parselmouth==0.3.3 9 | pyloudnorm 10 | pyworld 11 | pyyaml 12 | resemblyzer 13 | scikit-image 14 | scipy 15 | tensoarboard 16 | textgrid 17 | torchaudio 18 | tqdm 19 | webrtcvad -------------------------------------------------------------------------------- /utils/text/ko_symbols.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/keithito/tacotron 2 | # Defines the set of symbols used in text input to the model. 3 | 4 | _JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) 5 | _JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) 6 | _JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) 7 | 8 | _VALID_CHARS = _JAMO_LEADS + _JAMO_VOWELS + _JAMO_TAILS 9 | symbols = list(_VALID_CHARS) 10 | -------------------------------------------------------------------------------- /utils/nn/model_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import numpy as np 3 | 4 | 5 | def print_arch(model, model_name="model"): 6 | print(f"| {model_name} Arch: {model}") 7 | num_params(model, model_name=model_name) 8 | 9 | 10 | def num_params(model, print_out=True, model_name="model"): 11 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 12 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 13 | if print_out: 14 | print(f"| {model_name} Trainable Parameters: {parameters:.3f}M") 15 | return parameters 16 | -------------------------------------------------------------------------------- /preprocessor/runs/base_preprocess.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import utils.commons.single_thread_env # NOQA 3 | 4 | import importlib 5 | 6 | from utils.commons.hparams import hparams, set_hparams 7 | 8 | 9 | def preprocess(): 10 | assert hparams["preprocess_cls"] != "" 11 | 12 | pkg = ".".join(hparams["preprocess_cls"].split(".")[:-1]) 13 | cls_name = hparams["preprocess_cls"].split(".")[-1] 14 | process_cls = getattr(importlib.import_module(pkg), cls_name) 15 | process_cls().process() 16 | 17 | 18 | if __name__ == "__main__": 19 | set_hparams() 20 | preprocess() 21 | -------------------------------------------------------------------------------- /preprocessor/runs/base_binarize.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import utils.commons.single_thread_env # NOQA 3 | import importlib 4 | 5 | from utils.commons.hparams import hparams, set_hparams 6 | 7 | 8 | def binarize(): 9 | binarizer_cls = hparams.get("binarizer_cls", "preprocessor.base_binarizer.BaseBinarizer") 10 | pkg = ".".join(binarizer_cls.split(".")[:-1]) 11 | cls_name = binarizer_cls.split(".")[-1] 12 | binarizer_cls = getattr(importlib.import_module(pkg), cls_name) 13 | print("| Binarizer: ", binarizer_cls) 14 | binarizer_cls().process() 15 | 16 | 17 | if __name__ == "__main__": 18 | set_hparams() 19 | binarize() 20 | -------------------------------------------------------------------------------- /preprocessor/wave/base_wave_processor.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | 3 | REGISTERED_WAV_PROCESSORS = {} 4 | 5 | def register_wav_processors(name): 6 | def _f(cls): 7 | REGISTERED_WAV_PROCESSORS[name] = cls 8 | return cls 9 | 10 | return _f 11 | 12 | 13 | def get_wav_processor_cls(name): 14 | return REGISTERED_WAV_PROCESSORS.get(name, None) 15 | 16 | 17 | class BaseWavProcessor: 18 | @property 19 | def name(self): 20 | raise NotImplementedError 21 | 22 | def output_fn(self, input_fn: str): 23 | return f"{input_fn[:-4]}_{self.name}.wav" 24 | 25 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /tasks/runs/run.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import os 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | import torch 6 | import importlib 7 | 8 | from utils.commons.hparams import hparams, set_hparams 9 | 10 | 11 | def run_task(): 12 | assert hparams['task_cls'] != '' 13 | pkg = ".".join(hparams["task_cls"].split(".")[:-1]) 14 | cls_name = hparams["task_cls"].split(".")[-1] 15 | task_cls = getattr(importlib.import_module(pkg), cls_name) 16 | task_cls.start() 17 | 18 | 19 | if __name__ == '__main__': 20 | if os.environ.get("CUDA_VISIBLE_DEVICES", None) is None: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(i) for i in range(torch.cuda.device_count())]) 22 | set_hparams() 23 | run_task() 24 | 25 | -------------------------------------------------------------------------------- /utils/audio/io.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import numpy as np 3 | import subprocess 4 | 5 | from scipy.io import wavfile 6 | 7 | 8 | def save_wav(wav, path, sr, norm=False): 9 | if norm: 10 | wav = wav / np.abs(wav).max() 11 | wav = wav * 32767 12 | wavfile.write(path[:-4] + ".wav", sr, wav.astype(np.int16)) 13 | if path[-4:] == ".mp3": 14 | to_mp3(path[:-4]) 15 | 16 | 17 | def to_mp3(out_path): 18 | if out_path[-4:] == ".wav": 19 | out_path = out_path[:-4] 20 | subprocess.check_call( 21 | f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"', 22 | shell=True, stdin=subprocess.PIPE) 23 | subprocess.check_call(f"rm -f '{out_path}.wav'", shell=True) 24 | -------------------------------------------------------------------------------- /models/commons/align_ops.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def build_word_mask(x2word, y2word): 7 | return (x2word[:, :, None] == y2word[:, None, :]).long() 8 | 9 | 10 | def mel2ph_to_mel2word(mel2ph, ph2word): 11 | mel2ph = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1 12 | mel2word = mel2word * (mel2ph > 0).long() 13 | return mel2word 14 | 15 | 16 | def clip_mel2token_to_multiple(mel2token, frames_multiple): 17 | max_frames = mel2token.shape[1] // frames_multiple * frames_multiple 18 | mel2token = mel2token[:, :max_frames] 19 | return mel2token 20 | 21 | 22 | def expand_states(h, mel2token): 23 | h = F.pad(h, [0, 0, 1, 0]) 24 | mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) 25 | h = torch.gather(h, 1, mel2token_) 26 | return h # [Batch, T_mels, Hidden] 27 | -------------------------------------------------------------------------------- /utils/os_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import subprocess 3 | 4 | from inspect import isfunction 5 | 6 | 7 | def link_file(from_file, to_file): 8 | subprocess.check_call( 9 | f'ln -s "`realpath --relative-to="{to_file}" "{from_file}"`" "{to_file}"', shell=True) 10 | 11 | 12 | def move_file(from_file, to_file): 13 | """ Move from_file to to_file. """ 14 | subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True) 15 | 16 | 17 | def copy_file(from_file, to_file): 18 | """ Copy from_file to to_file. """ 19 | subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True) 20 | 21 | 22 | def remove_file(*fns): 23 | """ Remove files from fns. """ 24 | for f in fns: 25 | subprocess.check_call(f'rm -rf "{f}"', shell=True) 26 | 27 | 28 | def default(val, d): 29 | if val is not None: 30 | return val 31 | return d() if isfunction(d) else d 32 | -------------------------------------------------------------------------------- /config/models/base_config.yaml: -------------------------------------------------------------------------------- 1 | # task 2 | binary_data_dir: './data/binary/' 3 | work_dir: './checkpoints/' # experiment directory. 4 | infer: false # infer 5 | amp: false 6 | seed: 1234 7 | debug: false 8 | save_codes: ['tasks', 'models', 'modules'] 9 | 10 | ############# 11 | # dataset 12 | ############# 13 | ds_workers: 1 14 | endless_ds: true 15 | sort_by_len: true 16 | 17 | ######### 18 | # train and eval 19 | ######### 20 | print_nan_grads: false 21 | load_ckpt: '' 22 | save_best: true 23 | num_ckpt_keep: 100 24 | clip_grad_norm: 0 25 | accumulate_grad_batches: 1 26 | tb_log_interval: 100 27 | num_sanity_val_steps: 5 # steps of validation at the beginning 28 | check_val_every_n_epoch: 10 29 | val_check_interval: 500 30 | ckpt_save_interval: 1000 31 | valid_monitor_key: 'val_loss' 32 | valid_monitor_mode: 'min' 33 | max_epochs: 100000 34 | max_updates: 100000 35 | max_tokens: 20000 36 | max_sentences: 64 37 | eval_max_batches: 128 38 | resume_from_checkpoint: 0 39 | rename_tmux: true -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/models/base_task.yaml: -------------------------------------------------------------------------------- 1 | # task 2 | base_config: 3 | - ./base_config.yaml 4 | 5 | ############# 6 | # dataset in training 7 | ############# 8 | endless_ds: true 9 | min_frames: 0 10 | max_frames: 2000 11 | frames_multiple: 1 12 | max_input_tokens: 4000 13 | ds_workers: 1 14 | 15 | ######### 16 | # model 17 | ######### 18 | use_spk_id: true 19 | use_spk_embed: false 20 | mel_losses: "ssim:0.5|l1:0.5" 21 | 22 | ########### 23 | # optimization 24 | ########### 25 | lr: 0.0005 26 | scheduler: warmup # rsqrt|warmup|none 27 | warmup_updates: 4000 28 | optimizer_adam_beta1: 0.9 29 | optimizer_adam_beta2: 0.98 30 | weight_decay: 0 31 | clip_grad_norm: 1 32 | clip_grad_value: 0 33 | 34 | 35 | ########### 36 | # train and eval 37 | ########### 38 | use_word_input: false 39 | max_valid_sentences: 1 40 | max_valid_tokens: 60000 41 | valid_infer_interval: 1000 42 | train_set_name: 'train' 43 | train_sets: '' 44 | valid_set_name: 'valid' 45 | test_set_name: 'test' 46 | num_valid_plots: 10 47 | test_ids: [ ] 48 | test_input_yaml: '' 49 | vocoder: HifiGAN 50 | vocoder_ckpt: './checkpoints/pretrain' 51 | vocoder_config: 'hifigan' 52 | profile_infer: false 53 | out_wav_norm: true 54 | save_gt: true 55 | save_f0: false 56 | gen_dir_name: '' -------------------------------------------------------------------------------- /utils/commons/meters.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import time 3 | import torch 4 | 5 | 6 | class AvgrageMeter(object): 7 | """ Calculate average, summation, count for training result. """ 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.avg = 0 14 | self.sum = 0 15 | self.cnt = 0 16 | 17 | def update(self, val, n=1): 18 | self.sum += val * n 19 | self.cnt += n 20 | self.avg = self.sum / self.cnt 21 | 22 | 23 | class Timer: 24 | timer_map = {} 25 | 26 | def __init__(self, name, enable=False): 27 | if name not in Timer.timer_map: 28 | Timer.timer_map[name] = 0 29 | self.name = name 30 | self.enable = enable 31 | 32 | def __enter__(self): 33 | if self.enable: 34 | if torch.cuda.is_available(): 35 | torch.cuda.synchronize() 36 | self.t = time.time() 37 | 38 | def __exit__(self, exc_type, exc_val, exc_tb): 39 | if self.enable: 40 | if torch.cuda.is_available(): 41 | torch.cuda.synchronize() 42 | Timer.timer_map[self.name] += time.time() - self.t 43 | if self.enable: 44 | print(f"| {Timer} {self.name}: {Timer.timer_map[self.name]}") 45 | -------------------------------------------------------------------------------- /modules/visinger/predictor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from modules.rel_transformer import RelativeEncoder 5 | 6 | 7 | class PitchPredictor(nn.Module): 8 | """ Pitch predictor for VISinger. """ 9 | def __init__(self, in_dim, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels, out_dim=2): 10 | super().__init__() 11 | # Pitch Predictor 12 | self.pitch_predictor = RelativeEncoder(in_dim, filter_channels, n_heads, n_layers=n_layers, 13 | gin_channels=gin_channels, kernel_size=kernel_size, p_dropout=p_dropout) 14 | self.linear = nn.Conv1d(in_dim, out_dim, 1) 15 | 16 | def forward(self, x, x_mask, spk_emb): 17 | x = self.pitch_predictor(x, x_mask, g=spk_emb) 18 | x = self.linear(x).transpose(1, 2) # [Batch, T_len, Out_dim] 19 | return x 20 | 21 | 22 | class PhonemePredictor(nn.Module): 23 | """ Phoneme predictor for VISinger. """ 24 | def __init__(self, dict_size, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout): 25 | super().__init__() 26 | # Phoneme predictor 27 | self.phoneme_predictor = RelativeEncoder(hidden_channels, filter_channels, n_heads, n_layers=n_layers, 28 | kernel_size=kernel_size, p_dropout=p_dropout) 29 | self.ph_proj = nn.Conv1d(hidden_channels, dict_size, 1) 30 | 31 | def forward(self, x, x_mask): 32 | x = self.phoneme_predictor(x, x_mask) 33 | ph_pred = self.ph_proj(x) # [Batch, Dict_size, T_len] 34 | ph_pred = F.log_softmax(ph_pred, dim=1) 35 | return ph_pred 36 | -------------------------------------------------------------------------------- /utils/plot/plot.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import matplotlib 3 | 4 | matplotlib.use('Agg') 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | 10 | LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy'] 11 | 12 | 13 | def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None): 14 | if isinstance(spec, torch.Tensor): 15 | spec = spec.cpu().numpy() 16 | H = spec.shape[1] // 2 17 | fig = plt.figure(figsize=(12, 6)) 18 | plt.title(title) 19 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 20 | if dur_info is not None: 21 | assert isinstance(dur_info, dict) 22 | dur_gt = dur_info['duration_gt'] 23 | if isinstance(dur_gt, torch.Tensor): 24 | dur_gt = dur_gt.cpu().numpy() 25 | dur_gt = np.cumsum(dur_gt).astype(int) 26 | for i in range(len(dur_gt)): 27 | plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt 28 | plt.xlim(0, dur_gt[-1]) 29 | if 'durartion_pred' in dur_info: 30 | dur_pred = dur_info['durartion_pred'] 31 | if isinstance(dur_pred, torch.Tensor): 32 | dur_pred = dur_pred.cpu().numpy() 33 | dur_pred = np.cumsum(dur_pred).astype(int) 34 | for i in range(len(dur_pred)): 35 | plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred 36 | plt.xlim(0, max(dur_gt[-1], dur_pred[-1])) 37 | if f0s is not None: 38 | ax = plt.gca() 39 | ax2 = ax.twinx() 40 | if not isinstance(f0s, dict): 41 | f0s = {'f0': f0s} 42 | for i, (k, f0) in enumerate(f0s.items()): 43 | if isinstance(f0, torch.Tensor): 44 | f0 = f0.cpu().numpy() 45 | ax2.plot(f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5) 46 | ax2.set_ylim(0, 1250) 47 | ax2.legend() 48 | return fig 49 | -------------------------------------------------------------------------------- /preprocessor/text/base_text_processor.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | from utils.text.text_encoder import is_sil_phoneme 3 | 4 | REGISTERED_TEXT_PROCESSORS = {} 5 | 6 | 7 | def register_text_processors(name): 8 | # Regist text_processors to REGISTERED_TEXT_PROCESSORS 9 | def _f(cls): 10 | REGISTERED_TEXT_PROCESSORS[name] = cls 11 | return cls 12 | 13 | return _f 14 | 15 | 16 | def get_text_processor_cls(name): 17 | # Get text_processors from REGSISTERED_TEXT_PROCESSORS 18 | return REGISTERED_TEXT_PROCESSORS.get(name, None) 19 | 20 | 21 | class BaseTextProcessor: 22 | @staticmethod 23 | def sp_phonemes(): 24 | return ['|'] 25 | 26 | @classmethod 27 | def process(cls, text, preprocess_args): 28 | raise NotImplementedError 29 | 30 | @classmethod 31 | def postprocess(cls, text_struct, preprocess_args): 32 | # Remove sil_phoneme in head and tail 33 | while len(text_struct) > 0 and is_sil_phoneme(text_struct[0][0]): 34 | text_struct = text_struct[1:] 35 | while len(text_struct) > 0 and (is_sil_phoneme(text_struct[-1][0]) and text_struct[-1][0] not in cls.PUNCS): 36 | text_struct = text_struct[:-1] 37 | if preprocess_args["with_phsep"]: # Add ph to each word 38 | text_struct = cls.add_bdr(text_struct) 39 | if preprocess_args["add_eos_bos"]: # Add EOS and BOS token 40 | text_struct = [["", [""]]] + text_struct + [["", [""]]] 41 | return text_struct 42 | 43 | @classmethod 44 | def add_bdr(cls, text_struct): 45 | text_struct_ = [] 46 | for i, ts in enumerate(text_struct): 47 | text_struct_.append(ts) 48 | if i != len(text_struct) - 1 and \ 49 | not is_sil_phoneme(text_struct[i][0]) and not is_sil_phoneme(text_struct[i + 1][0]): 50 | text_struct_.append(['|', ['|']]) 51 | return text_struct_ 52 | -------------------------------------------------------------------------------- /config/datasets/svs/csd/preprocess.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import glob 3 | import os 4 | import miditoolkit 5 | import traceback 6 | 7 | from preprocessor.base_preprocessor import BasePreprocessor 8 | 9 | 10 | class CSDPreprocessor(BasePreprocessor): 11 | """ We alredy split the note, lyrics and waveform before preprocessing. 12 | In this repository, we only use Korean singing voice data in CSD dataset.""" 13 | def meta_data(self): 14 | # Get data 15 | base_dir = f"{self.raw_data_dir}" 16 | file_dirs = glob.glob(f"{base_dir}/midi/*.mid") 17 | # Get song information 18 | for dir in file_dirs: 19 | filename = os.path.basename(dir) 20 | item_name = filename.split(".")[0] 21 | spk_name = "csd" 22 | # file directory 23 | wav_fn = f"{base_dir}/wav/{item_name}.wav" 24 | with open(f"{base_dir}/text/{item_name}.txt", "r") as f: 25 | text = f.readline().strip().replace(" ", "") 26 | # Refine midi file 27 | try: 28 | midi_obj = miditoolkit.midi.parser.MidiFile(dir) 29 | midi_obj = self.refine_midi_file(midi_obj, text) 30 | yield {"item_name": item_name, "wav_fn": wav_fn, "midi_fn": dir, "midi_obj": midi_obj, 31 | "text": text, "spk_name": spk_name} 32 | except: 33 | traceback.print_exc() 34 | print(f"| Error is caught. item_name: {item_name}.") 35 | pass 36 | 37 | @staticmethod 38 | def refine_midi_file(midi_obj, lyrics): 39 | notes = midi_obj.instruments[0].notes 40 | assert len(notes) == len(lyrics), f"| Note: {len(notes)}, lyrics: {len(lyrics)}" 41 | lyric_list = [] 42 | for i, lyr in enumerate(lyrics): 43 | lyric = miditoolkit.Lyric(lyr, notes[i].start) 44 | lyric_list.append(lyric) 45 | 46 | midi_obj.lyrics = lyric_list 47 | return midi_obj 48 | -------------------------------------------------------------------------------- /tasks/utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import importlib 3 | 4 | from preprocessor.base_binarizer import BaseBinarizer 5 | from preprocessor.base_preprocessor import BasePreprocessor 6 | from utils.commons.hparams import hparams 7 | 8 | 9 | def parse_dataset_configs(): 10 | max_tokens = hparams["max_tokens"] 11 | max_sentences = hparams["max_sentences"] 12 | max_valid_tokens = hparams["max_valid_tokens"] 13 | if max_valid_tokens == -1: 14 | hparams["max_valid_tokens"] = max_valid_tokens = max_tokens 15 | max_valid_sentences = hparams["max_valid_sentences"] 16 | if max_valid_sentences == -1: 17 | hparams["max_valid_sentences"] = max_valid_sentences = max_sentences 18 | 19 | return max_tokens, max_sentences, max_valid_tokens, max_valid_sentences 20 | 21 | 22 | def parse_mel_losses(): 23 | mel_losses = hparams['mel_losses'].split("|") 24 | loss_and_lambda = {} 25 | for i, l in enumerate(mel_losses): 26 | if l == '': 27 | continue 28 | if ':' in l: 29 | l, lbd = l.split(":") 30 | lbd = float(lbd) 31 | else: 32 | lbd = 1.0 33 | loss_and_lambda[l] = lbd 34 | print(f"| Mel losses: {loss_and_lambda}") 35 | 36 | return loss_and_lambda 37 | 38 | 39 | def load_data_preprocessor(): 40 | """ Load preprocessor. """ 41 | preprocess_cls = hparams["preprocess_cls"] 42 | pkg = ".".join(preprocess_cls.split(".")[:-1]) 43 | cls_name = preprocess_cls.split(".")[-1] 44 | preprocessor: BasePreprocessor = getattr(importlib.import_module(pkg), cls_name)() 45 | preprocess_args = {} 46 | preprocess_args.update(hparams["preprocess_args"]) 47 | 48 | return preprocessor, preprocess_args 49 | 50 | 51 | def load_data_binarizer(): 52 | """ Load binarizer. """ 53 | binarizer_cls = hparams["binarizer_cls"] 54 | pkg = ".".join(binarizer_cls.split(".")[:-1]) 55 | cls_name = binarizer_cls.split(".")[-1] 56 | binarizer: BaseBinarizer = getattr(importlib.import_module(pkg), cls_name)() 57 | binarization_args = {} 58 | binarization_args.update(hparams["binarization_args"]) 59 | 60 | return binarizer, binarization_args 61 | -------------------------------------------------------------------------------- /utils/commons/indexed_datasets.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import numpy as np 3 | import os 4 | import pickle 5 | 6 | from copy import deepcopy 7 | 8 | 9 | class IndexedDataset: 10 | def __init__(self, path: str, num_cache=1): 11 | super().__init__() 12 | self.path = path 13 | self.data_file = None 14 | self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()["offsets"] 15 | self.data_file = open(f"{path}.data", "rb", buffering=-1) 16 | self.cache = [] 17 | self.num_cache = num_cache 18 | 19 | def check_index(self, i: int): 20 | if i < 0 or i > len(self.data_offsets) - 1: 21 | raise IndexError("index out of range") 22 | 23 | def __del__(self): 24 | if self.data_file: 25 | self.data_file.close() 26 | 27 | def __getitem__(self, i): 28 | self.check_index(i) 29 | if self.num_cache > 0: 30 | for c in self.cache: 31 | if c[0] == i: 32 | return c[1] 33 | self.data_file.seek(self.data_offsets[i]) 34 | b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) 35 | item = pickle.loads(b) 36 | if self.num_cache > 0: 37 | self.cache = [(i, deepcopy(item))] + self.cache[:-1] 38 | return item 39 | 40 | def __len__(self): 41 | return len(self.data_offsets) - 1 42 | 43 | 44 | class IndexedDatasetBuilder: 45 | def __init__(self, path): 46 | self.path = path 47 | self.read_type = "wb" 48 | if os.path.exists(f"{path}.data"): 49 | self.read_type = "ab" 50 | self.out_file = open(f"{path}.data", self.read_type) 51 | self.byte_offsets = [0] 52 | if os.path.exists(f"{self.path}.idx"): 53 | self.byte_offsets = np.load(f"{path}.idx", allow_pickle=True).item()["offsets"] 54 | 55 | def add_item(self, item): 56 | s = pickle.dumps(item) 57 | bytes = self.out_file.write(s) 58 | self.byte_offsets.append(self.byte_offsets[-1] + bytes) 59 | 60 | def finalize(self): 61 | self.out_file.close() 62 | np.save(open(f"{self.path}.idx", "wb"), {"offsets": self.byte_offsets}) 63 | -------------------------------------------------------------------------------- /tasks/vocoder_infer/base_vocoder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import librosa 3 | import numpy as np 4 | import torch 5 | 6 | from utils.audio.mel_processing import load_wav_to_torch, MelSpectrogramFixed 7 | from utils.commons.hparams import hparams 8 | 9 | REGISTERED_VOCODERS = {} 10 | 11 | 12 | def register_vocoder(name): 13 | def _f(cls): 14 | REGISTERED_VOCODERS[name] = cls 15 | return cls 16 | 17 | return _f 18 | 19 | 20 | def get_vocoder_cls(vocoder_name): 21 | return REGISTERED_VOCODERS.get(vocoder_name) 22 | 23 | 24 | class BaseVocoder: 25 | def spec2wav(self, mel): 26 | """ 27 | Parameter 28 | --------- 29 | mel: torch.Tensor([T, 80]) 30 | 31 | Return 32 | wav: torch.Tensor([T']) 33 | """ 34 | raise NotImplementedError 35 | 36 | @staticmethod 37 | def wav2spec(wav_fn): 38 | """ 39 | Parameter 40 | --------- 41 | wav_fn: str 42 | 43 | Return 44 | ------ 45 | wav, mel: torch.Tensor([T, 80]) 46 | """ 47 | wav = load_wav_to_torch(wav_fn, hop_size=hparams['hop_size']) 48 | mel_fn = MelSpectrogramFixed(sample_rate=hparams["sample_rate"], n_fft=hparams["fft_size"], 49 | win_length=hparams["win_size"], hop_length=hparams["hop_size"], 50 | f_min=hparams["fmin"], f_max=hparams["fmax"], n_mels=hparams["num_mel_bins"], 51 | window_fn=torch.hann_window).to(device=wav.device) 52 | mel = mel_fn(wav) 53 | return wav, mel 54 | 55 | @staticmethod 56 | def wav2mfcc(wav_fn): 57 | fft_size = hparams["audio"]['fft_size'] 58 | hop_size = hparams["audio"]['hop_size'] 59 | win_length = hparams["audio"]['win_size'] 60 | sample_rate = hparams["audio"]['audio_sample_rate'] 61 | wav, _ = librosa.core.load(wav_fn, sr=sample_rate) 62 | mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13, 63 | n_fft=fft_size, hop_length=hop_size, 64 | win_length=win_length, pad_mode="constant", power=1.0) 65 | mfcc_delta = librosa.feature.delta(mfcc, order=1) 66 | mfcc_delta_delta = librosa.feature.delta(mfcc, order=2) 67 | mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T 68 | return mfcc 69 | -------------------------------------------------------------------------------- /utils/audio/pitch/utils.py: -------------------------------------------------------------------------------- 1 | # https://github.com/MoonInTheRiver/DiffSinger 2 | import numpy as np 3 | import torch 4 | 5 | gamma = 0 6 | mcepInput = 3 # 0 for dB, 3 for magnitude 7 | alpha = 0.45 8 | en_floor = 10 ** (-80 / 20) 9 | FFT_SIZE = 2048 10 | 11 | f0_bin = 300 12 | f0_max = 1250.0 13 | f0_min = 50.0 14 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 15 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 16 | 17 | 18 | def f0_to_coarse(f0): 19 | is_torch = isinstance(f0, torch.Tensor) 20 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 21 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 22 | 23 | f0_mel[f0_mel <= 1] = 1 24 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 25 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 26 | assert f0_coarse.max() < f0_bin and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 27 | return f0_coarse 28 | 29 | 30 | def norm_f0(f0, uv, pitch_norm='log', f0_mean=6000, f0_std=100, use_uv=False): 31 | is_torch = isinstance(f0, torch.Tensor) 32 | if pitch_norm == 'standard': 33 | f0 = (f0 - f0_mean) / f0_std 34 | if pitch_norm == 'log': 35 | f0 = f0 + 1 36 | f0 = torch.log2(f0) if is_torch else np.log2(f0) 37 | if uv is not None and use_uv: 38 | f0[uv > 0] = 0 39 | return f0 40 | 41 | 42 | def norm_interp_f0(f0): 43 | is_torch = isinstance(f0, torch.Tensor) 44 | if is_torch: 45 | device = f0.device 46 | f0 = f0.data.cpu().numpy() 47 | uv = f0 == 0 48 | f0 = norm_f0(f0, uv) 49 | if sum(uv) == len(f0): 50 | f0[uv] = 0 51 | elif sum(uv) > 0: 52 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 53 | uv = torch.FloatTensor(uv) 54 | f0 = torch.FloatTensor(f0) 55 | if is_torch: 56 | f0 = f0.to(device) 57 | return f0, uv 58 | 59 | 60 | def denorm_f0(f0, uv, pitch_norm='log', pitch_padding=None, f0_mean=6000, f0_std=100, min=f0_min, max=f0_max, use_uv=False): 61 | if pitch_norm== 'standard': 62 | f0 = f0 * f0_std + f0_mean 63 | if pitch_norm == 'log': 64 | f0 = 2 ** f0 65 | f0 = f0 - 1 66 | if min is not None: 67 | f0 = f0.clamp(min=min) 68 | if max is not None: 69 | f0 = f0.clamp(max=max) 70 | if uv is not None and use_uv: 71 | f0[uv > 0] = 0 72 | if pitch_padding is not None: 73 | f0[pitch_padding] = 0 74 | return f0 -------------------------------------------------------------------------------- /utils/audio/pitch_extractors.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import numpy as np 3 | 4 | PITCH_EXTRACTOR = {} 5 | 6 | 7 | def register_pitch_extractor(name): 8 | def register_pitch_extractor_(cls): 9 | PITCH_EXTRACTOR[name] = cls 10 | return cls 11 | 12 | return register_pitch_extractor_ 13 | 14 | 15 | def get_pitch_extractor(name): 16 | return PITCH_EXTRACTOR[name] 17 | 18 | 19 | def extract_pitch_simple(wav): 20 | from utils.commons.hparams import hparams 21 | return extract_pitch(hparams["pitch_extractor"], 22 | wav, 23 | hparams["hop_size"], 24 | hparams["sample_rate"], 25 | f0_resolution=hparams.get("f0_resolution", 1), 26 | f0_min=hparams["f0_min"], 27 | f0_max=hparams["f0_max"]) 28 | 29 | 30 | def extract_pitch(extractor_name, wav_data: np.array, hop_size: int, sample_rate: int, 31 | f0_resolution=1, f0_min=50, f0_max=1250, **kwargs): 32 | return get_pitch_extractor(extractor_name)( 33 | wav_data, hop_size, sample_rate, f0_resolution, f0_min, f0_max, **kwargs) 34 | 35 | 36 | @register_pitch_extractor("parselmouth") 37 | def parselmouth_pitch(wav_data: np.array, hop_size: int, sample_rate: int, f0_resolution: int, 38 | f0_min: int, f0_max: int, voicing_threshold=0.6, *args, **kwargs): 39 | import parselmouth 40 | time_step = (hop_size // f0_resolution) / sample_rate * 1000 41 | n_mel_frames = int(len(wav_data) // (hop_size // f0_resolution)) 42 | f0_pm = parselmouth.Sound(wav_data, sample_rate).to_pitch_ac( 43 | time_step=time_step / 1000, 44 | voicing_threshold=voicing_threshold, 45 | pitch_floor=f0_min, 46 | pitch_ceiling=f0_max).selected_array["frequency"] 47 | pad_size = (n_mel_frames - len(f0_pm) + 1) // 2 48 | f0 = np.pad(f0_pm, [[pad_size, n_mel_frames - len(f0_pm) - pad_size]], mode="constant") 49 | 50 | return f0 51 | 52 | 53 | @register_pitch_extractor("pyworld") 54 | def compute_f0(wav_data, hop_size, sample_rate, f0_resolution=1, f0_min=0.0, f0_max=8000, 55 | voicing_threshold=0.6, *args, **kwargs): 56 | import pyworld as pw 57 | time_step = (hop_size // f0_resolution) / sample_rate * 1000 58 | f0, t = pw.dio(wav_data.astype(np.double), 59 | fs=sample_rate, 60 | f0_ceil=f0_max, 61 | frame_period=time_step) 62 | f0 = pw.stonemask(wav_data.astype(np.double), f0, t, sample_rate) 63 | f0 = f0[:len(wav_data)//(hop_size // f0_resolution)] 64 | f0 = np.maximum(f0, 1) 65 | f0 = f0.astype(np.float32) 66 | return f0 67 | -------------------------------------------------------------------------------- /utils/commons/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import glob 3 | import os 4 | import re 5 | import torch 6 | 7 | 8 | def get_last_checkpoint(work_dir, steps=None): 9 | checkpoint = None 10 | last_ckpt_path = None 11 | ckpt_paths = get_all_ckpts(work_dir, steps) 12 | if len(ckpt_paths) > 0: 13 | last_ckpt_path = ckpt_paths[0] 14 | checkpoint = torch.load(last_ckpt_path, map_location="cpu") 15 | return checkpoint, last_ckpt_path 16 | 17 | 18 | def get_all_ckpts(work_dir, steps=None): 19 | if steps is None: 20 | ckpt_path_pattern = f"{work_dir}/model_ckpt_steps_*.ckpt" 21 | else: 22 | ckpt_path_pattern = f"{work_dir}/model_ckpt_steps_{steps}.ckpt" 23 | # Return the file name in order of steps 24 | return sorted(glob.glob(ckpt_path_pattern), 25 | key=lambda x: -int(re.findall(".*steps\_(\d+)\.ckpt", x)[0])) 26 | 27 | 28 | def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): 29 | if os.path.isfile(ckpt_base_dir): 30 | base_dir = os.path.dirname(ckpt_base_dir) 31 | ckpt_path = ckpt_base_dir 32 | checkpoint = torch.load(ckpt_base_dir, map_location='cpu') 33 | else: 34 | base_dir = ckpt_base_dir 35 | checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) 36 | if checkpoint is not None: 37 | state_dict = checkpoint["state_dict"] 38 | if "." in model_name: 39 | state_dict = state_dict[model_name.split(".")[0]] 40 | if len([k for k in state_dict.keys() if "." in k]) > 0: 41 | state_dict = {k[len(model_name.split(".")[1]) + 1:]: v for k, v in state_dict.items() 42 | if k.startswith(f"{model_name.split('.')[1]}.")} 43 | else: 44 | state_dict = state_dict[model_name] 45 | if not strict: 46 | cur_model_state_dict = cur_model.state_dict() # 현재 모델 47 | unmatched_keys = [] 48 | for key, param in state_dict.items(): # 저장된 모델 49 | if key in cur_model_state_dict: # 현재 모델 50 | new_param = cur_model_state_dict[key] 51 | if new_param.shape != param.shape: 52 | unmatched_keys.append(key) 53 | print("| Unmatched keys: ", key, new_param.shape, param.shape) 54 | for key in unmatched_keys: 55 | del state_dict[key] 56 | cur_model.load_state_dict(state_dict, strict=strict) 57 | print(f"| load '{model_name}' from '{ckpt_path}'.") 58 | else: 59 | e_msg = f"| ckpt not found in {base_dir}." 60 | if force: 61 | assert False, e_msg 62 | else: 63 | print(e_msg) 64 | -------------------------------------------------------------------------------- /config/models/visinger.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - ./base_task.yaml 3 | - ../datasets/svs/csd/preprocess.yaml 4 | task_cls: tasks.visinger.VISingerTask 5 | ################################################ 6 | # Model 7 | ################################################ 8 | hidden_size: 192 9 | p_dropout: 0.1 10 | encoder_type: rel_fft # fft|ffn|rel_fft 11 | 12 | # FFT encoder 13 | enc_layers: 6 14 | ffn_kernel_size: 9 15 | ffn_filter_channels: 768 # hidden_size * 4 16 | enc_prenet: true 17 | enc_pre_ln: true 18 | num_heads: 2 19 | ffn_act: gelu 20 | use_pos_embed: true 21 | 22 | # Waveform Decoder 23 | dec_blocks: "1" 24 | dec_kernel_size: [3,7,11] 25 | dec_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 26 | upsample_rates: [5,5,3,2,2] # for compute 300 hop-size 27 | initial_upsample_channels: 512 28 | upsample_kernel_sizes: [11,11,7,4,4] # for compute 300 hop-size 29 | gin_channels: 256 30 | 31 | # Prior encoder 32 | use_pitch_encoder: true 33 | frame_prior_layers: 4 34 | 35 | # Pitch predictor 36 | use_pitch_embed: true 37 | pitch_predictor_layers: 6 38 | pitch_predictor_kernel_size: 9 39 | pitch_type: frame 40 | 41 | # Phoneme predictor 42 | use_phoneme_pred: true 43 | phoneme_predictor_layers: 2 44 | 45 | # Discriminator 46 | use_spectral_norm: false 47 | disc_win_num: 3 48 | mel_disc_hidden_size: 256 49 | disc_norm: in 50 | disc_reduction: stack 51 | disc_interval: 1 52 | disc_start_steps: 0 53 | 54 | # mel loss 55 | mel_losses: l1:45.0 56 | 57 | # Loss lambda 58 | lambda_pitch: 10.0 59 | lambda_ctc: 45.0 60 | lambda_mel_adv: 1.0 61 | lambda_kl: 1.0 62 | lambda_fm: 2.0 63 | kl_start_steps: 1 64 | kl_min: 0.0 65 | posterior_start_steps: 0 66 | predictor_grad: 0.1 67 | 68 | ################################################ 69 | # Optimization 70 | ################################################ 71 | optimizer: AdamW 72 | lr: 0.0002 73 | scheduler: steplr 74 | optimizer_adam_beta1: 0.8 75 | optimizer_adam_beta2: 0.99 76 | eps: 1.0e-9 77 | generator_scheduler_params: 78 | gamma: 0.999875 79 | discriminator_scheduler_params: 80 | gamma: 0.999875 81 | discriminator_optimizer_params: 82 | eps: 1.0e-09 83 | weight_decay: 0.0 84 | weight_decay: 0.001 85 | clip_grad_norm: 1 86 | clip_grad_value: 0 87 | 88 | ################################################ 89 | # Train and evaluate 90 | ################################################ 91 | use_pesq: true 92 | segment_size: 32 93 | max_frames: 1280 # max sequence sizes 94 | max_sentences: 4 # max batch size ( 16 * 4 = 64 ) 95 | max_updates: 600000 96 | max_tokens: 60000 97 | tb_log_interval: 100 98 | val_check_interval: 1000 99 | ckpt_save_interval: 1000 100 | eval_max_batches: 50 101 | 102 | #################### 103 | # Datasets 104 | #################### 105 | ds_workers: 0 106 | endless_ds: false # If want to use exponentialLR with decay, should be `false` here -------------------------------------------------------------------------------- /config/datasets/svs/csd/preprocess.yaml: -------------------------------------------------------------------------------- 1 | # | valid total files: 50, total duration: 324.262s, max duration: 14.600s 2 | # | test total files: 50, total duration: 311.587s, max duration: 13.012s 3 | # | train total files: 1123, total duration: 7010.387s, max duration: 19.038s 4 | speaker: csd 5 | 6 | num_mel_bins: 128 7 | num_linear_bins: 1025 8 | raw_sample_rate: 48000 9 | sample_rate: 24000 10 | max_wav_value: 32768.0 11 | hop_size: 300 # For 24000Hz, 300 ~= 12.5 ms (0.0125 * sample_rate) 12 | win_size: 1200 # For 24000Hz, 1200 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 13 | fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter 14 | fmin: 20 15 | fmax: 12000 # To be increased/reduced depending on data. 16 | f0_min: 50 17 | f0_max: 1250 18 | griffin_lim_iters: 30 19 | pitch_extractor: parselmouth 20 | num_spk: 1 21 | mel_vmin: -7 22 | mel_vmax: 12 23 | loud_norm: false 24 | 25 | # Prerprocess arguments 26 | raw_data_dir: './data/source/svs/csd' 27 | processed_data_dir: './data/preprocessed/svs/csd' 28 | binary_data_dir: './data/binarize/svs/csd' 29 | preprocess_cls: config.datasets.svs.csd.preprocess.CSDPreprocessor 30 | binarizer_cls: preprocessor.base_binarizer.BaseBinarizer 31 | preprocess_args: 32 | nsample_per_mfa_group: 1000 33 | # text process 34 | use_text: true 35 | text_processor: ko_sing 36 | use_mfa: false 37 | with_phsep: true 38 | reset_phone_dict: false 39 | reset_word_dict: true 40 | reset_spk_dict: true 41 | add_eos_bos: true 42 | # data-specific process 43 | use_midi: true 44 | divided: true 45 | DEFAULT_TEMPO: 120 46 | pos_resolution: 16 # per beat (quarter note) 47 | max_durations: 8 # 2 ** 8 * beat 48 | max_ts_denominator: 6 # x/1 x/2 x/4 ... x/64 49 | max_notes_per_bar: 2 # 1/64 ... 128/64 50 | max_note_dur: 5.0 51 | beat_note_factor: 4 # In MIDI format a note is always 4 beats 52 | filter_symbolic: false 53 | filter_symbolic_ppl: 16 54 | melody: true 55 | max_bar: 2 56 | melody_num: 0 57 | min_sil_dur: 8 # 64th note * min_sir_dur (now 8th note) 58 | num_frame: 3 59 | # mfa 60 | mfa_group_shuffle: false 61 | mfa_offset: 0.02 62 | # wav processors 63 | wav_processors: [sox_resample] 64 | save_sil_mask: true 65 | vad_max_silence_length: 12 66 | binarization_args: 67 | shuffle: true 68 | # text settings 69 | min_text: 6 70 | # note settings 71 | max_durations: 8 # 2 ** 8 * beat 72 | pos_resolution: 16 # per beat (quarter note) 73 | tempo_range: [16, 256] 74 | # wav process 75 | with_wav: false 76 | with_midi_align: true 77 | with_mfa_align: false 78 | with_spk_embed: false 79 | with_f0: true 80 | with_f0cwt: false 81 | with_spk_f0_norm: false 82 | with_linear: false 83 | with_mel: false 84 | trim_eos_bos: false 85 | # dataset range settings 86 | dataset_range: 'index' # index|title 87 | train_range: [ 100, -1 ] 88 | test_range: [ 0, 50 ] 89 | valid_range: [ 50, 100 ] 90 | pitch_key: pitch 91 | note_range: [ 12, 128 ] -------------------------------------------------------------------------------- /utils/audio/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | from torchaudio.transforms import MelSpectrogram, Spectrogram 5 | 6 | 7 | def load_wav_to_torch(full_path, hop_size=0, slice_train=False): 8 | wav, sampling_rate = torchaudio.load(full_path, normalize=True) 9 | if not slice_train: 10 | p = (wav.shape[-1] // hop_size + 1) * hop_size - wav.shape[-1] 11 | wav = torch.nn.functional.pad(wav, (0, p), mode="constant").data 12 | return wav.squeeze(0), sampling_rate 13 | 14 | 15 | class SpectrogramFixed(torch.nn.Module): 16 | """In order to remove padding of torchaudio package + add log10 scale.""" 17 | 18 | def __init__(self, **kwargs): 19 | super(SpectrogramFixed, self).__init__() 20 | self.torchaudio_backend = Spectrogram(**kwargs) 21 | 22 | def forward(self, x): 23 | outputs = self.torchaudio_backend(x) 24 | 25 | return outputs[..., :-1] 26 | 27 | 28 | class MelSpectrogramFixed(torch.nn.Module): 29 | """In order to remove padding of torchaudio package + add log10 scale.""" 30 | 31 | def __init__(self, **kwargs): 32 | super(MelSpectrogramFixed, self).__init__() 33 | self.torchaudio_backend = MelSpectrogram(**kwargs) 34 | 35 | def forward(self, x): 36 | outputs = torch.log(self.torchaudio_backend(x) + 0.001) 37 | 38 | return outputs[..., :-1] 39 | 40 | 41 | def torch_wav2spec(wav_fn, fft_size, hop_size, win_length, num_mels, fmin, fmax, sample_rate): 42 | """ Waveform to linear-spectrogram and mel-sepctrogram. """ 43 | # Read wavform 44 | wav, sr = load_wav_to_torch(wav_fn, hop_size, slice_train=False) 45 | if sr != sample_rate: 46 | raise ValueError(f"{sr} SR doesn't match target {sample_rate} SR") 47 | if torch.min(wav) < -1.: 48 | print('min value is ', torch.min(wav)) 49 | if torch.max(wav) > 1.: 50 | print('max value is ', torch.max(wav)) 51 | # Spectrogram process 52 | spec_fn = SpectrogramFixed(n_fft=fft_size, win_length=win_length, hop_length=hop_size, 53 | window_fn=torch.hann_window).to(device=wav.device) 54 | spec = spec_fn(wav) 55 | # Mel-spectrogram 56 | mel_fn = MelSpectrogramFixed(sample_rate=sample_rate, n_fft=fft_size, win_length=win_length, 57 | hop_length=hop_size, f_min=fmin, f_max=fmax, n_mels=num_mels, 58 | window_fn=torch.hann_window).to(device=wav.device) 59 | mel = mel_fn(wav) 60 | # Wav-processing 61 | wav = wav.squeeze(0)[:mel.shape[-1]*hop_size] 62 | # Check wav and spectorgram 63 | assert wav.shape[-1] == mel.shape[-1] * hop_size, f"| wav: {wav.shape}, spec: {spec.shape}, mel: {mel.shape}" 64 | assert mel.shape[-1] == spec.shape[-1], f"| wav: {wav.shape}, spec: {spec.shape}, mel: {mel.shape}" 65 | return {"wav": wav.cpu().detach().numpy(), "linear": spec.squeeze(0).T.cpu().detach().numpy(), 66 | "mel": mel.squeeze(0).T.cpu().detach().numpy()} 67 | -------------------------------------------------------------------------------- /modules/discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn.utils import weight_norm, spectral_norm 7 | 8 | from modules.commons.utils import get_padding 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class DiscriminatorP(nn.Module): 14 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 15 | super(DiscriminatorP, self).__init__() 16 | self.period = period 17 | self.use_spectral_norm = use_spectral_norm 18 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 19 | self.convs = nn.ModuleList([ 20 | norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 21 | norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 22 | norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 23 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 24 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 25 | ]) 26 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 27 | 28 | def forward(self, x): 29 | fmap = [] 30 | 31 | # 1d to 2d 32 | b, c, t = x.shape 33 | if t % self.period != 0: # pad first 34 | n_pad = self.period - (t % self.period) 35 | x = F.pad(x, (0, n_pad), "reflect") 36 | t = t + n_pad 37 | x = x.view(b, c, t // self.period, self.period) 38 | 39 | for l in self.convs: 40 | x = l(x) 41 | x = F.leaky_relu(x, LRELU_SLOPE) 42 | fmap.append(x) 43 | x = self.conv_post(x) 44 | fmap.append(x) 45 | x = torch.flatten(x, 1, -1) 46 | 47 | return x, fmap 48 | 49 | 50 | class DiscriminatorS(torch.nn.Module): 51 | def __init__(self, use_spectral_norm=False): 52 | super(DiscriminatorS, self).__init__() 53 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 54 | self.convs = nn.ModuleList([ 55 | norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), 56 | norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), 57 | norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), 58 | norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 59 | norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 60 | norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), 61 | ]) 62 | self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) 63 | 64 | def forward(self, x): 65 | fmap = [] 66 | 67 | for l in self.convs: 68 | x = l(x) 69 | x = F.leaky_relu(x, LRELU_SLOPE) 70 | fmap.append(x) 71 | x = self.conv_post(x) 72 | fmap.append(x) 73 | x = torch.flatten(x, 1, -1) 74 | 75 | return x, fmap 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unofficial implementation of Korean VISinger 2 | 3 | VISinger: Variational Inference with Adversarial Learning for End-to-End Singing Voice Snythesis [[paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9747664)] 4 | 5 | ## Overview 6 | This repositroy contains PyTorch implementation of the Korean VISinger architecture, along with examples. Feel free to use/modify the code. 7 | 8 |

9 | 10 |

Architecture of VISInger 11 |

12 | 13 | ## Install Dependencies 14 | ``` 15 | ## We tested on Linux/Ubuntu 20.04. 16 | ## Install Python 3.8+ first (Anaconda recommended). 17 | 18 | export PYTHONPATH=. 19 | # build a virtual env (recommended). 20 | conda create -n venv python=3.8 21 | conda activate venv 22 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 23 | pip install -r requirements.txt 24 | sudo apt install -y sox libsox-fmt-mp3 25 | ``` 26 | 27 | ## Training 28 | ### 1. Datasets 29 | The supported datasets are 30 | - [CSD](https://program.ismir2020.net/static/lbd/ISMIR2020-LBD-435-abstract.pdf): a single-singer Korean datasets contains 2.12 hours in total. 31 | 32 | ### 2. Preprocessing 33 | Run base_preprocess.py for preprocessing. 34 | ``` 35 | python preprocessor/runs/base_preprocess.py --config config/datasets/svs/csd/preprocess.yaml 36 | ``` 37 | After that, run base_binarize.py for training. 38 | ``` 39 | python preprocessor/runs/base_binarize.py --config config/datasets/svs/csd/preprocess.yaml 40 | ``` 41 | 42 | ### 3. Training 43 | Trian model with 44 | ``` 45 | CUDA_VISIBLE_DEVICES=0 python tasks/runs/run.py --config config/models/visinger.yaml --exp_name "[dir]/[folder_name]" 46 | ``` 47 | 48 | ### 4. Inference 49 | You have to download the [pretrained models]() (will be uploaded) and put them in `./checkpoints/svs/visinger`. You have to prepare MIDI data which contains lyrics with the same amount of notes. We uploaded the sample file in `./data/source/svs/new_midi/` (will be uploaded). 50 | You can inference new singing voice with 51 | ``` 52 | python inference/visinger.py 53 | ``` 54 | please setting the file path of MIDI data in `./inference/visinger.py`. 55 | 56 | ## Note 57 | - Korean singing voice synthesis (SVS) do not requires duration prediction. We just split the each syallble into three components: `onset`, `nucleus`, and `coda`. SVS has a long vowel duration and the `nucleus` of Korean syllable is equivalent to the vowel. In this repository, we assigned `onset` and `coda` to a maximum three frames and assigned the remaining frames to the `nucleus`. 58 | - We will upload the checkpoints of VISinger trained on CSD datasets (will be upload after march 2023) 59 | 60 | ## Acknowledgments 61 | Our codes are influenced by the following repos: 62 | - [NATSpeech](https://github.com/NATSpeech/NATSpeech) 63 | - [HiFi-GAN](https://github.com/jik876/hifi-gan) 64 | - [VITS](https://github.com/jaywalnut310/vits) 65 | - [BigVGAN unofficial implementation](https://github.com/sh-lee-prml/BigVGAN) 66 | -------------------------------------------------------------------------------- /preprocessor/wave/common_processor.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import librosa 3 | import numpy as np 4 | import os 5 | import subprocess 6 | 7 | from preprocessor.wave.base_wave_processor import BaseWavProcessor, register_wav_processors 8 | from utils.audio import trim_long_silences 9 | from utils.audio.io import save_wav 10 | 11 | 12 | @register_wav_processors(name="sox_to_wav") 13 | class ConvertToWavProcessor(BaseWavProcessor): 14 | @property 15 | def name(self): 16 | return "ToWav" 17 | 18 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 19 | if input_fn[:-4] == ".wav": 20 | return input_fn, 21 | else: 22 | output_fn = self.output_fn(input_fn) 23 | subprocess.check_call( 24 | f"sox -v 0.95 '{input_fn}' -t wav '{output_fn}'", 25 | shell=True) 26 | return output_fn, sr 27 | 28 | 29 | @register_wav_processors(name="sox_resample") 30 | class ResampleProcessor(BaseWavProcessor): 31 | @property 32 | def name(self): 33 | return "Resample" 34 | 35 | def process(self, input_fn, tgt_sr, tmp_dir, processed_dir, item_name): 36 | output_fn = self.output_fn(input_fn) 37 | sr = librosa.core.get_samplerate(input_fn) 38 | if tgt_sr != sr: 39 | try: 40 | subprocess.check_call(f"sox '{input_fn}' -r {tgt_sr} '{output_fn}'", 41 | shell=True) 42 | y, _ = librosa.core.load(input_fn, sr=tgt_sr) 43 | save_wav(y, output_fn, tgt_sr, norm=True) 44 | return input_fn, sr, output_fn 45 | except: 46 | return None 47 | else: 48 | return input_fn, sr 49 | 50 | 51 | @register_wav_processors(name="trim_sil") 52 | class TrimSILProcessor(BaseWavProcessor): 53 | @property 54 | def name(self): 55 | return "TrimSIL" 56 | 57 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, audio_args): 58 | output_fn = self.output_fn(input_fn) 59 | sr = librosa.core.get_samplerate(input_fn) 60 | y, _ = librosa.core.load(input_fn, sr=sr) 61 | y, _ = librosa.effects.tirm(y) 62 | save_wav(y, output_fn, sr) 63 | return input_fn, sr, output_fn 64 | 65 | 66 | @register_wav_processors(name="trim_all_sil") 67 | class TrimALLSILProcessor(BaseWavProcessor): 68 | @property 69 | def name(self): 70 | return "TrimALLSIL" 71 | 72 | def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): 73 | output_fn = self.output_fn(input_fn) 74 | y, audio_mask, _ = trim_long_silences( 75 | input_fn, 76 | vad_max_silence_length=preprocess_args.get("vad_max_silence_length", 12)) 77 | save_wav(y, output_fn, sr) 78 | if preprocess_args["save_sil_mask"]: 79 | os.makedirs(f"{processed_dir}/sil_mask", exist_ok=True) 80 | np.save(f"{processed_dir}/sil_mask/{item_name}.npy", audio_mask) 81 | return output_fn, sr 82 | -------------------------------------------------------------------------------- /modules/commons/utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LayerNorm(nn.LayerNorm): 8 | """ 9 | Layer normalization module. 10 | 11 | Parameters 12 | ---------- 13 | nout: int 14 | output dimension size 15 | dim: int 16 | dimension to be normalized 17 | """ 18 | def __init__(self, nout, dim=-1, eps=1e-5): 19 | """ Construct a layernorm object. """ 20 | super(LayerNorm, self).__init__(nout, eps=eps) 21 | self.dim = dim 22 | 23 | def forward(self, x): 24 | """ 25 | Apply layer normalization. 26 | 27 | Parameter 28 | --------- 29 | x: torch.Tensor 30 | input tensor 31 | 32 | Returns 33 | ------- 34 | x: torch.Tensor 35 | layer normalized tensor 36 | """ 37 | if self.dim == -1: 38 | return super(LayerNorm, self).forward(x) 39 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 40 | 41 | 42 | class ConvLayerNorm(nn.LayerNorm): 43 | def __init__(self, normalized_shape, **kwargs): 44 | super().__init__(normalized_shape, **kwargs) 45 | 46 | def forward(self, x): 47 | x = einops.rearrange(x, 'b ... t -> b t ...') 48 | x = super().forward(x) 49 | x = einops.rearrange(x, 'b t ... -> b ... t') 50 | return x 51 | 52 | 53 | class Reshape(nn.Module): 54 | def __init__(self, *args): 55 | super(Reshape, self).__init__() 56 | self.shape = args 57 | 58 | def forward(self, x): 59 | return x.view(self.shape) 60 | 61 | 62 | class Permute(nn.Module): 63 | def __init__(self, *args): 64 | super(Permute, self).__init__() 65 | self.args = args 66 | 67 | def forward(self, x): 68 | return x.permute(self) 69 | 70 | 71 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 72 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 73 | nn.init.normal_(m.weight, mean=0.0, std=embedding_dim ** -0.5) 74 | if padding_idx is not None: 75 | nn.init.constant_(m.weight[padding_idx], 0) 76 | return m 77 | 78 | 79 | def sequence_mask(length, max_length=None): 80 | if max_length is None: 81 | max_length = length.max() 82 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 83 | return x.unsqueeze(0) < length.unsqueeze(1) 84 | 85 | 86 | def slice_segments(x, ids_str, segment_size=4): 87 | ret = torch.zeros_like(x[:, :, :segment_size]) 88 | for i in range(x.size(0)): 89 | idx_str = ids_str[i] 90 | idx_end = idx_str + segment_size 91 | ret[i] = x[i, :, idx_str:idx_end] 92 | return ret 93 | 94 | 95 | def rand_slice_segments(x, segment_size=4): 96 | batch, _, t_len = x.size() # [Batch, Hidden, T_len] 97 | ids_str_max = t_len - segment_size + 1 98 | ids_str = (torch.rand([batch]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 99 | ret = slice_segments(x, ids_str, segment_size) 100 | return ret, ids_str 101 | 102 | 103 | def init_weights(m, mean=0.0, std=0.01): 104 | classname = m.__class__.__name__ 105 | if classname.find("Conv") != -1: 106 | m.weight.data.normal_(mean, std) 107 | 108 | 109 | def get_padding(kernel_size, dilation=1): 110 | return int((kernel_size*dilation - dilation)/2) 111 | -------------------------------------------------------------------------------- /utils/nn/seq_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from collections import defaultdict 6 | 7 | 8 | def make_positions(tensor, padding_idx): 9 | """ Replace non-padding symbols with their position numbers. 10 | 11 | Position numbers begin at padding_idx=1. Padding symbols are ignored. 12 | """ 13 | # The series of casts and type-conversions here are carefully balanced 14 | # to both work with ONNX export and XLA. 15 | # In particular XLA prefers ints, cumsum defaults to output longs, and 16 | # ONNX doesn't know how to handle the dtype kwarg in cumsum. 17 | mask = tensor.ne(padding_idx).int() 18 | return (torch.cumsum(mask, dim=1).type_as(mask) *mask).long() + padding_idx 19 | 20 | 21 | def softmax(x, dim): 22 | return F.softmax(x, dim=dim, dtype=torch.float32) 23 | 24 | 25 | def sequence_mask(lengths, maxlen, dtype=torch.bool): 26 | if maxlen is None: 27 | maxlen = lengths.max() 28 | mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() 29 | mask.type(dtype) 30 | return mask 31 | 32 | 33 | def weights_nonzero_speech(target): 34 | # target: [B, T, mel] 35 | # Assign weight 1.0 to all labels for padding (id=0). 36 | dim = target.size(-1) 37 | return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) 38 | 39 | INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) 40 | 41 | 42 | def _get_full_incremental_state_key(module_instance, key): 43 | module_name = module_instance.__class__.__name__ 44 | 45 | # Assign a unique ID to each module instance, so that incremental state 46 | # is not shared across module instance 47 | if not hasattr(module_instance, "_instance_id"): 48 | INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 49 | module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] 50 | 51 | return f"{module_name}.{module_instance._instance_id}.{key}" 52 | 53 | 54 | def get_incremental_state(module, incremental_state, key): 55 | """ Helper for getting incremental state for an nn.Module. """ 56 | full_key = _get_full_incremental_state_key(module, key) 57 | if incremental_state is None or full_key not in incremental_state: 58 | return None 59 | 60 | return incremental_state[full_key] 61 | 62 | 63 | def set_incremental_state(module, incremental_state, key, value): 64 | """ Helper for setting incremental state for an nn.Module. """ 65 | if incremental_state is not None: 66 | full_key = _get_full_incremental_state_key(module, key) 67 | incremental_state[full_key] = value 68 | 69 | 70 | def group_hidden_by_segs(h, seg_ids, max_len): 71 | """ 72 | Parameters 73 | ---------- 74 | h: torch.Tensor([Batch, T_len, Hidden]) 75 | seg_ids: torch.Tensor([Batch, T_len]) 76 | max_len: int 77 | 78 | Return 79 | ------- 80 | h_ph: torch.Tensor([Batch, T_phoneme, Hidden]) 81 | """ 82 | B, T, H = h.shape 83 | h_grouby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) 84 | all_ones = h.new_ones(h.shape[:2]) 85 | contigous_groupby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() 86 | h_grouby_segs = h_grouby_segs[:, 1:] 87 | contigous_groupby_segs = contigous_groupby_segs[:, 1:] 88 | h_grouby_segs = h_grouby_segs / torch.clamp(contigous_groupby_segs[:, :, None], min=1) 89 | 90 | return h_grouby_segs, contigous_groupby_segs 91 | -------------------------------------------------------------------------------- /utils/audio/vad.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import librosa 3 | import numpy as np 4 | import pyloudnorm as pyln 5 | import struct 6 | import warnings 7 | import webrtcvad 8 | 9 | from scipy.ndimage.morphology import binary_dilation 10 | from skimage.transform import resize 11 | 12 | warnings.filterwarnings("ignore", message="Possible clipped samples in output") 13 | 14 | int16_max = (2 ** 15) - 1 15 | 16 | 17 | def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, 18 | vad_max_silence_length=12): 19 | """ Ensure that segmetns without voice in the waveform remain no longer than a 20 | threshold determined by the VAD parameters 21 | 22 | Parameters 23 | ---------- 24 | path: str 25 | Path of the raw waveform 26 | 27 | return_raw_wav: boolean 28 | Wheter return raw waveform data 29 | 30 | vad_max_silence_length: int 31 | Maximum number of consecutive silent frames a segment can have. 32 | 33 | Return 34 | ------ 35 | wav: np.array 36 | the same waveform with silences trimmed away (length <= original wav length) 37 | """ 38 | ## Voice Activation Detaction 39 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 40 | # This sets the granularity of the VAD. Should not need to be changed. 41 | sampling_rate = 16000 42 | wav_raw, sr = librosa.core.load(path, sr=sr) 43 | 44 | if norm: 45 | meter = pyln.Meter(sr) # Create BS.1770 meter 46 | loudness = meter.integrated_loudness(wav_raw) 47 | wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0) 48 | if np.abs(wav_raw).max() > 1.0: 49 | wav_raw = wav_raw / np.abs(wav_raw).max() 50 | 51 | wav = librosa.resample(wav_raw, sr, sampling_rate, res_type="kaiser_best") 52 | 53 | vad_window_length = 10 # Milliseconds domain 54 | # Number of frames to average together when performing the moving average smoothing. 55 | # The larger this value, the larger the VAD variations must be to no get smoothed out. 56 | vad_moving_average_width = 8 57 | 58 | # Compute the voice detection window size 59 | samples_per_window = (vad_window_length * sampling_rate) // 1000 60 | 61 | # Trim the end of the audio to have a multiple of the window size 62 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 63 | 64 | # Convert the float waveform to 16-bit mono PCM 65 | pcm_wav = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 66 | 67 | # Perform voice activation detection 68 | voice_flags = [] 69 | vad = webrtcvad.Vad(mode=3) 70 | for window_start in range(0, len(wav), samples_per_window): 71 | window_end = window_start + samples_per_window 72 | voice_flags.append(vad.is_speech(pcm_wav[window_start * 2:window_end * 2], 73 | smaple_rate=sampling_rate)) 74 | voice_flags = np.array(voice_flags) 75 | 76 | # Smooth the voice detection with a moving average 77 | def moving_average(array, width): 78 | array_padded = \ 79 | np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 80 | ret = np.cumsum(array_padded, dtype=float) 81 | ret[width:] = ret[width:] - ret[:-width] 82 | return ret[width - 1:] / width 83 | 84 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 85 | audio_mask = np.round(audio_mask).astype(np.bool) 86 | 87 | # Dilate the voiced regions 88 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 89 | audio_mask = np.repeat(audio_mask, samples_per_window) 90 | audio_mask = resize(audio_mask, (len(wav_raw), )) > 0 91 | if return_raw_wav: 92 | return wav_raw, audio_mask, sr 93 | return wav_raw[audio_mask], audio_mask, sr 94 | -------------------------------------------------------------------------------- /utils/commons/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | def reduce_tensors(metrics): 7 | """ Reduce metrics across all machines. """ 8 | new_metrics = {} 9 | for k, v in metrics.items(): 10 | # Check tensor 11 | if isinstance(v, torch.Tensor): 12 | dist.all_reduce(v) 13 | v = v / dist.get_world_size() 14 | # Check dicionary 15 | if type(v) is dict: 16 | v = reduce_tensors(v) # Apply recursive 17 | new_metrics[k] = v 18 | return new_metrics 19 | 20 | 21 | def tensors_to_scalars(tensors): 22 | """ Get items from tensors. """ 23 | if isinstance(tensors, torch.Tensor): 24 | tensors = tensors.item() 25 | return tensors 26 | elif isinstance(tensors, dict): 27 | new_tensors = {} 28 | for k, v in tensors.items(): 29 | v = tensors_to_scalars(v) # Apply recursive 30 | new_tensors[k] = v 31 | return new_tensors 32 | elif isinstance(tensors, list): 33 | return [tensors_to_scalars(v) for v in tensors] 34 | else: 35 | return tensors 36 | 37 | 38 | def tensors_to_np(tensors): 39 | """ Convert to numpy. """ 40 | if isinstance(tensors, dict): 41 | new_np = {} 42 | for k, v in tensors.items(): 43 | if isinstance(v, torch.Tensor): 44 | v = v.cpu().numpy() 45 | if type(v) is dict: 46 | v = tensors_to_np(v) # Apply recursive 47 | new_np[k] = v 48 | elif isinstance(tensors, list): 49 | new_np = {} 50 | for v in tensors: 51 | if isinstance(v, torch.Tensor): 52 | v = v.cpu().numpy() 53 | if type(v) is dict: 54 | v = tensors_to_np(v) 55 | new_np.append(v) 56 | elif isinstance(tensors, torch.Tensor): 57 | v = tensors 58 | if isinstance(tensors, torch.Tensor): 59 | v = v.cpu().numpy() 60 | if type(v) is dict: 61 | v = tensors_to_np(v) # Apply recursive 62 | new_np = v 63 | else: 64 | raise Exception(f"tensor_to_np does not support type {type(tensors)}.") 65 | return new_np 66 | 67 | 68 | def move_to_cpu(tensors): 69 | """ Move data from GPU to CPU. """ 70 | ret = {} 71 | for k, v in tensors.items(): 72 | if isinstance(v, torch.Tensor): 73 | v = v.cpu() 74 | if type(v) is dict: 75 | v = move_to_cpu(v) # Apply recursive 76 | ret[k] = v 77 | return ret 78 | 79 | 80 | def move_to_cuda(batch, gpu_id=0): 81 | """ Move data from CPU to GPU. """ 82 | # base case: object can be directly moved using "cuda" or "to" 83 | if callable(getattr(batch, "cuda", None)): 84 | return batch.cuda(gpu_id, non_blocking=True) 85 | elif callable(getattr(batch, "to", None)): 86 | return batch.to(torch.device("cuda", gpu_id), non_blocking=True) 87 | elif isinstance(batch, list): 88 | for i, x in enumerate(batch): 89 | batch[i] = move_to_cuda(x, gpu_id) # Apply recursive 90 | return batch 91 | elif isinstance(batch, tuple): 92 | batch = list(batch) 93 | for i, x in enumerate(batch): 94 | batch[i] = move_to_cuda(x, gpu_id) # Apply recursive 95 | elif isinstance(batch, dict): 96 | for k, v in batch.items(): 97 | batch[k] = move_to_cuda(v, gpu_id) # Apply recursive 98 | return batch 99 | return batch 100 | 101 | 102 | def sequence_mask(length, max_length=None): 103 | if max_length is None: 104 | max_length = length.max() 105 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 106 | return x.unsqueeze(0) < length.unsqueeze(1) 107 | -------------------------------------------------------------------------------- /preprocessor/text/dict/korean.json: -------------------------------------------------------------------------------- 1 | {"etc_dict" : { 2 | "2 30대": "이삼십대", 3 | "20~30대": "이삼십대", 4 | "20, 30대": "이십대 삼십대", 5 | "1+1": "원플러스원", 6 | "+": "플러스", 7 | "3에서 6개월인": "3개월에서 육개월인", 8 | "㎍/㎥": "마이크로미터 퍼 세제곱미터", 9 | "MP3": "엠피쓰리", 10 | "5G": "파이브지", 11 | "4G": "포지", 12 | "3G": "쓰리지", 13 | "2G": "투지", 14 | "A/S": "에이 에스", 15 | "1/3":"삼분의 일", 16 | "greentea907": "그린티 구공칠", 17 | "CNT 123": "씨엔티 일이삼", 18 | "14학번": "일사 학번", 19 | "7011번": "칠공일일번", 20 | "P8학원": "피에잇 학원", 21 | "102마리": "백두 마리", 22 | "20명": "스무명" 23 | }, 24 | 25 | "num_dict": { 26 | "0": "영", 27 | "1": "일", 28 | "2": "이", 29 | "3": "삼", 30 | "4": "사", 31 | "5": "오", 32 | "6": "육", 33 | "7": "칠", 34 | "8": "팔", 35 | "9": "구" 36 | }, 37 | 38 | "num_ten_dict": ["", "십", "백", "천"], 39 | 40 | "num_tenthousand_dict": ["", "만", "억", "조", "해", "경"], 41 | 42 | "count_checker": "(시|명|가지|살|마리|포기|송이|수|톨|통|점|개(?!월)|벌|척|채|다발|그루|자루|줄|켤레|그릇|잔|마디|상자|사람|곡|병|판)", 43 | 44 | "count_dict": ["", "한", "두", "세", "네", "다섯", "여섯", "일곱", "여덟", "아홉"], 45 | 46 | "count_tenth_dict": { 47 | "십": "열", 48 | "두십": "스물", 49 | "세십": "서른", 50 | "네십": "마흔", 51 | "다섯십": "쉰", 52 | "여섯십": "예순", 53 | "일곱십": "일흔", 54 | "여덟십": "여든", 55 | "아홉십": "아흔" 56 | }, 57 | 58 | "unit_dict": { 59 | "%": "퍼센트", 60 | "ml": "밀리리터", 61 | "mm": "밀리미터", 62 | "cm": "센치미터", 63 | "km": "킬로미터", 64 | "kg": "킬로그램", 65 | "℃": "도", 66 | "㎢": "제곱킬로미터", 67 | "㎥": "세제곱미터", 68 | "m": "미터", 69 | "퍼센트": " 퍼센트" 70 | }, 71 | 72 | "upper_dict":{ 73 | "A": "에이", 74 | "B": "비", 75 | "C": "씨", 76 | "D": "디", 77 | "E": "이", 78 | "F": "에프", 79 | "G": "지", 80 | "H": "에이치", 81 | "I": "아이", 82 | "J": "제이", 83 | "K": "케이", 84 | "L": "엘", 85 | "M": "엠", 86 | "N": "엔", 87 | "O": "오", 88 | "P": "피", 89 | "Q": "큐", 90 | "R": "알", 91 | "S": "에스", 92 | "T": "티", 93 | "U": "유", 94 | "V": "브이", 95 | "W": "더블유", 96 | "X": "엑스", 97 | "Y": "와이", 98 | "Z": "지" 99 | }, 100 | 101 | "eng_dict": { 102 | "JTBC": "제이티비씨", 103 | "Devsisters": "데브시스터즈", 104 | "track": "트랙", 105 | "KOREA": "코리아", 106 | "idol": "아이돌", 107 | "trickle down effect": "트리클 다운 이펙트", 108 | "trickle up effect": "트리클 업 이펙트", 109 | "down": "다운", 110 | "up": "업", 111 | "WHERETHEWILDTHINGSARE": "", 112 | "Rashomon Effect": "", 113 | "bill": "빌", 114 | "Halmuny": "하모니", 115 | "ability": "어빌리티", 116 | "shy": "샤이", 117 | "the tenth man": "더 텐쓰 맨", 118 | "Content Attitude Timing": "컨텐트 애티튜드 타이밍", 119 | "CAT": "캣", 120 | "PPropertyPositionPowerPrisonP": "", 121 | "francisco": "프란시스코", 122 | "III": "아이아이", 123 | "No joke": "노 조크", 124 | "Don't worry be happy": "돈 워리 비 해피", 125 | "it was our sky": "잇 워즈 아워 스카이", 126 | "it is our sky": "잇 이즈 아워 스카이", 127 | "apology": "어폴로지", 128 | "humble": "험블", 129 | "Nowhere Man": "노웨어 맨", 130 | "The Tenth Man": "더 텐쓰 맨", 131 | "Pick me up": "픽 미 업", 132 | "STOP": "스탑", 133 | "PRESS": "프레스", 134 | "not to be": "낫 투비", 135 | "Denial": "디나이얼", 136 | "Time flies like an arrow": "타임 플라이즈 라이크 언 애로우", 137 | "MZ": "엠제트", 138 | "Z세대": "제트세대", 139 | "TV": "티비", 140 | "I love America": "아이 러브 아메리카", 141 | "Prime Minister": "프라임 미니스터", 142 | "Swordline": "스워드라인", 143 | "Reflecting Absence": "리플렉팅 앱센스", 144 | "Drum being beaten by everyone": "드럼 빙 비튼 바이 에브리원", 145 | "negative pressure": "네거티브 프레셔", 146 | "KIA": "기아", 147 | "Que sais-je": "", 148 | "Chaebol": "채벌", 149 | "who are you": "후 얼 유", 150 | "The Devils Advocate": "더 데빌즈 어드보카트", 151 | "so sorry": "쏘 쏘리", 152 | "Santa": "산타", 153 | "Big Endian": "빅 엔디안", 154 | "Small Endian": "스몰 엔디안", 155 | "Oh Captain My Captain": "오 캡틴 마이 캡틴" 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /utils/commons/multiprocess_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import os 3 | import traceback 4 | 5 | from functools import partial 6 | from tqdm import tqdm 7 | 8 | 9 | def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None): 10 | ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None 11 | while True: 12 | args = args_queue.get() 13 | if args == '': 14 | return 15 | job_idx, map_func, arg = args 16 | try: 17 | map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func 18 | if isinstance(arg, dict): 19 | res = map_func_(**arg) 20 | elif isinstance(arg, (list, tuple)): 21 | res = map_func_(*arg) 22 | else: 23 | res = map_func_(arg) 24 | results_queue.put((job_idx, res)) 25 | except: 26 | traceback.print_exc() 27 | results_queue.put((job_idx, None)) 28 | 29 | 30 | class MultiprocessManager: 31 | """ Multi-process Manager """ 32 | def __init__(self, num_workers=None, init_ctx_func=None, multithread=False, queue_max=-1): 33 | if multithread: 34 | from multiprocessing.dummy import Queue, Process 35 | else: 36 | from multiprocessing import Queue, Process 37 | if num_workers is None: 38 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 39 | self.num_workers = num_workers 40 | self.results_queue = Queue(maxsize=-1) 41 | self.jobs_pending = [] 42 | self.args_queue = Queue(maxsize=queue_max) 43 | self.workers = [] 44 | self.total_jobs = 0 45 | self.multithread = multithread 46 | for i in range(num_workers): 47 | if multithread: 48 | p = Process(target=chunked_worker, 49 | args=(i, self.args_queue, self.results_queue, init_ctx_func)) 50 | else: 51 | p = Process(target=chunked_worker, 52 | args=(i, self.args_queue, self.results_queue, init_ctx_func), 53 | daemon=True) 54 | self.workers.append(p) 55 | p.start() 56 | 57 | def add_job(self, func, args): 58 | if not self.args_queue.full(): 59 | self.args_queue.put((self.total_jobs, func, args)) 60 | else: 61 | self.jobs_pending.append((self.total_jobs, func, args)) 62 | self.total_jobs += 1 63 | 64 | def get_results(self): 65 | self.n_finished = 0 66 | while self.n_finished < self.total_jobs: 67 | while len(self.jobs_pending) > 0 and not self.args_queue.full(): 68 | self.args_queue.put(self.jobs_pending[0]) 69 | self.jobs_pending = self.jobs_pending[1:] 70 | job_id, res = self.results_queue.get() 71 | yield job_id, res 72 | self.n_finished += 1 73 | for w in range(self.num_workers): 74 | self.args_queue.put("") 75 | for w in self.workers: 76 | w.join() 77 | 78 | def close(self): 79 | if not self.multithread: 80 | for w in self.workers: 81 | w.terminate() 82 | 83 | def __len__(self): 84 | return self.total_jobs 85 | 86 | 87 | def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, 88 | multithread=False, queue_max=-1, desc=None): 89 | for i, res in tqdm( 90 | multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread, 91 | queue_max=queue_max), 92 | total=len(args), desc=desc): 93 | yield i, res 94 | 95 | 96 | def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False, 97 | queue_max=-1): 98 | """ Multiprocessing running chunked jobs. 99 | 100 | Examples 101 | -------- 102 | >>>> for res in tqdm(multiprocess_run(job_func, args)): 103 | >>>> print(res) 104 | 105 | Parameters 106 | ---------- 107 | map_func: function 108 | args: dict 109 | num_workers: int 110 | ordered: bool 111 | init_ctx_func: function 112 | multithread: bool 113 | queue_max: int 114 | """ 115 | if num_workers is None: 116 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 117 | 118 | # Setting Multi-process Manager 119 | manager = MultiprocessManager(num_workers, init_ctx_func, multithread, queue_max=queue_max) 120 | for arg in args: 121 | manager.add_job(map_func, arg) 122 | if ordered: 123 | n_jobs = len(args) 124 | results = ['' for _ in range(n_jobs)] 125 | i_now = 0 126 | for job_i, res in manager.get_results(): 127 | results[job_i] = res 128 | while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != ''): 129 | yield i_now, results[i_now] 130 | results[i_now] = None 131 | i_now += 1 132 | else: 133 | for job_i, res in manager.get_results(): 134 | yield job_i, res 135 | manager.close() 136 | -------------------------------------------------------------------------------- /utils/audio/align.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def get_mel2note(midi_info, mel, hop_size, sample_rate, min_sil_duration=0): 7 | # Check intervals 8 | midi_info_ = [] 9 | # midi_info: (Bar, Pos, Pitch, Duration_midi, start_time, end_time, Tempo, phone_token, phone) 10 | for i, midi in enumerate(midi_info): 11 | # Check (i-th start time - i-1th end time) 12 | if i > 0 and midi[4] - midi_info_[-1][5] < min_sil_duration: 13 | midi_info_[-1][5] = midi[4] 14 | # Remove | token after or before or repeat 15 | if i > 0 and midi[8][0] == "|" and (midi_info_[-1][8][0] == "" or midi_info_[-1][8][0] == "|"): 16 | midi_info_[-1][5] = midi[5] 17 | midi_info_[-1][2] = 0 18 | elif i > 0 and midi[8][0] == "" and midi_info_[-1][8][0] == "|": 19 | midi_info_[-1][5] = midi[5] 20 | midi_info_[-1][2] = 0 21 | else: 22 | if midi[8][0] == "|": 23 | midi[2] = 0 24 | midi_info_.append(midi) 25 | # For remove, zero duration BOS token 26 | midi_info_ = [midi for midi in midi_info_ if not (midi[8][0] == "" and midi[5] - midi[4] < 0.001)] 27 | # Check phoneme 28 | mel2phone = np.zeros([mel.shape[0]]) 29 | mel2note = np.zeros([mel.shape[0]]) 30 | ph_token_list = [] 31 | ph_list = [] 32 | note_token_list = [] 33 | i_note = 0 34 | while i_note < len(midi_info_): 35 | # midi_info: (Bar, Pos, Pitch, Duration_midi, start_time, end_time, Tempo, phone_token, phone) 36 | midi = midi_info_[i_note] 37 | start_frame = int(midi[4] * sample_rate / hop_size + 0.5) 38 | end_frame = int(midi[5] * sample_rate / hop_size + 0.5) 39 | assert end_frame - start_frame > 0, f"| Wrong note: {end_frame - start_frame}" 40 | mel2phone[start_frame:end_frame] = i_note + 1 41 | mel2note[start_frame:end_frame] = i_note + 1 42 | ph_token_list.extend(midi[7]) 43 | ph_list.extend(midi[8]) 44 | note_token_list.append(midi[3]) 45 | i_note += 1 46 | 47 | mel2phone[-1] = mel2phone[-2] 48 | mel2note[-1] = mel2note[-2] 49 | assert not np.any(mel2phone == 0) and not np.any(mel2note == 0), f"| mel2phone: {mel2phone}, mel2note: {mel2note}, midi_info: {midi_info}" 50 | assert mel2phone[-1] == len(ph_token_list), f"| last melphone index: {mel2phone[-1]}, length ph_list: {len(ph_token_list)}, midi_info: {len(midi_info_)}" 51 | 52 | T_ph = len(ph_list) 53 | duration = mel2token_to_dur(mel2phone, T_ph) 54 | 55 | return mel2phone.tolist(), mel2note.tolist(), duration.tolist(), ph_token_list, ph_list, note_token_list, midi_info_ 56 | 57 | 58 | def get_note2dur(midi_info, hop_size, sample_rate, min_sil_duration=0): 59 | # Check intervals 60 | midi_info_ = [] 61 | for i, midi in enumerate(midi_info): 62 | # Check (i-th start time - i-1th end time) 63 | if i > 0 and midi[4] - midi_info_[-1][5] < min_sil_duration: 64 | midi_info_[-1][5] = midi[4] 65 | if i > 0 and midi[8] == "|" and midi_info_[-1][8] == "|": 66 | midi_info_[-1][5] = midi[5] 67 | else: 68 | midi_info_.append(midi) 69 | # Check phoneme 70 | last_frame = int(midi_info_[-1][5] * sample_rate / hop_size + 0.5) 71 | mel2phone = np.zeros([last_frame], dtype=int) 72 | mel2note = np.zeros([last_frame], dtype=int) 73 | ph_list = [] 74 | i_note = 0 75 | i_ph = 0 76 | while i_note < len(midi_info_): 77 | # midi_info: (Bar, Pos, Pitch, Duration_midi, start_time, end_time, Tempo, Syllable) 78 | midi = midi_info_[i_note] 79 | start_frame = int(midi[4] * sample_rate / hop_size + 0.5) 80 | end_frame = int(midi[5] * sample_rate / hop_size + 0.5) 81 | if len(midi[7]) == 1: 82 | mel2phone[start_frame:end_frame] = i_ph + 1 83 | i_ph += 1 84 | elif len(midi[7]) == 2: 85 | mel2phone[start_frame:start_frame+3] = i_ph + 1 86 | mel2phone[start_frame+3:end_frame] = i_ph + 2 87 | i_ph += 2 88 | elif len(midi[7]) == 3: 89 | # Korean syllable consist of consonant, vowel, coda 90 | mel2phone[start_frame:start_frame+3] = i_ph + 1 91 | mel2phone[start_frame+3:end_frame-3] = i_ph + 2 92 | mel2phone[end_frame-3:end_frame] = i_ph + 3 93 | i_ph += 3 94 | ph_list.extend(midi[7]) 95 | mel2note[start_frame:end_frame] = i_note + 1 96 | i_note += 1 97 | 98 | mel2phone[-1] = mel2phone[-2] 99 | mel2note[-1] = mel2note[-2] 100 | assert not np.any(mel2phone == 0) and not np.any(mel2note == 0), f"| mel2phone: {mel2phone}, mel2note: {mel2note}, midi_info: {midi_info}" 101 | T_ph = len(ph_list) 102 | duration = mel2token_to_dur(mel2phone, T_ph) 103 | 104 | return mel2phone.tolist(), mel2note.tolist(), duration.tolist(), ph_list, midi_info_ 105 | 106 | 107 | def mel2token_to_dur(mel2token: torch.Tensor, T_txt=None, max_dur=None): 108 | # Check input data settings 109 | is_torch = isinstance(mel2token, torch.Tensor) 110 | has_batch_dim = True 111 | if not is_torch: 112 | mel2token = torch.LongTensor(mel2token) 113 | if T_txt is None: 114 | T_txt = mel2token.max() 115 | if len(mel2token.shape) == 1: 116 | mel2token = mel2token[None, ...] 117 | has_batch_dim = False 118 | 119 | B, _ = mel2token.shape 120 | dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token)) 121 | dur = dur[:, 1:] 122 | if max_dur is not None: 123 | dur = dur.clamp(max=max_dur) 124 | if not is_torch: 125 | dur = dur.numpy() 126 | if not has_batch_dim: 127 | dur = dur[0] 128 | 129 | return dur 130 | -------------------------------------------------------------------------------- /modules/visinger/decoder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/jaywalnut310/vits 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn.utils import weight_norm, remove_weight_norm 7 | 8 | from modules.commons.utils import init_weights, get_padding 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class Generator(nn.Module): 14 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, 15 | upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 16 | super(Generator, self).__init__() 17 | self.num_kernels = len(resblock_kernel_sizes) 18 | self.num_upsamples = len(upsample_rates) 19 | self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 20 | resblock = ResBlock1 if resblock == '1' else ResBlock2 21 | 22 | self.ups = nn.ModuleList() 23 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 24 | self.ups.append(weight_norm( 25 | nn.ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), 26 | k, u, padding=(k - u) // 2))) 27 | 28 | self.resblocks = nn.ModuleList() 29 | for i in range(len(self.ups)): 30 | ch = upsample_initial_channel // (2 ** (i + 1)) 31 | for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 32 | self.resblocks.append(resblock(ch, k, d)) 33 | 34 | self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) 35 | self.ups.apply(init_weights) 36 | 37 | if gin_channels != 0: 38 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 39 | 40 | def forward(self, x, g=None): 41 | x = self.conv_pre(x) 42 | if g is not None: 43 | x = x + self.cond(g) 44 | 45 | for i in range(self.num_upsamples): 46 | x = F.leaky_relu(x, LRELU_SLOPE) 47 | x = self.ups[i](x) 48 | xs = None 49 | for j in range(self.num_kernels): 50 | if xs is None: 51 | xs = self.resblocks[i * self.num_kernels + j](x) 52 | else: 53 | xs += self.resblocks[i * self.num_kernels + j](x) 54 | x = xs / self.num_kernels 55 | x = F.leaky_relu(x, LRELU_SLOPE) 56 | x = self.conv_post(x) 57 | x = torch.tanh(x) 58 | 59 | return x 60 | 61 | def remove_weight_norm(self): 62 | for l in self.ups: 63 | remove_weight_norm(l) 64 | for l in self.resblocks: 65 | l.remove_weight_norm() 66 | 67 | 68 | class ResBlock1(torch.nn.Module): 69 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 70 | super(ResBlock1, self).__init__() 71 | self.convs1 = nn.ModuleList([ 72 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 73 | padding=get_padding(kernel_size, dilation[0]))), 74 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 75 | padding=get_padding(kernel_size, dilation[1]))), 76 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 77 | padding=get_padding(kernel_size, dilation[2]))) 78 | ]) 79 | self.convs1.apply(init_weights) 80 | 81 | self.convs2 = nn.ModuleList([ 82 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, 83 | padding=get_padding(kernel_size, 1))), 84 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, 85 | padding=get_padding(kernel_size, 1))), 86 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, 87 | padding=get_padding(kernel_size, 1))) 88 | ]) 89 | self.convs2.apply(init_weights) 90 | 91 | def forward(self, x, x_mask=None): 92 | for c1, c2 in zip(self.convs1, self.convs2): 93 | xt = F.leaky_relu(x, LRELU_SLOPE) 94 | if x_mask is not None: 95 | xt = xt * x_mask 96 | xt = c1(xt) 97 | xt = F.leaky_relu(xt, LRELU_SLOPE) 98 | if x_mask is not None: 99 | xt = xt * x_mask 100 | xt = c2(xt) 101 | x = xt + x 102 | if x_mask is not None: 103 | x = x * x_mask 104 | return x 105 | 106 | def remove_weight_norm(self): 107 | for l in self.convs1: 108 | remove_weight_norm(l) 109 | for l in self.convs2: 110 | remove_weight_norm(l) 111 | 112 | 113 | class ResBlock2(nn.Module): 114 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 115 | super(ResBlock2, self).__init__() 116 | self.convs = nn.ModuleList([ 117 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 118 | padding=get_padding(kernel_size, dilation[0]))), 119 | weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 120 | padding=get_padding(kernel_size, dilation[1]))) 121 | ]) 122 | self.convs.apply(init_weights) 123 | 124 | def forward(self, x, x_mask=None): 125 | for c in self.convs: 126 | xt = F.leaky_relu(x, LRELU_SLOPE) 127 | if x_mask is not None: 128 | xt = xt * x_mask 129 | xt = c(xt) 130 | x = xt + x 131 | if x_mask is not None: 132 | x = x * x_mask 133 | return x 134 | 135 | def remove_weight_norm(self): 136 | for l in self.convs: 137 | remove_weight_norm(l) 138 | -------------------------------------------------------------------------------- /utils/commons/hparams.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import argparse 3 | import os 4 | import yaml 5 | 6 | from utils.os_utils import remove_file 7 | 8 | global_print_hparams = True 9 | hparams = {} 10 | 11 | 12 | class Args: 13 | def __init__(self, **kwargs): 14 | for k, v in kwargs.items(): 15 | self.__setattr__(k, v) 16 | 17 | 18 | def override_config(old_config: dict, new_config: dict): 19 | for k, v in new_config.items(): 20 | if isinstance(v, dict) and k in old_config: 21 | override_config(old_config[k], new_config[k]) 22 | else: 23 | old_config[k] = v 24 | 25 | 26 | def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): 27 | if config == '' and exp_name == '': 28 | parser = argparse.ArgumentParser(description='') 29 | parser.add_argument('--config', type=str, default='', 30 | help='location of the data corpus') 31 | parser.add_argument('--exp_name', type=str, default='', help='exp_name') 32 | parser.add_argument("--save_path", type=str, default="", help="saved path") 33 | parser.add_argument('-hp', '--hparams', type=str, default='', 34 | help='location of the data corpus') 35 | parser.add_argument('--infer', action='store_true', help='infer') 36 | parser.add_argument('--validate', action='store_true', help='validate') 37 | parser.add_argument('--reset', action='store_true', help='reset hparams') 38 | parser.add_argument('--remove', action='store_true', help='remove old ckpt') 39 | parser.add_argument('--debug', action='store_true', help='debug') 40 | args, unknown = parser.parse_known_args() 41 | print("| Unknow hparams: ", unknown) 42 | else: 43 | args = Args(config=config, exp_name=exp_name, hparams=hparams_str, 44 | infer=False, validate=False, reset=False, debug=False, remove=False) 45 | global hparams 46 | assert args.config != '' or args.exp_name != '' 47 | if args.config != '': 48 | assert os.path.exists(args.config) 49 | 50 | config_chains = [] 51 | loaded_config = set() 52 | 53 | def load_config(config_fn): 54 | # deep first inheritance and avoid the second visit of one node 55 | if not os.path.exists(config_fn): 56 | return {} 57 | with open(config_fn) as f: 58 | hparams_ = yaml.safe_load(f) 59 | loaded_config.add(config_fn) 60 | if 'base_config' in hparams_: 61 | ret_hparams = {} 62 | if not isinstance(hparams_['base_config'], list): 63 | hparams_['base_config'] = [hparams_['base_config']] 64 | for c in hparams_['base_config']: 65 | if c.startswith('.'): 66 | c = f'{os.path.dirname(config_fn)}/{c}' 67 | c = os.path.normpath(c) 68 | if c not in loaded_config: 69 | override_config(ret_hparams, load_config(c)) 70 | override_config(ret_hparams, hparams_) 71 | else: 72 | ret_hparams = hparams_ 73 | config_chains.append(config_fn) 74 | return ret_hparams 75 | 76 | saved_hparams = {} 77 | args_work_dir = '' 78 | if args.exp_name != '': 79 | args_work_dir = f'checkpoints/{args.exp_name}' 80 | ckpt_config_path = f'{args_work_dir}/config.yaml' 81 | if os.path.exists(ckpt_config_path): 82 | with open(ckpt_config_path) as f: 83 | saved_hparams_ = yaml.safe_load(f) 84 | if saved_hparams_ is not None: 85 | saved_hparams.update(saved_hparams_) 86 | hparams_ = {} 87 | if args.config != '': 88 | hparams_.update(load_config(args.config)) 89 | if not args.reset: 90 | hparams_.update(saved_hparams) 91 | hparams_['work_dir'] = args_work_dir 92 | print("args_work: ", args_work_dir) 93 | 94 | # Support config overriding in command line. Support list type config overriding. 95 | # Examples: --hparams="a=1,b.c=2,d=[1 1 1]" 96 | if args.hparams != "": 97 | for new_hparam in args.hparams.split(","): 98 | k, v = new_hparam.split("=") 99 | v = v.strip("\'\" ") 100 | config_node = hparams_ 101 | for k_ in k.split(".")[:-1]: 102 | config_node = config_node[k_] 103 | k = k.split(".")[-1] 104 | if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: 105 | if type(config_node[k]) == list: 106 | v = v.replace(" ", ",") 107 | config_node[k] = eval(v) 108 | else: 109 | config_node[k] = type(config_node[k])(v) 110 | if args_work_dir != '' and args.remove: 111 | answer = input("REMOVE old checkpoint? Y/N [Default: N]: ") 112 | if answer.lower() == "y": 113 | remove_file(args_work_dir) 114 | if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: 115 | os.makedirs(args_work_dir, exist_ok=True) 116 | with open(ckpt_config_path, 'w') as f: 117 | yaml.safe_dump(hparams_, f) 118 | 119 | hparams_['infer'] = args.infer 120 | hparams_['debug'] = args.debug 121 | hparams_['validate'] = args.validate 122 | hparams_['exp_name'] = args.exp_name 123 | global global_print_hparams 124 | if global_hparams: 125 | hparams.clear() 126 | hparams.update(hparams_) 127 | if print_hparams and global_print_hparams and global_hparams: 128 | print('| Hparams chains: ', config_chains) 129 | print('| Hparams: ') 130 | for i, (k, v) in enumerate(sorted(hparams_.items())): 131 | print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") 132 | print("") 133 | global_print_hparams = False 134 | return hparams_ 135 | -------------------------------------------------------------------------------- /utils/commons/ddp_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import torch 3 | import torch.optim 4 | import torch.utils.data 5 | 6 | from torch.nn.parallel import DistributedDataParallel 7 | from torch.nn.parallel.distributed import _find_tensors 8 | from packaging import version 9 | 10 | 11 | class DDP(DistributedDataParallel): 12 | """ 13 | Override the forward call in lightning so it goes to training and validation step respectively 14 | """ 15 | 16 | def forward(self, *inputs, **kwargs): # pragma: no cover 17 | if version.parse(torch.__version__[:6]) < version.parse("1.11"): 18 | self._sync_params() 19 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 20 | assert len(self.device_ids) == 1 21 | if self.module.training: 22 | output = self.module.training_step(*inputs[0], **kwargs[0]) 23 | elif self.module.testing: 24 | output = self.module.test_step(*inputs[0], **kwargs[0]) 25 | else: 26 | output = self.module.validation_step(*inputs[0], **kwargs[0]) 27 | if torch.is_grad_enabled(): 28 | # We'll return the output object verbatim since it is a freeform 29 | # object. We need to find any tensors in this object, though, 30 | # because we need to figure out which parameters were used during 31 | # this forward pass, to ensure we short circuit reduction for any 32 | # unused parameters. Only if `find_unused_parameters` is set. 33 | if self.find_unused_parameters: 34 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 35 | else: 36 | self.reducer.prepare_for_backward([]) 37 | else: 38 | from torch.nn.parallel.distributed import \ 39 | logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref 40 | with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): 41 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 42 | self.logger.set_runtime_stats_and_log() 43 | self.num_iterations += 1 44 | self.reducer.prepare_for_forward() 45 | 46 | # Notify the join context that this process has not joined, if 47 | # needed 48 | work = Join.notify_join_context(self) 49 | if work: 50 | self.reducer._set_forward_pass_work_handle( 51 | work, self._divide_by_initial_world_size 52 | ) 53 | 54 | # Calling _rebuild_buckets before forward compuation, 55 | # It may allocate new buckets before deallocating old buckets 56 | # inside _rebuild_buckets. To save peak memory usage, 57 | # call _rebuild_buckets before the peak memory usage increases 58 | # during forward computation. 59 | # This should be called only once during whole training period. 60 | if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): 61 | logging.info("Reducer buckets have been rebuilt in this iteration.") 62 | self._has_rebuilt_buckets = True 63 | 64 | # sync params according to location (before/after forward) user 65 | # specified as part of hook, if hook was specified. 66 | buffer_hook_registered = hasattr(self, 'buffer_hook') 67 | if self._check_sync_bufs_pre_fwd(): 68 | self._sync_buffers() 69 | 70 | if self._join_config.enable: 71 | # Notify joined ranks whether they should sync in backwards pass or not. 72 | self._check_global_requires_backward_grad_sync(is_joined_rank=False) 73 | 74 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 75 | if self.module.training: 76 | output = self.module.training_step(*inputs[0], **kwargs[0]) 77 | elif self.module.testing: 78 | output = self.module.test_step(*inputs[0], **kwargs[0]) 79 | else: 80 | output = self.module.validation_step(*inputs[0], **kwargs[0]) 81 | 82 | # sync params according to location (before/after forward) user 83 | # specified as part of hook, if hook was specified. 84 | if self._check_sync_bufs_post_fwd(): 85 | self._sync_buffers() 86 | 87 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 88 | self.require_forward_param_sync = True 89 | # We'll return the output object verbatim since it is a freeform 90 | # object. We need to find any tensors in this object, though, 91 | # because we need to figure out which parameters were used during 92 | # this forward pass, to ensure we short circuit reduction for any 93 | # unused parameters. Only if `find_unused_parameters` is set. 94 | if self.find_unused_parameters and not self.static_graph: 95 | # Do not need to populate this for static graph. 96 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 97 | else: 98 | self.reducer.prepare_for_backward([]) 99 | else: 100 | self.require_forward_param_sync = False 101 | 102 | # TODO: DDPSink is currently enabled for unused parameter detection and 103 | # static graph training for first iteration. 104 | if (self.find_unused_parameters and not self.static_graph) or ( 105 | self.static_graph and self.num_iterations == 1 106 | ): 107 | state_dict = { 108 | 'static_graph': self.static_graph, 109 | 'num_iterations': self.num_iterations, 110 | } 111 | 112 | output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( 113 | output 114 | ) 115 | output_placeholders = [None for _ in range(len(output_tensor_list))] 116 | # Do not touch tensors that have no grad_fn, which can cause issues 117 | # such as https://github.com/pytorch/pytorch/issues/60733 118 | for i, output in enumerate(output_tensor_list): 119 | if torch.is_tensor(output) and output.grad_fn is None: 120 | output_placeholders[i] = output 121 | 122 | # When find_unused_parameters=True, makes tensors which require grad 123 | # run through the DDPSink backward pass. When not all outputs are 124 | # used in loss, this makes those corresponding tensors receive 125 | # undefined gradient which the reducer then handles to ensure 126 | # param.grad field is not touched and we don't error out. 127 | passthrough_tensor_list = _DDPSink.apply( 128 | self.reducer, 129 | state_dict, 130 | *output_tensor_list, 131 | ) 132 | for i in range(len(output_placeholders)): 133 | if output_placeholders[i] is None: 134 | output_placeholders[i] = passthrough_tensor_list[i] 135 | 136 | # Reconstruct output data structure. 137 | output = _tree_unflatten_with_rref( 138 | output_placeholders, treespec, output_is_rref 139 | ) 140 | return output 141 | -------------------------------------------------------------------------------- /utils/commons/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import os 3 | import numpy as np 4 | import sys 5 | import traceback 6 | import types 7 | import torch 8 | import torchaudio 9 | 10 | from functools import wraps 11 | from itertools import chain 12 | from torch.utils.data import ConcatDataset, Dataset 13 | 14 | from utils.commons.hparams import hparams 15 | 16 | 17 | def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): 18 | if len(values[0].shape) == 1: 19 | return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id) 20 | else: 21 | return collate_2d(values, pad_idx, left_pad, shift_right, max_len) 22 | 23 | 24 | def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): 25 | """Convert a list of 1d tensors into a padded 2d tensor.""" 26 | size = max(v.size(0) for v in values) if max_len is None else max_len 27 | res = values[0].new(len(values), size).fill_(pad_idx) 28 | 29 | def copy_tensor(src, dst): 30 | assert dst.numel() == src.numel() 31 | if shift_right: 32 | dst[1:] = src[:-1] 33 | dst[0] = shift_id 34 | else: 35 | dst.copy_(src) 36 | 37 | for i, v in enumerate(values): 38 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 39 | return res 40 | 41 | 42 | def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): 43 | """Convert a list of 2d tensors into a padded 3d tensor.""" 44 | size = max(v.size(0) for v in values) if max_len is None else max_len 45 | res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) 46 | 47 | def copy_tensor(src, dst): 48 | assert dst.numel() == src.numel() 49 | if shift_right: 50 | dst[1:] = src[:-1] 51 | else: 52 | dst.copy_(src) 53 | 54 | for i, v in enumerate(values): 55 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 56 | return res 57 | 58 | 59 | def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 60 | if len(batch) == 0: 61 | return 0 62 | if len(batch) == max_sentences: 63 | return 1 64 | if num_tokens > max_tokens: 65 | return 1 66 | return 0 67 | 68 | 69 | def batch_by_size(indices, num_tokens_fn, max_tokens=None, max_sentences=None, 70 | required_batch_size_multiple=1, distributed=False): 71 | """ Yield mini-batches of indices bucketed by size. Batches may contain 72 | sequences of different lengths. 73 | 74 | Parameters 75 | ---------- 76 | indices: List[int] 77 | ordered list of dataset indices 78 | num_tokens_fn: callable 79 | function that returns the number of tokens at a given index 80 | max_tokens: int, optional 81 | max number of tokens in each batch (default: None). 82 | max_sentences: int, optional 83 | max number of sentences in each batch (default: None). 84 | required_batch_size_multiple: int, optional 85 | require batch size to be a multiple of N (default: 1). 86 | """ 87 | max_tokens = max_tokens if max_tokens is not None else sys.maxsize 88 | max_sentences = max_sentences if max_sentences is not None else sys.maxsize 89 | batch_size_mult = required_batch_size_multiple 90 | 91 | if isinstance(indices, types.GeneratorType): 92 | indices = np.fromiter(indices, dtype=np.int64, count=-1) 93 | 94 | sample_len = 0 95 | sample_lens = [] 96 | batch = [] 97 | batches = [] 98 | for i in range(len(indices)): 99 | idx = indices[i] 100 | num_tokens = num_tokens_fn(idx) 101 | sample_lens.append(num_tokens) 102 | sample_len = max(sample_len, num_tokens) 103 | 104 | assert sample_len <= max_tokens, ( 105 | f"sentence at index {idx} of size {sample_len} exceeds max_tokens limit of {max_tokens}!") 106 | num_tokens = (len(batch) + 1) * sample_len 107 | 108 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 109 | mod_len = max( 110 | batch_size_mult * (len(batch) // batch_size_mult), len(batch) % batch_size_mult) 111 | batches.append(batch[:mod_len]) 112 | batch = batch[mod_len:] 113 | sample_lens = sample_lens[mod_len:] 114 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 115 | batch.append(idx) 116 | if len(batch) > 0: 117 | batches.append(batch) 118 | return batches 119 | 120 | 121 | def data_loader(fn): 122 | """ 123 | Decorator to make any fix with this use the lazy property 124 | 125 | Parameters 126 | ---------- 127 | fn: callable 128 | function for data loader 129 | """ 130 | wraps(fn) # Update function inofrmation 131 | attr_name = "_lazy_" + fn.__name__ 132 | 133 | def _get_data_loader(self): 134 | try: 135 | value = getattr(self, attr_name) 136 | except AttributeError: 137 | try: 138 | # Lazy evaluation, done only once. 139 | value = fn(self) 140 | except AttributeError as e: 141 | # Guard against AttributeError suppression. 142 | traceback.print_exc() 143 | error = f"{fn.__name__}: An AttributeError was encoutered: {str(e)}" 144 | raise RuntimeError(error) from e 145 | # Memorize evaluation. 146 | setattr(self, attr_name, value) 147 | return value 148 | 149 | return _get_data_loader 150 | 151 | 152 | class BaseDataset(Dataset): 153 | def __init__(self, shuffle): 154 | super().__init__() 155 | self.hparams = hparams 156 | self.shuffle = shuffle 157 | self.sort_by_len = hparams["sort_by_len"] 158 | self.sizes = None 159 | 160 | @property 161 | def _sizes(self): 162 | return self.sizes 163 | 164 | def __getitem__(self, index): 165 | raise NotImplemented 166 | 167 | def collater(self, samples): 168 | raise NotImplementedError 169 | 170 | def __len__(self): 171 | return len(self._sizes) 172 | 173 | def num_tokens(self, index): 174 | return self.size(index) 175 | 176 | def size(self, index): 177 | """ Return an example's size as a float or tuple, This value is used 178 | when filtering a dataset with ``--max-positions``. """ 179 | return min(self._sizes[index], self.hparams["max_tokens"]) 180 | 181 | def ordered_indices(self): 182 | """ Return an ordered list of indices. Batches will be constructed 183 | based on this order. """ 184 | if self.shuffle: 185 | indices = np.random.permutation(len(self)) 186 | if self.sort_by_len: 187 | indices = indices[np.argsort(np.array(self._sizes)[indices], kind="mergesort")] 188 | else: 189 | indices = np.arange(len(self)) 190 | 191 | return indices 192 | 193 | @property 194 | def num_workers(self): 195 | return int(os.getenv("NUM_WORKERS", self.hparams["ds_workers"])) 196 | 197 | def load_audio_to_torch(self, audio_path): 198 | """ [WARN] You have to normalize waveform by max wav value. """ 199 | wav, sample_rate = torchaudio.load(audio_path, format="wav", normalize=False) 200 | # To ensure upsampling/downsampling will be processed in a right way for full signals 201 | return wav.squeeze(0), sample_rate 202 | 203 | 204 | class BaseConcatDataset(ConcatDataset): 205 | def collater(self, samples): 206 | return self.datasets[0].collater(samples) 207 | 208 | @property 209 | def _sizes(self): 210 | if not hasattr(self, "sizes"): 211 | self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets])) 212 | return self.sizes 213 | 214 | def size(self, index): 215 | return min(self._sizes[index], hparams["max_frames"]) 216 | 217 | def num_tokens(self, index): 218 | return self.size(index) 219 | 220 | def ordered_indices(self): 221 | """ Return an ordered list of indices. Batches will be constructed 222 | based on this order. """ 223 | if self.datasets[0].shuffle: 224 | indices = np.random.permutation(len(self)) 225 | if self.datasets[0].sort_by_len: 226 | indices = indices[np.argsort(np.array(self._sizes)[indices], kind="mergesort")] 227 | else: 228 | indices = np.arange(len(self)) 229 | 230 | @property 231 | def num_workers(self): 232 | return self.datasets[0].num_workers 233 | -------------------------------------------------------------------------------- /models/visinger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from copy import deepcopy 5 | 6 | from modules.commons.utils import Embedding, rand_slice_segments 7 | from modules.discriminator import DiscriminatorP, DiscriminatorS 8 | from modules.rel_transformer import SinusoidalPositionalEmbedding 9 | from modules.visinger.encoder import TextEncoder, PosteriorEncoder, FramePriorNetwork 10 | from modules.visinger.decoder import Generator 11 | from modules.visinger.predictor import PitchPredictor, PhonemePredictor 12 | from modules.visinger.flow import ResidualCouplingBlock 13 | from utils.audio.pitch.utils import denorm_f0, f0_to_coarse 14 | 15 | DEFAULT_MAX_TARGET_POSITIONS = 2000 16 | 17 | 18 | class VISinger(nn.Module): 19 | """ VISinger Implementation [Y. Zang et al., 2022] """ 20 | def __init__(self, ph_dict_size, pitch_size, dur_size, hparams, out_dims=None): 21 | super().__init__() 22 | self.hparams = deepcopy(hparams) 23 | self.enc_layers = hparams["enc_layers"] 24 | self.dec_blocks = hparams["dec_blocks"] 25 | self.hidden_size = hparams["hidden_size"] 26 | self.use_pos_embed = hparams["use_pos_embed"] 27 | self.segment_size = hparams["segment_size"] 28 | self.out_dims = hparams["num_mel_bins"] if out_dims is None else out_dims 29 | # Multi-speaker settings 30 | if hparams["use_spk_id"]: 31 | self.spk_id_proj = Embedding(hparams["num_spk"], hparams["gin_channels"]) 32 | if hparams['use_spk_embed']: 33 | self.spk_embed_proj = nn.Linear(256, hparams["gin_channels"], bias=True) 34 | #################### 35 | # Prior encoder 36 | #################### 37 | # Text encoder 38 | self.text_encoder = TextEncoder(ph_dict_size, pitch_size, dur_size, self.hidden_size, hparams["ffn_filter_channels"], 39 | hparams["num_heads"], self.enc_layers, hparams["ffn_kernel_size"], hparams["p_dropout"], True) 40 | # Position encoding 41 | self.embed_positions = SinusoidalPositionalEmbedding(self.hidden_size, 0, init_size=DEFAULT_MAX_TARGET_POSITIONS) 42 | # Pitch predictor 43 | if hparams["use_pitch_embed"]: 44 | self.pitch_predictor = PitchPredictor(self.hidden_size, hparams["ffn_filter_channels"], hparams["num_heads"], 45 | n_layers=hparams["pitch_predictor_layers"], kernel_size=hparams['ffn_kernel_size'], 46 | p_dropout=hparams["p_dropout"], gin_channels=hparams["gin_channels"], out_dim=2) 47 | # Phoneme predictor 48 | if hparams["use_phoneme_pred"]: 49 | self.phoneme_predictor = PhonemePredictor(ph_dict_size, self.hidden_size, hparams["ffn_filter_channels"], hparams["num_heads"], 50 | n_layers=hparams["phoneme_predictor_layers"], kernel_size=hparams["ffn_kernel_size"], 51 | p_dropout=hparams["p_dropout"]) 52 | # Frame prior network 53 | self.frame_prior = FramePriorNetwork(self.hidden_size, hparams["ffn_filter_channels"], hparams["num_heads"], hparams["frame_prior_layers"], 54 | hparams["ffn_kernel_size"], p_dropout=hparams["p_dropout"], gin_channels=1) 55 | 56 | #################### 57 | # Posterior encoder 58 | #################### 59 | self.posterior_encoder = PosteriorEncoder(hparams["num_linear_bins"], self.hidden_size, self.hidden_size, 5, 1, 16, 60 | gin_channels=hparams["gin_channels"]) 61 | #################### 62 | # Generator 63 | #################### 64 | # Flow 65 | self.flow = ResidualCouplingBlock(self.hidden_size, self.hidden_size, 5, 1, 4, gin_channels=hparams["gin_channels"]) 66 | # Raw Waveform Deocder 67 | self.decoder = Generator(self.hidden_size, hparams["dec_blocks"], hparams["dec_kernel_size"], 68 | hparams["dec_dilation_sizes"], hparams["upsample_rates"], hparams["initial_upsample_channels"], 69 | hparams["upsample_kernel_sizes"], gin_channels=hparams["gin_channels"]) 70 | 71 | def forward(self, text_tokens, pitch_tokens, dur_tokens, mel2ph, spk_embed=None, spk_id=None, f0=None, uv=None, 72 | mel=None, infer=False, **kwargs): 73 | ret = {} 74 | # Encoder 75 | tgt_nonpadding = (mel2ph > 0).float().unsqueeze(1) 76 | prior_inp = self.text_encoder(text_tokens, pitch_tokens, dur_tokens, mel2ph) # [Batch, Hidden, T_len] 77 | prior_inp = prior_inp * tgt_nonpadding 78 | # Positional encoding 79 | if self.use_pos_embed: 80 | pos_in = prior_inp.transpose(1, 2)[..., 0] 81 | positions = self.embed_positions(prior_inp.shape[0], prior_inp.shape[2], pos_in) 82 | prior_inp = prior_inp + positions.transpose(1, 2) 83 | # Multi-speaker settings 84 | spk_emb = self.speaker_embedding(spk_embed, spk_id).transpose(1, 2) 85 | # Pitch prediction 86 | cond_pitch = None 87 | if self.hparams["use_pitch_embed"]: 88 | cond_pitch = self.forward_pitch(prior_inp, f0, uv, spk_emb, tgt_nonpadding, ret) # [Batch, Hidden, T_len] 89 | # Frame prior network 90 | mu_p, logs_p = self.frame_prior(prior_inp, tgt_nonpadding, cond_pitch) 91 | if not infer: 92 | # Posterior encoder 93 | z_q, _, logs_q = self.posterior_encoder(mel.transpose(1, 2), tgt_nonpadding, g=spk_emb) 94 | # Phoneme prediction 95 | if self.hparams["use_phoneme_pred"]: 96 | ret["ph_pred"] = self.phoneme_predictor(z_q, tgt_nonpadding) * tgt_nonpadding 97 | # Normalizing Flow for posterior to prior 98 | z_p = ret["z_p"] = self.flow(z_q, tgt_nonpadding, g=spk_emb) * tgt_nonpadding # [Batch, Hidden, T_mel] 99 | # KL-divergence between prior from posterior and prior from prior encoder 100 | kl = (logs_p - logs_q - 0.5) + 0.5 * ((z_p - mu_p) ** 2) * torch.exp(-2. * logs_p) 101 | ret["kl"] = (kl * tgt_nonpadding).sum() / tgt_nonpadding.sum() 102 | # Waveform decoder 103 | z_slice, ret["ids_slice"] = rand_slice_segments(z_q, self.segment_size) 104 | ret["wav_out"] = self.decoder(z_slice, g=spk_emb).squeeze(1) 105 | else: 106 | # Reparameterization trick for prior 107 | z_p = (mu_p + torch.randn_like(mu_p) * torch.exp(logs_p)) * tgt_nonpadding # [Batch, Hidden, T_len] 108 | # Normalizing flow for prior to posterior 109 | z_q = self.flow(z_p, tgt_nonpadding, g=spk_emb, reverse=True) * tgt_nonpadding 110 | # Waveform decoder 111 | ret["wav_out"] = self.decoder(z_q * tgt_nonpadding, g=spk_emb).squeeze(1) 112 | return ret 113 | 114 | def speaker_embedding(self, spk_embed=None, spk_id=None): 115 | # Add speaker embedding 116 | speaker_embed = 0 117 | if self.hparams['use_spk_embed']: 118 | speaker_embed = speaker_embed + self.spk_embed_proj(spk_embed)[:, None, :] 119 | if self.hparams['use_spk_id']: 120 | speaker_embed = speaker_embed + self.spk_id_proj(spk_id)[:, None, :] 121 | return speaker_embed 122 | 123 | def forward_pitch(self, pitch_inp, f0, uv, spk_emb, tgt_nonpadding, ret): 124 | # Pitch prediction 125 | if self.hparams['predictor_grad'] != 1: 126 | pitch_inp = pitch_inp.detach() + self.hparams['predictor_grad'] * (pitch_inp - pitch_inp.detach()) 127 | ret['f0_pred'] = pitch_pred = self.pitch_predictor(pitch_inp, tgt_nonpadding, spk_emb) 128 | # Teacher forcing settings 129 | if f0 is None: 130 | f0 = pitch_pred[:, :, 0] 131 | v = (pitch_pred[:, :, 1] <= 0) # consider voiced part 132 | else: 133 | v = (uv == 0) 134 | f0 = (f0 * v).unsqueeze(1) * tgt_nonpadding 135 | return f0 # using log_f0 136 | 137 | 138 | class MultiPeriodDiscriminator(nn.Module): 139 | def __init__(self, use_spectral_norm=False): 140 | super(MultiPeriodDiscriminator, self).__init__() 141 | periods = [2, 3, 5, 7, 11] 142 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 143 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 144 | self.discriminators = nn.ModuleList(discs) 145 | 146 | def forward(self, y, y_hat): 147 | y_d_rs = [] 148 | y_d_gs = [] 149 | fmap_rs = [] 150 | fmap_gs = [] 151 | for _, d in enumerate(self.discriminators): 152 | y_d_r, fmap_r = d(y) 153 | y_d_g, fmap_g = d(y_hat) 154 | y_d_rs.append(y_d_r) 155 | y_d_gs.append(y_d_g) 156 | fmap_rs.append(fmap_r) 157 | fmap_gs.append(fmap_g) 158 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 159 | -------------------------------------------------------------------------------- /utils/commons/base_task.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import numpy as np 7 | import torch.utils.data 8 | 9 | from torch import nn 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from utils.commons.dataset_utils import data_loader 13 | from utils.commons.hparams import hparams 14 | from utils.commons.meters import AvgrageMeter 15 | from utils.commons.tensor_utils import tensors_to_scalars 16 | from utils.commons.trainer import Trainer 17 | 18 | # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # for degugging 19 | torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) 20 | 21 | log_format = '%(asctime)s %(message)s' 22 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 23 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 24 | 25 | 26 | class BaseTask(nn.Module): 27 | def __init__(self, *args, **kwargs): 28 | super(BaseTask, self).__init__() 29 | self.current_epoch = 0 30 | self.global_step = 0 31 | self.trainer = None 32 | self.use_ddp = False 33 | self.gradient_clip_norm = hparams['clip_grad_norm'] 34 | self.gradient_clip_val = hparams.get('clip_grad_value', 0) 35 | self.model = None 36 | self.training_losses_meter = None 37 | self.logger: SummaryWriter = None 38 | 39 | ###################### 40 | # build model, dataloaders, optimizer, scheduler and tensorboard 41 | ###################### 42 | def build_model(self): 43 | raise NotImplementedError 44 | 45 | @data_loader 46 | def train_dataloader(self): 47 | raise NotImplementedError 48 | 49 | @data_loader 50 | def test_dataloader(self): 51 | raise NotImplementedError 52 | 53 | @data_loader 54 | def val_dataloader(self): 55 | raise NotImplementedError 56 | 57 | def build_scheduler(self, optimizer): 58 | return None 59 | 60 | def build_optimizer(self, model): 61 | raise NotImplementedError 62 | 63 | def configure_optimizers(self): 64 | optm = self.build_optimizer(self.model) 65 | self.scheduler = self.build_scheduler(optm) 66 | if isinstance(optm, (list, tuple)): 67 | return optm 68 | return [optm] 69 | 70 | def build_tensorboard(self, save_dir, name, **kwargs): 71 | log_dir = os.path.join(save_dir, name) 72 | os.makedirs(log_dir, exist_ok=True) 73 | self.logger = SummaryWriter(log_dir=log_dir, **kwargs) 74 | 75 | ###################### 76 | # training 77 | ###################### 78 | def on_train_start(self): 79 | pass 80 | 81 | def on_train_end(self): 82 | pass 83 | 84 | def on_epoch_start(self): 85 | self.training_losses_meter = {'total_loss': AvgrageMeter()} 86 | 87 | def on_epoch_end(self): 88 | loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()} 89 | print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}") 90 | 91 | def _training_step(self, sample, batch_idx, optimizer_idx): 92 | """ 93 | :param sample: 94 | :param batch_idx: 95 | :return: total loss: torch.Tensor, loss_log: dict 96 | """ 97 | raise NotImplementedError 98 | 99 | def training_step(self, sample, batch_idx, optimizer_idx=-1): 100 | """ 101 | :param sample: 102 | :param batch_idx: 103 | :param optimizer_idx: 104 | :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict} 105 | """ 106 | loss_ret = self._training_step(sample, batch_idx, optimizer_idx) 107 | if loss_ret is None: 108 | return {'loss': None} 109 | total_loss, log_outputs = loss_ret 110 | log_outputs = tensors_to_scalars(log_outputs) 111 | for k, v in log_outputs.items(): 112 | if k not in self.training_losses_meter: 113 | self.training_losses_meter[k] = AvgrageMeter() 114 | if not np.isnan(v): 115 | self.training_losses_meter[k].update(v) 116 | self.training_losses_meter['total_loss'].update(total_loss.item()) 117 | 118 | if optimizer_idx >= 0: 119 | log_outputs[f'lr_{optimizer_idx}'] = self.trainer.optimizers[optimizer_idx].param_groups[0]['lr'] 120 | 121 | progress_bar_log = log_outputs 122 | tb_log = {f'train/{k}': v for k, v in log_outputs.items()} 123 | return {'loss': total_loss, 124 | 'progress_bar': progress_bar_log, 125 | 'tb_log': tb_log} 126 | 127 | def on_before_optimization(self, opt_idx): 128 | if self.gradient_clip_norm > 0: 129 | torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm) 130 | if self.gradient_clip_val > 0: 131 | torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val) 132 | 133 | def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): 134 | if self.scheduler is not None: 135 | self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) 136 | 137 | ###################### 138 | # validation 139 | ###################### 140 | def validation_start(self): 141 | pass 142 | 143 | def validation_step(self, sample, batch_idx): 144 | """ 145 | 146 | :param sample: 147 | :param batch_idx: 148 | :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict) 149 | """ 150 | raise NotImplementedError 151 | 152 | def validation_end(self, outputs): 153 | """ 154 | 155 | :param outputs: 156 | :return: loss_output: dict 157 | """ 158 | all_losses_meter = {'total_loss': AvgrageMeter()} 159 | for output in outputs: 160 | if len(output) == 0 or output is None: 161 | continue 162 | if isinstance(output, dict): 163 | assert 'losses' in output, 'Key "losses" should exist in validation output.' 164 | n = output.pop('nsamples', 1) 165 | losses = tensors_to_scalars(output['losses']) 166 | total_loss = output.get('total_loss', sum(losses.values())) 167 | else: 168 | assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)' 169 | n = 1 170 | total_loss, losses = output 171 | losses = tensors_to_scalars(losses) 172 | if isinstance(total_loss, torch.Tensor): 173 | total_loss = total_loss.item() 174 | for k, v in losses.items(): 175 | if k not in all_losses_meter: 176 | all_losses_meter[k] = AvgrageMeter() 177 | all_losses_meter[k].update(v, n) 178 | all_losses_meter['total_loss'].update(total_loss, n) 179 | loss_output = {k: round(v.avg, 4) for k, v in all_losses_meter.items()} 180 | print(f"| Validation results@{self.global_step}: {loss_output}") 181 | return {'tb_log': {f'val/{k}': v for k, v in loss_output.items()}, 182 | 'val_loss': loss_output['total_loss']} 183 | 184 | ###################### 185 | # testing 186 | ###################### 187 | def test_start(self): 188 | pass 189 | 190 | def test_step(self, sample, batch_idx): 191 | return self.validation_step(sample, batch_idx) 192 | 193 | def test_end(self, outputs): 194 | return self.validation_end(outputs) 195 | 196 | ###################### 197 | # start training/testing 198 | ###################### 199 | @classmethod 200 | def start(cls): 201 | os.environ['MASTER_PORT'] = str(random.randint(15000, 30000)) 202 | random.seed(hparams['seed']) 203 | np.random.seed(hparams['seed']) 204 | work_dir = hparams['work_dir'] 205 | trainer = Trainer( 206 | work_dir=work_dir, 207 | val_check_interval=hparams['val_check_interval'], 208 | tb_log_interval=hparams['tb_log_interval'], 209 | max_updates=hparams['max_updates'], 210 | num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000, 211 | accumulate_grad_batches=hparams['accumulate_grad_batches'], 212 | print_nan_grads=hparams['print_nan_grads'], 213 | resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0), 214 | amp=hparams['amp'], 215 | monitor_key=hparams['valid_monitor_key'], 216 | monitor_mode=hparams['valid_monitor_mode'], 217 | num_ckpt_keep=hparams['num_ckpt_keep'], 218 | save_best=hparams['save_best'], 219 | seed=hparams['seed'], 220 | debug=hparams['debug'] 221 | ) 222 | if not hparams['infer']: # train 223 | trainer.fit(cls) 224 | trainer.test(cls) 225 | else: 226 | trainer.test(cls) 227 | 228 | def on_keyboard_interrupt(self): 229 | pass 230 | -------------------------------------------------------------------------------- /modules/visinger/encoder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/jaywalnut310/vits 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.commons.align_ops import expand_states 7 | from modules.rel_transformer import RelativeEncoder, SinusoidalPositionalEmbedding 8 | from modules.commons.utils import Embedding 9 | 10 | DEFAULT_MAX_TARGET_POSITIONS = 2000 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class TextEncoder(nn.Module): 15 | """ Text encoder of VISinger """ 16 | def __init__(self, ph_dict_size, note_pitch_size, note_dur_size, hidden_channels, filter_channels, 17 | n_heads, n_layers, kernel_size, p_dropout, use_pos_embed=False): 18 | super().__init__() 19 | self.dropout = p_dropout 20 | self.use_pos_embed = use_pos_embed 21 | # Input settings 22 | self.ph_emb = Embedding(ph_dict_size, hidden_channels) 23 | self.pitch_emb = Embedding(note_pitch_size, hidden_channels) 24 | self.dur_emb = Embedding(note_dur_size, hidden_channels) 25 | self.embed_scale = math.sqrt(hidden_channels) 26 | self.padding_idx = 0 27 | self.linear = nn.Linear(hidden_channels * 3, hidden_channels) 28 | self.text_encoder = RelativeEncoder(hidden_channels, filter_channels, n_heads, n_layers, 29 | kernel_size=kernel_size, p_dropout=p_dropout) 30 | # Position embedding 31 | if self.use_pos_embed: 32 | self.embed_positions = SinusoidalPositionalEmbedding(hidden_channels, 0, init_size=DEFAULT_MAX_TARGET_POSITIONS) 33 | 34 | def forward(self, text_tokens, pitch_tokens, dur_tokens, mel2ph): 35 | tgt_nonpadding = (text_tokens > 0).float().unsqueeze(1) 36 | # Text encoder 37 | token_emb = self.forward_text_embedding(text_tokens, pitch_tokens, dur_tokens, tgt_nonpadding.transpose(1, 2)) 38 | enc_out = self.text_encoder(token_emb.transpose(1, 2), tgt_nonpadding) 39 | enc_out = expand_states(enc_out.transpose(1, 2), mel2ph) 40 | return enc_out.transpose(1, 2) 41 | 42 | def forward_text_embedding(self, text_tokens, pitch_tokens, dur_tokens, nonpadding): 43 | # Inputs embedding settings 44 | text_emb = self.ph_emb(text_tokens) * self.embed_scale 45 | pitch_emb = self.pitch_emb(pitch_tokens) * self.embed_scale 46 | dur_emb = self.dur_emb(dur_tokens) * self.embed_scale 47 | # Concatenation and linear projection for text encoder 48 | token_emb = torch.cat([text_emb, pitch_emb, dur_emb], 2) 49 | token_emb = self.linear(token_emb) * nonpadding 50 | # Use position embedding 51 | if self.use_pos_embed: 52 | pos_in = token_emb[..., 0] 53 | positions = self.embed_positions(token_emb.shape[0], token_emb.shape[2], pos_in) 54 | token_emb = token_emb + positions.transpose(1, 2) 55 | return token_emb * nonpadding 56 | 57 | 58 | class FramePriorNetwork(nn.Module): 59 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, gin_channels, p_dropout): 60 | super().__init__() 61 | self.hidden_channels = hidden_channels 62 | # Frame Prior Network 63 | self.encoder = RelativeEncoder(hidden_channels, filter_channels, n_heads, n_layers=n_layers, 64 | kernel_size=kernel_size, gin_channels=gin_channels, p_dropout=p_dropout) 65 | self.proj = nn.Conv1d(self.hidden_channels, self.hidden_channels * 2, 1) 66 | 67 | def forward(self, x, x_mask, g=None): 68 | if g is not None: 69 | g = g.transpose(1, 2) 70 | prior_out = self.encoder(x, x_mask, g) 71 | prior_out = self.proj(prior_out) * x_mask 72 | mu_p, logs_p = torch.split(prior_out, self.hidden_channels, dim=1) 73 | return mu_p, logs_p 74 | 75 | 76 | class PosteriorEncoder(nn.Module): 77 | def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, 78 | gin_channels): 79 | super().__init__() 80 | self.in_channels = in_channels 81 | self.out_channels = out_channels 82 | self.hidden_channels = hidden_channels 83 | self.kernel_size = kernel_size 84 | self.dilation_rate = dilation_rate 85 | self.n_layers = n_layers 86 | self.gin_channels = gin_channels 87 | 88 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 89 | self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 90 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 91 | 92 | def forward(self, x, nonpadding, g=None): 93 | x = self.pre(x) * nonpadding 94 | x = self.enc(x, nonpadding, g=g) 95 | stats = self.proj(x) * nonpadding 96 | mu_q, logs_q = torch.split(stats, self.out_channels, dim=1) 97 | z_q = (mu_q + torch.randn_like(mu_q) * torch.exp(logs_q)) * nonpadding 98 | return z_q, mu_q, logs_q 99 | 100 | def remove_weight_norm(self): 101 | self.enc.remove_weight_norm() 102 | 103 | 104 | class PitchEncoder(nn.Module): 105 | def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, 106 | gin_channels=0): 107 | super().__init__() 108 | self.in_channels = in_channels 109 | self.out_channels = out_channels 110 | self.hidden_channels = hidden_channels 111 | self.kernel_size = kernel_size 112 | self.dilation_rate = dilation_rate 113 | self.n_layers = n_layers 114 | self.gin_channels = gin_channels 115 | 116 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 117 | self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 118 | self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) 119 | 120 | def forward(self, x, nonpadding, g=None): 121 | x = self.pre(x) * nonpadding 122 | x = self.enc(x, nonpadding, g=g) 123 | x = self.proj(x) * nonpadding 124 | return x 125 | 126 | def remove_weight_norm(self): 127 | self.enc.remove_weight_norm() 128 | 129 | 130 | class WaveNet(torch.nn.Module): 131 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 132 | super(WaveNet, self).__init__() 133 | assert(kernel_size % 2 == 1) 134 | self.hidden_channels =hidden_channels 135 | self.kernel_size = kernel_size, 136 | self.dilation_rate = dilation_rate 137 | self.n_layers = n_layers 138 | self.gin_channels = gin_channels 139 | self.p_dropout = p_dropout 140 | 141 | self.in_layers = nn.ModuleList() 142 | self.res_skip_layers = nn.ModuleList() 143 | self.drop = nn.Dropout(p_dropout) 144 | 145 | if gin_channels != 0: 146 | cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 147 | self.cond_layer = nn.utils.weight_norm(cond_layer, name='weight') 148 | 149 | for i in range(n_layers): 150 | dilation = dilation_rate ** i 151 | padding = int((kernel_size * dilation - dilation) / 2) 152 | in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, 153 | dilation=dilation, padding=padding) 154 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 155 | self.in_layers.append(in_layer) 156 | 157 | # last one is not necessary 158 | if i < n_layers - 1: 159 | res_skip_channels = 2 * hidden_channels 160 | else: 161 | res_skip_channels = hidden_channels 162 | 163 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 164 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 165 | self.res_skip_layers.append(res_skip_layer) 166 | 167 | def forward(self, x, x_mask, g=None, **kwargs): 168 | output = torch.zeros_like(x) 169 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 170 | 171 | if g is not None: 172 | g = self.cond_layer(g) 173 | 174 | for i in range(self.n_layers): 175 | x_in = self.in_layers[i](x) 176 | if g is not None: 177 | cond_offset = i * 2 * self.hidden_channels 178 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 179 | else: 180 | g_l = torch.zeros_like(x_in) 181 | 182 | acts = fused_add_tanh_sigmoid_multiply( 183 | x_in, 184 | g_l, 185 | n_channels_tensor) 186 | acts = self.drop(acts) 187 | 188 | res_skip_acts = self.res_skip_layers[i](acts) 189 | if i < self.n_layers - 1: 190 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 191 | x = (x + res_acts) * x_mask 192 | output = output + res_skip_acts[:,self.hidden_channels:,:] 193 | else: 194 | output = output + res_skip_acts 195 | return output * x_mask 196 | 197 | def remove_weight_norm(self): 198 | if self.gin_channels != 0: 199 | torch.nn.utils.remove_weight_norm(self.cond_layer) 200 | for l in self.in_layers: 201 | torch.nn.utils.remove_weight_norm(l) 202 | for l in self.res_skip_layers: 203 | torch.nn.utils.remove_weight_norm(l) 204 | 205 | 206 | @torch.jit.script 207 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 208 | n_channels_int = n_channels[0] 209 | in_act = input_a + input_b 210 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 211 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 212 | acts = t_act * s_act 213 | return acts 214 | -------------------------------------------------------------------------------- /inference/visinger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import miditoolkit 3 | import os 4 | import torch 5 | 6 | from tqdm import tqdm 7 | 8 | from models.visinger import VISinger 9 | from preprocessor.base_preprocessor import BasePreprocessor 10 | from preprocessor.base_binarizer import BaseBinarizer 11 | from preprocessor.text.base_text_processor import get_text_processor_cls 12 | from tasks.dataset_utils import VISingerDataset 13 | from tasks.utils import load_data_preprocessor 14 | from utils.audio.align import get_note2dur 15 | from utils.audio.io import save_wav 16 | from utils.commons.ckpt_utils import load_ckpt 17 | from utils.commons.hparams import hparams, set_hparams 18 | 19 | 20 | class VISingerInfer: 21 | def __init__(self, hparams, work_dir, device=None): 22 | if device is None: 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | self.hparams = hparams 25 | self.work_dir = work_dir 26 | self.device = device 27 | self.data_dir = hparams['binary_data_dir'] 28 | self.preprocessor, self.preprocess_args = load_data_preprocessor() 29 | self.text_processor = get_text_processor_cls("ko_sing") 30 | self.ph_encoder = self.preprocessor.load_dict(self.data_dir) 31 | # Dictionary settings 32 | self.pitch_dict = json.load(open(f"{self.data_dir}/pitch_map.json")) 33 | self.dur_dict = json.load(open(f"{self.data_dir}/dur_map.json")) 34 | self.spk_map = self.preprocessor.load_spk_map(self.data_dir) 35 | self.ds_cls = VISingerDataset 36 | self.model = self.build_model() 37 | self.model.eval() 38 | self.model.to(self.device) 39 | 40 | def build_model(self): 41 | dict_size = len(self.ph_encoder) 42 | model = VISinger(dict_size, len(self.pitch_dict), len(self.dur_dict), self.hparams) 43 | model.eval() 44 | load_ckpt(model, f"{self.hparams['work_dir']}/{self.work_dir}") 45 | self.gen_dir = f"{self.hparams['work_dir']}/{self.work_dir}/unseen_wav" 46 | os.makedirs(self.gen_dir, exist_ok=True) 47 | return model 48 | 49 | def preprocess_input(self, inp, divided=True, pitch_control=0): 50 | """ 51 | Parameters 52 | ---------- 53 | inp: dict 54 | {'item_name': (str, optional), 'spk_name': (str, optional), 'midi_fn': str} 55 | """ 56 | midi_obj = miditoolkit.midi.parser.MidiFile(inp["midi_fn"], charset="korean") 57 | midi_info = BasePreprocessor.MIDI_to_encoding(midi_obj, 0, self.hparams["preprocess_args"]) 58 | midi_info = self.text_process(midi_info, self.hparams) 59 | ret = [] 60 | item = self.process_second_pass(midi_info, inp.get("spk_name", "csd"), self.ph_encoder, self.spk_map, pitch_control) 61 | mel2phone, mel2note, _, ph_token, _, _, item["midi_info"] = get_note2dur(item["midi_info"], self.hparams["hop_size"], 62 | self.hparams["sample_rate"], 63 | self.hparams["binarization_args"]["min_sil_duration"]) 64 | BaseBinarizer.process_note(item, self.pitch_dict, self.dur_dict, self.tempo_dict, self.hparams["binarization_args"]) 65 | item["item_name"] = f"{inp['item_name']}" 66 | item["spk_name"] = inp["spk_name"] 67 | item["ph_token"] = ph_token 68 | item["mel2ph"] = mel2phone 69 | item["mel2note"] = mel2note 70 | ret.append(item) 71 | return ret 72 | 73 | def input_to_batch(self, sample): 74 | item_names = [sample['item_name']] 75 | ph_token = torch.LongTensor(sample["ph_token"])[None, :].to(self.device) 76 | note_pitch = torch.LongTensor(sample["note_pitch"])[None, :].to(self.device) 77 | note_dur = torch.LongTensor(sample["note_duration"])[None, :].to(self.device) 78 | mel2ph = torch.LongTensor(sample["mel2ph"])[None, :].to(self.device) 79 | mel2note = torch.LongTensor(sample["mel2note"])[None, :].to(self.device) 80 | batch = {"item_name": item_names, 81 | "spk_name": [sample["spk_name"]], 82 | "text_token": ph_token, 83 | "note_pitch": note_pitch, 84 | "note_dur": note_dur, 85 | "mel2ph": mel2ph, 86 | "mel2note": mel2note,} 87 | if hparams["use_spk_id"]: 88 | batch["spk_id"] = torch.LongTensor([int(sample["spk_id"])]).to(self.device) 89 | return batch 90 | 91 | def forward_model(self, datasets): 92 | gen_dir = self.gen_dir 93 | for sample in tqdm(datasets): 94 | batch = self.input_to_batch(sample) 95 | output = self.model(batch["text_token"], batch["note_pitch"], batch["note_dur"], mel2ph=batch["mel2ph"], mel2note=batch["mel2note"], 96 | spk_embed=batch.get("spk_embed"), spk_id=batch.get("spk_id"), infer=True) 97 | item_name = sample['item_name'] 98 | wav_pred = output['wav_out'][0].detach().cpu().numpy() 99 | input_fn = f"{gen_dir}/spk{sample['spk_name']}_{item_name}.wav" 100 | save_wav(wav_pred, input_fn, self.hparams["sample_rate"], norm=self.hparams["out_wav_norm"]) 101 | 102 | def text_process(self, midi_info, hparmas): 103 | _, midi_info = self.text_processor.process(midi_info, hparmas) 104 | return midi_info 105 | 106 | def divide_info(self, midi_info, preprocess_args): 107 | # Divide midi information 108 | midi_infos = [] 109 | ph = [] 110 | bar_info = [] 111 | ph_bar_info = [] 112 | phrase_idx = 0 113 | for i, midi in enumerate(midi_info): 114 | phrase_num = midi[0] // preprocess_args["max_bar"] 115 | if phrase_num == phrase_idx: 116 | bar_info.append(midi) 117 | ph_bar_info.extend(midi[-1]) 118 | elif phrase_num > phrase_idx: 119 | midi_infos.append(bar_info) 120 | ph.append(" ".join(ph_bar_info)) 121 | bar_info = [] 122 | ph_bar_info = [] 123 | phrase_idx += 1 124 | bar_info.append(midi) 125 | ph_bar_info.extend(midi[-1]) 126 | midi_infos.append(bar_info) 127 | ph.append(" ".join(ph_bar_info)) 128 | end_time_ = 0.0 129 | midi_infos_ = [] 130 | phrase_ = [] 131 | for _, phrase in enumerate(midi_infos): 132 | if len(phrase) > 0: 133 | assert len(phrase[-1]) == 8, f"| Wrong data construction :{phrase[-1]}" 134 | end_time = phrase[-1][5] 135 | # Time settings 136 | max_note_dur = 0 137 | for _, midi in enumerate(phrase): 138 | midi_ = midi 139 | midi_[4] -= end_time_ 140 | midi_[5] -= end_time_ 141 | max_note_dur = (midi_[5] - midi_[4]) if (midi_[5] - midi_[4]) > max_note_dur else max_note_dur 142 | phrase_.append(midi_) 143 | if max_note_dur <= preprocess_args["max_note_dur"]: 144 | midi_infos_.append(phrase_) 145 | phrase_ = [] 146 | end_time_ = (end_time - 0.2) 147 | return midi_infos, ph 148 | 149 | def process_second_pass(self, midi_info, spk_name, ph_encoder, spk_map, pitch_control=0): 150 | midi_ = [] 151 | ph_token = [] 152 | phs = [] 153 | for i, (bar, _, pitch, duration, start_time, end_time, tempo, ph_) in enumerate(midi_info): 154 | if i == 0: 155 | # Add [BOS] token 156 | phs.extend([""]) 157 | ph = ph_encoder.encode("") 158 | midi = [bar, 0, 0, 0, 0.0, start_time, tempo, ph, [""]] 159 | midi_.append(midi) 160 | ph_token.extend(ph) 161 | ph_ = [p for p in ph_ if p != "" and p != " "] 162 | phs.extend(ph_) 163 | ph = ph_encoder.encode(" ".join(ph_)) 164 | if pitch > 0: 165 | pitch = pitch + pitch_control 166 | midi = [bar, i + 1, pitch, duration, start_time, end_time, tempo, ph, ph_] 167 | midi_.append(midi) 168 | ph_token.extend(ph) 169 | if i == len(midi_info) - 1: 170 | # Add [EOS] token 171 | phs.extend([""]) 172 | ph = ph_encoder.encode("") 173 | midi = [bar, i + 2, 0, 0, end_time, end_time + 0.1, tempo, ph, [""]] 174 | midi_.append(midi) 175 | ph_token.extend(ph) 176 | assert len(midi_info) < len(midi_), print(f"| Original token: {len(midi_info)}. Additional token: {len(midi_)}.") 177 | assert len(phs) == len(ph_token), print(f"| Phonmem token: {len(ph_token)}, Phonemes: {len(phs)}") 178 | midi_[-1][-1] = [ph_encoder.encode("")] if midi_[-1][-2] == ["|"] else midi_[-1][-1] 179 | midi_[-1][-2] = [""] if midi_[-1][-2] == ["|"] else midi_[-1][-2] 180 | phs[-1] = [""] if phs[-1] == ["|"] else phs[-1] 181 | spk_id = spk_map[spk_name] 182 | return {"midi_info": midi_, "ph": phs, "ph_token": ph_token, "spk_id": spk_id} 183 | 184 | def inference(self, inp, pitch_control=0): 185 | items = self.preprocess_input(inp, pitch_control=pitch_control) 186 | output = self.forward_model(items) 187 | return output 188 | 189 | 190 | if __name__ == "__main__": 191 | config = set_hparams("./config/models/visinger.yaml") 192 | work_dir = "svs/visinger" 193 | pitch_control = 0 # 1 is half-note 194 | generator = VISingerInfer(config, work_dir, device=0) 195 | midi_nm = "" # MIDI data have to get lyrics 196 | generator.inference({"item_name": f"{midi_nm}_{pitch_control}", 197 | "midi_fn": f"./data/source/new_svs/{midi_nm}.mid", 198 | "spk_name": str(0)}, 199 | pitch_control=pitch_control) 200 | -------------------------------------------------------------------------------- /tasks/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import numpy as np 3 | import torch 4 | import torch.distributions 5 | import torch.optim 6 | import torch.utils.data 7 | 8 | 9 | from utils.audio.pitch.utils import norm_interp_f0 10 | from utils.audio.mel_processing import SpectrogramFixed, load_wav_to_torch 11 | from utils.commons.dataset_utils import BaseDataset, collate_1d_or_2d 12 | from utils.commons.hparams import hparams 13 | from utils.commons.indexed_datasets import IndexedDataset 14 | 15 | 16 | class BaseSpeechDataset(BaseDataset): 17 | """ Base dataset. """ 18 | def __init__(self, prefix, shuffle=False, items=None, data_dir=None): 19 | super().__init__(shuffle) 20 | self.data_dir = hparams["preprocess"]["binary_data_dir"] if data_dir is None else data_dir 21 | self.prefix = prefix 22 | self.hparams = hparams 23 | self.indexed_ds = None 24 | if items is not None: 25 | self.indexed_ds = items 26 | self.sizes = [1] * len(items) 27 | self.avail_idxs = list(range(len(self.sizes))) 28 | else: 29 | self.sizes = np.load(f"{self.data_dir}/{self.prefix}_lengths.npy") 30 | if prefix == "test" and len(hparams["test_ids"]) > 0: 31 | self.avail_idxs = hparams["test_ids"] 32 | else: 33 | self.avail_idxs = list(range(len(self.sizes))) 34 | if prefix == "train" and hparams["min_frames"] > 0: 35 | self.avail_idxs = [x for x in self.avail_idxs if self.sizes[x] >= hparams["min_frames"]] 36 | self.sizes = [self.sizes[i] for i in self.avail_idxs] 37 | 38 | def _get_item(self, index): 39 | if hasattr(self, "avail_idxs") and self.avail_idxs is not None: 40 | index = self.avail_idxs[index] 41 | if self.indexed_ds is None: 42 | self.indexed_ds = IndexedDataset(f"{self.data_dir}/{self.prefix}") 43 | return self.indexed_ds[index] 44 | 45 | def __getitem__(self, index): 46 | hparams = self.hparams 47 | item = self._get_item(index) 48 | assert len(item["mel"]) == self.sizes[index], (len(item["mel"]), self.sizes[index]) 49 | max_frames = hparams["max_frames"] 50 | spec = torch.Tensor(item["mel"])[:max_frames] 51 | max_frames = spec.shape[0] // hparams["frames_multiple"] * hparams["frames_multiple"] 52 | spec = spec[:max_frames] 53 | ph_token = torch.LongTensor(item["ph_token"][:hparams["max_input_tokens"]]) 54 | sample = {"id": index, 55 | "item_name": item["item_name"], 56 | "text": item["text"], 57 | "text_token": ph_token, 58 | "mel": spec, 59 | "mel_nonpadding": spec.abs().sum(-1) > 0} 60 | if hparams["use_spk_embed"]: 61 | sample["spk_embed"] = torch.Tensor(item["spk_embed"]) 62 | if hparams["use_spk_id"]: 63 | sample["spk_id"] = int(item["spk_id"]) 64 | 65 | return sample 66 | 67 | def collater(self, samples): 68 | if len(samples) == 0: 69 | return {} 70 | hparams = self.hparams 71 | id = torch.LongTensor([s["id"] for s in samples]) 72 | item_names = [s["item_name"] for s in samples] 73 | text =[s["text"] for s in samples] 74 | text_tokens = collate_1d_or_2d([s["text_token"] for s in samples], 0) 75 | mels = collate_1d_or_2d([s["mel"] for s in samples], 0.0) 76 | text_lengths = torch.LongTensor([s["text_token"].numel() for s in samples]) # Return the number of total elements 77 | mel_lengths = torch.LongTensor([s["mel"].shape[0] for s in samples]) 78 | 79 | batch = {"id": id, 80 | "item_name": item_names, 81 | "nsamples": len(samples), 82 | "text": text, 83 | "text_tokens": text_tokens, 84 | "text_lengths": text_lengths, 85 | "mels": mels, 86 | "mel_lenghts": mel_lengths} 87 | 88 | if hparams["use_spk_embed"]: 89 | spk_embed = torch.stack([s["spk_embed"] for s in samples]) 90 | batch["spk_embed"] = spk_embed 91 | if hparams["use_spk_id"]: 92 | spk_ids = torch.LongTensor([s["spk_id"] for s in samples]) 93 | batch["spk_ids"] = spk_ids 94 | 95 | return batch 96 | 97 | 98 | class VISingerDataset(BaseDataset): 99 | """ Dataset of VISinger. """ 100 | def __init__(self, prefix, shuffle=False, items=None, data_dir=None): 101 | super().__init__(shuffle) 102 | self.data_dir = hparams["binary_data_dir"] if data_dir is None else data_dir 103 | self.prefix = prefix 104 | self.hparams = hparams 105 | self.segment_size = hparams["segment_size"] 106 | self.spec_fn = SpectrogramFixed(n_fft=hparams["fft_size"], win_length=hparams["win_size"], 107 | hop_length=hparams["hop_size"], window_fn=torch.hann_window) 108 | self.indexed_ds = None 109 | if items is not None: 110 | self.indexed_ds = items 111 | self.sizes = [1] * len(items) 112 | self.avail_idxs = list(range(len(self.sizes))) 113 | else: 114 | self.sizes = np.load(f"{self.data_dir}/{self.prefix}_lengths.npy") 115 | if prefix == "test" and len(hparams["test_ids"]) > 0: 116 | self.avail_idxs = hparams["test_ids"] 117 | else: 118 | self.avail_idxs = list(range(len(self.sizes))) 119 | if prefix == "train" and self.segment_size > 0: 120 | self.avail_idxs = [x for x in self.avail_idxs if self.sizes[x] > self.segment_size 121 | and self.sizes[x] <= hparams["max_frames"]] 122 | self.sizes = [self.sizes[i] for i in self.avail_idxs] 123 | 124 | def _get_item(self, index): 125 | if hasattr(self, "avail_idxs") and self.avail_idxs is not None: 126 | index = self.avail_idxs[index] 127 | if self.indexed_ds is None: 128 | self.indexed_ds = IndexedDataset(f"{self.data_dir}/{self.prefix}") 129 | return self.indexed_ds[index] 130 | 131 | def __getitem__(self, index): 132 | hparams = self.hparams 133 | item = self._get_item(index) 134 | # Phoneme & Note settings 135 | ph_token = torch.LongTensor(item["ph_token"][:hparams["max_input_tokens"]]) 136 | note_pitch = torch.LongTensor(item["note_pitch"][:hparams["max_input_tokens"]]) 137 | note_dur = torch.LongTensor(item["note_duration"][:hparams["max_input_tokens"]]) 138 | # Waveform and linear-spectrogram with max frames 139 | max_frames = hparams["max_frames"] 140 | wav, _ = load_wav_to_torch(item["wav_fn"], hparams["hop_size"]) 141 | spec = torch.Tensor(self.spec_fn(wav)).transpose(0, 1) 142 | assert spec.shape[0] == self.sizes[index], (spec.shape, self.sizes[index]) 143 | # Mapping function 144 | mel2ph = torch.LongTensor(item["mel2ph"][:max_frames]) 145 | # Sample settings 146 | sample = {"id": index, 147 | "item_name": item["item_name"], 148 | "wav_fn": item["wav_fn"], 149 | "text_token": ph_token, 150 | "wav": wav, 151 | "mel": spec, 152 | "note_pitch": note_pitch, 153 | "note_dur": note_dur, 154 | "mel2ph": mel2ph} 155 | # Multi-speaker settings 156 | if hparams["use_spk_embed"]: 157 | sample["spk_embed"] = torch.Tensor(item["spk_embed"]) 158 | if hparams["use_spk_id"]: 159 | sample["spk_id"] = int(item["spk_id"]) 160 | # Pitch settings 161 | uv, f0 = None, None 162 | if hparams["use_pitch_embed"]: 163 | T = spec.shape[0] 164 | f0, uv = norm_interp_f0(item["f0"][:T]) 165 | uv = torch.FloatTensor(uv) 166 | f0 = torch.FloatTensor(f0) 167 | sample["f0"], sample["uv"] = f0, uv 168 | return sample 169 | 170 | def collater(self, samples): 171 | if len(samples) == 0: 172 | return {} 173 | hparams = self.hparams 174 | id = torch.LongTensor([s["id"] for s in samples]) 175 | item_names = [s["item_name"] for s in samples] 176 | wav_fns = [s["wav_fn"] for s in samples] 177 | text_tokens = collate_1d_or_2d([s["text_token"] for s in samples], 0) 178 | text_lengths = torch.LongTensor([s["text_token"].numel() for s in samples]) # Return the number of total elements 179 | wavs = collate_1d_or_2d([s["wav"] for s in samples], 0.0) 180 | mels = collate_1d_or_2d([s["mel"] for s in samples], 0.0) 181 | mel_lengths = torch.LongTensor([s["mel"].shape[0] for s in samples]) 182 | note_pitches = collate_1d_or_2d([s["note_pitch"] for s in samples], 0) 183 | note_durations = collate_1d_or_2d([s["note_dur"] for s in samples], 0) 184 | mel2phs = collate_1d_or_2d([s["mel2ph"] for s in samples], 0) 185 | batch = {"id": id, 186 | "item_name": item_names, 187 | "wav_fn": wav_fns, 188 | "nsamples": len(samples), 189 | "text_tokens": text_tokens, 190 | "text_lengths": text_lengths, 191 | "wavs": wavs, 192 | "mels": mels, 193 | "mel_lengths": mel_lengths, 194 | "note_pitch": note_pitches, 195 | "note_dur": note_durations, 196 | "mel2ph": mel2phs,} 197 | if hparams["use_spk_embed"]: 198 | spk_embed = torch.stack([s["spk_embed"] for s in samples]) 199 | batch["spk_embed"] = spk_embed 200 | if hparams["use_spk_id"]: 201 | spk_ids = torch.LongTensor([s["spk_id"] for s in samples]) 202 | batch["spk_ids"] = spk_ids 203 | f0, uv, pitch = None, None, None 204 | if hparams['use_pitch_embed']: 205 | f0 = collate_1d_or_2d([s['f0'] for s in samples], 0.0) 206 | uv = collate_1d_or_2d([s['uv'] for s in samples]) 207 | batch.update({'pitch': pitch, 'f0': f0, 'uv': uv}) 208 | return batch 209 | -------------------------------------------------------------------------------- /utils/text/text_encoder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import json 3 | import re 4 | import six 5 | 6 | from six.moves import range 7 | 8 | from utils.text.ko_symbols import symbols 9 | 10 | 11 | PAD = "" 12 | EOS = "" 13 | UNK = "" 14 | SEG = "|" 15 | PUNCS = "!./?;:" 16 | RESERVED_TOKENS = [PAD, EOS, UNK, SEG] 17 | NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) 18 | PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 19 | EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 20 | UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 21 | 22 | if six.PY2: 23 | RESERVED_TOKENS_BYTES = RESERVED_TOKENS 24 | else: 25 | RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] 26 | 27 | 28 | def strip_ids(ids: list, ids_to_strip: list): 29 | """ Strip ids_to_strip from the end ids. """ 30 | ids = list@staticmethod 31 | while ids and ids[-1] in ids_to_strip: 32 | ids.pop() 33 | return ids 34 | 35 | 36 | class TextEncoder(object): 37 | """ Base class for converting from ints to/from human readable strings.""" 38 | 39 | def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): 40 | self._num_reserved_ids = num_reserved_ids 41 | 42 | @property 43 | def num_reserved_ids(self): 44 | return self._num_reserved_ids 45 | 46 | def encode(self, s: str): 47 | """ Transform a human-readable string into a seqeunce of int ids. 48 | 49 | The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, 50 | num_reserved_ids) are reserved. 51 | 52 | EOS is not appended. 53 | 54 | Parameters 55 | ---------- 56 | s: str 57 | human-readable string to be converted. 58 | 59 | Returns 60 | ------- 61 | ids: list 62 | list of integers 63 | """ 64 | return [int(w) + self._num_reserved_ids for w in s.split()] 65 | 66 | def decode(self, ids: list, strip_extraneous=False): 67 | """ Transform a sequence of int ids into a their string versions. 68 | 69 | This method supports transforming individual input/output ids to their 70 | string versions so that sequence to/from text conversions can be visualized 71 | in a human-readable format. 72 | 73 | Parameters 74 | ---------- 75 | ids: list 76 | list of integers of be converted. 77 | strip_extraneous: bool 78 | whether to stipr off extraneous tokens (BOS and PAD). 79 | 80 | Returns 81 | ------- 82 | strs: str 83 | human-readable string. 84 | """ 85 | if strip_extraneous: 86 | ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) 87 | return " ".join(self.decode_list(ids)) 88 | 89 | def decode_list(self, ids: list): 90 | """ Transform a sequence of int ids into a their string versions. 91 | 92 | This method supports transforming individual input/output ids to thier 93 | string versions so that sequence to/from text conversisons can be visualized 94 | in a human-redable format. 95 | 96 | Parameters 97 | ---------- 98 | ids: list 99 | list of integers to be converted. 100 | 101 | Returns 102 | ------- 103 | strs: list 104 | list of human-readable string. 105 | """ 106 | decoded_ids = [] 107 | for id_ in ids: 108 | if 0 <= id_ < self._num_reserved_ids: 109 | decoded_ids.append(RESERVED_TOKENS[int(id_)]) 110 | else: 111 | decoded_ids.append(id_ - self.num_reserved_ids) 112 | 113 | return [str(d) for d in decoded_ids] 114 | 115 | @property 116 | def vocab_size(self): 117 | raise NotImplementedError() 118 | 119 | 120 | class TokenTextEncoder(TextEncoder): 121 | """ Encoder based on a user-supplied vocabulary (file or list). """ 122 | 123 | def __init__(self, 124 | vocab_filename: str, 125 | reverse=False, 126 | vocab_list=None, 127 | replace_oov=None, 128 | num_reserved_ids=NUM_RESERVED_TOKENS): 129 | """ Initialize from a file or list, one token per line. 130 | 131 | Handling of reserved tokens works as follows: 132 | - When initializing from a list, we add reserved tokens to the vocab. 133 | - When initializing from a file, we do not add rserved tokens to the vocab. 134 | - When saving vocab files, we save reserved tokens to the file. 135 | 136 | Parameters 137 | ---------- 138 | vocab_filename: str 139 | If not None, the full filename to read vocab from. If this is not None, 140 | then vocab_list should be None. 141 | reverse: bool 142 | Indicating if tokens sholud be reversed during encoding and decoding. 143 | vocab_list: list 144 | If not None, a list of elements of the vocabulary. If this is not None, 145 | then vocab_filename should be None. 146 | replace_oov: str 147 | If not None, every out-of-vocabulary token seen when encoding will be 148 | replaced by this string (which must be in vocab). 149 | num_reserved_ids: int 150 | Number of IDs to save for reserved tokens like . 151 | """ 152 | super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) 153 | self._reverse = reverse 154 | self._replace_oov = replace_oov 155 | if vocab_filename: 156 | self._init_vocab_from_file(vocab_filename) 157 | else: 158 | assert vocab_list is not None 159 | self._init_vocab_from_list(vocab_list) 160 | self.pad_index = self.token_to_id[PAD] 161 | self.eos_index = self.token_to_id[EOS] 162 | self.unk_index = self.token_to_id[UNK] 163 | self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index 164 | 165 | def encode(self, s: str): 166 | """ Converts a space-separated string of tokens to a list of ids. """ 167 | sentence = s 168 | tokens = sentence.strip().split() 169 | if self._replace_oov is not None: 170 | tokens = [t if t in self.token_to_id else self._replace_oov 171 | for t in tokens] 172 | ret = [self.token_to_id[tok] for tok in tokens] 173 | 174 | return ret[::-1] if self._reverse else ret 175 | 176 | def decode(self, ids: list, strip_eos=False, strip_padding=False): 177 | if strip_padding and self.pad() in list(ids): 178 | pad_pos = list(ids).index(self.pad()) 179 | ids = ids[:pad_pos] 180 | if strip_eos and self.eos() in list(ids): 181 | eos_pos = list(ids).index(self.eos()) 182 | ids = ids[:eos_pos] 183 | 184 | return " ".join(self.decode_list(ids)) 185 | 186 | def decode_list(self, ids: list): 187 | seq = reversed(ids) if self._reverse else ids 188 | 189 | return [self._safe_id_to_token(i) for i in seq] 190 | 191 | @property 192 | def vocab_size(self): 193 | return len(self.id_to_token) 194 | 195 | def __len__(self): 196 | return self.vocab_size 197 | 198 | def _safe_id_to_token(self, idx: list): 199 | return self.id_to_token.get(idx, "ID_%d" % idx) 200 | 201 | def _init_vocab_from_file(self, filename: str): 202 | """ Load vocab from a file. 203 | 204 | Parameters 205 | ---------- 206 | filename: str 207 | The file to load vocabulary from. 208 | """ 209 | with open(filename) as f: 210 | tokens = [token.strip() for token in f.readlines()] 211 | 212 | def token_gen(): 213 | for token in tokens: 214 | yield token 215 | 216 | self._init_vocab(token_gen(), add_reserved_tokens=False) 217 | 218 | def _init_vocab_from_list(self, vocab_list: list): 219 | """ Initialize tokens from a list of tokens. 220 | 221 | It is ok if reserved tokens appear in the vocab list. They will be 222 | removed. The set of tokens in vocab_list should be unique. 223 | 224 | Parameters 225 | ---------- 226 | vocab_list: list 227 | A list of tokens 228 | """ 229 | def token_gen(): 230 | for token in vocab_list: 231 | if token not in RESERVED_TOKENS: 232 | yield token 233 | 234 | self._init_vocab(token_gen()) 235 | 236 | def _init_vocab(self, token_generator, add_reserved_tokens=True): 237 | """ Initialize vocabulary with tokens from token_generator. """ 238 | 239 | self.id_to_token = {} 240 | non_reserved_start_index = 0 241 | 242 | if add_reserved_tokens: 243 | self.id_to_token.update(enumerate(RESERVED_TOKENS)) 244 | non_reserved_start_index = len(RESERVED_TOKENS) 245 | 246 | self.id_to_token.update( 247 | enumerate(token_generator, start=non_reserved_start_index)) 248 | 249 | # _token_to_id is the reverse of _id_to_token 250 | self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token)) 251 | 252 | def pad(self): 253 | return self.pad_index 254 | 255 | def eos(self): 256 | return self.eos_index 257 | 258 | def unk(self): 259 | return self.unk_index 260 | 261 | def seg(self): 262 | return self.seg_index 263 | 264 | def store_to_file(self, filename: str): 265 | """ Write vocab file to disk. 266 | 267 | Vocab files have one token per lien. The file ends in a newline. Reserved 268 | tokens are written to the vocab file as well. 269 | 270 | Parameters 271 | ---------- 272 | filename: str 273 | Full path of the file to store the vocab to. 274 | """ 275 | with open(filename, "w") as f: 276 | for i in range(len(self.id_to_token)): 277 | f.write(self.id_to_token[i] + "\n") 278 | 279 | def sil_phonemes(self): 280 | return [p for p in self.id_to_token.values() if is_sil_phoneme(p)] 281 | 282 | 283 | def build_token_encoder(token_list_file: str): 284 | """ 285 | Parameters 286 | ---------- 287 | token_list_file: str 288 | Path of phoneme_set file 289 | """ 290 | token_list = json.load(open(token_list_file)) 291 | 292 | return TokenTextEncoder(None, vocab_list=token_list, replace_oov="") 293 | 294 | 295 | def is_sil_phoneme(p: str): 296 | return p == "" or not (p[0].isalpha() or ishangul(p[0])) 297 | 298 | 299 | def ishangul(phoneme): 300 | if phoneme not in symbols: 301 | # Consider Hangul syllable 302 | hanCount = len(re.findall(u"[\u3130-\u318F\uAC00-\uD7A3]+", phoneme)) 303 | return hanCount > 0 304 | else: 305 | # Consider Hanguel jamo 306 | return phoneme in symbols 307 | -------------------------------------------------------------------------------- /preprocessor/text/ko_sing.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/NATSpeech/NATSpeech 2 | import json 3 | import re 4 | 5 | from g2pk import G2p # Docker can't install G2p. Do not use this library in the docker. 6 | from jamo import h2j 7 | 8 | from preprocessor.text.base_text_processor import BaseTextProcessor, register_text_processors 9 | from utils.commons.hparams import hparams 10 | from utils.text.text_encoder import PUNCS 11 | 12 | 13 | @register_text_processors("ko_sing") 14 | class KoreanSingingProcessor(BaseTextProcessor): 15 | # G2pk settings 16 | g2p = G2p() 17 | # Dictionary settings 18 | dictionary = json.load(open("./preprocessor/text/dict/korean.json", "r")) 19 | num_checker = "([+-]?\d{1,3},\d{3}(?!\d)|[+-]?\d+)[\.]?\d*" 20 | PUNCS += ",\'\"" 21 | 22 | @staticmethod 23 | def sp_phonemes(): 24 | return ['|'] 25 | 26 | @classmethod 27 | def preprocess_text(cls, text): 28 | # Normalize basic pattern 29 | text = text.strip() 30 | text = re.sub("[\'\"()]+", "", text) 31 | text = re.sub("[-]+", " ", text) 32 | text = re.sub(f"[^ A-Za-z가-힣]", "", text) 33 | text = re.sub(f" ?([{cls.PUNCS}]) ?", r"\1", text) # !! -> ! 34 | text = re.sub(f"([{cls.PUNCS}])+", r"\1", text) # !! -> ! 35 | text = re.sub('\(\d+일\)', '', text) 36 | text = re.sub('\([⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]+\)', '', text) 37 | text = re.sub(f"([{cls.PUNCS}])", r" \1 ", text) 38 | text = re.sub(rf"\s+", r" ", text) 39 | # Normalize with prepared dictionaries 40 | text = cls.normalize_with_dictionary(text, cls.dictionary["etc_dict"]) 41 | text = cls.normalize_english(text, cls.dictionary["eng_dict"]) 42 | text = cls.normalize_upper(text, cls.dictionary["upper_dict"]) 43 | # number to hanguel 44 | text = cls.normalize_number(text, cls.num_checker, cls.dictionary) 45 | return text 46 | 47 | @staticmethod 48 | def normalize_with_dictionary(text, dictionary): 49 | """ Check special korean pronounciation in dictionary """ 50 | if any(key in text for key in dictionary.keys()): 51 | pattern = re.compile("|".join(re.escape(key) for key in dictionary.keys())) 52 | return pattern.sub(lambda x: dictionary[x.group()], text) 53 | else: 54 | return text 55 | 56 | @staticmethod 57 | def normalize_english(text, dictionary): 58 | """ Convert English to Korean pronounciation """ 59 | def _eng_replace(w): 60 | word = w.group() 61 | if word in dict: 62 | return dictionary[word] 63 | else: 64 | return word 65 | text = re.sub("([A-Za-z]+)", _eng_replace, text) 66 | return text 67 | 68 | @staticmethod 69 | def normalize_upper(text, dictionary): 70 | """ Convert lower English to Upper English and Changing to Korean pronounciation""" 71 | def upper_replace(w): 72 | word = w.group() 73 | if all([char.isupper() for char in word]): 74 | return "".join(dictionary[char] for char in word) 75 | else: 76 | return word 77 | text = re.sub("[A-Za-z]+", upper_replace, text) 78 | 79 | return text 80 | 81 | @classmethod 82 | def normalize_number(cls, text, num_checker, dictionary): 83 | """ Convert Numbert to Korean pronounciation """ 84 | text = cls.normalize_with_dictionary(text, dictionary["unit_dict"]) 85 | text = re.sub(num_checker + dictionary["count_checker"], 86 | lambda x: cls.num_to_hangeul(x, dictionary, True), text) 87 | text = re.sub(num_checker, 88 | lambda x: cls.num_to_hangeul(x, dictionary, False), text) 89 | return text 90 | 91 | @staticmethod 92 | def num_to_hangeul(num_str, dictionary, is_count=False): 93 | """ Following https://github.com/keonlee9420/Expressive-FastSpeech2/blob/main/text/korean.py 94 | Normalize number prounciation. """ 95 | zero_cnt = 0 96 | # Check Korean count unit 97 | if is_count: 98 | num_str, unit_str = num_str.group(1), num_str.group(2) 99 | else: 100 | num_str, unit_str = num_str.group(), "" 101 | # Remove decimal separator 102 | num_str = num_str.replace(",", "") 103 | 104 | if is_count and len(num_str) > 2: 105 | is_count = False 106 | 107 | if len(num_str) > 1 and num_str.startwith("0") and "." not in num_str: 108 | for n in num_str: 109 | zero_cnt += 1 if n == "0" else 0 110 | num_str = num_str[zero_cnt:] 111 | 112 | kor = "" 113 | if num_str != "": 114 | if num_str == "0": 115 | return "영 " + (unit_str if unit_str else "") 116 | # Split float number 117 | check_float = num_str.split(".") 118 | if len(check_float) == 2: 119 | digit_str, float_str = check_float 120 | elif len(check_float) >= 3: 121 | raise Exception(f"| Wrong number format: {num_str}") 122 | else: 123 | digit_str, float_str = check_float[0], None 124 | if is_count and float_str is not None: 125 | raise Exception(f"| 'is_count' and float number does not fit each other") 126 | # Check minus or plus symbol 127 | digit = int(digit_str) 128 | if digit_str.startswith("-") or digit_str.startswith("+"): 129 | digit, digit_str = abs(digit), str(abs(digit)) 130 | size = len(str(digit)) 131 | tmp = [] 132 | for i, v in enumerate(digit_str, start=1): 133 | v = int(v) 134 | if v != 0: 135 | if is_count: 136 | tmp += dictionary["count_dict"][v] 137 | else: 138 | tmp += dictionary["num_dict"][str(v)] 139 | if v == 1 and i != 1 and i != len(digit_str): 140 | tmp = tmp[:-1] 141 | tmp += dictionary["num_ten_dict"][(size - i) % 4] 142 | if (size - i) % 4 == 0 and len(tmp) != 0: 143 | kor += "".join(tmp) 144 | tmp = [] 145 | kor += dictionary["num_tenthousand_dict"][int((size - i) / 4)] 146 | if is_count: 147 | if kor.startswith("한") and len(kor) > 1: 148 | kor = kor[1:] 149 | 150 | if any(word in kor for word in dictionary["count_tenth_dict"]): 151 | kor = re.sub("|".join(dictionary["count_tenth_dict"].keys()), 152 | lambda x: dictionary["count_tenth_dict"][x.group()], kor) 153 | if not is_count and kor.startswith("일") and len(kor) > 1: 154 | kor = kor[1:] 155 | if float_str is not None and float_str != "": 156 | kor += "영" if kor == "" else "" 157 | kor += "쩜 " 158 | kor += re.sub("\d", lambda x: dictionary["num_dict"][x.group()], float_str) 159 | if num_str.startswith("+"): 160 | kor = "플러스 " + kor 161 | elif num_str.startswith("-"): 162 | kor = "마이너스 " + kor 163 | if zero_cnt > 0: 164 | kor = "공" * zero_cnt + kor 165 | return kor + unit_str 166 | 167 | @classmethod 168 | def process(cls, midi_info, hparams): 169 | midi_info_ = [] 170 | ph_list = [] 171 | n_frame = hparams["preprocess_args"]["num_frame"] 172 | sr = hparams["sample_rate"] 173 | hop_size = hparams["hop_size"] 174 | frame_time = n_frame * hop_size / sr 175 | text = "".join([midi[7] for midi in midi_info]) 176 | text = [cls.g2p(word) for word in text.split("|")] 177 | text = "|".join(text) 178 | assert len(text) == len(midi_info), f"| Wrong text process: {len(text)}, {len(midi_info)}" 179 | # Korean singing voice processing 180 | for i, (bar, pos, pitch, duration, start_time, end_time, tempo, _) in enumerate(midi_info): 181 | phs = h2j(cls.preprocess_text(text[i])) 182 | ph = [p for p in phs if p != " " or p != ""] if len(phs) != 0 else ["|"] 183 | if len(ph) == 1: 184 | notes = [[bar, pos, pitch, duration, start_time, end_time, tempo, ph]] 185 | elif len(ph) == 2: 186 | notes = [] 187 | if int((end_time - start_time) * sr / hop_size + 0.5) > n_frame: 188 | for j, p in enumerate(ph): 189 | if j == 0: 190 | note = [bar, pos, pitch, duration, start_time, start_time + frame_time, tempo, p] 191 | else: 192 | note = [bar, pos, pitch, duration, start_time + frame_time, end_time, tempo, p] 193 | notes.append(note) 194 | else: 195 | except_frame_time = (n_frame - 2) * hop_size / sr 196 | for j, p in enumerate(ph): 197 | if j == 0: 198 | note = [bar, pos, pitch, duration, start_time, start_time + except_frame_time, tempo, p] 199 | else: 200 | note = [bar, pos, pitch, duration, start_time + except_frame_time, end_time, tempo, p] 201 | notes.append(note) 202 | elif len(ph) == 3: 203 | notes = [] 204 | if int((end_time - start_time) * sr / hop_size + 0.5) >= n_frame * 3: 205 | for j, p in enumerate(ph): 206 | if j == 0: 207 | note = [bar, pos, pitch, duration, start_time, start_time + frame_time, tempo, p] 208 | elif j == 1: 209 | note = [bar, pos, pitch, duration, start_time + frame_time, end_time - frame_time, tempo, p] 210 | else: 211 | note = [bar, pos, pitch, duration, end_time - frame_time, end_time, tempo, p] 212 | notes.append(note) 213 | elif int((end_time - start_time) * sr / hop_size + 0.5) >= n_frame * 2: 214 | except_frame_time = (n_frame - 1) * hop_size / sr 215 | for j, p in enumerate(ph): 216 | if j == 0: 217 | note = [bar, pos, pitch, duration, start_time, start_time + except_frame_time, tempo, p] 218 | elif j == 1: 219 | note = [bar, pos, pitch, duration, start_time + except_frame_time, end_time - except_frame_time, tempo, p] 220 | else: 221 | note = [bar, pos, pitch, duration, end_time - except_frame_time, end_time, tempo, p] 222 | notes.append(note) 223 | elif int((end_time - start_time) * sr / hop_size + 0.5) >= n_frame: 224 | except_frame_time = (n_frame - 2) * hop_size / sr 225 | for j, p in enumerate(ph): 226 | if j == 0: 227 | note = [bar, pos, pitch, duration, start_time, start_time + except_frame_time, tempo, p] 228 | elif j == 1: 229 | note = [bar, pos, pitch, duration, start_time + except_frame_time, end_time - except_frame_time, tempo, p] 230 | else: 231 | note = [bar, pos, pitch, duration, end_time - except_frame_time, end_time, tempo, p] 232 | notes.append(note) 233 | else: 234 | for j, p in enumerate(ph): 235 | except_frame_time = (n_frame - 2) * hop_size / sr 236 | if j == 0: 237 | note = [bar, pos, pitch, duration, start_time, start_time + 1, tempo, p] 238 | elif j == 1: 239 | note = [bar, pos, pitch, duration, start_time + 1, end_time - 1, tempo, p] 240 | elif j == 2: 241 | note = [bar, pos, pitch, duration, end_time - 1, end_time, tempo, p] 242 | notes.append(note) 243 | assert len(ph) == len(notes), f"| Wrong settings: ph = {len(ph)}, notes = {len(notes)}" 244 | ph_list.extend(ph) 245 | midi_info_.extend(notes) 246 | return ph_list, midi_info_ 247 | -------------------------------------------------------------------------------- /tasks/visinger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from models.visinger import VISinger, MultiPeriodDiscriminator 7 | from modules.commons.utils import slice_segments 8 | from tasks.base import SpeechBaseTask 9 | from tasks.dataset_utils import VISingerDataset 10 | from utils.audio.io import save_wav 11 | from utils.audio.mel_processing import MelSpectrogramFixed, SpectrogramFixed 12 | from utils.commons.hparams import hparams 13 | from utils.commons.multiprocess_utils import MultiprocessManager 14 | from utils.commons.tensor_utils import tensors_to_scalars 15 | from utils.nn.model_utils import num_params 16 | 17 | 18 | class VISingerTask(SpeechBaseTask): 19 | def __init__(self): 20 | super().__init__() 21 | self.dataset_cls = VISingerDataset 22 | self.sil_ph = self.token_encoder.sil_phonemes() 23 | data_dir = hparams["binary_data_dir"] 24 | # Discriminator settings 25 | self.build_disc_model() 26 | # Dictionary settings 27 | self.pitch_dict = json.load(open(f"{data_dir}/pitch_map.json")) 28 | self.dur_dict = json.load(open(f"{data_dir}/dur_map.json")) 29 | # Spectrogram Function 30 | self.spec_fn = SpectrogramFixed(n_fft=hparams["fft_size"], win_length=hparams["win_size"], 31 | hop_length=hparams["hop_size"], window_fn=torch.hann_window) 32 | self.mel_fn = MelSpectrogramFixed(sample_rate=hparams["sample_rate"], n_fft=hparams["fft_size"], 33 | win_length=hparams["win_size"], hop_length=hparams["hop_size"], 34 | f_min=hparams["fmin"], f_max=hparams["fmax"], 35 | n_mels=hparams["num_mel_bins"], window_fn=torch.hann_window) 36 | 37 | def build_tts_model(self): 38 | ph_dict_size = len(self.token_encoder) 39 | self.model = VISinger(ph_dict_size, len(self.pitch_dict), len(self.dur_dict), hparams) 40 | 41 | def build_disc_model(self): 42 | self.mel_disc = MultiPeriodDiscriminator(hparams["use_spectral_norm"]) 43 | self.disc_params = list(self.mel_disc.parameters()) 44 | 45 | def on_train_start(self): 46 | super().on_train_start() 47 | for n, m in self.model.named_children(): 48 | num_params(m, model_name=n) 49 | if hasattr(self.model, "visinger"): 50 | for n, m in self.model.visinger.named_children(): 51 | num_params(m, model_name=f"visinger.{n}") 52 | 53 | def _training_step(self, sample, batch_idx, optimizer_idx): 54 | loss_output = {} 55 | loss_weights = {} 56 | disc_start = self.global_step >= hparams["disc_start_steps"] and hparams["lambda_mel_adv"] > 0 57 | if optimizer_idx == 0: 58 | ######################### 59 | # Generator # 60 | ######################### 61 | loss_output, model_out = self.run_model(sample, infer=False) 62 | self.model_out_gt = self.model_out = \ 63 | {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} 64 | if disc_start: 65 | slice_wavs = slice_segments(sample["wavs"].unsqueeze(1), model_out["ids_slice"] * hparams["hop_size"], 66 | hparams["segment_size"] * hparams["hop_size"]) 67 | _, mel_disc_gen, fmap_tgt, fmap_gen = self.mel_disc(slice_wavs, model_out["wav_out"].unsqueeze(1)) 68 | if mel_disc_gen is not None: 69 | loss_output["generator"] = self.add_generator_loss(mel_disc_gen) 70 | loss_weights["generator"] = hparams["lambda_mel_adv"] 71 | if fmap_tgt is not None and fmap_gen is not None: 72 | loss_output["feature_match"] = self.add_feature_matching_loss(fmap_tgt, fmap_gen) 73 | loss_weights["feature_match"] = hparams["lambda_fm"] 74 | else: 75 | ######################### 76 | # Discriminator # 77 | ######################### 78 | if disc_start and self.global_step % hparams["disc_interval"] == 0: 79 | model_out = self.model_out_gt 80 | # Slicing wavs 81 | slice_wavs = slice_segments(sample["wavs"].unsqueeze(1), model_out["ids_slice"] * hparams["hop_size"], 82 | hparams["segment_size"] * hparams["hop_size"]) 83 | mel_disc_tgt, mel_disc_gen, _, _ = self.mel_disc(slice_wavs, model_out["wav_out"].unsqueeze(1)) 84 | if mel_disc_gen is not None: 85 | loss_output["discriminator"] = self.add_discriminator_loss(mel_disc_tgt, mel_disc_gen) 86 | loss_weights["discriminator"] = 1.0 87 | total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) 88 | loss_output["batch_size"] = sample["text_tokens"].shape[0] 89 | return total_loss, loss_output 90 | 91 | def run_model(self, sample, infer=False, *args): 92 | text_tokens = sample["text_tokens"] # [B, T_text] 93 | pitch_tokens = sample["note_pitch"] 94 | dur_tokens = sample["note_dur"] 95 | mel2ph = sample["mel2ph"] # [Batch, T_mels] 96 | spk_embed = sample.get("spk_embed") 97 | spk_id = sample.get("spk_ids") 98 | if not infer: 99 | f0 = sample.get('f0') 100 | mel = sample["mels"] 101 | output = self.model(text_tokens, pitch_tokens, dur_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, f0=f0, mel=mel, infer=False) 102 | # Losses 103 | losses = {} 104 | # KL divergence losses 105 | losses["kl_v"] = output["kl"].detach() 106 | losses_kl = output["kl"] 107 | losses_kl = torch.clamp(losses_kl, min=hparams["kl_min"]) 108 | losses_kl = min(self.global_step / hparams["kl_start_steps"], 1) * losses_kl 109 | losses_kl = losses_kl * hparams["lambda_kl"] 110 | losses["kl"] = losses_kl 111 | sample["tgt_mel"] = self.mel_fn(sample["wavs"]) # [Batch, mel_bins, T_len] 112 | tgt_slice_mel = slice_segments(sample["tgt_mel"], output["ids_slice"], hparams["segment_size"]) # [Batch, mel_bins, T_slice] 113 | output["mel_out"] = mel_out = self.mel_fn(output["wav_out"].squeeze(1)) # [Batch, mel_bins, T_slice] 114 | self.add_mel_loss(mel_out.transpose(1, 2), tgt_slice_mel.transpose(1, 2), losses) 115 | # Pitch losses 116 | if hparams["use_pitch_embed"]: 117 | self.add_pitch_loss(output, sample, losses) 118 | # CTC loss 119 | if hparams["use_phoneme_pred"]: 120 | self.add_ctc_loss(output, sample, losses) 121 | return losses, output 122 | else: 123 | output = self.model(text_tokens, pitch_tokens, dur_tokens, mel2ph=mel2ph, spk_embed=spk_embed, 124 | spk_id=spk_id, f0=None, mel=None, infer=True) 125 | return output 126 | 127 | def add_pitch_loss(self, output, sample, losses): 128 | f0 = sample['f0'] 129 | uv = sample["uv"] 130 | nonpadding = (sample["mel2ph"] != 0).float() # [B, T_mels] 131 | p_pred = output['f0_pred'] 132 | assert p_pred[..., 0].shape == f0.shape, f"| f0_diff: {f0.shape}, pred_diff: {p_pred.shape}" 133 | # Loss for voice/unvoice flag 134 | losses["uv"] = (F.binary_cross_entropy_with_logits(p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \ 135 | / nonpadding.sum() * hparams['lambda_uv'] 136 | nonpadding = nonpadding * (uv == 0).float() 137 | # Loss for f0 difference 138 | losses["f0"] = (F.l1_loss(p_pred[:, :, 0], f0, reduction="none") * nonpadding).sum() \ 139 | / nonpadding.sum() * hparams["lambda_f0"] 140 | 141 | def add_ctc_loss(self, output, sample, losses): 142 | ph_pred = output["ph_pred"].float().permute(2, 0, 1) # [T_mel, Batch, Dict_size] 143 | input_length = sample["mel_lengths"] 144 | text_tokens = sample["text_tokens"] # [Batch, T_ph] 145 | target_length = sample["text_lengths"] 146 | losses["ctc"] = F.ctc_loss(ph_pred, text_tokens, input_length, target_length, zero_infinity=True) * hparams["lambda_ctc"] 147 | 148 | def add_discriminator_loss(self, tgt_output, gen_output): 149 | disc_loss = 0 150 | for tgt, gen in zip(tgt_output, gen_output): 151 | r_loss = torch.mean((1 - tgt.float()) ** 2) # (D(y) - 1)^2 152 | g_loss = torch.mean(gen.float() ** 2) # (D(G(z)))^2 153 | disc_loss += (r_loss + g_loss) 154 | return disc_loss 155 | 156 | def add_generator_loss(self, gen_output): 157 | gen_loss = 0 158 | for gen in gen_output: 159 | gen = gen.float() 160 | gen_loss += torch.mean((1 - gen) ** 2) # (D(G(z) - 1)^2 161 | return gen_loss 162 | 163 | def add_feature_matching_loss(self, fmap_tgt, fmap_gen): 164 | feature_loss = 0 165 | for tgt, gen in zip(fmap_tgt, fmap_gen): 166 | for tgt_layer, gen_layer in zip(tgt, gen): 167 | tgt_layer = tgt_layer.float().detach() 168 | gen_layer = gen_layer.float() 169 | feature_loss += torch.mean(torch.abs(tgt_layer - gen_layer)) 170 | return feature_loss 171 | 172 | def validation_start(self): 173 | pass 174 | 175 | def save_valid_result(self, sample, batch_idx, model_out): 176 | sr = hparams["sample_rate"] 177 | mel_out = model_out["mel_out"] 178 | self.plot_mel(batch_idx, mel_out, sample["tgt_mel"].transpose(1, 2)) 179 | if self.global_step > 0: 180 | self.logger.add_audio(f"wav_val_{batch_idx}", model_out["wav_out"], self.global_step, sr) 181 | if "wav_full" in model_out: 182 | self.logger.add_audio(f"wav_val_full_{batch_idx}", model_out["wav_full"], self.global_step, sr) 183 | # Ground truth wavforms 184 | if self.global_step <= hparams["valid_infer_interval"]: 185 | self.logger.add_audio(f"wav_gt_{batch_idx}", sample["wavs"], self.global_step, sr) 186 | 187 | def validation_step(self, sample, batch_idx): 188 | outputs = {} 189 | outputs["losses"] = {} 190 | outputs["losses"], model_out = self.run_model(sample) 191 | outputs["total_loss"] = sum(outputs["losses"].values()) 192 | outputs["nsamples"] = sample["nsamples"] 193 | outputs = tensors_to_scalars(outputs) 194 | if self.global_step % hparams["valid_infer_interval"] == 0 \ 195 | and batch_idx < hparams["num_valid_plots"]: 196 | model_out = self.run_model(sample, infer=True) 197 | model_out["mel_out"] = self.mel_fn(model_out["wav_out"].squeeze(1)).transpose(1, 2) 198 | self.save_valid_result(sample, batch_idx, model_out) 199 | return outputs 200 | 201 | def build_optimizer(self, model): 202 | optimizer_gen = torch.optim.AdamW(self.model.parameters(), 203 | lr=hparams["lr"], 204 | betas=(hparams["optimizer_adam_beta1"], hparams["optimizer_adam_beta2"]), 205 | weight_decay=hparams["weight_decay"], 206 | eps=hparams["eps"]) 207 | optimizer_disc = torch.optim.AdamW(self.disc_params, 208 | lr=hparams["lr"], 209 | betas=(hparams["optimizer_adam_beta1"], hparams["optimizer_adam_beta2"]), 210 | **hparams["discriminator_optimizer_params"]) if len(self.disc_params) > 0 else None 211 | return [optimizer_gen, optimizer_disc] 212 | 213 | def build_scheduler(self, optimizer): 214 | return [torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer[0], # Generator Scheduler 215 | last_epoch=self.current_epoch-1, 216 | **hparams["generator_scheduler_params"]), 217 | torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer[1], # Discriminator Scheduler 218 | last_epoch=self.current_epoch-1, 219 | **hparams["discriminator_scheduler_params"])] 220 | 221 | def on_after_optimization(self, epoch, batch_idx, optimizer, opt_idx): 222 | if self.scheduler is not None and hparams["endless_ds"]: 223 | self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches']) 224 | self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches']) 225 | elif self.scheduler is not None and not hparams["endless_ds"]: 226 | self.scheduler[0].step(epoch) 227 | self.scheduler[1].step(epoch) 228 | 229 | ############################## 230 | # inference 231 | ############################## 232 | def test_start(self): 233 | self.saving_result_pool = MultiprocessManager(int(os.getenv('N_PROC', os.cpu_count()))) 234 | self.saving_results_futures = [] 235 | self.gen_dir = os.path.join( 236 | hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') 237 | self.tgt_dir = hparams["processed_data_dir"] 238 | os.makedirs(self.gen_dir, exist_ok=True) 239 | os.makedirs(f'{self.gen_dir}/wavs', exist_ok=True) 240 | os.makedirs(f'{self.gen_dir}/plot', exist_ok=True) 241 | if hparams.get('save_mel_npy', False): 242 | os.makedirs(f'{self.gen_dir}/mel_npy', exist_ok=True) 243 | 244 | def test_step(self, sample, batch_idx): 245 | import time 246 | assert sample['text_tokens'].shape[0] == 1, 'only support batch_size=1 in inference' 247 | # Inference for test step 248 | start_time = time.time() 249 | output = self.run_model(sample, infer=True) 250 | running_time = time.time() - start_time 251 | # Inference settings 252 | item_name = sample['item_name'][0] 253 | wav_fn = f"{hparams['processed_data_dir']}/{'/'.join(sample['wav_fn'][0].split('/')[-2:])}" 254 | tokens = sample['text_tokens'][0].cpu().numpy() 255 | wav_pred = output['wav_out'][0].cpu().numpy() 256 | gen_dir = self.gen_dir 257 | input_fn = f"{gen_dir}/wavs/{item_name}_synth.wav" 258 | save_wav(wav_pred, input_fn, hparams["sample_rate"], norm=hparams["out_wav_norm"]) 259 | return {'item_name': item_name, 260 | 'ph_tokens': self.token_encoder.decode(tokens.tolist()), 261 | "wav_fn": wav_fn, 262 | 'wav_fn_pred': f"{item_name}_synth.wav", 263 | "rtf": running_time} 264 | 265 | def test_end(self): 266 | pass 267 | -------------------------------------------------------------------------------- /modules/visinger/flow.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/jaywalnut310/vits 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from modules.visinger.encoder import WaveNet 9 | 10 | DEFAULT_MIN_BIN_WIDTH = 1e-3 11 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 12 | DEFAULT_MIN_DERIVATIVE = 1e-3 13 | 14 | 15 | class ResidualCouplingBlock(nn.Module): 16 | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, 17 | n_layers, n_flows=4, gin_channels=0): 18 | super().__init__() 19 | self.channels = channels 20 | self.hidden_channels = hidden_channels 21 | self.kernel_size = kernel_size 22 | self.dilation_rate = dilation_rate 23 | self.n_layers = n_layers 24 | self.n_flows = n_flows 25 | self.gin_channels = gin_channels 26 | 27 | self.flows = nn.ModuleList() 28 | for _ in range(n_flows): 29 | self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, 30 | gin_channels=gin_channels, mean_only=True)) 31 | self.flows.append(Flip()) 32 | 33 | def forward(self, x, x_mask, g=None, reverse=False): 34 | if not reverse: 35 | for flow in self.flows: 36 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 37 | else: 38 | for flow in reversed(self.flows): 39 | x = flow(x, x_mask, g=g, reverse=reverse) 40 | return x 41 | 42 | def remove_weight_norm(self): 43 | for i in range(self.n_flows): 44 | self.flows[i * 2].remove_weight_norm() 45 | 46 | 47 | class ResidualCouplingLayer(nn.Module): 48 | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, 49 | p_dropout=0, gin_channels=0, mean_only=False): 50 | assert channels % 2 == 0, "channels should be divisible by 2" 51 | super().__init__() 52 | self.channels = channels 53 | self.hidden_channels = hidden_channels 54 | self.kernel_size = kernel_size 55 | self.dilation_rate = dilation_rate 56 | self.n_layers = n_layers 57 | self.half_channels = channels // 2 58 | self.mean_only = mean_only 59 | 60 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 61 | self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 62 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 63 | self.post.weight.data.zero_() 64 | self.post.bias.data.zero_() 65 | 66 | def forward(self, x, x_mask, g=None, reverse=False): 67 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 68 | h = self.pre(x0) * x_mask 69 | h = self.enc(h, x_mask, g=g) 70 | stats = self.post(h) * x_mask 71 | if not self.mean_only: 72 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 73 | else: 74 | m = stats 75 | logs = torch.zeros_like(m) 76 | 77 | if not reverse: 78 | x1 = m + x1 * torch.exp(logs) * x_mask 79 | x = torch.cat([x0, x1], 1) 80 | logdet = torch.sum(logs, [1,2]) 81 | return x, logdet 82 | else: 83 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 84 | x = torch.cat([x0, x1], 1) 85 | return x 86 | 87 | 88 | class Flip(nn.Module): 89 | def forward(self, x, *args, reverse=False, **kwargs): 90 | x = torch.flip(x, [1]) 91 | if not reverse: 92 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 93 | return x, logdet 94 | else: 95 | return x 96 | 97 | 98 | class ConvFlow(nn.Module): 99 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 100 | super().__init__() 101 | self.in_channels = in_channels 102 | self.filter_channels = filter_channels 103 | self.kernel_size = kernel_size 104 | self.n_layers = n_layers 105 | self.num_bins = num_bins 106 | self.tail_bound = tail_bound 107 | self.half_channels = in_channels // 2 108 | 109 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 110 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 111 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 112 | self.proj.weight.data.zero_() 113 | self.proj.bias.data.zero_() 114 | 115 | def forward(self, x, x_mask, g=None, reverse=False): 116 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 117 | h = self.pre(x0) 118 | h = self.convs(h, x_mask, g=g) 119 | h = self.proj(h) * x_mask 120 | 121 | b, c, t = x0.shape 122 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 123 | 124 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 125 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 126 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 127 | 128 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, unnormalized_widths, 129 | unnormalized_heights, 130 | unnormalized_derivatives, 131 | inverse=reverse, 132 | tails='linear', 133 | tail_bound=self.tail_bound 134 | ) 135 | 136 | x = torch.cat([x0, x1], 1) * x_mask 137 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 138 | if not reverse: 139 | return x, logdet 140 | else: 141 | return x 142 | 143 | 144 | class LayerNorm(nn.Module): 145 | def __init__(self, channels, eps=1e-5): 146 | super().__init__() 147 | self.channels = channels 148 | self.eps = eps 149 | self.gamma = nn.Parameter(torch.ones(channels)) 150 | self.beta = nn.Parameter(torch.zeros(channels)) 151 | 152 | def forward(self, x): 153 | x = x.transpose(1, -1) 154 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 155 | return x.transpose(1, -1) 156 | 157 | 158 | class DDSConv(nn.Module): 159 | """ 160 | Dialted and Depth-Separable Convolution 161 | """ 162 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 163 | super().__init__() 164 | self.channels = channels 165 | self.kernel_size = kernel_size 166 | self.n_layers = n_layers 167 | self.p_dropout = p_dropout 168 | 169 | self.drop = nn.Dropout(p_dropout) 170 | self.convs_sep = nn.ModuleList() 171 | self.convs_1x1 = nn.ModuleList() 172 | self.norms_1 = nn.ModuleList() 173 | self.norms_2 = nn.ModuleList() 174 | for i in range(n_layers): 175 | dilation = kernel_size ** i 176 | padding = (kernel_size * dilation - dilation) // 2 177 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 178 | groups=channels, dilation=dilation, padding=padding)) 179 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 180 | self.norms_1.append(LayerNorm(channels)) 181 | self.norms_2.append(LayerNorm(channels)) 182 | 183 | def forward(self, x, x_mask, g=None): 184 | if g is not None: 185 | x = x + g 186 | for i in range(self.n_layers): 187 | y = self.convs_sep[i](x * x_mask) 188 | y = self.norms_1[i](y) 189 | y = F.gelu(y) 190 | y = self.convs_1x1[i](y) 191 | y = self.norms_2[i](y) 192 | y = F.gelu(y) 193 | y = self.drop(y) 194 | x = x + y 195 | return x * x_mask 196 | 197 | 198 | def piecewise_rational_quadratic_transform(inputs, unnormalized_widths, unnormalized_heights, 199 | unnormalized_derivatives, inverse=False, 200 | tails=None, tail_bound=1., 201 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 202 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 203 | min_derivative=DEFAULT_MIN_DERIVATIVE): 204 | if tails is None: 205 | spline_fn = rational_quadratic_spline 206 | spline_kwargs = {} 207 | else: 208 | spline_fn = unconstrained_rational_quadratic_spline 209 | spline_kwargs = {'tails': tails, 'tail_bound': tail_bound} 210 | 211 | outputs, logabsdet = spline_fn(inputs=inputs, 212 | unnormalized_widths=unnormalized_widths, 213 | unnormalized_heights=unnormalized_heights, 214 | unnormalized_derivatives=unnormalized_derivatives, 215 | inverse=inverse, 216 | min_bin_width=min_bin_width, 217 | min_bin_height=min_bin_height, 218 | min_derivative=min_derivative, 219 | **spline_kwargs) 220 | return outputs, logabsdet 221 | 222 | 223 | def unconstrained_rational_quadratic_spline(inputs, unnormalized_widths, unnormalized_heights, 224 | unnormalized_derivatives, inverse=False, 225 | tails='linear', tail_bound=1., 226 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 227 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 228 | min_derivative=DEFAULT_MIN_DERIVATIVE): 229 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 230 | outside_interval_mask = ~inside_interval_mask 231 | 232 | outputs = torch.zeros_like(inputs) 233 | logabsdet = torch.zeros_like(inputs) 234 | 235 | if tails == 'linear': 236 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 237 | constant = np.log(np.exp(1 - min_derivative) - 1) 238 | unnormalized_derivatives[..., 0] = constant 239 | unnormalized_derivatives[..., -1] = constant 240 | 241 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 242 | logabsdet[outside_interval_mask] = 0 243 | else: 244 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 245 | 246 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 247 | inputs=inputs[inside_interval_mask], 248 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 249 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 250 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 251 | inverse=inverse, 252 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 253 | min_bin_width=min_bin_width, 254 | min_bin_height=min_bin_height, 255 | min_derivative=min_derivative) 256 | 257 | return outputs, logabsdet 258 | 259 | 260 | def rational_quadratic_spline(inputs, unnormalized_widths, unnormalized_heights, 261 | unnormalized_derivatives, inverse=False, 262 | left=0., right=1., bottom=0., top=1., 263 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 264 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 265 | min_derivative=DEFAULT_MIN_DERIVATIVE): 266 | if torch.min(inputs) < left or torch.max(inputs) > right: 267 | raise ValueError('Input to a transform is not within its domain') 268 | 269 | num_bins = unnormalized_widths.shape[-1] 270 | 271 | if min_bin_width * num_bins > 1.0: 272 | raise ValueError('Minimal bin width too large for the number of bins') 273 | if min_bin_height * num_bins > 1.0: 274 | raise ValueError('Minimal bin height too large for the number of bins') 275 | 276 | widths = F.softmax(unnormalized_widths, dim=-1) 277 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 278 | cumwidths = torch.cumsum(widths, dim=-1) 279 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 280 | cumwidths = (right - left) * cumwidths + left 281 | cumwidths[..., 0] = left 282 | cumwidths[..., -1] = right 283 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 284 | 285 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 286 | 287 | heights = F.softmax(unnormalized_heights, dim=-1) 288 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 289 | cumheights = torch.cumsum(heights, dim=-1) 290 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 291 | cumheights = (top - bottom) * cumheights + bottom 292 | cumheights[..., 0] = bottom 293 | cumheights[..., -1] = top 294 | heights = cumheights[..., 1:] - cumheights[..., :-1] 295 | 296 | if inverse: 297 | bin_idx = searchsorted(cumheights, inputs)[..., None] 298 | else: 299 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 300 | 301 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 302 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 303 | 304 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 305 | delta = heights / widths 306 | input_delta = delta.gather(-1, bin_idx)[..., 0] 307 | 308 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 309 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 310 | 311 | input_heights = heights.gather(-1, bin_idx)[..., 0] 312 | 313 | if inverse: 314 | a = (((inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one 315 | - 2 * input_delta) 316 | + input_heights * (input_delta - input_derivatives))) 317 | b = (input_heights * input_derivatives 318 | - (inputs - input_cumheights) * (input_derivatives 319 | + input_derivatives_plus_one 320 | - 2 * input_delta)) 321 | c = - input_delta * (inputs - input_cumheights) 322 | 323 | discriminant = b.pow(2) - 4 * a * c 324 | assert (discriminant >= 0).all() 325 | 326 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 327 | outputs = root * input_bin_widths + input_cumwidths 328 | 329 | theta_one_minus_theta = root * (1 - root) 330 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 331 | * theta_one_minus_theta) 332 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 333 | + 2 * input_delta * theta_one_minus_theta 334 | + input_derivatives * (1 - root).pow(2)) 335 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 336 | 337 | return outputs, -logabsdet 338 | else: 339 | theta = (inputs - input_cumwidths) / input_bin_widths 340 | theta_one_minus_theta = theta * (1 - theta) 341 | 342 | numerator = input_heights * (input_delta * theta.pow(2) 343 | + input_derivatives * theta_one_minus_theta) 344 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 345 | * theta_one_minus_theta) 346 | outputs = input_cumheights + numerator / denominator 347 | 348 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 349 | + 2 * input_delta * theta_one_minus_theta 350 | + input_derivatives * (1 - theta).pow(2)) 351 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 352 | 353 | return outputs, logabsdet 354 | 355 | 356 | def searchsorted(bin_locations, inputs, eps=1e-6): 357 | bin_locations[..., -1] += eps 358 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 359 | --------------------------------------------------------------------------------