├── .gitignore ├── assets └── Demo.gif ├── data_gen └── tts │ ├── base_binarizer.py │ ├── base_pre_align.py │ ├── bin │ ├── align_and_binarize.py │ ├── binarize.py │ ├── pre_align.py │ └── train_mfa_align.py │ ├── binarizer_zh.py │ ├── data_gen_utils.py │ ├── mfa_config.yaml │ ├── tacotron │ ├── audio_processing.py │ ├── layers.py │ └── stft.py │ ├── txt_processors │ ├── base_text_processor.py │ ├── en.py │ ├── en_syl.py │ ├── zh.py │ ├── zh_g2pM.py │ ├── zh_g2pM_song_seg.py │ └── zh_song_seg.py │ ├── vocoder_binarizer.py │ ├── vocoder_binarizer_tacotron.py │ └── vocoder_pre_align.py ├── egs ├── audios │ ├── LJ001-0001_gt.wav │ ├── LJ001-0002_gt.wav │ └── LJ001-0003_gt.wav ├── datasets │ └── audio │ │ ├── libritts │ │ └── pre_align.py │ │ ├── lj │ │ ├── base_text2mel.yaml │ │ └── pre_align.py │ │ ├── pre_align.py │ │ └── vctk │ │ └── pre_align.py ├── demo.ipynb ├── demo_tts.ipynb ├── demo_tts.py ├── egs_bases │ ├── config_base.yaml │ └── tts │ │ ├── base.yaml │ │ ├── base_zh.yaml │ │ └── vocoder │ │ └── base.yaml └── tts │ ├── base_tts_infer.py │ ├── ds.py │ ├── fs.py │ ├── fs2_orig.py │ └── ps_flow.py ├── modules ├── FastDiff │ ├── config │ │ ├── FastDiff.yaml │ │ ├── FastDiff_libritts.yaml │ │ ├── FastDiff_tacotron.yaml │ │ ├── FastDiff_vctk.yaml │ │ └── base.yaml │ ├── module │ │ ├── FastDiff_model.py │ │ ├── WaveNet.py │ │ ├── modules.py │ │ └── util.py │ └── task │ │ └── FastDiff.py ├── commons │ ├── common_layers.py │ └── gdl_loss.py ├── parallel_wavegan │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── causal_conv.py │ │ ├── residual_block.py │ │ ├── residual_stack.py │ │ └── upsample.py │ ├── models │ │ ├── __init__.py │ │ └── parallel_wavegan.py │ └── utils │ │ ├── __init__.py │ │ └── utils.py └── wavenet_vocoder │ ├── conv.py │ ├── mixture.py │ ├── modules.py │ ├── upsample.py │ ├── util.py │ └── wavenet.py ├── readme.md ├── requirements.txt ├── tasks ├── base_task.py ├── run.py ├── tts │ ├── dataset_utils.py │ ├── fs2.py │ └── tts_base.py └── vocoder │ ├── dataset_utils.py │ └── vocoder_base.py ├── utils ├── __init__.py ├── audio.py ├── ckpt_utils.py ├── common_schedulers.py ├── ddp_utils.py ├── hparams.py ├── indexed_datasets.py ├── metrics.py ├── multiprocess_utils.py ├── pitch_distance.py ├── pitch_utils.py ├── plot.py ├── rnnoise.py ├── text_encoder.py ├── text_norm.py ├── torch_stft.py ├── trainer.py └── tts_utils.py └── vocoders ├── __init__.py ├── base_vocoder.py ├── gl_linear.py ├── gl_mel.py ├── pwg.py ├── stft.py └── vocoder_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Project ignore 2 | 3 | /ParallelWaveGAN 4 | /wavegan_pretrained* 5 | /pretrained_models 6 | /checkpoints 7 | /data 8 | rsync 9 | .idea 10 | .DS_Store 11 | bak 12 | tmp 13 | *.tar.gz 14 | mos 15 | nbs 16 | /configs_usr/* 17 | !/configs_usr/.gitkeep 18 | /egs_usr/* 19 | !/egs_usr/.gitkeep 20 | /fast_transformers 21 | /rnnoise 22 | /usr/* 23 | usr_ali 24 | !/usr/.gitkeep 25 | test_scripts 26 | 27 | # mfa and kaldi 28 | mfa 29 | montreal-forced-aligner 30 | kaldi_align 31 | 32 | # Created by .ignore support plugin (hsz.mobi) 33 | ### Python template 34 | # Byte-compiled / optimized / DLL files 35 | __pycache__/ 36 | *.py[cod] 37 | *$py.class 38 | 39 | # C extensions 40 | *.so 41 | 42 | # Distribution / packaging 43 | .Python 44 | build/ 45 | develop-eggs/ 46 | dist/ 47 | downloads/ 48 | eggs/ 49 | .eggs/ 50 | lib/ 51 | lib64/ 52 | parts/ 53 | sdist/ 54 | var/ 55 | wheels/ 56 | pip-wheel-metadata/ 57 | share/python-wheels/ 58 | *.egg-info/ 59 | .installed.cfg 60 | *.egg 61 | MANIFEST 62 | 63 | # PyInstaller 64 | # Usually these files are written by a python script from a template 65 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 66 | *.manifest 67 | *.spec 68 | 69 | # Installer logs 70 | pip-log.txt 71 | pip-delete-this-directory.txt 72 | 73 | # Unit test / coverage reports 74 | htmlcov/ 75 | .tox/ 76 | .nox/ 77 | .coverage 78 | .coverage.* 79 | .cache 80 | nosetests.xml 81 | coverage.xml 82 | *.cover 83 | .hypothesis/ 84 | .pytest_cache/ 85 | 86 | # Translations 87 | *.mo 88 | *.pot 89 | 90 | # Django stuff: 91 | *.log 92 | local_settings.py 93 | db.sqlite3 94 | db.sqlite3-journal 95 | 96 | # Flask stuff: 97 | instance/ 98 | .webassets-cache 99 | 100 | # Scrapy stuff: 101 | .scrapy 102 | 103 | # Sphinx documentation 104 | docs/_build/ 105 | 106 | # PyBuilder 107 | target/ 108 | 109 | # Jupyter Notebook 110 | .ipynb_checkpoints 111 | 112 | # IPython 113 | profile_default/ 114 | ipython_config.py 115 | 116 | # pyenv 117 | .python-version 118 | 119 | # pipenv 120 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 121 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 122 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 123 | # install all needed dependencies. 124 | #Pipfile.lock 125 | 126 | # celery beat schedule file 127 | celerybeat-schedule 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | research/vcpitch3/* 156 | 157 | 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 将删除 datasets/remi/test/ 162 | -------------------------------------------------------------------------------- /assets/Demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/FastDiff/c954758364aa845f37642952c7a16ff3e92e436e/assets/Demo.gif -------------------------------------------------------------------------------- /data_gen/tts/base_pre_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | import librosa 6 | from utils import audio 7 | from data_gen.tts.data_gen_utils import is_sil_phoneme 8 | from utils.multiprocess_utils import chunked_multiprocess_run 9 | import traceback 10 | import importlib 11 | from utils.hparams import hparams, set_hparams 12 | import json 13 | import os 14 | import subprocess 15 | from tqdm import tqdm 16 | import pandas as pd 17 | from utils.rnnoise import rnnoise 18 | 19 | 20 | class BasePreAlign: 21 | def __init__(self): 22 | self.pre_align_args = hparams['pre_align_args'] 23 | txt_processor = self.pre_align_args['txt_processor'] 24 | self.txt_processor = importlib.import_module(f'data_gen.tts.txt_processors.{txt_processor}').TxtProcessor 25 | self.raw_data_dir = hparams['raw_data_dir'] 26 | self.processed_dir = hparams['processed_data_dir'] 27 | 28 | def meta_data(self): 29 | raise NotImplementedError 30 | 31 | @staticmethod 32 | def load_txt(txt_fn): 33 | l = open(txt_fn).readlines()[0].strip() 34 | return l 35 | 36 | @staticmethod 37 | def process_wav(idx, item_name, wav_fn, processed_dir, pre_align_args): 38 | if pre_align_args['sox_to_wav'] or pre_align_args['trim_sil'] or \ 39 | pre_align_args['sox_resample'] or pre_align_args['denoise']: 40 | sr = hparams['audio_sample_rate'] 41 | new_wav_fn = f"{processed_dir}/wav_inputs/{idx}" 42 | subprocess.check_call(f'sox "{wav_fn}" -t wav "{new_wav_fn}.wav"', shell=True) 43 | if pre_align_args['trim_sil']: 44 | y, sr = librosa.core.load(new_wav_fn + '.wav') 45 | y, _ = librosa.effects.trim(y) 46 | audio.save_wav(y, new_wav_fn + '_trim.wav', sr) 47 | new_wav_fn = new_wav_fn + '_trim' 48 | if pre_align_args['sox_resample']: 49 | subprocess.check_call(f'sox -v 0.95 "{new_wav_fn}.wav" -r{sr} "{new_wav_fn}_rs.wav"', shell=True) 50 | new_wav_fn = new_wav_fn + '_rs' 51 | if pre_align_args['denoise']: 52 | rnnoise(wav_fn, new_wav_fn + '_denoise.wav', out_sample_rate=sr) 53 | new_wav_fn = new_wav_fn + '_denoise' 54 | return new_wav_fn + '.wav' 55 | else: 56 | return wav_fn 57 | 58 | def process(self): 59 | set_hparams() 60 | processed_dir = self.processed_dir 61 | subprocess.check_call(f'rm -rf {processed_dir}/mfa_inputs', shell=True) 62 | os.makedirs(f"{processed_dir}/wav_inputs", exist_ok=True) 63 | phone_set = set() 64 | meta_df = [] 65 | word_level_dict = set() 66 | 67 | args = [] 68 | meta_data = [] 69 | for idx, inp_args in enumerate(tqdm(self.meta_data(), desc='Load meta data')): 70 | if len(inp_args) == 4: 71 | inp_args = [*inp_args, {}] 72 | meta_data.append(inp_args) 73 | item_name, wav_fn, txt_or_fn, spk, others = inp_args 74 | args.append([ 75 | idx, item_name, self.txt_processor, txt_or_fn, wav_fn, others, processed_dir, self.pre_align_args 76 | ]) 77 | item_names = [x[1] for x in args] 78 | assert len(item_names) == len(set(item_names)), 'Key `item_name` should be Unique.' 79 | 80 | for inp_args, res in zip(tqdm(meta_data, 'Processing'), chunked_multiprocess_run(self.process_job, args)): 81 | item_name, wav_fn, txt_or_fn, spk, others = inp_args 82 | if res is None: 83 | print(f"| Skip {wav_fn}.") 84 | continue 85 | phs, phs_for_dict, phs_for_align, txt, txt_raw, wav_fn = res 86 | meta_df.append({ 87 | 'item_name': item_name, 'spk': spk, 'txt': txt, 'txt_raw': txt_raw, 88 | 'ph': phs, 'wav_fn': wav_fn, 'others': json.dumps(others)}) 89 | for ph in phs.split(" "): 90 | phone_set.add(ph) 91 | phs_for_dict.add('SIL') 92 | for t in phs_for_dict: 93 | word_level_dict.add(f"{t.replace(' ', '_')} {t}") 94 | phone_set = sorted(phone_set) 95 | word_level_dict = sorted(word_level_dict) 96 | print("| phone_set[:200]: ", phone_set[:200]) 97 | with open(f'{processed_dir}/dict.txt', 'w') as f: 98 | for ph in phone_set: 99 | f.write(f'{ph} {ph}\n') 100 | json.dump(phone_set, open(f'{processed_dir}/phone_set.json', 'w')) 101 | 102 | # save to csv 103 | meta_df = pd.DataFrame(meta_df) 104 | meta_df.to_csv(f"{processed_dir}/metadata_phone.csv") 105 | with open(f'{processed_dir}/mfa_dict.txt', 'w') as f: 106 | for l in word_level_dict: 107 | f.write(f'{l}\n') 108 | 109 | @staticmethod 110 | def process_text(txt_processor, txt_raw, pre_align_args): 111 | phs, txt = txt_processor.process(txt_raw, pre_align_args) 112 | phs = [p.strip() for p in phs if p.strip() != ""] 113 | 114 | # remove sil phoneme in head and tail 115 | while len(phs) > 0 and is_sil_phoneme(phs[0]): 116 | phs = phs[1:] 117 | while len(phs) > 0 and is_sil_phoneme(phs[-1]): 118 | phs = phs[:-1] 119 | phs = [""] + phs + [""] 120 | phs_ = [] 121 | for i in range(len(phs)): 122 | if len(phs_) == 0 or not is_sil_phoneme(phs[i]) or not is_sil_phoneme(phs_[-1]): 123 | phs_.append(phs[i]) 124 | elif phs_[-1] == '|' and is_sil_phoneme(phs[i]) and phs[i] != '|': 125 | phs_[-1] = phs[i] 126 | cur_word = [] 127 | phs_for_align = [] 128 | phs_for_dict = set() 129 | for p in phs_: 130 | if is_sil_phoneme(p): 131 | if len(cur_word) > 0: 132 | phs_for_align.append('_'.join(cur_word)) 133 | phs_for_dict.add(' '.join(cur_word)) 134 | cur_word = [] 135 | if p not in txt_processor.sp_phonemes(): 136 | phs_for_align.append('SIL') 137 | else: 138 | cur_word.append(p) 139 | phs = " ".join(phs_) 140 | phs_for_align = " ".join(phs_for_align) 141 | return phs, phs_for_dict, phs_for_align, txt 142 | 143 | @classmethod 144 | def process_job(cls, idx, item_name, g2p_func, txt_or_fn, wav_fn, others, processed_dir, pre_align_args): 145 | try: 146 | if isinstance(txt_or_fn, list) or isinstance(txt_or_fn, tuple): 147 | txt_load_func, txt_load_func_args = txt_or_fn 148 | txt_raw = txt_load_func(txt_load_func_args) 149 | else: 150 | txt_raw = txt_or_fn 151 | except Exception as e: 152 | if not pre_align_args['allow_no_txt']: 153 | raise e 154 | else: 155 | txt_raw = 'NO_TEXT' 156 | if txt_raw is None: 157 | return None 158 | try: 159 | phs, phs_for_dict, phs_for_align, txt = cls.process_text(g2p_func, txt_raw, pre_align_args) 160 | wav_fn = cls.process_wav(idx, item_name, wav_fn, processed_dir, pre_align_args) 161 | if wav_fn is None: 162 | return None 163 | except: 164 | traceback.print_exc() 165 | return None 166 | group = idx // pre_align_args['nsample_per_mfa_group'] # group MFA inputs for better parallelism 167 | os.makedirs(f'{processed_dir}/mfa_inputs/{group}', exist_ok=True) 168 | ext = os.path.splitext(wav_fn)[1] 169 | new_wav_fn = f"{processed_dir}/mfa_inputs/{group}/{idx:07d}_{item_name}{ext}" 170 | cp_cmd = 'mv' if 'wav_inputs' in wav_fn else 'cp' 171 | subprocess.check_call(f'{cp_cmd} "{wav_fn}" "{new_wav_fn}"', shell=True) 172 | with open(f'{processed_dir}/mfa_inputs/{group}/{idx:07d}_{item_name}.lab', 'w') as f_txt: 173 | f_txt.write(phs_for_align) 174 | wav_fn = new_wav_fn 175 | return phs, phs_for_dict, phs_for_align, txt, txt_raw, wav_fn 176 | -------------------------------------------------------------------------------- /data_gen/tts/bin/align_and_binarize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | from data_gen.tts.bin.binarize import binarize 6 | from data_gen.tts.bin.pre_align import pre_align 7 | from data_gen.tts.bin.train_mfa_align import train_mfa_align 8 | from utils.hparams import set_hparams 9 | 10 | if __name__ == '__main__': 11 | set_hparams() 12 | pre_align() 13 | train_mfa_align() 14 | binarize() 15 | -------------------------------------------------------------------------------- /data_gen/tts/bin/binarize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | import importlib 6 | from utils.hparams import set_hparams, hparams 7 | 8 | 9 | def binarize(): 10 | binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') 11 | pkg = ".".join(binarizer_cls.split(".")[:-1]) 12 | cls_name = binarizer_cls.split(".")[-1] 13 | binarizer_cls = getattr(importlib.import_module(pkg), cls_name) 14 | print("| Binarizer: ", binarizer_cls) 15 | binarizer_cls().process() 16 | 17 | 18 | if __name__ == '__main__': 19 | set_hparams() 20 | binarize() 21 | -------------------------------------------------------------------------------- /data_gen/tts/bin/pre_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | import importlib 6 | from utils.hparams import set_hparams, hparams 7 | 8 | 9 | def pre_align(): 10 | assert hparams['pre_align_cls'] != '' 11 | 12 | pkg = ".".join(hparams["pre_align_cls"].split(".")[:-1]) 13 | cls_name = hparams["pre_align_cls"].split(".")[-1] 14 | process_cls = getattr(importlib.import_module(pkg), cls_name) 15 | process_cls().process() 16 | 17 | 18 | if __name__ == '__main__': 19 | set_hparams() 20 | pre_align() 21 | -------------------------------------------------------------------------------- /data_gen/tts/bin/train_mfa_align.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from utils.hparams import hparams, set_hparams 3 | import os 4 | 5 | 6 | def train_mfa_align(): 7 | CORPUS = hparams['processed_data_dir'].split("/")[-1] 8 | print(f"| Run MFA for {CORPUS}.") 9 | NUM_JOB = int(os.getenv('N_PROC', os.cpu_count())) 10 | subprocess.check_call( 11 | f'CORPUS={CORPUS} NUM_JOB={NUM_JOB} MFA_VERSION={hparams["mfa_version"]} ' 12 | f'bash scripts/run_mfa_train_align.sh', 13 | shell=True) 14 | 15 | 16 | if __name__ == '__main__': 17 | set_hparams(print_hparams=False) 18 | train_mfa_align() 19 | -------------------------------------------------------------------------------- /data_gen/tts/binarizer_zh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import Counter 4 | 5 | os.environ["OMP_NUM_THREADS"] = "1" 6 | 7 | from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError 8 | from data_gen.tts.data_gen_utils import get_mel2ph, PUNCS 9 | from utils.hparams import set_hparams, hparams 10 | import numpy as np 11 | 12 | 13 | class ZhBinarizer(BaseBinarizer): 14 | def _word_encoder(self): 15 | fn = f"{hparams['binary_data_dir']}/word_set.json" 16 | if self.binarization_args['reset_word_dict']: 17 | word_set = [] 18 | for word_sent in self.item2txt.values(): 19 | word_set += list(word_sent) 20 | word_set = Counter(word_set) 21 | total_words = sum(word_set.values()) 22 | word_set = word_set.most_common(hparams['word_size']) 23 | num_unk_words = total_words - sum([x[1] for x in word_set]) 24 | word_set = [x[0] for x in word_set] 25 | json.dump(word_set, open(fn, 'w')) 26 | print(f"| #total words: {total_words}, #unk_words: {num_unk_words}") 27 | else: 28 | word_set = json.load(open(fn, 'r')) 29 | print("| Word dict size: ", len(word_set), word_set[:10]) 30 | from utils.text_encoder import TokenTextEncoder 31 | return TokenTextEncoder(None, vocab_list=word_set, replace_oov='') 32 | 33 | @staticmethod 34 | def get_align(tg_fn, res): 35 | ph = res['ph'] 36 | mel = res['mel'] 37 | phone = res['phone'] 38 | if tg_fn is not None and os.path.exists(tg_fn): 39 | _, dur = get_mel2ph(tg_fn, ph, mel, hparams) 40 | else: 41 | raise BinarizationError(f"Align not found") 42 | ph_list = ph.split(" ") 43 | assert len(dur) == len(ph_list) 44 | mel2ph = [] 45 | for i in range(len(dur)): 46 | mel2ph += [i + 1] * dur[i] 47 | mel2ph = np.array(mel2ph) 48 | if mel2ph.max() - 1 >= len(phone): 49 | raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone))}") 50 | res['mel2ph'] = mel2ph 51 | res['dur'] = dur 52 | 53 | # char-level pitch 54 | if 'f0' in res: 55 | res['f0_ph'] = np.array([0 for _ in res['f0']], dtype=float) 56 | char_start_idx = 0 57 | f0s_char = [] 58 | # ph_list = 0 59 | for idx, (f0_, ph_idx) in enumerate(zip(res['f0'], res['mel2ph'])): 60 | is_pinyin = ph_list[ph_idx - 1][0].isalpha() 61 | if not is_pinyin or ph_idx - res['mel2ph'][idx - 1] > 1: 62 | if len(f0s_char) > 0: 63 | res['f0_ph'][char_start_idx:idx] = sum(f0s_char) / len(f0s_char) 64 | f0s_char = [] 65 | char_start_idx = idx 66 | if not is_pinyin: 67 | char_start_idx += 1 68 | if f0_ > 0: 69 | f0s_char.append(f0_) 70 | 71 | @staticmethod 72 | def get_word(res, word_encoder): 73 | ph_split = res['ph'].split(" ") 74 | # ph side mapping to word 75 | ph_words = [] # ['', 'N_AW1_', ',', 'AE1_Z_|', 'AO1_L_|', 'B_UH1_K_S_|', 'N_AA1_T_|', ....] 76 | ph2word = np.zeros([len(ph_split)], dtype=int) 77 | last_ph_idx_for_word = [] # [2, 11, ...] 78 | for i, ph in enumerate(ph_split): 79 | if ph in ['|', '#']: 80 | last_ph_idx_for_word.append(i) 81 | elif not ph[0].isalnum(): 82 | if ph not in ['']: 83 | last_ph_idx_for_word.append(i - 1) 84 | last_ph_idx_for_word.append(i) 85 | start_ph_idx_for_word = [0] + [i + 1 for i in last_ph_idx_for_word[:-1]] 86 | for i, (s_w, e_w) in enumerate(zip(start_ph_idx_for_word, last_ph_idx_for_word)): 87 | ph_words.append(ph_split[s_w:e_w + 1]) 88 | ph2word[s_w:e_w + 1] = i 89 | ph2word = ph2word.tolist() 90 | ph_words = ["_".join(w) for w in ph_words] 91 | 92 | # mel side mapping to word 93 | mel2word = [] 94 | dur_word = [0 for _ in range(len(ph_words))] 95 | for i, m2p in enumerate(res['mel2ph']): 96 | word_idx = ph2word[m2p - 1] 97 | mel2word.append(ph2word[m2p - 1]) 98 | dur_word[word_idx] += 1 99 | ph2word = [x + 1 for x in ph2word] # 0预留给padding 100 | mel2word = [x + 1 for x in mel2word] # 0预留给padding 101 | res['ph_words'] = ph_words # [T_word] 102 | res['ph2word'] = ph2word # [T_ph] 103 | res['mel2word'] = mel2word # [T_mel] 104 | res['dur_word'] = dur_word # [T_word] 105 | 106 | words = [x for x in res['txt']] 107 | if words[-1] in PUNCS: 108 | words = words[:-1] 109 | words = [''] + words + [''] 110 | word_tokens = word_encoder.encode(" ".join(words)) 111 | res['words'] = words 112 | res['word_tokens'] = word_tokens 113 | assert len(words) == len(ph_words), [words, ph_words] 114 | 115 | # words = [x for x in res['txt'].split(" ") if x != ''] 116 | # while len(words) > 0 and is_sil_phoneme(words[0]): 117 | # words = words[1:] 118 | # while len(words) > 0 and is_sil_phoneme(words[-1]): 119 | # words = words[:-1] 120 | # words = [''] + words + [''] 121 | # word_tokens = word_encoder.encode(" ".join(words)) 122 | # res['words'] = words 123 | # res['word_tokens'] = word_tokens 124 | # assert len(words) == len(ph_words_nosep), [words, ph_words_nosep] 125 | 126 | 127 | if __name__ == "__main__": 128 | set_hparams() 129 | ZhBinarizer().process() 130 | -------------------------------------------------------------------------------- /data_gen/tts/mfa_config.yaml: -------------------------------------------------------------------------------- 1 | beam: 10 2 | retry_beam: 40 3 | 4 | features: 5 | type: "mfcc" 6 | use_energy: false 7 | frame_shift: 10 8 | 9 | training: 10 | - monophone: 11 | num_iterations: 40 12 | max_gaussians: 1000 13 | subset: 0 14 | boost_silence: 1.25 15 | 16 | - triphone: 17 | num_iterations: 35 18 | num_leaves: 2000 19 | max_gaussians: 10000 20 | cluster_threshold: -1 21 | subset: 0 22 | boost_silence: 1.25 23 | power: 0.25 24 | 25 | - lda: 26 | num_leaves: 2500 27 | max_gaussians: 15000 28 | subset: 0 29 | num_iterations: 35 30 | features: 31 | splice_left_context: 3 32 | splice_right_context: 3 33 | 34 | - sat: 35 | num_leaves: 2500 36 | max_gaussians: 15000 37 | power: 0.2 38 | silence_weight: 0.0 39 | fmllr_update_type: "diag" 40 | subset: 0 41 | features: 42 | lda: true 43 | 44 | - sat: 45 | num_leaves: 4200 46 | max_gaussians: 40000 47 | power: 0.2 48 | silence_weight: 0.0 49 | fmllr_update_type: "diag" 50 | subset: 0 51 | features: 52 | lda: true 53 | fmllr: true -------------------------------------------------------------------------------- /data_gen/tts/tacotron/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /data_gen/tts/tacotron/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from data_gen.tts.tacotron.audio_processing import dynamic_range_compression 4 | from data_gen.tts.tacotron.audio_processing import dynamic_range_decompression 5 | from data_gen.tts.tacotron.stft import STFT 6 | 7 | 8 | class LinearNorm(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 10 | super(LinearNorm, self).__init__() 11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 12 | 13 | torch.nn.init.xavier_uniform_( 14 | self.linear_layer.weight, 15 | gain=torch.nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class ConvNorm(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 23 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 24 | super(ConvNorm, self).__init__() 25 | if padding is None: 26 | assert(kernel_size % 2 == 1) 27 | padding = int(dilation * (kernel_size - 1) / 2) 28 | 29 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, 32 | bias=bias) 33 | 34 | torch.nn.init.xavier_uniform_( 35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 36 | 37 | def forward(self, signal): 38 | conv_signal = self.conv(signal) 39 | return conv_signal 40 | 41 | 42 | class TacotronSTFT(torch.nn.Module): 43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 45 | mel_fmax=8000.0): 46 | super(TacotronSTFT, self).__init__() 47 | self.n_mel_channels = n_mel_channels 48 | self.sampling_rate = sampling_rate 49 | self.stft_fn = STFT(filter_length, hop_length, win_length) 50 | mel_basis = librosa_mel_fn( 51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 52 | mel_basis = torch.from_numpy(mel_basis).float() 53 | self.register_buffer('mel_basis', mel_basis) 54 | 55 | def spectral_normalize(self, magnitudes): 56 | output = dynamic_range_compression(magnitudes) 57 | return output 58 | 59 | def spectral_de_normalize(self, magnitudes): 60 | output = dynamic_range_decompression(magnitudes) 61 | return output 62 | 63 | def mel_spectrogram(self, y): 64 | """Computes mel-spectrograms from a batch of waves 65 | PARAMS 66 | ------ 67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 68 | 69 | RETURNS 70 | ------- 71 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 72 | """ 73 | assert(torch.min(y.data) >= -1) 74 | assert(torch.max(y.data) <= 1) 75 | 76 | magnitudes, phases = self.stft_fn.transform(y) 77 | magnitudes = magnitudes.data 78 | mel_output = torch.matmul(self.mel_basis, magnitudes) 79 | mel_output = self.spectral_normalize(mel_output) 80 | return mel_output 81 | -------------------------------------------------------------------------------- /data_gen/tts/tacotron/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from data_gen.tts.tacotron.audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann'): 46 | super(STFT, self).__init__() 47 | self.filter_length = filter_length 48 | self.hop_length = hop_length 49 | self.win_length = win_length 50 | self.window = window 51 | self.forward_transform = None 52 | scale = self.filter_length / self.hop_length 53 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 54 | 55 | cutoff = int((self.filter_length / 2 + 1)) 56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 57 | np.imag(fourier_basis[:cutoff, :])]) 58 | 59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 60 | inverse_basis = torch.FloatTensor( 61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 62 | 63 | if window is not None: 64 | assert(filter_length >= win_length) 65 | # get window and zero center pad it to filter_length 66 | fft_window = get_window(window, win_length, fftbins=True) 67 | fft_window = pad_center(fft_window, filter_length) 68 | fft_window = torch.from_numpy(fft_window).float() 69 | 70 | # window the bases 71 | forward_basis *= fft_window 72 | inverse_basis *= fft_window 73 | 74 | self.register_buffer('forward_basis', forward_basis.float()) 75 | self.register_buffer('inverse_basis', inverse_basis.float()) 76 | 77 | def transform(self, input_data): 78 | num_batches = input_data.size(0) 79 | num_samples = input_data.size(1) 80 | 81 | self.num_samples = num_samples 82 | 83 | # similar to librosa, reflect-pad the input 84 | input_data = input_data.view(num_batches, 1, num_samples) 85 | input_data = F.pad( 86 | input_data.unsqueeze(1), 87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 88 | mode='reflect') 89 | input_data = input_data.squeeze(1) 90 | 91 | forward_transform = F.conv1d( 92 | input_data, 93 | Variable(self.forward_basis, requires_grad=False), 94 | stride=self.hop_length, 95 | padding=0) 96 | 97 | cutoff = int((self.filter_length / 2) + 1) 98 | real_part = forward_transform[:, :cutoff, :] 99 | imag_part = forward_transform[:, cutoff:, :] 100 | 101 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 102 | phase = torch.autograd.Variable( 103 | torch.atan2(imag_part.data, real_part.data)) 104 | 105 | return magnitude, phase 106 | 107 | def inverse(self, magnitude, phase): 108 | recombine_magnitude_phase = torch.cat( 109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 110 | 111 | inverse_transform = F.conv_transpose1d( 112 | recombine_magnitude_phase, 113 | Variable(self.inverse_basis, requires_grad=False), 114 | stride=self.hop_length, 115 | padding=0) 116 | 117 | if self.window is not None: 118 | window_sum = window_sumsquare( 119 | self.window, magnitude.size(-1), hop_length=self.hop_length, 120 | win_length=self.win_length, n_fft=self.filter_length, 121 | dtype=np.float32) 122 | # remove modulation effects 123 | approx_nonzero_indices = torch.from_numpy( 124 | np.where(window_sum > tiny(window_sum))[0]) 125 | window_sum = torch.autograd.Variable( 126 | torch.from_numpy(window_sum), requires_grad=False) 127 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 128 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 129 | 130 | # scale by hop ratio 131 | inverse_transform *= float(self.filter_length) / self.hop_length 132 | 133 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 134 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 135 | 136 | return inverse_transform 137 | 138 | def forward(self, input_data): 139 | self.magnitude, self.phase = self.transform(input_data) 140 | reconstruction = self.inverse(self.magnitude, self.phase) 141 | return reconstruction 142 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/base_text_processor.py: -------------------------------------------------------------------------------- 1 | class BaseTxtProcessor: 2 | @staticmethod 3 | def sp_phonemes(): 4 | return ['|'] 5 | 6 | @classmethod 7 | def process(cls, txt, pre_align_args): 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/en.py: -------------------------------------------------------------------------------- 1 | import re 2 | from data_gen.tts.data_gen_utils import PUNCS 3 | from g2p_en import G2p 4 | import unicodedata 5 | from g2p_en.expand import normalize_numbers 6 | from nltk import pos_tag 7 | from nltk.tokenize import TweetTokenizer 8 | 9 | from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor 10 | 11 | 12 | class EnG2p(G2p): 13 | word_tokenize = TweetTokenizer().tokenize 14 | 15 | def __call__(self, text): 16 | # preprocessing 17 | words = EnG2p.word_tokenize(text) 18 | tokens = pos_tag(words) # tuples of (word, tag) 19 | 20 | # steps 21 | prons = [] 22 | for word, pos in tokens: 23 | if re.search("[a-z]", word) is None: 24 | pron = [word] 25 | 26 | elif word in self.homograph2features: # Check homograph 27 | pron1, pron2, pos1 = self.homograph2features[word] 28 | if pos.startswith(pos1): 29 | pron = pron1 30 | else: 31 | pron = pron2 32 | elif word in self.cmu: # lookup CMU dict 33 | pron = self.cmu[word][0] 34 | else: # predict for oov 35 | pron = self.predict(word) 36 | 37 | prons.extend(pron) 38 | prons.extend([" "]) 39 | 40 | return prons[:-1] 41 | 42 | 43 | class TxtProcessor(BaseTxtProcessor): 44 | g2p = EnG2p() 45 | 46 | @staticmethod 47 | def preprocess_text(text): 48 | text = normalize_numbers(text) 49 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 50 | if unicodedata.category(char) != 'Mn') # Strip accents 51 | text = text.lower() 52 | text = re.sub("[\'\"()]+", "", text) 53 | text = re.sub("[-]+", " ", text) 54 | text = re.sub(f"[^ a-z{PUNCS}]", "", text) 55 | text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> ! 56 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 57 | text = text.replace("i.e.", "that is") 58 | text = text.replace("i.e.", "that is") 59 | text = text.replace("etc.", "etc") 60 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 61 | text = re.sub(rf"\s+", r" ", text) 62 | return text 63 | 64 | @classmethod 65 | def process(cls, txt, pre_align_args): 66 | txt = cls.preprocess_text(txt).strip() 67 | phs = cls.g2p(txt) 68 | phs_ = [] 69 | n_word_sep = 0 70 | for p in phs: 71 | if p.strip() == '': 72 | phs_ += ['|'] 73 | n_word_sep += 1 74 | else: 75 | phs_ += p.split(" ") 76 | phs = phs_ 77 | assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"") 78 | return phs, txt 79 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/en_syl.py: -------------------------------------------------------------------------------- 1 | from syllabipy.sonoripy import SonoriPy 2 | from data_gen.tts.txt_processors import en 3 | 4 | 5 | class TxtProcessor(en.TxtProcessor): 6 | @classmethod 7 | def process(cls, txt, pre_align_args): 8 | txt = cls.preprocess_text(txt) 9 | phs = [] 10 | for p in txt.split(" "): 11 | if len(p) == 0: 12 | continue 13 | syl = SonoriPy(p) 14 | if len(syl) == 0: 15 | phs += list(p) 16 | else: 17 | for x in syl: 18 | phs += list(x) 19 | phs += ['|'] 20 | return phs, txt 21 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/zh.py: -------------------------------------------------------------------------------- 1 | import re 2 | import jieba 3 | from pypinyin import pinyin, Style 4 | from data_gen.tts.data_gen_utils import PUNCS 5 | from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor 6 | from utils.text_norm import NSWNormalizer 7 | 8 | ALL_SHENMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', 9 | 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'] 10 | 11 | 12 | class TxtProcessor(BaseTxtProcessor): 13 | table = {ord(f): ord(t) for f, t in zip( 14 | u':,。!?【】()%#@&1234567890', 15 | u':,.!?[]()%#@&1234567890')} 16 | 17 | @staticmethod 18 | def sp_phonemes(): 19 | return ['|', '#'] 20 | 21 | @staticmethod 22 | def preprocess_text(text): 23 | text = text.translate(TxtProcessor.table) 24 | text = NSWNormalizer(text).normalize(remove_punc=False).lower() 25 | text = re.sub("[\'\"()]+", "", text) 26 | text = re.sub("[-]+", " ", text) 27 | text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text) 28 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 29 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 30 | text = re.sub(rf"\s+", r"", text) 31 | text = re.sub(rf"[A-Za-z]+", r"$", text) 32 | return text 33 | 34 | @classmethod 35 | def pinyin_with_en(cls, txt, style): 36 | x = pinyin(txt, style) 37 | x = [t[0] for t in x] 38 | x_ = [] 39 | for t in x: 40 | if '$' not in t: 41 | x_.append(t) 42 | else: 43 | x_ += list(t) 44 | x_ = [t if t != '$' else 'ENG' for t in x_] 45 | return x_ 46 | 47 | @classmethod 48 | def process(cls, txt, pre_align_args): 49 | txt = cls.preprocess_text(txt) 50 | 51 | # https://blog.csdn.net/zhoulei124/article/details/89055403 52 | shengmu = cls.pinyin_with_en(txt, style=Style.INITIALS) 53 | yunmu = cls.pinyin_with_en(txt, style= 54 | Style.FINALS_TONE3 if pre_align_args['use_tone'] else Style.FINALS) 55 | assert len(shengmu) == len(yunmu) 56 | ph_list = [] 57 | for a, b in zip(shengmu, yunmu): 58 | if a == b: 59 | ph_list += [a] 60 | else: 61 | ph_list += [a + "%" + b] 62 | seg_list = '#'.join(jieba.cut(txt)) 63 | assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list) 64 | 65 | # 加入词边界'#' 66 | ph_list_ = [] 67 | seg_idx = 0 68 | for p in ph_list: 69 | if seg_list[seg_idx] == '#': 70 | ph_list_.append('#') 71 | seg_idx += 1 72 | elif len(ph_list_) > 0: 73 | ph_list_.append("|") 74 | seg_idx += 1 75 | finished = False 76 | if not finished: 77 | ph_list_ += [x for x in p.split("%") if x != ''] 78 | 79 | ph_list = ph_list_ 80 | 81 | # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...] 82 | sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes() 83 | ph_list_ = [] 84 | for i in range(0, len(ph_list), 1): 85 | if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes): 86 | ph_list_.append(ph_list[i]) 87 | ph_list = ph_list_ 88 | return ph_list, txt 89 | 90 | 91 | if __name__ == '__main__': 92 | t = 'simon演唱过后,simon还进行了simon精彩的文艺演出simon.' 93 | phs, txt = TxtProcessor.process(t, {'use_tone': True}) 94 | print(phs, txt) 95 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/zh_g2pM.py: -------------------------------------------------------------------------------- 1 | import re 2 | import jieba 3 | from pypinyin import pinyin, Style 4 | from data_gen.tts.data_gen_utils import PUNCS 5 | from data_gen.tts.txt_processors import zh 6 | from g2pM import G2pM 7 | 8 | ALL_SHENMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', 9 | 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'] 10 | 11 | 12 | class TxtProcessor(zh.TxtProcessor): 13 | model = G2pM() 14 | 15 | @staticmethod 16 | def sp_phonemes(): 17 | return ['|', '#'] 18 | 19 | @classmethod 20 | def process(cls, txt, pre_align_args): 21 | txt = cls.preprocess_text(txt) 22 | ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True) 23 | seg_list = '#'.join(jieba.cut(txt)) 24 | assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list) 25 | 26 | # 加入词边界'#' 27 | ph_list_ = [] 28 | seg_idx = 0 29 | for p in ph_list: 30 | p = p.replace("u:", "v") 31 | if seg_list[seg_idx] == '#': 32 | ph_list_.append('#') 33 | seg_idx += 1 34 | else: 35 | ph_list_.append("|") 36 | seg_idx += 1 37 | if re.findall('[\u4e00-\u9fff]', p): 38 | if pre_align_args['use_tone']: 39 | p = pinyin(p, style=Style.TONE3, strict=True)[0][0] 40 | if p[-1] not in ['1', '2', '3', '4', '5']: 41 | p = p + '5' 42 | else: 43 | p = pinyin(p, style=Style.NORMAL, strict=True)[0][0] 44 | 45 | finished = False 46 | if len([c.isalpha() for c in p]) > 1: 47 | for shenmu in ALL_SHENMU: 48 | if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric(): 49 | ph_list_ += [shenmu, p.lstrip(shenmu)] 50 | finished = True 51 | break 52 | if not finished: 53 | ph_list_.append(p) 54 | 55 | ph_list = ph_list_ 56 | 57 | # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...] 58 | sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes() 59 | ph_list_ = [] 60 | for i in range(0, len(ph_list), 1): 61 | if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes): 62 | ph_list_.append(ph_list[i]) 63 | ph_list = ph_list_ 64 | return ph_list, txt 65 | 66 | 67 | if __name__ == '__main__': 68 | phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True}) 69 | print(phs) 70 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/zh_g2pM_song_seg.py: -------------------------------------------------------------------------------- 1 | import re 2 | from data_gen.tts.data_gen_utils import PUNCS 3 | from data_gen.tts.txt_processors import zh_g2pM 4 | 5 | from utils.text_norm import NSWNormalizer 6 | 7 | 8 | class TxtProcessor(zh_g2pM.TxtProcessor): 9 | @staticmethod 10 | def preprocess_text(text): 11 | text = text.translate(TxtProcessor.table) 12 | text = NSWNormalizer(text).normalize(remove_punc=False) 13 | text = re.sub("[\'\"()]+", "", text) 14 | text = re.sub("[-]+", " ", text) 15 | text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}&]", "", text) 16 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 17 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 18 | text = re.sub(rf"\s+", r"", text) 19 | return text 20 | 21 | @staticmethod 22 | def sp_phonemes(): 23 | return ['|', '#', '&'] 24 | 25 | @classmethod 26 | def process(cls, txt, pre_align_args): 27 | txt = txt.replace('SEP', '&') 28 | ph_list, txt = super().process(txt, pre_align_args) 29 | txt = txt.replace('&', ' SEP ') 30 | ph_list = [p if p != '&' else 'SEP' for p in ph_list if p not in ['|', '#', '', '']] 31 | return ph_list, txt 32 | -------------------------------------------------------------------------------- /data_gen/tts/txt_processors/zh_song_seg.py: -------------------------------------------------------------------------------- 1 | import re 2 | from data_gen.tts.data_gen_utils import PUNCS 3 | from data_gen.tts.txt_processors import zh 4 | from utils.text_norm import NSWNormalizer 5 | 6 | 7 | class TxtProcessor(zh.TxtProcessor): 8 | @staticmethod 9 | def preprocess_text(text): 10 | text = text.translate(TxtProcessor.table) 11 | text = NSWNormalizer(text).normalize(remove_punc=False) 12 | text = re.sub("[\'\"()]+", "", text) 13 | text = re.sub("[-]+", " ", text) 14 | text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}&]", "", text) 15 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 16 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 17 | text = re.sub(rf"\s+", r"", text) 18 | return text 19 | 20 | @staticmethod 21 | def sp_phonemes(): 22 | return ['|', '#', '&'] 23 | 24 | @classmethod 25 | def process(cls, txt, pre_align_args): 26 | txt = txt.replace('SEP', '&') 27 | ph_list, txt = super().process(txt, pre_align_args) 28 | txt = txt.replace('&', ' SEP ') 29 | ph_list = [p if p != '&' else 'SEP' for p in ph_list if p not in ['|', '#', '', '']] 30 | return ph_list, txt 31 | -------------------------------------------------------------------------------- /data_gen/tts/vocoder_binarizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | from collections import Counter 6 | from utils.text_encoder import TokenTextEncoder 7 | 8 | from utils.multiprocess_utils import chunked_multiprocess_run 9 | import random 10 | import traceback 11 | import json 12 | from resemblyzer import VoiceEncoder 13 | from tqdm import tqdm 14 | from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder, is_sil_phoneme 15 | from utils.hparams import hparams, set_hparams 16 | import numpy as np 17 | from utils.indexed_datasets import IndexedDatasetBuilder 18 | from vocoders.base_vocoder import get_vocoder_cls 19 | import pandas as pd 20 | 21 | 22 | class BinarizationError(Exception): 23 | pass 24 | 25 | 26 | class VocoderBinarizer: 27 | def __init__(self, processed_data_dir=None): 28 | if processed_data_dir is None: 29 | processed_data_dir = hparams['processed_data_dir'] 30 | self.processed_data_dirs = processed_data_dir.split(",") 31 | self.binarization_args = hparams['binarization_args'] 32 | self.pre_align_args = hparams['pre_align_args'] 33 | self.item2wavfn = {} 34 | 35 | def load_meta_data(self): 36 | for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): 37 | self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str) 38 | for r_idx, r in tqdm(self.meta_df.iterrows(), desc='Loading meta data.'): 39 | item_name = raw_item_name = r['item_name'] 40 | if len(self.processed_data_dirs) > 1: 41 | item_name = f'ds{ds_id}_{item_name}' 42 | self.item2wavfn[item_name] = r['wav_fn'] 43 | self.item_names = sorted(list(self.item2wavfn.keys())) 44 | if self.binarization_args['shuffle']: 45 | random.seed(1234) 46 | random.shuffle(self.item_names) 47 | 48 | @property 49 | def train_item_names(self): 50 | return self.item_names[hparams['test_num']:] 51 | 52 | @property 53 | def valid_item_names(self): 54 | return self.item_names[:hparams['test_num']] 55 | 56 | @property 57 | def test_item_names(self): 58 | return self.valid_item_names 59 | 60 | def meta_data(self, prefix): 61 | if prefix == 'valid': 62 | item_names = self.valid_item_names 63 | elif prefix == 'test': 64 | item_names = self.test_item_names 65 | else: 66 | item_names = self.train_item_names 67 | for item_name in item_names: 68 | wav_fn = self.item2wavfn[item_name] 69 | yield item_name, wav_fn 70 | 71 | def process(self): 72 | self.load_meta_data() 73 | os.makedirs(hparams['binary_data_dir'], exist_ok=True) 74 | self.process_data('valid') 75 | self.process_data('test') 76 | self.process_data('train') 77 | 78 | def process_data(self, prefix): 79 | data_dir = hparams['binary_data_dir'] 80 | args = [] 81 | builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}') 82 | mel_lengths = [] 83 | total_sec = 0 84 | meta_data = list(self.meta_data(prefix)) 85 | for m in meta_data: 86 | args.append(list(m) + [self.binarization_args]) 87 | num_workers = self.num_workers 88 | for f_id, (_, item) in enumerate( 89 | zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))): 90 | if item is None: 91 | continue 92 | if not self.binarization_args['with_wav'] and 'wav' in item: 93 | del item['wav'] 94 | builder.add_item(item) 95 | mel_lengths.append(item['len']) 96 | total_sec += item['sec'] 97 | builder.finalize() 98 | np.save(f'{data_dir}/{prefix}_lengths.npy', mel_lengths) 99 | print(f"| {prefix} total duration: {total_sec:.3f}s") 100 | 101 | @classmethod 102 | def process_item(cls, item_name, wav_fn, binarization_args): 103 | res = {'item_name': item_name, 'wav_fn': wav_fn} 104 | if binarization_args['with_linear']: 105 | wav, mel, linear_stft = get_vocoder_cls(hparams).wav2spec(wav_fn, return_linear=True) 106 | res['linear'] = linear_stft 107 | else: 108 | wav, mel = get_vocoder_cls(hparams).wav2spec(wav_fn) 109 | wav = wav.astype(np.float16) 110 | res.update({'mel': mel, 'wav': wav, 111 | 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]}) 112 | 113 | return res 114 | 115 | @classmethod 116 | def process_mel_item(cls, item_name, mel, wav_fn, binarization_args): 117 | res = {'item_name': item_name, 'wav_fn': wav_fn} 118 | mel = mel 119 | wav = np.ones((1,500,100)) 120 | res.update({'mel': mel, 'wav': wav, 121 | 'sec': 0, 'len': mel.shape[0]}) 122 | return res 123 | 124 | @property 125 | def num_workers(self): 126 | return int(os.getenv('N_PROC', hparams.get('N_PROC', os.cpu_count()))) 127 | 128 | 129 | if __name__ == "__main__": 130 | set_hparams() 131 | VocoderBinarizer().process() 132 | -------------------------------------------------------------------------------- /data_gen/tts/vocoder_binarizer_tacotron.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | from collections import Counter 6 | from utils.text_encoder import TokenTextEncoder 7 | import torch 8 | from utils.multiprocess_utils import chunked_multiprocess_run 9 | import random 10 | import traceback 11 | import json 12 | from resemblyzer import VoiceEncoder 13 | from tqdm import tqdm 14 | from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder, is_sil_phoneme 15 | from utils.hparams import hparams, set_hparams 16 | import numpy as np 17 | from utils.indexed_datasets import IndexedDatasetBuilder 18 | from vocoders.base_vocoder import get_vocoder_cls 19 | import pandas as pd 20 | from scipy.io.wavfile import read 21 | 22 | MAX_WAV_VALUE = 32768.0 23 | 24 | import torch 25 | from data_gen.tts.tacotron.layers import TacotronSTFT 26 | 27 | 28 | 29 | class BinarizationError(Exception): 30 | pass 31 | 32 | def load_wav_to_torch(full_path): 33 | sampling_rate, data = read(full_path) 34 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 35 | 36 | class VocoderBinarizer_Tacotron: 37 | def __init__(self, processed_data_dir=None): 38 | if processed_data_dir is None: 39 | processed_data_dir = hparams['processed_data_dir'] 40 | self.processed_data_dirs = processed_data_dir.split(",") 41 | self.binarization_args = hparams['binarization_args'] 42 | self.pre_align_args = hparams['pre_align_args'] 43 | self.item2wavfn = {} 44 | self.stft = TacotronSTFT( 45 | hparams['fft_size'], hparams['hop_size'], hparams['win_size'], 46 | hparams['audio_num_mel_bins'], hparams['audio_sample_rate'], hparams['mel_fmin'], 47 | hparams['mel_fmax']) 48 | 49 | def load_meta_data(self): 50 | for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): 51 | self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str) 52 | for r_idx, r in tqdm(self.meta_df.iterrows(), desc='Loading meta data.'): 53 | item_name = raw_item_name = r['item_name'] 54 | if len(self.processed_data_dirs) > 1: 55 | item_name = f'ds{ds_id}_{item_name}' 56 | self.item2wavfn[item_name] = r['wav_fn'] 57 | self.item_names = sorted(list(self.item2wavfn.keys())) 58 | if self.binarization_args['shuffle']: 59 | random.seed(1234) 60 | random.shuffle(self.item_names) 61 | 62 | @property 63 | def train_item_names(self): 64 | return self.item_names[hparams['test_num']:] 65 | 66 | @property 67 | def valid_item_names(self): 68 | return self.item_names[:hparams['test_num']] 69 | 70 | @property 71 | def test_item_names(self): 72 | return self.valid_item_names 73 | 74 | def meta_data(self, prefix): 75 | if prefix == 'valid': 76 | item_names = self.valid_item_names 77 | elif prefix == 'test': 78 | item_names = self.test_item_names 79 | else: 80 | item_names = self.train_item_names 81 | for item_name in item_names: 82 | wav_fn = self.item2wavfn[item_name] 83 | yield item_name, wav_fn 84 | 85 | def process(self): 86 | self.load_meta_data() 87 | os.makedirs(hparams['binary_data_dir'], exist_ok=True) 88 | self.process_data('valid') 89 | self.process_data('test') 90 | self.process_data('train') 91 | 92 | 93 | 94 | def process_data(self, prefix): 95 | data_dir = hparams['binary_data_dir'] 96 | args = [] 97 | builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}') 98 | mel_lengths = [] 99 | total_sec = 0 100 | meta_data = list(self.meta_data(prefix)) 101 | for m in meta_data: 102 | args.append(list(m) + [self.binarization_args]) 103 | num_workers = self.num_workers 104 | for f_id, (_, item) in enumerate( 105 | zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))): 106 | if item is None: 107 | continue 108 | 109 | audio, sampling_rate = load_wav_to_torch(item['wav_fn']) 110 | if sampling_rate != hparams['audio_sample_rate']: 111 | raise ValueError 112 | audio_norm = audio / MAX_WAV_VALUE 113 | audio_norm = audio_norm.unsqueeze(0) 114 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 115 | melspec = self.stft.mel_spectrogram(audio_norm) 116 | melspec = torch.squeeze(melspec, 0) 117 | 118 | wav = audio_norm.squeeze().detach().cpu().numpy().astype(np.float16) 119 | mel = melspec.transpose(0, 1).detach().cpu().numpy() 120 | item.update({'mel': mel, 'wav': wav, 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]}) 121 | if not self.binarization_args['with_wav'] and 'wav' in item: 122 | del item['wav'] 123 | 124 | builder.add_item(item) 125 | mel_lengths.append(item['len']) 126 | total_sec += item['sec'] 127 | builder.finalize() 128 | np.save(f'{data_dir}/{prefix}_lengths.npy', mel_lengths) 129 | print(f"| {prefix} total duration: {total_sec:.3f}s") 130 | 131 | @classmethod 132 | def process_item(cls, item_name, wav_fn, binarization_args): 133 | res = {'item_name': item_name, 'wav_fn': wav_fn} 134 | return res 135 | 136 | @classmethod 137 | def process_mel_item(cls, item_name, mel, wav_fn, binarization_args): 138 | res = {'item_name': item_name, 'wav_fn': wav_fn} 139 | mel = mel 140 | wav = np.ones((1,500,100)) 141 | res.update({'mel': mel, 'wav': wav, 142 | 'sec': 0, 'len': mel.shape[0]}) 143 | return res 144 | 145 | @property 146 | def num_workers(self): 147 | return int(os.getenv('N_PROC', hparams.get('N_PROC', os.cpu_count()))) 148 | 149 | 150 | if __name__ == "__main__": 151 | set_hparams() 152 | VocoderBinarizer().process() 153 | -------------------------------------------------------------------------------- /data_gen/tts/vocoder_pre_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | 5 | import librosa 6 | from utils import audio 7 | from data_gen.tts.data_gen_utils import is_sil_phoneme 8 | from utils.multiprocess_utils import chunked_multiprocess_run 9 | import traceback 10 | import importlib 11 | from utils.hparams import hparams, set_hparams 12 | import json 13 | import os 14 | import subprocess 15 | from tqdm import tqdm 16 | import pandas as pd 17 | from utils.rnnoise import rnnoise 18 | 19 | 20 | class VocoderPreAlign: 21 | def __init__(self): 22 | self.pre_align_args = hparams['pre_align_args'] 23 | self.raw_data_dir = hparams['raw_data_dir'] 24 | self.processed_dir = hparams['processed_data_dir'] 25 | 26 | def meta_data(self): 27 | raise NotImplementedError 28 | 29 | 30 | @staticmethod 31 | def process_wav(idx, item_name, wav_fn, processed_dir, pre_align_args): 32 | if pre_align_args['sox_to_wav'] or pre_align_args['trim_sil'] or \ 33 | pre_align_args['sox_resample'] or pre_align_args['denoise']: 34 | sr = hparams['audio_sample_rate'] 35 | new_wav_fn = f"{processed_dir}/wav_inputs/{idx}" 36 | subprocess.check_call(f'sox "{wav_fn}" -t wav "{new_wav_fn}.wav"', shell=True) 37 | if pre_align_args['sox_resample']: 38 | subprocess.check_call(f'sox -v 0.95 "{new_wav_fn}.wav" -r{sr} "{new_wav_fn}_rs.wav"', shell=True) 39 | new_wav_fn = new_wav_fn + '_rs' 40 | if pre_align_args['denoise']: 41 | rnnoise(wav_fn, new_wav_fn + '_denoise.wav', out_sample_rate=sr) 42 | new_wav_fn = new_wav_fn + '_denoise' 43 | if pre_align_args['trim_sil']: 44 | y, _ = librosa.core.load(new_wav_fn + '.wav', sr=sr) 45 | y, _ = librosa.effects.trim(y) 46 | audio.save_wav(y, new_wav_fn + '_trim.wav', sr, norm=True) 47 | new_wav_fn = new_wav_fn + '_trim' 48 | return new_wav_fn + '.wav' 49 | else: 50 | return wav_fn 51 | 52 | def process(self): 53 | set_hparams() 54 | processed_dir = self.processed_dir 55 | subprocess.check_call(f'rm -rf {processed_dir}/mfa_inputs', shell=True) 56 | os.makedirs(f"{processed_dir}/wav_inputs", exist_ok=True) 57 | meta_df = [] 58 | 59 | args = [] 60 | meta_data = [] 61 | for idx, inp_args in enumerate(tqdm(self.meta_data(), desc='Load meta data')): 62 | meta_data.append(inp_args) 63 | item_name, wav_fn = inp_args 64 | args.append([ 65 | idx, item_name, wav_fn, processed_dir, self.pre_align_args 66 | ]) 67 | item_names = [x[1] for x in args] 68 | assert len(item_names) == len(set(item_names)), 'Key `item_name` should be Unique.' 69 | 70 | for inp_args, res in zip(tqdm(meta_data, 'Processing'), chunked_multiprocess_run(self.process_job, args)): 71 | item_name, wav_fn = inp_args 72 | if res is None: 73 | print(f"| Skip {wav_fn}.") 74 | continue 75 | wav_fn = res 76 | meta_df.append({'item_name': item_name, 'wav_fn': wav_fn}) 77 | 78 | 79 | # save to csv 80 | meta_df = pd.DataFrame(meta_df) 81 | meta_df.to_csv(f"{processed_dir}/metadata_phone.csv") 82 | 83 | @classmethod 84 | def process_job(cls, idx, item_name, wav_fn, processed_dir, pre_align_args): 85 | try: 86 | wav_fn = cls.process_wav(idx, item_name, wav_fn, processed_dir, pre_align_args) 87 | if wav_fn is None: 88 | return None 89 | except: 90 | traceback.print_exc() 91 | return None 92 | group = idx // pre_align_args['nsample_per_mfa_group'] # group MFA inputs for better parallelism 93 | os.makedirs(f'{processed_dir}/mfa_inputs/{group}', exist_ok=True) 94 | ext = os.path.splitext(wav_fn)[1] 95 | new_wav_fn = f"{processed_dir}/mfa_inputs/{group}/{idx:07d}_{item_name}{ext}" 96 | cp_cmd = 'mv' if 'wav_inputs' in wav_fn else 'cp' 97 | subprocess.check_call(f'{cp_cmd} "{wav_fn}" "{new_wav_fn}"', shell=True) 98 | wav_fn = new_wav_fn 99 | return wav_fn 100 | -------------------------------------------------------------------------------- /egs/audios/LJ001-0001_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/FastDiff/c954758364aa845f37642952c7a16ff3e92e436e/egs/audios/LJ001-0001_gt.wav -------------------------------------------------------------------------------- /egs/audios/LJ001-0002_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/FastDiff/c954758364aa845f37642952c7a16ff3e92e436e/egs/audios/LJ001-0002_gt.wav -------------------------------------------------------------------------------- /egs/audios/LJ001-0003_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/FastDiff/c954758364aa845f37642952c7a16ff3e92e436e/egs/audios/LJ001-0003_gt.wav -------------------------------------------------------------------------------- /egs/datasets/audio/libritts/pre_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_gen.tts.vocoder_pre_align import VocoderPreAlign 4 | import glob 5 | 6 | 7 | class LibrittsPreAlign(VocoderPreAlign): 8 | def meta_data(self): 9 | wav_fns = sorted(glob.glob(f'{self.raw_data_dir}/*/*/*.wav')) 10 | for wav_fn in wav_fns: 11 | item_name = os.path.basename(wav_fn)[:-4] 12 | txt = 'Not Needed.' 13 | spk = item_name.split("_")[0] 14 | yield item_name, wav_fn, txt, spk 15 | 16 | 17 | if __name__ == "__main__": 18 | LibrittsPreAlign().process() 19 | -------------------------------------------------------------------------------- /egs/datasets/audio/lj/base_text2mel.yaml: -------------------------------------------------------------------------------- 1 | raw_data_dir: 'data/raw/LJSpeech-1.1' 2 | processed_data_dir: 'data/processed/ljspeech' 3 | binary_data_dir: 'data/binary/ljspeech' 4 | pre_align_cls: egs.datasets.audio.lj.pre_align.LJPreAlign 5 | binarization_args: 6 | with_spk_embed: false 7 | 8 | pitch_type: frame 9 | mel_loss: "ssim:0.5|l1:0.5" 10 | num_test_samples: 20 11 | test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294, 12 | 316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ] 13 | use_energy_embed: false 14 | test_num: 523 15 | vocoder: vocoders.hifigan.HifiGAN 16 | vocoder_ckpt: 'checkpoints/0414_hifi_lj_1' 17 | -------------------------------------------------------------------------------- /egs/datasets/audio/lj/pre_align.py: -------------------------------------------------------------------------------- 1 | from data_gen.tts.vocoder_pre_align import VocoderPreAlign 2 | 3 | 4 | class LJPreAlign(VocoderPreAlign): 5 | def meta_data(self): 6 | for l in open(f'{self.raw_data_dir}/metadata.csv').readlines(): 7 | item_name, _, txt = l.strip().split("|") 8 | wav_fn = f"{self.raw_data_dir}/wavs/{item_name}.wav" 9 | yield item_name, wav_fn, txt, 'SPK1' 10 | 11 | 12 | if __name__ == "__main__": 13 | LJPreAlign().process() 14 | -------------------------------------------------------------------------------- /egs/datasets/audio/pre_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_gen.tts.vocoder_pre_align import VocoderPreAlign 4 | import glob 5 | from pathlib import Path 6 | 7 | class PreAlign(VocoderPreAlign): 8 | def meta_data(self): 9 | wav_fns = sorted(glob.glob(f'{self.raw_data_dir}/*/*/*.wav')) + sorted(glob.glob(f'{self.raw_data_dir}/*/*.wav')) + sorted(glob.glob(f'{self.raw_data_dir}/*.wav')) 10 | for wav_fn in wav_fns: 11 | item_name = os.path.basename(wav_fn)[:-4] 12 | if os.path.exists(wav_fn): 13 | yield item_name, wav_fn 14 | 15 | 16 | if __name__ == "__main__": 17 | PreAlign().process() 18 | -------------------------------------------------------------------------------- /egs/datasets/audio/vctk/pre_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_gen.tts.vocoder_pre_align import VocoderPreAlign 4 | import glob 5 | 6 | 7 | class VCTKPreAlign(VocoderPreAlign): 8 | def meta_data(self): 9 | wav_fns = glob.glob(f'{self.raw_data_dir}/wav48/*/*.wav') 10 | for wav_fn in wav_fns: 11 | item_name = os.path.basename(wav_fn)[:-4] 12 | spk = item_name.split("_")[0] 13 | txt = "Not Needed" 14 | if os.path.exists(wav_fn): 15 | yield item_name, wav_fn, txt, spk 16 | 17 | 18 | if __name__ == "__main__": 19 | VCTKPreAlign().process() 20 | -------------------------------------------------------------------------------- /egs/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import librosa\n", 15 | "import numpy as np\n", 16 | "import torch\n", 17 | "# from audio_processing import dynamic_range_decompression\n", 18 | "# from audio_processing import dynamic_range_compression\n", 19 | "from modules.FastDiff.module.FastDiff_model import FastDiff\n", 20 | "from utils import audio\n", 21 | "from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule\n", 22 | "import IPython.display as ipd\n", 23 | "\n", 24 | "# download checkpoint to this folder\n", 25 | "state_dict = torch.load(\"pretrained_models/LJSpeech/model_ckpt_steps_500000.ckpt\")[\"state_dict\"][\"model\"]\n", 26 | "model = FastDiff().cuda()\n", 27 | "model.load_state_dict(state_dict)\n", 28 | "\n", 29 | "# hparams (donot change)\n", 30 | "fft_size, hop_size, win_length = 1024, 256, 1024\n", 31 | "window=\"hann\"\n", 32 | "num_mels=80\n", 33 | "fmin, fmax=80, 7600\n", 34 | "eps=1e-6\n", 35 | "sample_rate=22050" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "outputs": [], 42 | "source": [ 43 | "# get diffusion schedule\n", 44 | "train_noise_schedule = torch.linspace(1e-06, 0.01, 1000).cuda()\n", 45 | "diffusion_hyperparams = compute_hyperparams_given_schedule(train_noise_schedule)\n", 46 | "\n", 47 | "# map diffusion hyperparameters to gpu\n", 48 | "for key in diffusion_hyperparams:\n", 49 | " if key in [\"beta\", \"alpha\", \"sigma\"]:\n", 50 | " diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()\n", 51 | "diffusion_hyperparams = diffusion_hyperparams\n", 52 | "\n", 53 | "# load noise schedule for 6 sampling steps (recommended)\n", 54 | "#noise_schedule = torch.FloatTensor([1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984,\n", 55 | " # 0.006634317338466644, 0.09357017278671265, 0.6000000238418579]).cuda()\n", 56 | "# load noise schedule for 4 sampling steps\n", 57 | "noise_schedule = torch.FloatTensor([3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]).cuda()" 58 | ], 59 | "metadata": { 60 | "collapsed": false, 61 | "pycharm": { 62 | "name": "#%%\n" 63 | } 64 | } 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "outputs": [], 70 | "source": [ 71 | "# Direct inference from wavefroms #\n", 72 | "\n", 73 | "wav, _ = librosa.core.load('egs/audios/LJ001-0001_gt.wav', sr=22050)\n", 74 | "# get amplitude spectrogram\n", 75 | "x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,\n", 76 | " win_length=win_length, window=window, pad_mode=\"constant\")\n", 77 | "spc = np.abs(x_stft) # (n_bins, T)\n", 78 | "\n", 79 | "# get mel basis\n", 80 | "fmin = 0 if fmin == -1 else fmin\n", 81 | "fmax = sample_rate / 2 if fmax == -1 else fmax\n", 82 | "mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)\n", 83 | "mel = mel_basis @ spc\n", 84 | "mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)\n", 85 | "mel = torch.from_numpy(mel).cuda()" 86 | ], 87 | "metadata": { 88 | "collapsed": false, 89 | "pycharm": { 90 | "name": "#%%\n" 91 | } 92 | } 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "outputs": [], 98 | "source": [ 99 | "audio_length = mel.shape[-1] * hop_size\n", 100 | "pred_wav = sampling_given_noise_schedule(\n", 101 | " model, (1, 1, audio_length), diffusion_hyperparams, noise_schedule,\n", 102 | " condition=mel, ddim=False, return_sequence=False)\n", 103 | "\n", 104 | "pred_wav = pred_wav / pred_wav.abs().max()\n", 105 | "pred_wav = pred_wav.view(-1).cpu().float().numpy()\n", 106 | "audio.save_wav(pred_wav, 'egs/audios/test.wav', 22050)\n", 107 | "ipd.Audio(pred_wav, rate=sample_rate) " 108 | ], 109 | "metadata": { 110 | "collapsed": false, 111 | "pycharm": { 112 | "name": "#%%\n" 113 | } 114 | } 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "Python 3", 120 | "language": "python", 121 | "name": "python3" 122 | }, 123 | "language_info": { 124 | "codemirror_mode": { 125 | "name": "ipython", 126 | "version": 2 127 | }, 128 | "file_extension": ".py", 129 | "mimetype": "text/x-python", 130 | "name": "python", 131 | "nbconvert_exporter": "python", 132 | "pygments_lexer": "ipython2", 133 | "version": "2.7.6" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 0 138 | } -------------------------------------------------------------------------------- /egs/demo_tts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import torch\n", 15 | "from modules.FastDiff.module.FastDiff_model import FastDiff\n", 16 | "from utils import audio\n", 17 | "from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "HOP_SIZE = 256 # for 22050 frequency\n", 21 | "\n", 22 | "# download checkpoint to this folder\n", 23 | "state_dict = torch.load(\"./checkpoints/FastDiff_tacotron/model_ckpt_steps_500000.ckpt\")[\"state_dict\"][\"model\"]\n", 24 | "\n", 25 | "model = FastDiff().cuda()\n", 26 | "model.load_state_dict(state_dict)\n", 27 | "\n", 28 | "train_noise_schedule = noise_schedule = torch.linspace(1e-06, 0.01, 1000)\n", 29 | "diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)\n", 30 | "\n", 31 | "# map diffusion hyperparameters to gpu\n", 32 | "for key in diffusion_hyperparams:\n", 33 | " if key in [\"beta\", \"alpha\", \"sigma\"]:\n", 34 | " diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()\n", 35 | "diffusion_hyperparams = diffusion_hyperparams\n", 36 | "\n", 37 | "# load noise schedule for 8 sampling steps\n", 38 | "#noise_schedule = torch.FloatTensor([6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5]).cuda()\n", 39 | "# load noise schedule for 4 sampling steps\n", 40 | "noise_schedule = torch.FloatTensor([3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]).cuda()\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "outputs": [], 47 | "source": [ 48 | "# Text-to-speech\n", 49 | "tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')\n", 50 | "tacotron2 = tacotron2.to(\"cuda\").eval()\n", 51 | "\n", 52 | "text = \"Welcome to a conditional diffusion probabilistic model capable of generating high fidelity speech efficiently.\"\n", 53 | "utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')\n", 54 | "sequences, lengths = utils.prepare_input_sequence([text])\n", 55 | "\n", 56 | "with torch.no_grad():\n", 57 | " mels, _, _ = tacotron2.infer(sequences, lengths)" 58 | ], 59 | "metadata": { 60 | "collapsed": false, 61 | "pycharm": { 62 | "name": "#%%\n" 63 | } 64 | } 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "outputs": [], 70 | "source": [ 71 | "# Speech-to-waveform\n", 72 | "\n", 73 | "audio_length = mels.shape[-1] * HOP_SIZE\n", 74 | "pred_wav = sampling_given_noise_schedule(\n", 75 | " model, (1, 1, audio_length), diffusion_hyperparams, noise_schedule,\n", 76 | " condition=mels, ddim=False, return_sequence=False)\n", 77 | "\n", 78 | "pred_wav = pred_wav / pred_wav.abs().max()\n", 79 | "audio.save_wav(pred_wav.view(-1).cpu().float().numpy(), './test.wav', 22050)" 80 | ], 81 | "metadata": { 82 | "collapsed": false, 83 | "pycharm": { 84 | "name": "#%%\n" 85 | } 86 | } 87 | } 88 | ], 89 | "metadata": { 90 | "kernelspec": { 91 | "display_name": "Python 3", 92 | "language": "python", 93 | "name": "python3" 94 | }, 95 | "language_info": { 96 | "codemirror_mode": { 97 | "name": "ipython", 98 | "version": 2 99 | }, 100 | "file_extension": ".py", 101 | "mimetype": "text/x-python", 102 | "name": "python", 103 | "nbconvert_exporter": "python", 104 | "pygments_lexer": "ipython2", 105 | "version": "2.7.6" 106 | } 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 0 110 | } -------------------------------------------------------------------------------- /egs/demo_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | # Below provide the TTS pipeline in LJSpeech dataset. 4 | 5 | def synthesize(choice, N, text): 6 | # If you are curious about these choices: https://huggingface.co/spaces/NATSpeech/PortaSpeech/resolve/main/checkpoints/ 7 | exp = ['ps_normal_exp', 'fs2_exp', 'diffspeech'] 8 | steps = ['406000', '160000', '160000'] 9 | infer = ['ps_flow', 'fs2_orig', 'ds'] 10 | 11 | # Install dependencies 12 | if not os.path.exists('PortaSpeech/'): 13 | print('-----------------------Start installing dependencies-----------------------') 14 | os.system('git clone https://huggingface.co/spaces/NATSpeech/PortaSpeech.git') 15 | os.system('cp -r egs/tts/* PortaSpeech/inference/tts/') 16 | 17 | ckpt = f'PortaSpeech/checkpoints/{exp[choice]}/model_ckpt_steps_{steps[choice]}.ckpt' 18 | if not os.path.exists(ckpt) or os.path.getsize(ckpt) < 1000: 19 | os.system( 20 | f'wget https://huggingface.co/spaces/NATSpeech/PortaSpeech/resolve/main/checkpoints/{exp[choice]}/model_ckpt_steps_{steps[choice]}.ckpt') 21 | os.system(f'mv model_ckpt_steps_406000.ckpt PortaSpeech/checkpoints/{exp[choice]}') 22 | 23 | # TTS 24 | print(f'-----------------------Start text-to-spectrogram synthesis using {exp[choice]}-----------------------') 25 | os.system(f" cd PortaSpeech && CUDA_VISIBLE_DEVICES=0 python inference/tts/{infer[choice]}.py --exp_name {exp[choice]} --hparams='processed_data_dir={text.replace(',', '/')}'") 26 | 27 | # FastDiff 28 | print('-----------------------Start neural vocoding using FastDiff-----------------------') 29 | os.system(f"CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config modules/FastDiff/config/FastDiff.yaml --exp_name FastDiff --infer --hparams='test_mel_dir=PortaSpeech/infer_out/,use_wav=False,N={N}'") 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser( 34 | description="Below provide the TTS pipeline in LJSpeech dataset.") 35 | parser.add_argument("--N", type=str, default='4', help="denoising steps") 36 | parser.add_argument("--text", "-o", type=str, help="input text", 37 | default="the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.") 38 | parser.add_argument("--model", type=int, choices=[0, 1, 2], default=0, help="choice a TTS model.") 39 | args = parser.parse_args() 40 | 41 | try: 42 | synthesize(args.model, args.N, args.text) 43 | except KeyboardInterrupt: 44 | print('KeyboardInterrupt.') 45 | -------------------------------------------------------------------------------- /egs/egs_bases/config_base.yaml: -------------------------------------------------------------------------------- 1 | # task 2 | binary_data_dir: '' 3 | work_dir: '' # experiment directory. 4 | infer: false # infer 5 | amp: false 6 | seed: 1234 7 | debug: false 8 | save_codes: [] 9 | # - configs 10 | # - modules 11 | # - tasks 12 | # - utils 13 | # - usr 14 | 15 | ############# 16 | # dataset 17 | ############# 18 | ds_workers: 1 19 | test_num: 100 20 | endless_ds: false 21 | sort_by_len: true 22 | 23 | ######### 24 | # train and eval 25 | ######### 26 | print_nan_grads: false 27 | load_ckpt: '' 28 | save_best: true 29 | num_ckpt_keep: 3 30 | clip_grad_norm: 0 31 | accumulate_grad_batches: 1 32 | tb_log_interval: 100 33 | num_sanity_val_steps: 5 # steps of validation at the beginning 34 | check_val_every_n_epoch: 10 35 | val_check_interval: 2000 36 | valid_monitor_key: 'val_loss' 37 | valid_monitor_mode: 'min' 38 | max_epochs: 1000 39 | max_updates: 1000000 40 | max_tokens: 31250 41 | max_sentences: 100000 42 | max_valid_tokens: -1 43 | max_valid_sentences: -1 44 | eval_max_batches: -1 45 | test_input_dir: '' 46 | resume_from_checkpoint: 0 47 | rename_tmux: true -------------------------------------------------------------------------------- /egs/egs_bases/tts/base.yaml: -------------------------------------------------------------------------------- 1 | # task 2 | base_config: ../config_base.yaml 3 | task_cls: '' 4 | ############# 5 | # dataset 6 | ############# 7 | raw_data_dir: '' 8 | processed_data_dir: '' 9 | binary_data_dir: '' 10 | dict_dir: '' 11 | pre_align_cls: '' 12 | binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer 13 | mfa_version: 2 14 | pre_align_args: 15 | nsample_per_mfa_group: 1000 16 | txt_processor: en 17 | use_tone: true # for ZH 18 | sox_resample: false 19 | sox_to_wav: false 20 | allow_no_txt: false 21 | trim_sil: false 22 | denoise: false 23 | binarization_args: 24 | shuffle: false 25 | with_txt: true 26 | with_wav: false 27 | with_align: true 28 | with_spk_embed: false 29 | with_spk_id: true 30 | with_f0: true 31 | with_f0cwt: false 32 | with_linear: false 33 | with_word: true 34 | trim_eos_bos: false 35 | reset_phone_dict: true 36 | reset_word_dict: true 37 | word_size: 30000 38 | pitch_extractor: parselmouth 39 | 40 | loud_norm: false 41 | endless_ds: true 42 | 43 | test_num: 100 44 | min_frames: 0 45 | max_frames: 1548 46 | frames_multiple: 1 47 | max_input_tokens: 1550 48 | audio_num_mel_bins: 80 49 | audio_sample_rate: 22050 50 | hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) 51 | win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 52 | fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 53 | fmax: 7600 # To be increased/reduced depending on data. 54 | fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter 55 | min_level_db: -100 56 | ref_level_db: 20 57 | griffin_lim_iters: 60 58 | num_spk: 1 59 | mel_vmin: -6 60 | mel_vmax: 1.5 61 | ds_workers: 1 62 | 63 | ######### 64 | # models 65 | ######### 66 | dropout: 0.1 67 | enc_layers: 4 68 | dec_layers: 4 69 | hidden_size: 256 70 | num_heads: 2 71 | enc_ffn_kernel_size: 9 72 | dec_ffn_kernel_size: 9 73 | ffn_act: gelu 74 | ffn_padding: 'SAME' 75 | use_spk_id: false 76 | use_split_spk_id: false 77 | use_spk_embed: false 78 | mel_loss: l1 79 | 80 | 81 | ########### 82 | # optimization 83 | ########### 84 | lr: 2.0 85 | scheduler: rsqrt # rsqrt|none 86 | warmup_updates: 8000 87 | optimizer_adam_beta1: 0.9 88 | optimizer_adam_beta2: 0.98 89 | weight_decay: 0 90 | clip_grad_norm: 1 91 | clip_grad_value: 0 92 | 93 | 94 | ########### 95 | # train and eval 96 | ########### 97 | use_word_input: false 98 | max_tokens: 30000 99 | max_sentences: 100000 100 | max_valid_sentences: 1 101 | max_valid_tokens: 60000 102 | valid_infer_interval: 10000 103 | train_set_name: 'train' 104 | train_sets: '' 105 | valid_set_name: 'valid' 106 | test_set_name: 'test' 107 | num_test_samples: 0 108 | num_valid_plots: 10 109 | test_ids: [ ] 110 | vocoder: pwg 111 | vocoder_ckpt: '' 112 | vocoder_denoise_c: 0.0 113 | profile_infer: false 114 | out_wav_norm: false 115 | save_gt: true 116 | save_f0: false 117 | gen_dir_name: '' -------------------------------------------------------------------------------- /egs/egs_bases/tts/base_zh.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./base.yaml 2 | pre_align_args: 3 | txt_processor: zh 4 | binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer 5 | word_size: 3000 -------------------------------------------------------------------------------- /egs/egs_bases/tts/vocoder/base.yaml: -------------------------------------------------------------------------------- 1 | base_config: ../base.yaml 2 | binarization_args: 3 | with_wav: true 4 | with_spk_embed: false 5 | with_align: false 6 | with_word: false 7 | with_txt: false 8 | 9 | ########### 10 | # train and eval 11 | ########### 12 | max_samples: 25600 13 | max_sentences: 5 14 | max_valid_sentences: 1 15 | max_updates: 1000000 16 | val_check_interval: 2000 17 | 18 | ########################################################### 19 | # FEATURE EXTRACTION SETTING # 20 | ########################################################### 21 | fft_size: 1024 # FFT size. 22 | hop_size: 256 # Hop size. 23 | win_length: null # Window length. 24 | # If set to null, it will be the same as fft_size. 25 | window: "hann" # Window function. 26 | num_mels: 80 # Number of mel basis. 27 | fmin: 80 # Minimum freq in mel basis calculation. 28 | fmax: 7600 # Maximum frequency in mel basis calculation. 29 | aux_context_window: 0 # Context window size for auxiliary feature. 30 | use_pitch_embed: false 31 | 32 | generator_grad_norm: 10 # Generator's gradient norm. 33 | discriminator_grad_norm: 1 # Discriminator's gradient norm. 34 | disc_start_steps: 40000 # Number of steps to start to train discriminator. 35 | -------------------------------------------------------------------------------- /egs/tts/base_tts_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from modules.vocoder.hifigan.hifigan import HifiGanGenerator 6 | from tasks.tts.dataset_utils import FastSpeechWordDataset 7 | from tasks.tts.tts_utils import load_data_preprocessor 8 | from utils.commons.ckpt_utils import load_ckpt 9 | from utils.commons.hparams import set_hparams 10 | import numpy as np 11 | 12 | class BaseTTSInfer: 13 | def __init__(self, hparams, device=None): 14 | if device is None: 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | self.hparams = hparams 17 | self.device = device 18 | self.data_dir = hparams['binary_data_dir'] 19 | self.preprocessor, self.preprocess_args = load_data_preprocessor() 20 | self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir) 21 | self.spk_map = self.preprocessor.load_spk_map(self.data_dir) 22 | self.ds_cls = FastSpeechWordDataset 23 | self.model = self.build_model() 24 | self.model.eval() 25 | self.model.to(self.device) 26 | 27 | def build_model(self): 28 | raise NotImplementedError 29 | 30 | def forward_model(self, inp): 31 | raise NotImplementedError 32 | 33 | 34 | def preprocess_input(self, inp): 35 | """ 36 | 37 | :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} 38 | :return: 39 | """ 40 | preprocessor, preprocess_args = self.preprocessor, self.preprocess_args 41 | text_raw = inp['text'] 42 | item_name = inp.get('item_name', '') 43 | spk_name = inp.get('spk_name', '') 44 | ph, txt, word, ph2word, ph_gb_word = preprocessor.txt_to_ph( 45 | preprocessor.txt_processor, text_raw, preprocess_args) 46 | word_token = self.word_encoder.encode(word) 47 | ph_token = self.ph_encoder.encode(ph) 48 | spk_id = self.spk_map[spk_name] 49 | item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, 50 | 'ph_token': ph_token, 'word_token': word_token, 'ph2word': ph2word} 51 | item['ph_len'] = len(item['ph_token']) 52 | return item 53 | 54 | def input_to_batch(self, item): 55 | item_names = [item['item_name']] 56 | text = [item['text']] 57 | ph = [item['ph']] 58 | txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) 59 | txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) 60 | word_tokens = torch.LongTensor(item['word_token'])[None, :].to(self.device) 61 | word_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) 62 | ph2word = torch.LongTensor(item['ph2word'])[None, :].to(self.device) 63 | spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) 64 | batch = { 65 | 'item_name': item_names, 66 | 'text': text, 67 | 'ph': ph, 68 | 'txt_tokens': txt_tokens, 69 | 'txt_lengths': txt_lengths, 70 | 'word_tokens': word_tokens, 71 | 'word_lengths': word_lengths, 72 | 'ph2word': ph2word, 73 | 'spk_ids': spk_ids, 74 | } 75 | return batch 76 | 77 | def postprocess_output(self, output): 78 | return output 79 | 80 | def infer_once(self, inp): 81 | inp = self.preprocess_input(inp) 82 | output = self.forward_model(inp) 83 | output = self.postprocess_output(output) 84 | return output 85 | 86 | @classmethod 87 | def example_run(cls): 88 | from utils.commons.hparams import set_hparams 89 | from utils.commons.hparams import hparams as hp 90 | 91 | set_hparams() 92 | inp = { 93 | 'text': hp['processed_data_dir'] 94 | } 95 | infer_ins = cls(hp) 96 | out = infer_ins.infer_once(inp) 97 | os.makedirs('infer_out', exist_ok=True) 98 | np.save(f'infer_out/example_out.npy', out) 99 | -------------------------------------------------------------------------------- /egs/tts/ds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from inference.tts.fs import FastSpeechInfer 3 | # from modules.tts.fs2_orig import FastSpeech2Orig 4 | from inference.tts.base_tts_infer import BaseTTSInfer 5 | from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion 6 | from utils.commons.ckpt_utils import load_ckpt 7 | from utils.commons.hparams import hparams 8 | 9 | 10 | class DiffSpeechInfer(BaseTTSInfer): 11 | def build_model(self): 12 | dict_size = len(self.ph_encoder) 13 | model = GaussianDiffusion(dict_size, self.hparams) 14 | model.eval() 15 | load_ckpt(model, hparams['work_dir'], 'model') 16 | return model 17 | 18 | def forward_model(self, inp): 19 | sample = self.input_to_batch(inp) 20 | txt_tokens = sample['txt_tokens'] # [B, T_t] 21 | spk_id = sample.get('spk_ids') 22 | with torch.no_grad(): 23 | output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True) 24 | mel_out = output['mel_out'] 25 | mel_out = mel_out.cpu().numpy() 26 | return mel_out[0] 27 | 28 | if __name__ == '__main__': 29 | DiffSpeechInfer.example_run() 30 | -------------------------------------------------------------------------------- /egs/tts/fs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inference.tts.base_tts_infer import BaseTTSInfer 3 | from modules.tts.fs import FastSpeech 4 | from utils.commons.ckpt_utils import load_ckpt 5 | from utils.commons.hparams import hparams 6 | 7 | 8 | class FastSpeechInfer(BaseTTSInfer): 9 | def build_model(self): 10 | dict_size = len(self.ph_encoder) 11 | model = FastSpeech(dict_size, self.hparams) 12 | model.eval() 13 | load_ckpt(model, hparams['work_dir'], 'model') 14 | return model 15 | 16 | def forward_model(self, inp): 17 | sample = self.input_to_batch(inp) 18 | txt_tokens = sample['txt_tokens'] # [B, T_t] 19 | spk_id = sample.get('spk_ids') 20 | with torch.no_grad(): 21 | output = self.model(txt_tokens, spk_id=spk_id, infer=True) 22 | mel_out = output['mel_out'] 23 | mel_out = mel_out.cpu().numpy() 24 | return mel_out[0] 25 | 26 | 27 | if __name__ == '__main__': 28 | FastSpeechInfer.example_run() 29 | -------------------------------------------------------------------------------- /egs/tts/fs2_orig.py: -------------------------------------------------------------------------------- 1 | from inference.tts.fs import FastSpeechInfer 2 | from modules.tts.fs2_orig import FastSpeech2Orig 3 | from utils.commons.ckpt_utils import load_ckpt 4 | from utils.commons.hparams import hparams 5 | 6 | 7 | class FastSpeech2OrigInfer(FastSpeechInfer): 8 | def build_model(self): 9 | dict_size = len(self.ph_encoder) 10 | model = FastSpeech2Orig(dict_size, self.hparams) 11 | model.eval() 12 | load_ckpt(model, hparams['work_dir'], 'model') 13 | return model 14 | 15 | 16 | if __name__ == '__main__': 17 | FastSpeech2OrigInfer.example_run() 18 | -------------------------------------------------------------------------------- /egs/tts/ps_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inference.tts.base_tts_infer import BaseTTSInfer 3 | from modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow 4 | from utils.commons.ckpt_utils import load_ckpt 5 | from utils.commons.hparams import hparams 6 | 7 | 8 | class PortaSpeechFlowInfer(BaseTTSInfer): 9 | def build_model(self): 10 | ph_dict_size = len(self.ph_encoder) 11 | word_dict_size = len(self.word_encoder) 12 | model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams) 13 | load_ckpt(model, hparams['work_dir'], 'model') 14 | model.to(self.device) 15 | with torch.no_grad(): 16 | model.store_inverse_all() 17 | model.eval() 18 | return model 19 | 20 | def forward_model(self, inp): 21 | sample = self.input_to_batch(inp) 22 | with torch.no_grad(): 23 | output = self.model( 24 | sample['txt_tokens'], 25 | sample['word_tokens'], 26 | ph2word=sample['ph2word'], 27 | word_len=sample['word_lengths'].max(), 28 | infer=True, 29 | forward_post_glow=True, 30 | spk_id=sample.get('spk_ids') 31 | ) 32 | mel_out = output['mel_out'] 33 | mel_out = mel_out.cpu().numpy() 34 | return mel_out[0] 35 | 36 | 37 | if __name__ == '__main__': 38 | PortaSpeechFlowInfer.example_run() 39 | -------------------------------------------------------------------------------- /modules/FastDiff/config/FastDiff.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - ./base.yaml 3 | 4 | audio_sample_rate: 22050 5 | raw_data_dir: 'data/raw/LJSpeech-1.1' 6 | processed_data_dir: 'data/processed/LJSpeech' 7 | binary_data_dir: 'data/binary/LJSpeech' 8 | -------------------------------------------------------------------------------- /modules/FastDiff/config/FastDiff_libritts.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - ./base.yaml 3 | 4 | audio_sample_rate: 22050 5 | raw_data_dir: 'data/raw/LibriTTS' 6 | processed_data_dir: 'data/processed/LibriTTS' 7 | binary_data_dir: 'data/binary/LibriTTS' -------------------------------------------------------------------------------- /modules/FastDiff/config/FastDiff_tacotron.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - ./base.yaml 3 | 4 | #raw_data_dir: 'data/raw/LJSpeech-1.1' 5 | #processed_data_dir: 'data/processed/LJSpeech' 6 | #binary_data_dir: 'data/binary/LJSpeech_Taco' 7 | binary_data_dir: /apdcephfs/share_1316500/nlphuang/data/AdaGrad/LJSpeech_Taco/ 8 | 9 | binarizer_cls: data_gen.tts.vocoder_binarizer_tacotron.VocoderBinarizer_Tacotron 10 | 11 | binarization_args: 12 | with_wav: true 13 | with_spk_embed: false 14 | with_align: false 15 | with_word: false 16 | with_txt: false 17 | with_f0: false 18 | 19 | max_sentences: 50 # max batch size in training 20 | mel_fmin: 0.0 21 | mel_fmax: 8000.0 22 | valid_infer_interval: 10000 23 | val_check_interval: 2000 -------------------------------------------------------------------------------- /modules/FastDiff/config/FastDiff_vctk.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - ./base.yaml 3 | 4 | audio_sample_rate: 22050 5 | raw_data_dir: 'data/raw/VCTK' 6 | processed_data_dir: 'data/processed/VCTK' 7 | binary_data_dir: 'data/binary/VCTK' -------------------------------------------------------------------------------- /modules/FastDiff/config/base.yaml: -------------------------------------------------------------------------------- 1 | ############# 2 | # Custom dataset preprocess 3 | ############# 4 | audio_num_mel_bins: 80 5 | audio_sample_rate: 22050 6 | hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) 7 | win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) 8 | fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 9 | fmax: 7600 # To be increased/reduced depending on data. 10 | fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter 11 | min_level_db: -100 12 | ref_level_db: 20 13 | griffin_lim_iters: 60 14 | num_spk: 1 # number of speakers 15 | mel_vmin: -6 16 | mel_vmax: 1.5 17 | 18 | ############# 19 | # FastDiff Model 20 | ############# 21 | audio_channels: 1 22 | inner_channels: 32 23 | cond_channels: 80 24 | upsample_ratios: [8, 8, 4] 25 | lvc_layers_each_block: 4 26 | lvc_kernel_size: 3 27 | kpnet_hidden_channels: 64 28 | kpnet_conv_size: 3 29 | dropout: 0.0 30 | diffusion_step_embed_dim_in: 128 31 | diffusion_step_embed_dim_mid: 512 32 | diffusion_step_embed_dim_out: 512 33 | use_weight_norm: True 34 | 35 | ########### 36 | # Diffusion 37 | ########### 38 | T: 1000 39 | beta_0: 0.000001 40 | beta_T: 0.01 41 | noise_schedule: '' 42 | N: '' 43 | 44 | 45 | ########### 46 | # train and eval 47 | ########### 48 | task_cls: modules.FastDiff.task.FastDiff.FastDiffTask 49 | max_updates: 1000000 # max training steps 50 | max_samples: 25600 # audio length in training 51 | max_sentences: 20 # max batch size in training 52 | num_sanity_val_steps: -1 53 | max_valid_sentences: 1 54 | valid_infer_interval: 10000 55 | val_check_interval: 2000 56 | num_test_samples: 0 57 | num_valid_plots: 10 58 | 59 | 60 | ############# 61 | # Stage 1 of data processing 62 | ############# 63 | pre_align_cls: egs.datasets.audio.pre_align.PreAlign 64 | pre_align_args: 65 | nsample_per_mfa_group: 1000 66 | txt_processor: en 67 | use_tone: true # for ZH 68 | sox_resample: false 69 | sox_to_wav: false 70 | allow_no_txt: true 71 | trim_sil: false 72 | denoise: false 73 | 74 | 75 | ############# 76 | # Stage 2 of data processing 77 | ############# 78 | binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer 79 | binarization_args: 80 | with_wav: true 81 | with_spk_embed: false 82 | with_align: false 83 | with_word: false 84 | with_txt: false 85 | with_f0: false 86 | shuffle: false 87 | with_spk_id: true 88 | with_f0cwt: false 89 | with_linear: false 90 | trim_eos_bos: false 91 | reset_phone_dict: true 92 | reset_word_dict: true 93 | 94 | 95 | ########### 96 | # optimization 97 | ########### 98 | lr: 2e-4 # learning rate 99 | weight_decay: 0 100 | scheduler: rsqrt # rsqrt|none 101 | optimizer_adam_beta1: 0.9 102 | optimizer_adam_beta2: 0.98 103 | clip_grad_norm: 1 104 | clip_grad_value: 0 105 | 106 | ############# 107 | # Setting for this Pytorch framework 108 | ############# 109 | max_input_tokens: 1550 110 | frames_multiple: 1 111 | use_word_input: false 112 | vocoder: pwg 113 | vocoder_ckpt: '' 114 | vocoder_denoise_c: 0.0 115 | max_tokens: 30000 116 | max_valid_tokens: 60000 117 | test_ids: [ ] 118 | profile_infer: false 119 | out_wav_norm: false 120 | save_gt: true 121 | save_f0: false 122 | aux_context_window: 0 123 | test_input_dir: '' # 'wavs' # wav->wav infer 124 | test_mel_dir: '' # 'mels' # mel->wav infer 125 | use_wav: True # mel->wav infer 126 | pitch_extractor: parselmouth 127 | loud_norm: false 128 | endless_ds: true 129 | test_num: 100 130 | min_frames: 0 131 | max_frames: 1548 132 | ds_workers: 1 133 | gen_dir_name: '' 134 | accumulate_grad_batches: 1 135 | tb_log_interval: 100 136 | print_nan_grads: false 137 | work_dir: '' # experiment directory. 138 | infer: false # infer 139 | amp: false 140 | debug: false 141 | save_codes: [] 142 | save_best: true 143 | num_ckpt_keep: 3 144 | sort_by_len: true 145 | load_ckpt: '' 146 | check_val_every_n_epoch: 10 147 | max_epochs: 1000 148 | eval_max_batches: -1 149 | resume_from_checkpoint: 0 150 | rename_tmux: true 151 | valid_monitor_key: 'val_loss' 152 | valid_monitor_mode: 'min' 153 | train_set_name: 'train' 154 | train_sets: '' 155 | valid_set_name: 'valid' 156 | test_set_name: 'test' 157 | seed: 1234 -------------------------------------------------------------------------------- /modules/FastDiff/module/FastDiff_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import logging 4 | from modules.FastDiff.module.modules import DiffusionDBlock, TimeAware_LVCBlock 5 | from modules.FastDiff.module.util import calc_diffusion_step_embedding 6 | 7 | def swish(x): 8 | return x * torch.sigmoid(x) 9 | 10 | class FastDiff(nn.Module): 11 | """FastDiff module.""" 12 | 13 | def __init__(self, 14 | audio_channels=1, 15 | inner_channels=32, 16 | cond_channels=80, 17 | upsample_ratios=[8, 8, 4], 18 | lvc_layers_each_block=4, 19 | lvc_kernel_size=3, 20 | kpnet_hidden_channels=64, 21 | kpnet_conv_size=3, 22 | dropout=0.0, 23 | diffusion_step_embed_dim_in=128, 24 | diffusion_step_embed_dim_mid=512, 25 | diffusion_step_embed_dim_out=512, 26 | use_weight_norm=True): 27 | super().__init__() 28 | 29 | self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in 30 | 31 | self.audio_channels = audio_channels 32 | self.cond_channels = cond_channels 33 | self.lvc_block_nums = len(upsample_ratios) 34 | self.first_audio_conv = nn.Conv1d(1, inner_channels, 35 | kernel_size=7, padding=(7 - 1) // 2, 36 | dilation=1, bias=True) 37 | 38 | # define residual blocks 39 | self.lvc_blocks = nn.ModuleList() 40 | self.downsample = nn.ModuleList() 41 | 42 | # the layer-specific fc for noise scale embedding 43 | self.fc_t = nn.ModuleList() 44 | self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) 45 | self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) 46 | 47 | cond_hop_length = 1 48 | for n in range(self.lvc_block_nums): 49 | cond_hop_length = cond_hop_length * upsample_ratios[n] 50 | lvcb = TimeAware_LVCBlock( 51 | in_channels=inner_channels, 52 | cond_channels=cond_channels, 53 | upsample_ratio=upsample_ratios[n], 54 | conv_layers=lvc_layers_each_block, 55 | conv_kernel_size=lvc_kernel_size, 56 | cond_hop_length=cond_hop_length, 57 | kpnet_hidden_channels=kpnet_hidden_channels, 58 | kpnet_conv_size=kpnet_conv_size, 59 | kpnet_dropout=dropout, 60 | noise_scale_embed_dim_out=diffusion_step_embed_dim_out 61 | ) 62 | self.lvc_blocks += [lvcb] 63 | self.downsample.append(DiffusionDBlock(inner_channels, inner_channels, upsample_ratios[self.lvc_block_nums-n-1])) 64 | 65 | 66 | # define output layers 67 | self.final_conv = nn.Sequential(nn.Conv1d(inner_channels, audio_channels, kernel_size=7, padding=(7 - 1) // 2, 68 | dilation=1, bias=True)) 69 | 70 | # apply weight norm 71 | if use_weight_norm: 72 | self.apply_weight_norm() 73 | 74 | def forward(self, data): 75 | """Calculate forward propagation. 76 | Args: 77 | x (Tensor): Input noise signal (B, 1, T). 78 | c (Tensor): Local conditioning auxiliary features (B, C ,T'). 79 | Returns: 80 | Tensor: Output tensor (B, out_channels, T) 81 | """ 82 | audio, c, diffusion_steps = data 83 | 84 | # embed diffusion step t 85 | diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in) 86 | diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) 87 | diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) 88 | 89 | audio = self.first_audio_conv(audio) 90 | downsample = [] 91 | for down_layer in self.downsample: 92 | downsample.append(audio) 93 | audio = down_layer(audio) 94 | 95 | x = audio 96 | for n, audio_down in enumerate(reversed(downsample)): 97 | x = self.lvc_blocks[n]((x, audio_down, c, diffusion_step_embed)) 98 | 99 | # apply final layers 100 | x = self.final_conv(x) 101 | 102 | return x 103 | 104 | def remove_weight_norm(self): 105 | """Remove weight normalization module from all of the layers.""" 106 | def _remove_weight_norm(m): 107 | try: 108 | logging.debug(f"Weight norm is removed from {m}.") 109 | torch.nn.utils.remove_weight_norm(m) 110 | except ValueError: # this module didn't have weight norm 111 | return 112 | 113 | self.apply(_remove_weight_norm) 114 | 115 | def apply_weight_norm(self): 116 | """Apply weight normalization module from all of the layers.""" 117 | def _apply_weight_norm(m): 118 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): 119 | torch.nn.utils.weight_norm(m) 120 | logging.debug(f"Weight norm is applied to {m}.") 121 | 122 | self.apply(_apply_weight_norm) 123 | 124 | -------------------------------------------------------------------------------- /modules/FastDiff/module/WaveNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from modules.FastDiff.module.util import calc_noise_scale_embedding 7 | def swish(x): 8 | return x * torch.sigmoid(x) 9 | 10 | 11 | # dilated conv layer with kaiming_normal initialization 12 | # from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py 13 | class Conv(nn.Module): 14 | def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): 15 | super(Conv, self).__init__() 16 | self.padding = dilation * (kernel_size - 1) // 2 17 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) 18 | self.conv = nn.utils.weight_norm(self.conv) 19 | nn.init.kaiming_normal_(self.conv.weight) 20 | 21 | def forward(self, x): 22 | out = self.conv(x) 23 | return out 24 | 25 | 26 | # conv1x1 layer with zero initialization 27 | # from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed 28 | class ZeroConv1d(nn.Module): 29 | def __init__(self, in_channel, out_channel): 30 | super(ZeroConv1d, self).__init__() 31 | self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0) 32 | self.conv.weight.data.zero_() 33 | self.conv.bias.data.zero_() 34 | 35 | def forward(self, x): 36 | out = self.conv(x) 37 | return out 38 | 39 | 40 | # every residual block (named residual layer in paper) 41 | # contains one noncausal dilated conv 42 | class Residual_block(nn.Module): 43 | def __init__(self, res_channels, skip_channels, dilation, 44 | noise_scale_embed_dim_out, multiband=True): 45 | super(Residual_block, self).__init__() 46 | self.res_channels = res_channels 47 | 48 | # the layer-specific fc for noise scale embedding 49 | self.fc_t = nn.Linear(noise_scale_embed_dim_out, self.res_channels) 50 | 51 | # dilated conv layer 52 | self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation) 53 | 54 | # add mel spectrogram upsampler and conditioner conv1x1 layer 55 | self.upsample_conv2d = torch.nn.ModuleList() 56 | if multiband is True: 57 | params = 8 58 | else: 59 | params = 16 60 | for s in [params, params]: ####### Very Important!!!!! ####### 61 | conv_trans2d = torch.nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s)) 62 | conv_trans2d = torch.nn.utils.weight_norm(conv_trans2d) 63 | torch.nn.init.kaiming_normal_(conv_trans2d.weight) 64 | self.upsample_conv2d.append(conv_trans2d) 65 | self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1) # 80 is mel bands 66 | 67 | # residual conv1x1 layer, connect to next residual layer 68 | self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1) 69 | self.res_conv = nn.utils.weight_norm(self.res_conv) 70 | nn.init.kaiming_normal_(self.res_conv.weight) 71 | 72 | # skip conv1x1 layer, add to all skip outputs through skip connections 73 | self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1) 74 | self.skip_conv = nn.utils.weight_norm(self.skip_conv) 75 | nn.init.kaiming_normal_(self.skip_conv.weight) 76 | 77 | def forward(self, input_data): 78 | x, mel_spec, noise_scale_embed = input_data 79 | h = x 80 | B, C, L = x.shape # B, res_channels, L 81 | assert C == self.res_channels 82 | 83 | # add in noise scale embedding 84 | part_t = self.fc_t(noise_scale_embed) 85 | part_t = part_t.view([B, self.res_channels, 1]) 86 | h += part_t 87 | 88 | # dilated conv layer 89 | h = self.dilated_conv_layer(h) 90 | 91 | # add mel spectrogram as (local) conditioner 92 | assert mel_spec is not None 93 | 94 | # Upsample spectrogram to size of audio 95 | mel_spec = torch.unsqueeze(mel_spec, dim=1) # (B, 1, 80, T') 96 | mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4) 97 | mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4) 98 | mel_spec = torch.squeeze(mel_spec, dim=1) 99 | 100 | assert(mel_spec.size(2) >= L) 101 | if mel_spec.size(2) > L: 102 | mel_spec = mel_spec[:, :, :L] 103 | 104 | mel_spec = self.mel_conv(mel_spec) 105 | h += mel_spec 106 | 107 | # gated-tanh nonlinearity 108 | out = torch.tanh(h[:,:self.res_channels,:]) * torch.sigmoid(h[:,self.res_channels:,:]) 109 | 110 | # residual and skip outputs 111 | res = self.res_conv(out) 112 | assert x.shape == res.shape 113 | skip = self.skip_conv(out) 114 | 115 | return (x + res) * math.sqrt(0.5), skip # normalize for training stability 116 | 117 | 118 | class Residual_group(nn.Module): 119 | def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, 120 | noise_scale_embed_dim_in, 121 | noise_scale_embed_dim_mid, 122 | noise_scale_embed_dim_out, multiband): 123 | super(Residual_group, self).__init__() 124 | self.num_res_layers = num_res_layers 125 | self.noise_scale_embed_dim_in = noise_scale_embed_dim_in 126 | 127 | # the shared two fc layers for noise scale embedding 128 | self.fc_t1 = nn.Linear(noise_scale_embed_dim_in, noise_scale_embed_dim_mid) 129 | self.fc_t2 = nn.Linear(noise_scale_embed_dim_mid, noise_scale_embed_dim_out) 130 | 131 | # stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512 132 | self.residual_blocks = nn.ModuleList() 133 | for n in range(self.num_res_layers): 134 | self.residual_blocks.append(Residual_block(res_channels, skip_channels, 135 | dilation=2 ** (n % dilation_cycle), 136 | noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband)) 137 | 138 | def forward(self, input_data): 139 | x, mel_spectrogram, noise_scales = input_data 140 | 141 | # embed noise scale 142 | noise_scale_embed = calc_noise_scale_embedding(noise_scales, self.noise_scale_embed_dim_in) 143 | noise_scale_embed = swish(self.fc_t1(noise_scale_embed)) 144 | noise_scale_embed = swish(self.fc_t2(noise_scale_embed)) 145 | 146 | # pass all residual layers 147 | h = x 148 | skip = 0 149 | for n in range(self.num_res_layers): 150 | h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, noise_scale_embed)) # use the output from last residual layer 151 | skip += skip_n # accumulate all skip outputs 152 | 153 | return skip * math.sqrt(1.0 / self.num_res_layers) # normalize for training stability 154 | 155 | 156 | class WaveNet_vocoder(nn.Module): 157 | def __init__(self, in_channels, res_channels, skip_channels, out_channels, 158 | num_res_layers, dilation_cycle, 159 | noise_scale_embed_dim_in, 160 | noise_scale_embed_dim_mid, 161 | noise_scale_embed_dim_out, multiband): 162 | super(WaveNet_vocoder, self).__init__() 163 | 164 | # initial conv1x1 with relu 165 | self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU()) 166 | 167 | # all residual layers 168 | self.residual_layer = Residual_group(res_channels=res_channels, 169 | skip_channels=skip_channels, 170 | num_res_layers=num_res_layers, 171 | dilation_cycle=dilation_cycle, 172 | noise_scale_embed_dim_in=noise_scale_embed_dim_in, 173 | noise_scale_embed_dim_mid=noise_scale_embed_dim_mid, 174 | noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband) 175 | 176 | # final conv1x1 -> relu -> zeroconv1x1 177 | self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1), 178 | nn.ReLU(), 179 | ZeroConv1d(skip_channels, out_channels)) 180 | 181 | def forward(self, input_data): 182 | audio, mel_spectrogram, noise_scales = input_data # b x band x T, b x 80 x T', b x 1 183 | x = audio 184 | x = self.init_conv(x) 185 | x = self.residual_layer((x, mel_spectrogram, noise_scales)) 186 | x = self.final_conv(x) 187 | 188 | return x 189 | 190 | -------------------------------------------------------------------------------- /modules/FastDiff/task/FastDiff.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import utils 5 | from modules.FastDiff.module.FastDiff_model import FastDiff 6 | from tasks.vocoder.vocoder_base import VocoderBaseTask 7 | from utils import audio 8 | from utils.hparams import hparams 9 | from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule 10 | 11 | 12 | class FastDiffTask(VocoderBaseTask): 13 | def __init__(self): 14 | super(FastDiffTask, self).__init__() 15 | 16 | def build_model(self): 17 | self.model = FastDiff(audio_channels=hparams['audio_channels'], 18 | inner_channels=hparams['inner_channels'], 19 | cond_channels=hparams['cond_channels'], 20 | upsample_ratios=hparams['upsample_ratios'], 21 | lvc_layers_each_block=hparams['lvc_layers_each_block'], 22 | lvc_kernel_size=hparams['lvc_kernel_size'], 23 | kpnet_hidden_channels=hparams['kpnet_hidden_channels'], 24 | kpnet_conv_size=hparams['kpnet_conv_size'], 25 | dropout=hparams['dropout'], 26 | diffusion_step_embed_dim_in=hparams['diffusion_step_embed_dim_in'], 27 | diffusion_step_embed_dim_mid=hparams['diffusion_step_embed_dim_mid'], 28 | diffusion_step_embed_dim_out=hparams['diffusion_step_embed_dim_out'], 29 | use_weight_norm=hparams['use_weight_norm']) 30 | utils.print_arch(self.model) 31 | 32 | # Init hyperparameters by linear schedule 33 | noise_schedule = torch.linspace(float(hparams["beta_0"]), float(hparams["beta_T"]), int(hparams["T"])).cuda() 34 | diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule) 35 | 36 | # map diffusion hyperparameters to gpu 37 | for key in diffusion_hyperparams: 38 | if key in ["beta", "alpha", "sigma"]: 39 | diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda() 40 | self.diffusion_hyperparams = diffusion_hyperparams 41 | 42 | return self.model 43 | 44 | def _training_step(self, sample, batch_idx, optimizer_idx): 45 | mels = sample['mels'] 46 | y = sample['wavs'] 47 | X = (mels, y) 48 | loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams) 49 | return loss, {'loss': loss} 50 | 51 | 52 | def validation_step(self, sample, batch_idx): 53 | mels = sample['mels'] 54 | y = sample['wavs'] 55 | X = (mels, y) 56 | loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams) 57 | return loss, {'loss': loss} 58 | 59 | 60 | def test_step(self, sample, batch_idx): 61 | mels = sample['mels'] 62 | y = sample['wavs'] 63 | loss_output = {} 64 | 65 | if hparams['noise_schedule'] != '': 66 | noise_schedule = hparams['noise_schedule'] 67 | if isinstance(noise_schedule, list): 68 | noise_schedule = torch.FloatTensor(noise_schedule).cuda() 69 | else: 70 | # Select Schedule 71 | try: 72 | reverse_step = int(hparams.get('N')) 73 | except: 74 | print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.') 75 | reverse_step = 4 76 | if reverse_step == 1000: 77 | noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda() 78 | elif reverse_step == 200: 79 | noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda() 80 | 81 | # Below are schedules derived by Noise Predictor. 82 | elif reverse_step == 8: 83 | noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, 84 | 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] 85 | elif reverse_step == 6: 86 | noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, 87 | 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] 88 | elif reverse_step == 4: 89 | noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] 90 | elif reverse_step == 3: 91 | noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] 92 | else: 93 | raise NotImplementedError 94 | 95 | if isinstance(noise_schedule, list): 96 | noise_schedule = torch.FloatTensor(noise_schedule).cuda() 97 | 98 | audio_length = mels.shape[-1] * hparams["hop_size"] 99 | # generate using DDPM reverse process 100 | 101 | y_ = sampling_given_noise_schedule( 102 | self.model, (1, 1, audio_length), self.diffusion_hyperparams, noise_schedule, 103 | condition=mels, ddim=False, return_sequence=False) 104 | gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') 105 | os.makedirs(gen_dir, exist_ok=True) 106 | 107 | if len(y) == 0: 108 | # Inference from mel 109 | for idx, (wav_pred, item_name) in enumerate(zip(y_, sample["item_name"])): 110 | wav_pred = wav_pred / wav_pred.abs().max() 111 | audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', 112 | hparams['audio_sample_rate']) 113 | else: 114 | for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): 115 | wav_gt = wav_gt / wav_gt.abs().max() 116 | wav_pred = wav_pred / wav_pred.abs().max() 117 | audio.save_wav(wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', hparams['audio_sample_rate']) 118 | audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', hparams['audio_sample_rate']) 119 | return loss_output 120 | 121 | def build_optimizer(self, model): 122 | self.optimizer = optimizer = torch.optim.AdamW( 123 | self.model.parameters(), 124 | lr=float(hparams['lr']), weight_decay=float(hparams['weight_decay'])) 125 | return optimizer 126 | 127 | def compute_rtf(self, sample, generation_time, sample_rate=22050): 128 | """ 129 | Computes RTF for a given sample. 130 | """ 131 | total_length = sample.shape[-1] 132 | return float(generation_time * sample_rate / total_length) -------------------------------------------------------------------------------- /modules/commons/gdl_loss.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/FastDiff/c954758364aa845f37642952c7a16ff3e92e436e/modules/commons/gdl_loss.py -------------------------------------------------------------------------------- /modules/parallel_wavegan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/FastDiff/c954758364aa845f37642952c7a16ff3e92e436e/modules/parallel_wavegan/__init__.py -------------------------------------------------------------------------------- /modules/parallel_wavegan/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .causal_conv import * # NOQA 2 | from .residual_block import * # NOQA 3 | from modules.parallel_wavegan.layers.residual_stack import * # NOQA 4 | from .upsample import * # NOQA 5 | -------------------------------------------------------------------------------- /modules/parallel_wavegan/layers/causal_conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Causal convolusion layer modules.""" 7 | 8 | 9 | import torch 10 | 11 | 12 | class CausalConv1d(torch.nn.Module): 13 | """CausalConv1d module with customized initialization.""" 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, 16 | dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}): 17 | """Initialize CausalConv1d module.""" 18 | super(CausalConv1d, self).__init__() 19 | self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) 20 | self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, 21 | dilation=dilation, bias=bias) 22 | 23 | def forward(self, x): 24 | """Calculate forward propagation. 25 | 26 | Args: 27 | x (Tensor): Input tensor (B, in_channels, T). 28 | 29 | Returns: 30 | Tensor: Output tensor (B, out_channels, T). 31 | 32 | """ 33 | return self.conv(self.pad(x))[:, :, :x.size(2)] 34 | 35 | 36 | class CausalConvTranspose1d(torch.nn.Module): 37 | """CausalConvTranspose1d module with customized initialization.""" 38 | 39 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): 40 | """Initialize CausalConvTranspose1d module.""" 41 | super(CausalConvTranspose1d, self).__init__() 42 | self.deconv = torch.nn.ConvTranspose1d( 43 | in_channels, out_channels, kernel_size, stride, bias=bias) 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | """Calculate forward propagation. 48 | 49 | Args: 50 | x (Tensor): Input tensor (B, in_channels, T_in). 51 | 52 | Returns: 53 | Tensor: Output tensor (B, out_channels, T_out). 54 | 55 | """ 56 | return self.deconv(x)[:, :, :-self.stride] 57 | -------------------------------------------------------------------------------- /modules/parallel_wavegan/layers/residual_stack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Residual stack module in MelGAN.""" 7 | 8 | import torch 9 | 10 | from . import CausalConv1d 11 | 12 | 13 | class ResidualStack(torch.nn.Module): 14 | """Residual stack module introduced in MelGAN.""" 15 | 16 | def __init__(self, 17 | kernel_size=3, 18 | channels=32, 19 | dilation=1, 20 | bias=True, 21 | nonlinear_activation="LeakyReLU", 22 | nonlinear_activation_params={"negative_slope": 0.2}, 23 | pad="ReflectionPad1d", 24 | pad_params={}, 25 | use_causal_conv=False, 26 | ): 27 | """Initialize ResidualStack module. 28 | 29 | Args: 30 | kernel_size (int): Kernel size of dilation convolution layer. 31 | channels (int): Number of channels of convolution layers. 32 | dilation (int): Dilation factor. 33 | bias (bool): Whether to add bias parameter in convolution layers. 34 | nonlinear_activation (str): Activation function module name. 35 | nonlinear_activation_params (dict): Hyperparameters for activation function. 36 | pad (str): Padding function module name before dilated convolution layer. 37 | pad_params (dict): Hyperparameters for padding function. 38 | use_causal_conv (bool): Whether to use causal convolution. 39 | 40 | """ 41 | super(ResidualStack, self).__init__() 42 | 43 | # defile residual stack part 44 | if not use_causal_conv: 45 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 46 | self.stack = torch.nn.Sequential( 47 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 48 | getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), 49 | torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias), 50 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 51 | torch.nn.Conv1d(channels, channels, 1, bias=bias), 52 | ) 53 | else: 54 | self.stack = torch.nn.Sequential( 55 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 56 | CausalConv1d(channels, channels, kernel_size, dilation=dilation, 57 | bias=bias, pad=pad, pad_params=pad_params), 58 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 59 | torch.nn.Conv1d(channels, channels, 1, bias=bias), 60 | ) 61 | 62 | # defile extra layer for skip connection 63 | self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias) 64 | 65 | def forward(self, c): 66 | """Calculate forward propagation. 67 | 68 | Args: 69 | c (Tensor): Input tensor (B, channels, T). 70 | 71 | Returns: 72 | Tensor: Output tensor (B, chennels, T). 73 | 74 | """ 75 | return self.stack(c) + self.skip_layer(c) 76 | -------------------------------------------------------------------------------- /modules/parallel_wavegan/layers/upsample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Upsampling module. 4 | 5 | This code is modified from https://github.com/r9y9/wavenet_vocoder. 6 | 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from . import Conv1d 14 | 15 | 16 | class Stretch2d(torch.nn.Module): 17 | """Stretch2d module.""" 18 | 19 | def __init__(self, x_scale, y_scale, mode="nearest"): 20 | """Initialize Stretch2d module. 21 | 22 | Args: 23 | x_scale (int): X scaling factor (Time axis in spectrogram). 24 | y_scale (int): Y scaling factor (Frequency axis in spectrogram). 25 | mode (str): Interpolation mode. 26 | 27 | """ 28 | super(Stretch2d, self).__init__() 29 | self.x_scale = x_scale 30 | self.y_scale = y_scale 31 | self.mode = mode 32 | 33 | def forward(self, x): 34 | """Calculate forward propagation. 35 | 36 | Args: 37 | x (Tensor): Input tensor (B, C, F, T). 38 | 39 | Returns: 40 | Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), 41 | 42 | """ 43 | return F.interpolate( 44 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) 45 | 46 | 47 | class Conv2d(torch.nn.Conv2d): 48 | """Conv2d module with customized initialization.""" 49 | 50 | def __init__(self, *args, **kwargs): 51 | """Initialize Conv2d module.""" 52 | super(Conv2d, self).__init__(*args, **kwargs) 53 | 54 | def reset_parameters(self): 55 | """Reset parameters.""" 56 | self.weight.data.fill_(1. / np.prod(self.kernel_size)) 57 | if self.bias is not None: 58 | torch.nn.init.constant_(self.bias, 0.0) 59 | 60 | 61 | class UpsampleNetwork(torch.nn.Module): 62 | """Upsampling network module.""" 63 | 64 | def __init__(self, 65 | upsample_scales, 66 | nonlinear_activation=None, 67 | nonlinear_activation_params={}, 68 | interpolate_mode="nearest", 69 | freq_axis_kernel_size=1, 70 | use_causal_conv=False, 71 | ): 72 | """Initialize upsampling network module. 73 | 74 | Args: 75 | upsample_scales (list): List of upsampling scales. 76 | nonlinear_activation (str): Activation function name. 77 | nonlinear_activation_params (dict): Arguments for specified activation function. 78 | interpolate_mode (str): Interpolation mode. 79 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 80 | 81 | """ 82 | super(UpsampleNetwork, self).__init__() 83 | self.use_causal_conv = use_causal_conv 84 | self.up_layers = torch.nn.ModuleList() 85 | for scale in upsample_scales: 86 | # interpolation layer 87 | stretch = Stretch2d(scale, 1, interpolate_mode) 88 | self.up_layers += [stretch] 89 | 90 | # conv layer 91 | assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." 92 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 93 | kernel_size = (freq_axis_kernel_size, scale * 2 + 1) 94 | if use_causal_conv: 95 | padding = (freq_axis_padding, scale * 2) 96 | else: 97 | padding = (freq_axis_padding, scale) 98 | conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 99 | self.up_layers += [conv] 100 | 101 | # nonlinear 102 | if nonlinear_activation is not None: 103 | nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) 104 | self.up_layers += [nonlinear] 105 | 106 | def forward(self, c): 107 | """Calculate forward propagation. 108 | 109 | Args: 110 | c : Input tensor (B, C, T). 111 | 112 | Returns: 113 | Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). 114 | 115 | """ 116 | c = c.unsqueeze(1) # (B, 1, C, T) 117 | for f in self.up_layers: 118 | if self.use_causal_conv and isinstance(f, Conv2d): 119 | c = f(c)[..., :c.size(-1)] 120 | else: 121 | c = f(c) 122 | return c.squeeze(1) # (B, C, T') 123 | 124 | 125 | class ConvInUpsampleNetwork(torch.nn.Module): 126 | """Convolution + upsampling network module.""" 127 | 128 | def __init__(self, 129 | upsample_scales, 130 | nonlinear_activation=None, 131 | nonlinear_activation_params={}, 132 | interpolate_mode="nearest", 133 | freq_axis_kernel_size=1, 134 | aux_channels=80, 135 | aux_context_window=0, 136 | use_causal_conv=False 137 | ): 138 | """Initialize convolution + upsampling network module. 139 | 140 | Args: 141 | upsample_scales (list): List of upsampling scales. 142 | nonlinear_activation (str): Activation function name. 143 | nonlinear_activation_params (dict): Arguments for specified activation function. 144 | mode (str): Interpolation mode. 145 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 146 | aux_channels (int): Number of channels of pre-convolutional layer. 147 | aux_context_window (int): Context window size of the pre-convolutional layer. 148 | use_causal_conv (bool): Whether to use causal structure. 149 | 150 | """ 151 | super(ConvInUpsampleNetwork, self).__init__() 152 | self.aux_context_window = aux_context_window 153 | self.use_causal_conv = use_causal_conv and aux_context_window > 0 154 | # To capture wide-context information in conditional features 155 | kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 156 | # NOTE(kan-bayashi): Here do not use padding because the input is already padded 157 | self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) 158 | self.upsample = UpsampleNetwork( 159 | upsample_scales=upsample_scales, 160 | nonlinear_activation=nonlinear_activation, 161 | nonlinear_activation_params=nonlinear_activation_params, 162 | interpolate_mode=interpolate_mode, 163 | freq_axis_kernel_size=freq_axis_kernel_size, 164 | use_causal_conv=use_causal_conv, 165 | ) 166 | 167 | def forward(self, c): 168 | """Calculate forward propagation. 169 | 170 | Args: 171 | c : Input tensor (B, C, T'). 172 | 173 | Returns: 174 | Tensor: Upsampled tensor (B, C, T), 175 | where T = (T' - aux_context_window * 2) * prod(upsample_scales). 176 | 177 | Note: 178 | The length of inputs considers the context window size. 179 | 180 | """ 181 | c_ = self.conv_in(c) 182 | c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ 183 | return self.upsample(c) 184 | -------------------------------------------------------------------------------- /modules/parallel_wavegan/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .parallel_wavegan import * # NOQA 2 | -------------------------------------------------------------------------------- /modules/parallel_wavegan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * # NOQA 2 | -------------------------------------------------------------------------------- /modules/parallel_wavegan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Utility functions.""" 7 | 8 | import fnmatch 9 | import logging 10 | import os 11 | import sys 12 | try: 13 | import h5py 14 | except: 15 | pass 16 | import numpy as np 17 | 18 | 19 | def find_files(root_dir, query="*.wav", include_root_dir=True): 20 | """Find files recursively. 21 | 22 | Args: 23 | root_dir (str): Root root_dir to find. 24 | query (str): Query to find. 25 | include_root_dir (bool): If False, root_dir name is not included. 26 | 27 | Returns: 28 | list: List of found filenames. 29 | 30 | """ 31 | files = [] 32 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 33 | for filename in fnmatch.filter(filenames, query): 34 | files.append(os.path.join(root, filename)) 35 | if not include_root_dir: 36 | files = [file_.replace(root_dir + "/", "") for file_ in files] 37 | 38 | return files 39 | 40 | 41 | def read_hdf5(hdf5_name, hdf5_path): 42 | """Read hdf5 dataset. 43 | 44 | Args: 45 | hdf5_name (str): Filename of hdf5 file. 46 | hdf5_path (str): Dataset name in hdf5 file. 47 | 48 | Return: 49 | any: Dataset values. 50 | 51 | """ 52 | if not os.path.exists(hdf5_name): 53 | logging.error(f"There is no such a hdf5 file ({hdf5_name}).") 54 | sys.exit(1) 55 | 56 | hdf5_file = h5py.File(hdf5_name, "r") 57 | 58 | if hdf5_path not in hdf5_file: 59 | logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") 60 | sys.exit(1) 61 | 62 | hdf5_data = hdf5_file[hdf5_path][()] 63 | hdf5_file.close() 64 | 65 | return hdf5_data 66 | 67 | 68 | def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): 69 | """Write dataset to hdf5. 70 | 71 | Args: 72 | hdf5_name (str): Hdf5 dataset filename. 73 | hdf5_path (str): Dataset path in hdf5. 74 | write_data (ndarray): Data to write. 75 | is_overwrite (bool): Whether to overwrite dataset. 76 | 77 | """ 78 | # convert to numpy array 79 | write_data = np.array(write_data) 80 | 81 | # check folder existence 82 | folder_name, _ = os.path.split(hdf5_name) 83 | if not os.path.exists(folder_name) and len(folder_name) != 0: 84 | os.makedirs(folder_name) 85 | 86 | # check hdf5 existence 87 | if os.path.exists(hdf5_name): 88 | # if already exists, open with r+ mode 89 | hdf5_file = h5py.File(hdf5_name, "r+") 90 | # check dataset existence 91 | if hdf5_path in hdf5_file: 92 | if is_overwrite: 93 | logging.warning("Dataset in hdf5 file already exists. " 94 | "recreate dataset in hdf5.") 95 | hdf5_file.__delitem__(hdf5_path) 96 | else: 97 | logging.error("Dataset in hdf5 file already exists. " 98 | "if you want to overwrite, please set is_overwrite = True.") 99 | hdf5_file.close() 100 | sys.exit(1) 101 | else: 102 | # if not exists, open with w mode 103 | hdf5_file = h5py.File(hdf5_name, "w") 104 | 105 | # write data to hdf5 106 | hdf5_file.create_dataset(hdf5_path, data=write_data) 107 | hdf5_file.flush() 108 | hdf5_file.close() 109 | 110 | 111 | class HDF5ScpLoader(object): 112 | """Loader class for a fests.scp file of hdf5 file. 113 | 114 | Examples: 115 | key1 /some/path/a.h5:feats 116 | key2 /some/path/b.h5:feats 117 | key3 /some/path/c.h5:feats 118 | key4 /some/path/d.h5:feats 119 | ... 120 | >>> loader = HDF5ScpLoader("hdf5.scp") 121 | >>> array = loader["key1"] 122 | 123 | key1 /some/path/a.h5 124 | key2 /some/path/b.h5 125 | key3 /some/path/c.h5 126 | key4 /some/path/d.h5 127 | ... 128 | >>> loader = HDF5ScpLoader("hdf5.scp", "feats") 129 | >>> array = loader["key1"] 130 | 131 | """ 132 | 133 | def __init__(self, feats_scp, default_hdf5_path="feats"): 134 | """Initialize HDF5 scp loader. 135 | 136 | Args: 137 | feats_scp (str): Kaldi-style feats.scp file with hdf5 format. 138 | default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. 139 | 140 | """ 141 | self.default_hdf5_path = default_hdf5_path 142 | with open(feats_scp) as f: 143 | lines = [line.replace("\n", "") for line in f.readlines()] 144 | self.data = {} 145 | for line in lines: 146 | key, value = line.split() 147 | self.data[key] = value 148 | 149 | def get_path(self, key): 150 | """Get hdf5 file path for a given key.""" 151 | return self.data[key] 152 | 153 | def __getitem__(self, key): 154 | """Get ndarray for a given key.""" 155 | p = self.data[key] 156 | if ":" in p: 157 | return read_hdf5(*p.split(":")) 158 | else: 159 | return read_hdf5(p, self.default_hdf5_path) 160 | 161 | def __len__(self): 162 | """Return the length of the scp file.""" 163 | return len(self.data) 164 | 165 | def __iter__(self): 166 | """Return the iterator of the scp file.""" 167 | return iter(self.data) 168 | 169 | def keys(self): 170 | """Return the keys of the scp file.""" 171 | return self.data.keys() 172 | -------------------------------------------------------------------------------- /modules/wavenet_vocoder/conv.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Conv1d(nn.Conv1d): 8 | """Extended nn.Conv1d for incremental dilated convolutions 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.clear_buffer() 14 | self._linearized_weight = None 15 | self.register_backward_hook(self._clear_linearized_weight) 16 | 17 | def incremental_forward(self, input): 18 | # input: (B, T, C) 19 | if self.training: 20 | raise RuntimeError('incremental_forward only supports eval mode') 21 | 22 | # run forward pre hooks (e.g., weight norm) 23 | for hook in self._forward_pre_hooks.values(): 24 | hook(self, input) 25 | 26 | # reshape weight 27 | weight = self._get_linearized_weight() 28 | kw = self.kernel_size[0] 29 | dilation = self.dilation[0] 30 | 31 | bsz = input.size(0) # input: bsz x len x dim 32 | if kw > 1: 33 | input = input.data 34 | if self.input_buffer is None: 35 | self.input_buffer = input.new(bsz, kw + (kw - 1) * (dilation - 1), input.size(2)) 36 | self.input_buffer.zero_() 37 | else: 38 | # shift buffer 39 | self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() 40 | # append next input 41 | self.input_buffer[:, -1, :] = input[:, -1, :] 42 | input = self.input_buffer 43 | if dilation > 1: 44 | input = input[:, 0::dilation, :].contiguous() 45 | output = F.linear(input.view(bsz, -1), weight, self.bias) 46 | return output.view(bsz, 1, -1) 47 | 48 | def clear_buffer(self): 49 | self.input_buffer = None 50 | 51 | def _get_linearized_weight(self): 52 | if self._linearized_weight is None: 53 | kw = self.kernel_size[0] 54 | # nn.Conv1d 55 | if self.weight.size() == (self.out_channels, self.in_channels, kw): 56 | weight = self.weight.transpose(1, 2).contiguous() 57 | else: 58 | # fairseq.modules.conv_tbc.ConvTBC 59 | weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() 60 | assert weight.size() == (self.out_channels, kw, self.in_channels) 61 | self._linearized_weight = weight.view(self.out_channels, -1) 62 | return self._linearized_weight 63 | 64 | def _clear_linearized_weight(self, *args): 65 | self._linearized_weight = None 66 | -------------------------------------------------------------------------------- /modules/wavenet_vocoder/modules.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from . import conv 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): 14 | m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs) 15 | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") 16 | if m.bias is not None: 17 | nn.init.constant_(m.bias, 0) 18 | return nn.utils.weight_norm(m) 19 | 20 | 21 | def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01): 22 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 23 | m.weight.data.normal_(0, std) 24 | return m 25 | 26 | 27 | def ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs): 28 | freq_axis_kernel_size = kernel_size[0] 29 | m = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs) 30 | m.weight.data.fill_(1.0 / freq_axis_kernel_size) 31 | m.bias.data.zero_() 32 | return nn.utils.weight_norm(m) 33 | 34 | 35 | def Conv1d1x1(in_channels, out_channels, bias=True): 36 | """1-by-1 convolution layer 37 | """ 38 | return Conv1d(in_channels, out_channels, kernel_size=1, padding=0, 39 | dilation=1, bias=bias) 40 | 41 | 42 | def _conv1x1_forward(conv, x, is_incremental): 43 | """Conv1x1 forward 44 | """ 45 | if is_incremental: 46 | x = conv.incremental_forward(x) 47 | else: 48 | x = conv(x) 49 | return x 50 | 51 | 52 | class ResidualConv1dGLU(nn.Module): 53 | """Residual dilated conv1d + Gated linear unit 54 | 55 | Args: 56 | residual_channels (int): Residual input / output channels 57 | gate_channels (int): Gated activation channels. 58 | kernel_size (int): Kernel size of convolution layers. 59 | skip_out_channels (int): Skip connection channels. If None, set to same 60 | as ``residual_channels``. 61 | cin_channels (int): Local conditioning channels. If negative value is 62 | set, local conditioning is disabled. 63 | gin_channels (int): Global conditioning channels. If negative value is 64 | set, global conditioning is disabled. 65 | dropout (float): Dropout probability. 66 | padding (int): Padding for convolution layers. If None, proper padding 67 | is computed depends on dilation and kernel_size. 68 | dilation (int): Dilation factor. 69 | """ 70 | 71 | def __init__(self, residual_channels, gate_channels, kernel_size, 72 | skip_out_channels=None, 73 | cin_channels=-1, gin_channels=-1, 74 | dropout=1 - 0.95, padding=None, dilation=1, causal=True, 75 | bias=True, *args, **kwargs): 76 | super(ResidualConv1dGLU, self).__init__() 77 | self.dropout = dropout 78 | if skip_out_channels is None: 79 | skip_out_channels = residual_channels 80 | if padding is None: 81 | # no future time stamps available 82 | if causal: 83 | padding = (kernel_size - 1) * dilation 84 | else: 85 | padding = (kernel_size - 1) // 2 * dilation 86 | self.causal = causal 87 | 88 | self.conv = Conv1d(residual_channels, gate_channels, kernel_size, 89 | padding=padding, dilation=dilation, 90 | bias=bias, *args, **kwargs) 91 | 92 | # local conditioning 93 | if cin_channels > 0: 94 | self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) 95 | else: 96 | self.conv1x1c = None 97 | 98 | # global conditioning 99 | if gin_channels > 0: 100 | self.conv1x1g = Conv1d1x1(gin_channels, gate_channels, bias=False) 101 | else: 102 | self.conv1x1g = None 103 | 104 | # conv output is split into two groups 105 | gate_out_channels = gate_channels // 2 106 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 107 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) 108 | 109 | def forward(self, x, c=None, g=None): 110 | return self._forward(x, c, g, False) 111 | 112 | def incremental_forward(self, x, c=None, g=None): 113 | return self._forward(x, c, g, True) 114 | 115 | def _forward(self, x, c, g, is_incremental): 116 | """Forward 117 | 118 | Args: 119 | x (Tensor): B x C x T 120 | c (Tensor): B x C x T, Local conditioning features 121 | g (Tensor): B x C x T, Expanded global conditioning features 122 | is_incremental (Bool) : Whether incremental mode or not 123 | 124 | Returns: 125 | Tensor: output 126 | """ 127 | residual = x 128 | x = F.dropout(x, p=self.dropout, training=self.training) 129 | if is_incremental: 130 | splitdim = -1 131 | x = self.conv.incremental_forward(x) 132 | else: 133 | splitdim = 1 134 | x = self.conv(x) 135 | # remove future time steps 136 | x = x[:, :, :residual.size(-1)] if self.causal else x 137 | 138 | a, b = x.split(x.size(splitdim) // 2, dim=splitdim) 139 | 140 | # local conditioning 141 | if c is not None: 142 | assert self.conv1x1c is not None 143 | c = _conv1x1_forward(self.conv1x1c, c, is_incremental) 144 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 145 | a, b = a + ca, b + cb 146 | 147 | # global conditioning 148 | if g is not None: 149 | assert self.conv1x1g is not None 150 | g = _conv1x1_forward(self.conv1x1g, g, is_incremental) 151 | ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) 152 | a, b = a + ga, b + gb 153 | 154 | x = torch.tanh(a) * torch.sigmoid(b) 155 | 156 | # For skip connection 157 | s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) 158 | 159 | # For residual connection 160 | x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) 161 | 162 | x = (x + residual) * math.sqrt(0.5) 163 | return x, s 164 | 165 | def clear_buffer(self): 166 | for c in [self.conv, self.conv1x1_out, self.conv1x1_skip, 167 | self.conv1x1c, self.conv1x1g]: 168 | if c is not None: 169 | c.clear_buffer() 170 | -------------------------------------------------------------------------------- /modules/wavenet_vocoder/upsample.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | 12 | class Stretch2d(nn.Module): 13 | def __init__(self, x_scale, y_scale, mode="nearest"): 14 | super(Stretch2d, self).__init__() 15 | self.x_scale = x_scale 16 | self.y_scale = y_scale 17 | self.mode = mode 18 | 19 | def forward(self, x): 20 | return F.interpolate( 21 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) 22 | 23 | 24 | def _get_activation(upsample_activation): 25 | nonlinear = getattr(nn, upsample_activation) 26 | return nonlinear 27 | 28 | 29 | class UpsampleNetwork(nn.Module): 30 | def __init__(self, upsample_scales, upsample_activation="none", 31 | upsample_activation_params={}, mode="nearest", 32 | freq_axis_kernel_size=1, cin_pad=0, cin_channels=80): 33 | super(UpsampleNetwork, self).__init__() 34 | self.up_layers = nn.ModuleList() 35 | total_scale = np.prod(upsample_scales) 36 | self.indent = cin_pad * total_scale 37 | for scale in upsample_scales: 38 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 39 | k_size = (freq_axis_kernel_size, scale * 2 + 1) 40 | padding = (freq_axis_padding, scale) 41 | stretch = Stretch2d(scale, 1, mode) 42 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) 43 | conv.weight.data.fill_(1. / np.prod(k_size)) 44 | conv = nn.utils.weight_norm(conv) 45 | self.up_layers.append(stretch) 46 | self.up_layers.append(conv) 47 | if upsample_activation != "none": 48 | nonlinear = _get_activation(upsample_activation) 49 | self.up_layers.append(nonlinear(**upsample_activation_params)) 50 | 51 | def forward(self, c): 52 | """ 53 | Args: 54 | c : B x C x T 55 | """ 56 | 57 | # B x 1 x C x T 58 | c = c.unsqueeze(1) 59 | for f in self.up_layers: 60 | c = f(c) 61 | # B x C x T 62 | c = c.squeeze(1) 63 | 64 | if self.indent > 0: 65 | c = c[:, :, self.indent:-self.indent] 66 | return c 67 | 68 | 69 | class ConvInUpsampleNetwork(nn.Module): 70 | def __init__(self, upsample_scales, upsample_activation="none", 71 | upsample_activation_params={}, mode="nearest", 72 | freq_axis_kernel_size=1, cin_pad=0, 73 | cin_channels=80): 74 | super(ConvInUpsampleNetwork, self).__init__() 75 | # To capture wide-context information in conditional features 76 | # meaningless if cin_pad == 0 77 | ks = 2 * cin_pad + 1 78 | self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, bias=False) 79 | self.upsample = UpsampleNetwork( 80 | upsample_scales, upsample_activation, upsample_activation_params, 81 | mode, freq_axis_kernel_size, cin_pad=0, cin_channels=cin_channels) 82 | 83 | def forward(self, c): 84 | c_up = self.upsample(self.conv_in(c)) 85 | return c_up 86 | -------------------------------------------------------------------------------- /modules/wavenet_vocoder/util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | 5 | def _assert_valid_input_type(s): 6 | assert s == "mulaw-quantize" or s == "mulaw" or s == "raw" 7 | 8 | 9 | def is_mulaw_quantize(s): 10 | _assert_valid_input_type(s) 11 | return s == "mulaw-quantize" 12 | 13 | 14 | def is_mulaw(s): 15 | _assert_valid_input_type(s) 16 | return s == "mulaw" 17 | 18 | 19 | def is_raw(s): 20 | _assert_valid_input_type(s) 21 | return s == "raw" 22 | 23 | 24 | def is_scalar_input(s): 25 | return is_raw(s) or is_mulaw(s) 26 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FastDiff: A Fast Conditional Diffusion Model for High-Quality Speech Synthesis 2 | 3 |
drawing
4 | 5 | 6 | #### Rongjie Huang, Max W. Y. Lam, Jun Wang, Dan Su, Dong Yu, Yi Ren, Zhou Zhao 7 | 8 | PyTorch Implementation of [FastDiff (IJCAI'22)](https://arxiv.org/abs/2204.09934): a conditional diffusion probabilistic model capable of generating high fidelity speech efficiently. 9 | 10 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2204.09934) 11 | [![GitHub Stars](https://img.shields.io/github/stars/Rongjiehuang/FastDiff?style=social)](https://github.com/Rongjiehuang/FastDiff) 12 | ![visitors](https://visitor-badge.glitch.me/badge?page_id=Rongjiehuang/FastDiff) 13 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/Rongjiehuang/ProDiff) 14 | 15 | We provide our implementation and pretrained models as open source in this repository. 16 | 17 | Visit our [demo page](https://fastdiff.github.io/) for audio samples. 18 | 19 | Our follow-up work might also interest you: [ProDiff (ACM Multimedia'22)](https://arxiv.org/abs/2207.06389) on [GitHub](https://github.com/Rongjiehuang/ProDiff) 20 | 21 | ## News 22 | - April.22, 2021: **FastDiff** accepted by IJCAI 2022. 23 | - June.21, 2022: The LJSpeech checkpoint and demo code are provided. 24 | - August.12, 2022: The VCTK/LibriTTS checkpoints are provided. 25 | - August.25, 2022: **FastDiff (tacotron)** is provided. 26 | - September, 2022: We release follow-up work [ProDiff (ACM Multimedia'22)](https://arxiv.org/abs/2207.06389) on [GitHub](https://github.com/Rongjiehuang/ProDiff), where we futher optimized the speed-and-quality trade-off. 27 | 28 | # Quick Started 29 | We provide an example of how you can generate high-fidelity samples using FastDiff. 30 | 31 | To try on your own dataset, simply clone this repo in your local machine provided with NVIDIA GPU + CUDA cuDNN and follow the below intructions. 32 | 33 | ## Support Datasets and Pretrained Models 34 | 35 | You can also use pretrained models we provide [here](https://huggingface.co/Rongjiehuang/FastDiff). 36 | Details of each folder are as in follows: 37 | 38 | | Dataset | Config | 39 | |--------------------|--------------------------------------------------| 40 | | LJSpeech | `modules/FastDiff/config/FastDiff.yaml` | 41 | | LibriTTS | `modules/FastDiff/config/FastDiff_libritts.yaml` | 42 | | VCTK | `modules/FastDiff/config/FastDiff_vctk.yaml` | 43 | | LJSpeech(Tacotron) | `modules/FastDiff/config/FastDiff_tacotron.yaml` | 44 | 45 | More supported datasets are coming soon. 46 | 47 | Put the checkpoints in `checkpoints/$your_experiment_name/model_ckpt_steps_*.ckpt` 48 | 49 | ## Dependencies 50 | See requirements in `requirement.txt`: 51 | - [pytorch](https://github.com/pytorch/pytorch) 52 | - [librosa](https://github.com/librosa/librosa) 53 | - [NATSpeech](https://github.com/NATSpeech/NATSpeech) 54 | 55 | ## Multi-GPU 56 | By default, this implementation uses as many GPUs in parallel as returned by `torch.cuda.device_count()`. 57 | You can specify which GPUs to use by setting the `CUDA_DEVICES_AVAILABLE` environment variable before running the training module. 58 | 59 | ## Inference for text-to-speech synthesis 60 | 61 | ### Using ProDiff 62 | We provide a more efficient and stable pipeline in [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/Rongjiehuang/ProDiff) and [GitHub](https://github.com/Rongjiehuang/ProDiff) 63 | 64 | ### Using Tacotron 65 | Download LJSpeech checkpoint for neural vocoding of tacotron output [here](https://zjueducn-my.sharepoint.com/:f:/g/personal/rongjiehuang_zju_edu_cn/Epia7La6O7FHsKPTHZXZpoMBF7PoDcjWeKgC-7jtpVkCOQ?e=b8vPiA). 66 | We provide a demo in `egs/demo_tacotron.ipynb`. 67 | 68 | ### Using Portaspeech, DiffSpeech, FastSpeech 2 69 | 70 | 1. Download LJSpeech checkpoint and put it in `checkpoint/FastDiff/model_ckpt_steps_*.ckpt ` 71 | 2. Specify the input `$text`, and an int-type index `$model_index` to choose the TTS model. `0`(Portaspeech, Ren et al), `1`(FastSpeech 2, Ren et al), or `2`(DiffSpeech, Liu et al). 72 | 3. Set `N` for reverse sampling, which is a trade off between quality and speed. 73 | 4. Run the following command. 74 | ```bash 75 | CUDA_VISIBLE_DEVICES=$GPU python egs/demo_tts.py --N $N --text $text --model $model_index 76 | ``` 77 | Generated wav files are saved in `checkpoints/FastDiff/` by default.
78 | Note: For better quality, it's recommended to finetune the FastDiff model. 79 | 80 | ## Inference from wav file 81 | 1. Make `wavs` directory and copy wav files into the directory. 82 | 2. Set `N` for reverse sampling, which is a trade off between quality and speed. 83 | 3. Run the following command. 84 | ```bash 85 | CUDA_VISIBLE_DEVICES=$GPU python tasks/run.py --config $path/to/config --exp_name $your_experiment_name --infer --hparams='test_input_dir=wavs,N=$N' 86 | ``` 87 | 88 | Generated wav files are saved in `checkpoints/$your_experiment_name/` by default.
89 | 90 | ## Inference for end-to-end speech synthesis 91 | 1. Make `mels` directory and copy generated mel-spectrogram files into the directory.
92 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 93 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 94 | 2. Set `N` for reverse sampling, which is a trade off between quality and speed. 95 | 3. Run the following command. 96 | ```bash 97 | CUDA_VISIBLE_DEVICES=$GPU python tasks/run.py --config $path/to/config --exp_name $your_experiment_name --infer --hparams='test_mel_dir=mels,use_wav=False,N=$N' 98 | ``` 99 | Generated wav files are saved in `checkpoints/$your_experiment_name/` by default.
100 | 101 | Note: If you find the output wav noisy, it's likely because of the mel-preprocessing mismatch between the acoustic and vocoder models. 102 | 103 | # Train your own model 104 | 105 | ### Data Preparation and Configuraion ## 106 | 1. Set `raw_data_dir`, `processed_data_dir`, `binary_data_dir` in the config file. For custom dataset, please specify configurations of audio preprocessing in `modules/FastDiff/config/base.yaml` 107 | 2. Download dataset to `raw_data_dir`. Note: the dataset structure needs to follow `egs/datasets/audio/*/pre_align.py`, or you could rewrite `pre_align.py` according to your dataset 108 | 3. Preprocess Dataset 109 | ```bash 110 | # Preprocess step: unify the file structure. 111 | python data_gen/tts/bin/pre_align.py --config $path/to/config 112 | # Binarization step: Binarize data for fast IO. 113 | CUDA_VISIBLE_DEVICES=$GPU python data_gen/tts/bin/binarize.py --config $path/to/config 114 | ``` 115 | 116 | We also provide our processed LJSpeech dataset [here](https://zjueducn-my.sharepoint.com/:f:/g/personal/rongjiehuang_zju_edu_cn/Eo7r83WZPK1GmlwvFhhIKeQBABZpYW3ec9c8WZoUV5HhbA?e=9QoWnf). 117 | 118 | ### Training the Refinement Network 119 | ```bash 120 | CUDA_VISIBLE_DEVICES=$GPU python tasks/run.py --config $path/to/config --exp_name $your_experiment_name --reset 121 | ``` 122 | 123 | ### Training the Noise Predictor Network (Optional) 124 | Refer to [Bilateral Denoising Diffusion Models (BDDMs)](https://github.com/tencent-ailab/bddm). 125 | 126 | ### Noise Scheduling (Optional) 127 | You can use our pre-derived noise schedule in this time, or refer to [Bilateral Denoising Diffusion Models (BDDMs)](https://github.com/tencent-ailab/bddm). 128 | 129 | ### Inference 130 | 131 | ```bash 132 | CUDA_VISIBLE_DEVICES=$GPU python tasks/run.py --config $path/to/config --exp_name $your_experiment_name --infer 133 | ``` 134 | 135 | 136 | ## Acknowledgements 137 | This implementation uses parts of the code from the following Github repos: 138 | [NATSpeech](https://github.com/NATSpeech/NATSpeech), 139 | [Tacotron2](https://github.com/NVIDIA/tacotron2), and 140 | [DiffWave-Vocoder](https://github.com/philsyn/DiffWave-Vocoder) 141 | as described in our code. 142 | 143 | ## Citations ## 144 | If you find this code useful in your research, please consider citing: 145 | ``` 146 | @article{huang2022fastdiff, 147 | title={FastDiff: A Fast Conditional Diffusion Model for High-Quality Speech Synthesis}, 148 | author={Huang, Rongjie and Lam, Max WY and Wang, Jun and Su, Dan and Yu, Dong and Ren, Yi and Zhao, Zhou}, 149 | booktitle = {Proceedings of the Thirty-First International Joint Conference on 150 | Artificial Intelligence, {IJCAI-22}}, 151 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 152 | year={2022} 153 | } 154 | ``` 155 | 156 | ## Disclaimer ## 157 | - This is not an officially supported Tencent product. 158 | 159 | - Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. 160 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | librosa==0.8.0 3 | tqdm 4 | pandas 5 | numba==0.53.1 6 | numpy 7 | scipy==1.3 8 | PyYAML 9 | tensorboardX 10 | pyloudnorm 11 | setuptools>=41.0.0 12 | g2p_en 13 | resemblyzer 14 | webrtcvad 15 | tensorboard==2.6.0 16 | scikit-learn==0.24.1 17 | scikit-image==0.16.2 18 | textgrid 19 | jiwer 20 | pycwt 21 | PyWavelets 22 | praat-parselmouth==0.3.3 23 | jieba 24 | einops 25 | chardet -------------------------------------------------------------------------------- /tasks/run.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import subprocess 3 | from utils.hparams import set_hparams, hparams 4 | 5 | 6 | def run_task(): 7 | assert hparams['task_cls'] != '' 8 | pkg = ".".join(hparams["task_cls"].split(".")[:-1]) 9 | cls_name = hparams["task_cls"].split(".")[-1] 10 | task_cls = getattr(importlib.import_module(pkg), cls_name) 11 | task_cls.start() 12 | 13 | 14 | if __name__ == '__main__': 15 | try: 16 | import libtmux 17 | 18 | tmux_session_name = subprocess.run( 19 | "echo $(tmux list-panes -t \"$TMUX_PANE\" -F '#S' | head -n1)", 20 | shell=True, check=True, stdout=subprocess.PIPE).stdout 21 | tmux_session_name = tmux_session_name.decode().strip() 22 | server = libtmux.Server() 23 | session = server.find_where({"session_name": tmux_session_name}) 24 | window = session.attached_window 25 | except Exception as e: 26 | print('| libtmux load error.') 27 | 28 | try: 29 | from setproctitle import setproctitle 30 | 31 | # hide the process title 32 | setproctitle("python train.py") 33 | except: 34 | pass 35 | set_hparams() 36 | try: 37 | if hparams['rename_tmux'] and not hparams['infer']: 38 | window.rename_window('_'.join(hparams['exp_name'].split("_")[:-1])) 39 | except: 40 | pass 41 | 42 | run_task() 43 | -------------------------------------------------------------------------------- /tasks/vocoder/vocoder_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.utils.data import DistributedSampler 6 | 7 | from tasks.base_task import BaseTask 8 | from tasks.base_task import data_loader 9 | from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler 10 | from utils.hparams import hparams 11 | 12 | 13 | class VocoderBaseTask(BaseTask): 14 | def __init__(self): 15 | super(VocoderBaseTask, self).__init__() 16 | self.max_sentences = hparams['max_sentences'] 17 | self.max_valid_sentences = hparams['max_valid_sentences'] 18 | if self.max_valid_sentences == -1: 19 | hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences 20 | self.dataset_cls = VocoderDataset 21 | 22 | @data_loader 23 | def train_dataloader(self): 24 | train_dataset = self.dataset_cls('train', shuffle=True) 25 | return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) 26 | 27 | @data_loader 28 | def val_dataloader(self): 29 | valid_dataset = self.dataset_cls('valid', shuffle=False) 30 | return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) 31 | 32 | @data_loader 33 | def test_dataloader(self): 34 | test_dataset = self.dataset_cls('test', shuffle=False) 35 | return self.build_dataloader(test_dataset, False, self.max_valid_sentences) 36 | 37 | def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): 38 | world_size = 1 39 | rank = 0 40 | if dist.is_initialized(): 41 | world_size = dist.get_world_size() 42 | rank = dist.get_rank() 43 | sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler 44 | train_sampler = sampler_cls( 45 | dataset=dataset, 46 | num_replicas=world_size, 47 | rank=rank, 48 | shuffle=shuffle, 49 | ) 50 | return torch.utils.data.DataLoader( 51 | dataset=dataset, 52 | shuffle=False, 53 | collate_fn=dataset.collater, 54 | batch_size=max_sentences, 55 | num_workers=dataset.num_workers, 56 | sampler=train_sampler, 57 | pin_memory=True, 58 | ) 59 | 60 | def test_start(self): 61 | self.gen_dir = os.path.join(hparams['work_dir'], 62 | f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') 63 | os.makedirs(self.gen_dir, exist_ok=True) 64 | 65 | def test_end(self, outputs): 66 | return {} 67 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import librosa 3 | import librosa.filters 4 | import numpy as np 5 | import torch 6 | from scipy import signal 7 | from scipy.io import wavfile 8 | import torch.nn.functional as F 9 | 10 | 11 | def save_wav(wav, path, sr, norm=False): 12 | if norm: 13 | wav = wav / np.abs(wav).max() 14 | wav *= 32767 15 | # proposed by @dsmiller 16 | wavfile.write(path, sr, wav.astype(np.int16)) 17 | 18 | 19 | def to_mp3(out_path): 20 | subprocess.check_call( 21 | f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -ar 44100 -ac 1 -b:a 192k -y -hide_banner "{out_path}.mp3"', 22 | shell=True, stdin=subprocess.PIPE) 23 | subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True) 24 | 25 | 26 | def get_hop_size(hparams): 27 | hop_size = hparams['hop_size'] 28 | if hop_size is None: 29 | assert hparams['frame_shift_ms'] is not None 30 | hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['audio_sample_rate']) 31 | return hop_size 32 | 33 | 34 | ########################################################################################### 35 | def griffin_lim(S, hparams, angles=None): 36 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) if angles is None else angles 37 | S_complex = np.abs(S).astype(np.complex) 38 | y = _istft(S_complex * angles, hparams) 39 | for i in range(hparams['griffin_lim_iters']): 40 | angles = np.exp(1j * np.angle(_stft(y, hparams))) 41 | y = _istft(S_complex * angles, hparams) 42 | return y 43 | 44 | 45 | def preemphasis(wav, k, preemphasize=True): 46 | if preemphasize: 47 | return signal.lfilter([1, -k], [1], wav) 48 | return wav 49 | 50 | 51 | def inv_preemphasis(wav, k, inv_preemphasize=True): 52 | if inv_preemphasize: 53 | return signal.lfilter([1], [1, -k], wav) 54 | return wav 55 | 56 | 57 | def _stft(y, hparams): 58 | return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams), 59 | win_length=hparams['win_size'], pad_mode='constant') 60 | 61 | 62 | def _istft(y, hparams): 63 | return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size']) 64 | 65 | 66 | 67 | def librosa_pad_lr(x, fsize, fshift, pad_sides=1): 68 | '''compute right padding (final frame) or both sides padding (first and final frames) 69 | ''' 70 | assert pad_sides in (1, 2) 71 | # return int(fsize // 2) 72 | pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] 73 | if pad_sides == 1: 74 | return 0, pad 75 | else: 76 | return pad // 2, pad // 2 + pad % 2 77 | 78 | 79 | # Conversions 80 | _mel_basis = None 81 | _inv_mel_basis = None 82 | 83 | 84 | def _linear_to_mel(spectogram, hparams): 85 | global _mel_basis 86 | if _mel_basis is None: 87 | _mel_basis = _build_mel_basis(hparams) 88 | return np.dot(_mel_basis, spectogram) 89 | 90 | 91 | def _mel_to_linear(mel_spectrogram, hparams): 92 | global _inv_mel_basis 93 | if _inv_mel_basis is None: 94 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) 95 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 96 | 97 | 98 | def _build_mel_basis(hparams): 99 | assert hparams['fmax'] <= hparams['audio_sample_rate'] // 2 100 | return librosa.filters.mel(hparams['audio_sample_rate'], hparams['fft_size'], n_mels=hparams['audio_num_mel_bins'], 101 | fmin=hparams['fmin'], fmax=hparams['fmax']) 102 | 103 | 104 | def amp_to_db(x): 105 | return 20 * np.log10(np.maximum(1e-5, x)) 106 | 107 | 108 | def db_to_amp(x): 109 | return 10.0 ** (x * 0.05) 110 | 111 | 112 | def normalize(S, hparams): 113 | return (S - hparams['min_level_db']) / -hparams['min_level_db'] 114 | 115 | 116 | def denormalize(D, hparams): 117 | return (D * -hparams['min_level_db']) + hparams['min_level_db'] 118 | 119 | 120 | #### torch audio 121 | 122 | 123 | def istft(amp, ang, hparams, pad=False, window=None): 124 | spec = amp * torch.exp(1j * ang) 125 | spec_r = spec.real 126 | spec_i = spec.imag 127 | spec = torch.stack([spec_r, spec_i], -1) 128 | if window is None: 129 | window = torch.hann_window(hparams['win_size']).to(amp.device) 130 | if pad: 131 | spec = F.pad(spec, [0, 0, 0, 1], mode='reflect') 132 | wav = torch.istft(spec, hparams['fft_size'], hparams['hop_size'], hparams['win_size']) 133 | return wav 134 | 135 | 136 | def griffin_lim_torch(amp, ang, hparams, n_iters=30): 137 | """ 138 | 139 | Examples: 140 | >>> x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, win_length=win_length, pad_mode="constant") 141 | >>> x_stft = x_stft[None, ...] 142 | >>> amp = np.abs(x_stft) 143 | >>> angle_init = np.exp(2j * np.pi * np.random.rand(*x_stft.shape)) 144 | >>> amp = torch.FloatTensor(amp) 145 | >>> wav = griffin_lim_torch(amp, angle_init, hparams) 146 | 147 | :param amp: [B, n_fft, T] 148 | :param ang: [B, n_fft, T] 149 | :return: [B, T_wav] 150 | """ 151 | window = torch.hann_window(hparams['win_size']).to(amp.device) 152 | y = istft(amp, ang, hparams, window=window) 153 | for i in range(n_iters): 154 | x_stft = torch.stft(y, hparams['fft_size'], hparams['hop_size'], hparams['win_size'], window) 155 | x_stft = x_stft[..., 0] + 1j * x_stft[..., 1] 156 | ang = torch.angle(x_stft) 157 | y = istft(amp, ang, hparams, window=window) 158 | return y 159 | 160 | 161 | def split_audio_by_mel2ph(audio, mel2ph, hparams): 162 | if isinstance(audio, torch.Tensor): 163 | audio = audio.numpy() 164 | if isinstance(mel2ph, torch.Tensor): 165 | mel2ph = mel2ph.numpy() 166 | assert len(audio.shape) == 1, len(mel2ph.shape) == 1 167 | split_locs = [] 168 | for i in range(1, len(mel2ph)): 169 | if mel2ph[i] != mel2ph[i - 1]: 170 | split_loc = i * hparams['hop_size'] 171 | split_locs.append(split_loc) 172 | 173 | new_audio = [] 174 | for i in range(len(split_locs) - 1): 175 | new_audio.append(audio[split_locs[i]:split_locs[i + 1]]) 176 | new_audio.append(np.zeros([0.5 * hparams['audio_num_mel_bins']])) 177 | return np.concatenate(new_audio) 178 | -------------------------------------------------------------------------------- /utils/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import re 5 | import torch 6 | 7 | 8 | def get_last_checkpoint(work_dir, steps=None): 9 | checkpoint = None 10 | last_ckpt_path = None 11 | ckpt_paths = get_all_ckpts(work_dir, steps) 12 | if len(ckpt_paths) > 0: 13 | last_ckpt_path = ckpt_paths[0] 14 | checkpoint = torch.load(last_ckpt_path, map_location='cpu') 15 | logging.info(f'load module from checkpoint: {last_ckpt_path}') 16 | return checkpoint, last_ckpt_path 17 | 18 | 19 | def get_all_ckpts(work_dir, steps=None): 20 | if steps is None: 21 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' 22 | else: 23 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' 24 | return sorted(glob.glob(ckpt_path_pattern), 25 | key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) 26 | 27 | 28 | def load_ckpt(cur_model, ckpt_base_dir, model_name='models', force=True, strict=True): 29 | if os.path.isfile(ckpt_base_dir): 30 | base_dir = os.path.dirname(ckpt_base_dir) 31 | ckpt_path = ckpt_base_dir 32 | checkpoint = torch.load(ckpt_base_dir, map_location='cpu') 33 | else: 34 | base_dir = ckpt_base_dir 35 | checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) 36 | if checkpoint is not None: 37 | state_dict = checkpoint["state_dict"] 38 | if len([k for k in state_dict.keys() if '.' in k]) > 0: 39 | state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() 40 | if k.startswith(f'{model_name}.')} 41 | else: 42 | if '.' not in model_name: 43 | state_dict = state_dict[model_name] 44 | else: 45 | base_model_name = model_name.split('.')[0] 46 | rest_model_name = model_name[len(base_model_name) + 1:] 47 | state_dict = { 48 | k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() 49 | if k.startswith(f'{rest_model_name}.')} 50 | if not strict: 51 | cur_model_state_dict = cur_model.state_dict() 52 | unmatched_keys = [] 53 | for key, param in state_dict.items(): 54 | if key in cur_model_state_dict: 55 | new_param = cur_model_state_dict[key] 56 | if new_param.shape != param.shape: 57 | unmatched_keys.append(key) 58 | print("| Unmatched keys: ", key, new_param.shape, param.shape) 59 | for key in unmatched_keys: 60 | del state_dict[key] 61 | cur_model.load_state_dict(state_dict, strict=strict) 62 | print(f"| load '{model_name}' from '{ckpt_path}'.") 63 | else: 64 | e_msg = f"| ckpt not found in {base_dir}." 65 | if force: 66 | assert False, e_msg 67 | else: 68 | print(e_msg) 69 | -------------------------------------------------------------------------------- /utils/common_schedulers.py: -------------------------------------------------------------------------------- 1 | from utils.hparams import hparams 2 | 3 | 4 | class NoneSchedule(object): 5 | def __init__(self, optimizer): 6 | super().__init__() 7 | self.optimizer = optimizer 8 | self.constant_lr = hparams['lr'] 9 | self.step(0) 10 | 11 | def step(self, num_updates): 12 | self.lr = self.constant_lr 13 | for param_group in self.optimizer.param_groups: 14 | param_group['lr'] = self.lr 15 | return self.lr 16 | 17 | def get_lr(self): 18 | return self.optimizer.param_groups[0]['lr'] 19 | 20 | def get_last_lr(self): 21 | return self.get_lr() 22 | 23 | 24 | class RSQRTSchedule(object): 25 | def __init__(self, optimizer): 26 | super().__init__() 27 | self.optimizer = optimizer 28 | self.constant_lr = hparams['lr'] 29 | self.warmup_updates = hparams['warmup_updates'] 30 | self.hidden_size = hparams['hidden_size'] 31 | self.lr = hparams['lr'] 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = self.lr 34 | self.step(0) 35 | 36 | def step(self, num_updates): 37 | constant_lr = self.constant_lr 38 | warmup = min(num_updates / self.warmup_updates, 1.0) 39 | rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 40 | rsqrt_hidden = self.hidden_size ** -0.5 41 | self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) 42 | for param_group in self.optimizer.param_groups: 43 | param_group['lr'] = self.lr 44 | return self.lr 45 | 46 | def get_lr(self): 47 | return self.optimizer.param_groups[0]['lr'] 48 | 49 | def get_last_lr(self): 50 | return self.get_lr() 51 | -------------------------------------------------------------------------------- /utils/ddp_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DistributedDataParallel 2 | from torch.nn.parallel.distributed import _find_tensors 3 | import torch.optim 4 | import torch.utils.data 5 | import torch 6 | from packaging import version 7 | 8 | class DDP(DistributedDataParallel): 9 | """ 10 | Override the forward call in lightning so it goes to training and validation step respectively 11 | """ 12 | 13 | def forward(self, *inputs, **kwargs): # pragma: no cover 14 | if version.parse(torch.__version__[:6]) < version.parse("1.11"): 15 | self._sync_params() 16 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 17 | assert len(self.device_ids) == 1 18 | if self.module.training: 19 | output = self.module.training_step(*inputs[0], **kwargs[0]) 20 | elif self.module.testing: 21 | output = self.module.test_step(*inputs[0], **kwargs[0]) 22 | else: 23 | output = self.module.validation_step(*inputs[0], **kwargs[0]) 24 | if torch.is_grad_enabled(): 25 | # We'll return the output object verbatim since it is a freeform 26 | # object. We need to find any tensors in this object, though, 27 | # because we need to figure out which parameters were used during 28 | # this forward pass, to ensure we short circuit reduction for any 29 | # unused parameters. Only if `find_unused_parameters` is set. 30 | if self.find_unused_parameters: 31 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 32 | else: 33 | self.reducer.prepare_for_backward([]) 34 | else: 35 | from torch.nn.parallel.distributed import \ 36 | logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref 37 | with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): 38 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 39 | self.logger.set_runtime_stats_and_log() 40 | self.num_iterations += 1 41 | self.reducer.prepare_for_forward() 42 | 43 | # Notify the join context that this process has not joined, if 44 | # needed 45 | work = Join.notify_join_context(self) 46 | if work: 47 | self.reducer._set_forward_pass_work_handle( 48 | work, self._divide_by_initial_world_size 49 | ) 50 | 51 | # Calling _rebuild_buckets before forward compuation, 52 | # It may allocate new buckets before deallocating old buckets 53 | # inside _rebuild_buckets. To save peak memory usage, 54 | # call _rebuild_buckets before the peak memory usage increases 55 | # during forward computation. 56 | # This should be called only once during whole training period. 57 | if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): 58 | logging.info("Reducer buckets have been rebuilt in this iteration.") 59 | self._has_rebuilt_buckets = True 60 | 61 | # sync params according to location (before/after forward) user 62 | # specified as part of hook, if hook was specified. 63 | buffer_hook_registered = hasattr(self, 'buffer_hook') 64 | if self._check_sync_bufs_pre_fwd(): 65 | self._sync_buffers() 66 | 67 | if self._join_config.enable: 68 | # Notify joined ranks whether they should sync in backwards pass or not. 69 | self._check_global_requires_backward_grad_sync(is_joined_rank=False) 70 | 71 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 72 | if self.module.training: 73 | output = self.module.training_step(*inputs[0], **kwargs[0]) 74 | elif self.module.testing: 75 | output = self.module.test_step(*inputs[0], **kwargs[0]) 76 | else: 77 | output = self.module.validation_step(*inputs[0], **kwargs[0]) 78 | 79 | # sync params according to location (before/after forward) user 80 | # specified as part of hook, if hook was specified. 81 | if self._check_sync_bufs_post_fwd(): 82 | self._sync_buffers() 83 | 84 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 85 | self.require_forward_param_sync = True 86 | # We'll return the output object verbatim since it is a freeform 87 | # object. We need to find any tensors in this object, though, 88 | # because we need to figure out which parameters were used during 89 | # this forward pass, to ensure we short circuit reduction for any 90 | # unused parameters. Only if `find_unused_parameters` is set. 91 | if self.find_unused_parameters and not self.static_graph: 92 | # Do not need to populate this for static graph. 93 | self.reducer.prepare_for_backward(list(_find_tensors(output))) 94 | else: 95 | self.reducer.prepare_for_backward([]) 96 | else: 97 | self.require_forward_param_sync = False 98 | 99 | # TODO: DDPSink is currently enabled for unused parameter detection and 100 | # static graph training for first iteration. 101 | if (self.find_unused_parameters and not self.static_graph) or ( 102 | self.static_graph and self.num_iterations == 1 103 | ): 104 | state_dict = { 105 | 'static_graph': self.static_graph, 106 | 'num_iterations': self.num_iterations, 107 | } 108 | 109 | output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( 110 | output 111 | ) 112 | output_placeholders = [None for _ in range(len(output_tensor_list))] 113 | # Do not touch tensors that have no grad_fn, which can cause issues 114 | # such as https://github.com/pytorch/pytorch/issues/60733 115 | for i, output in enumerate(output_tensor_list): 116 | if torch.is_tensor(output) and output.grad_fn is None: 117 | output_placeholders[i] = output 118 | 119 | # When find_unused_parameters=True, makes tensors which require grad 120 | # run through the DDPSink backward pass. When not all outputs are 121 | # used in loss, this makes those corresponding tensors receive 122 | # undefined gradient which the reducer then handles to ensure 123 | # param.grad field is not touched and we don't error out. 124 | passthrough_tensor_list = _DDPSink.apply( 125 | self.reducer, 126 | state_dict, 127 | *output_tensor_list, 128 | ) 129 | for i in range(len(output_placeholders)): 130 | if output_placeholders[i] is None: 131 | output_placeholders[i] = passthrough_tensor_list[i] 132 | 133 | # Reconstruct output data structure. 134 | output = _tree_unflatten_with_rref( 135 | output_placeholders, treespec, output_is_rref 136 | ) 137 | return output 138 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | import yaml 6 | 7 | global_print_hparams = True 8 | hparams = {} 9 | 10 | 11 | class Args: 12 | def __init__(self, **kwargs): 13 | for k, v in kwargs.items(): 14 | self.__setattr__(k, v) 15 | 16 | 17 | def override_config(old_config: dict, new_config: dict): 18 | for k, v in new_config.items(): 19 | if isinstance(v, dict) and k in old_config: 20 | override_config(old_config[k], new_config[k]) 21 | else: 22 | old_config[k] = v 23 | 24 | 25 | def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): 26 | if config == '' and exp_name == '': 27 | parser = argparse.ArgumentParser(description='') 28 | parser.add_argument('--config', type=str, default='configs/config_base.yaml', 29 | help='location of the data corpus') 30 | parser.add_argument('--exp_name', type=str, default='', help='exp_name') 31 | parser.add_argument('--hparams', type=str, default='', 32 | help='location of the data corpus') 33 | parser.add_argument('--infer', action='store_true', help='infer') 34 | parser.add_argument('--validate', action='store_true', help='validate') 35 | parser.add_argument('--reset', action='store_true', help='reset hparams') 36 | parser.add_argument('--remove', action='store_true', help='remove old ckpt') 37 | parser.add_argument('--debug', action='store_true', help='debug') 38 | args, unknown = parser.parse_known_args() 39 | else: 40 | args = Args(config=config, exp_name=exp_name, hparams=hparams_str, 41 | infer=False, validate=False, reset=False, debug=False) 42 | global hparams 43 | assert args.config != '' or args.exp_name != '' 44 | 45 | config_chains = [] 46 | loaded_config = set() 47 | 48 | def load_config(config_fn): # deep first 49 | if not os.path.exists(config_fn): 50 | return {} 51 | with open(config_fn) as f: 52 | hparams_ = yaml.safe_load(f) 53 | loaded_config.add(config_fn) 54 | if 'base_config' in hparams_: 55 | ret_hparams = {} 56 | if not isinstance(hparams_['base_config'], list): 57 | hparams_['base_config'] = [hparams_['base_config']] 58 | for c in hparams_['base_config']: 59 | if c.startswith('.'): 60 | c = f'{os.path.dirname(config_fn)}/{c}' 61 | c = os.path.normpath(c) 62 | if c not in loaded_config: 63 | override_config(ret_hparams, load_config(c)) 64 | override_config(ret_hparams, hparams_) 65 | else: 66 | ret_hparams = hparams_ 67 | config_chains.append(config_fn) 68 | return ret_hparams 69 | 70 | saved_hparams = {} 71 | args_work_dir = '' 72 | if args.exp_name != '': 73 | args_work_dir = f'checkpoints/{args.exp_name}' 74 | ckpt_config_path = f'{args_work_dir}/config.yaml' 75 | if os.path.exists(ckpt_config_path): 76 | with open(ckpt_config_path) as f: 77 | saved_hparams_ = yaml.safe_load(f) 78 | if saved_hparams_ is not None: 79 | saved_hparams.update(saved_hparams_) 80 | hparams_ = {} 81 | if args.config != '': 82 | hparams_.update(load_config(args.config)) 83 | if not args.reset: 84 | hparams_.update(saved_hparams) 85 | hparams_['work_dir'] = args_work_dir 86 | 87 | # --hparams="a=1,b.c=2,d=[1 1 1]" 88 | if args.hparams != "": 89 | for new_hparam in args.hparams.split(","): 90 | k, v = new_hparam.split("=") 91 | v = v.strip("\'\" ") 92 | config_node = hparams_ 93 | for k_ in k.split(".")[:-1]: 94 | config_node = config_node[k_] 95 | k = k.split(".")[-1] 96 | if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: 97 | if type(config_node[k]) == list: 98 | v = v.replace(" ", ",") 99 | config_node[k] = eval(v) 100 | else: 101 | config_node[k] = type(config_node[k])(v) 102 | if args_work_dir != '' and args.remove: 103 | answer = input("REMOVE old checkpoint? Y/N [Default: N]: ") 104 | if answer.lower() == "y": 105 | subprocess.check_call(f'rm -rf {args_work_dir}', shell=True) 106 | if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: 107 | os.makedirs(hparams_['work_dir'], exist_ok=True) 108 | with open(ckpt_config_path, 'w') as f: 109 | yaml.safe_dump(hparams_, f) 110 | 111 | hparams_['infer'] = args.infer 112 | hparams_['debug'] = args.debug 113 | hparams_['validate'] = args.validate 114 | hparams_['exp_name'] = args.exp_name 115 | global global_print_hparams 116 | if global_hparams: 117 | hparams.clear() 118 | hparams.update(hparams_) 119 | if print_hparams and global_print_hparams and global_hparams: 120 | print('| Hparams chains: ', config_chains) 121 | print('| Hparams: ') 122 | for i, (k, v) in enumerate(sorted(hparams_.items())): 123 | print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") 124 | print("") 125 | global_print_hparams = False 126 | return hparams_ 127 | -------------------------------------------------------------------------------- /utils/indexed_datasets.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | 6 | 7 | class IndexedDataset: 8 | def __init__(self, path, num_cache=1): 9 | super().__init__() 10 | self.path = path 11 | self.data_file = None 12 | self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] 13 | self.data_file = open(f"{path}.data", 'rb', buffering=-1) 14 | self.cache = [] 15 | self.num_cache = num_cache 16 | 17 | def check_index(self, i): 18 | if i < 0 or i >= len(self.data_offsets) - 1: 19 | raise IndexError('index out of range') 20 | 21 | def __del__(self): 22 | if self.data_file: 23 | self.data_file.close() 24 | 25 | def __getitem__(self, i): 26 | self.check_index(i) 27 | if self.num_cache > 0: 28 | for c in self.cache: 29 | if c[0] == i: 30 | return c[1] 31 | self.data_file.seek(self.data_offsets[i]) 32 | b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) 33 | item = pickle.loads(b) 34 | if self.num_cache > 0: 35 | self.cache = [(i, deepcopy(item))] + self.cache[:-1] 36 | return item 37 | 38 | def __len__(self): 39 | return len(self.data_offsets) - 1 40 | 41 | class IndexedDatasetBuilder: 42 | def __init__(self, path): 43 | self.path = path 44 | self.out_file = open(f"{path}.data", 'wb') 45 | self.byte_offsets = [0] 46 | 47 | def add_item(self, item): 48 | s = pickle.dumps(item) 49 | bytes = self.out_file.write(s) 50 | self.byte_offsets.append(self.byte_offsets[-1] + bytes) 51 | 52 | def finalize(self): 53 | self.out_file.close() 54 | np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) 55 | 56 | 57 | if __name__ == "__main__": 58 | import random 59 | from tqdm import tqdm 60 | ds_path = '/tmp/indexed_ds_example' 61 | size = 100 62 | items = [{"a": np.random.normal(size=[10000, 10]), 63 | "b": np.random.normal(size=[10000, 10])} for i in range(size)] 64 | builder = IndexedDatasetBuilder(ds_path) 65 | for i in tqdm(range(size)): 66 | builder.add_item(items[i]) 67 | builder.finalize() 68 | ds = IndexedDataset(ds_path) 69 | for i in tqdm(range(10000)): 70 | idx = random.randint(0, size - 1) 71 | assert (ds[idx]['a'] == items[idx]['a']).all() 72 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import scipy.ndimage 2 | 3 | def laplace_var(x): 4 | return scipy.ndimage.laplace(x).var() 5 | -------------------------------------------------------------------------------- /utils/multiprocess_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | def chunked_worker(worker_id, map_func, args, results_queue=None, init_ctx_func=None): 8 | ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None 9 | for job_idx, arg in args: 10 | try: 11 | if not isinstance(arg, tuple) and not isinstance(arg, list): 12 | arg = [arg] 13 | if ctx is not None: 14 | res = map_func(*arg, ctx=ctx) 15 | else: 16 | res = map_func(*arg) 17 | results_queue.put((job_idx, res)) 18 | except: 19 | traceback.print_exc() 20 | results_queue.put((job_idx, None)) 21 | 22 | 23 | def chunked_multiprocess_run( 24 | map_func, args, num_workers=None, ordered=True, 25 | init_ctx_func=None, q_max_size=1000, multithread=False): 26 | if multithread: 27 | from multiprocessing.dummy import Queue, Process 28 | else: 29 | from multiprocessing import Queue, Process 30 | args = zip(range(len(args)), args) 31 | args = list(args) 32 | n_jobs = len(args) 33 | if num_workers is None: 34 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 35 | results_queues = [] 36 | if ordered: 37 | for i in range(num_workers): 38 | results_queues.append(Queue(maxsize=q_max_size // num_workers)) 39 | else: 40 | results_queue = Queue(maxsize=q_max_size) 41 | for i in range(num_workers): 42 | results_queues.append(results_queue) 43 | workers = [] 44 | for i in range(num_workers): 45 | args_worker = args[i::num_workers] 46 | p = Process(target=chunked_worker, args=( 47 | i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True) 48 | workers.append(p) 49 | p.start() 50 | for n_finished in range(n_jobs): 51 | results_queue = results_queues[n_finished % num_workers] 52 | job_idx, res = results_queue.get() 53 | assert job_idx == n_finished or not ordered, (job_idx, n_finished) 54 | yield res 55 | for w in workers: 56 | w.join() 57 | 58 | 59 | def chunked_worker2(worker_id, args_queue=None, results_queue=None, init_ctx_func=None): 60 | ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None 61 | while True: 62 | args = args_queue.get() 63 | if args == '': 64 | return 65 | job_id, map_func, arg = args 66 | try: 67 | if ctx is not None: 68 | res = map_func(*arg, ctx=ctx) 69 | else: 70 | res = map_func(*arg) 71 | results_queue.put((job_id, res)) 72 | except: 73 | traceback.print_exc() 74 | results_queue.put((job_id, None)) 75 | 76 | 77 | class MultiprocessManager: 78 | def __init__(self, num_workers=None, init_ctx_func=None): 79 | from multiprocessing import Queue, Process 80 | if num_workers is None: 81 | num_workers = int(os.getenv('N_PROC', os.cpu_count())) 82 | self.num_workers = num_workers 83 | self.results_queue = Queue(maxsize=-1) 84 | self.args_queue = Queue(maxsize=-1) 85 | self.workers = [] 86 | self.total_jobs = 0 87 | for i in range(num_workers): 88 | p = Process(target=chunked_worker2, 89 | args=(i, self.args_queue, self.results_queue, init_ctx_func), 90 | daemon=True) 91 | self.workers.append(p) 92 | p.start() 93 | 94 | def add_job(self, func, arg): 95 | self.args_queue.put((self.total_jobs, func, arg)) 96 | self.total_jobs += 1 97 | 98 | def get_results(self): 99 | for w in range(self.num_workers): 100 | self.args_queue.put("") 101 | results = [None for _ in range(self.total_jobs)] 102 | self.n_finished = 0 103 | t = tqdm(desc='MultiprocessManager Process: ', total=self.total_jobs) 104 | while self.n_finished < self.total_jobs: 105 | t.update() 106 | job_id, res = self.results_queue.get() 107 | results[job_id] = res 108 | self.n_finished += 1 109 | for w in self.workers: 110 | w.join() 111 | return results 112 | -------------------------------------------------------------------------------- /utils/pitch_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import matplotlib.pyplot as plt 4 | from numba import jit 5 | 6 | import torch 7 | 8 | 9 | @jit 10 | def time_warp(costs): 11 | dtw = np.zeros_like(costs) 12 | dtw[0,1:] = np.inf 13 | dtw[1:,0] = np.inf 14 | eps = 1e-4 15 | for i in range(1,costs.shape[0]): 16 | for j in range(1,costs.shape[1]): 17 | dtw[i,j] = costs[i,j] + min(dtw[i-1,j],dtw[i,j-1],dtw[i-1,j-1]) 18 | return dtw 19 | 20 | 21 | def align_from_distances(distance_matrix, debug=False, return_mindist=False): 22 | # for each position in spectrum 1, returns best match position in spectrum2 23 | # using monotonic alignment 24 | dtw = time_warp(distance_matrix) 25 | 26 | i = distance_matrix.shape[0]-1 27 | j = distance_matrix.shape[1]-1 28 | results = [0] * distance_matrix.shape[0] 29 | while i > 0 and j > 0: 30 | results[i] = j 31 | i, j = min([(i-1,j),(i,j-1),(i-1,j-1)], key=lambda x: dtw[x[0],x[1]]) 32 | 33 | if debug: 34 | visual = np.zeros_like(dtw) 35 | visual[range(len(results)),results] = 1 36 | plt.matshow(visual) 37 | plt.show() 38 | if return_mindist: 39 | return results, dtw[-1, -1] 40 | return results 41 | 42 | 43 | def get_local_context(input_f, max_window=32, scale_factor=1.): 44 | # input_f: [S, 1], support numpy array or torch tensor 45 | # return hist: [S, max_window * 2], list of list 46 | T = input_f.shape[0] 47 | # max_window = int(max_window * scale_factor) 48 | derivative = [[0 for _ in range(max_window * 2)] for _ in range(T)] 49 | 50 | for t in range(T): # travel the time series 51 | for feat_idx in range(-max_window, max_window): 52 | if t + feat_idx < 0 or t + feat_idx >= T: 53 | value = 0 54 | else: 55 | value = input_f[t+feat_idx] 56 | derivative[t][feat_idx+max_window] = value 57 | return derivative 58 | 59 | 60 | def cal_localnorm_dist(src, tgt, src_len, tgt_len): 61 | local_src = torch.tensor(get_local_context(src)) 62 | local_tgt = torch.tensor(get_local_context(tgt, scale_factor=tgt_len / src_len)) 63 | 64 | local_norm_src = (local_src - local_src.mean(-1).unsqueeze(-1)) #/ local_src.std(-1).unsqueeze(-1) # [T1, 32] 65 | local_norm_tgt = (local_tgt - local_tgt.mean(-1).unsqueeze(-1)) #/ local_tgt.std(-1).unsqueeze(-1) # [T2, 32] 66 | 67 | dists = torch.cdist(local_norm_src[None, :, :], local_norm_tgt[None, :, :]) # [1, T1, T2] 68 | return dists 69 | 70 | 71 | ## here is API for one sample 72 | def LoNDTWDistance(src, tgt): 73 | # src: [S] 74 | # tgt: [T] 75 | dists = cal_localnorm_dist(src, tgt, src.shape[0], tgt.shape[0]) # [1, S, T] 76 | costs = dists.squeeze(0) # [S, T] 77 | alignment, min_distance = align_from_distances(costs.T.cpu().detach().numpy(), return_mindist=True) # [T] 78 | return alignment, min_distance 79 | 80 | 81 | # if __name__ == '__main__': 82 | # # utils from ns 83 | # from utils.pitch_utils import denorm_f0 84 | # from tasks.singing.fsinging import FastSingingDataset 85 | # from utils.hparams import hparams, set_hparams 86 | # 87 | # set_hparams() 88 | # 89 | # train_ds = FastSingingDataset('test') 90 | # 91 | # # Test One sample case 92 | # sample = train_ds[0] 93 | # amateur_f0 = sample['f0'] 94 | # prof_f0 = sample['prof_f0'] 95 | # 96 | # amateur_uv = sample['uv'] 97 | # amateur_padding = sample['mel2ph'] == 0 98 | # prof_uv = sample['prof_uv'] 99 | # prof_padding = sample['prof_mel2ph'] == 0 100 | # amateur_f0_denorm = denorm_f0(amateur_f0, amateur_uv, hparams, pitch_padding=amateur_padding) 101 | # prof_f0_denorm = denorm_f0(prof_f0, prof_uv, hparams, pitch_padding=prof_padding) 102 | # alignment, min_distance = LoNDTWDistance(amateur_f0_denorm, prof_f0_denorm) 103 | # print(min_distance) 104 | # python utils/pitch_distance.py --config egs/datasets/audio/molar/svc_ppg.yaml -------------------------------------------------------------------------------- /utils/pitch_utils.py: -------------------------------------------------------------------------------- 1 | ########## 2 | # world 3 | ########## 4 | import librosa 5 | import numpy as np 6 | import copy 7 | 8 | import torch 9 | 10 | gamma = 0 11 | mcepInput = 3 # 0 for dB, 3 for magnitude 12 | alpha = 0.45 13 | en_floor = 10 ** (-80 / 20) 14 | FFT_SIZE = 2048 15 | 16 | 17 | def code_harmonic(sp, order): 18 | import pysptk 19 | # get mcep 20 | mceps = np.apply_along_axis(pysptk.mcep, 1, sp, order - 1, alpha, itype=mcepInput, threshold=en_floor) 21 | 22 | # do fft and take real 23 | scale_mceps = copy.copy(mceps) 24 | scale_mceps[:, 0] *= 2 25 | scale_mceps[:, -1] *= 2 26 | mirror = np.hstack([scale_mceps[:, :-1], scale_mceps[:, -1:0:-1]]) 27 | mfsc = np.fft.rfft(mirror).real 28 | 29 | return mfsc 30 | 31 | 32 | def decode_harmonic(mfsc, fftlen=FFT_SIZE): 33 | import pysptk 34 | # get mcep back 35 | mceps_mirror = np.fft.irfft(mfsc) 36 | mceps_back = mceps_mirror[:, :60] 37 | mceps_back[:, 0] /= 2 38 | mceps_back[:, -1] /= 2 39 | 40 | # get sp 41 | spSm = np.exp(np.apply_along_axis(pysptk.mgc2sp, 1, mceps_back, alpha, gamma, fftlen=fftlen).real) 42 | 43 | return spSm 44 | 45 | 46 | def to_lf0(f0): 47 | f0[f0 < 1.0e-5] = 1.0e-6 48 | lf0 = f0.log() if isinstance(f0, torch.Tensor) else np.log(f0) 49 | lf0[f0 < 1.0e-5] = - 1.0E+10 50 | return lf0 51 | 52 | 53 | def to_f0(lf0): 54 | f0 = np.where(lf0 <= 0, 0.0, np.exp(lf0)) 55 | return f0.flatten() 56 | 57 | 58 | def formant_enhancement(coded_spectrogram, beta, fs): 59 | alpha_dict = { 60 | 8000: 0.31, 61 | 16000: 0.58, 62 | 22050: 0.65, 63 | 44100: 0.76, 64 | 48000: 0.77 65 | } 66 | alpha = alpha_dict[fs] 67 | datad = np.zeros((coded_spectrogram.shape[1],)) 68 | sp_dim = coded_spectrogram.shape[1] 69 | for i in range(coded_spectrogram.shape[0]): 70 | datad = mc2b(coded_spectrogram[i], datad, sp_dim - 1, alpha) 71 | datad[1] = datad[1] - alpha * beta * datad[2] 72 | for j in range(2, sp_dim): 73 | datad[j] *= 1 + beta 74 | coded_spectrogram[i] = b2mc(datad, coded_spectrogram[i], sp_dim - 1, alpha) 75 | return coded_spectrogram 76 | 77 | 78 | def mc2b(mc, b, m, a): 79 | """ 80 | Transform Mel Cepstrum to MLSA Digital Filter Coefficients 81 | 82 | void mc2b(mc, b, m, a) 83 | 84 | double *mc : mel cepstral coefficients 85 | double *b : MLSA digital filter coefficients 86 | int m : order of mel cepstrum 87 | double a : all-pass constant 88 | 89 | http://www.asel.udel.edu/icslp/cdrom/vol1/725/a725.pdf 90 | CELP coding system based on mel-generalized cepstral analysis 91 | :param mc: 92 | :param b: 93 | :param m: 94 | :param a: 95 | :return: 96 | """ 97 | b[m] = mc[m] 98 | for i in range(1, m + 1): 99 | b[m - i] = mc[m - i] - a * b[m - i + 1] 100 | return b 101 | 102 | 103 | def b2mc(b, mc, m, a): 104 | """ 105 | Transform MLSA Digital Filter Coefficients to Mel Cepstrum 106 | 107 | void b2mc(b, mc, m, a) 108 | 109 | double *b : MLSA digital filter coefficients 110 | double *mc : mel cepstral coefficients 111 | int m : order of mel cepstrum 112 | double a : all-pass constant 113 | 114 | http://www.asel.udel.edu/icslp/cdrom/vol1/725/a725.pdf 115 | CELP coding system based on mel-generalized cepstral analysis 116 | :param b: 117 | :param mc: 118 | :param m: 119 | :param a: 120 | :return: 121 | """ 122 | d = mc[m] = b[m] 123 | for i in range(1, m + 1): 124 | o = b[m - i] + a * d 125 | d = b[m - i] 126 | mc[m - i] = o 127 | return mc 128 | 129 | 130 | f0_bin = 256 131 | f0_max = 1100.0 132 | f0_min = 50.0 133 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 134 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 135 | 136 | 137 | def f0_to_coarse(f0): 138 | is_torch = isinstance(f0, torch.Tensor) 139 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 140 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 141 | 142 | f0_mel[f0_mel <= 1] = 1 143 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 144 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 145 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min(), f0.min(), f0.max()) 146 | return f0_coarse 147 | 148 | 149 | def norm_f0(f0, uv, hparams): 150 | is_torch = isinstance(f0, torch.Tensor) 151 | if hparams['pitch_norm'] == 'standard': 152 | f0 = (f0 - hparams['f0_mean']) / hparams['f0_std'] 153 | if hparams['pitch_norm'] == 'log': 154 | f0 = torch.log2(f0 + 1e-8) if is_torch else np.log2(f0 + 1e-8) 155 | if uv is not None and hparams['use_uv']: 156 | f0[uv > 0] = 0 157 | return f0 158 | 159 | 160 | def norm_interp_f0(f0, hparams): 161 | is_torch = isinstance(f0, torch.Tensor) 162 | if is_torch: 163 | device = f0.device 164 | f0 = f0.data.cpu().numpy() 165 | uv = f0 == 0 166 | f0 = norm_f0(f0, uv, hparams) 167 | if sum(uv) == len(f0): 168 | f0[uv] = 0 169 | elif sum(uv) > 0: 170 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 171 | if is_torch: 172 | uv = torch.FloatTensor(uv) 173 | f0 = torch.FloatTensor(f0) 174 | f0 = f0.to(device) 175 | uv = uv.to(device) 176 | return f0, uv 177 | 178 | 179 | def denorm_f0(f0, uv, hparams, pitch_padding=None, min=None, max=None): 180 | is_torch = isinstance(f0, torch.Tensor) 181 | if hparams['pitch_norm'] == 'standard': 182 | f0 = f0 * hparams['f0_std'] + hparams['f0_mean'] 183 | if hparams['pitch_norm'] == 'log': 184 | f0 = 2 ** f0 185 | if min is None: 186 | min = 0 187 | if max is None: 188 | max = f0_max 189 | f0 = f0.clamp(min=min) if is_torch else np.clip(f0, min=min) 190 | f0 = f0.clamp(max=max) if is_torch else np.clip(f0, max=max) 191 | if uv is not None and hparams['use_uv']: 192 | f0[uv > 0] = 0 193 | if pitch_padding is not None: 194 | f0[pitch_padding] = 0 195 | return f0 196 | 197 | 198 | def pitchfeats(wav, sampling_rate, fft_size, hop_size, win_length, fmin, fmax): 199 | pitches, magnitudes = librosa.piptrack(wav, sampling_rate, 200 | n_fft=fft_size, win_length=win_length, hop_length=hop_size, 201 | fmin=fmin, fmax=fmax) 202 | pitches = pitches.T 203 | magnitudes = magnitudes.T 204 | assert pitches.shape == magnitudes.shape 205 | 206 | pitches = [pitches[i][find_f0(magnitudes[i])] for i, _ in enumerate(pitches)] 207 | 208 | return np.asarray(pitches) 209 | 210 | 211 | def find_f0(mags): 212 | tmp = 0 213 | mags = list(mags) 214 | for i, mag in enumerate(mags): 215 | if mag < tmp: 216 | # return i-1 217 | if tmp - mag > 2: 218 | # return i-1 219 | return mags.index(max(mags[0:i])) 220 | else: 221 | return 0 222 | else: 223 | tmp = mag 224 | return 0 225 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] 9 | 10 | 11 | def spec_to_figure(spec, vmin=None, vmax=None, title=''): 12 | if isinstance(spec, torch.Tensor): 13 | spec = spec.cpu().numpy() 14 | fig = plt.figure(figsize=(12, 6)) 15 | plt.title(title) 16 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 17 | return fig 18 | 19 | 20 | def spec_f0_to_figure(spec, f0s, figsize=None): 21 | max_y = spec.shape[1] 22 | if isinstance(spec, torch.Tensor): 23 | spec = spec.detach().cpu().numpy() 24 | f0s = {k: f0.detach().cpu().numpy() for k, f0 in f0s.items()} 25 | f0s = {k: f0 / 10 for k, f0 in f0s.items()} 26 | fig = plt.figure(figsize=(12, 6) if figsize is None else figsize) 27 | plt.pcolor(spec.T) 28 | for i, (k, f0) in enumerate(f0s.items()): 29 | plt.plot(f0.clip(0, max_y), label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.8) 30 | plt.legend() 31 | return fig 32 | 33 | 34 | def dur_to_figure(dur_gt, dur_pred, txt, mels=None, vmin=-5.5, vmax=1): 35 | dur_gt = dur_gt.cpu().numpy() 36 | dur_pred = dur_pred.cpu().numpy() 37 | dur_gt = np.cumsum(dur_gt).astype(int) 38 | dur_pred = np.cumsum(dur_pred).astype(int) 39 | fig = plt.figure(figsize=(12, 6)) 40 | for i in range(len(dur_gt)): 41 | shift = (i % 8) + 1 42 | plt.text(dur_gt[i], shift * 4, txt[i]) 43 | plt.text(dur_pred[i], 40 + shift * 4, txt[i]) 44 | plt.vlines(dur_gt[i], 0, 40, colors='b') # blue is gt 45 | plt.vlines(dur_pred[i], 40, 80, colors='r') # red is pred 46 | plt.xlim(0, max(dur_gt[-1], dur_pred[-1])) 47 | if mels is not None: 48 | mels = mels.cpu().numpy() 49 | plt.pcolor(mels.T, vmin=vmin, vmax=vmax) 50 | return fig 51 | 52 | 53 | def f0_to_figure(f0_gt, f0_cwt=None, f0_pred=None): 54 | fig = plt.figure(figsize=(12, 8)) 55 | f0_gt = f0_gt.cpu().numpy() 56 | plt.plot(f0_gt, color='r', label='gt') 57 | if f0_cwt is not None: 58 | f0_cwt = f0_cwt.cpu().numpy() 59 | plt.plot(f0_cwt, color='b', label='cwt') 60 | if f0_pred is not None: 61 | f0_pred = f0_pred.cpu().numpy() 62 | plt.plot(f0_pred, color='green', label='pred') 63 | plt.legend() 64 | return fig 65 | -------------------------------------------------------------------------------- /utils/rnnoise.py: -------------------------------------------------------------------------------- 1 | # rnnoise.py, requirements: ffmpeg, sox, rnnoise, python 2 | import os 3 | import subprocess 4 | 5 | INSTALL_STR = """ 6 | RNNoise library not found. Please install RNNoise (https://github.com/xiph/rnnoise) to $REPO/rnnoise: 7 | sudo apt-get install -y autoconf automake libtool ffmpeg sox 8 | git clone https://github.com/xiph/rnnoise.git 9 | rm -rf rnnoise/.git 10 | cd rnnoise 11 | ./autogen.sh && ./configure && make 12 | cd .. 13 | """ 14 | 15 | 16 | def rnnoise(filename, out_fn=None, verbose=False, out_sample_rate=22050): 17 | assert os.path.exists('./rnnoise/examples/rnnoise_demo'), INSTALL_STR 18 | if out_fn is None: 19 | out_fn = f"{filename[:-4]}.denoised.wav" 20 | out_48k_fn = f"{out_fn}.48000.wav" 21 | tmp0_fn = f"{out_fn}.0.wav" 22 | tmp1_fn = f"{out_fn}.1.wav" 23 | tmp2_fn = f"{out_fn}.2.raw" 24 | tmp3_fn = f"{out_fn}.3.raw" 25 | if verbose: 26 | print("Pre-processing audio...") # wav to pcm raw 27 | subprocess.check_call( 28 | f'sox "{filename}" -G -r48000 "{tmp0_fn}"', shell=True, stdin=subprocess.PIPE) # convert to raw 29 | subprocess.check_call( 30 | f'sox -v 0.95 "{tmp0_fn}" "{tmp1_fn}"', shell=True, stdin=subprocess.PIPE) # convert to raw 31 | subprocess.check_call( 32 | f'ffmpeg -y -i "{tmp1_fn}" -loglevel quiet -f s16le -ac 1 -ar 48000 "{tmp2_fn}"', 33 | shell=True, stdin=subprocess.PIPE) # convert to raw 34 | if verbose: 35 | print("Applying rnnoise algorithm to audio...") # rnnoise 36 | subprocess.check_call( 37 | f'./rnnoise/examples/rnnoise_demo "{tmp2_fn}" "{tmp3_fn}"', shell=True) 38 | 39 | if verbose: 40 | print("Post-processing audio...") # pcm raw to wav 41 | if filename == out_fn: 42 | subprocess.check_call(f'rm -f "{out_fn}"', shell=True) 43 | subprocess.check_call( 44 | f'sox -t raw -r 48000 -b 16 -e signed-integer -c 1 "{tmp3_fn}" "{out_48k_fn}"', shell=True) 45 | subprocess.check_call(f'sox "{out_48k_fn}" -G -r{out_sample_rate} "{out_fn}"', shell=True) 46 | subprocess.check_call(f'rm -f "{tmp0_fn}" "{tmp1_fn}" "{tmp2_fn}" "{tmp3_fn}" "{out_48k_fn}"', shell=True) 47 | if verbose: 48 | print("Audio-filtering completed!") 49 | -------------------------------------------------------------------------------- /vocoders/__init__.py: -------------------------------------------------------------------------------- 1 | from vocoders import pwg 2 | from vocoders import gl_mel, gl_linear 3 | from vocoders import stft 4 | -------------------------------------------------------------------------------- /vocoders/base_vocoder.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | VOCODERS = {} 4 | 5 | 6 | def register_vocoder(cls): 7 | VOCODERS[cls.__name__.lower()] = cls 8 | VOCODERS[cls.__name__] = cls 9 | return cls 10 | 11 | 12 | def get_vocoder_cls(hparams): 13 | if hparams['vocoder'] in VOCODERS: 14 | return VOCODERS[hparams['vocoder']] 15 | else: 16 | vocoder_cls = hparams['vocoder'] 17 | pkg = ".".join(vocoder_cls.split(".")[:-1]) 18 | cls_name = vocoder_cls.split(".")[-1] 19 | vocoder_cls = getattr(importlib.import_module(pkg), cls_name) 20 | return vocoder_cls 21 | 22 | 23 | class BaseVocoder: 24 | def spec2wav(self, mel): 25 | """ 26 | 27 | :param mel: [T, 80] 28 | :return: wav: [T'] 29 | """ 30 | 31 | raise NotImplementedError 32 | 33 | @staticmethod 34 | def wav2spec(wav_fn): 35 | """ 36 | 37 | :param wav_fn: str 38 | :return: wav, mel: [T, 80] 39 | """ 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /vocoders/gl_linear.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | from utils import audio 3 | from utils.audio import griffin_lim 4 | from utils.hparams import hparams, set_hparams 5 | from vocoders.base_vocoder import BaseVocoder, register_vocoder 6 | import numpy as np 7 | 8 | 9 | @register_vocoder 10 | class GLLinear(BaseVocoder): 11 | def spec2wav(self, spec, **kwargs): 12 | phase = kwargs.get('phase', None) 13 | spec = audio.denormalize(spec, hparams) 14 | spec = audio.db_to_amp(spec) 15 | spec = np.abs(spec.T) 16 | return griffin_lim(spec, hparams, phase) 17 | 18 | @staticmethod 19 | def wav2spec(wav_fn): 20 | sample_rate = hparams['audio_sample_rate'] 21 | wav, _ = librosa.core.load(wav_fn, sr=sample_rate) 22 | fft_size = hparams['fft_size'] 23 | hop_size = hparams['hop_size'] 24 | min_level_db = hparams['min_level_db'] 25 | x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hparams['hop_size'], 26 | win_length=hparams['win_size'], window='hann', pad_mode="constant") 27 | spc = np.abs(x_stft) # [n_bins, T] 28 | phase = np.angle(x_stft) 29 | spc = audio.amp_to_db(spc) 30 | spc = audio.normalize(spc, {'min_level_db': min_level_db}) 31 | spc = spc.T # [T, n_bins] 32 | 33 | l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1) 34 | wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) 35 | wav_data = wav[:spc.shape[0] * hop_size] 36 | return wav_data, spc # [T, n_bins] 37 | 38 | 39 | if __name__ == "__main__": 40 | """ 41 | Run: python vocoders/gl.py --config configs/tts/transformer_tts.yaml 42 | """ 43 | set_hparams() 44 | fn = '相爱后动物伤感-07' 45 | wav_path = f'tmp/{fn}.wav' 46 | vocoder = GLLinear() 47 | _, spec = vocoder.wav2spec(wav_path) 48 | spec, phase = spec[:, :513], spec[:, 513:] 49 | wav = vocoder.spec2wav(spec.T) 50 | librosa.output.write_wav(f'tmp/{fn}_gl.wav', wav, 22050) 51 | wav = vocoder.spec2wav(spec.T, phase=phase.T) 52 | librosa.output.write_wav(f'tmp/{fn}_gl_phase.wav', wav, 22050) 53 | -------------------------------------------------------------------------------- /vocoders/gl_mel.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | from utils.hparams import hparams 4 | from vocoders.base_vocoder import register_vocoder 5 | from vocoders.pwg import PWG 6 | from utils.audio import griffin_lim 7 | 8 | 9 | @register_vocoder 10 | class GLMel(PWG): 11 | def __init__(self): 12 | self.mel_basis = librosa.filters.mel(hparams['audio_sample_rate'], hparams['fft_size'], 13 | hparams['audio_num_mel_bins'], hparams['fmin'], hparams['fmax']) 14 | 15 | def spec2wav(self, spec, **kwargs): 16 | spec = 10 ** spec 17 | spec = np.abs(spec) 18 | x_stft = librosa.util.nnls(self.mel_basis, spec.T) 19 | return griffin_lim(x_stft, hparams) 20 | -------------------------------------------------------------------------------- /vocoders/pwg.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import librosa 4 | import torch 5 | import yaml 6 | from sklearn.preprocessing import StandardScaler 7 | from torch import nn 8 | 9 | import utils 10 | from modules.parallel_wavegan.models import ParallelWaveGANGenerator 11 | from modules.parallel_wavegan.utils import read_hdf5 12 | from utils.hparams import hparams 13 | from utils.pitch_utils import f0_to_coarse 14 | from vocoders.base_vocoder import BaseVocoder, register_vocoder 15 | import numpy as np 16 | 17 | 18 | def load_pwg_model(config_path, checkpoint_path, stats_path): 19 | # load config 20 | with open(config_path) as f: 21 | config = yaml.load(f, Loader=yaml.Loader) 22 | 23 | # setup 24 | if torch.cuda.is_available(): 25 | device = torch.device("cuda") 26 | else: 27 | device = torch.device("cpu") 28 | model = ParallelWaveGANGenerator(**config["generator_params"]) 29 | 30 | ckpt_dict = torch.load(checkpoint_path, map_location="cpu") 31 | if 'state_dict' not in ckpt_dict: # official vocoder 32 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["models"]["generator"]) 33 | scaler = StandardScaler() 34 | if config["format"] == "hdf5": 35 | scaler.mean_ = read_hdf5(stats_path, "mean") 36 | scaler.scale_ = read_hdf5(stats_path, "scale") 37 | elif config["format"] == "npy": 38 | scaler.mean_ = np.load(stats_path)[0] 39 | scaler.scale_ = np.load(stats_path)[1] 40 | else: 41 | raise ValueError("support only hdf5 or npy format.") 42 | else: # custom PWG vocoder 43 | utils.load_ckpt(model, checkpoint_path, 'model_gen') 44 | scaler = None 45 | 46 | model.remove_weight_norm() 47 | model = model.eval().to(device) 48 | print(f"| Loaded models parameters from {checkpoint_path}.") 49 | print(f"| PWG device: {device}.") 50 | return model, scaler, config, device 51 | 52 | 53 | total_time = 0 54 | @register_vocoder 55 | class PWG(BaseVocoder): 56 | def __init__(self): 57 | if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained models 58 | base_dir = 'wavegan_pretrained' 59 | ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl') 60 | ckpt = sorted(ckpts, key= 61 | lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1] 62 | config_path = f'{base_dir}/config.yaml' 63 | print('| load PWG: ', ckpt) 64 | self.model, self.scaler, self.config, self.device = load_pwg_model( 65 | config_path=config_path, 66 | checkpoint_path=ckpt, 67 | stats_path=f'{base_dir}/stats.h5', 68 | ) 69 | else: 70 | base_dir = hparams['vocoder_ckpt'] 71 | config_path = f'{base_dir}/config.yaml' 72 | ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= 73 | lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] 74 | print('| load PWG: ', ckpt) 75 | self.scaler = None 76 | self.model, _, self.config, self.device = load_pwg_model( 77 | config_path=config_path, 78 | checkpoint_path=ckpt, 79 | stats_path=f'{base_dir}/stats.h5', 80 | ) 81 | if 'aux_context_window' not in self.config: 82 | self.config['aux_context_window'] = self.config['generator_params']['aux_context_window'] 83 | 84 | def spec2wav(self, mel, **kwargs): 85 | # start generation 86 | config = self.config 87 | device = self.device 88 | pad_size = (config["aux_context_window"], config["aux_context_window"]) 89 | c = mel 90 | if self.scaler is not None: 91 | c = self.scaler.transform(c) 92 | 93 | with torch.no_grad(): 94 | z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device) 95 | c = np.pad(c, (pad_size, (0, 0)), "edge") 96 | c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device) 97 | p = kwargs.get('f0') 98 | if p is not None: 99 | p = f0_to_coarse(p) 100 | p = np.pad(p, (pad_size,), "edge") 101 | p = torch.LongTensor(p[None, :]).to(device) 102 | with utils.Timer('pwg', enable=hparams['profile_infer']): 103 | y = self.model(z, c, p).view(-1) 104 | wav_out = y.cpu().numpy() 105 | return wav_out 106 | 107 | @staticmethod 108 | def wav2spec(wav_fn, return_linear=False): 109 | from data_gen.tts.data_gen_utils import process_utterance 110 | res = process_utterance( 111 | wav_fn, fft_size=hparams['fft_size'], 112 | hop_size=hparams['hop_size'], 113 | win_length=hparams['win_size'], 114 | num_mels=hparams['audio_num_mel_bins'], 115 | fmin=hparams['fmin'], 116 | fmax=hparams['fmax'], 117 | sample_rate=hparams['audio_sample_rate'], 118 | loud_norm=hparams['loud_norm'], 119 | min_level_db=hparams['min_level_db'], 120 | return_linear=return_linear, vocoder='pwg') 121 | if return_linear: 122 | return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft] 123 | else: 124 | return res[0], res[1].T 125 | 126 | @staticmethod 127 | def wav2mfcc(wav_fn): 128 | fft_size = hparams['fft_size'] 129 | hop_size = hparams['hop_size'] 130 | win_length = hparams['win_size'] 131 | sample_rate = hparams['audio_sample_rate'] 132 | wav, _ = librosa.core.load(wav_fn, sr=sample_rate) 133 | mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13, 134 | n_fft=fft_size, hop_length=hop_size, 135 | win_length=win_length, pad_mode="constant", power=1.0) 136 | mfcc_delta = librosa.feature.delta(mfcc, order=1) 137 | mfcc_delta_delta = librosa.feature.delta(mfcc, order=2) 138 | mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T 139 | return mfcc 140 | -------------------------------------------------------------------------------- /vocoders/stft.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | from utils import audio 3 | from utils.audio import griffin_lim 4 | from utils.hparams import hparams, set_hparams 5 | from vocoders.base_vocoder import BaseVocoder, register_vocoder 6 | import numpy as np 7 | 8 | 9 | @register_vocoder 10 | class STFT(BaseVocoder): 11 | rescale = 100 12 | 13 | def spec2wav(self, spec, **kwargs): 14 | """ 15 | 16 | :param spec: [2, T, n_bins] 17 | :param kwargs: 18 | :return: wav 19 | """ 20 | spec = spec.transpose([0, 2, 1]) 21 | spec = spec[0] + 1j * spec[1] 22 | spec = spec * STFT.rescale 23 | return librosa.istft(spec, hop_length=hparams['hop_size'], win_length=hparams['win_size']) 24 | 25 | @staticmethod 26 | def wav2spec(wav_fn): 27 | sample_rate = hparams['audio_sample_rate'] 28 | wav, _ = librosa.core.load(wav_fn, sr=sample_rate) 29 | x_stft = librosa.stft(wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'], 30 | win_length=hparams['win_size'], window='hann', pad_mode="constant") 31 | x_stft = x_stft.T / STFT.rescale 32 | stft = np.abs(x_stft) # [T, n_bins] 33 | real = np.real(x_stft) 34 | imag = np.imag(x_stft) 35 | real_imag = np.stack([real, imag], -1) # [T, n_bins, 2] 36 | return wav, stft, real_imag 37 | -------------------------------------------------------------------------------- /vocoders/vocoder_utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | 3 | from utils.hparams import hparams 4 | import numpy as np 5 | 6 | 7 | def denoise(wav, v=0.1): 8 | spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'], 9 | win_length=hparams['win_size'], pad_mode='constant') 10 | spec_m = np.abs(spec) 11 | spec_m = np.clip(spec_m - v, a_min=0, a_max=None) 12 | spec_a = np.angle(spec) 13 | 14 | return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'], 15 | win_length=hparams['win_size']) 16 | --------------------------------------------------------------------------------