├── tests └── ref_audio │ ├── test_en_1_ref_short.wav │ └── test_zh_1_ref_short.wav ├── ckpts └── README.md ├── model ├── __init__.py ├── backbones │ ├── README.md │ ├── mmdit.py │ ├── dit.py │ └── unett.py ├── dataset.py ├── cfm.py ├── ecapa_tdnn.py ├── trainer.py ├── utils.py └── modules.py ├── requirements.txt ├── test_infer_batch.sh ├── LICENSE ├── scripts ├── count_max_epoch.py ├── count_params_gflops.py ├── eval_seedtts_testset.py ├── eval_librispeech_test_clean.py ├── prepare_wenetspeech4tts.py └── prepare_emilia.py ├── test_train.py ├── .gitignore ├── test_infer_single.py ├── README.md ├── test_infer_single_edit.py ├── test_infer_batch.py └── data └── Emilia_ZH_EN_pinyin └── vocab.txt /tests/ref_audio/test_en_1_ref_short.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/F5-TTS/main/tests/ref_audio/test_en_1_ref_short.wav -------------------------------------------------------------------------------- /tests/ref_audio/test_zh_1_ref_short.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/F5-TTS/main/tests/ref_audio/test_zh_1_ref_short.wav -------------------------------------------------------------------------------- /ckpts/README.md: -------------------------------------------------------------------------------- 1 | 2 | Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS 3 | 4 | ``` 5 | ckpts/ 6 | E2TTS_Base/ 7 | model_1200000.pt 8 | F5TTS_Base/ 9 | model_1200000.pt 10 | ``` -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.cfm import CFM 2 | 3 | from model.backbones.unett import UNetT 4 | from model.backbones.dit import DiT 5 | from model.backbones.mmdit import MMDiT 6 | 7 | from model.trainer import Trainer 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=0.33.0 2 | datasets 3 | einops>=0.8.0 4 | einx>=0.3.0 5 | ema_pytorch>=0.5.2 6 | faster_whisper 7 | funasr 8 | jieba 9 | jiwer 10 | librosa 11 | matplotlib 12 | pypinyin 13 | torch>=2.0 14 | torchaudio>=2.3.0 15 | torchdiffeq 16 | tqdm>=4.65.0 17 | transformers 18 | vocos 19 | wandb 20 | x_transformers>=1.31.14 21 | zhconv 22 | zhon 23 | -------------------------------------------------------------------------------- /test_infer_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # e.g. F5-TTS, 16 NFE 4 | accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 5 | accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 6 | accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 7 | 8 | # e.g. Vanilla E2 TTS, 32 NFE 9 | accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 10 | accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 11 | accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 12 | 13 | # etc. 14 | -------------------------------------------------------------------------------- /model/backbones/README.md: -------------------------------------------------------------------------------- 1 | ## Backbones quick introduction 2 | 3 | 4 | ### unett.py 5 | - flat unet transformer 6 | - structure same as in e2-tts & voicebox paper except using rotary pos emb 7 | - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat 8 | 9 | ### dit.py 10 | - adaln-zero dit 11 | - embedded timestep as condition 12 | - concatted noised_input + masked_cond + embedded_text, linear proj in 13 | - possible abs pos emb & convnextv2 blocks for embedded text before concat 14 | - possible long skip connection (first layer to last layer) 15 | 16 | ### mmdit.py 17 | - sd3 structure 18 | - timestep as condition 19 | - left stream: text embedded and applied a abs pos emb 20 | - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yushen CHEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/count_max_epoch.py: -------------------------------------------------------------------------------- 1 | '''ADAPTIVE BATCH SIZE''' 2 | print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in') 3 | print(' -> least padding, gather wavs with accumulated frames in a batch\n') 4 | 5 | # data 6 | total_hours = 95282 7 | mel_hop_length = 256 8 | mel_sampling_rate = 24000 9 | 10 | # target 11 | wanted_max_updates = 1000000 12 | 13 | # train params 14 | gpus = 8 15 | frames_per_gpu = 38400 # 8 * 38400 = 307200 16 | grad_accum = 1 17 | 18 | # intermediate 19 | mini_batch_frames = frames_per_gpu * grad_accum * gpus 20 | mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 21 | updates_per_epoch = total_hours / mini_batch_hours 22 | steps_per_epoch = updates_per_epoch * grad_accum 23 | 24 | # result 25 | epochs = wanted_max_updates / updates_per_epoch 26 | print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})") 27 | print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") 28 | print(f" or approx. 0/{steps_per_epoch:.0f} steps") 29 | 30 | # others 31 | print(f"total {total_hours:.0f} hours") 32 | print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch") 33 | -------------------------------------------------------------------------------- /scripts/count_params_gflops.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.getcwd()) 3 | 4 | from model import M2_TTS, UNetT, DiT, MMDiT 5 | 6 | import torch 7 | import thop 8 | 9 | 10 | ''' ~155M ''' 11 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) 12 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) 13 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2) 14 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4) 15 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True) 16 | # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2) 17 | 18 | ''' ~335M ''' 19 | # FLOPs: 622.1 G, Params: 333.2 M 20 | # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 21 | # FLOPs: 363.4 G, Params: 335.8 M 22 | transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) 23 | 24 | 25 | model = M2_TTS(transformer=transformer) 26 | target_sample_rate = 24000 27 | n_mel_channels = 100 28 | hop_length = 256 29 | duration = 20 30 | frame_length = int(duration * target_sample_rate / hop_length) 31 | text_length = 150 32 | 33 | flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))) 34 | print(f"FLOPs: {flops / 1e9} G") 35 | print(f"Params: {params / 1e6} M") 36 | -------------------------------------------------------------------------------- /scripts/eval_seedtts_testset.py: -------------------------------------------------------------------------------- 1 | # Evaluate with Seed-TTS testset 2 | 3 | import sys, os 4 | sys.path.append(os.getcwd()) 5 | 6 | import multiprocessing as mp 7 | import numpy as np 8 | 9 | from model.utils import ( 10 | get_seed_tts_test, 11 | run_asr_wer, 12 | run_sim, 13 | ) 14 | 15 | 16 | eval_task = "wer" # sim | wer 17 | lang = "zh" # zh | en 18 | metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset 19 | # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs 20 | gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs 21 | 22 | 23 | # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different 24 | # zh 1.254 seems a result of 4 workers wer_seed_tts 25 | gpus = [0,1,2,3,4,5,6,7] 26 | test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) 27 | 28 | local = False 29 | if local: # use local custom checkpoint dir 30 | if lang == "zh": 31 | asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr 32 | elif lang == "en": 33 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" 34 | else: 35 | asr_ckpt_dir = "" # auto download to cache dir 36 | 37 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" 38 | 39 | 40 | # --------------------------- WER --------------------------- 41 | 42 | if eval_task == "wer": 43 | wers = [] 44 | 45 | with mp.Pool(processes=len(gpus)) as pool: 46 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] 47 | results = pool.map(run_asr_wer, args) 48 | for wers_ in results: 49 | wers.extend(wers_) 50 | 51 | wer = round(np.mean(wers)*100, 3) 52 | print(f"\nTotal {len(wers)} samples") 53 | print(f"WER : {wer}%") 54 | 55 | 56 | # --------------------------- SIM --------------------------- 57 | 58 | if eval_task == "sim": 59 | sim_list = [] 60 | 61 | with mp.Pool(processes=len(gpus)) as pool: 62 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] 63 | results = pool.map(run_sim, args) 64 | for sim_ in results: 65 | sim_list.extend(sim_) 66 | 67 | sim = round(sum(sim_list)/len(sim_list), 3) 68 | print(f"\nTotal {len(sim_list)} samples") 69 | print(f"SIM : {sim}") 70 | -------------------------------------------------------------------------------- /scripts/eval_librispeech_test_clean.py: -------------------------------------------------------------------------------- 1 | # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation) 2 | 3 | import sys, os 4 | sys.path.append(os.getcwd()) 5 | 6 | import multiprocessing as mp 7 | import numpy as np 8 | 9 | from model.utils import ( 10 | get_librispeech_test, 11 | run_asr_wer, 12 | run_sim, 13 | ) 14 | 15 | 16 | eval_task = "wer" # sim | wer 17 | lang = "en" 18 | metalst = "data/librispeech_pc_test_clean_cross_sentence.lst" 19 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path 20 | gen_wav_dir = "PATH_TO_GENERATED" # generated wavs 21 | 22 | gpus = [0,1,2,3,4,5,6,7] 23 | test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) 24 | 25 | ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, 26 | ## leading to a low similarity for the ground truth in some cases. 27 | # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth 28 | 29 | local = False 30 | if local: # use local custom checkpoint dir 31 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" 32 | else: 33 | asr_ckpt_dir = "" # auto download to cache dir 34 | 35 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" 36 | 37 | 38 | # --------------------------- WER --------------------------- 39 | 40 | if eval_task == "wer": 41 | wers = [] 42 | 43 | with mp.Pool(processes=len(gpus)) as pool: 44 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] 45 | results = pool.map(run_asr_wer, args) 46 | for wers_ in results: 47 | wers.extend(wers_) 48 | 49 | wer = round(np.mean(wers)*100, 3) 50 | print(f"\nTotal {len(wers)} samples") 51 | print(f"WER : {wer}%") 52 | 53 | 54 | # --------------------------- SIM --------------------------- 55 | 56 | if eval_task == "sim": 57 | sim_list = [] 58 | 59 | with mp.Pool(processes=len(gpus)) as pool: 60 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] 61 | results = pool.map(run_sim, args) 62 | for sim_ in results: 63 | sim_list.extend(sim_) 64 | 65 | sim = round(sum(sim_list)/len(sim_list), 3) 66 | print(f"\nTotal {len(sim_list)} samples") 67 | print(f"SIM : {sim}") 68 | -------------------------------------------------------------------------------- /test_train.py: -------------------------------------------------------------------------------- 1 | from model import CFM, UNetT, DiT, MMDiT, Trainer 2 | from model.utils import get_tokenizer 3 | from model.dataset import load_dataset 4 | 5 | 6 | # -------------------------- Dataset Settings --------------------------- # 7 | 8 | target_sample_rate = 24000 9 | n_mel_channels = 100 10 | hop_length = 256 11 | 12 | tokenizer = "pinyin" 13 | dataset_name = "Emilia_ZH_EN" 14 | 15 | 16 | # -------------------------- Training Settings -------------------------- # 17 | 18 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base 19 | 20 | learning_rate = 7.5e-5 21 | 22 | batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200 23 | batch_size_type = "frame" # "frame" or "sample" 24 | max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 25 | grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps 26 | max_grad_norm = 1. 27 | 28 | epochs = 11 # use linear decay, thus epochs control the slope 29 | num_warmup_updates = 20000 # warmup steps 30 | save_per_updates = 50000 # save checkpoint per steps 31 | last_per_steps = 5000 # save last checkpoint per steps 32 | 33 | # model params 34 | if exp_name == "F5TTS_Base": 35 | wandb_resume_id = None 36 | model_cls = DiT 37 | model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) 38 | elif exp_name == "E2TTS_Base": 39 | wandb_resume_id = None 40 | model_cls = UNetT 41 | model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 42 | 43 | 44 | # ----------------------------------------------------------------------- # 45 | 46 | def main(): 47 | 48 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 49 | 50 | mel_spec_kwargs = dict( 51 | target_sample_rate = target_sample_rate, 52 | n_mel_channels = n_mel_channels, 53 | hop_length = hop_length, 54 | ) 55 | 56 | e2tts = CFM( 57 | transformer = model_cls( 58 | **model_cfg, 59 | text_num_embeds = vocab_size, 60 | mel_dim = n_mel_channels 61 | ), 62 | mel_spec_kwargs = mel_spec_kwargs, 63 | vocab_char_map = vocab_char_map, 64 | ) 65 | 66 | trainer = Trainer( 67 | e2tts, 68 | epochs, 69 | learning_rate, 70 | num_warmup_updates = num_warmup_updates, 71 | save_per_updates = save_per_updates, 72 | checkpoint_path = f'ckpts/{exp_name}', 73 | batch_size = batch_size_per_gpu, 74 | batch_size_type = batch_size_type, 75 | max_samples = max_samples, 76 | grad_accumulation_steps = grad_accumulation_steps, 77 | max_grad_norm = max_grad_norm, 78 | wandb_project = "CFM-TTS", 79 | wandb_run_name = exp_name, 80 | wandb_resume_id = wandb_resume_id, 81 | last_per_steps = last_per_steps, 82 | ) 83 | 84 | train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) 85 | trainer.train(train_dataset, 86 | resumable_with_seed = 666 # seed for shuffling dataset 87 | ) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Customed 2 | .vscode/ 3 | tests/ 4 | runs/ 5 | data/ 6 | ckpts/ 7 | wandb/ 8 | results/ 9 | 10 | 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 121 | .pdm.toml 122 | .pdm-python 123 | .pdm-build/ 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | -------------------------------------------------------------------------------- /model/backbones/mmdit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from einops import repeat 16 | 17 | from x_transformers.x_transformers import RotaryEmbedding 18 | 19 | from model.modules import ( 20 | TimestepEmbedding, 21 | ConvPositionEmbedding, 22 | MMDiTBlock, 23 | AdaLayerNormZero_Final, 24 | precompute_freqs_cis, get_pos_embed_indices, 25 | ) 26 | 27 | 28 | # text embedding 29 | 30 | class TextEmbedding(nn.Module): 31 | def __init__(self, out_dim, text_num_embeds): 32 | super().__init__() 33 | self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token 34 | 35 | self.precompute_max_pos = 1024 36 | self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) 37 | 38 | def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']: 39 | text = text + 1 40 | if drop_text: 41 | text = torch.zeros_like(text) 42 | text = self.text_embed(text) 43 | 44 | # sinus pos emb 45 | batch_start = torch.zeros((text.shape[0],), dtype=torch.long) 46 | batch_text_len = text.shape[1] 47 | pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) 48 | text_pos_embed = self.freqs_cis[pos_idx] 49 | 50 | text = text + text_pos_embed 51 | 52 | return text 53 | 54 | 55 | # noised input & masked cond audio embedding 56 | 57 | class AudioEmbedding(nn.Module): 58 | def __init__(self, in_dim, out_dim): 59 | super().__init__() 60 | self.linear = nn.Linear(2 * in_dim, out_dim) 61 | self.conv_pos_embed = ConvPositionEmbedding(out_dim) 62 | 63 | def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False): 64 | if drop_audio_cond: 65 | cond = torch.zeros_like(cond) 66 | x = torch.cat((x, cond), dim = -1) 67 | x = self.linear(x) 68 | x = self.conv_pos_embed(x) + x 69 | return x 70 | 71 | 72 | # Transformer backbone using MM-DiT blocks 73 | 74 | class MMDiT(nn.Module): 75 | def __init__(self, *, 76 | dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4, 77 | text_num_embeds = 256, mel_dim = 100, 78 | ): 79 | super().__init__() 80 | 81 | self.time_embed = TimestepEmbedding(dim) 82 | self.text_embed = TextEmbedding(dim, text_num_embeds) 83 | self.audio_embed = AudioEmbedding(mel_dim, dim) 84 | 85 | self.rotary_embed = RotaryEmbedding(dim_head) 86 | 87 | self.dim = dim 88 | self.depth = depth 89 | 90 | self.transformer_blocks = nn.ModuleList( 91 | [ 92 | MMDiTBlock( 93 | dim = dim, 94 | heads = heads, 95 | dim_head = dim_head, 96 | dropout = dropout, 97 | ff_mult = ff_mult, 98 | context_pre_only = i == depth - 1, 99 | ) 100 | for i in range(depth) 101 | ] 102 | ) 103 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation 104 | self.proj_out = nn.Linear(dim, mel_dim) 105 | 106 | def forward( 107 | self, 108 | x: float['b n d'], # nosied input audio 109 | cond: float['b n d'], # masked cond audio 110 | text: int['b nt'], # text 111 | time: float['b'] | float[''], # time step 112 | drop_audio_cond, # cfg for cond audio 113 | drop_text, # cfg for text 114 | mask: bool['b n'] | None = None, 115 | ): 116 | batch = x.shape[0] 117 | if time.ndim == 0: 118 | time = repeat(time, ' -> b', b = batch) 119 | 120 | # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio 121 | t = self.time_embed(time) 122 | c = self.text_embed(text, drop_text = drop_text) 123 | x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond) 124 | 125 | seq_len = x.shape[1] 126 | text_len = text.shape[1] 127 | rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) 128 | rope_text = self.rotary_embed.forward_from_seq_len(text_len) 129 | 130 | for block in self.transformer_blocks: 131 | c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text) 132 | 133 | x = self.norm_out(x, t) 134 | output = self.proj_out(x) 135 | 136 | return output 137 | -------------------------------------------------------------------------------- /scripts/prepare_wenetspeech4tts.py: -------------------------------------------------------------------------------- 1 | # generate audio text map for WenetSpeech4TTS 2 | # evaluate for vocab size 3 | 4 | import sys, os 5 | sys.path.append(os.getcwd()) 6 | 7 | import json 8 | from tqdm import tqdm 9 | from concurrent.futures import ProcessPoolExecutor 10 | 11 | import torchaudio 12 | from datasets import Dataset 13 | 14 | from model.utils import convert_char_to_pinyin 15 | 16 | 17 | def deal_with_sub_path_files(dataset_path, sub_path): 18 | print(f"Dealing with: {sub_path}") 19 | 20 | text_dir = os.path.join(dataset_path, sub_path, "txts") 21 | audio_dir = os.path.join(dataset_path, sub_path, "wavs") 22 | text_files = os.listdir(text_dir) 23 | 24 | audio_paths, texts, durations = [], [], [] 25 | for text_file in tqdm(text_files): 26 | with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file: 27 | first_line = file.readline().split("\t") 28 | audio_nm = first_line[0] 29 | audio_path = os.path.join(audio_dir, audio_nm + ".wav") 30 | text = first_line[1].strip() 31 | 32 | audio_paths.append(audio_path) 33 | 34 | if tokenizer == "pinyin": 35 | texts.extend(convert_char_to_pinyin([text], polyphone = polyphone)) 36 | elif tokenizer == "char": 37 | texts.append(text) 38 | 39 | audio, sample_rate = torchaudio.load(audio_path) 40 | durations.append(audio.shape[-1] / sample_rate) 41 | 42 | return audio_paths, texts, durations 43 | 44 | 45 | def main(): 46 | assert tokenizer in ["pinyin", "char"] 47 | 48 | audio_path_list, text_list, duration_list = [], [], [] 49 | 50 | executor = ProcessPoolExecutor(max_workers=max_workers) 51 | futures = [] 52 | for dataset_path in dataset_paths: 53 | sub_items = os.listdir(dataset_path) 54 | sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))] 55 | for sub_path in sub_paths: 56 | futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path)) 57 | for future in tqdm(futures, total=len(futures)): 58 | audio_paths, texts, durations = future.result() 59 | audio_path_list.extend(audio_paths) 60 | text_list.extend(texts) 61 | duration_list.extend(durations) 62 | executor.shutdown() 63 | 64 | if not os.path.exists("data"): 65 | os.makedirs("data") 66 | 67 | print(f"\nSaving to data/{dataset_name}_{tokenizer} ...") 68 | dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) 69 | dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format 70 | 71 | with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f: 72 | json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease 73 | 74 | print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...") 75 | text_vocab_set = set() 76 | for text in tqdm(text_list): 77 | text_vocab_set.update(list(text)) 78 | 79 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 80 | if tokenizer == "pinyin": 81 | text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 82 | 83 | with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f: 84 | for vocab in sorted(text_vocab_set): 85 | f.write(vocab + "\n") 86 | print(f"\nFor {dataset_name}, sample count: {len(text_list)}") 87 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n") 88 | 89 | 90 | if __name__ == "__main__": 91 | 92 | max_workers = 32 93 | 94 | tokenizer = "pinyin" # "pinyin" | "char" 95 | polyphone = True 96 | dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic 97 | 98 | dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1] 99 | dataset_paths = [ 100 | "/WenetSpeech4TTS/Basic", 101 | "/WenetSpeech4TTS/Standard", 102 | "/WenetSpeech4TTS/Premium", 103 | ][-dataset_choice:] 104 | print(f"\nChoose Dataset: {dataset_name}\n") 105 | 106 | main() 107 | 108 | # Results (if adding alphabets with accents and symbols): 109 | # WenetSpeech4TTS Basic Standard Premium 110 | # samples count 3932473 1941220 407494 111 | # pinyin vocab size 1349 1348 1344 (no polyphone) 112 | # - - 1459 (polyphone) 113 | # char vocab size 5264 5219 5042 114 | 115 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) 116 | # please be careful if using pretrained model, make sure the vocab.txt is same 117 | -------------------------------------------------------------------------------- /test_infer_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import torch 5 | import torchaudio 6 | from einops import rearrange 7 | from ema_pytorch import EMA 8 | from vocos import Vocos 9 | 10 | from model import CFM, UNetT, DiT, MMDiT 11 | from model.utils import ( 12 | get_tokenizer, 13 | convert_char_to_pinyin, 14 | save_spectrogram, 15 | ) 16 | 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | 20 | # --------------------- Dataset Settings -------------------- # 21 | 22 | target_sample_rate = 24000 23 | n_mel_channels = 100 24 | hop_length = 256 25 | target_rms = 0.1 26 | 27 | tokenizer = "pinyin" 28 | dataset_name = "Emilia_ZH_EN" 29 | 30 | 31 | # ---------------------- infer setting ---------------------- # 32 | 33 | seed = None # int | None 34 | 35 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base 36 | ckpt_step = 1200000 37 | 38 | nfe_step = 32 # 16, 32 39 | cfg_strength = 2. 40 | ode_method = 'euler' # euler | midpoint 41 | sway_sampling_coef = -1. 42 | speed = 1. 43 | fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio) 44 | 45 | if exp_name == "F5TTS_Base": 46 | model_cls = DiT 47 | model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) 48 | 49 | elif exp_name == "E2TTS_Base": 50 | model_cls = UNetT 51 | model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 52 | 53 | checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) 54 | output_dir = "tests" 55 | 56 | ref_audio = "tests/ref_audio/test_en_1_ref_short.wav" 57 | ref_text = "Some call me nature, others call me mother nature." 58 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences." 59 | 60 | # ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav" 61 | # ref_text = "对,这就是我,万人敬仰的太乙真人。" 62 | # gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\"" 63 | 64 | 65 | # -------------------------------------------------# 66 | 67 | use_ema = True 68 | 69 | if not os.path.exists(output_dir): 70 | os.makedirs(output_dir) 71 | 72 | # Vocoder model 73 | local = False 74 | if local: 75 | vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" 76 | vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") 77 | state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) 78 | vocos.load_state_dict(state_dict) 79 | vocos.eval() 80 | else: 81 | vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 82 | 83 | # Tokenizer 84 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 85 | 86 | # Model 87 | model = CFM( 88 | transformer = model_cls( 89 | **model_cfg, 90 | text_num_embeds = vocab_size, 91 | mel_dim = n_mel_channels 92 | ), 93 | mel_spec_kwargs = dict( 94 | target_sample_rate = target_sample_rate, 95 | n_mel_channels = n_mel_channels, 96 | hop_length = hop_length, 97 | ), 98 | odeint_kwargs = dict( 99 | method = ode_method, 100 | ), 101 | vocab_char_map = vocab_char_map, 102 | ).to(device) 103 | 104 | if use_ema == True: 105 | ema_model = EMA(model, include_online_model = False).to(device) 106 | ema_model.load_state_dict(checkpoint['ema_model_state_dict']) 107 | ema_model.copy_params_from_ema_to_model() 108 | else: 109 | model.load_state_dict(checkpoint['model_state_dict']) 110 | 111 | # Audio 112 | audio, sr = torchaudio.load(ref_audio) 113 | rms = torch.sqrt(torch.mean(torch.square(audio))) 114 | if rms < target_rms: 115 | audio = audio * target_rms / rms 116 | if sr != target_sample_rate: 117 | resampler = torchaudio.transforms.Resample(sr, target_sample_rate) 118 | audio = resampler(audio) 119 | audio = audio.to(device) 120 | 121 | # Text 122 | text_list = [ref_text + gen_text] 123 | if tokenizer == "pinyin": 124 | final_text_list = convert_char_to_pinyin(text_list) 125 | else: 126 | final_text_list = [text_list] 127 | print(f"text : {text_list}") 128 | print(f"pinyin: {final_text_list}") 129 | 130 | # Duration 131 | ref_audio_len = audio.shape[-1] // hop_length 132 | if fix_duration is not None: 133 | duration = int(fix_duration * target_sample_rate / hop_length) 134 | else: # simple linear scale calcul 135 | zh_pause_punc = r"。,、;:?!" 136 | ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text)) 137 | gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text)) 138 | duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) 139 | 140 | # Inference 141 | with torch.inference_mode(): 142 | generated, trajectory = model.sample( 143 | cond = audio, 144 | text = final_text_list, 145 | duration = duration, 146 | steps = nfe_step, 147 | cfg_strength = cfg_strength, 148 | sway_sampling_coef = sway_sampling_coef, 149 | seed = seed, 150 | ) 151 | print(f"Generated mel: {generated.shape}") 152 | 153 | # Final result 154 | generated = generated[:, ref_audio_len:, :] 155 | generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') 156 | generated_wave = vocos.decode(generated_mel_spec.cpu()) 157 | if rms < target_rms: 158 | generated_wave = generated_wave * rms / target_rms 159 | 160 | save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png") 161 | torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate) 162 | print(f"Generated wav: {generated_wave.shape}") 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching 3 | 4 | [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885) 5 | [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/) 6 | [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) \ 7 | **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference. \ 8 | **E2 TTS**: Flat-UNet Transformer, closest reproduction.\ 9 | **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance 10 | 11 | ## Installation 12 | Clone this repository. 13 | ```bash 14 | git clone git@github.com:SWivid/F5-TTS.git 15 | cd F5-TTS 16 | ``` 17 | Install packages. 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Prepare Dataset 23 | Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`. 24 | ```bash 25 | # prepare custom dataset up to your need 26 | # download corresponding dataset first, and fill in the path in scripts 27 | 28 | # Prepare the Emilia dataset 29 | python scripts/prepare_emilia.py 30 | 31 | # Prepare the Wenetspeech4TTS dataset 32 | python scripts/prepare_wenetspeech4tts.py 33 | ``` 34 | 35 | ## Training 36 | Once your datasets are prepared, you can start the training process. 37 | ```bash 38 | # setup accelerate config, e.g. use multi-gpu ddp, fp16 39 | # will be to: ~/.cache/huggingface/accelerate/default_config.yaml 40 | accelerate config 41 | accelerate launch test_train.py 42 | ``` 43 | 44 | ## Inference 45 | To inference with pretrained models, download the checkpoints from [🤗](https://huggingface.co/SWivid/F5-TTS). 46 | 47 | ### Single Inference 48 | You can test single inference using the following command. Before running the command, modify the config up to your need. 49 | ```bash 50 | # modify the config up to your need, 51 | # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s) 52 | # nfe_step (larger takes more time to do more precise inference ode) 53 | # ode_method (switch to 'midpoint' for better compatibility with small nfe_step, ) 54 | # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler') 55 | python test_infer_single.py 56 | ``` 57 | ### Speech Edit 58 | To test speech editing capabilities, use the following command. 59 | ``` 60 | python test_infer_single_edit.py 61 | ``` 62 | 63 | ## Evaluation 64 | ### Prepare Test Datasets 65 | 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval). 66 | 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/). 67 | 3. Unzip the downloaded datasets and place them in the data/ directory. 68 | 4. Update the path for the test-clean data in `test_infer_batch.py` 69 | 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo 70 | 71 | ### Batch Inference for Test Set 72 | To run batch inference for evaluations, execute the following commands: 73 | ```bash 74 | # batch inference for evaluations 75 | accelerate config # if not set before 76 | bash test_infer_batch.sh 77 | ``` 78 | 79 | ### Download Evaluation Model Checkpoints 80 | 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh) 81 | 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3) 82 | 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view). 83 | 84 | ### Objective Evaluation 85 | **Some Notes**\ 86 | For faster-whisper with CUDA 11: \ 87 | `pip install --force-reinstall ctranslate2==3.24.0`\ 88 | (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:\ 89 | `pip install faster-whisper==0.10.1` 90 | 91 | Update the path with your batch-inferenced results, and carry out WER / SIM evaluations: 92 | ```bash 93 | # Evaluation for Seed-TTS test set 94 | python scripts/eval_seedtts_testset.py 95 | 96 | # Evaluation for LibriSpeech-PC test-clean (cross-sentence) 97 | python scripts/eval_librispeech_test_clean.py 98 | ``` 99 | 100 | ## Acknowledgements 101 | 102 | - E2-TTS brilliant work, simple and effective 103 | - Emilia, WenetSpeech4TTS valuable datasets 104 | - lucidrains initial CFM structure with also bfs18 for discussion 105 | - SD3 & Huggingface diffusers DiT and MMDiT code structure 106 | - torchdiffeq as ODE solver, Vocos as vocoder 107 | - mrfakename huggingface space demo ~ 108 | - FunASR, faster-whisper & UniSpeech for evaluation tools 109 | - ctc-forced-aligner for speech edit test 110 | 111 | ## Citation 112 | ``` 113 | @article{chen-etal-2024-f5tts, 114 | title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching}, 115 | author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen}, 116 | journal={arXiv preprint arXiv:2410.06885}, 117 | year={2024}, 118 | } 119 | ``` 120 | ## LICENSE 121 | Our code is released under MIT License. 122 | -------------------------------------------------------------------------------- /model/backbones/dit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | from einops import repeat 17 | 18 | from x_transformers.x_transformers import RotaryEmbedding 19 | 20 | from model.modules import ( 21 | TimestepEmbedding, 22 | ConvNeXtV2Block, 23 | ConvPositionEmbedding, 24 | DiTBlock, 25 | AdaLayerNormZero_Final, 26 | precompute_freqs_cis, get_pos_embed_indices, 27 | ) 28 | 29 | 30 | # Text embedding 31 | 32 | class TextEmbedding(nn.Module): 33 | def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): 34 | super().__init__() 35 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token 36 | 37 | if conv_layers > 0: 38 | self.extra_modeling = True 39 | self.precompute_max_pos = 4096 # ~44s of 24khz audio 40 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) 41 | self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) 42 | else: 43 | self.extra_modeling = False 44 | 45 | def forward(self, text: int['b nt'], seq_len, drop_text = False): 46 | batch, text_len = text.shape[0], text.shape[1] 47 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() 48 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens 49 | text = F.pad(text, (0, seq_len - text_len), value = 0) 50 | 51 | if drop_text: # cfg for text 52 | text = torch.zeros_like(text) 53 | 54 | text = self.text_embed(text) # b n -> b n d 55 | 56 | # possible extra modeling 57 | if self.extra_modeling: 58 | # sinus pos emb 59 | batch_start = torch.zeros((batch,), dtype=torch.long) 60 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) 61 | text_pos_embed = self.freqs_cis[pos_idx] 62 | text = text + text_pos_embed 63 | 64 | # convnextv2 blocks 65 | text = self.text_blocks(text) 66 | 67 | return text 68 | 69 | 70 | # noised input audio and context mixing embedding 71 | 72 | class InputEmbedding(nn.Module): 73 | def __init__(self, mel_dim, text_dim, out_dim): 74 | super().__init__() 75 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) 76 | self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim) 77 | 78 | def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False): 79 | if drop_audio_cond: # cfg for cond audio 80 | cond = torch.zeros_like(cond) 81 | 82 | x = self.proj(torch.cat((x, cond, text_embed), dim = -1)) 83 | x = self.conv_pos_embed(x) + x 84 | return x 85 | 86 | 87 | # Transformer backbone using DiT blocks 88 | 89 | class DiT(nn.Module): 90 | def __init__(self, *, 91 | dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4, 92 | mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0, 93 | long_skip_connection = False, 94 | ): 95 | super().__init__() 96 | 97 | self.time_embed = TimestepEmbedding(dim) 98 | if text_dim is None: 99 | text_dim = mel_dim 100 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers) 101 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim) 102 | 103 | self.rotary_embed = RotaryEmbedding(dim_head) 104 | 105 | self.dim = dim 106 | self.depth = depth 107 | 108 | self.transformer_blocks = nn.ModuleList( 109 | [ 110 | DiTBlock( 111 | dim = dim, 112 | heads = heads, 113 | dim_head = dim_head, 114 | ff_mult = ff_mult, 115 | dropout = dropout 116 | ) 117 | for _ in range(depth) 118 | ] 119 | ) 120 | self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None 121 | 122 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation 123 | self.proj_out = nn.Linear(dim, mel_dim) 124 | 125 | def forward( 126 | self, 127 | x: float['b n d'], # nosied input audio 128 | cond: float['b n d'], # masked cond audio 129 | text: int['b nt'], # text 130 | time: float['b'] | float[''], # time step 131 | drop_audio_cond, # cfg for cond audio 132 | drop_text, # cfg for text 133 | mask: bool['b n'] | None = None, 134 | ): 135 | batch, seq_len = x.shape[0], x.shape[1] 136 | if time.ndim == 0: 137 | time = repeat(time, ' -> b', b = batch) 138 | 139 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio 140 | t = self.time_embed(time) 141 | text_embed = self.text_embed(text, seq_len, drop_text = drop_text) 142 | x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond) 143 | 144 | rope = self.rotary_embed.forward_from_seq_len(seq_len) 145 | 146 | if self.long_skip_connection is not None: 147 | residual = x 148 | 149 | for block in self.transformer_blocks: 150 | x = block(x, t, mask = mask, rope = rope) 151 | 152 | if self.long_skip_connection is not None: 153 | x = self.long_skip_connection(torch.cat((x, residual), dim = -1)) 154 | 155 | x = self.norm_out(x, t) 156 | output = self.proj_out(x) 157 | 158 | return output 159 | -------------------------------------------------------------------------------- /test_infer_single_edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torchaudio 6 | from einops import rearrange 7 | from ema_pytorch import EMA 8 | from vocos import Vocos 9 | 10 | from model import CFM, UNetT, DiT, MMDiT 11 | from model.utils import ( 12 | get_tokenizer, 13 | convert_char_to_pinyin, 14 | save_spectrogram, 15 | ) 16 | 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | 20 | # --------------------- Dataset Settings -------------------- # 21 | 22 | target_sample_rate = 24000 23 | n_mel_channels = 100 24 | hop_length = 256 25 | target_rms = 0.1 26 | 27 | tokenizer = "pinyin" 28 | dataset_name = "Emilia_ZH_EN" 29 | 30 | 31 | # ---------------------- infer setting ---------------------- # 32 | 33 | seed = None # int | None 34 | 35 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base 36 | ckpt_step = 1200000 37 | 38 | nfe_step = 32 # 16, 32 39 | cfg_strength = 2. 40 | ode_method = 'euler' # euler | midpoint 41 | sway_sampling_coef = -1. 42 | speed = 1. 43 | 44 | if exp_name == "F5TTS_Base": 45 | model_cls = DiT 46 | model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) 47 | 48 | elif exp_name == "E2TTS_Base": 49 | model_cls = UNetT 50 | model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 51 | 52 | checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) 53 | output_dir = "tests" 54 | 55 | # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] 56 | # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git 57 | # [write the origin_text into a file, e.g. tests/test_edit.txt] 58 | # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char" 59 | # [result will be saved at same path of audio file] 60 | # [--language "zho" for Chinese, "eng" for English] 61 | # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] 62 | 63 | audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav" 64 | origin_text = "Some call me nature, others call me mother nature." 65 | target_text = "Some call me optimist, others call me realist." 66 | parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds 67 | fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds 68 | 69 | # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav" 70 | # origin_text = "对,这就是我,万人敬仰的太乙真人。" 71 | # target_text = "对,那就是你,万人敬仰的太白金星。" 72 | # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] 73 | # fix_duration = None # use origin text duration 74 | 75 | 76 | # -------------------------------------------------# 77 | 78 | use_ema = True 79 | 80 | if not os.path.exists(output_dir): 81 | os.makedirs(output_dir) 82 | 83 | # Vocoder model 84 | local = False 85 | if local: 86 | vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" 87 | vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") 88 | state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) 89 | vocos.load_state_dict(state_dict) 90 | vocos.eval() 91 | else: 92 | vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 93 | 94 | # Tokenizer 95 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 96 | 97 | # Model 98 | model = CFM( 99 | transformer = model_cls( 100 | **model_cfg, 101 | text_num_embeds = vocab_size, 102 | mel_dim = n_mel_channels 103 | ), 104 | mel_spec_kwargs = dict( 105 | target_sample_rate = target_sample_rate, 106 | n_mel_channels = n_mel_channels, 107 | hop_length = hop_length, 108 | ), 109 | odeint_kwargs = dict( 110 | method = ode_method, 111 | ), 112 | vocab_char_map = vocab_char_map, 113 | ).to(device) 114 | 115 | if use_ema == True: 116 | ema_model = EMA(model, include_online_model = False).to(device) 117 | ema_model.load_state_dict(checkpoint['ema_model_state_dict']) 118 | ema_model.copy_params_from_ema_to_model() 119 | else: 120 | model.load_state_dict(checkpoint['model_state_dict']) 121 | 122 | # Audio 123 | audio, sr = torchaudio.load(audio_to_edit) 124 | rms = torch.sqrt(torch.mean(torch.square(audio))) 125 | if rms < target_rms: 126 | audio = audio * target_rms / rms 127 | if sr != target_sample_rate: 128 | resampler = torchaudio.transforms.Resample(sr, target_sample_rate) 129 | audio = resampler(audio) 130 | offset = 0 131 | audio_ = torch.zeros(1, 0) 132 | edit_mask = torch.zeros(1, 0, dtype=torch.bool) 133 | for part in parts_to_edit: 134 | start, end = part 135 | part_dur = end - start if fix_duration is None else fix_duration.pop(0) 136 | part_dur = part_dur * target_sample_rate 137 | start = start * target_sample_rate 138 | audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1) 139 | edit_mask = torch.cat((edit_mask, 140 | torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool), 141 | torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool) 142 | ), dim = -1) 143 | offset = end * target_sample_rate 144 | # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1) 145 | edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True) 146 | audio = audio.to(device) 147 | edit_mask = edit_mask.to(device) 148 | 149 | # Text 150 | text_list = [target_text] 151 | if tokenizer == "pinyin": 152 | final_text_list = convert_char_to_pinyin(text_list) 153 | else: 154 | final_text_list = [text_list] 155 | print(f"text : {text_list}") 156 | print(f"pinyin: {final_text_list}") 157 | 158 | # Duration 159 | ref_audio_len = 0 160 | duration = audio.shape[-1] // hop_length 161 | 162 | # Inference 163 | with torch.inference_mode(): 164 | generated, trajectory = model.sample( 165 | cond = audio, 166 | text = final_text_list, 167 | duration = duration, 168 | steps = nfe_step, 169 | cfg_strength = cfg_strength, 170 | sway_sampling_coef = sway_sampling_coef, 171 | seed = seed, 172 | edit_mask = edit_mask, 173 | ) 174 | print(f"Generated mel: {generated.shape}") 175 | 176 | # Final result 177 | generated = generated[:, ref_audio_len:, :] 178 | generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') 179 | generated_wave = vocos.decode(generated_mel_spec.cpu()) 180 | if rms < target_rms: 181 | generated_wave = generated_wave * rms / target_rms 182 | 183 | save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png") 184 | torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate) 185 | print(f"Generated wav: {generated_wave.shape}") 186 | -------------------------------------------------------------------------------- /test_infer_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | import torch 8 | import torchaudio 9 | from accelerate import Accelerator 10 | from einops import rearrange 11 | from ema_pytorch import EMA 12 | from vocos import Vocos 13 | 14 | from model import CFM, UNetT, DiT 15 | from model.utils import ( 16 | get_tokenizer, 17 | get_seedtts_testset_metainfo, 18 | get_librispeech_test_clean_metainfo, 19 | get_inference_prompt, 20 | ) 21 | 22 | accelerator = Accelerator() 23 | device = f"cuda:{accelerator.process_index}" 24 | 25 | 26 | # --------------------- Dataset Settings -------------------- # 27 | 28 | target_sample_rate = 24000 29 | n_mel_channels = 100 30 | hop_length = 256 31 | target_rms = 0.1 32 | 33 | tokenizer = "pinyin" 34 | 35 | 36 | # ---------------------- infer setting ---------------------- # 37 | 38 | parser = argparse.ArgumentParser(description="batch inference") 39 | 40 | parser.add_argument('-s', '--seed', default=None, type=int) 41 | parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN") 42 | parser.add_argument('-n', '--expname', required=True) 43 | parser.add_argument('-c', '--ckptstep', default=1200000, type=int) 44 | 45 | parser.add_argument('-nfe', '--nfestep', default=32, type=int) 46 | parser.add_argument('-o', '--odemethod', default="euler") 47 | parser.add_argument('-ss', '--swaysampling', default=-1, type=float) 48 | 49 | parser.add_argument('-t', '--testset', required=True) 50 | 51 | args = parser.parse_args() 52 | 53 | 54 | seed = args.seed 55 | dataset_name = args.dataset 56 | exp_name = args.expname 57 | ckpt_step = args.ckptstep 58 | checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) 59 | 60 | nfe_step = args.nfestep 61 | ode_method = args.odemethod 62 | sway_sampling_coef = args.swaysampling 63 | 64 | testset = args.testset 65 | 66 | 67 | infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) 68 | cfg_strength = 2. 69 | speed = 1. 70 | use_truth_duration = False 71 | no_ref_audio = False 72 | 73 | 74 | if exp_name == "F5TTS_Base": 75 | model_cls = DiT 76 | model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) 77 | 78 | elif exp_name == "E2TTS_Base": 79 | model_cls = UNetT 80 | model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 81 | 82 | 83 | if testset == "ls_pc_test_clean": 84 | metalst = "data/librispeech_pc_test_clean_cross_sentence.lst" 85 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path 86 | metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) 87 | 88 | elif testset == "seedtts_test_zh": 89 | metalst = "data/seedtts_testset/zh/meta.lst" 90 | metainfo = get_seedtts_testset_metainfo(metalst) 91 | 92 | elif testset == "seedtts_test_en": 93 | metalst = "data/seedtts_testset/en/meta.lst" 94 | metainfo = get_seedtts_testset_metainfo(metalst) 95 | 96 | 97 | # path to save genereted wavs 98 | if seed is None: seed = random.randint(-10000, 10000) 99 | output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \ 100 | f"seed{seed}_{ode_method}_nfe{nfe_step}" \ 101 | f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \ 102 | f"_cfg{cfg_strength}_speed{speed}" \ 103 | f"{'_gt-dur' if use_truth_duration else ''}" \ 104 | f"{'_no-ref-audio' if no_ref_audio else ''}" 105 | 106 | 107 | # -------------------------------------------------# 108 | 109 | use_ema = True 110 | 111 | prompts_all = get_inference_prompt( 112 | metainfo, 113 | speed = speed, 114 | tokenizer = tokenizer, 115 | target_sample_rate = target_sample_rate, 116 | n_mel_channels = n_mel_channels, 117 | hop_length = hop_length, 118 | target_rms = target_rms, 119 | use_truth_duration = use_truth_duration, 120 | infer_batch_size = infer_batch_size, 121 | ) 122 | 123 | # Vocoder model 124 | local = False 125 | if local: 126 | vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" 127 | vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") 128 | state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) 129 | vocos.load_state_dict(state_dict) 130 | vocos.eval() 131 | else: 132 | vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 133 | 134 | # Tokenizer 135 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 136 | 137 | # Model 138 | model = CFM( 139 | transformer = model_cls( 140 | **model_cfg, 141 | text_num_embeds = vocab_size, 142 | mel_dim = n_mel_channels 143 | ), 144 | mel_spec_kwargs = dict( 145 | target_sample_rate = target_sample_rate, 146 | n_mel_channels = n_mel_channels, 147 | hop_length = hop_length, 148 | ), 149 | odeint_kwargs = dict( 150 | method = ode_method, 151 | ), 152 | vocab_char_map = vocab_char_map, 153 | ).to(device) 154 | 155 | if use_ema == True: 156 | ema_model = EMA(model, include_online_model = False).to(device) 157 | ema_model.load_state_dict(checkpoint['ema_model_state_dict']) 158 | ema_model.copy_params_from_ema_to_model() 159 | else: 160 | model.load_state_dict(checkpoint['model_state_dict']) 161 | 162 | if not os.path.exists(output_dir) and accelerator.is_main_process: 163 | os.makedirs(output_dir) 164 | 165 | # start batch inference 166 | accelerator.wait_for_everyone() 167 | start = time.time() 168 | 169 | with accelerator.split_between_processes(prompts_all) as prompts: 170 | 171 | for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): 172 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt 173 | ref_mels = ref_mels.to(device) 174 | ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device) 175 | total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device) 176 | 177 | # Inference 178 | with torch.inference_mode(): 179 | generated, _ = model.sample( 180 | cond = ref_mels, 181 | text = final_text_list, 182 | duration = total_mel_lens, 183 | lens = ref_mel_lens, 184 | steps = nfe_step, 185 | cfg_strength = cfg_strength, 186 | sway_sampling_coef = sway_sampling_coef, 187 | no_ref_audio = no_ref_audio, 188 | seed = seed, 189 | ) 190 | # Final result 191 | for i, gen in enumerate(generated): 192 | gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0) 193 | gen_mel_spec = rearrange(gen, '1 n d -> 1 d n') 194 | generated_wave = vocos.decode(gen_mel_spec.cpu()) 195 | if ref_rms_list[i] < target_rms: 196 | generated_wave = generated_wave * ref_rms_list[i] / target_rms 197 | torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) 198 | 199 | accelerator.wait_for_everyone() 200 | if accelerator.is_main_process: 201 | timediff = time.time() - start 202 | print(f"Done batch inference in {timediff / 60 :.2f} minutes.") 203 | -------------------------------------------------------------------------------- /scripts/prepare_emilia.py: -------------------------------------------------------------------------------- 1 | # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07 2 | # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script 3 | 4 | # generate audio text map for Emilia ZH & EN 5 | # evaluate for vocab size 6 | 7 | import sys, os 8 | sys.path.append(os.getcwd()) 9 | 10 | from pathlib import Path 11 | import json 12 | from tqdm import tqdm 13 | from concurrent.futures import ProcessPoolExecutor 14 | 15 | from datasets import Dataset 16 | from datasets.arrow_writer import ArrowWriter 17 | 18 | from model.utils import ( 19 | repetition_found, 20 | convert_char_to_pinyin, 21 | ) 22 | 23 | 24 | out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"} 25 | zh_filters = ["い", "て"] 26 | # seems synthesized audios, or heavily code-switched 27 | out_en = { 28 | "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375", 29 | 30 | "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995", 31 | } 32 | en_filters = ["ا", "い", "て"] 33 | 34 | 35 | def deal_with_audio_dir(audio_dir): 36 | audio_jsonl = audio_dir.with_suffix(".jsonl") 37 | sub_result, durations = [], [] 38 | vocab_set = set() 39 | bad_case_zh = 0 40 | bad_case_en = 0 41 | with open(audio_jsonl, "r") as f: 42 | lines = f.readlines() 43 | for line in tqdm(lines, desc=f"{audio_jsonl.stem}"): 44 | obj = json.loads(line) 45 | text = obj["text"] 46 | if obj['language'] == "zh": 47 | if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text): 48 | bad_case_zh += 1 49 | continue 50 | else: 51 | text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched 52 | if obj['language'] == "en": 53 | if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4): 54 | bad_case_en += 1 55 | continue 56 | if tokenizer == "pinyin": 57 | text = convert_char_to_pinyin([text], polyphone = polyphone)[0] 58 | duration = obj["duration"] 59 | sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration}) 60 | durations.append(duration) 61 | vocab_set.update(list(text)) 62 | return sub_result, durations, vocab_set, bad_case_zh, bad_case_en 63 | 64 | 65 | def main(): 66 | assert tokenizer in ["pinyin", "char"] 67 | result = [] 68 | duration_list = [] 69 | text_vocab_set = set() 70 | total_bad_case_zh = 0 71 | total_bad_case_en = 0 72 | 73 | # process raw data 74 | executor = ProcessPoolExecutor(max_workers=max_workers) 75 | futures = [] 76 | for lang in langs: 77 | dataset_path = Path(os.path.join(dataset_dir, lang)) 78 | [ 79 | futures.append(executor.submit(deal_with_audio_dir, audio_dir)) 80 | for audio_dir in dataset_path.iterdir() 81 | if audio_dir.is_dir() 82 | ] 83 | for futures in tqdm(futures, total=len(futures)): 84 | sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result() 85 | result.extend(sub_result) 86 | duration_list.extend(durations) 87 | text_vocab_set.update(vocab_set) 88 | total_bad_case_zh += bad_case_zh 89 | total_bad_case_en += bad_case_en 90 | executor.shutdown() 91 | 92 | # save preprocessed dataset to disk 93 | if not os.path.exists(f"data/{dataset_name}"): 94 | os.makedirs(f"data/{dataset_name}") 95 | print(f"\nSaving to data/{dataset_name} ...") 96 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom 97 | # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB") 98 | with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer: 99 | for line in tqdm(result, desc=f"Writing to raw.arrow ..."): 100 | writer.write(line) 101 | 102 | # dup a json separately saving duration in case for DynamicBatchSampler ease 103 | with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f: 104 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 105 | 106 | # vocab map, i.e. tokenizer 107 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 108 | # if tokenizer == "pinyin": 109 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 110 | with open(f"data/{dataset_name}/vocab.txt", "w") as f: 111 | for vocab in sorted(text_vocab_set): 112 | f.write(vocab + "\n") 113 | 114 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 115 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 116 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 117 | if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}") 118 | if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n") 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | max_workers = 32 124 | 125 | tokenizer = "pinyin" # "pinyin" | "char" 126 | polyphone = True 127 | 128 | langs = ["ZH", "EN"] 129 | dataset_dir = "/Emilia_Dataset/raw" 130 | dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}" 131 | print(f"\nPrepare for {dataset_name}\n") 132 | 133 | main() 134 | 135 | # Emilia ZH & EN 136 | # samples count 37837916 (after removal) 137 | # pinyin vocab size 2543 (polyphone) 138 | # total duration 95281.87 (hours) 139 | # bad zh asr cnt 230435 (samples) 140 | # bad eh asr cnt 37217 (samples) 141 | 142 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) 143 | # please be careful if using pretrained model, make sure the vocab.txt is same 144 | -------------------------------------------------------------------------------- /model/backbones/unett.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | from typing import Literal 12 | 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from einops import repeat, pack, unpack 18 | 19 | from x_transformers import RMSNorm 20 | from x_transformers.x_transformers import RotaryEmbedding 21 | 22 | from model.modules import ( 23 | TimestepEmbedding, 24 | ConvNeXtV2Block, 25 | ConvPositionEmbedding, 26 | Attention, 27 | AttnProcessor, 28 | FeedForward, 29 | precompute_freqs_cis, get_pos_embed_indices, 30 | ) 31 | 32 | 33 | # Text embedding 34 | 35 | class TextEmbedding(nn.Module): 36 | def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): 37 | super().__init__() 38 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token 39 | 40 | if conv_layers > 0: 41 | self.extra_modeling = True 42 | self.precompute_max_pos = 4096 # ~44s of 24khz audio 43 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) 44 | self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) 45 | else: 46 | self.extra_modeling = False 47 | 48 | def forward(self, text: int['b nt'], seq_len, drop_text = False): 49 | batch, text_len = text.shape[0], text.shape[1] 50 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() 51 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens 52 | text = F.pad(text, (0, seq_len - text_len), value = 0) 53 | 54 | if drop_text: # cfg for text 55 | text = torch.zeros_like(text) 56 | 57 | text = self.text_embed(text) # b n -> b n d 58 | 59 | # possible extra modeling 60 | if self.extra_modeling: 61 | # sinus pos emb 62 | batch_start = torch.zeros((batch,), dtype=torch.long) 63 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) 64 | text_pos_embed = self.freqs_cis[pos_idx] 65 | text = text + text_pos_embed 66 | 67 | # convnextv2 blocks 68 | text = self.text_blocks(text) 69 | 70 | return text 71 | 72 | 73 | # noised input audio and context mixing embedding 74 | 75 | class InputEmbedding(nn.Module): 76 | def __init__(self, mel_dim, text_dim, out_dim): 77 | super().__init__() 78 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) 79 | self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim) 80 | 81 | def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False): 82 | if drop_audio_cond: # cfg for cond audio 83 | cond = torch.zeros_like(cond) 84 | 85 | x = self.proj(torch.cat((x, cond, text_embed), dim = -1)) 86 | x = self.conv_pos_embed(x) + x 87 | return x 88 | 89 | 90 | # Flat UNet Transformer backbone 91 | 92 | class UNetT(nn.Module): 93 | def __init__(self, *, 94 | dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4, 95 | mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0, 96 | skip_connect_type: Literal['add', 'concat', 'none'] = 'concat', 97 | ): 98 | super().__init__() 99 | assert depth % 2 == 0, "UNet-Transformer's depth should be even." 100 | 101 | self.time_embed = TimestepEmbedding(dim) 102 | if text_dim is None: 103 | text_dim = mel_dim 104 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers) 105 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim) 106 | 107 | self.rotary_embed = RotaryEmbedding(dim_head) 108 | 109 | # transformer layers & skip connections 110 | 111 | self.dim = dim 112 | self.skip_connect_type = skip_connect_type 113 | needs_skip_proj = skip_connect_type == 'concat' 114 | 115 | self.depth = depth 116 | self.layers = nn.ModuleList([]) 117 | 118 | for idx in range(depth): 119 | is_later_half = idx >= (depth // 2) 120 | 121 | attn_norm = RMSNorm(dim) 122 | attn = Attention( 123 | processor = AttnProcessor(), 124 | dim = dim, 125 | heads = heads, 126 | dim_head = dim_head, 127 | dropout = dropout, 128 | ) 129 | 130 | ff_norm = RMSNorm(dim) 131 | ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") 132 | 133 | skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None 134 | 135 | self.layers.append(nn.ModuleList([ 136 | skip_proj, 137 | attn_norm, 138 | attn, 139 | ff_norm, 140 | ff, 141 | ])) 142 | 143 | self.norm_out = RMSNorm(dim) 144 | self.proj_out = nn.Linear(dim, mel_dim) 145 | 146 | def forward( 147 | self, 148 | x: float['b n d'], # nosied input audio 149 | cond: float['b n d'], # masked cond audio 150 | text: int['b nt'], # text 151 | time: float['b'] | float[''], # time step 152 | drop_audio_cond, # cfg for cond audio 153 | drop_text, # cfg for text 154 | mask: bool['b n'] | None = None, 155 | ): 156 | batch, seq_len = x.shape[0], x.shape[1] 157 | if time.ndim == 0: 158 | time = repeat(time, ' -> b', b = batch) 159 | 160 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio 161 | t = self.time_embed(time) 162 | text_embed = self.text_embed(text, seq_len, drop_text = drop_text) 163 | x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond) 164 | 165 | # postfix time t to input x, [b n d] -> [b n+1 d] 166 | x, ps = pack((t, x), 'b * d') 167 | if mask is not None: 168 | mask = F.pad(mask, (1, 0), value=1) 169 | 170 | rope = self.rotary_embed.forward_from_seq_len(seq_len + 1) 171 | 172 | # flat unet transformer 173 | skip_connect_type = self.skip_connect_type 174 | skips = [] 175 | for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): 176 | layer = idx + 1 177 | 178 | # skip connection logic 179 | is_first_half = layer <= (self.depth // 2) 180 | is_later_half = not is_first_half 181 | 182 | if is_first_half: 183 | skips.append(x) 184 | 185 | if is_later_half: 186 | skip = skips.pop() 187 | if skip_connect_type == 'concat': 188 | x = torch.cat((x, skip), dim = -1) 189 | x = maybe_skip_proj(x) 190 | elif skip_connect_type == 'add': 191 | x = x + skip 192 | 193 | # attention and feedforward blocks 194 | x = attn(attn_norm(x), rope = rope, mask = mask) + x 195 | x = ff(ff_norm(x)) + x 196 | 197 | assert len(skips) == 0 198 | 199 | _, x = unpack(self.norm_out(x), ps, 'b * d') 200 | 201 | return self.proj_out(x) 202 | -------------------------------------------------------------------------------- /model/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset, Sampler 8 | import torchaudio 9 | from datasets import load_dataset, load_from_disk 10 | from datasets import Dataset as Dataset_ 11 | 12 | from einops import rearrange 13 | 14 | from model.modules import MelSpec 15 | 16 | 17 | class HFDataset(Dataset): 18 | def __init__( 19 | self, 20 | hf_dataset: Dataset, 21 | target_sample_rate = 24_000, 22 | n_mel_channels = 100, 23 | hop_length = 256, 24 | ): 25 | self.data = hf_dataset 26 | self.target_sample_rate = target_sample_rate 27 | self.hop_length = hop_length 28 | self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length) 29 | 30 | def get_frame_len(self, index): 31 | row = self.data[index] 32 | audio = row['audio']['array'] 33 | sample_rate = row['audio']['sampling_rate'] 34 | return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def __getitem__(self, index): 40 | row = self.data[index] 41 | audio = row['audio']['array'] 42 | 43 | # logger.info(f"Audio shape: {audio.shape}") 44 | 45 | sample_rate = row['audio']['sampling_rate'] 46 | duration = audio.shape[-1] / sample_rate 47 | 48 | if duration > 30 or duration < 0.3: 49 | return self.__getitem__((index + 1) % len(self.data)) 50 | 51 | audio_tensor = torch.from_numpy(audio).float() 52 | 53 | if sample_rate != self.target_sample_rate: 54 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) 55 | audio_tensor = resampler(audio_tensor) 56 | 57 | audio_tensor = rearrange(audio_tensor, 't -> 1 t') 58 | 59 | mel_spec = self.mel_spectrogram(audio_tensor) 60 | 61 | mel_spec = rearrange(mel_spec, '1 d t -> d t') 62 | 63 | text = row['text'] 64 | 65 | return dict( 66 | mel_spec = mel_spec, 67 | text = text, 68 | ) 69 | 70 | 71 | class CustomDataset(Dataset): 72 | def __init__( 73 | self, 74 | custom_dataset: Dataset, 75 | durations = None, 76 | target_sample_rate = 24_000, 77 | hop_length = 256, 78 | n_mel_channels = 100, 79 | preprocessed_mel = False, 80 | ): 81 | self.data = custom_dataset 82 | self.durations = durations 83 | self.target_sample_rate = target_sample_rate 84 | self.hop_length = hop_length 85 | self.preprocessed_mel = preprocessed_mel 86 | if not preprocessed_mel: 87 | self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels) 88 | 89 | def get_frame_len(self, index): 90 | if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM 91 | return self.durations[index] * self.target_sample_rate / self.hop_length 92 | return self.data[index]["duration"] * self.target_sample_rate / self.hop_length 93 | 94 | def __len__(self): 95 | return len(self.data) 96 | 97 | def __getitem__(self, index): 98 | row = self.data[index] 99 | audio_path = row["audio_path"] 100 | text = row["text"] 101 | duration = row["duration"] 102 | 103 | if self.preprocessed_mel: 104 | mel_spec = torch.tensor(row["mel_spec"]) 105 | 106 | else: 107 | audio, source_sample_rate = torchaudio.load(audio_path) 108 | 109 | if duration > 30 or duration < 0.3: 110 | return self.__getitem__((index + 1) % len(self.data)) 111 | 112 | if source_sample_rate != self.target_sample_rate: 113 | resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) 114 | audio = resampler(audio) 115 | 116 | mel_spec = self.mel_spectrogram(audio) 117 | mel_spec = rearrange(mel_spec, '1 d t -> d t') 118 | 119 | return dict( 120 | mel_spec = mel_spec, 121 | text = text, 122 | ) 123 | 124 | 125 | # Dynamic Batch Sampler 126 | 127 | class DynamicBatchSampler(Sampler[list[int]]): 128 | """ Extension of Sampler that will do the following: 129 | 1. Change the batch size (essentially number of sequences) 130 | in a batch to ensure that the total number of frames are less 131 | than a certain threshold. 132 | 2. Make sure the padding efficiency in the batch is high. 133 | """ 134 | 135 | def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False): 136 | self.sampler = sampler 137 | self.frames_threshold = frames_threshold 138 | self.max_samples = max_samples 139 | 140 | indices, batches = [], [] 141 | data_source = self.sampler.data_source 142 | 143 | for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"): 144 | indices.append((idx, data_source.get_frame_len(idx))) 145 | indices.sort(key=lambda elem : elem[1]) 146 | 147 | batch = [] 148 | batch_frames = 0 149 | for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"): 150 | if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): 151 | batch.append(idx) 152 | batch_frames += frame_len 153 | else: 154 | if len(batch) > 0: 155 | batches.append(batch) 156 | if frame_len <= self.frames_threshold: 157 | batch = [idx] 158 | batch_frames = frame_len 159 | else: 160 | batch = [] 161 | batch_frames = 0 162 | 163 | if not drop_last and len(batch) > 0: 164 | batches.append(batch) 165 | 166 | del indices 167 | 168 | # if want to have different batches between epochs, may just set a seed and log it in ckpt 169 | # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different 170 | # e.g. for epoch n, use (random_seed + n) 171 | random.seed(random_seed) 172 | random.shuffle(batches) 173 | 174 | self.batches = batches 175 | 176 | def __iter__(self): 177 | return iter(self.batches) 178 | 179 | def __len__(self): 180 | return len(self.batches) 181 | 182 | 183 | # Load dataset 184 | 185 | def load_dataset( 186 | dataset_name: str, 187 | tokenizer: str, 188 | dataset_type: str = "CustomDataset", 189 | audio_type: str = "raw", 190 | mel_spec_kwargs: dict = dict() 191 | ) -> CustomDataset | HFDataset: 192 | 193 | print("Loading dataset ...") 194 | 195 | if dataset_type == "CustomDataset": 196 | if audio_type == "raw": 197 | try: 198 | train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw") 199 | except: 200 | train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow") 201 | preprocessed_mel = False 202 | elif audio_type == "mel": 203 | train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow") 204 | preprocessed_mel = True 205 | with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f: 206 | data_dict = json.load(f) 207 | durations = data_dict["duration"] 208 | train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs) 209 | 210 | elif dataset_type == "HFDataset": 211 | print("Should manually modify the path of huggingface dataset to your need.\n" + 212 | "May also the corresponding script cuz different dataset may have different format.") 213 | pre, post = dataset_name.split("_") 214 | train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),) 215 | 216 | return train_dataset 217 | 218 | 219 | # collation 220 | 221 | def collate_fn(batch): 222 | mel_specs = [item['mel_spec'].squeeze(0) for item in batch] 223 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) 224 | max_mel_length = mel_lengths.amax() 225 | 226 | padded_mel_specs = [] 227 | for spec in mel_specs: # TODO. maybe records mask for attention here 228 | padding = (0, max_mel_length - spec.size(-1)) 229 | padded_spec = F.pad(spec, padding, value = 0) 230 | padded_mel_specs.append(padded_spec) 231 | 232 | mel_specs = torch.stack(padded_mel_specs) 233 | 234 | text = [item['text'] for item in batch] 235 | text_lengths = torch.LongTensor([len(item) for item in text]) 236 | 237 | return dict( 238 | mel = mel_specs, 239 | mel_lengths = mel_lengths, 240 | text = text, 241 | text_lengths = text_lengths, 242 | ) 243 | -------------------------------------------------------------------------------- /model/cfm.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | from typing import Callable 12 | from random import random 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from torch.nn.utils.rnn import pad_sequence 18 | 19 | from torchdiffeq import odeint 20 | 21 | from einops import rearrange 22 | 23 | from model.modules import MelSpec 24 | 25 | from model.utils import ( 26 | default, exists, 27 | list_str_to_idx, list_str_to_tensor, 28 | lens_to_mask, mask_from_frac_lengths, 29 | ) 30 | 31 | 32 | class CFM(nn.Module): 33 | def __init__( 34 | self, 35 | transformer: nn.Module, 36 | sigma = 0., 37 | odeint_kwargs: dict = dict( 38 | # atol = 1e-5, 39 | # rtol = 1e-5, 40 | method = 'euler' # 'midpoint' 41 | ), 42 | audio_drop_prob = 0.3, 43 | cond_drop_prob = 0.2, 44 | num_channels = None, 45 | mel_spec_module: nn.Module | None = None, 46 | mel_spec_kwargs: dict = dict(), 47 | frac_lengths_mask: tuple[float, float] = (0.7, 1.), 48 | vocab_char_map: dict[str: int] | None = None 49 | ): 50 | super().__init__() 51 | 52 | self.frac_lengths_mask = frac_lengths_mask 53 | 54 | # mel spec 55 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) 56 | num_channels = default(num_channels, self.mel_spec.n_mel_channels) 57 | self.num_channels = num_channels 58 | 59 | # classifier-free guidance 60 | self.audio_drop_prob = audio_drop_prob 61 | self.cond_drop_prob = cond_drop_prob 62 | 63 | # transformer 64 | self.transformer = transformer 65 | dim = transformer.dim 66 | self.dim = dim 67 | 68 | # conditional flow related 69 | self.sigma = sigma 70 | 71 | # sampling related 72 | self.odeint_kwargs = odeint_kwargs 73 | 74 | # vocab map for tokenization 75 | self.vocab_char_map = vocab_char_map 76 | 77 | @property 78 | def device(self): 79 | return next(self.parameters()).device 80 | 81 | @torch.no_grad() 82 | def sample( 83 | self, 84 | cond: float['b n d'] | float['b nw'], 85 | text: int['b nt'] | list[str], 86 | duration: int | int['b'], 87 | *, 88 | lens: int['b'] | None = None, 89 | steps = 32, 90 | cfg_strength = 1., 91 | sway_sampling_coef = None, 92 | seed: int | None = None, 93 | max_duration = 4096, 94 | vocoder: Callable[[float['b d n']], float['b nw']] | None = None, 95 | no_ref_audio = False, 96 | duplicate_test = False, 97 | t_inter = 0.1, 98 | edit_mask = None, 99 | ): 100 | self.eval() 101 | 102 | # raw wave 103 | 104 | if cond.ndim == 2: 105 | cond = self.mel_spec(cond) 106 | cond = rearrange(cond, 'b d n -> b n d') 107 | assert cond.shape[-1] == self.num_channels 108 | 109 | batch, cond_seq_len, device = *cond.shape[:2], cond.device 110 | if not exists(lens): 111 | lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) 112 | 113 | # text 114 | 115 | if isinstance(text, list): 116 | if exists(self.vocab_char_map): 117 | text = list_str_to_idx(text, self.vocab_char_map).to(device) 118 | else: 119 | text = list_str_to_tensor(text).to(device) 120 | assert text.shape[0] == batch 121 | 122 | if exists(text): 123 | text_lens = (text != -1).sum(dim = -1) 124 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters 125 | 126 | # duration 127 | 128 | cond_mask = lens_to_mask(lens) 129 | if edit_mask is not None: 130 | cond_mask = cond_mask & edit_mask 131 | 132 | if isinstance(duration, int): 133 | duration = torch.full((batch,), duration, device = device, dtype = torch.long) 134 | 135 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated 136 | duration = duration.clamp(max = max_duration) 137 | max_duration = duration.amax() 138 | 139 | # duplicate test corner for inner time step oberservation 140 | if duplicate_test: 141 | test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.) 142 | 143 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) 144 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) 145 | cond_mask = rearrange(cond_mask, '... -> ... 1') 146 | step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in 147 | 148 | if batch > 1: 149 | mask = lens_to_mask(duration) 150 | else: # save memory and speed up, as single inference need no mask currently 151 | mask = None 152 | 153 | # test for no ref audio 154 | if no_ref_audio: 155 | cond = torch.zeros_like(cond) 156 | 157 | # neural ode 158 | 159 | def fn(t, x): 160 | # at each step, conditioning is fixed 161 | # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) 162 | 163 | # predict flow 164 | pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False) 165 | if cfg_strength < 1e-5: 166 | return pred 167 | 168 | null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True) 169 | return pred + (pred - null_pred) * cfg_strength 170 | 171 | # noise input 172 | # to make sure batch inference result is same with different batch size, and for sure single inference 173 | # still some difference maybe due to convolutional layers 174 | y0 = [] 175 | for dur in duration: 176 | if exists(seed): 177 | torch.manual_seed(seed) 178 | y0.append(torch.randn(dur, self.num_channels, device = self.device)) 179 | y0 = pad_sequence(y0, padding_value = 0, batch_first = True) 180 | 181 | t_start = 0 182 | 183 | # duplicate test corner for inner time step oberservation 184 | if duplicate_test: 185 | t_start = t_inter 186 | y0 = (1 - t_start) * y0 + t_start * test_cond 187 | steps = int(steps * (1 - t_start)) 188 | 189 | t = torch.linspace(t_start, 1, steps, device = self.device) 190 | if sway_sampling_coef is not None: 191 | t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) 192 | 193 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs) 194 | 195 | sampled = trajectory[-1] 196 | out = sampled 197 | out = torch.where(cond_mask, cond, out) 198 | 199 | if exists(vocoder): 200 | out = rearrange(out, 'b n d -> b d n') 201 | out = vocoder(out) 202 | 203 | return out, trajectory 204 | 205 | def forward( 206 | self, 207 | inp: float['b n d'] | float['b nw'], # mel or raw wave 208 | text: int['b nt'] | list[str], 209 | *, 210 | lens: int['b'] | None = None, 211 | noise_scheduler: str | None = None, 212 | ): 213 | # handle raw wave 214 | if inp.ndim == 2: 215 | inp = self.mel_spec(inp) 216 | inp = rearrange(inp, 'b d n -> b n d') 217 | assert inp.shape[-1] == self.num_channels 218 | 219 | batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma 220 | 221 | # handle text as string 222 | if isinstance(text, list): 223 | if exists(self.vocab_char_map): 224 | text = list_str_to_idx(text, self.vocab_char_map).to(device) 225 | else: 226 | text = list_str_to_tensor(text).to(device) 227 | assert text.shape[0] == batch 228 | 229 | # lens and mask 230 | if not exists(lens): 231 | lens = torch.full((batch,), seq_len, device = device) 232 | 233 | mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch 234 | 235 | # get a random span to mask out for training conditionally 236 | frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) 237 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) 238 | 239 | if exists(mask): 240 | rand_span_mask &= mask 241 | 242 | # mel is x1 243 | x1 = inp 244 | 245 | # x0 is gaussian noise 246 | x0 = torch.randn_like(x1) 247 | 248 | # time step 249 | time = torch.rand((batch,), dtype = dtype, device = self.device) 250 | # TODO. noise_scheduler 251 | 252 | # sample xt (φ_t(x) in the paper) 253 | t = rearrange(time, 'b -> b 1 1') 254 | φ = (1 - t) * x0 + t * x1 255 | flow = x1 - x0 256 | 257 | # only predict what is within the random mask span for infilling 258 | cond = torch.where( 259 | rand_span_mask[..., None], 260 | torch.zeros_like(x1), x1 261 | ) 262 | 263 | # transformer and cfg training with a drop rate 264 | drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper 265 | if random() < self.cond_drop_prob: # p_uncond in voicebox paper 266 | drop_audio_cond = True 267 | drop_text = True 268 | else: 269 | drop_text = False 270 | 271 | # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here 272 | # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences 273 | pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text) 274 | 275 | # flow matching loss 276 | loss = F.mse_loss(pred, flow, reduction = 'none') 277 | loss = loss[rand_span_mask] 278 | 279 | return loss.mean(), cond, pred 280 | -------------------------------------------------------------------------------- /model/ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | # just for speaker similarity evaluation, third-party code 2 | 3 | # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ 4 | # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN 5 | 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | ''' Res2Conv1d + BatchNorm1d + ReLU 13 | ''' 14 | 15 | class Res2Conv1dReluBn(nn.Module): 16 | ''' 17 | in_channels == out_channels == channels 18 | ''' 19 | 20 | def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): 21 | super().__init__() 22 | assert channels % scale == 0, "{} % {} != 0".format(channels, scale) 23 | self.scale = scale 24 | self.width = channels // scale 25 | self.nums = scale if scale == 1 else scale - 1 26 | 27 | self.convs = [] 28 | self.bns = [] 29 | for i in range(self.nums): 30 | self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias)) 31 | self.bns.append(nn.BatchNorm1d(self.width)) 32 | self.convs = nn.ModuleList(self.convs) 33 | self.bns = nn.ModuleList(self.bns) 34 | 35 | def forward(self, x): 36 | out = [] 37 | spx = torch.split(x, self.width, 1) 38 | for i in range(self.nums): 39 | if i == 0: 40 | sp = spx[i] 41 | else: 42 | sp = sp + spx[i] 43 | # Order: conv -> relu -> bn 44 | sp = self.convs[i](sp) 45 | sp = self.bns[i](F.relu(sp)) 46 | out.append(sp) 47 | if self.scale != 1: 48 | out.append(spx[self.nums]) 49 | out = torch.cat(out, dim=1) 50 | 51 | return out 52 | 53 | 54 | ''' Conv1d + BatchNorm1d + ReLU 55 | ''' 56 | 57 | class Conv1dReluBn(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): 59 | super().__init__() 60 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) 61 | self.bn = nn.BatchNorm1d(out_channels) 62 | 63 | def forward(self, x): 64 | return self.bn(F.relu(self.conv(x))) 65 | 66 | 67 | ''' The SE connection of 1D case. 68 | ''' 69 | 70 | class SE_Connect(nn.Module): 71 | def __init__(self, channels, se_bottleneck_dim=128): 72 | super().__init__() 73 | self.linear1 = nn.Linear(channels, se_bottleneck_dim) 74 | self.linear2 = nn.Linear(se_bottleneck_dim, channels) 75 | 76 | def forward(self, x): 77 | out = x.mean(dim=2) 78 | out = F.relu(self.linear1(out)) 79 | out = torch.sigmoid(self.linear2(out)) 80 | out = x * out.unsqueeze(2) 81 | 82 | return out 83 | 84 | 85 | ''' SE-Res2Block of the ECAPA-TDNN architecture. 86 | ''' 87 | 88 | # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): 89 | # return nn.Sequential( 90 | # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), 91 | # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), 92 | # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), 93 | # SE_Connect(channels) 94 | # ) 95 | 96 | class SE_Res2Block(nn.Module): 97 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): 98 | super().__init__() 99 | self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 100 | self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) 101 | self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) 102 | self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) 103 | 104 | self.shortcut = None 105 | if in_channels != out_channels: 106 | self.shortcut = nn.Conv1d( 107 | in_channels=in_channels, 108 | out_channels=out_channels, 109 | kernel_size=1, 110 | ) 111 | 112 | def forward(self, x): 113 | residual = x 114 | if self.shortcut: 115 | residual = self.shortcut(x) 116 | 117 | x = self.Conv1dReluBn1(x) 118 | x = self.Res2Conv1dReluBn(x) 119 | x = self.Conv1dReluBn2(x) 120 | x = self.SE_Connect(x) 121 | 122 | return x + residual 123 | 124 | 125 | ''' Attentive weighted mean and standard deviation pooling. 126 | ''' 127 | 128 | class AttentiveStatsPool(nn.Module): 129 | def __init__(self, in_dim, attention_channels=128, global_context_att=False): 130 | super().__init__() 131 | self.global_context_att = global_context_att 132 | 133 | # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. 134 | if global_context_att: 135 | self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper 136 | else: 137 | self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper 138 | self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper 139 | 140 | def forward(self, x): 141 | 142 | if self.global_context_att: 143 | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) 144 | context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) 145 | x_in = torch.cat((x, context_mean, context_std), dim=1) 146 | else: 147 | x_in = x 148 | 149 | # DON'T use ReLU here! In experiments, I find ReLU hard to converge. 150 | alpha = torch.tanh(self.linear1(x_in)) 151 | # alpha = F.relu(self.linear1(x_in)) 152 | alpha = torch.softmax(self.linear2(alpha), dim=2) 153 | mean = torch.sum(alpha * x, dim=2) 154 | residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 155 | std = torch.sqrt(residuals.clamp(min=1e-9)) 156 | return torch.cat([mean, std], dim=1) 157 | 158 | 159 | class ECAPA_TDNN(nn.Module): 160 | def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, 161 | feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): 162 | super().__init__() 163 | 164 | self.feat_type = feat_type 165 | self.feature_selection = feature_selection 166 | self.update_extract = update_extract 167 | self.sr = sr 168 | 169 | torch.hub._validate_not_a_forked_repo=lambda a,b,c: True 170 | try: 171 | local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main") 172 | self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path) 173 | except: 174 | self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type) 175 | 176 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): 177 | self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False 178 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): 179 | self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False 180 | 181 | self.feat_num = self.get_feat_num() 182 | self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) 183 | 184 | if feat_type != 'fbank' and feat_type != 'mfcc': 185 | freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer'] 186 | for name, param in self.feature_extract.named_parameters(): 187 | for freeze_val in freeze_list: 188 | if freeze_val in name: 189 | param.requires_grad = False 190 | break 191 | 192 | if not self.update_extract: 193 | for param in self.feature_extract.parameters(): 194 | param.requires_grad = False 195 | 196 | self.instance_norm = nn.InstanceNorm1d(feat_dim) 197 | # self.channels = [channels] * 4 + [channels * 3] 198 | self.channels = [channels] * 4 + [1536] 199 | 200 | self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) 201 | self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) 202 | self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) 203 | self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) 204 | 205 | # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) 206 | cat_channels = channels * 3 207 | self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) 208 | self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) 209 | self.bn = nn.BatchNorm1d(self.channels[-1] * 2) 210 | self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) 211 | 212 | 213 | def get_feat_num(self): 214 | self.feature_extract.eval() 215 | wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] 216 | with torch.no_grad(): 217 | features = self.feature_extract(wav) 218 | select_feature = features[self.feature_selection] 219 | if isinstance(select_feature, (list, tuple)): 220 | return len(select_feature) 221 | else: 222 | return 1 223 | 224 | def get_feat(self, x): 225 | if self.update_extract: 226 | x = self.feature_extract([sample for sample in x]) 227 | else: 228 | with torch.no_grad(): 229 | if self.feat_type == 'fbank' or self.feat_type == 'mfcc': 230 | x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len 231 | else: 232 | x = self.feature_extract([sample for sample in x]) 233 | 234 | if self.feat_type == 'fbank': 235 | x = x.log() 236 | 237 | if self.feat_type != "fbank" and self.feat_type != "mfcc": 238 | x = x[self.feature_selection] 239 | if isinstance(x, (list, tuple)): 240 | x = torch.stack(x, dim=0) 241 | else: 242 | x = x.unsqueeze(0) 243 | norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 244 | x = (norm_weights * x).sum(dim=0) 245 | x = torch.transpose(x, 1, 2) + 1e-6 246 | 247 | x = self.instance_norm(x) 248 | return x 249 | 250 | def forward(self, x): 251 | x = self.get_feat(x) 252 | 253 | out1 = self.layer1(x) 254 | out2 = self.layer2(out1) 255 | out3 = self.layer3(out2) 256 | out4 = self.layer4(out3) 257 | 258 | out = torch.cat([out2, out3, out4], dim=1) 259 | out = F.relu(self.conv(out)) 260 | out = self.bn(self.pooling(out)) 261 | out = self.linear(out) 262 | 263 | return out 264 | 265 | 266 | def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): 267 | return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim, 268 | feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path) 269 | -------------------------------------------------------------------------------- /model/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import gc 5 | from tqdm import tqdm 6 | import wandb 7 | 8 | import torch 9 | from torch.optim import AdamW 10 | from torch.utils.data import DataLoader, Dataset, SequentialSampler 11 | from torch.optim.lr_scheduler import LinearLR, SequentialLR 12 | 13 | from einops import rearrange 14 | 15 | from accelerate import Accelerator 16 | from accelerate.utils import DistributedDataParallelKwargs 17 | 18 | from ema_pytorch import EMA 19 | 20 | from model import CFM 21 | from model.utils import exists, default 22 | from model.dataset import DynamicBatchSampler, collate_fn 23 | 24 | 25 | # trainer 26 | 27 | class Trainer: 28 | def __init__( 29 | self, 30 | model: CFM, 31 | epochs, 32 | learning_rate, 33 | num_warmup_updates = 20000, 34 | save_per_updates = 1000, 35 | checkpoint_path = None, 36 | batch_size = 32, 37 | batch_size_type: str = "sample", 38 | max_samples = 32, 39 | grad_accumulation_steps = 1, 40 | max_grad_norm = 1.0, 41 | noise_scheduler: str | None = None, 42 | duration_predictor: torch.nn.Module | None = None, 43 | wandb_project = "test_e2-tts", 44 | wandb_run_name = "test_run", 45 | wandb_resume_id: str = None, 46 | last_per_steps = None, 47 | accelerate_kwargs: dict = dict(), 48 | ema_kwargs: dict = dict() 49 | ): 50 | 51 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) 52 | 53 | self.accelerator = Accelerator( 54 | log_with = "wandb", 55 | kwargs_handlers = [ddp_kwargs], 56 | gradient_accumulation_steps = grad_accumulation_steps, 57 | **accelerate_kwargs 58 | ) 59 | 60 | if exists(wandb_resume_id): 61 | init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}} 62 | else: 63 | init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}} 64 | self.accelerator.init_trackers( 65 | project_name = wandb_project, 66 | init_kwargs=init_kwargs, 67 | config={"epochs": epochs, 68 | "learning_rate": learning_rate, 69 | "num_warmup_updates": num_warmup_updates, 70 | "batch_size": batch_size, 71 | "batch_size_type": batch_size_type, 72 | "max_samples": max_samples, 73 | "grad_accumulation_steps": grad_accumulation_steps, 74 | "max_grad_norm": max_grad_norm, 75 | "gpus": self.accelerator.num_processes, 76 | "noise_scheduler": noise_scheduler} 77 | ) 78 | 79 | self.model = model 80 | 81 | if self.is_main: 82 | self.ema_model = EMA( 83 | model, 84 | include_online_model = False, 85 | **ema_kwargs 86 | ) 87 | 88 | self.ema_model.to(self.accelerator.device) 89 | 90 | self.epochs = epochs 91 | self.num_warmup_updates = num_warmup_updates 92 | self.save_per_updates = save_per_updates 93 | self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) 94 | self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts') 95 | 96 | self.batch_size = batch_size 97 | self.batch_size_type = batch_size_type 98 | self.max_samples = max_samples 99 | self.grad_accumulation_steps = grad_accumulation_steps 100 | self.max_grad_norm = max_grad_norm 101 | 102 | self.noise_scheduler = noise_scheduler 103 | 104 | self.duration_predictor = duration_predictor 105 | 106 | self.optimizer = AdamW(model.parameters(), lr=learning_rate) 107 | self.model, self.optimizer = self.accelerator.prepare( 108 | self.model, self.optimizer 109 | ) 110 | 111 | @property 112 | def is_main(self): 113 | return self.accelerator.is_main_process 114 | 115 | def save_checkpoint(self, step, last=False): 116 | self.accelerator.wait_for_everyone() 117 | if self.is_main: 118 | checkpoint = dict( 119 | model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(), 120 | optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(), 121 | ema_model_state_dict = self.ema_model.state_dict(), 122 | scheduler_state_dict = self.scheduler.state_dict(), 123 | step = step 124 | ) 125 | if not os.path.exists(self.checkpoint_path): 126 | os.makedirs(self.checkpoint_path) 127 | if last == True: 128 | self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") 129 | print(f"Saved last checkpoint at step {step}") 130 | else: 131 | self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") 132 | 133 | def load_checkpoint(self): 134 | if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path): 135 | return 0 136 | 137 | self.accelerator.wait_for_everyone() 138 | if "model_last.pt" in os.listdir(self.checkpoint_path): 139 | latest_checkpoint = "model_last.pt" 140 | else: 141 | latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1] 142 | # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ 143 | checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu") 144 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) 145 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict']) 146 | 147 | if self.is_main: 148 | self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) 149 | 150 | if self.scheduler: 151 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 152 | 153 | step = checkpoint['step'] 154 | del checkpoint; gc.collect() 155 | return step 156 | 157 | def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): 158 | 159 | if exists(resumable_with_seed): 160 | generator = torch.Generator() 161 | generator.manual_seed(resumable_with_seed) 162 | else: 163 | generator = None 164 | 165 | if self.batch_size_type == "sample": 166 | train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, 167 | batch_size=self.batch_size, shuffle=True, generator=generator) 168 | elif self.batch_size_type == "frame": 169 | self.accelerator.even_batches = False 170 | sampler = SequentialSampler(train_dataset) 171 | batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False) 172 | train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, 173 | batch_sampler=batch_sampler) 174 | else: 175 | raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}") 176 | 177 | # accelerator.prepare() dispatches batches to devices; 178 | # which means the length of dataloader calculated before, should consider the number of devices 179 | warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp 180 | # otherwise by default with split_batches=False, warmup steps change with num_processes 181 | total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps 182 | decay_steps = total_steps - warmup_steps 183 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) 184 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) 185 | self.scheduler = SequentialLR(self.optimizer, 186 | schedulers=[warmup_scheduler, decay_scheduler], 187 | milestones=[warmup_steps]) 188 | train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus 189 | start_step = self.load_checkpoint() 190 | global_step = start_step 191 | 192 | if exists(resumable_with_seed): 193 | orig_epoch_step = len(train_dataloader) 194 | skipped_epoch = int(start_step // orig_epoch_step) 195 | skipped_batch = start_step % orig_epoch_step 196 | skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) 197 | else: 198 | skipped_epoch = 0 199 | 200 | for epoch in range(skipped_epoch, self.epochs): 201 | self.model.train() 202 | if exists(resumable_with_seed) and epoch == skipped_epoch: 203 | progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, 204 | initial=skipped_batch, total=orig_epoch_step) 205 | else: 206 | progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process) 207 | 208 | for batch in progress_bar: 209 | with self.accelerator.accumulate(self.model): 210 | text_inputs = batch['text'] 211 | mel_spec = rearrange(batch['mel'], 'b d n -> b n d') 212 | mel_lengths = batch["mel_lengths"] 213 | 214 | # TODO. add duration predictor training 215 | if self.duration_predictor is not None and self.accelerator.is_local_main_process: 216 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) 217 | self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step) 218 | 219 | loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler) 220 | self.accelerator.backward(loss) 221 | 222 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients: 223 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 224 | 225 | self.optimizer.step() 226 | self.scheduler.step() 227 | self.optimizer.zero_grad() 228 | 229 | if self.is_main: 230 | self.ema_model.update() 231 | 232 | global_step += 1 233 | 234 | if self.accelerator.is_local_main_process: 235 | self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step) 236 | 237 | progress_bar.set_postfix(step=str(global_step), loss=loss.item()) 238 | 239 | if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: 240 | self.save_checkpoint(global_step) 241 | 242 | if global_step % self.last_per_steps == 0: 243 | self.save_checkpoint(global_step, last=True) 244 | 245 | self.accelerator.end_training() 246 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import re 5 | import math 6 | import random 7 | import string 8 | from tqdm import tqdm 9 | from collections import defaultdict 10 | 11 | import matplotlib 12 | matplotlib.use("Agg") 13 | import matplotlib.pylab as plt 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.nn.utils.rnn import pad_sequence 18 | import torchaudio 19 | 20 | import einx 21 | from einops import rearrange, reduce 22 | 23 | import jieba 24 | from pypinyin import lazy_pinyin, Style 25 | import zhconv 26 | from zhon.hanzi import punctuation 27 | from jiwer import compute_measures 28 | 29 | from funasr import AutoModel 30 | from faster_whisper import WhisperModel 31 | 32 | from model.ecapa_tdnn import ECAPA_TDNN_SMALL 33 | from model.modules import MelSpec 34 | 35 | 36 | # seed everything 37 | 38 | def seed_everything(seed = 0): 39 | random.seed(seed) 40 | os.environ['PYTHONHASHSEED'] = str(seed) 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | 47 | # helpers 48 | 49 | def exists(v): 50 | return v is not None 51 | 52 | def default(v, d): 53 | return v if exists(v) else d 54 | 55 | # tensor helpers 56 | 57 | def lens_to_mask( 58 | t: int['b'], 59 | length: int | None = None 60 | ) -> bool['b n']: 61 | 62 | if not exists(length): 63 | length = t.amax() 64 | 65 | seq = torch.arange(length, device = t.device) 66 | return einx.less('n, b -> b n', seq, t) 67 | 68 | def mask_from_start_end_indices( 69 | seq_len: int['b'], 70 | start: int['b'], 71 | end: int['b'] 72 | ): 73 | max_seq_len = seq_len.max().item() 74 | seq = torch.arange(max_seq_len, device = start.device).long() 75 | return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end) 76 | 77 | def mask_from_frac_lengths( 78 | seq_len: int['b'], 79 | frac_lengths: float['b'] 80 | ): 81 | lengths = (frac_lengths * seq_len).long() 82 | max_start = seq_len - lengths 83 | 84 | rand = torch.rand_like(frac_lengths) 85 | start = (max_start * rand).long().clamp(min = 0) 86 | end = start + lengths 87 | 88 | return mask_from_start_end_indices(seq_len, start, end) 89 | 90 | def maybe_masked_mean( 91 | t: float['b n d'], 92 | mask: bool['b n'] = None 93 | ) -> float['b d']: 94 | 95 | if not exists(mask): 96 | return t.mean(dim = 1) 97 | 98 | t = einx.where('b n, b n d, -> b n d', mask, t, 0.) 99 | num = reduce(t, 'b n d -> b d', 'sum') 100 | den = reduce(mask.float(), 'b n -> b', 'sum') 101 | 102 | return einx.divide('b d, b -> b d', num, den.clamp(min = 1.)) 103 | 104 | 105 | # simple utf-8 tokenizer, since paper went character based 106 | def list_str_to_tensor( 107 | text: list[str], 108 | padding_value = -1 109 | ) -> int['b nt']: 110 | list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style 111 | text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True) 112 | return text 113 | 114 | # char tokenizer, based on custom dataset's extracted .txt file 115 | def list_str_to_idx( 116 | text: list[str] | list[list[str]], 117 | vocab_char_map: dict[str, int], # {char: idx} 118 | padding_value = -1 119 | ) -> int['b nt']: 120 | list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style 121 | text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True) 122 | return text 123 | 124 | 125 | # Get tokenizer 126 | 127 | def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): 128 | ''' 129 | tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file 130 | - "char" for char-wise tokenizer, need .txt vocab_file 131 | - "byte" for utf-8 tokenizer 132 | vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols 133 | - if use "char", derived from unfiltered character & symbol counts of custom dataset 134 | - if use "byte", set to 256 (unicode byte range) 135 | ''' 136 | if tokenizer in ["pinyin", "char"]: 137 | with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f: 138 | vocab_char_map = {} 139 | for i, char in enumerate(f): 140 | vocab_char_map[char[:-1]] = i 141 | vocab_size = len(vocab_char_map) 142 | assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" 143 | 144 | elif tokenizer == "byte": 145 | vocab_char_map = None 146 | vocab_size = 256 147 | 148 | return vocab_char_map, vocab_size 149 | 150 | 151 | # convert char to pinyin 152 | 153 | def convert_char_to_pinyin(text_list, polyphone = True): 154 | final_text_list = [] 155 | god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean 156 | custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov 157 | for text in text_list: 158 | char_list = [] 159 | text = text.translate(god_knows_why_en_testset_contains_zh_quote) 160 | text = text.translate(custom_trans) 161 | for seg in jieba.cut(text): 162 | seg_byte_len = len(bytes(seg, 'UTF-8')) 163 | if seg_byte_len == len(seg): # if pure alphabets and symbols 164 | if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": 165 | char_list.append(" ") 166 | char_list.extend(seg) 167 | elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters 168 | seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) 169 | for c in seg: 170 | if c not in "。,、;:?!《》【】—…": 171 | char_list.append(" ") 172 | char_list.append(c) 173 | else: # if mixed chinese characters, alphabets and symbols 174 | for c in seg: 175 | if ord(c) < 256: 176 | char_list.extend(c) 177 | else: 178 | if c not in "。,、;:?!《》【】—…": 179 | char_list.append(" ") 180 | char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) 181 | else: # if is zh punc 182 | char_list.append(c) 183 | final_text_list.append(char_list) 184 | 185 | return final_text_list 186 | 187 | 188 | # save spectrogram 189 | def save_spectrogram(spectrogram, path): 190 | plt.figure(figsize=(12, 4)) 191 | plt.imshow(spectrogram, origin='lower', aspect='auto') 192 | plt.colorbar() 193 | plt.savefig(path) 194 | plt.close() 195 | 196 | 197 | # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav 198 | def get_seedtts_testset_metainfo(metalst): 199 | f = open(metalst); lines = f.readlines(); f.close() 200 | metainfo = [] 201 | for line in lines: 202 | if len(line.strip().split('|')) == 5: 203 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|') 204 | elif len(line.strip().split('|')) == 4: 205 | utt, prompt_text, prompt_wav, gt_text = line.strip().split('|') 206 | gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") 207 | if not os.path.isabs(prompt_wav): 208 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) 209 | metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) 210 | return metainfo 211 | 212 | 213 | # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav 214 | def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path): 215 | f = open(metalst); lines = f.readlines(); f.close() 216 | metainfo = [] 217 | for line in lines: 218 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t') 219 | 220 | # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) 221 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-') 222 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac') 223 | 224 | # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) 225 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-') 226 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac') 227 | 228 | metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav)) 229 | 230 | return metainfo 231 | 232 | 233 | # padded to max length mel batch 234 | def padded_mel_batch(ref_mels): 235 | max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() 236 | padded_ref_mels = [] 237 | for mel in ref_mels: 238 | padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0) 239 | padded_ref_mels.append(padded_ref_mel) 240 | padded_ref_mels = torch.stack(padded_ref_mels) 241 | padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d') 242 | return padded_ref_mels 243 | 244 | 245 | # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav 246 | 247 | def get_inference_prompt( 248 | metainfo, 249 | speed = 1., tokenizer = "pinyin", polyphone = True, 250 | target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1, 251 | use_truth_duration = False, 252 | infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40, 253 | ): 254 | prompts_all = [] 255 | 256 | min_tokens = min_secs * target_sample_rate // hop_length 257 | max_tokens = max_secs * target_sample_rate // hop_length 258 | 259 | batch_accum = [0] * num_buckets 260 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \ 261 | ([[] for _ in range(num_buckets)] for _ in range(6)) 262 | 263 | mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length) 264 | 265 | for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): 266 | 267 | # Audio 268 | ref_audio, ref_sr = torchaudio.load(prompt_wav) 269 | ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) 270 | if ref_rms < target_rms: 271 | ref_audio = ref_audio * target_rms / ref_rms 272 | assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." 273 | if ref_sr != target_sample_rate: 274 | resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) 275 | ref_audio = resampler(ref_audio) 276 | 277 | # Text 278 | text = [prompt_text + gt_text] 279 | if tokenizer == "pinyin": 280 | text_list = convert_char_to_pinyin(text, polyphone = polyphone) 281 | else: 282 | text_list = text 283 | 284 | # Duration, mel frame length 285 | ref_mel_len = ref_audio.shape[-1] // hop_length 286 | if use_truth_duration: 287 | gt_audio, gt_sr = torchaudio.load(gt_wav) 288 | if gt_sr != target_sample_rate: 289 | resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) 290 | gt_audio = resampler(gt_audio) 291 | total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) 292 | 293 | # # test vocoder resynthesis 294 | # ref_audio = gt_audio 295 | else: 296 | zh_pause_punc = r"。,、;:?!" 297 | ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text)) 298 | gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text)) 299 | total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) 300 | 301 | # to mel spectrogram 302 | ref_mel = mel_spectrogram(ref_audio) 303 | ref_mel = rearrange(ref_mel, '1 d n -> d n') 304 | 305 | # deal with batch 306 | assert infer_batch_size > 0, "infer_batch_size should be greater than 0." 307 | assert min_tokens <= total_mel_len <= max_tokens, \ 308 | f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." 309 | bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) 310 | 311 | utts[bucket_i].append(utt) 312 | ref_rms_list[bucket_i].append(ref_rms) 313 | ref_mels[bucket_i].append(ref_mel) 314 | ref_mel_lens[bucket_i].append(ref_mel_len) 315 | total_mel_lens[bucket_i].append(total_mel_len) 316 | final_text_list[bucket_i].extend(text_list) 317 | 318 | batch_accum[bucket_i] += total_mel_len 319 | 320 | if batch_accum[bucket_i] >= infer_batch_size: 321 | # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") 322 | prompts_all.append(( 323 | utts[bucket_i], 324 | ref_rms_list[bucket_i], 325 | padded_mel_batch(ref_mels[bucket_i]), 326 | ref_mel_lens[bucket_i], 327 | total_mel_lens[bucket_i], 328 | final_text_list[bucket_i] 329 | )) 330 | batch_accum[bucket_i] = 0 331 | utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], [] 332 | 333 | # add residual 334 | for bucket_i, bucket_frames in enumerate(batch_accum): 335 | if bucket_frames > 0: 336 | prompts_all.append(( 337 | utts[bucket_i], 338 | ref_rms_list[bucket_i], 339 | padded_mel_batch(ref_mels[bucket_i]), 340 | ref_mel_lens[bucket_i], 341 | total_mel_lens[bucket_i], 342 | final_text_list[bucket_i] 343 | )) 344 | # not only leave easy work for last workers 345 | random.seed(666) 346 | random.shuffle(prompts_all) 347 | 348 | return prompts_all 349 | 350 | 351 | # get wav_res_ref_text of seed-tts test metalst 352 | # https://github.com/BytedanceSpeech/seed-tts-eval 353 | 354 | def get_seed_tts_test(metalst, gen_wav_dir, gpus): 355 | f = open(metalst) 356 | lines = f.readlines() 357 | f.close() 358 | 359 | test_set_ = [] 360 | for line in tqdm(lines): 361 | if len(line.strip().split('|')) == 5: 362 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|') 363 | elif len(line.strip().split('|')) == 4: 364 | utt, prompt_text, prompt_wav, gt_text = line.strip().split('|') 365 | 366 | if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')): 367 | continue 368 | gen_wav = os.path.join(gen_wav_dir, utt + '.wav') 369 | if not os.path.isabs(prompt_wav): 370 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) 371 | 372 | test_set_.append((gen_wav, prompt_wav, gt_text)) 373 | 374 | num_jobs = len(gpus) 375 | if num_jobs == 1: 376 | return [(gpus[0], test_set_)] 377 | 378 | wav_per_job = len(test_set_) // num_jobs + 1 379 | test_set = [] 380 | for i in range(num_jobs): 381 | test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job])) 382 | 383 | return test_set 384 | 385 | 386 | # get librispeech test-clean cross sentence test 387 | 388 | def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False): 389 | f = open(metalst) 390 | lines = f.readlines() 391 | f.close() 392 | 393 | test_set_ = [] 394 | for line in tqdm(lines): 395 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t') 396 | 397 | if eval_ground_truth: 398 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-') 399 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac') 400 | else: 401 | if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')): 402 | raise FileNotFoundError(f"Generated wav not found: {gen_utt}") 403 | gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav') 404 | 405 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-') 406 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac') 407 | 408 | test_set_.append((gen_wav, ref_wav, gen_txt)) 409 | 410 | num_jobs = len(gpus) 411 | if num_jobs == 1: 412 | return [(gpus[0], test_set_)] 413 | 414 | wav_per_job = len(test_set_) // num_jobs + 1 415 | test_set = [] 416 | for i in range(num_jobs): 417 | test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job])) 418 | 419 | return test_set 420 | 421 | 422 | # load asr model 423 | 424 | def load_asr_model(lang, ckpt_dir = ""): 425 | if lang == "zh": 426 | model = AutoModel( 427 | model = os.path.join(ckpt_dir, "paraformer-zh"), 428 | # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), 429 | # punc_model = os.path.join(ckpt_dir, "ct-punc"), 430 | # spk_model = os.path.join(ckpt_dir, "cam++"), 431 | disable_update=True, 432 | ) # following seed-tts setting 433 | elif lang == "en": 434 | model_size = "large-v3" if ckpt_dir == "" else ckpt_dir 435 | model = WhisperModel(model_size, device="cuda", compute_type="float16") 436 | return model 437 | 438 | 439 | # WER Evaluation, the way Seed-TTS does 440 | 441 | def run_asr_wer(args): 442 | rank, lang, test_set, ckpt_dir = args 443 | 444 | if lang == "zh": 445 | torch.cuda.set_device(rank) 446 | elif lang == "en": 447 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 448 | else: 449 | raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.") 450 | 451 | asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir) 452 | 453 | punctuation_all = punctuation + string.punctuation 454 | wers = [] 455 | 456 | for gen_wav, prompt_wav, truth in tqdm(test_set): 457 | if lang == "zh": 458 | res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True) 459 | hypo = res[0]["text"] 460 | hypo = zhconv.convert(hypo, 'zh-cn') 461 | elif lang == "en": 462 | segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en") 463 | hypo = '' 464 | for segment in segments: 465 | hypo = hypo + ' ' + segment.text 466 | 467 | # raw_truth = truth 468 | # raw_hypo = hypo 469 | 470 | for x in punctuation_all: 471 | truth = truth.replace(x, '') 472 | hypo = hypo.replace(x, '') 473 | 474 | truth = truth.replace(' ', ' ') 475 | hypo = hypo.replace(' ', ' ') 476 | 477 | if lang == "zh": 478 | truth = " ".join([x for x in truth]) 479 | hypo = " ".join([x for x in hypo]) 480 | elif lang == "en": 481 | truth = truth.lower() 482 | hypo = hypo.lower() 483 | 484 | measures = compute_measures(truth, hypo) 485 | wer = measures["wer"] 486 | 487 | # ref_list = truth.split(" ") 488 | # subs = measures["substitutions"] / len(ref_list) 489 | # dele = measures["deletions"] / len(ref_list) 490 | # inse = measures["insertions"] / len(ref_list) 491 | 492 | wers.append(wer) 493 | 494 | return wers 495 | 496 | 497 | # SIM Evaluation 498 | 499 | def run_sim(args): 500 | rank, test_set, ckpt_dir = args 501 | device = f"cuda:{rank}" 502 | 503 | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None) 504 | state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage) 505 | model.load_state_dict(state_dict['model'], strict=False) 506 | 507 | use_gpu=True if torch.cuda.is_available() else False 508 | if use_gpu: 509 | model = model.cuda(device) 510 | model.eval() 511 | 512 | sim_list = [] 513 | for wav1, wav2, truth in tqdm(test_set): 514 | 515 | wav1, sr1 = torchaudio.load(wav1) 516 | wav2, sr2 = torchaudio.load(wav2) 517 | 518 | resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000) 519 | resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000) 520 | wav1 = resample1(wav1) 521 | wav2 = resample2(wav2) 522 | 523 | if use_gpu: 524 | wav1 = wav1.cuda(device) 525 | wav2 = wav2.cuda(device) 526 | with torch.no_grad(): 527 | emb1 = model(wav1) 528 | emb2 = model(wav2) 529 | 530 | sim = F.cosine_similarity(emb1, emb2)[0].item() 531 | # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") 532 | sim_list.append(sim) 533 | 534 | return sim_list 535 | 536 | 537 | # filter func for dirty data with many repetitions 538 | 539 | def repetition_found(text, length = 2, tolerance = 10): 540 | pattern_count = defaultdict(int) 541 | for i in range(len(text) - length + 1): 542 | pattern = text[i:i + length] 543 | pattern_count[pattern] += 1 544 | for pattern, count in pattern_count.items(): 545 | if count > tolerance: 546 | return True 547 | return False 548 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | from typing import Optional 12 | import math 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | import torchaudio 18 | 19 | from einops import rearrange 20 | from x_transformers.x_transformers import apply_rotary_pos_emb 21 | 22 | 23 | # raw wav to mel spec 24 | 25 | class MelSpec(nn.Module): 26 | def __init__( 27 | self, 28 | filter_length = 1024, 29 | hop_length = 256, 30 | win_length = 1024, 31 | n_mel_channels = 100, 32 | target_sample_rate = 24_000, 33 | normalize = False, 34 | power = 1, 35 | norm = None, 36 | center = True, 37 | ): 38 | super().__init__() 39 | self.n_mel_channels = n_mel_channels 40 | 41 | self.mel_stft = torchaudio.transforms.MelSpectrogram( 42 | sample_rate = target_sample_rate, 43 | n_fft = filter_length, 44 | win_length = win_length, 45 | hop_length = hop_length, 46 | n_mels = n_mel_channels, 47 | power = power, 48 | center = center, 49 | normalized = normalize, 50 | norm = norm, 51 | ) 52 | 53 | self.register_buffer('dummy', torch.tensor(0), persistent = False) 54 | 55 | def forward(self, inp): 56 | if len(inp.shape) == 3: 57 | inp = rearrange(inp, 'b 1 nw -> b nw') 58 | 59 | assert len(inp.shape) == 2 60 | 61 | if self.dummy.device != inp.device: 62 | self.to(inp.device) 63 | 64 | mel = self.mel_stft(inp) 65 | mel = mel.clamp(min = 1e-5).log() 66 | return mel 67 | 68 | 69 | # sinusoidal position embedding 70 | 71 | class SinusPositionEmbedding(nn.Module): 72 | def __init__(self, dim): 73 | super().__init__() 74 | self.dim = dim 75 | 76 | def forward(self, x, scale=1000): 77 | device = x.device 78 | half_dim = self.dim // 2 79 | emb = math.log(10000) / (half_dim - 1) 80 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 81 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 82 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 83 | return emb 84 | 85 | 86 | # convolutional position embedding 87 | 88 | class ConvPositionEmbedding(nn.Module): 89 | def __init__(self, dim, kernel_size = 31, groups = 16): 90 | super().__init__() 91 | assert kernel_size % 2 != 0 92 | self.conv1d = nn.Sequential( 93 | nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), 94 | nn.Mish(), 95 | nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), 96 | nn.Mish(), 97 | ) 98 | 99 | def forward(self, x: float['b n d'], mask: bool['b n'] | None = None): 100 | if mask is not None: 101 | mask = mask[..., None] 102 | x = x.masked_fill(~mask, 0.) 103 | 104 | x = rearrange(x, 'b n d -> b d n') 105 | x = self.conv1d(x) 106 | out = rearrange(x, 'b d n -> b n d') 107 | 108 | if mask is not None: 109 | out = out.masked_fill(~mask, 0.) 110 | 111 | return out 112 | 113 | 114 | # rotary positional embedding related 115 | 116 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.): 117 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 118 | # has some connection to NTK literature 119 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 120 | # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py 121 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 122 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 123 | t = torch.arange(end, device=freqs.device) # type: ignore 124 | freqs = torch.outer(t, freqs).float() # type: ignore 125 | freqs_cos = torch.cos(freqs) # real part 126 | freqs_sin = torch.sin(freqs) # imaginary part 127 | return torch.cat([freqs_cos, freqs_sin], dim=-1) 128 | 129 | def get_pos_embed_indices(start, length, max_pos, scale=1.): 130 | # length = length if isinstance(length, int) else length.max() 131 | scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar 132 | pos = start.unsqueeze(1) + ( 133 | torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * 134 | scale.unsqueeze(1)).long() 135 | # avoid extra long error. 136 | pos = torch.where(pos < max_pos, pos, max_pos - 1) 137 | return pos 138 | 139 | 140 | # Global Response Normalization layer (Instance Normalization ?) 141 | 142 | class GRN(nn.Module): 143 | def __init__(self, dim): 144 | super().__init__() 145 | self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) 146 | self.beta = nn.Parameter(torch.zeros(1, 1, dim)) 147 | 148 | def forward(self, x): 149 | Gx = torch.norm(x, p=2, dim=1, keepdim=True) 150 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 151 | return self.gamma * (x * Nx) + self.beta + x 152 | 153 | 154 | # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py 155 | # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 156 | 157 | class ConvNeXtV2Block(nn.Module): 158 | def __init__( 159 | self, 160 | dim: int, 161 | intermediate_dim: int, 162 | dilation: int = 1, 163 | ): 164 | super().__init__() 165 | padding = (dilation * (7 - 1)) // 2 166 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv 167 | self.norm = nn.LayerNorm(dim, eps=1e-6) 168 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 169 | self.act = nn.GELU() 170 | self.grn = GRN(intermediate_dim) 171 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 172 | 173 | def forward(self, x: torch.Tensor) -> torch.Tensor: 174 | residual = x 175 | x = x.transpose(1, 2) # b n d -> b d n 176 | x = self.dwconv(x) 177 | x = x.transpose(1, 2) # b d n -> b n d 178 | x = self.norm(x) 179 | x = self.pwconv1(x) 180 | x = self.act(x) 181 | x = self.grn(x) 182 | x = self.pwconv2(x) 183 | return residual + x 184 | 185 | 186 | # AdaLayerNormZero 187 | # return with modulated x for attn input, and params for later mlp modulation 188 | 189 | class AdaLayerNormZero(nn.Module): 190 | def __init__(self, dim): 191 | super().__init__() 192 | 193 | self.silu = nn.SiLU() 194 | self.linear = nn.Linear(dim, dim * 6) 195 | 196 | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 197 | 198 | def forward(self, x, emb = None): 199 | emb = self.linear(self.silu(emb)) 200 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) 201 | 202 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 203 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 204 | 205 | 206 | # AdaLayerNormZero for final layer 207 | # return only with modulated x for attn input, cuz no more mlp modulation 208 | 209 | class AdaLayerNormZero_Final(nn.Module): 210 | def __init__(self, dim): 211 | super().__init__() 212 | 213 | self.silu = nn.SiLU() 214 | self.linear = nn.Linear(dim, dim * 2) 215 | 216 | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 217 | 218 | def forward(self, x, emb): 219 | emb = self.linear(self.silu(emb)) 220 | scale, shift = torch.chunk(emb, 2, dim=1) 221 | 222 | x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] 223 | return x 224 | 225 | 226 | # FeedForward 227 | 228 | class FeedForward(nn.Module): 229 | def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'): 230 | super().__init__() 231 | inner_dim = int(dim * mult) 232 | dim_out = dim_out if dim_out is not None else dim 233 | 234 | activation = nn.GELU(approximate=approximate) 235 | project_in = nn.Sequential( 236 | nn.Linear(dim, inner_dim), 237 | activation 238 | ) 239 | self.ff = nn.Sequential( 240 | project_in, 241 | nn.Dropout(dropout), 242 | nn.Linear(inner_dim, dim_out) 243 | ) 244 | 245 | def forward(self, x): 246 | return self.ff(x) 247 | 248 | 249 | # Attention with possible joint part 250 | # modified from diffusers/src/diffusers/models/attention_processor.py 251 | 252 | class Attention(nn.Module): 253 | def __init__( 254 | self, 255 | processor: JointAttnProcessor | AttnProcessor, 256 | dim: int, 257 | heads: int = 8, 258 | dim_head: int = 64, 259 | dropout: float = 0.0, 260 | context_dim: Optional[int] = None, # if not None -> joint attention 261 | context_pre_only = None, 262 | ): 263 | super().__init__() 264 | 265 | if not hasattr(F, "scaled_dot_product_attention"): 266 | raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 267 | 268 | self.processor = processor 269 | 270 | self.dim = dim 271 | self.heads = heads 272 | self.inner_dim = dim_head * heads 273 | self.dropout = dropout 274 | 275 | self.context_dim = context_dim 276 | self.context_pre_only = context_pre_only 277 | 278 | self.to_q = nn.Linear(dim, self.inner_dim) 279 | self.to_k = nn.Linear(dim, self.inner_dim) 280 | self.to_v = nn.Linear(dim, self.inner_dim) 281 | 282 | if self.context_dim is not None: 283 | self.to_k_c = nn.Linear(context_dim, self.inner_dim) 284 | self.to_v_c = nn.Linear(context_dim, self.inner_dim) 285 | if self.context_pre_only is not None: 286 | self.to_q_c = nn.Linear(context_dim, self.inner_dim) 287 | 288 | self.to_out = nn.ModuleList([]) 289 | self.to_out.append(nn.Linear(self.inner_dim, dim)) 290 | self.to_out.append(nn.Dropout(dropout)) 291 | 292 | if self.context_pre_only is not None and not self.context_pre_only: 293 | self.to_out_c = nn.Linear(self.inner_dim, dim) 294 | 295 | def forward( 296 | self, 297 | x: float['b n d'], # noised input x 298 | c: float['b n d'] = None, # context c 299 | mask: bool['b n'] | None = None, 300 | rope = None, # rotary position embedding for x 301 | c_rope = None, # rotary position embedding for c 302 | ) -> torch.Tensor: 303 | if c is not None: 304 | return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope) 305 | else: 306 | return self.processor(self, x, mask = mask, rope = rope) 307 | 308 | 309 | # Attention processor 310 | 311 | class AttnProcessor: 312 | def __init__(self): 313 | pass 314 | 315 | def __call__( 316 | self, 317 | attn: Attention, 318 | x: float['b n d'], # noised input x 319 | mask: bool['b n'] | None = None, 320 | rope = None, # rotary position embedding 321 | ) -> torch.FloatTensor: 322 | 323 | batch_size = x.shape[0] 324 | 325 | # `sample` projections. 326 | query = attn.to_q(x) 327 | key = attn.to_k(x) 328 | value = attn.to_v(x) 329 | 330 | # apply rotary position embedding 331 | if rope is not None: 332 | freqs, xpos_scale = rope 333 | q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.) 334 | 335 | query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) 336 | key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) 337 | 338 | # attention 339 | inner_dim = key.shape[-1] 340 | head_dim = inner_dim // attn.heads 341 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 342 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 343 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 344 | 345 | # mask. e.g. inference got a batch with different target durations, mask out the padding 346 | if mask is not None: 347 | attn_mask = mask 348 | attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n') 349 | attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) 350 | else: 351 | attn_mask = None 352 | 353 | x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) 354 | x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 355 | x = x.to(query.dtype) 356 | 357 | # linear proj 358 | x = attn.to_out[0](x) 359 | # dropout 360 | x = attn.to_out[1](x) 361 | 362 | if mask is not None: 363 | mask = rearrange(mask, 'b n -> b n 1') 364 | x = x.masked_fill(~mask, 0.) 365 | 366 | return x 367 | 368 | 369 | # Joint Attention processor for MM-DiT 370 | # modified from diffusers/src/diffusers/models/attention_processor.py 371 | 372 | class JointAttnProcessor: 373 | def __init__(self): 374 | pass 375 | 376 | def __call__( 377 | self, 378 | attn: Attention, 379 | x: float['b n d'], # noised input x 380 | c: float['b nt d'] = None, # context c, here text 381 | mask: bool['b n'] | None = None, 382 | rope = None, # rotary position embedding for x 383 | c_rope = None, # rotary position embedding for c 384 | ) -> torch.FloatTensor: 385 | residual = x 386 | 387 | batch_size = c.shape[0] 388 | 389 | # `sample` projections. 390 | query = attn.to_q(x) 391 | key = attn.to_k(x) 392 | value = attn.to_v(x) 393 | 394 | # `context` projections. 395 | c_query = attn.to_q_c(c) 396 | c_key = attn.to_k_c(c) 397 | c_value = attn.to_v_c(c) 398 | 399 | # apply rope for context and noised input independently 400 | if rope is not None: 401 | freqs, xpos_scale = rope 402 | q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.) 403 | query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) 404 | key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) 405 | if c_rope is not None: 406 | freqs, xpos_scale = c_rope 407 | q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.) 408 | c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) 409 | c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) 410 | 411 | # attention 412 | query = torch.cat([query, c_query], dim=1) 413 | key = torch.cat([key, c_key], dim=1) 414 | value = torch.cat([value, c_value], dim=1) 415 | 416 | inner_dim = key.shape[-1] 417 | head_dim = inner_dim // attn.heads 418 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 419 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 420 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 421 | 422 | # mask. e.g. inference got a batch with different target durations, mask out the padding 423 | if mask is not None: 424 | attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text) 425 | attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n') 426 | attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) 427 | else: 428 | attn_mask = None 429 | 430 | x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) 431 | x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 432 | x = x.to(query.dtype) 433 | 434 | # Split the attention outputs. 435 | x, c = ( 436 | x[:, :residual.shape[1]], 437 | x[:, residual.shape[1]:], 438 | ) 439 | 440 | # linear proj 441 | x = attn.to_out[0](x) 442 | # dropout 443 | x = attn.to_out[1](x) 444 | if not attn.context_pre_only: 445 | c = attn.to_out_c(c) 446 | 447 | if mask is not None: 448 | mask = rearrange(mask, 'b n -> b n 1') 449 | x = x.masked_fill(~mask, 0.) 450 | # c = c.masked_fill(~mask, 0.) # no mask for c (text) 451 | 452 | return x, c 453 | 454 | 455 | # DiT Block 456 | 457 | class DiTBlock(nn.Module): 458 | 459 | def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1): 460 | super().__init__() 461 | 462 | self.attn_norm = AdaLayerNormZero(dim) 463 | self.attn = Attention( 464 | processor = AttnProcessor(), 465 | dim = dim, 466 | heads = heads, 467 | dim_head = dim_head, 468 | dropout = dropout, 469 | ) 470 | 471 | self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 472 | self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") 473 | 474 | def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding 475 | # pre-norm & modulation for attention input 476 | norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) 477 | 478 | # attention 479 | attn_output = self.attn(x=norm, mask=mask, rope=rope) 480 | 481 | # process attention output for input x 482 | x = x + gate_msa.unsqueeze(1) * attn_output 483 | 484 | norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 485 | ff_output = self.ff(norm) 486 | x = x + gate_mlp.unsqueeze(1) * ff_output 487 | 488 | return x 489 | 490 | 491 | # MMDiT Block https://arxiv.org/abs/2403.03206 492 | 493 | class MMDiTBlock(nn.Module): 494 | r""" 495 | modified from diffusers/src/diffusers/models/attention.py 496 | 497 | notes. 498 | _c: context related. text, cond, etc. (left part in sd3 fig2.b) 499 | _x: noised input related. (right part) 500 | context_pre_only: last layer only do prenorm + modulation cuz no more ffn 501 | """ 502 | 503 | def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False): 504 | super().__init__() 505 | 506 | self.context_pre_only = context_pre_only 507 | 508 | self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) 509 | self.attn_norm_x = AdaLayerNormZero(dim) 510 | self.attn = Attention( 511 | processor = JointAttnProcessor(), 512 | dim = dim, 513 | heads = heads, 514 | dim_head = dim_head, 515 | dropout = dropout, 516 | context_dim = dim, 517 | context_pre_only = context_pre_only, 518 | ) 519 | 520 | if not context_pre_only: 521 | self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 522 | self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") 523 | else: 524 | self.ff_norm_c = None 525 | self.ff_c = None 526 | self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 527 | self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh") 528 | 529 | def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding 530 | # pre-norm & modulation for attention input 531 | if self.context_pre_only: 532 | norm_c = self.attn_norm_c(c, t) 533 | else: 534 | norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) 535 | norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) 536 | 537 | # attention 538 | x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope) 539 | 540 | # process attention output for context c 541 | if self.context_pre_only: 542 | c = None 543 | else: # if not last layer 544 | c = c + c_gate_msa.unsqueeze(1) * c_attn_output 545 | 546 | norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 547 | c_ff_output = self.ff_c(norm_c) 548 | c = c + c_gate_mlp.unsqueeze(1) * c_ff_output 549 | 550 | # process attention output for input x 551 | x = x + x_gate_msa.unsqueeze(1) * x_attn_output 552 | 553 | norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] 554 | x_ff_output = self.ff_x(norm_x) 555 | x = x + x_gate_mlp.unsqueeze(1) * x_ff_output 556 | 557 | return c, x 558 | 559 | 560 | # time step conditioning embedding 561 | 562 | class TimestepEmbedding(nn.Module): 563 | def __init__(self, dim, freq_embed_dim=256): 564 | super().__init__() 565 | self.time_embed = SinusPositionEmbedding(freq_embed_dim) 566 | self.time_mlp = nn.Sequential( 567 | nn.Linear(freq_embed_dim, dim), 568 | nn.SiLU(), 569 | nn.Linear(dim, dim) 570 | ) 571 | 572 | def forward(self, timestep: float['b']): 573 | time_hidden = self.time_embed(timestep) 574 | time = self.time_mlp(time_hidden) # b d 575 | return time 576 | -------------------------------------------------------------------------------- /data/Emilia_ZH_EN_pinyin/vocab.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | " 4 | # 5 | $ 6 | % 7 | & 8 | ' 9 | ( 10 | ) 11 | * 12 | + 13 | , 14 | - 15 | . 16 | / 17 | 0 18 | 1 19 | 2 20 | 3 21 | 4 22 | 5 23 | 6 24 | 7 25 | 8 26 | 9 27 | : 28 | ; 29 | = 30 | > 31 | ? 32 | @ 33 | A 34 | B 35 | C 36 | D 37 | E 38 | F 39 | G 40 | H 41 | I 42 | J 43 | K 44 | L 45 | M 46 | N 47 | O 48 | P 49 | Q 50 | R 51 | S 52 | T 53 | U 54 | V 55 | W 56 | X 57 | Y 58 | Z 59 | [ 60 | \ 61 | ] 62 | _ 63 | a 64 | a1 65 | ai1 66 | ai2 67 | ai3 68 | ai4 69 | an1 70 | an3 71 | an4 72 | ang1 73 | ang2 74 | ang4 75 | ao1 76 | ao2 77 | ao3 78 | ao4 79 | b 80 | ba 81 | ba1 82 | ba2 83 | ba3 84 | ba4 85 | bai1 86 | bai2 87 | bai3 88 | bai4 89 | ban1 90 | ban2 91 | ban3 92 | ban4 93 | bang1 94 | bang2 95 | bang3 96 | bang4 97 | bao1 98 | bao2 99 | bao3 100 | bao4 101 | bei 102 | bei1 103 | bei2 104 | bei3 105 | bei4 106 | ben1 107 | ben2 108 | ben3 109 | ben4 110 | beng 111 | beng1 112 | beng2 113 | beng3 114 | beng4 115 | bi1 116 | bi2 117 | bi3 118 | bi4 119 | bian1 120 | bian2 121 | bian3 122 | bian4 123 | biao1 124 | biao2 125 | biao3 126 | bie1 127 | bie2 128 | bie3 129 | bie4 130 | bin1 131 | bin4 132 | bing1 133 | bing2 134 | bing3 135 | bing4 136 | bo 137 | bo1 138 | bo2 139 | bo3 140 | bo4 141 | bu2 142 | bu3 143 | bu4 144 | c 145 | ca1 146 | cai1 147 | cai2 148 | cai3 149 | cai4 150 | can1 151 | can2 152 | can3 153 | can4 154 | cang1 155 | cang2 156 | cao1 157 | cao2 158 | cao3 159 | ce4 160 | cen1 161 | cen2 162 | ceng1 163 | ceng2 164 | ceng4 165 | cha1 166 | cha2 167 | cha3 168 | cha4 169 | chai1 170 | chai2 171 | chan1 172 | chan2 173 | chan3 174 | chan4 175 | chang1 176 | chang2 177 | chang3 178 | chang4 179 | chao1 180 | chao2 181 | chao3 182 | che1 183 | che2 184 | che3 185 | che4 186 | chen1 187 | chen2 188 | chen3 189 | chen4 190 | cheng1 191 | cheng2 192 | cheng3 193 | cheng4 194 | chi1 195 | chi2 196 | chi3 197 | chi4 198 | chong1 199 | chong2 200 | chong3 201 | chong4 202 | chou1 203 | chou2 204 | chou3 205 | chou4 206 | chu1 207 | chu2 208 | chu3 209 | chu4 210 | chua1 211 | chuai1 212 | chuai2 213 | chuai3 214 | chuai4 215 | chuan1 216 | chuan2 217 | chuan3 218 | chuan4 219 | chuang1 220 | chuang2 221 | chuang3 222 | chuang4 223 | chui1 224 | chui2 225 | chun1 226 | chun2 227 | chun3 228 | chuo1 229 | chuo4 230 | ci1 231 | ci2 232 | ci3 233 | ci4 234 | cong1 235 | cong2 236 | cou4 237 | cu1 238 | cu4 239 | cuan1 240 | cuan2 241 | cuan4 242 | cui1 243 | cui3 244 | cui4 245 | cun1 246 | cun2 247 | cun4 248 | cuo1 249 | cuo2 250 | cuo4 251 | d 252 | da 253 | da1 254 | da2 255 | da3 256 | da4 257 | dai1 258 | dai2 259 | dai3 260 | dai4 261 | dan1 262 | dan2 263 | dan3 264 | dan4 265 | dang1 266 | dang2 267 | dang3 268 | dang4 269 | dao1 270 | dao2 271 | dao3 272 | dao4 273 | de 274 | de1 275 | de2 276 | dei3 277 | den4 278 | deng1 279 | deng2 280 | deng3 281 | deng4 282 | di1 283 | di2 284 | di3 285 | di4 286 | dia3 287 | dian1 288 | dian2 289 | dian3 290 | dian4 291 | diao1 292 | diao3 293 | diao4 294 | die1 295 | die2 296 | die4 297 | ding1 298 | ding2 299 | ding3 300 | ding4 301 | diu1 302 | dong1 303 | dong3 304 | dong4 305 | dou1 306 | dou2 307 | dou3 308 | dou4 309 | du1 310 | du2 311 | du3 312 | du4 313 | duan1 314 | duan2 315 | duan3 316 | duan4 317 | dui1 318 | dui4 319 | dun1 320 | dun3 321 | dun4 322 | duo1 323 | duo2 324 | duo3 325 | duo4 326 | e 327 | e1 328 | e2 329 | e3 330 | e4 331 | ei2 332 | en1 333 | en4 334 | er 335 | er2 336 | er3 337 | er4 338 | f 339 | fa1 340 | fa2 341 | fa3 342 | fa4 343 | fan1 344 | fan2 345 | fan3 346 | fan4 347 | fang1 348 | fang2 349 | fang3 350 | fang4 351 | fei1 352 | fei2 353 | fei3 354 | fei4 355 | fen1 356 | fen2 357 | fen3 358 | fen4 359 | feng1 360 | feng2 361 | feng3 362 | feng4 363 | fo2 364 | fou2 365 | fou3 366 | fu1 367 | fu2 368 | fu3 369 | fu4 370 | g 371 | ga1 372 | ga2 373 | ga3 374 | ga4 375 | gai1 376 | gai2 377 | gai3 378 | gai4 379 | gan1 380 | gan2 381 | gan3 382 | gan4 383 | gang1 384 | gang2 385 | gang3 386 | gang4 387 | gao1 388 | gao2 389 | gao3 390 | gao4 391 | ge1 392 | ge2 393 | ge3 394 | ge4 395 | gei2 396 | gei3 397 | gen1 398 | gen2 399 | gen3 400 | gen4 401 | geng1 402 | geng3 403 | geng4 404 | gong1 405 | gong3 406 | gong4 407 | gou1 408 | gou2 409 | gou3 410 | gou4 411 | gu 412 | gu1 413 | gu2 414 | gu3 415 | gu4 416 | gua1 417 | gua2 418 | gua3 419 | gua4 420 | guai1 421 | guai2 422 | guai3 423 | guai4 424 | guan1 425 | guan2 426 | guan3 427 | guan4 428 | guang1 429 | guang2 430 | guang3 431 | guang4 432 | gui1 433 | gui2 434 | gui3 435 | gui4 436 | gun3 437 | gun4 438 | guo1 439 | guo2 440 | guo3 441 | guo4 442 | h 443 | ha1 444 | ha2 445 | ha3 446 | hai1 447 | hai2 448 | hai3 449 | hai4 450 | han1 451 | han2 452 | han3 453 | han4 454 | hang1 455 | hang2 456 | hang4 457 | hao1 458 | hao2 459 | hao3 460 | hao4 461 | he1 462 | he2 463 | he4 464 | hei1 465 | hen2 466 | hen3 467 | hen4 468 | heng1 469 | heng2 470 | heng4 471 | hong1 472 | hong2 473 | hong3 474 | hong4 475 | hou1 476 | hou2 477 | hou3 478 | hou4 479 | hu1 480 | hu2 481 | hu3 482 | hu4 483 | hua1 484 | hua2 485 | hua4 486 | huai2 487 | huai4 488 | huan1 489 | huan2 490 | huan3 491 | huan4 492 | huang1 493 | huang2 494 | huang3 495 | huang4 496 | hui1 497 | hui2 498 | hui3 499 | hui4 500 | hun1 501 | hun2 502 | hun4 503 | huo 504 | huo1 505 | huo2 506 | huo3 507 | huo4 508 | i 509 | j 510 | ji1 511 | ji2 512 | ji3 513 | ji4 514 | jia 515 | jia1 516 | jia2 517 | jia3 518 | jia4 519 | jian1 520 | jian2 521 | jian3 522 | jian4 523 | jiang1 524 | jiang2 525 | jiang3 526 | jiang4 527 | jiao1 528 | jiao2 529 | jiao3 530 | jiao4 531 | jie1 532 | jie2 533 | jie3 534 | jie4 535 | jin1 536 | jin2 537 | jin3 538 | jin4 539 | jing1 540 | jing2 541 | jing3 542 | jing4 543 | jiong3 544 | jiu1 545 | jiu2 546 | jiu3 547 | jiu4 548 | ju1 549 | ju2 550 | ju3 551 | ju4 552 | juan1 553 | juan2 554 | juan3 555 | juan4 556 | jue1 557 | jue2 558 | jue4 559 | jun1 560 | jun4 561 | k 562 | ka1 563 | ka2 564 | ka3 565 | kai1 566 | kai2 567 | kai3 568 | kai4 569 | kan1 570 | kan2 571 | kan3 572 | kan4 573 | kang1 574 | kang2 575 | kang4 576 | kao1 577 | kao2 578 | kao3 579 | kao4 580 | ke1 581 | ke2 582 | ke3 583 | ke4 584 | ken3 585 | keng1 586 | kong1 587 | kong3 588 | kong4 589 | kou1 590 | kou2 591 | kou3 592 | kou4 593 | ku1 594 | ku2 595 | ku3 596 | ku4 597 | kua1 598 | kua3 599 | kua4 600 | kuai3 601 | kuai4 602 | kuan1 603 | kuan2 604 | kuan3 605 | kuang1 606 | kuang2 607 | kuang4 608 | kui1 609 | kui2 610 | kui3 611 | kui4 612 | kun1 613 | kun3 614 | kun4 615 | kuo4 616 | l 617 | la 618 | la1 619 | la2 620 | la3 621 | la4 622 | lai2 623 | lai4 624 | lan2 625 | lan3 626 | lan4 627 | lang1 628 | lang2 629 | lang3 630 | lang4 631 | lao1 632 | lao2 633 | lao3 634 | lao4 635 | le 636 | le1 637 | le4 638 | lei 639 | lei1 640 | lei2 641 | lei3 642 | lei4 643 | leng1 644 | leng2 645 | leng3 646 | leng4 647 | li 648 | li1 649 | li2 650 | li3 651 | li4 652 | lia3 653 | lian2 654 | lian3 655 | lian4 656 | liang2 657 | liang3 658 | liang4 659 | liao1 660 | liao2 661 | liao3 662 | liao4 663 | lie1 664 | lie2 665 | lie3 666 | lie4 667 | lin1 668 | lin2 669 | lin3 670 | lin4 671 | ling2 672 | ling3 673 | ling4 674 | liu1 675 | liu2 676 | liu3 677 | liu4 678 | long1 679 | long2 680 | long3 681 | long4 682 | lou1 683 | lou2 684 | lou3 685 | lou4 686 | lu1 687 | lu2 688 | lu3 689 | lu4 690 | luan2 691 | luan3 692 | luan4 693 | lun1 694 | lun2 695 | lun4 696 | luo1 697 | luo2 698 | luo3 699 | luo4 700 | lv2 701 | lv3 702 | lv4 703 | lve3 704 | lve4 705 | m 706 | ma 707 | ma1 708 | ma2 709 | ma3 710 | ma4 711 | mai2 712 | mai3 713 | mai4 714 | man1 715 | man2 716 | man3 717 | man4 718 | mang2 719 | mang3 720 | mao1 721 | mao2 722 | mao3 723 | mao4 724 | me 725 | mei2 726 | mei3 727 | mei4 728 | men 729 | men1 730 | men2 731 | men4 732 | meng 733 | meng1 734 | meng2 735 | meng3 736 | meng4 737 | mi1 738 | mi2 739 | mi3 740 | mi4 741 | mian2 742 | mian3 743 | mian4 744 | miao1 745 | miao2 746 | miao3 747 | miao4 748 | mie1 749 | mie4 750 | min2 751 | min3 752 | ming2 753 | ming3 754 | ming4 755 | miu4 756 | mo1 757 | mo2 758 | mo3 759 | mo4 760 | mou1 761 | mou2 762 | mou3 763 | mu2 764 | mu3 765 | mu4 766 | n 767 | n2 768 | na1 769 | na2 770 | na3 771 | na4 772 | nai2 773 | nai3 774 | nai4 775 | nan1 776 | nan2 777 | nan3 778 | nan4 779 | nang1 780 | nang2 781 | nang3 782 | nao1 783 | nao2 784 | nao3 785 | nao4 786 | ne 787 | ne2 788 | ne4 789 | nei3 790 | nei4 791 | nen4 792 | neng2 793 | ni1 794 | ni2 795 | ni3 796 | ni4 797 | nian1 798 | nian2 799 | nian3 800 | nian4 801 | niang2 802 | niang4 803 | niao2 804 | niao3 805 | niao4 806 | nie1 807 | nie4 808 | nin2 809 | ning2 810 | ning3 811 | ning4 812 | niu1 813 | niu2 814 | niu3 815 | niu4 816 | nong2 817 | nong4 818 | nou4 819 | nu2 820 | nu3 821 | nu4 822 | nuan3 823 | nuo2 824 | nuo4 825 | nv2 826 | nv3 827 | nve4 828 | o 829 | o1 830 | o2 831 | ou1 832 | ou2 833 | ou3 834 | ou4 835 | p 836 | pa1 837 | pa2 838 | pa4 839 | pai1 840 | pai2 841 | pai3 842 | pai4 843 | pan1 844 | pan2 845 | pan4 846 | pang1 847 | pang2 848 | pang4 849 | pao1 850 | pao2 851 | pao3 852 | pao4 853 | pei1 854 | pei2 855 | pei4 856 | pen1 857 | pen2 858 | pen4 859 | peng1 860 | peng2 861 | peng3 862 | peng4 863 | pi1 864 | pi2 865 | pi3 866 | pi4 867 | pian1 868 | pian2 869 | pian4 870 | piao1 871 | piao2 872 | piao3 873 | piao4 874 | pie1 875 | pie2 876 | pie3 877 | pin1 878 | pin2 879 | pin3 880 | pin4 881 | ping1 882 | ping2 883 | po1 884 | po2 885 | po3 886 | po4 887 | pou1 888 | pu1 889 | pu2 890 | pu3 891 | pu4 892 | q 893 | qi1 894 | qi2 895 | qi3 896 | qi4 897 | qia1 898 | qia3 899 | qia4 900 | qian1 901 | qian2 902 | qian3 903 | qian4 904 | qiang1 905 | qiang2 906 | qiang3 907 | qiang4 908 | qiao1 909 | qiao2 910 | qiao3 911 | qiao4 912 | qie1 913 | qie2 914 | qie3 915 | qie4 916 | qin1 917 | qin2 918 | qin3 919 | qin4 920 | qing1 921 | qing2 922 | qing3 923 | qing4 924 | qiong1 925 | qiong2 926 | qiu1 927 | qiu2 928 | qiu3 929 | qu1 930 | qu2 931 | qu3 932 | qu4 933 | quan1 934 | quan2 935 | quan3 936 | quan4 937 | que1 938 | que2 939 | que4 940 | qun2 941 | r 942 | ran2 943 | ran3 944 | rang1 945 | rang2 946 | rang3 947 | rang4 948 | rao2 949 | rao3 950 | rao4 951 | re2 952 | re3 953 | re4 954 | ren2 955 | ren3 956 | ren4 957 | reng1 958 | reng2 959 | ri4 960 | rong1 961 | rong2 962 | rong3 963 | rou2 964 | rou4 965 | ru2 966 | ru3 967 | ru4 968 | ruan2 969 | ruan3 970 | rui3 971 | rui4 972 | run4 973 | ruo4 974 | s 975 | sa1 976 | sa2 977 | sa3 978 | sa4 979 | sai1 980 | sai4 981 | san1 982 | san2 983 | san3 984 | san4 985 | sang1 986 | sang3 987 | sang4 988 | sao1 989 | sao2 990 | sao3 991 | sao4 992 | se4 993 | sen1 994 | seng1 995 | sha1 996 | sha2 997 | sha3 998 | sha4 999 | shai1 1000 | shai2 1001 | shai3 1002 | shai4 1003 | shan1 1004 | shan3 1005 | shan4 1006 | shang 1007 | shang1 1008 | shang3 1009 | shang4 1010 | shao1 1011 | shao2 1012 | shao3 1013 | shao4 1014 | she1 1015 | she2 1016 | she3 1017 | she4 1018 | shei2 1019 | shen1 1020 | shen2 1021 | shen3 1022 | shen4 1023 | sheng1 1024 | sheng2 1025 | sheng3 1026 | sheng4 1027 | shi 1028 | shi1 1029 | shi2 1030 | shi3 1031 | shi4 1032 | shou1 1033 | shou2 1034 | shou3 1035 | shou4 1036 | shu1 1037 | shu2 1038 | shu3 1039 | shu4 1040 | shua1 1041 | shua2 1042 | shua3 1043 | shua4 1044 | shuai1 1045 | shuai3 1046 | shuai4 1047 | shuan1 1048 | shuan4 1049 | shuang1 1050 | shuang3 1051 | shui2 1052 | shui3 1053 | shui4 1054 | shun3 1055 | shun4 1056 | shuo1 1057 | shuo4 1058 | si1 1059 | si2 1060 | si3 1061 | si4 1062 | song1 1063 | song3 1064 | song4 1065 | sou1 1066 | sou3 1067 | sou4 1068 | su1 1069 | su2 1070 | su4 1071 | suan1 1072 | suan4 1073 | sui1 1074 | sui2 1075 | sui3 1076 | sui4 1077 | sun1 1078 | sun3 1079 | suo 1080 | suo1 1081 | suo2 1082 | suo3 1083 | t 1084 | ta1 1085 | ta2 1086 | ta3 1087 | ta4 1088 | tai1 1089 | tai2 1090 | tai4 1091 | tan1 1092 | tan2 1093 | tan3 1094 | tan4 1095 | tang1 1096 | tang2 1097 | tang3 1098 | tang4 1099 | tao1 1100 | tao2 1101 | tao3 1102 | tao4 1103 | te4 1104 | teng2 1105 | ti1 1106 | ti2 1107 | ti3 1108 | ti4 1109 | tian1 1110 | tian2 1111 | tian3 1112 | tiao1 1113 | tiao2 1114 | tiao3 1115 | tiao4 1116 | tie1 1117 | tie2 1118 | tie3 1119 | tie4 1120 | ting1 1121 | ting2 1122 | ting3 1123 | tong1 1124 | tong2 1125 | tong3 1126 | tong4 1127 | tou 1128 | tou1 1129 | tou2 1130 | tou4 1131 | tu1 1132 | tu2 1133 | tu3 1134 | tu4 1135 | tuan1 1136 | tuan2 1137 | tui1 1138 | tui2 1139 | tui3 1140 | tui4 1141 | tun1 1142 | tun2 1143 | tun4 1144 | tuo1 1145 | tuo2 1146 | tuo3 1147 | tuo4 1148 | u 1149 | v 1150 | w 1151 | wa 1152 | wa1 1153 | wa2 1154 | wa3 1155 | wa4 1156 | wai1 1157 | wai3 1158 | wai4 1159 | wan1 1160 | wan2 1161 | wan3 1162 | wan4 1163 | wang1 1164 | wang2 1165 | wang3 1166 | wang4 1167 | wei1 1168 | wei2 1169 | wei3 1170 | wei4 1171 | wen1 1172 | wen2 1173 | wen3 1174 | wen4 1175 | weng1 1176 | weng4 1177 | wo1 1178 | wo2 1179 | wo3 1180 | wo4 1181 | wu1 1182 | wu2 1183 | wu3 1184 | wu4 1185 | x 1186 | xi1 1187 | xi2 1188 | xi3 1189 | xi4 1190 | xia1 1191 | xia2 1192 | xia4 1193 | xian1 1194 | xian2 1195 | xian3 1196 | xian4 1197 | xiang1 1198 | xiang2 1199 | xiang3 1200 | xiang4 1201 | xiao1 1202 | xiao2 1203 | xiao3 1204 | xiao4 1205 | xie1 1206 | xie2 1207 | xie3 1208 | xie4 1209 | xin1 1210 | xin2 1211 | xin4 1212 | xing1 1213 | xing2 1214 | xing3 1215 | xing4 1216 | xiong1 1217 | xiong2 1218 | xiu1 1219 | xiu3 1220 | xiu4 1221 | xu 1222 | xu1 1223 | xu2 1224 | xu3 1225 | xu4 1226 | xuan1 1227 | xuan2 1228 | xuan3 1229 | xuan4 1230 | xue1 1231 | xue2 1232 | xue3 1233 | xue4 1234 | xun1 1235 | xun2 1236 | xun4 1237 | y 1238 | ya 1239 | ya1 1240 | ya2 1241 | ya3 1242 | ya4 1243 | yan1 1244 | yan2 1245 | yan3 1246 | yan4 1247 | yang1 1248 | yang2 1249 | yang3 1250 | yang4 1251 | yao1 1252 | yao2 1253 | yao3 1254 | yao4 1255 | ye1 1256 | ye2 1257 | ye3 1258 | ye4 1259 | yi 1260 | yi1 1261 | yi2 1262 | yi3 1263 | yi4 1264 | yin1 1265 | yin2 1266 | yin3 1267 | yin4 1268 | ying1 1269 | ying2 1270 | ying3 1271 | ying4 1272 | yo1 1273 | yong1 1274 | yong2 1275 | yong3 1276 | yong4 1277 | you1 1278 | you2 1279 | you3 1280 | you4 1281 | yu1 1282 | yu2 1283 | yu3 1284 | yu4 1285 | yuan1 1286 | yuan2 1287 | yuan3 1288 | yuan4 1289 | yue1 1290 | yue4 1291 | yun1 1292 | yun2 1293 | yun3 1294 | yun4 1295 | z 1296 | za1 1297 | za2 1298 | za3 1299 | zai1 1300 | zai3 1301 | zai4 1302 | zan1 1303 | zan2 1304 | zan3 1305 | zan4 1306 | zang1 1307 | zang4 1308 | zao1 1309 | zao2 1310 | zao3 1311 | zao4 1312 | ze2 1313 | ze4 1314 | zei2 1315 | zen3 1316 | zeng1 1317 | zeng4 1318 | zha1 1319 | zha2 1320 | zha3 1321 | zha4 1322 | zhai1 1323 | zhai2 1324 | zhai3 1325 | zhai4 1326 | zhan1 1327 | zhan2 1328 | zhan3 1329 | zhan4 1330 | zhang1 1331 | zhang2 1332 | zhang3 1333 | zhang4 1334 | zhao1 1335 | zhao2 1336 | zhao3 1337 | zhao4 1338 | zhe 1339 | zhe1 1340 | zhe2 1341 | zhe3 1342 | zhe4 1343 | zhen1 1344 | zhen2 1345 | zhen3 1346 | zhen4 1347 | zheng1 1348 | zheng2 1349 | zheng3 1350 | zheng4 1351 | zhi1 1352 | zhi2 1353 | zhi3 1354 | zhi4 1355 | zhong1 1356 | zhong2 1357 | zhong3 1358 | zhong4 1359 | zhou1 1360 | zhou2 1361 | zhou3 1362 | zhou4 1363 | zhu1 1364 | zhu2 1365 | zhu3 1366 | zhu4 1367 | zhua1 1368 | zhua2 1369 | zhua3 1370 | zhuai1 1371 | zhuai3 1372 | zhuai4 1373 | zhuan1 1374 | zhuan2 1375 | zhuan3 1376 | zhuan4 1377 | zhuang1 1378 | zhuang4 1379 | zhui1 1380 | zhui4 1381 | zhun1 1382 | zhun2 1383 | zhun3 1384 | zhuo1 1385 | zhuo2 1386 | zi 1387 | zi1 1388 | zi2 1389 | zi3 1390 | zi4 1391 | zong1 1392 | zong2 1393 | zong3 1394 | zong4 1395 | zou1 1396 | zou2 1397 | zou3 1398 | zou4 1399 | zu1 1400 | zu2 1401 | zu3 1402 | zuan1 1403 | zuan3 1404 | zuan4 1405 | zui2 1406 | zui3 1407 | zui4 1408 | zun1 1409 | zuo 1410 | zuo1 1411 | zuo2 1412 | zuo3 1413 | zuo4 1414 | { 1415 | ~ 1416 | ¡ 1417 | ¢ 1418 | £ 1419 | ¥ 1420 | § 1421 | ¨ 1422 | © 1423 | « 1424 | ® 1425 | ¯ 1426 | ° 1427 | ± 1428 | ² 1429 | ³ 1430 | ´ 1431 | µ 1432 | · 1433 | ¹ 1434 | º 1435 | » 1436 | ¼ 1437 | ½ 1438 | ¾ 1439 | ¿ 1440 | À 1441 | Á 1442 |  1443 | à 1444 | Ä 1445 | Å 1446 | Æ 1447 | Ç 1448 | È 1449 | É 1450 | Ê 1451 | Í 1452 | Î 1453 | Ñ 1454 | Ó 1455 | Ö 1456 | × 1457 | Ø 1458 | Ú 1459 | Ü 1460 | Ý 1461 | Þ 1462 | ß 1463 | à 1464 | á 1465 | â 1466 | ã 1467 | ä 1468 | å 1469 | æ 1470 | ç 1471 | è 1472 | é 1473 | ê 1474 | ë 1475 | ì 1476 | í 1477 | î 1478 | ï 1479 | ð 1480 | ñ 1481 | ò 1482 | ó 1483 | ô 1484 | õ 1485 | ö 1486 | ø 1487 | ù 1488 | ú 1489 | û 1490 | ü 1491 | ý 1492 | Ā 1493 | ā 1494 | ă 1495 | ą 1496 | ć 1497 | Č 1498 | č 1499 | Đ 1500 | đ 1501 | ē 1502 | ė 1503 | ę 1504 | ě 1505 | ĝ 1506 | ğ 1507 | ħ 1508 | ī 1509 | į 1510 | İ 1511 | ı 1512 | Ł 1513 | ł 1514 | ń 1515 | ņ 1516 | ň 1517 | ŋ 1518 | Ō 1519 | ō 1520 | ő 1521 | œ 1522 | ř 1523 | Ś 1524 | ś 1525 | Ş 1526 | ş 1527 | Š 1528 | š 1529 | Ť 1530 | ť 1531 | ũ 1532 | ū 1533 | ź 1534 | Ż 1535 | ż 1536 | Ž 1537 | ž 1538 | ơ 1539 | ư 1540 | ǎ 1541 | ǐ 1542 | ǒ 1543 | ǔ 1544 | ǚ 1545 | ș 1546 | ț 1547 | ɑ 1548 | ɔ 1549 | ɕ 1550 | ə 1551 | ɛ 1552 | ɜ 1553 | ɡ 1554 | ɣ 1555 | ɪ 1556 | ɫ 1557 | ɴ 1558 | ɹ 1559 | ɾ 1560 | ʃ 1561 | ʊ 1562 | ʌ 1563 | ʒ 1564 | ʔ 1565 | ʰ 1566 | ʷ 1567 | ʻ 1568 | ʾ 1569 | ʿ 1570 | ˈ 1571 | ː 1572 | ˙ 1573 | ˜ 1574 | ˢ 1575 | ́ 1576 | ̅ 1577 | Α 1578 | Β 1579 | Δ 1580 | Ε 1581 | Θ 1582 | Κ 1583 | Λ 1584 | Μ 1585 | Ξ 1586 | Π 1587 | Σ 1588 | Τ 1589 | Φ 1590 | Χ 1591 | Ψ 1592 | Ω 1593 | ά 1594 | έ 1595 | ή 1596 | ί 1597 | α 1598 | β 1599 | γ 1600 | δ 1601 | ε 1602 | ζ 1603 | η 1604 | θ 1605 | ι 1606 | κ 1607 | λ 1608 | μ 1609 | ν 1610 | ξ 1611 | ο 1612 | π 1613 | ρ 1614 | ς 1615 | σ 1616 | τ 1617 | υ 1618 | φ 1619 | χ 1620 | ψ 1621 | ω 1622 | ϊ 1623 | ό 1624 | ύ 1625 | ώ 1626 | ϕ 1627 | ϵ 1628 | Ё 1629 | А 1630 | Б 1631 | В 1632 | Г 1633 | Д 1634 | Е 1635 | Ж 1636 | З 1637 | И 1638 | Й 1639 | К 1640 | Л 1641 | М 1642 | Н 1643 | О 1644 | П 1645 | Р 1646 | С 1647 | Т 1648 | У 1649 | Ф 1650 | Х 1651 | Ц 1652 | Ч 1653 | Ш 1654 | Щ 1655 | Ы 1656 | Ь 1657 | Э 1658 | Ю 1659 | Я 1660 | а 1661 | б 1662 | в 1663 | г 1664 | д 1665 | е 1666 | ж 1667 | з 1668 | и 1669 | й 1670 | к 1671 | л 1672 | м 1673 | н 1674 | о 1675 | п 1676 | р 1677 | с 1678 | т 1679 | у 1680 | ф 1681 | х 1682 | ц 1683 | ч 1684 | ш 1685 | щ 1686 | ъ 1687 | ы 1688 | ь 1689 | э 1690 | ю 1691 | я 1692 | ё 1693 | і 1694 | ְ 1695 | ִ 1696 | ֵ 1697 | ֶ 1698 | ַ 1699 | ָ 1700 | ֹ 1701 | ּ 1702 | ־ 1703 | ׁ 1704 | א 1705 | ב 1706 | ג 1707 | ד 1708 | ה 1709 | ו 1710 | ז 1711 | ח 1712 | ט 1713 | י 1714 | כ 1715 | ל 1716 | ם 1717 | מ 1718 | ן 1719 | נ 1720 | ס 1721 | ע 1722 | פ 1723 | ק 1724 | ר 1725 | ש 1726 | ת 1727 | أ 1728 | ب 1729 | ة 1730 | ت 1731 | ج 1732 | ح 1733 | د 1734 | ر 1735 | ز 1736 | س 1737 | ص 1738 | ط 1739 | ع 1740 | ق 1741 | ك 1742 | ل 1743 | م 1744 | ن 1745 | ه 1746 | و 1747 | ي 1748 | َ 1749 | ُ 1750 | ِ 1751 | ْ 1752 | ก 1753 | ข 1754 | ง 1755 | จ 1756 | ต 1757 | ท 1758 | น 1759 | ป 1760 | ย 1761 | ร 1762 | ว 1763 | ส 1764 | ห 1765 | อ 1766 | ฮ 1767 | ั 1768 | า 1769 | ี 1770 | ึ 1771 | โ 1772 | ใ 1773 | ไ 1774 | ่ 1775 | ้ 1776 | ์ 1777 | ḍ 1778 | Ḥ 1779 | ḥ 1780 | ṁ 1781 | ṃ 1782 | ṅ 1783 | ṇ 1784 | Ṛ 1785 | ṛ 1786 | Ṣ 1787 | ṣ 1788 | Ṭ 1789 | ṭ 1790 | ạ 1791 | ả 1792 | Ấ 1793 | ấ 1794 | ầ 1795 | ậ 1796 | ắ 1797 | ằ 1798 | ẻ 1799 | ẽ 1800 | ế 1801 | ề 1802 | ể 1803 | ễ 1804 | ệ 1805 | ị 1806 | ọ 1807 | ỏ 1808 | ố 1809 | ồ 1810 | ộ 1811 | ớ 1812 | ờ 1813 | ở 1814 | ụ 1815 | ủ 1816 | ứ 1817 | ữ 1818 | ἀ 1819 | ἁ 1820 | Ἀ 1821 | ἐ 1822 | ἔ 1823 | ἰ 1824 | ἱ 1825 | ὀ 1826 | ὁ 1827 | ὐ 1828 | ὲ 1829 | ὸ 1830 | ᾶ 1831 | ᾽ 1832 | ῆ 1833 | ῇ 1834 | ῶ 1835 | ‎ 1836 | ‑ 1837 | ‒ 1838 | – 1839 | — 1840 | ― 1841 | ‖ 1842 | † 1843 | ‡ 1844 | • 1845 | … 1846 | ‧ 1847 | ‬ 1848 | ′ 1849 | ″ 1850 | ⁄ 1851 | ⁡ 1852 | ⁰ 1853 | ⁴ 1854 | ⁵ 1855 | ⁶ 1856 | ⁷ 1857 | ⁸ 1858 | ⁹ 1859 | ₁ 1860 | ₂ 1861 | ₃ 1862 | € 1863 | ₱ 1864 | ₹ 1865 | ₽ 1866 | ℃ 1867 | ℏ 1868 | ℓ 1869 | № 1870 | ℝ 1871 | ™ 1872 | ⅓ 1873 | ⅔ 1874 | ⅛ 1875 | → 1876 | ∂ 1877 | ∈ 1878 | ∑ 1879 | − 1880 | ∗ 1881 | √ 1882 | ∞ 1883 | ∫ 1884 | ≈ 1885 | ≠ 1886 | ≡ 1887 | ≤ 1888 | ≥ 1889 | ⋅ 1890 | ⋯ 1891 | █ 1892 | ♪ 1893 | ⟨ 1894 | ⟩ 1895 | 、 1896 | 。 1897 | 《 1898 | 》 1899 | 「 1900 | 」 1901 | 【 1902 | 】 1903 | あ 1904 | う 1905 | え 1906 | お 1907 | か 1908 | が 1909 | き 1910 | ぎ 1911 | く 1912 | ぐ 1913 | け 1914 | げ 1915 | こ 1916 | ご 1917 | さ 1918 | し 1919 | じ 1920 | す 1921 | ず 1922 | せ 1923 | ぜ 1924 | そ 1925 | ぞ 1926 | た 1927 | だ 1928 | ち 1929 | っ 1930 | つ 1931 | で 1932 | と 1933 | ど 1934 | な 1935 | に 1936 | ね 1937 | の 1938 | は 1939 | ば 1940 | ひ 1941 | ぶ 1942 | へ 1943 | べ 1944 | ま 1945 | み 1946 | む 1947 | め 1948 | も 1949 | ゃ 1950 | や 1951 | ゆ 1952 | ょ 1953 | よ 1954 | ら 1955 | り 1956 | る 1957 | れ 1958 | ろ 1959 | わ 1960 | を 1961 | ん 1962 | ァ 1963 | ア 1964 | ィ 1965 | イ 1966 | ウ 1967 | ェ 1968 | エ 1969 | オ 1970 | カ 1971 | ガ 1972 | キ 1973 | ク 1974 | ケ 1975 | ゲ 1976 | コ 1977 | ゴ 1978 | サ 1979 | ザ 1980 | シ 1981 | ジ 1982 | ス 1983 | ズ 1984 | セ 1985 | ゾ 1986 | タ 1987 | ダ 1988 | チ 1989 | ッ 1990 | ツ 1991 | テ 1992 | デ 1993 | ト 1994 | ド 1995 | ナ 1996 | ニ 1997 | ネ 1998 | ノ 1999 | バ 2000 | パ 2001 | ビ 2002 | ピ 2003 | フ 2004 | プ 2005 | ヘ 2006 | ベ 2007 | ペ 2008 | ホ 2009 | ボ 2010 | ポ 2011 | マ 2012 | ミ 2013 | ム 2014 | メ 2015 | モ 2016 | ャ 2017 | ヤ 2018 | ュ 2019 | ユ 2020 | ョ 2021 | ヨ 2022 | ラ 2023 | リ 2024 | ル 2025 | レ 2026 | ロ 2027 | ワ 2028 | ン 2029 | ・ 2030 | ー 2031 | ㄋ 2032 | ㄍ 2033 | ㄎ 2034 | ㄏ 2035 | ㄓ 2036 | ㄕ 2037 | ㄚ 2038 | ㄜ 2039 | ㄟ 2040 | ㄤ 2041 | ㄥ 2042 | ㄧ 2043 | ㄱ 2044 | ㄴ 2045 | ㄷ 2046 | ㄹ 2047 | ㅁ 2048 | ㅂ 2049 | ㅅ 2050 | ㅈ 2051 | ㅍ 2052 | ㅎ 2053 | ㅏ 2054 | ㅓ 2055 | ㅗ 2056 | ㅜ 2057 | ㅡ 2058 | ㅣ 2059 | 㗎 2060 | 가 2061 | 각 2062 | 간 2063 | 갈 2064 | 감 2065 | 갑 2066 | 갓 2067 | 갔 2068 | 강 2069 | 같 2070 | 개 2071 | 거 2072 | 건 2073 | 걸 2074 | 겁 2075 | 것 2076 | 겉 2077 | 게 2078 | 겠 2079 | 겨 2080 | 결 2081 | 겼 2082 | 경 2083 | 계 2084 | 고 2085 | 곤 2086 | 골 2087 | 곱 2088 | 공 2089 | 과 2090 | 관 2091 | 광 2092 | 교 2093 | 구 2094 | 국 2095 | 굴 2096 | 귀 2097 | 귄 2098 | 그 2099 | 근 2100 | 글 2101 | 금 2102 | 기 2103 | 긴 2104 | 길 2105 | 까 2106 | 깍 2107 | 깔 2108 | 깜 2109 | 깨 2110 | 께 2111 | 꼬 2112 | 꼭 2113 | 꽃 2114 | 꾸 2115 | 꿔 2116 | 끔 2117 | 끗 2118 | 끝 2119 | 끼 2120 | 나 2121 | 난 2122 | 날 2123 | 남 2124 | 납 2125 | 내 2126 | 냐 2127 | 냥 2128 | 너 2129 | 넘 2130 | 넣 2131 | 네 2132 | 녁 2133 | 년 2134 | 녕 2135 | 노 2136 | 녹 2137 | 놀 2138 | 누 2139 | 눈 2140 | 느 2141 | 는 2142 | 늘 2143 | 니 2144 | 님 2145 | 닙 2146 | 다 2147 | 닥 2148 | 단 2149 | 달 2150 | 닭 2151 | 당 2152 | 대 2153 | 더 2154 | 덕 2155 | 던 2156 | 덥 2157 | 데 2158 | 도 2159 | 독 2160 | 동 2161 | 돼 2162 | 됐 2163 | 되 2164 | 된 2165 | 될 2166 | 두 2167 | 둑 2168 | 둥 2169 | 드 2170 | 들 2171 | 등 2172 | 디 2173 | 따 2174 | 딱 2175 | 딸 2176 | 땅 2177 | 때 2178 | 떤 2179 | 떨 2180 | 떻 2181 | 또 2182 | 똑 2183 | 뚱 2184 | 뛰 2185 | 뜻 2186 | 띠 2187 | 라 2188 | 락 2189 | 란 2190 | 람 2191 | 랍 2192 | 랑 2193 | 래 2194 | 랜 2195 | 러 2196 | 런 2197 | 럼 2198 | 렇 2199 | 레 2200 | 려 2201 | 력 2202 | 렵 2203 | 렸 2204 | 로 2205 | 록 2206 | 롬 2207 | 루 2208 | 르 2209 | 른 2210 | 를 2211 | 름 2212 | 릉 2213 | 리 2214 | 릴 2215 | 림 2216 | 마 2217 | 막 2218 | 만 2219 | 많 2220 | 말 2221 | 맑 2222 | 맙 2223 | 맛 2224 | 매 2225 | 머 2226 | 먹 2227 | 멍 2228 | 메 2229 | 면 2230 | 명 2231 | 몇 2232 | 모 2233 | 목 2234 | 몸 2235 | 못 2236 | 무 2237 | 문 2238 | 물 2239 | 뭐 2240 | 뭘 2241 | 미 2242 | 민 2243 | 밌 2244 | 밑 2245 | 바 2246 | 박 2247 | 밖 2248 | 반 2249 | 받 2250 | 발 2251 | 밤 2252 | 밥 2253 | 방 2254 | 배 2255 | 백 2256 | 밸 2257 | 뱀 2258 | 버 2259 | 번 2260 | 벌 2261 | 벚 2262 | 베 2263 | 벼 2264 | 벽 2265 | 별 2266 | 병 2267 | 보 2268 | 복 2269 | 본 2270 | 볼 2271 | 봐 2272 | 봤 2273 | 부 2274 | 분 2275 | 불 2276 | 비 2277 | 빔 2278 | 빛 2279 | 빠 2280 | 빨 2281 | 뼈 2282 | 뽀 2283 | 뿅 2284 | 쁘 2285 | 사 2286 | 산 2287 | 살 2288 | 삼 2289 | 샀 2290 | 상 2291 | 새 2292 | 색 2293 | 생 2294 | 서 2295 | 선 2296 | 설 2297 | 섭 2298 | 섰 2299 | 성 2300 | 세 2301 | 셔 2302 | 션 2303 | 셨 2304 | 소 2305 | 속 2306 | 손 2307 | 송 2308 | 수 2309 | 숙 2310 | 순 2311 | 술 2312 | 숫 2313 | 숭 2314 | 숲 2315 | 쉬 2316 | 쉽 2317 | 스 2318 | 슨 2319 | 습 2320 | 슷 2321 | 시 2322 | 식 2323 | 신 2324 | 실 2325 | 싫 2326 | 심 2327 | 십 2328 | 싶 2329 | 싸 2330 | 써 2331 | 쓰 2332 | 쓴 2333 | 씌 2334 | 씨 2335 | 씩 2336 | 씬 2337 | 아 2338 | 악 2339 | 안 2340 | 않 2341 | 알 2342 | 야 2343 | 약 2344 | 얀 2345 | 양 2346 | 얘 2347 | 어 2348 | 언 2349 | 얼 2350 | 엄 2351 | 업 2352 | 없 2353 | 었 2354 | 엉 2355 | 에 2356 | 여 2357 | 역 2358 | 연 2359 | 염 2360 | 엽 2361 | 영 2362 | 옆 2363 | 예 2364 | 옛 2365 | 오 2366 | 온 2367 | 올 2368 | 옷 2369 | 옹 2370 | 와 2371 | 왔 2372 | 왜 2373 | 요 2374 | 욕 2375 | 용 2376 | 우 2377 | 운 2378 | 울 2379 | 웃 2380 | 워 2381 | 원 2382 | 월 2383 | 웠 2384 | 위 2385 | 윙 2386 | 유 2387 | 육 2388 | 윤 2389 | 으 2390 | 은 2391 | 을 2392 | 음 2393 | 응 2394 | 의 2395 | 이 2396 | 익 2397 | 인 2398 | 일 2399 | 읽 2400 | 임 2401 | 입 2402 | 있 2403 | 자 2404 | 작 2405 | 잔 2406 | 잖 2407 | 잘 2408 | 잡 2409 | 잤 2410 | 장 2411 | 재 2412 | 저 2413 | 전 2414 | 점 2415 | 정 2416 | 제 2417 | 져 2418 | 졌 2419 | 조 2420 | 족 2421 | 좀 2422 | 종 2423 | 좋 2424 | 죠 2425 | 주 2426 | 준 2427 | 줄 2428 | 중 2429 | 줘 2430 | 즈 2431 | 즐 2432 | 즘 2433 | 지 2434 | 진 2435 | 집 2436 | 짜 2437 | 짝 2438 | 쩌 2439 | 쪼 2440 | 쪽 2441 | 쫌 2442 | 쭈 2443 | 쯔 2444 | 찌 2445 | 찍 2446 | 차 2447 | 착 2448 | 찾 2449 | 책 2450 | 처 2451 | 천 2452 | 철 2453 | 체 2454 | 쳐 2455 | 쳤 2456 | 초 2457 | 촌 2458 | 추 2459 | 출 2460 | 춤 2461 | 춥 2462 | 춰 2463 | 치 2464 | 친 2465 | 칠 2466 | 침 2467 | 칩 2468 | 칼 2469 | 커 2470 | 켓 2471 | 코 2472 | 콩 2473 | 쿠 2474 | 퀴 2475 | 크 2476 | 큰 2477 | 큽 2478 | 키 2479 | 킨 2480 | 타 2481 | 태 2482 | 터 2483 | 턴 2484 | 털 2485 | 테 2486 | 토 2487 | 통 2488 | 투 2489 | 트 2490 | 특 2491 | 튼 2492 | 틀 2493 | 티 2494 | 팀 2495 | 파 2496 | 팔 2497 | 패 2498 | 페 2499 | 펜 2500 | 펭 2501 | 평 2502 | 포 2503 | 폭 2504 | 표 2505 | 품 2506 | 풍 2507 | 프 2508 | 플 2509 | 피 2510 | 필 2511 | 하 2512 | 학 2513 | 한 2514 | 할 2515 | 함 2516 | 합 2517 | 항 2518 | 해 2519 | 햇 2520 | 했 2521 | 행 2522 | 허 2523 | 험 2524 | 형 2525 | 혜 2526 | 호 2527 | 혼 2528 | 홀 2529 | 화 2530 | 회 2531 | 획 2532 | 후 2533 | 휴 2534 | 흐 2535 | 흔 2536 | 희 2537 | 히 2538 | 힘 2539 | ﷺ 2540 | ﷻ 2541 | ! 2542 | , 2543 | ? 2544 | � 2545 | 𠮶 2546 | --------------------------------------------------------------------------------