├── .gitignore ├── LICENSE ├── README.md ├── api.py ├── configs ├── E2TTS_Base_train.yaml ├── E2TTS_Small_train.yaml ├── F5TTS_Base_train.yaml └── F5TTS_Small_train.yaml ├── data ├── Emilia_ZH_EN_pinyin │ └── vocab.txt └── librispeech_pc_test_clean_cross_sentence.lst ├── f5_tts ├── api.py ├── configs │ ├── E2TTS_Base_train.yaml │ ├── E2TTS_Small_train.yaml │ ├── F5TTS_Base_train.yaml │ └── F5TTS_Small_train.yaml ├── eval │ ├── README.md │ ├── ecapa_tdnn.py │ ├── eval_infer_batch.py │ ├── eval_infer_batch.sh │ ├── eval_librispeech_test_clean.py │ ├── eval_seedtts_testset.py │ ├── eval_utmos.py │ └── utils_eval.py ├── infer │ ├── README.md │ ├── SHARED.md │ ├── examples │ │ ├── basic │ │ │ ├── basic.toml │ │ │ ├── basic_ref_en.wav │ │ │ └── basic_ref_zh.wav │ │ ├── multi │ │ │ ├── country.flac │ │ │ ├── main.flac │ │ │ ├── story.toml │ │ │ ├── story.txt │ │ │ └── town.flac │ │ └── vocab.txt │ ├── infer_cli.py │ ├── infer_gradio.py │ ├── speech_edit.py │ └── utils_infer.py ├── model │ ├── __init__.py │ ├── backbones │ │ ├── README.md │ │ ├── dit.py │ │ ├── mmdit.py │ │ └── unett.py │ ├── cfm.py │ ├── dataset.py │ ├── modules.py │ ├── trainer.py │ └── utils.py ├── scripts │ ├── count_max_epoch.py │ └── count_params_gflops.py ├── socket_server.py └── train │ ├── README.md │ ├── datasets │ ├── prepare_csv_wavs.py │ ├── prepare_emilia.py │ ├── prepare_libritts.py │ ├── prepare_ljspeech.py │ └── prepare_wenetspeech4tts.py │ ├── finetune_cli.py │ ├── finetune_gradio.py │ └── train.py ├── requirements.txt ├── run-api.bat ├── run-webui.bat ├── runtest.bat ├── test.py ├── testcuda.py └── 测试GPU是否可用.bat /.gitignore: -------------------------------------------------------------------------------- 1 | # Customed 2 | .vscode/ 3 | tests/ 4 | runs/ 5 | ckpts/ 6 | wandb/ 7 | results/ 8 | tmp/ 9 | modelscache/ 10 | runtime/ 11 | run-test.bat 12 | *.7z 13 | *.mp3 14 | *.wav 15 | docs 16 | *.exe 17 | 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | cover/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | .pybuilder/ 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | # For a library or package, you might want to ignore these files since the code is 105 | # intended to run in multiple environments; otherwise, check them in: 106 | # .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # poetry 116 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 117 | # This is especially recommended for binary packages to ensure reproducibility, and is more 118 | # commonly ignored for libraries. 119 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 120 | #poetry.lock 121 | 122 | # pdm 123 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 124 | #pdm.lock 125 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 126 | # in version control. 127 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 128 | .pdm.toml 129 | .pdm-python 130 | .pdm-build/ 131 | 132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 133 | __pypackages__/ 134 | 135 | # Celery stuff 136 | celerybeat-schedule 137 | celerybeat.pid 138 | 139 | # SageMath parsed files 140 | *.sage.py 141 | 142 | # Environments 143 | .env 144 | .venv 145 | env/ 146 | venv/ 147 | ENV/ 148 | env.bak/ 149 | venv.bak/ 150 | 151 | # Spyder project settings 152 | .spyderproject 153 | .spyproject 154 | 155 | # Rope project settings 156 | .ropeproject 157 | 158 | # mkdocs documentation 159 | /site 160 | 161 | # mypy 162 | .mypy_cache/ 163 | .dmypy.json 164 | dmypy.json 165 | 166 | # Pyre type checker 167 | .pyre/ 168 | 169 | # pytype static type analyzer 170 | .pytype/ 171 | 172 | # Cython debug symbols 173 | cython_debug/ 174 | 175 | # PyCharm 176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 178 | # and can be added to the global gitignore or merged into this file. For a more nuclear 179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 180 | #.idea/ 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yushen CHEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # F5-TTS-api 2 | 3 | 这是用于 [F5-TTS](https://github.com/SWivid/F5-TTS) 项目的api 4 | 5 | > F5-TTS是由上海交通大学开源的一款基于流匹配的全非自回归文本到语音转换系统(Text-to-Speech,TTS)。它以其高效、自然和多语言支持的特点脱颖而出 6 | 7 | ## 功能 8 | 9 | - 提供api接口文件 `api.py`,可对接于视频翻译项目 [pyvideotrans](https://github.com/jianchang512/pyvideotrans) 10 | - 提供兼容 OpenAI TTS的接口 11 | - 提供windows下整合包 12 | 13 | 14 | ## 整合包部署5G(包含f5-tts模型及环境) 15 | 16 | > 整合包仅可用于 Windows10/11, 下载后解压即用 17 | > 18 | > 下载地址 https://www.123684.com/s/03Sxjv-okTJ3 19 | 20 | 1. 启动Api服务: 双击 `run-api.bat`,接口地址是 `http://127.0.0.1:5010/api` 21 | 22 | > 整合包默认使用 cuda11.8版本,若有英伟达显卡,并且已安装配置好CUDA/cuDNN环境,将自动使用GPU加速 23 | 24 | ## 在第三方整合包内使用 api.py 25 | 26 | 1. 将 api.py 和 configs 文件夹复制到三方整合包内根目录内 27 | 2. 查看三方整合包集成的 python.exe 路径,例如在 py311 文件夹内,那么在根目录下文件夹地址栏内输入 `cmd`回车,接着执行命令 28 | `.\py311\python api.py` ,如果提示`module flask not found`,则先执行 `.\py311\python -m pip install waitress flask` 29 | 30 | ## 源码部署F5-TTS官方项目后使用 api.py 31 | 32 | 1. 将 api.py 和 configs 文件夹复制到项目文件夹内 33 | 2. 安装模块 `pip install flask waitress` 34 | 3. 执行 `python api.py` 35 | 36 | 37 | 38 | ## 使用注意/代理VPN 39 | 40 | 1. 模型需要从 `huggingface.co`网站在线下载,该站点无法在国内访问,请提前开启系统代理或全局代理,否则模型会下载失败 41 | 42 | 43 | ## 在视频翻译软件中使用 44 | 45 | 1. 启动api服务 46 | 2. 打开视频翻译软件,找到菜单-TTS设置-F5-TTS,填写api地址,如果未修改过地址,填写`http://127.0.0.1:5010` 47 | 3. 填写参考音频和音频内文本 48 | 49 | ![](https://pyvideotrans.com/img/f5002.jpg) 50 | 51 | 52 | 53 | 54 | ## API 使用示例 55 | 56 | ``` 57 | import requests 58 | 59 | res=requests.post('http://127.0.0.1:5010/api',data={ 60 | "ref_text": '这里填写 1.wav 中对应的文字内容', 61 | "gen_text": '''这里填写要生成的文本。''', 62 | "model": 'f5-tts' 63 | },files={"audio":open('./1.wav','rb')}) 64 | 65 | if res.status_code!=200: 66 | print(res.text) 67 | exit() 68 | 69 | with open("ceshi.wav",'wb') as f: 70 | f.write(res.content) 71 | ``` 72 | 73 | ## 兼容openai tts接口 74 | 75 | `voice` 参数必须用3个#号分割参考音频和参考音频对应的文本,例如 76 | 77 | `1.wav###你说四大皆空,却为何紧闭双眼,若你睁开眼睛看看我,我不相信你,两眼空空。` 78 | 表示参考音频是 1.wav 和 api.py位于同一位置,1.wav里的文本内容是 "你说四大皆空,却为何紧闭双眼,若你睁开眼睛看看我,我不相信你,两眼空空。" 79 | 80 | 返回数据固定为wav音频数据 81 | 82 | ``` 83 | import requests 84 | import json 85 | import os 86 | import base64 87 | import struct 88 | 89 | 90 | from openai import OpenAI 91 | 92 | client = OpenAI(api_key='12314', base_url='http://127.0.0.1:5010/v1') 93 | with client.audio.speech.with_streaming_response.create( 94 | model='f5-tts', 95 | voice='1.wav###你说四大皆空,却为何紧闭双眼,若你睁开眼睛看看我,我不相信你,两眼空空。', 96 | input='你好啊,亲爱的朋友们', 97 | speed=1.0 98 | ) as response: 99 | with open('./test.wav', 'wb') as f: 100 | for chunk in response.iter_bytes(): 101 | f.write(chunk) 102 | 103 | ``` -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | 2 | import os,time,sys 3 | from pathlib import Path 4 | ROOT_DIR=Path(__file__).parent.as_posix() 5 | 6 | # ffmpeg 7 | if sys.platform == 'win32': 8 | os.environ['PATH'] = ROOT_DIR + f';{ROOT_DIR}\\ffmpeg;' + os.environ['PATH'] 9 | else: 10 | os.environ['PATH'] = ROOT_DIR + f':{ROOT_DIR}/ffmpeg:' + os.environ['PATH'] 11 | 12 | SANFANG=True 13 | if Path(f"{ROOT_DIR}/modelscache").exists(): 14 | SANFANG=False 15 | os.environ['HF_HOME']=Path(f"{ROOT_DIR}/modelscache").as_posix() 16 | 17 | 18 | import re 19 | import torch 20 | from torch.backends import cudnn 21 | import torchaudio 22 | import numpy as np 23 | from flask import Flask, request, jsonify, send_file, render_template 24 | from flask_cors import CORS 25 | from einops import rearrange 26 | from vocos import Vocos 27 | from pydub import AudioSegment, silence 28 | 29 | from cached_path import cached_path 30 | 31 | import soundfile as sf 32 | import io 33 | import tempfile 34 | import logging 35 | import traceback 36 | from waitress import serve 37 | from importlib.resources import files 38 | from omegaconf import OmegaConf 39 | 40 | from f5_tts.infer.utils_infer import ( 41 | infer_process, 42 | load_model, 43 | load_vocoder, 44 | preprocess_ref_audio_text, 45 | remove_silence_for_generated_wav, 46 | ) 47 | from f5_tts.model import DiT, UNetT 48 | 49 | 50 | TMPDIR=(Path(__file__).parent/'tmp').as_posix() 51 | Path(TMPDIR).mkdir(exist_ok=True) 52 | 53 | # Set up logging 54 | logging.basicConfig(level=logging.INFO) 55 | logger = logging.getLogger(__name__) 56 | 57 | app = Flask(__name__, template_folder='templates') 58 | CORS(app) 59 | 60 | # --------------------- Settings -------------------- # 61 | 62 | 63 | 64 | # Add this near the top of the file, after other imports 65 | UPLOAD_FOLDER = 'data' 66 | if not os.path.exists(UPLOAD_FOLDER): 67 | os.makedirs(UPLOAD_FOLDER) 68 | 69 | 70 | def load_model2(repo_name='F5-TTS',vocoder_name='vocos'): 71 | mel_spec_type = vocoder_name 72 | model_cfg = f"{ROOT_DIR}/configs/F5TTS_Base_train.yaml" 73 | 74 | model_cfg = OmegaConf.load(model_cfg).model.arch 75 | model_cls = DiT 76 | 77 | ckpt_file = "" 78 | vocab_file='' 79 | remove_silence=False 80 | speed=1.0 81 | if repo_name=='F5-TTS': 82 | if vocoder_name == "vocos": 83 | repo_name = "F5-TTS" 84 | exp_name = "F5TTS_Base" 85 | ckpt_step = 1200000 86 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) 87 | 88 | elif vocoder_name == "bigvgan": 89 | repo_name = "F5-TTS" 90 | exp_name = "F5TTS_Base_bigvgan" 91 | ckpt_step = 1250000 92 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) 93 | else: 94 | mel_spec_type='vocos' 95 | model_cls = UNetT 96 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 97 | if ckpt_file == "": 98 | repo_name = "E2-TTS" 99 | exp_name = "E2TTS_Base" 100 | ckpt_step = 1200000 101 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) 102 | model=load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file) 103 | return model 104 | 105 | 106 | 107 | # Dictionary to store loaded models 108 | loaded_models = {} 109 | 110 | 111 | 112 | @app.route('/api', methods=['POST']) 113 | def api(): 114 | logger.info("Accessing generate_audio route") 115 | ref_text = request.form.get('ref_text') 116 | gen_text = request.form.get('gen_text') 117 | remove_silence = int(request.form.get('remove_silence',0)) 118 | 119 | speed = float(request.form.get('speed',1.0)) 120 | model_choice = 'F5-TTS' 121 | vocoder_name = request.form.get('vocoder_name','vocos') 122 | 123 | 124 | if not all([ref_text, gen_text, model_choice]): # Include audio_filename in the check 125 | return jsonify({"error": "Missing required parameters"}), 400 126 | 127 | audio_file = request.files['audio'] 128 | if audio_file.filename == '': 129 | logger.error("No audio file selected") 130 | return jsonify({"error": "No audio file selected"}), 400 131 | 132 | 133 | logger.info(f"Processing audio file: {audio_file.filename}") 134 | audio_name=f'{TMPDIR}/{time.time()}-{audio_file.filename}' 135 | audio_file.save(audio_name) 136 | 137 | try: 138 | 139 | if model_choice not in loaded_models: 140 | loaded_models[model_choice] = load_model2(repo_name=model_choice) 141 | 142 | 143 | model = loaded_models[model_choice] 144 | if vocoder_name == "vocos": 145 | 146 | vocoder = load_vocoder(vocoder_name=vocoder_name, 147 | is_local=True if not SANFANG else False, 148 | local_path='./modelscache/hub/models--charactr--vocos-mel-24khz/snapshots/0feb3fdd929bcd6649e0e7c5a688cf7dd012ef21/' if not SANFANG else None 149 | ) 150 | elif vocoder_name == "bigvgan": 151 | vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=False, local_path="./checkpoints/bigvgan_v2_24khz_100band_256x") 152 | 153 | 154 | main_voice = {"ref_audio": audio_name, "ref_text": ref_text} 155 | voices = {"main": main_voice} 156 | for voice in voices: 157 | voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( 158 | voices[voice]["ref_audio"], voices[voice]["ref_text"] 159 | ) 160 | print("Voice:", voice) 161 | print("Ref_audio:", voices[voice]["ref_audio"]) 162 | print("Ref_text:", voices[voice]["ref_text"]) 163 | 164 | generated_audio_segments = [] 165 | reg1 = r"(?=\[\w+\])" 166 | chunks = re.split(reg1, gen_text) 167 | reg2 = r"\[(\w+)\]" 168 | for text in chunks: 169 | if not text.strip(): 170 | continue 171 | match = re.match(reg2, text) 172 | if match: 173 | voice = match[1] 174 | else: 175 | print("No voice tag found, using main.") 176 | voice = "main" 177 | if voice not in voices: 178 | print(f"Voice {voice} not found, using main.") 179 | voice = "main" 180 | text = re.sub(reg2, "", text) 181 | gen_text = text.strip() 182 | ref_audio = voices[voice]["ref_audio"] 183 | ref_text = voices[voice]["ref_text"] 184 | print(f"Voice: {voice}") 185 | 186 | audio, final_sample_rate, spectragram = infer_process( 187 | ref_audio, ref_text, gen_text, model, vocoder, mel_spec_type=vocoder_name, speed=speed 188 | ) 189 | generated_audio_segments.append(audio) 190 | 191 | # 192 | if generated_audio_segments: 193 | final_wave = np.concatenate(generated_audio_segments) 194 | 195 | 196 | wave_path=TMPDIR+f'/out-{time.time()}.wav' 197 | print(f'{wave_path=}') 198 | with open(wave_path, "wb") as f: 199 | sf.write(f.name, final_wave, final_sample_rate) 200 | if remove_silence==1: 201 | remove_silence_for_generated_wav(f.name) 202 | print(f.name) 203 | 204 | return send_file(wave_path, mimetype="audio/wav", as_attachment=True, download_name=audio_file.filename) 205 | 206 | except Exception as e: 207 | logger.error(f"Error generating audio: {str(e)}", exc_info=True) 208 | return jsonify({"error": str(e)}), 500 209 | 210 | 211 | 212 | @app.route('/v1/audio/speech', methods=['POST']) 213 | def audio_speech(): 214 | """ 215 | 兼容 OpenAI /v1/audio/speech API 的接口 216 | """ 217 | if not request.is_json: 218 | return jsonify({"error": "请求必须是 JSON 格式"}), 400 219 | 220 | data = request.get_json() 221 | 222 | # 检查请求中是否包含必要的参数 223 | if 'input' not in data or 'voice' not in data: 224 | return jsonify({"error": "请求缺少必要的参数: input, voice"}), 400 225 | 226 | 227 | gen_text = data.get('input') 228 | speed = float(data.get('speed',1.0)) 229 | 230 | # 参考音频 231 | voice = data.get('voice','') 232 | 233 | audio_file,ref_text=voice.split('###') 234 | 235 | if not Path(audio_file).exists() or not Path(f'{ROOT_DIR}/{audio_file}').exists(): 236 | return jsonify({"error": {"message": f"必须填写'参考音频路径###参考音频文本'", "type": e.__class__.__name__, "param": f'speed={speed},voice={voice},input={gen_text}', "code": 400}}), 500 237 | 238 | model_choice='F5-TTS' 239 | 240 | try: 241 | if model_choice not in loaded_models: 242 | loaded_models[model_choice] = load_model2(repo_name=model_choice) 243 | 244 | 245 | model = loaded_models[model_choice] 246 | 247 | vocoder = load_vocoder(vocoder_name='vocos', 248 | is_local=True if not SANFANG else False, 249 | local_path='./modelscache/hub/models--charactr--vocos-mel-24khz/snapshots/0feb3fdd929bcd6649e0e7c5a688cf7dd012ef21/' if not SANFANG else None 250 | ) 251 | 252 | 253 | main_voice = {"ref_audio": audio_file, "ref_text": ref_text} 254 | voices = {"main": main_voice} 255 | for voice in voices: 256 | voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( 257 | voices[voice]["ref_audio"], voices[voice]["ref_text"] 258 | ) 259 | print("Voice:", voice) 260 | print("Ref_audio:", voices[voice]["ref_audio"]) 261 | print("Ref_text:", voices[voice]["ref_text"]) 262 | 263 | generated_audio_segments = [] 264 | reg1 = r"(?=\[\w+\])" 265 | chunks = re.split(reg1, gen_text) 266 | reg2 = r"\[(\w+)\]" 267 | for text in chunks: 268 | if not text.strip(): 269 | continue 270 | match = re.match(reg2, text) 271 | if match: 272 | voice = match[1] 273 | else: 274 | print("No voice tag found, using main.") 275 | voice = "main" 276 | if voice not in voices: 277 | print(f"Voice {voice} not found, using main.") 278 | voice = "main" 279 | text = re.sub(reg2, "", text) 280 | gen_text = text.strip() 281 | ref_audio = voices[voice]["ref_audio"] 282 | ref_text = voices[voice]["ref_text"] 283 | print(f"Voice: {voice}") 284 | 285 | audio, final_sample_rate, spectragram = infer_process( 286 | ref_audio, ref_text, gen_text, model, vocoder, mel_spec_type='vocos', speed=speed 287 | ) 288 | generated_audio_segments.append(audio) 289 | 290 | # 291 | if generated_audio_segments: 292 | final_wave = np.concatenate(generated_audio_segments) 293 | 294 | 295 | wave_path=TMPDIR+f'/openai-{time.time()}.wav' 296 | print(f'{wave_path=}') 297 | with open(wave_path, "wb") as f: 298 | sf.write(f.name, final_wave, final_sample_rate) 299 | print(f.name) 300 | 301 | return send_file(wave_path, mimetype="audio/x-wav") 302 | except Exception as e: 303 | return jsonify({"error": {"message": f"{e}", "type": e.__class__.__name__, "param": f'speed={speed},voice={voice},input={text}', "code": 400}}), 500 304 | 305 | 306 | 307 | if __name__ == '__main__': 308 | try: 309 | host="127.0.0.1" 310 | port=5010 311 | print(f"api接口地址 http://{host}:{port}") 312 | serve(app,host=host, port=port) 313 | except Exception as e: 314 | logger.error(f"An error occurred: {str(e)}") 315 | logger.error(traceback.format_exc()) -------------------------------------------------------------------------------- /configs/E2TTS_Base_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN # dataset name 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 # gradient clipping 18 | bnb_optimizer: False # use bnb 8bit AdamW optimizer or not 19 | 20 | model: 21 | name: E2TTS_Base 22 | tokenizer: pinyin 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 1024 26 | depth: 24 27 | heads: 16 28 | ff_mult: 4 29 | mel_spec: 30 | target_sample_rate: 24000 31 | n_mel_channels: 100 32 | hop_length: 256 33 | win_length: 1024 34 | n_fft: 1024 35 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 36 | vocoder: 37 | is_local: False # use local offline ckpt or not 38 | local_path: None # local vocoder path 39 | 40 | ckpts: 41 | logger: wandb # wandb | tensorboard | None 42 | save_per_updates: 50000 # save checkpoint per steps 43 | last_per_steps: 5000 # save last checkpoint per steps 44 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /configs/E2TTS_Small_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 18 | bnb_optimizer: False 19 | 20 | model: 21 | name: E2TTS_Small 22 | tokenizer: pinyin 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 768 26 | depth: 20 27 | heads: 12 28 | ff_mult: 4 29 | mel_spec: 30 | target_sample_rate: 24000 31 | n_mel_channels: 100 32 | hop_length: 256 33 | win_length: 1024 34 | n_fft: 1024 35 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 36 | vocoder: 37 | is_local: False # use local offline ckpt or not 38 | local_path: None # local vocoder path 39 | 40 | ckpts: 41 | logger: wandb # wandb | tensorboard | None 42 | save_per_updates: 50000 # save checkpoint per steps 43 | last_per_steps: 5000 # save last checkpoint per steps 44 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /configs/F5TTS_Base_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN # dataset name 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 # gradient clipping 18 | bnb_optimizer: False # use bnb 8bit AdamW optimizer or not 19 | 20 | model: 21 | name: F5TTS_Base # model name 22 | tokenizer: pinyin # tokenizer type 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 1024 26 | depth: 22 27 | heads: 16 28 | ff_mult: 2 29 | text_dim: 512 30 | conv_layers: 4 31 | checkpoint_activations: False # recompute activations and save memory for extra compute 32 | mel_spec: 33 | target_sample_rate: 24000 34 | n_mel_channels: 100 35 | hop_length: 256 36 | win_length: 1024 37 | n_fft: 1024 38 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 39 | vocoder: 40 | is_local: False # use local offline ckpt or not 41 | local_path: None # local vocoder path 42 | 43 | ckpts: 44 | logger: wandb # wandb | tensorboard | None 45 | save_per_updates: 50000 # save checkpoint per steps 46 | last_per_steps: 5000 # save last checkpoint per steps 47 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /configs/F5TTS_Small_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 # gradient clipping 18 | bnb_optimizer: False # use bnb 8bit AdamW optimizer or not 19 | 20 | model: 21 | name: F5TTS_Small 22 | tokenizer: pinyin 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 768 26 | depth: 18 27 | heads: 12 28 | ff_mult: 2 29 | text_dim: 512 30 | conv_layers: 4 31 | checkpoint_activations: False # recompute activations and save memory for extra compute 32 | mel_spec: 33 | target_sample_rate: 24000 34 | n_mel_channels: 100 35 | hop_length: 256 36 | win_length: 1024 37 | n_fft: 1024 38 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 39 | vocoder: 40 | is_local: False # use local offline ckpt or not 41 | local_path: None # local vocoder path 42 | 43 | ckpts: 44 | logger: wandb # wandb | tensorboard | None 45 | save_per_updates: 50000 # save checkpoint per steps 46 | last_per_steps: 5000 # save last checkpoint per steps 47 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /f5_tts/api.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from importlib.resources import files 4 | 5 | import soundfile as sf 6 | import tqdm 7 | from cached_path import cached_path 8 | 9 | from f5_tts.infer.utils_infer import ( 10 | hop_length, 11 | infer_process, 12 | load_model, 13 | load_vocoder, 14 | preprocess_ref_audio_text, 15 | remove_silence_for_generated_wav, 16 | save_spectrogram, 17 | transcribe, 18 | target_sample_rate, 19 | ) 20 | from f5_tts.model import DiT, UNetT 21 | from f5_tts.model.utils import seed_everything 22 | 23 | 24 | class F5TTS: 25 | def __init__( 26 | self, 27 | model_type="F5-TTS", 28 | ckpt_file="", 29 | vocab_file="", 30 | ode_method="euler", 31 | use_ema=True, 32 | vocoder_name="vocos", 33 | local_path=None, 34 | device=None, 35 | hf_cache_dir=None, 36 | ): 37 | # Initialize parameters 38 | self.final_wave = None 39 | self.target_sample_rate = target_sample_rate 40 | self.hop_length = hop_length 41 | self.seed = -1 42 | self.mel_spec_type = vocoder_name 43 | 44 | # Set device 45 | if device is not None: 46 | self.device = device 47 | else: 48 | import torch 49 | 50 | self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 51 | 52 | # Load models 53 | self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) 54 | self.load_ema_model( 55 | model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir 56 | ) 57 | 58 | def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None): 59 | self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir) 60 | 61 | def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None): 62 | if model_type == "F5-TTS": 63 | if not ckpt_file: 64 | if mel_spec_type == "vocos": 65 | ckpt_file = str( 66 | cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) 67 | ) 68 | elif mel_spec_type == "bigvgan": 69 | ckpt_file = str( 70 | cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir) 71 | ) 72 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 73 | model_cls = DiT 74 | elif model_type == "E2-TTS": 75 | if not ckpt_file: 76 | ckpt_file = str( 77 | cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) 78 | ) 79 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 80 | model_cls = UNetT 81 | else: 82 | raise ValueError(f"Unknown model type: {model_type}") 83 | 84 | self.ema_model = load_model( 85 | model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device 86 | ) 87 | 88 | def transcribe(self, ref_audio, language=None): 89 | return transcribe(ref_audio, language) 90 | 91 | def export_wav(self, wav, file_wave, remove_silence=False): 92 | sf.write(file_wave, wav, self.target_sample_rate) 93 | 94 | if remove_silence: 95 | remove_silence_for_generated_wav(file_wave) 96 | 97 | def export_spectrogram(self, spect, file_spect): 98 | save_spectrogram(spect, file_spect) 99 | 100 | def infer( 101 | self, 102 | ref_file, 103 | ref_text, 104 | gen_text, 105 | show_info=print, 106 | progress=tqdm, 107 | target_rms=0.1, 108 | cross_fade_duration=0.15, 109 | sway_sampling_coef=-1, 110 | cfg_strength=2, 111 | nfe_step=32, 112 | speed=1.0, 113 | fix_duration=None, 114 | remove_silence=False, 115 | file_wave=None, 116 | file_spect=None, 117 | seed=-1, 118 | ): 119 | if seed == -1: 120 | seed = random.randint(0, sys.maxsize) 121 | seed_everything(seed) 122 | self.seed = seed 123 | 124 | ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) 125 | 126 | wav, sr, spect = infer_process( 127 | ref_file, 128 | ref_text, 129 | gen_text, 130 | self.ema_model, 131 | self.vocoder, 132 | self.mel_spec_type, 133 | show_info=show_info, 134 | progress=progress, 135 | target_rms=target_rms, 136 | cross_fade_duration=cross_fade_duration, 137 | nfe_step=nfe_step, 138 | cfg_strength=cfg_strength, 139 | sway_sampling_coef=sway_sampling_coef, 140 | speed=speed, 141 | fix_duration=fix_duration, 142 | device=self.device, 143 | ) 144 | 145 | if file_wave is not None: 146 | self.export_wav(wav, file_wave, remove_silence) 147 | 148 | if file_spect is not None: 149 | self.export_spectrogram(spect, file_spect) 150 | 151 | return wav, sr, spect 152 | 153 | 154 | if __name__ == "__main__": 155 | f5tts = F5TTS() 156 | 157 | wav, sr, spect = f5tts.infer( 158 | ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), 159 | ref_text="some call me nature, others call me mother nature.", 160 | gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", 161 | file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), 162 | file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), 163 | seed=-1, # random seed = -1 164 | ) 165 | 166 | print("seed :", f5tts.seed) 167 | -------------------------------------------------------------------------------- /f5_tts/configs/E2TTS_Base_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN # dataset name 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 # gradient clipping 18 | bnb_optimizer: False # use bnb 8bit AdamW optimizer or not 19 | 20 | model: 21 | name: E2TTS_Base 22 | tokenizer: pinyin 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 1024 26 | depth: 24 27 | heads: 16 28 | ff_mult: 4 29 | mel_spec: 30 | target_sample_rate: 24000 31 | n_mel_channels: 100 32 | hop_length: 256 33 | win_length: 1024 34 | n_fft: 1024 35 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 36 | vocoder: 37 | is_local: False # use local offline ckpt or not 38 | local_path: None # local vocoder path 39 | 40 | ckpts: 41 | logger: wandb # wandb | tensorboard | None 42 | save_per_updates: 50000 # save checkpoint per steps 43 | last_per_steps: 5000 # save last checkpoint per steps 44 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /f5_tts/configs/E2TTS_Small_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 18 | bnb_optimizer: False 19 | 20 | model: 21 | name: E2TTS_Small 22 | tokenizer: pinyin 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 768 26 | depth: 20 27 | heads: 12 28 | ff_mult: 4 29 | mel_spec: 30 | target_sample_rate: 24000 31 | n_mel_channels: 100 32 | hop_length: 256 33 | win_length: 1024 34 | n_fft: 1024 35 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 36 | vocoder: 37 | is_local: False # use local offline ckpt or not 38 | local_path: None # local vocoder path 39 | 40 | ckpts: 41 | logger: wandb # wandb | tensorboard | None 42 | save_per_updates: 50000 # save checkpoint per steps 43 | last_per_steps: 5000 # save last checkpoint per steps 44 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /f5_tts/configs/F5TTS_Base_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN # dataset name 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 # gradient clipping 18 | bnb_optimizer: False # use bnb 8bit AdamW optimizer or not 19 | 20 | model: 21 | name: F5TTS_Base # model name 22 | tokenizer: pinyin # tokenizer type 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 1024 26 | depth: 22 27 | heads: 16 28 | ff_mult: 2 29 | text_dim: 512 30 | conv_layers: 4 31 | checkpoint_activations: False # recompute activations and save memory for extra compute 32 | mel_spec: 33 | target_sample_rate: 24000 34 | n_mel_channels: 100 35 | hop_length: 256 36 | win_length: 1024 37 | n_fft: 1024 38 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 39 | vocoder: 40 | is_local: False # use local offline ckpt or not 41 | local_path: None # local vocoder path 42 | 43 | ckpts: 44 | logger: wandb # wandb | tensorboard | None 45 | save_per_updates: 50000 # save checkpoint per steps 46 | last_per_steps: 5000 # save last checkpoint per steps 47 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /f5_tts/configs/F5TTS_Small_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | datasets: 6 | name: Emilia_ZH_EN 7 | batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 8 | batch_size_type: frame # "frame" or "sample" 9 | max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 10 | num_workers: 16 11 | 12 | optim: 13 | epochs: 15 14 | learning_rate: 7.5e-5 15 | num_warmup_updates: 20000 # warmup steps 16 | grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps 17 | max_grad_norm: 1.0 # gradient clipping 18 | bnb_optimizer: False # use bnb 8bit AdamW optimizer or not 19 | 20 | model: 21 | name: F5TTS_Small 22 | tokenizer: pinyin 23 | tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 24 | arch: 25 | dim: 768 26 | depth: 18 27 | heads: 12 28 | ff_mult: 2 29 | text_dim: 512 30 | conv_layers: 4 31 | checkpoint_activations: False # recompute activations and save memory for extra compute 32 | mel_spec: 33 | target_sample_rate: 24000 34 | n_mel_channels: 100 35 | hop_length: 256 36 | win_length: 1024 37 | n_fft: 1024 38 | mel_spec_type: vocos # 'vocos' or 'bigvgan' 39 | vocoder: 40 | is_local: False # use local offline ckpt or not 41 | local_path: None # local vocoder path 42 | 43 | ckpts: 44 | logger: wandb # wandb | tensorboard | None 45 | save_per_updates: 50000 # save checkpoint per steps 46 | last_per_steps: 5000 # save last checkpoint per steps 47 | save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} -------------------------------------------------------------------------------- /f5_tts/eval/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Evaluation 3 | 4 | Install packages for evaluation: 5 | 6 | ```bash 7 | pip install -e .[eval] 8 | ``` 9 | 10 | ## Generating Samples for Evaluation 11 | 12 | ### Prepare Test Datasets 13 | 14 | 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval). 15 | 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/). 16 | 3. Unzip the downloaded datasets and place them in the `data/` directory. 17 | 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py` 18 | 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst` 19 | 20 | ### Batch Inference for Test Set 21 | 22 | To run batch inference for evaluations, execute the following commands: 23 | 24 | ```bash 25 | # batch inference for evaluations 26 | accelerate config # if not set before 27 | bash src/f5_tts/eval/eval_infer_batch.sh 28 | ``` 29 | 30 | ## Objective Evaluation on Generated Results 31 | 32 | ### Download Evaluation Model Checkpoints 33 | 34 | 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh) 35 | 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3) 36 | 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view). 37 | 38 | Then update in the following scripts with the paths you put evaluation model ckpts to. 39 | 40 | ### Objective Evaluation 41 | 42 | Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations: 43 | ```bash 44 | # Evaluation [WER] for Seed-TTS test [ZH] set 45 | python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir --gpu_nums 8 46 | 47 | # Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence) 48 | python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir --librispeech_test_clean_path 49 | 50 | # Evaluation [UTMOS]. --ext: Audio extension 51 | python src/f5_tts/eval/eval_utmos.py --audio_dir --ext wav 52 | ``` 53 | -------------------------------------------------------------------------------- /f5_tts/eval/ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | # just for speaker similarity evaluation, third-party code 2 | 3 | # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ 4 | # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN 5 | 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | """ Res2Conv1d + BatchNorm1d + ReLU 13 | """ 14 | 15 | 16 | class Res2Conv1dReluBn(nn.Module): 17 | """ 18 | in_channels == out_channels == channels 19 | """ 20 | 21 | def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): 22 | super().__init__() 23 | assert channels % scale == 0, "{} % {} != 0".format(channels, scale) 24 | self.scale = scale 25 | self.width = channels // scale 26 | self.nums = scale if scale == 1 else scale - 1 27 | 28 | self.convs = [] 29 | self.bns = [] 30 | for i in range(self.nums): 31 | self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias)) 32 | self.bns.append(nn.BatchNorm1d(self.width)) 33 | self.convs = nn.ModuleList(self.convs) 34 | self.bns = nn.ModuleList(self.bns) 35 | 36 | def forward(self, x): 37 | out = [] 38 | spx = torch.split(x, self.width, 1) 39 | for i in range(self.nums): 40 | if i == 0: 41 | sp = spx[i] 42 | else: 43 | sp = sp + spx[i] 44 | # Order: conv -> relu -> bn 45 | sp = self.convs[i](sp) 46 | sp = self.bns[i](F.relu(sp)) 47 | out.append(sp) 48 | if self.scale != 1: 49 | out.append(spx[self.nums]) 50 | out = torch.cat(out, dim=1) 51 | 52 | return out 53 | 54 | 55 | """ Conv1d + BatchNorm1d + ReLU 56 | """ 57 | 58 | 59 | class Conv1dReluBn(nn.Module): 60 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): 61 | super().__init__() 62 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) 63 | self.bn = nn.BatchNorm1d(out_channels) 64 | 65 | def forward(self, x): 66 | return self.bn(F.relu(self.conv(x))) 67 | 68 | 69 | """ The SE connection of 1D case. 70 | """ 71 | 72 | 73 | class SE_Connect(nn.Module): 74 | def __init__(self, channels, se_bottleneck_dim=128): 75 | super().__init__() 76 | self.linear1 = nn.Linear(channels, se_bottleneck_dim) 77 | self.linear2 = nn.Linear(se_bottleneck_dim, channels) 78 | 79 | def forward(self, x): 80 | out = x.mean(dim=2) 81 | out = F.relu(self.linear1(out)) 82 | out = torch.sigmoid(self.linear2(out)) 83 | out = x * out.unsqueeze(2) 84 | 85 | return out 86 | 87 | 88 | """ SE-Res2Block of the ECAPA-TDNN architecture. 89 | """ 90 | 91 | # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): 92 | # return nn.Sequential( 93 | # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), 94 | # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), 95 | # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), 96 | # SE_Connect(channels) 97 | # ) 98 | 99 | 100 | class SE_Res2Block(nn.Module): 101 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): 102 | super().__init__() 103 | self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 104 | self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) 105 | self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) 106 | self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) 107 | 108 | self.shortcut = None 109 | if in_channels != out_channels: 110 | self.shortcut = nn.Conv1d( 111 | in_channels=in_channels, 112 | out_channels=out_channels, 113 | kernel_size=1, 114 | ) 115 | 116 | def forward(self, x): 117 | residual = x 118 | if self.shortcut: 119 | residual = self.shortcut(x) 120 | 121 | x = self.Conv1dReluBn1(x) 122 | x = self.Res2Conv1dReluBn(x) 123 | x = self.Conv1dReluBn2(x) 124 | x = self.SE_Connect(x) 125 | 126 | return x + residual 127 | 128 | 129 | """ Attentive weighted mean and standard deviation pooling. 130 | """ 131 | 132 | 133 | class AttentiveStatsPool(nn.Module): 134 | def __init__(self, in_dim, attention_channels=128, global_context_att=False): 135 | super().__init__() 136 | self.global_context_att = global_context_att 137 | 138 | # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. 139 | if global_context_att: 140 | self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper 141 | else: 142 | self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper 143 | self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper 144 | 145 | def forward(self, x): 146 | if self.global_context_att: 147 | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) 148 | context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) 149 | x_in = torch.cat((x, context_mean, context_std), dim=1) 150 | else: 151 | x_in = x 152 | 153 | # DON'T use ReLU here! In experiments, I find ReLU hard to converge. 154 | alpha = torch.tanh(self.linear1(x_in)) 155 | # alpha = F.relu(self.linear1(x_in)) 156 | alpha = torch.softmax(self.linear2(alpha), dim=2) 157 | mean = torch.sum(alpha * x, dim=2) 158 | residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 159 | std = torch.sqrt(residuals.clamp(min=1e-9)) 160 | return torch.cat([mean, std], dim=1) 161 | 162 | 163 | class ECAPA_TDNN(nn.Module): 164 | def __init__( 165 | self, 166 | feat_dim=80, 167 | channels=512, 168 | emb_dim=192, 169 | global_context_att=False, 170 | feat_type="wavlm_large", 171 | sr=16000, 172 | feature_selection="hidden_states", 173 | update_extract=False, 174 | config_path=None, 175 | ): 176 | super().__init__() 177 | 178 | self.feat_type = feat_type 179 | self.feature_selection = feature_selection 180 | self.update_extract = update_extract 181 | self.sr = sr 182 | 183 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True 184 | try: 185 | local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main") 186 | self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path) 187 | except: # noqa: E722 188 | self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) 189 | 190 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( 191 | self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" 192 | ): 193 | self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False 194 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( 195 | self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" 196 | ): 197 | self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False 198 | 199 | self.feat_num = self.get_feat_num() 200 | self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) 201 | 202 | if feat_type != "fbank" and feat_type != "mfcc": 203 | freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"] 204 | for name, param in self.feature_extract.named_parameters(): 205 | for freeze_val in freeze_list: 206 | if freeze_val in name: 207 | param.requires_grad = False 208 | break 209 | 210 | if not self.update_extract: 211 | for param in self.feature_extract.parameters(): 212 | param.requires_grad = False 213 | 214 | self.instance_norm = nn.InstanceNorm1d(feat_dim) 215 | # self.channels = [channels] * 4 + [channels * 3] 216 | self.channels = [channels] * 4 + [1536] 217 | 218 | self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) 219 | self.layer2 = SE_Res2Block( 220 | self.channels[0], 221 | self.channels[1], 222 | kernel_size=3, 223 | stride=1, 224 | padding=2, 225 | dilation=2, 226 | scale=8, 227 | se_bottleneck_dim=128, 228 | ) 229 | self.layer3 = SE_Res2Block( 230 | self.channels[1], 231 | self.channels[2], 232 | kernel_size=3, 233 | stride=1, 234 | padding=3, 235 | dilation=3, 236 | scale=8, 237 | se_bottleneck_dim=128, 238 | ) 239 | self.layer4 = SE_Res2Block( 240 | self.channels[2], 241 | self.channels[3], 242 | kernel_size=3, 243 | stride=1, 244 | padding=4, 245 | dilation=4, 246 | scale=8, 247 | se_bottleneck_dim=128, 248 | ) 249 | 250 | # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) 251 | cat_channels = channels * 3 252 | self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) 253 | self.pooling = AttentiveStatsPool( 254 | self.channels[-1], attention_channels=128, global_context_att=global_context_att 255 | ) 256 | self.bn = nn.BatchNorm1d(self.channels[-1] * 2) 257 | self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) 258 | 259 | def get_feat_num(self): 260 | self.feature_extract.eval() 261 | wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] 262 | with torch.no_grad(): 263 | features = self.feature_extract(wav) 264 | select_feature = features[self.feature_selection] 265 | if isinstance(select_feature, (list, tuple)): 266 | return len(select_feature) 267 | else: 268 | return 1 269 | 270 | def get_feat(self, x): 271 | if self.update_extract: 272 | x = self.feature_extract([sample for sample in x]) 273 | else: 274 | with torch.no_grad(): 275 | if self.feat_type == "fbank" or self.feat_type == "mfcc": 276 | x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len 277 | else: 278 | x = self.feature_extract([sample for sample in x]) 279 | 280 | if self.feat_type == "fbank": 281 | x = x.log() 282 | 283 | if self.feat_type != "fbank" and self.feat_type != "mfcc": 284 | x = x[self.feature_selection] 285 | if isinstance(x, (list, tuple)): 286 | x = torch.stack(x, dim=0) 287 | else: 288 | x = x.unsqueeze(0) 289 | norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 290 | x = (norm_weights * x).sum(dim=0) 291 | x = torch.transpose(x, 1, 2) + 1e-6 292 | 293 | x = self.instance_norm(x) 294 | return x 295 | 296 | def forward(self, x): 297 | x = self.get_feat(x) 298 | 299 | out1 = self.layer1(x) 300 | out2 = self.layer2(out1) 301 | out3 = self.layer3(out2) 302 | out4 = self.layer4(out3) 303 | 304 | out = torch.cat([out2, out3, out4], dim=1) 305 | out = F.relu(self.conv(out)) 306 | out = self.bn(self.pooling(out)) 307 | out = self.linear(out) 308 | 309 | return out 310 | 311 | 312 | def ECAPA_TDNN_SMALL( 313 | feat_dim, 314 | emb_dim=256, 315 | feat_type="wavlm_large", 316 | sr=16000, 317 | feature_selection="hidden_states", 318 | update_extract=False, 319 | config_path=None, 320 | ): 321 | return ECAPA_TDNN( 322 | feat_dim=feat_dim, 323 | channels=512, 324 | emb_dim=emb_dim, 325 | feat_type=feat_type, 326 | sr=sr, 327 | feature_selection=feature_selection, 328 | update_extract=update_extract, 329 | config_path=config_path, 330 | ) 331 | -------------------------------------------------------------------------------- /f5_tts/eval/eval_infer_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import argparse 7 | import time 8 | from importlib.resources import files 9 | 10 | import torch 11 | import torchaudio 12 | from accelerate import Accelerator 13 | from tqdm import tqdm 14 | 15 | from f5_tts.eval.utils_eval import ( 16 | get_inference_prompt, 17 | get_librispeech_test_clean_metainfo, 18 | get_seedtts_testset_metainfo, 19 | ) 20 | from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder 21 | from f5_tts.model import CFM, DiT, UNetT 22 | from f5_tts.model.utils import get_tokenizer 23 | 24 | accelerator = Accelerator() 25 | device = f"cuda:{accelerator.process_index}" 26 | 27 | 28 | # --------------------- Dataset Settings -------------------- # 29 | 30 | target_sample_rate = 24000 31 | n_mel_channels = 100 32 | hop_length = 256 33 | win_length = 1024 34 | n_fft = 1024 35 | target_rms = 0.1 36 | 37 | rel_path = str(files("f5_tts").joinpath("../../")) 38 | 39 | 40 | def main(): 41 | # ---------------------- infer setting ---------------------- # 42 | 43 | parser = argparse.ArgumentParser(description="batch inference") 44 | 45 | parser.add_argument("-s", "--seed", default=None, type=int) 46 | parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") 47 | parser.add_argument("-n", "--expname", required=True) 48 | parser.add_argument("-c", "--ckptstep", default=1200000, type=int) 49 | parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"]) 50 | parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]) 51 | 52 | parser.add_argument("-nfe", "--nfestep", default=32, type=int) 53 | parser.add_argument("-o", "--odemethod", default="euler") 54 | parser.add_argument("-ss", "--swaysampling", default=-1, type=float) 55 | 56 | parser.add_argument("-t", "--testset", required=True) 57 | 58 | args = parser.parse_args() 59 | 60 | seed = args.seed 61 | dataset_name = args.dataset 62 | exp_name = args.expname 63 | ckpt_step = args.ckptstep 64 | ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" 65 | mel_spec_type = args.mel_spec_type 66 | tokenizer = args.tokenizer 67 | 68 | nfe_step = args.nfestep 69 | ode_method = args.odemethod 70 | sway_sampling_coef = args.swaysampling 71 | 72 | testset = args.testset 73 | 74 | infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) 75 | cfg_strength = 2.0 76 | speed = 1.0 77 | use_truth_duration = False 78 | no_ref_audio = False 79 | 80 | if exp_name == "F5TTS_Base": 81 | model_cls = DiT 82 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 83 | 84 | elif exp_name == "E2TTS_Base": 85 | model_cls = UNetT 86 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 87 | 88 | if testset == "ls_pc_test_clean": 89 | metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" 90 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path 91 | metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) 92 | 93 | elif testset == "seedtts_test_zh": 94 | metalst = rel_path + "/data/seedtts_testset/zh/meta.lst" 95 | metainfo = get_seedtts_testset_metainfo(metalst) 96 | 97 | elif testset == "seedtts_test_en": 98 | metalst = rel_path + "/data/seedtts_testset/en/meta.lst" 99 | metainfo = get_seedtts_testset_metainfo(metalst) 100 | 101 | # path to save genereted wavs 102 | output_dir = ( 103 | f"{rel_path}/" 104 | f"results/{exp_name}_{ckpt_step}/{testset}/" 105 | f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}" 106 | f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" 107 | f"_cfg{cfg_strength}_speed{speed}" 108 | f"{'_gt-dur' if use_truth_duration else ''}" 109 | f"{'_no-ref-audio' if no_ref_audio else ''}" 110 | ) 111 | 112 | # -------------------------------------------------# 113 | 114 | use_ema = True 115 | 116 | prompts_all = get_inference_prompt( 117 | metainfo, 118 | speed=speed, 119 | tokenizer=tokenizer, 120 | target_sample_rate=target_sample_rate, 121 | n_mel_channels=n_mel_channels, 122 | hop_length=hop_length, 123 | mel_spec_type=mel_spec_type, 124 | target_rms=target_rms, 125 | use_truth_duration=use_truth_duration, 126 | infer_batch_size=infer_batch_size, 127 | ) 128 | 129 | # Vocoder model 130 | local = False 131 | if mel_spec_type == "vocos": 132 | vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" 133 | elif mel_spec_type == "bigvgan": 134 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" 135 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) 136 | 137 | # Tokenizer 138 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 139 | 140 | # Model 141 | model = CFM( 142 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 143 | mel_spec_kwargs=dict( 144 | n_fft=n_fft, 145 | hop_length=hop_length, 146 | win_length=win_length, 147 | n_mel_channels=n_mel_channels, 148 | target_sample_rate=target_sample_rate, 149 | mel_spec_type=mel_spec_type, 150 | ), 151 | odeint_kwargs=dict( 152 | method=ode_method, 153 | ), 154 | vocab_char_map=vocab_char_map, 155 | ).to(device) 156 | 157 | dtype = torch.float32 if mel_spec_type == "bigvgan" else None 158 | model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) 159 | 160 | if not os.path.exists(output_dir) and accelerator.is_main_process: 161 | os.makedirs(output_dir) 162 | 163 | # start batch inference 164 | accelerator.wait_for_everyone() 165 | start = time.time() 166 | 167 | with accelerator.split_between_processes(prompts_all) as prompts: 168 | for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): 169 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt 170 | ref_mels = ref_mels.to(device) 171 | ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) 172 | total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) 173 | 174 | # Inference 175 | with torch.inference_mode(): 176 | generated, _ = model.sample( 177 | cond=ref_mels, 178 | text=final_text_list, 179 | duration=total_mel_lens, 180 | lens=ref_mel_lens, 181 | steps=nfe_step, 182 | cfg_strength=cfg_strength, 183 | sway_sampling_coef=sway_sampling_coef, 184 | no_ref_audio=no_ref_audio, 185 | seed=seed, 186 | ) 187 | # Final result 188 | for i, gen in enumerate(generated): 189 | gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) 190 | gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) 191 | if mel_spec_type == "vocos": 192 | generated_wave = vocoder.decode(gen_mel_spec).cpu() 193 | elif mel_spec_type == "bigvgan": 194 | generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() 195 | 196 | if ref_rms_list[i] < target_rms: 197 | generated_wave = generated_wave * ref_rms_list[i] / target_rms 198 | torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) 199 | 200 | accelerator.wait_for_everyone() 201 | if accelerator.is_main_process: 202 | timediff = time.time() - start 203 | print(f"Done batch inference in {timediff / 60 :.2f} minutes.") 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /f5_tts/eval/eval_infer_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # e.g. F5-TTS, 16 NFE 4 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 5 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 6 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 7 | 8 | # e.g. Vanilla E2 TTS, 32 NFE 9 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 10 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 11 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 12 | 13 | # etc. 14 | -------------------------------------------------------------------------------- /f5_tts/eval/eval_librispeech_test_clean.py: -------------------------------------------------------------------------------- 1 | # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation) 2 | 3 | import argparse 4 | import json 5 | import os 6 | import sys 7 | 8 | sys.path.append(os.getcwd()) 9 | 10 | import multiprocessing as mp 11 | from importlib.resources import files 12 | 13 | import numpy as np 14 | from f5_tts.eval.utils_eval import ( 15 | get_librispeech_test, 16 | run_asr_wer, 17 | run_sim, 18 | ) 19 | 20 | rel_path = str(files("f5_tts").joinpath("../../")) 21 | 22 | 23 | def get_args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) 26 | parser.add_argument("-l", "--lang", type=str, default="en") 27 | parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) 28 | parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True) 29 | parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use") 30 | parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") 31 | return parser.parse_args() 32 | 33 | 34 | def main(): 35 | args = get_args() 36 | eval_task = args.eval_task 37 | lang = args.lang 38 | librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path 39 | gen_wav_dir = args.gen_wav_dir 40 | metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" 41 | 42 | gpus = list(range(args.gpu_nums)) 43 | test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) 44 | 45 | ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, 46 | ## leading to a low similarity for the ground truth in some cases. 47 | # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth 48 | 49 | local = args.local 50 | if local: # use local custom checkpoint dir 51 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" 52 | else: 53 | asr_ckpt_dir = "" # auto download to cache dir 54 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" 55 | 56 | # --------------------------- WER --------------------------- 57 | 58 | if eval_task == "wer": 59 | wer_results = [] 60 | wers = [] 61 | 62 | with mp.Pool(processes=len(gpus)) as pool: 63 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] 64 | results = pool.map(run_asr_wer, args) 65 | for r in results: 66 | wer_results.extend(r) 67 | 68 | wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" 69 | with open(wer_result_path, "w") as f: 70 | for line in wer_results: 71 | wers.append(line["wer"]) 72 | json_line = json.dumps(line, ensure_ascii=False) 73 | f.write(json_line + "\n") 74 | 75 | wer = round(np.mean(wers) * 100, 3) 76 | print(f"\nTotal {len(wers)} samples") 77 | print(f"WER : {wer}%") 78 | print(f"Results have been saved to {wer_result_path}") 79 | 80 | # --------------------------- SIM --------------------------- 81 | 82 | if eval_task == "sim": 83 | sims = [] 84 | with mp.Pool(processes=len(gpus)) as pool: 85 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] 86 | results = pool.map(run_sim, args) 87 | for r in results: 88 | sims.extend(r) 89 | 90 | sim = round(sum(sims) / len(sims), 3) 91 | print(f"\nTotal {len(sims)} samples") 92 | print(f"SIM : {sim}") 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /f5_tts/eval/eval_seedtts_testset.py: -------------------------------------------------------------------------------- 1 | # Evaluate with Seed-TTS testset 2 | 3 | import argparse 4 | import json 5 | import os 6 | import sys 7 | 8 | sys.path.append(os.getcwd()) 9 | 10 | import multiprocessing as mp 11 | from importlib.resources import files 12 | 13 | import numpy as np 14 | from f5_tts.eval.utils_eval import ( 15 | get_seed_tts_test, 16 | run_asr_wer, 17 | run_sim, 18 | ) 19 | 20 | rel_path = str(files("f5_tts").joinpath("../../")) 21 | 22 | 23 | def get_args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) 26 | parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"]) 27 | parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) 28 | parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use") 29 | parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") 30 | return parser.parse_args() 31 | 32 | 33 | def main(): 34 | args = get_args() 35 | eval_task = args.eval_task 36 | lang = args.lang 37 | gen_wav_dir = args.gen_wav_dir 38 | metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset 39 | 40 | # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different 41 | # zh 1.254 seems a result of 4 workers wer_seed_tts 42 | gpus = list(range(args.gpu_nums)) 43 | test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) 44 | 45 | local = args.local 46 | if local: # use local custom checkpoint dir 47 | if lang == "zh": 48 | asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr 49 | elif lang == "en": 50 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" 51 | else: 52 | asr_ckpt_dir = "" # auto download to cache dir 53 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" 54 | 55 | # --------------------------- WER --------------------------- 56 | 57 | if eval_task == "wer": 58 | wer_results = [] 59 | wers = [] 60 | 61 | with mp.Pool(processes=len(gpus)) as pool: 62 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] 63 | results = pool.map(run_asr_wer, args) 64 | for r in results: 65 | wer_results.extend(r) 66 | 67 | wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" 68 | with open(wer_result_path, "w") as f: 69 | for line in wer_results: 70 | wers.append(line["wer"]) 71 | json_line = json.dumps(line, ensure_ascii=False) 72 | f.write(json_line + "\n") 73 | 74 | wer = round(np.mean(wers) * 100, 3) 75 | print(f"\nTotal {len(wers)} samples") 76 | print(f"WER : {wer}%") 77 | print(f"Results have been saved to {wer_result_path}") 78 | 79 | # --------------------------- SIM --------------------------- 80 | 81 | if eval_task == "sim": 82 | sims = [] 83 | with mp.Pool(processes=len(gpus)) as pool: 84 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] 85 | results = pool.map(run_sim, args) 86 | for r in results: 87 | sims.extend(r) 88 | 89 | sim = round(sum(sims) / len(sims), 3) 90 | print(f"\nTotal {len(sims)} samples") 91 | print(f"SIM : {sim}") 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /f5_tts/eval/eval_utmos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | import librosa 6 | import torch 7 | from tqdm import tqdm 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description="UTMOS Evaluation") 12 | parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.") 13 | parser.add_argument("--ext", type=str, default="wav", help="Audio extension.") 14 | args = parser.parse_args() 15 | 16 | device = "cuda" if torch.cuda.is_available() else "cpu" 17 | 18 | predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) 19 | predictor = predictor.to(device) 20 | 21 | audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) 22 | utmos_results = {} 23 | utmos_score = 0 24 | 25 | for audio_path in tqdm(audio_paths, desc="Processing"): 26 | wav_name = audio_path.stem 27 | wav, sr = librosa.load(audio_path, sr=None, mono=True) 28 | wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) 29 | score = predictor(wav_tensor, sr) 30 | utmos_results[str(wav_name)] = score.item() 31 | utmos_score += score.item() 32 | 33 | avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 34 | print(f"UTMOS: {avg_score}") 35 | 36 | utmos_result_path = Path(args.audio_dir) / "utmos_results.json" 37 | with open(utmos_result_path, "w", encoding="utf-8") as f: 38 | json.dump(utmos_results, f, ensure_ascii=False, indent=4) 39 | 40 | print(f"Results have been saved to {utmos_result_path}") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /f5_tts/eval/utils_eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import string 5 | from pathlib import Path 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchaudio 10 | from tqdm import tqdm 11 | 12 | from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL 13 | from f5_tts.model.modules import MelSpec 14 | from f5_tts.model.utils import convert_char_to_pinyin 15 | 16 | 17 | # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav 18 | def get_seedtts_testset_metainfo(metalst): 19 | f = open(metalst) 20 | lines = f.readlines() 21 | f.close() 22 | metainfo = [] 23 | for line in lines: 24 | if len(line.strip().split("|")) == 5: 25 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") 26 | elif len(line.strip().split("|")) == 4: 27 | utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") 28 | gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") 29 | if not os.path.isabs(prompt_wav): 30 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) 31 | metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) 32 | return metainfo 33 | 34 | 35 | # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav 36 | def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path): 37 | f = open(metalst) 38 | lines = f.readlines() 39 | f.close() 40 | metainfo = [] 41 | for line in lines: 42 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") 43 | 44 | # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) 45 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") 46 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") 47 | 48 | # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) 49 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") 50 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") 51 | 52 | metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav)) 53 | 54 | return metainfo 55 | 56 | 57 | # padded to max length mel batch 58 | def padded_mel_batch(ref_mels): 59 | max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() 60 | padded_ref_mels = [] 61 | for mel in ref_mels: 62 | padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) 63 | padded_ref_mels.append(padded_ref_mel) 64 | padded_ref_mels = torch.stack(padded_ref_mels) 65 | padded_ref_mels = padded_ref_mels.permute(0, 2, 1) 66 | return padded_ref_mels 67 | 68 | 69 | # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav 70 | 71 | 72 | def get_inference_prompt( 73 | metainfo, 74 | speed=1.0, 75 | tokenizer="pinyin", 76 | polyphone=True, 77 | target_sample_rate=24000, 78 | n_fft=1024, 79 | win_length=1024, 80 | n_mel_channels=100, 81 | hop_length=256, 82 | mel_spec_type="vocos", 83 | target_rms=0.1, 84 | use_truth_duration=False, 85 | infer_batch_size=1, 86 | num_buckets=200, 87 | min_secs=3, 88 | max_secs=40, 89 | ): 90 | prompts_all = [] 91 | 92 | min_tokens = min_secs * target_sample_rate // hop_length 93 | max_tokens = max_secs * target_sample_rate // hop_length 94 | 95 | batch_accum = [0] * num_buckets 96 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( 97 | [[] for _ in range(num_buckets)] for _ in range(6) 98 | ) 99 | 100 | mel_spectrogram = MelSpec( 101 | n_fft=n_fft, 102 | hop_length=hop_length, 103 | win_length=win_length, 104 | n_mel_channels=n_mel_channels, 105 | target_sample_rate=target_sample_rate, 106 | mel_spec_type=mel_spec_type, 107 | ) 108 | 109 | for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): 110 | # Audio 111 | ref_audio, ref_sr = torchaudio.load(prompt_wav) 112 | ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) 113 | if ref_rms < target_rms: 114 | ref_audio = ref_audio * target_rms / ref_rms 115 | assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." 116 | if ref_sr != target_sample_rate: 117 | resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) 118 | ref_audio = resampler(ref_audio) 119 | 120 | # Text 121 | if len(prompt_text[-1].encode("utf-8")) == 1: 122 | prompt_text = prompt_text + " " 123 | text = [prompt_text + gt_text] 124 | if tokenizer == "pinyin": 125 | text_list = convert_char_to_pinyin(text, polyphone=polyphone) 126 | else: 127 | text_list = text 128 | 129 | # Duration, mel frame length 130 | ref_mel_len = ref_audio.shape[-1] // hop_length 131 | if use_truth_duration: 132 | gt_audio, gt_sr = torchaudio.load(gt_wav) 133 | if gt_sr != target_sample_rate: 134 | resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) 135 | gt_audio = resampler(gt_audio) 136 | total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) 137 | 138 | # # test vocoder resynthesis 139 | # ref_audio = gt_audio 140 | else: 141 | ref_text_len = len(prompt_text.encode("utf-8")) 142 | gen_text_len = len(gt_text.encode("utf-8")) 143 | total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) 144 | 145 | # to mel spectrogram 146 | ref_mel = mel_spectrogram(ref_audio) 147 | ref_mel = ref_mel.squeeze(0) 148 | 149 | # deal with batch 150 | assert infer_batch_size > 0, "infer_batch_size should be greater than 0." 151 | assert ( 152 | min_tokens <= total_mel_len <= max_tokens 153 | ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." 154 | bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) 155 | 156 | utts[bucket_i].append(utt) 157 | ref_rms_list[bucket_i].append(ref_rms) 158 | ref_mels[bucket_i].append(ref_mel) 159 | ref_mel_lens[bucket_i].append(ref_mel_len) 160 | total_mel_lens[bucket_i].append(total_mel_len) 161 | final_text_list[bucket_i].extend(text_list) 162 | 163 | batch_accum[bucket_i] += total_mel_len 164 | 165 | if batch_accum[bucket_i] >= infer_batch_size: 166 | # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") 167 | prompts_all.append( 168 | ( 169 | utts[bucket_i], 170 | ref_rms_list[bucket_i], 171 | padded_mel_batch(ref_mels[bucket_i]), 172 | ref_mel_lens[bucket_i], 173 | total_mel_lens[bucket_i], 174 | final_text_list[bucket_i], 175 | ) 176 | ) 177 | batch_accum[bucket_i] = 0 178 | ( 179 | utts[bucket_i], 180 | ref_rms_list[bucket_i], 181 | ref_mels[bucket_i], 182 | ref_mel_lens[bucket_i], 183 | total_mel_lens[bucket_i], 184 | final_text_list[bucket_i], 185 | ) = [], [], [], [], [], [] 186 | 187 | # add residual 188 | for bucket_i, bucket_frames in enumerate(batch_accum): 189 | if bucket_frames > 0: 190 | prompts_all.append( 191 | ( 192 | utts[bucket_i], 193 | ref_rms_list[bucket_i], 194 | padded_mel_batch(ref_mels[bucket_i]), 195 | ref_mel_lens[bucket_i], 196 | total_mel_lens[bucket_i], 197 | final_text_list[bucket_i], 198 | ) 199 | ) 200 | # not only leave easy work for last workers 201 | random.seed(666) 202 | random.shuffle(prompts_all) 203 | 204 | return prompts_all 205 | 206 | 207 | # get wav_res_ref_text of seed-tts test metalst 208 | # https://github.com/BytedanceSpeech/seed-tts-eval 209 | 210 | 211 | def get_seed_tts_test(metalst, gen_wav_dir, gpus): 212 | f = open(metalst) 213 | lines = f.readlines() 214 | f.close() 215 | 216 | test_set_ = [] 217 | for line in tqdm(lines): 218 | if len(line.strip().split("|")) == 5: 219 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") 220 | elif len(line.strip().split("|")) == 4: 221 | utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") 222 | 223 | if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")): 224 | continue 225 | gen_wav = os.path.join(gen_wav_dir, utt + ".wav") 226 | if not os.path.isabs(prompt_wav): 227 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) 228 | 229 | test_set_.append((gen_wav, prompt_wav, gt_text)) 230 | 231 | num_jobs = len(gpus) 232 | if num_jobs == 1: 233 | return [(gpus[0], test_set_)] 234 | 235 | wav_per_job = len(test_set_) // num_jobs + 1 236 | test_set = [] 237 | for i in range(num_jobs): 238 | test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) 239 | 240 | return test_set 241 | 242 | 243 | # get librispeech test-clean cross sentence test 244 | 245 | 246 | def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False): 247 | f = open(metalst) 248 | lines = f.readlines() 249 | f.close() 250 | 251 | test_set_ = [] 252 | for line in tqdm(lines): 253 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") 254 | 255 | if eval_ground_truth: 256 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") 257 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") 258 | else: 259 | if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")): 260 | raise FileNotFoundError(f"Generated wav not found: {gen_utt}") 261 | gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav") 262 | 263 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") 264 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") 265 | 266 | test_set_.append((gen_wav, ref_wav, gen_txt)) 267 | 268 | num_jobs = len(gpus) 269 | if num_jobs == 1: 270 | return [(gpus[0], test_set_)] 271 | 272 | wav_per_job = len(test_set_) // num_jobs + 1 273 | test_set = [] 274 | for i in range(num_jobs): 275 | test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) 276 | 277 | return test_set 278 | 279 | 280 | # load asr model 281 | 282 | 283 | def load_asr_model(lang, ckpt_dir=""): 284 | if lang == "zh": 285 | from funasr import AutoModel 286 | 287 | model = AutoModel( 288 | model=os.path.join(ckpt_dir, "paraformer-zh"), 289 | # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), 290 | # punc_model = os.path.join(ckpt_dir, "ct-punc"), 291 | # spk_model = os.path.join(ckpt_dir, "cam++"), 292 | disable_update=True, 293 | ) # following seed-tts setting 294 | elif lang == "en": 295 | from faster_whisper import WhisperModel 296 | 297 | model_size = "large-v3" if ckpt_dir == "" else ckpt_dir 298 | model = WhisperModel(model_size, device="cuda", compute_type="float16") 299 | return model 300 | 301 | 302 | # WER Evaluation, the way Seed-TTS does 303 | 304 | 305 | def run_asr_wer(args): 306 | rank, lang, test_set, ckpt_dir = args 307 | 308 | if lang == "zh": 309 | import zhconv 310 | 311 | torch.cuda.set_device(rank) 312 | elif lang == "en": 313 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 314 | else: 315 | raise NotImplementedError( 316 | "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now." 317 | ) 318 | 319 | asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir) 320 | 321 | from zhon.hanzi import punctuation 322 | 323 | punctuation_all = punctuation + string.punctuation 324 | wer_results = [] 325 | 326 | from jiwer import compute_measures 327 | 328 | for gen_wav, prompt_wav, truth in tqdm(test_set): 329 | if lang == "zh": 330 | res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True) 331 | hypo = res[0]["text"] 332 | hypo = zhconv.convert(hypo, "zh-cn") 333 | elif lang == "en": 334 | segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en") 335 | hypo = "" 336 | for segment in segments: 337 | hypo = hypo + " " + segment.text 338 | 339 | raw_truth = truth 340 | raw_hypo = hypo 341 | 342 | for x in punctuation_all: 343 | truth = truth.replace(x, "") 344 | hypo = hypo.replace(x, "") 345 | 346 | truth = truth.replace(" ", " ") 347 | hypo = hypo.replace(" ", " ") 348 | 349 | if lang == "zh": 350 | truth = " ".join([x for x in truth]) 351 | hypo = " ".join([x for x in hypo]) 352 | elif lang == "en": 353 | truth = truth.lower() 354 | hypo = hypo.lower() 355 | 356 | measures = compute_measures(truth, hypo) 357 | wer = measures["wer"] 358 | 359 | # ref_list = truth.split(" ") 360 | # subs = measures["substitutions"] / len(ref_list) 361 | # dele = measures["deletions"] / len(ref_list) 362 | # inse = measures["insertions"] / len(ref_list) 363 | 364 | wer_results.append( 365 | { 366 | "wav": Path(gen_wav).stem, 367 | "truth": raw_truth, 368 | "hypo": raw_hypo, 369 | "wer": wer, 370 | } 371 | ) 372 | 373 | return wer_results 374 | 375 | 376 | # SIM Evaluation 377 | 378 | 379 | def run_sim(args): 380 | rank, test_set, ckpt_dir = args 381 | device = f"cuda:{rank}" 382 | 383 | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None) 384 | state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage) 385 | model.load_state_dict(state_dict["model"], strict=False) 386 | 387 | use_gpu = True if torch.cuda.is_available() else False 388 | if use_gpu: 389 | model = model.cuda(device) 390 | model.eval() 391 | 392 | sims = [] 393 | for wav1, wav2, truth in tqdm(test_set): 394 | wav1, sr1 = torchaudio.load(wav1) 395 | wav2, sr2 = torchaudio.load(wav2) 396 | 397 | resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000) 398 | resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000) 399 | wav1 = resample1(wav1) 400 | wav2 = resample2(wav2) 401 | 402 | if use_gpu: 403 | wav1 = wav1.cuda(device) 404 | wav2 = wav2.cuda(device) 405 | with torch.no_grad(): 406 | emb1 = model(wav1) 407 | emb2 = model(wav2) 408 | 409 | sim = F.cosine_similarity(emb1, emb2)[0].item() 410 | # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") 411 | sims.append(sim) 412 | 413 | return sims 414 | -------------------------------------------------------------------------------- /f5_tts/infer/README.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts. 4 | 5 | **More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.** 6 | 7 | Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**. 8 | 9 | To avoid possible inference failures, make sure you have seen through the following instructions. 10 | 11 | - Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation. 12 | - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words. 13 | - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. 14 | - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English. 15 | - If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.). 16 | - Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates). 17 | 18 | 19 | ## Gradio App 20 | 21 | Currently supported features: 22 | 23 | - Basic TTS with Chunk Inference 24 | - Multi-Style / Multi-Speaker Generation 25 | - Voice Chat powered by Qwen2.5-3B-Instruct 26 | 27 | The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference. 28 | 29 | The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat. 30 | 31 | Could also be used as a component for larger application. 32 | ```python 33 | import gradio as gr 34 | from f5_tts.infer.infer_gradio import app 35 | 36 | with gr.Blocks() as main_app: 37 | gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app") 38 | 39 | # ... other Gradio components 40 | 41 | app.render() 42 | 43 | main_app.launch() 44 | ``` 45 | 46 | 47 | ## CLI Inference 48 | 49 | The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference. 50 | 51 | The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`. 52 | 53 | For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file. 54 | 55 | Basically you can inference with flags: 56 | ```bash 57 | # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) 58 | f5-tts_infer-cli \ 59 | --model "F5-TTS" \ 60 | --ref_audio "ref_audio.wav" \ 61 | --ref_text "The content, subtitle or transcription of reference audio." \ 62 | --gen_text "Some text you want TTS model generate for you." 63 | 64 | # Choose Vocoder 65 | f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file 66 | f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file 67 | 68 | # More instructions 69 | f5-tts_infer-cli --help 70 | ``` 71 | 72 | And a `.toml` file would help with more flexible usage. 73 | 74 | ```bash 75 | f5-tts_infer-cli -c custom.toml 76 | ``` 77 | 78 | For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`: 79 | 80 | ```toml 81 | # F5-TTS | E2-TTS 82 | model = "F5-TTS" 83 | ref_audio = "infer/examples/basic/basic_ref_en.wav" 84 | # If an empty "", transcribes the reference audio automatically. 85 | ref_text = "Some call me nature, others call me mother nature." 86 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." 87 | # File with text to generate. Ignores the text above. 88 | gen_file = "" 89 | remove_silence = false 90 | output_dir = "tests" 91 | ``` 92 | 93 | You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`. 94 | 95 | ```toml 96 | # F5-TTS | E2-TTS 97 | model = "F5-TTS" 98 | ref_audio = "infer/examples/multi/main.flac" 99 | # If an empty "", transcribes the reference audio automatically. 100 | ref_text = "" 101 | gen_text = "" 102 | # File with text to generate. Ignores the text above. 103 | gen_file = "infer/examples/multi/story.txt" 104 | remove_silence = true 105 | output_dir = "tests" 106 | 107 | [voices.town] 108 | ref_audio = "infer/examples/multi/town.flac" 109 | ref_text = "" 110 | 111 | [voices.country] 112 | ref_audio = "infer/examples/multi/country.flac" 113 | ref_text = "" 114 | ``` 115 | You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`. 116 | 117 | ## Speech Editing 118 | 119 | To test speech editing capabilities, use the following command: 120 | 121 | ```bash 122 | python src/f5_tts/infer/speech_edit.py 123 | ``` 124 | 125 | ## Socket Realtime Client 126 | 127 | To communicate with socket server you need to run 128 | ```bash 129 | python src/f5_tts/socket_server.py 130 | ``` 131 | 132 |
133 | Then create client to communicate 134 | 135 | ``` python 136 | import socket 137 | import numpy as np 138 | import asyncio 139 | import pyaudio 140 | 141 | async def listen_to_voice(text, server_ip='localhost', server_port=9999): 142 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 143 | client_socket.connect((server_ip, server_port)) 144 | 145 | async def play_audio_stream(): 146 | buffer = b'' 147 | p = pyaudio.PyAudio() 148 | stream = p.open(format=pyaudio.paFloat32, 149 | channels=1, 150 | rate=24000, # Ensure this matches the server's sampling rate 151 | output=True, 152 | frames_per_buffer=2048) 153 | 154 | try: 155 | while True: 156 | chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024) 157 | if not chunk: # End of stream 158 | break 159 | if b"END_OF_AUDIO" in chunk: 160 | buffer += chunk.replace(b"END_OF_AUDIO", b"") 161 | if buffer: 162 | audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy 163 | stream.write(audio_array.tobytes()) 164 | break 165 | buffer += chunk 166 | if len(buffer) >= 4096: 167 | audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy 168 | stream.write(audio_array.tobytes()) 169 | buffer = buffer[4096:] 170 | finally: 171 | stream.stop_stream() 172 | stream.close() 173 | p.terminate() 174 | 175 | try: 176 | # Send only the text to the server 177 | await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8')) 178 | await play_audio_stream() 179 | print("Audio playback finished.") 180 | 181 | except Exception as e: 182 | print(f"Error in listen_to_voice: {e}") 183 | 184 | finally: 185 | client_socket.close() 186 | 187 | # Example usage: Replace this with your actual server IP and port 188 | async def main(): 189 | await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998) 190 | 191 | # Run the main async function 192 | asyncio.run(main()) 193 | ``` 194 | 195 |
196 | 197 | -------------------------------------------------------------------------------- /f5_tts/infer/SHARED.md: -------------------------------------------------------------------------------- 1 | 2 | # Shared Model Cards 3 | 4 | 5 | ### **Prerequisites of using** 6 | - This document is serving as a quick lookup table for the community training/finetuning result, with various language support. 7 | - The models in this repository are open source and are based on voluntary contributions from contributors. 8 | - The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts. 9 | 10 | 11 | ### **Welcome to share here** 12 | - Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization). 13 | - Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files. 14 | - Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`. 15 | 16 | 17 | ### Supported Languages 18 | - [Multilingual](#multilingual) 19 | - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts) 20 | - [English](#english) 21 | - [Finnish](#finnish) 22 | - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen) 23 | - [French](#french) 24 | - [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio) 25 | - [Hindi](#hindi) 26 | - [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab) 27 | - [Italian](#italian) 28 | - [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79) 29 | - [Japanese](#japanese) 30 | - [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica) 31 | - [Mandarin](#mandarin) 32 | - [Russian](#russian) 33 | - [F5-TTS Base @ ru @ HotDro4illa](#f5-tts-base--ru--hotdro4illa) 34 | - [Spanish](#spanish) 35 | - [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar) 36 | 37 | 38 | ## Multilingual 39 | 40 | #### F5-TTS Base @ zh & en @ F5-TTS 41 | |Model|🤗Hugging Face|Data (Hours)|Model License| 42 | |:---:|:------------:|:-----------:|:-------------:| 43 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| 44 | 45 | ```bash 46 | Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors 47 | Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt 48 | Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 49 | ``` 50 | 51 | *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...* 52 | 53 | 54 | ## English 55 | 56 | 57 | ## Finnish 58 | 59 | #### F5-TTS Base @ fi @ AsmoKoskinen 60 | |Model|🤗Hugging Face|Data|Model License| 61 | |:---:|:------------:|:-----------:|:-------------:| 62 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0| 63 | 64 | ```bash 65 | Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors 66 | Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt 67 | Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 68 | ``` 69 | 70 | 71 | ## French 72 | 73 | #### F5-TTS Base @ fr @ RASPIAUDIO 74 | |Model|🤗Hugging Face|Data (Hours)|Model License| 75 | |:---:|:------------:|:-----------:|:-------------:| 76 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0| 77 | 78 | ```bash 79 | Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt 80 | Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt 81 | Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 82 | ``` 83 | 84 | - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french). 85 | - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys). 86 | - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434). 87 | 88 | 89 | ## Hindi 90 | 91 | #### F5-TTS Small @ hi @ SPRINGLab 92 | |Model|🤗Hugging Face|Data (Hours)|Model License| 93 | |:---:|:------------:|:-----------:|:-------------:| 94 | |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0| 95 | 96 | ```bash 97 | Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors 98 | Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt 99 | Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 100 | ``` 101 | 102 | - Authors: SPRING Lab, Indian Institute of Technology, Madras 103 | - Website: https://asr.iitm.ac.in/ 104 | 105 | 106 | ## Italian 107 | 108 | #### F5-TTS Base @ it @ alien79 109 | |Model|🤗Hugging Face|Data|Model License| 110 | |:---:|:------------:|:-----------:|:-------------:| 111 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0| 112 | 113 | ```bash 114 | Model: hf://alien79/F5-TTS-italian/model_159600.safetensors 115 | Vocab: hf://alien79/F5-TTS-italian/vocab.txt 116 | Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 117 | ``` 118 | 119 | - Trained by [Mithril Man](https://github.com/MithrilMan) 120 | - Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian) 121 | - Open to collaborations to further improve the model 122 | 123 | 124 | ## Japanese 125 | 126 | #### F5-TTS Base @ ja @ Jmica 127 | |Model|🤗Hugging Face|Data (Hours)|Model License| 128 | |:---:|:------------:|:-----------:|:-------------:| 129 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_25498980)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0| 130 | 131 | ```bash 132 | Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt 133 | Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt 134 | Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 135 | ``` 136 | 137 | 138 | ## Mandarin 139 | 140 | 141 | ## Russian 142 | 143 | #### F5-TTS Base @ ru @ HotDro4illa 144 | |Model|🤗Hugging Face|Data (Hours)|Model License| 145 | |:---:|:------------:|:-----------:|:-------------:| 146 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/hotstone228/F5-TTS-Russian)|[Common voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)|cc-by-nc-4.0| 147 | 148 | ```bash 149 | Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors 150 | Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt 151 | Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} 152 | ``` 153 | - Finetuned by [HotDro4illa](https://github.com/HotDro4illa) 154 | - Any improvements are welcome 155 | 156 | 157 | ## Spanish 158 | 159 | #### F5-TTS Base @ es @ jpgallegoar 160 | |Model|🤗Hugging Face|Data (Hours)|Model License| 161 | |:---:|:------------:|:-----------:|:-------------:| 162 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0| 163 | 164 | - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model. 165 | -------------------------------------------------------------------------------- /f5_tts/infer/examples/basic/basic.toml: -------------------------------------------------------------------------------- 1 | # F5-TTS | E2-TTS 2 | model = "F5-TTS" 3 | ref_audio = "infer/examples/basic/basic_ref_en.wav" 4 | # If an empty "", transcribes the reference audio automatically. 5 | ref_text = "Some call me nature, others call me mother nature." 6 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." 7 | # File with text to generate. Ignores the text above. 8 | gen_file = "" 9 | remove_silence = false 10 | output_dir = "tests" 11 | output_file = "infer_cli_basic.wav" 12 | -------------------------------------------------------------------------------- /f5_tts/infer/examples/basic/basic_ref_en.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/f5-tts-api/d3c408cdde4e4343e5cbb7ba528f17596e4870ac/f5_tts/infer/examples/basic/basic_ref_en.wav -------------------------------------------------------------------------------- /f5_tts/infer/examples/basic/basic_ref_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/f5-tts-api/d3c408cdde4e4343e5cbb7ba528f17596e4870ac/f5_tts/infer/examples/basic/basic_ref_zh.wav -------------------------------------------------------------------------------- /f5_tts/infer/examples/multi/country.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/f5-tts-api/d3c408cdde4e4343e5cbb7ba528f17596e4870ac/f5_tts/infer/examples/multi/country.flac -------------------------------------------------------------------------------- /f5_tts/infer/examples/multi/main.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/f5-tts-api/d3c408cdde4e4343e5cbb7ba528f17596e4870ac/f5_tts/infer/examples/multi/main.flac -------------------------------------------------------------------------------- /f5_tts/infer/examples/multi/story.toml: -------------------------------------------------------------------------------- 1 | # F5-TTS | E2-TTS 2 | model = "F5-TTS" 3 | ref_audio = "infer/examples/multi/main.flac" 4 | # If an empty "", transcribes the reference audio automatically. 5 | ref_text = "" 6 | gen_text = "" 7 | # File with text to generate. Ignores the text above. 8 | gen_file = "infer/examples/multi/story.txt" 9 | remove_silence = true 10 | output_dir = "tests" 11 | output_file = "infer_cli_story.wav" 12 | 13 | [voices.town] 14 | ref_audio = "infer/examples/multi/town.flac" 15 | ref_text = "" 16 | 17 | [voices.country] 18 | ref_audio = "infer/examples/multi/country.flac" 19 | ref_text = "" 20 | 21 | -------------------------------------------------------------------------------- /f5_tts/infer/examples/multi/story.txt: -------------------------------------------------------------------------------- 1 | A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.” -------------------------------------------------------------------------------- /f5_tts/infer/examples/multi/town.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianchang512/f5-tts-api/d3c408cdde4e4343e5cbb7ba528f17596e4870ac/f5_tts/infer/examples/multi/town.flac -------------------------------------------------------------------------------- /f5_tts/infer/infer_cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | import os 4 | import re 5 | from datetime import datetime 6 | from importlib.resources import files 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import soundfile as sf 11 | import tomli 12 | from cached_path import cached_path 13 | from omegaconf import OmegaConf 14 | 15 | from f5_tts.infer.utils_infer import ( 16 | mel_spec_type, 17 | target_rms, 18 | cross_fade_duration, 19 | nfe_step, 20 | cfg_strength, 21 | sway_sampling_coef, 22 | speed, 23 | fix_duration, 24 | infer_process, 25 | load_model, 26 | load_vocoder, 27 | preprocess_ref_audio_text, 28 | remove_silence_for_generated_wav, 29 | ) 30 | from f5_tts.model import DiT, UNetT 31 | 32 | 33 | parser = argparse.ArgumentParser( 34 | prog="python3 infer-cli.py", 35 | description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.", 36 | epilog="Specify options above to override one or more settings from config.", 37 | ) 38 | parser.add_argument( 39 | "-c", 40 | "--config", 41 | type=str, 42 | default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), 43 | help="The configuration file, default see infer/examples/basic/basic.toml", 44 | ) 45 | 46 | 47 | # Note. Not to provide default value here in order to read default from config file 48 | 49 | parser.add_argument( 50 | "-m", 51 | "--model", 52 | type=str, 53 | help="The model name: F5-TTS | E2-TTS", 54 | ) 55 | parser.add_argument( 56 | "-mc", 57 | "--model_cfg", 58 | type=str, 59 | help="The path to F5-TTS model config file .yaml", 60 | ) 61 | parser.add_argument( 62 | "-p", 63 | "--ckpt_file", 64 | type=str, 65 | help="The path to model checkpoint .pt, leave blank to use default", 66 | ) 67 | parser.add_argument( 68 | "-v", 69 | "--vocab_file", 70 | type=str, 71 | help="The path to vocab file .txt, leave blank to use default", 72 | ) 73 | parser.add_argument( 74 | "-r", 75 | "--ref_audio", 76 | type=str, 77 | help="The reference audio file.", 78 | ) 79 | parser.add_argument( 80 | "-s", 81 | "--ref_text", 82 | type=str, 83 | help="The transcript/subtitle for the reference audio", 84 | ) 85 | parser.add_argument( 86 | "-t", 87 | "--gen_text", 88 | type=str, 89 | help="The text to make model synthesize a speech", 90 | ) 91 | parser.add_argument( 92 | "-f", 93 | "--gen_file", 94 | type=str, 95 | help="The file with text to generate, will ignore --gen_text", 96 | ) 97 | parser.add_argument( 98 | "-o", 99 | "--output_dir", 100 | type=str, 101 | help="The path to output folder", 102 | ) 103 | parser.add_argument( 104 | "-w", 105 | "--output_file", 106 | type=str, 107 | help="The name of output file", 108 | ) 109 | parser.add_argument( 110 | "--save_chunk", 111 | action="store_true", 112 | help="To save each audio chunks during inference", 113 | ) 114 | parser.add_argument( 115 | "--remove_silence", 116 | action="store_true", 117 | help="To remove long silence found in ouput", 118 | ) 119 | parser.add_argument( 120 | "--load_vocoder_from_local", 121 | action="store_true", 122 | help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz", 123 | ) 124 | parser.add_argument( 125 | "--vocoder_name", 126 | type=str, 127 | choices=["vocos", "bigvgan"], 128 | help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}", 129 | ) 130 | parser.add_argument( 131 | "--target_rms", 132 | type=float, 133 | help=f"Target output speech loudness normalization value, default {target_rms}", 134 | ) 135 | parser.add_argument( 136 | "--cross_fade_duration", 137 | type=float, 138 | help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}", 139 | ) 140 | parser.add_argument( 141 | "--nfe_step", 142 | type=int, 143 | help=f"The number of function evaluation (denoising steps), default {nfe_step}", 144 | ) 145 | parser.add_argument( 146 | "--cfg_strength", 147 | type=float, 148 | help=f"Classifier-free guidance strength, default {cfg_strength}", 149 | ) 150 | parser.add_argument( 151 | "--sway_sampling_coef", 152 | type=float, 153 | help=f"Sway Sampling coefficient, default {sway_sampling_coef}", 154 | ) 155 | parser.add_argument( 156 | "--speed", 157 | type=float, 158 | help=f"The speed of the generated audio, default {speed}", 159 | ) 160 | parser.add_argument( 161 | "--fix_duration", 162 | type=float, 163 | help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}", 164 | ) 165 | args = parser.parse_args() 166 | 167 | 168 | # config file 169 | 170 | config = tomli.load(open(args.config, "rb")) 171 | 172 | 173 | # command-line interface parameters 174 | 175 | model = args.model or config.get("model", "F5-TTS") 176 | model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml"))) 177 | ckpt_file = args.ckpt_file or config.get("ckpt_file", "") 178 | vocab_file = args.vocab_file or config.get("vocab_file", "") 179 | 180 | ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav") 181 | ref_text = ( 182 | args.ref_text 183 | if args.ref_text is not None 184 | else config.get("ref_text", "Some call me nature, others call me mother nature.") 185 | ) 186 | gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.") 187 | gen_file = args.gen_file or config.get("gen_file", "") 188 | 189 | output_dir = args.output_dir or config.get("output_dir", "tests") 190 | output_file = args.output_file or config.get( 191 | "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav" 192 | ) 193 | 194 | save_chunk = args.save_chunk or config.get("save_chunk", False) 195 | remove_silence = args.remove_silence or config.get("remove_silence", False) 196 | load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False) 197 | 198 | vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type) 199 | target_rms = args.target_rms or config.get("target_rms", target_rms) 200 | cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration) 201 | nfe_step = args.nfe_step or config.get("nfe_step", nfe_step) 202 | cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength) 203 | sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef) 204 | speed = args.speed or config.get("speed", speed) 205 | fix_duration = args.fix_duration or config.get("fix_duration", fix_duration) 206 | 207 | 208 | # patches for pip pkg user 209 | if "infer/examples/" in ref_audio: 210 | ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}")) 211 | if "infer/examples/" in gen_file: 212 | gen_file = str(files("f5_tts").joinpath(f"{gen_file}")) 213 | if "voices" in config: 214 | for voice in config["voices"]: 215 | voice_ref_audio = config["voices"][voice]["ref_audio"] 216 | if "infer/examples/" in voice_ref_audio: 217 | config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) 218 | 219 | 220 | # ignore gen_text if gen_file provided 221 | 222 | if gen_file: 223 | gen_text = codecs.open(gen_file, "r", "utf-8").read() 224 | 225 | 226 | # output path 227 | 228 | wave_path = Path(output_dir) / output_file 229 | # spectrogram_path = Path(output_dir) / "infer_cli_out.png" 230 | if save_chunk: 231 | output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks") 232 | if not os.path.exists(output_chunk_dir): 233 | os.makedirs(output_chunk_dir) 234 | 235 | 236 | # load vocoder 237 | 238 | if vocoder_name == "vocos": 239 | vocoder_local_path = "../checkpoints/vocos-mel-24khz" 240 | elif vocoder_name == "bigvgan": 241 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" 242 | 243 | vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path) 244 | 245 | 246 | # load TTS model 247 | 248 | if model == "F5-TTS": 249 | model_cls = DiT 250 | model_cfg = OmegaConf.load(model_cfg).model.arch 251 | if not ckpt_file: # path not specified, download from repo 252 | if vocoder_name == "vocos": 253 | repo_name = "F5-TTS" 254 | exp_name = "F5TTS_Base" 255 | ckpt_step = 1200000 256 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) 257 | # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path 258 | elif vocoder_name == "bigvgan": 259 | repo_name = "F5-TTS" 260 | exp_name = "F5TTS_Base_bigvgan" 261 | ckpt_step = 1250000 262 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) 263 | 264 | elif model == "E2-TTS": 265 | assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet" 266 | assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet" 267 | model_cls = UNetT 268 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 269 | if not ckpt_file: # path not specified, download from repo 270 | repo_name = "E2-TTS" 271 | exp_name = "E2TTS_Base" 272 | ckpt_step = 1200000 273 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) 274 | # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path 275 | 276 | print(f"Using {model}...") 277 | ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) 278 | 279 | 280 | # inference process 281 | 282 | 283 | def main(): 284 | main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} 285 | if "voices" not in config: 286 | voices = {"main": main_voice} 287 | else: 288 | voices = config["voices"] 289 | voices["main"] = main_voice 290 | for voice in voices: 291 | print("Voice:", voice) 292 | print("ref_audio ", voices[voice]["ref_audio"]) 293 | voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( 294 | voices[voice]["ref_audio"], voices[voice]["ref_text"] 295 | ) 296 | print("ref_audio_", voices[voice]["ref_audio"], "\n\n") 297 | 298 | generated_audio_segments = [] 299 | reg1 = r"(?=\[\w+\])" 300 | chunks = re.split(reg1, gen_text) 301 | reg2 = r"\[(\w+)\]" 302 | for text in chunks: 303 | if not text.strip(): 304 | continue 305 | match = re.match(reg2, text) 306 | if match: 307 | voice = match[1] 308 | else: 309 | print("No voice tag found, using main.") 310 | voice = "main" 311 | if voice not in voices: 312 | print(f"Voice {voice} not found, using main.") 313 | voice = "main" 314 | text = re.sub(reg2, "", text) 315 | ref_audio_ = voices[voice]["ref_audio"] 316 | ref_text_ = voices[voice]["ref_text"] 317 | gen_text_ = text.strip() 318 | print(f"Voice: {voice}") 319 | audio_segment, final_sample_rate, spectragram = infer_process( 320 | ref_audio_, 321 | ref_text_, 322 | gen_text_, 323 | ema_model, 324 | vocoder, 325 | mel_spec_type=vocoder_name, 326 | target_rms=target_rms, 327 | cross_fade_duration=cross_fade_duration, 328 | nfe_step=nfe_step, 329 | cfg_strength=cfg_strength, 330 | sway_sampling_coef=sway_sampling_coef, 331 | speed=speed, 332 | fix_duration=fix_duration, 333 | ) 334 | generated_audio_segments.append(audio_segment) 335 | 336 | if save_chunk: 337 | if len(gen_text_) > 200: 338 | gen_text_ = gen_text_[:200] + " ... " 339 | sf.write( 340 | os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"), 341 | audio_segment, 342 | final_sample_rate, 343 | ) 344 | 345 | if generated_audio_segments: 346 | final_wave = np.concatenate(generated_audio_segments) 347 | 348 | if not os.path.exists(output_dir): 349 | os.makedirs(output_dir) 350 | 351 | with open(wave_path, "wb") as f: 352 | sf.write(f.name, final_wave, final_sample_rate) 353 | # Remove silence 354 | if remove_silence: 355 | remove_silence_for_generated_wav(f.name) 356 | print(f.name) 357 | 358 | 359 | if __name__ == "__main__": 360 | main() 361 | -------------------------------------------------------------------------------- /f5_tts/infer/speech_edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchaudio 8 | 9 | from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram 10 | from f5_tts.model import CFM, DiT, UNetT 11 | from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer 12 | 13 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 14 | 15 | 16 | # --------------------- Dataset Settings -------------------- # 17 | 18 | target_sample_rate = 24000 19 | n_mel_channels = 100 20 | hop_length = 256 21 | win_length = 1024 22 | n_fft = 1024 23 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan' 24 | target_rms = 0.1 25 | 26 | tokenizer = "pinyin" 27 | dataset_name = "Emilia_ZH_EN" 28 | 29 | 30 | # ---------------------- infer setting ---------------------- # 31 | 32 | seed = None # int | None 33 | 34 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base 35 | ckpt_step = 1200000 36 | 37 | nfe_step = 32 # 16, 32 38 | cfg_strength = 2.0 39 | ode_method = "euler" # euler | midpoint 40 | sway_sampling_coef = -1.0 41 | speed = 1.0 42 | 43 | if exp_name == "F5TTS_Base": 44 | model_cls = DiT 45 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 46 | 47 | elif exp_name == "E2TTS_Base": 48 | model_cls = UNetT 49 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 50 | 51 | ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" 52 | output_dir = "tests" 53 | 54 | # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] 55 | # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git 56 | # [write the origin_text into a file, e.g. tests/test_edit.txt] 57 | # ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char" 58 | # [result will be saved at same path of audio file] 59 | # [--language "zho" for Chinese, "eng" for English] 60 | # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] 61 | 62 | audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav" 63 | origin_text = "Some call me nature, others call me mother nature." 64 | target_text = "Some call me optimist, others call me realist." 65 | parts_to_edit = [ 66 | [1.42, 2.44], 67 | [4.04, 4.9], 68 | ] # stard_ends of "nature" & "mother nature", in seconds 69 | fix_duration = [ 70 | 1.2, 71 | 1, 72 | ] # fix duration for "optimist" & "realist", in seconds 73 | 74 | # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav" 75 | # origin_text = "对,这就是我,万人敬仰的太乙真人。" 76 | # target_text = "对,那就是你,万人敬仰的太白金星。" 77 | # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] 78 | # fix_duration = None # use origin text duration 79 | 80 | 81 | # -------------------------------------------------# 82 | 83 | use_ema = True 84 | 85 | if not os.path.exists(output_dir): 86 | os.makedirs(output_dir) 87 | 88 | # Vocoder model 89 | local = False 90 | if mel_spec_type == "vocos": 91 | vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" 92 | elif mel_spec_type == "bigvgan": 93 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" 94 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) 95 | 96 | # Tokenizer 97 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 98 | 99 | # Model 100 | model = CFM( 101 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 102 | mel_spec_kwargs=dict( 103 | n_fft=n_fft, 104 | hop_length=hop_length, 105 | win_length=win_length, 106 | n_mel_channels=n_mel_channels, 107 | target_sample_rate=target_sample_rate, 108 | mel_spec_type=mel_spec_type, 109 | ), 110 | odeint_kwargs=dict( 111 | method=ode_method, 112 | ), 113 | vocab_char_map=vocab_char_map, 114 | ).to(device) 115 | 116 | dtype = torch.float32 if mel_spec_type == "bigvgan" else None 117 | model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) 118 | 119 | # Audio 120 | audio, sr = torchaudio.load(audio_to_edit) 121 | if audio.shape[0] > 1: 122 | audio = torch.mean(audio, dim=0, keepdim=True) 123 | rms = torch.sqrt(torch.mean(torch.square(audio))) 124 | if rms < target_rms: 125 | audio = audio * target_rms / rms 126 | if sr != target_sample_rate: 127 | resampler = torchaudio.transforms.Resample(sr, target_sample_rate) 128 | audio = resampler(audio) 129 | offset = 0 130 | audio_ = torch.zeros(1, 0) 131 | edit_mask = torch.zeros(1, 0, dtype=torch.bool) 132 | for part in parts_to_edit: 133 | start, end = part 134 | part_dur = end - start if fix_duration is None else fix_duration.pop(0) 135 | part_dur = part_dur * target_sample_rate 136 | start = start * target_sample_rate 137 | audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1) 138 | edit_mask = torch.cat( 139 | ( 140 | edit_mask, 141 | torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool), 142 | torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool), 143 | ), 144 | dim=-1, 145 | ) 146 | offset = end * target_sample_rate 147 | # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1) 148 | edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True) 149 | audio = audio.to(device) 150 | edit_mask = edit_mask.to(device) 151 | 152 | # Text 153 | text_list = [target_text] 154 | if tokenizer == "pinyin": 155 | final_text_list = convert_char_to_pinyin(text_list) 156 | else: 157 | final_text_list = [text_list] 158 | print(f"text : {text_list}") 159 | print(f"pinyin: {final_text_list}") 160 | 161 | # Duration 162 | ref_audio_len = 0 163 | duration = audio.shape[-1] // hop_length 164 | 165 | # Inference 166 | with torch.inference_mode(): 167 | generated, trajectory = model.sample( 168 | cond=audio, 169 | text=final_text_list, 170 | duration=duration, 171 | steps=nfe_step, 172 | cfg_strength=cfg_strength, 173 | sway_sampling_coef=sway_sampling_coef, 174 | seed=seed, 175 | edit_mask=edit_mask, 176 | ) 177 | print(f"Generated mel: {generated.shape}") 178 | 179 | # Final result 180 | generated = generated.to(torch.float32) 181 | generated = generated[:, ref_audio_len:, :] 182 | gen_mel_spec = generated.permute(0, 2, 1) 183 | if mel_spec_type == "vocos": 184 | generated_wave = vocoder.decode(gen_mel_spec).cpu() 185 | elif mel_spec_type == "bigvgan": 186 | generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() 187 | 188 | if rms < target_rms: 189 | generated_wave = generated_wave * rms / target_rms 190 | 191 | save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") 192 | torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate) 193 | print(f"Generated wav: {generated_wave.shape}") 194 | -------------------------------------------------------------------------------- /f5_tts/model/__init__.py: -------------------------------------------------------------------------------- 1 | from f5_tts.model.cfm import CFM 2 | 3 | from f5_tts.model.backbones.unett import UNetT 4 | from f5_tts.model.backbones.dit import DiT 5 | from f5_tts.model.backbones.mmdit import MMDiT 6 | 7 | from f5_tts.model.trainer import Trainer 8 | 9 | 10 | __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] 11 | -------------------------------------------------------------------------------- /f5_tts/model/backbones/README.md: -------------------------------------------------------------------------------- 1 | ## Backbones quick introduction 2 | 3 | 4 | ### unett.py 5 | - flat unet transformer 6 | - structure same as in e2-tts & voicebox paper except using rotary pos emb 7 | - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat 8 | 9 | ### dit.py 10 | - adaln-zero dit 11 | - embedded timestep as condition 12 | - concatted noised_input + masked_cond + embedded_text, linear proj in 13 | - possible abs pos emb & convnextv2 blocks for embedded text before concat 14 | - possible long skip connection (first layer to last layer) 15 | 16 | ### mmdit.py 17 | - sd3 structure 18 | - timestep as condition 19 | - left stream: text embedded and applied a abs pos emb 20 | - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett 21 | -------------------------------------------------------------------------------- /f5_tts/model/backbones/dit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | from x_transformers.x_transformers import RotaryEmbedding 17 | 18 | from f5_tts.model.modules import ( 19 | TimestepEmbedding, 20 | ConvNeXtV2Block, 21 | ConvPositionEmbedding, 22 | DiTBlock, 23 | AdaLayerNormZero_Final, 24 | precompute_freqs_cis, 25 | get_pos_embed_indices, 26 | ) 27 | 28 | 29 | # Text embedding 30 | 31 | 32 | class TextEmbedding(nn.Module): 33 | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): 34 | super().__init__() 35 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token 36 | 37 | if conv_layers > 0: 38 | self.extra_modeling = True 39 | self.precompute_max_pos = 4096 # ~44s of 24khz audio 40 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) 41 | self.text_blocks = nn.Sequential( 42 | *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] 43 | ) 44 | else: 45 | self.extra_modeling = False 46 | 47 | def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 48 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() 49 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens 50 | batch, text_len = text.shape[0], text.shape[1] 51 | text = F.pad(text, (0, seq_len - text_len), value=0) 52 | 53 | if drop_text: # cfg for text 54 | text = torch.zeros_like(text) 55 | 56 | text = self.text_embed(text) # b n -> b n d 57 | 58 | # possible extra modeling 59 | if self.extra_modeling: 60 | # sinus pos emb 61 | batch_start = torch.zeros((batch,), dtype=torch.long) 62 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) 63 | text_pos_embed = self.freqs_cis[pos_idx] 64 | text = text + text_pos_embed 65 | 66 | # convnextv2 blocks 67 | text = self.text_blocks(text) 68 | 69 | return text 70 | 71 | 72 | # noised input audio and context mixing embedding 73 | 74 | 75 | class InputEmbedding(nn.Module): 76 | def __init__(self, mel_dim, text_dim, out_dim): 77 | super().__init__() 78 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) 79 | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) 80 | 81 | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 82 | if drop_audio_cond: # cfg for cond audio 83 | cond = torch.zeros_like(cond) 84 | 85 | x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) 86 | x = self.conv_pos_embed(x) + x 87 | return x 88 | 89 | 90 | # Transformer backbone using DiT blocks 91 | 92 | 93 | class DiT(nn.Module): 94 | def __init__( 95 | self, 96 | *, 97 | dim, 98 | depth=8, 99 | heads=8, 100 | dim_head=64, 101 | dropout=0.1, 102 | ff_mult=4, 103 | mel_dim=100, 104 | text_num_embeds=256, 105 | text_dim=None, 106 | conv_layers=0, 107 | long_skip_connection=False, 108 | checkpoint_activations=False, 109 | ): 110 | super().__init__() 111 | 112 | self.time_embed = TimestepEmbedding(dim) 113 | if text_dim is None: 114 | text_dim = mel_dim 115 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) 116 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim) 117 | 118 | self.rotary_embed = RotaryEmbedding(dim_head) 119 | 120 | self.dim = dim 121 | self.depth = depth 122 | 123 | self.transformer_blocks = nn.ModuleList( 124 | [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] 125 | ) 126 | self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None 127 | 128 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation 129 | self.proj_out = nn.Linear(dim, mel_dim) 130 | 131 | self.checkpoint_activations = checkpoint_activations 132 | 133 | def ckpt_wrapper(self, module): 134 | # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py 135 | def ckpt_forward(*inputs): 136 | outputs = module(*inputs) 137 | return outputs 138 | 139 | return ckpt_forward 140 | 141 | def forward( 142 | self, 143 | x: float["b n d"], # nosied input audio # noqa: F722 144 | cond: float["b n d"], # masked cond audio # noqa: F722 145 | text: int["b nt"], # text # noqa: F722 146 | time: float["b"] | float[""], # time step # noqa: F821 F722 147 | drop_audio_cond, # cfg for cond audio 148 | drop_text, # cfg for text 149 | mask: bool["b n"] | None = None, # noqa: F722 150 | ): 151 | batch, seq_len = x.shape[0], x.shape[1] 152 | if time.ndim == 0: 153 | time = time.repeat(batch) 154 | 155 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio 156 | t = self.time_embed(time) 157 | text_embed = self.text_embed(text, seq_len, drop_text=drop_text) 158 | x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) 159 | 160 | rope = self.rotary_embed.forward_from_seq_len(seq_len) 161 | 162 | if self.long_skip_connection is not None: 163 | residual = x 164 | 165 | for block in self.transformer_blocks: 166 | if self.checkpoint_activations: 167 | x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope) 168 | else: 169 | x = block(x, t, mask=mask, rope=rope) 170 | 171 | if self.long_skip_connection is not None: 172 | x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) 173 | 174 | x = self.norm_out(x, t) 175 | output = self.proj_out(x) 176 | 177 | return output 178 | -------------------------------------------------------------------------------- /f5_tts/model/backbones/mmdit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from x_transformers.x_transformers import RotaryEmbedding 16 | 17 | from f5_tts.model.modules import ( 18 | TimestepEmbedding, 19 | ConvPositionEmbedding, 20 | MMDiTBlock, 21 | AdaLayerNormZero_Final, 22 | precompute_freqs_cis, 23 | get_pos_embed_indices, 24 | ) 25 | 26 | 27 | # text embedding 28 | 29 | 30 | class TextEmbedding(nn.Module): 31 | def __init__(self, out_dim, text_num_embeds): 32 | super().__init__() 33 | self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token 34 | 35 | self.precompute_max_pos = 1024 36 | self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) 37 | 38 | def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 39 | text = text + 1 40 | if drop_text: 41 | text = torch.zeros_like(text) 42 | text = self.text_embed(text) 43 | 44 | # sinus pos emb 45 | batch_start = torch.zeros((text.shape[0],), dtype=torch.long) 46 | batch_text_len = text.shape[1] 47 | pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) 48 | text_pos_embed = self.freqs_cis[pos_idx] 49 | 50 | text = text + text_pos_embed 51 | 52 | return text 53 | 54 | 55 | # noised input & masked cond audio embedding 56 | 57 | 58 | class AudioEmbedding(nn.Module): 59 | def __init__(self, in_dim, out_dim): 60 | super().__init__() 61 | self.linear = nn.Linear(2 * in_dim, out_dim) 62 | self.conv_pos_embed = ConvPositionEmbedding(out_dim) 63 | 64 | def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 65 | if drop_audio_cond: 66 | cond = torch.zeros_like(cond) 67 | x = torch.cat((x, cond), dim=-1) 68 | x = self.linear(x) 69 | x = self.conv_pos_embed(x) + x 70 | return x 71 | 72 | 73 | # Transformer backbone using MM-DiT blocks 74 | 75 | 76 | class MMDiT(nn.Module): 77 | def __init__( 78 | self, 79 | *, 80 | dim, 81 | depth=8, 82 | heads=8, 83 | dim_head=64, 84 | dropout=0.1, 85 | ff_mult=4, 86 | text_num_embeds=256, 87 | mel_dim=100, 88 | ): 89 | super().__init__() 90 | 91 | self.time_embed = TimestepEmbedding(dim) 92 | self.text_embed = TextEmbedding(dim, text_num_embeds) 93 | self.audio_embed = AudioEmbedding(mel_dim, dim) 94 | 95 | self.rotary_embed = RotaryEmbedding(dim_head) 96 | 97 | self.dim = dim 98 | self.depth = depth 99 | 100 | self.transformer_blocks = nn.ModuleList( 101 | [ 102 | MMDiTBlock( 103 | dim=dim, 104 | heads=heads, 105 | dim_head=dim_head, 106 | dropout=dropout, 107 | ff_mult=ff_mult, 108 | context_pre_only=i == depth - 1, 109 | ) 110 | for i in range(depth) 111 | ] 112 | ) 113 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation 114 | self.proj_out = nn.Linear(dim, mel_dim) 115 | 116 | def forward( 117 | self, 118 | x: float["b n d"], # nosied input audio # noqa: F722 119 | cond: float["b n d"], # masked cond audio # noqa: F722 120 | text: int["b nt"], # text # noqa: F722 121 | time: float["b"] | float[""], # time step # noqa: F821 F722 122 | drop_audio_cond, # cfg for cond audio 123 | drop_text, # cfg for text 124 | mask: bool["b n"] | None = None, # noqa: F722 125 | ): 126 | batch = x.shape[0] 127 | if time.ndim == 0: 128 | time = time.repeat(batch) 129 | 130 | # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio 131 | t = self.time_embed(time) 132 | c = self.text_embed(text, drop_text=drop_text) 133 | x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) 134 | 135 | seq_len = x.shape[1] 136 | text_len = text.shape[1] 137 | rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) 138 | rope_text = self.rotary_embed.forward_from_seq_len(text_len) 139 | 140 | for block in self.transformer_blocks: 141 | c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) 142 | 143 | x = self.norm_out(x, t) 144 | output = self.proj_out(x) 145 | 146 | return output 147 | -------------------------------------------------------------------------------- /f5_tts/model/backbones/unett.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | from typing import Literal 12 | 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from x_transformers import RMSNorm 18 | from x_transformers.x_transformers import RotaryEmbedding 19 | 20 | from f5_tts.model.modules import ( 21 | TimestepEmbedding, 22 | ConvNeXtV2Block, 23 | ConvPositionEmbedding, 24 | Attention, 25 | AttnProcessor, 26 | FeedForward, 27 | precompute_freqs_cis, 28 | get_pos_embed_indices, 29 | ) 30 | 31 | 32 | # Text embedding 33 | 34 | 35 | class TextEmbedding(nn.Module): 36 | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): 37 | super().__init__() 38 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token 39 | 40 | if conv_layers > 0: 41 | self.extra_modeling = True 42 | self.precompute_max_pos = 4096 # ~44s of 24khz audio 43 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) 44 | self.text_blocks = nn.Sequential( 45 | *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] 46 | ) 47 | else: 48 | self.extra_modeling = False 49 | 50 | def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 51 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() 52 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens 53 | batch, text_len = text.shape[0], text.shape[1] 54 | text = F.pad(text, (0, seq_len - text_len), value=0) 55 | 56 | if drop_text: # cfg for text 57 | text = torch.zeros_like(text) 58 | 59 | text = self.text_embed(text) # b n -> b n d 60 | 61 | # possible extra modeling 62 | if self.extra_modeling: 63 | # sinus pos emb 64 | batch_start = torch.zeros((batch,), dtype=torch.long) 65 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) 66 | text_pos_embed = self.freqs_cis[pos_idx] 67 | text = text + text_pos_embed 68 | 69 | # convnextv2 blocks 70 | text = self.text_blocks(text) 71 | 72 | return text 73 | 74 | 75 | # noised input audio and context mixing embedding 76 | 77 | 78 | class InputEmbedding(nn.Module): 79 | def __init__(self, mel_dim, text_dim, out_dim): 80 | super().__init__() 81 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) 82 | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) 83 | 84 | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 85 | if drop_audio_cond: # cfg for cond audio 86 | cond = torch.zeros_like(cond) 87 | 88 | x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) 89 | x = self.conv_pos_embed(x) + x 90 | return x 91 | 92 | 93 | # Flat UNet Transformer backbone 94 | 95 | 96 | class UNetT(nn.Module): 97 | def __init__( 98 | self, 99 | *, 100 | dim, 101 | depth=8, 102 | heads=8, 103 | dim_head=64, 104 | dropout=0.1, 105 | ff_mult=4, 106 | mel_dim=100, 107 | text_num_embeds=256, 108 | text_dim=None, 109 | conv_layers=0, 110 | skip_connect_type: Literal["add", "concat", "none"] = "concat", 111 | ): 112 | super().__init__() 113 | assert depth % 2 == 0, "UNet-Transformer's depth should be even." 114 | 115 | self.time_embed = TimestepEmbedding(dim) 116 | if text_dim is None: 117 | text_dim = mel_dim 118 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) 119 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim) 120 | 121 | self.rotary_embed = RotaryEmbedding(dim_head) 122 | 123 | # transformer layers & skip connections 124 | 125 | self.dim = dim 126 | self.skip_connect_type = skip_connect_type 127 | needs_skip_proj = skip_connect_type == "concat" 128 | 129 | self.depth = depth 130 | self.layers = nn.ModuleList([]) 131 | 132 | for idx in range(depth): 133 | is_later_half = idx >= (depth // 2) 134 | 135 | attn_norm = RMSNorm(dim) 136 | attn = Attention( 137 | processor=AttnProcessor(), 138 | dim=dim, 139 | heads=heads, 140 | dim_head=dim_head, 141 | dropout=dropout, 142 | ) 143 | 144 | ff_norm = RMSNorm(dim) 145 | ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") 146 | 147 | skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None 148 | 149 | self.layers.append( 150 | nn.ModuleList( 151 | [ 152 | skip_proj, 153 | attn_norm, 154 | attn, 155 | ff_norm, 156 | ff, 157 | ] 158 | ) 159 | ) 160 | 161 | self.norm_out = RMSNorm(dim) 162 | self.proj_out = nn.Linear(dim, mel_dim) 163 | 164 | def forward( 165 | self, 166 | x: float["b n d"], # nosied input audio # noqa: F722 167 | cond: float["b n d"], # masked cond audio # noqa: F722 168 | text: int["b nt"], # text # noqa: F722 169 | time: float["b"] | float[""], # time step # noqa: F821 F722 170 | drop_audio_cond, # cfg for cond audio 171 | drop_text, # cfg for text 172 | mask: bool["b n"] | None = None, # noqa: F722 173 | ): 174 | batch, seq_len = x.shape[0], x.shape[1] 175 | if time.ndim == 0: 176 | time = time.repeat(batch) 177 | 178 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio 179 | t = self.time_embed(time) 180 | text_embed = self.text_embed(text, seq_len, drop_text=drop_text) 181 | x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) 182 | 183 | # postfix time t to input x, [b n d] -> [b n+1 d] 184 | x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x 185 | if mask is not None: 186 | mask = F.pad(mask, (1, 0), value=1) 187 | 188 | rope = self.rotary_embed.forward_from_seq_len(seq_len + 1) 189 | 190 | # flat unet transformer 191 | skip_connect_type = self.skip_connect_type 192 | skips = [] 193 | for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): 194 | layer = idx + 1 195 | 196 | # skip connection logic 197 | is_first_half = layer <= (self.depth // 2) 198 | is_later_half = not is_first_half 199 | 200 | if is_first_half: 201 | skips.append(x) 202 | 203 | if is_later_half: 204 | skip = skips.pop() 205 | if skip_connect_type == "concat": 206 | x = torch.cat((x, skip), dim=-1) 207 | x = maybe_skip_proj(x) 208 | elif skip_connect_type == "add": 209 | x = x + skip 210 | 211 | # attention and feedforward blocks 212 | x = attn(attn_norm(x), rope=rope, mask=mask) + x 213 | x = ff(ff_norm(x)) + x 214 | 215 | assert len(skips) == 0 216 | 217 | x = self.norm_out(x)[:, 1:, :] # unpack t from x 218 | 219 | return self.proj_out(x) 220 | -------------------------------------------------------------------------------- /f5_tts/model/cfm.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from random import random 13 | from typing import Callable 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch import nn 18 | from torch.nn.utils.rnn import pad_sequence 19 | from torchdiffeq import odeint 20 | 21 | from f5_tts.model.modules import MelSpec 22 | from f5_tts.model.utils import ( 23 | default, 24 | exists, 25 | lens_to_mask, 26 | list_str_to_idx, 27 | list_str_to_tensor, 28 | mask_from_frac_lengths, 29 | ) 30 | 31 | 32 | class CFM(nn.Module): 33 | def __init__( 34 | self, 35 | transformer: nn.Module, 36 | sigma=0.0, 37 | odeint_kwargs: dict = dict( 38 | # atol = 1e-5, 39 | # rtol = 1e-5, 40 | method="euler" # 'midpoint' 41 | ), 42 | audio_drop_prob=0.3, 43 | cond_drop_prob=0.2, 44 | num_channels=None, 45 | mel_spec_module: nn.Module | None = None, 46 | mel_spec_kwargs: dict = dict(), 47 | frac_lengths_mask: tuple[float, float] = (0.7, 1.0), 48 | vocab_char_map: dict[str:int] | None = None, 49 | ): 50 | super().__init__() 51 | 52 | self.frac_lengths_mask = frac_lengths_mask 53 | 54 | # mel spec 55 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) 56 | num_channels = default(num_channels, self.mel_spec.n_mel_channels) 57 | self.num_channels = num_channels 58 | 59 | # classifier-free guidance 60 | self.audio_drop_prob = audio_drop_prob 61 | self.cond_drop_prob = cond_drop_prob 62 | 63 | # transformer 64 | self.transformer = transformer 65 | dim = transformer.dim 66 | self.dim = dim 67 | 68 | # conditional flow related 69 | self.sigma = sigma 70 | 71 | # sampling related 72 | self.odeint_kwargs = odeint_kwargs 73 | 74 | # vocab map for tokenization 75 | self.vocab_char_map = vocab_char_map 76 | 77 | @property 78 | def device(self): 79 | return next(self.parameters()).device 80 | 81 | @torch.no_grad() 82 | def sample( 83 | self, 84 | cond: float["b n d"] | float["b nw"], # noqa: F722 85 | text: int["b nt"] | list[str], # noqa: F722 86 | duration: int | int["b"], # noqa: F821 87 | *, 88 | lens: int["b"] | None = None, # noqa: F821 89 | steps=32, 90 | cfg_strength=1.0, 91 | sway_sampling_coef=None, 92 | seed: int | None = None, 93 | max_duration=4096, 94 | vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 95 | no_ref_audio=False, 96 | duplicate_test=False, 97 | t_inter=0.1, 98 | edit_mask=None, 99 | ): 100 | self.eval() 101 | # raw wave 102 | 103 | if cond.ndim == 2: 104 | cond = self.mel_spec(cond) 105 | cond = cond.permute(0, 2, 1) 106 | assert cond.shape[-1] == self.num_channels 107 | 108 | cond = cond.to(next(self.parameters()).dtype) 109 | 110 | batch, cond_seq_len, device = *cond.shape[:2], cond.device 111 | if not exists(lens): 112 | lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) 113 | 114 | # text 115 | 116 | if isinstance(text, list): 117 | if exists(self.vocab_char_map): 118 | text = list_str_to_idx(text, self.vocab_char_map).to(device) 119 | else: 120 | text = list_str_to_tensor(text).to(device) 121 | assert text.shape[0] == batch 122 | 123 | if exists(text): 124 | text_lens = (text != -1).sum(dim=-1) 125 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters 126 | 127 | # duration 128 | 129 | cond_mask = lens_to_mask(lens) 130 | if edit_mask is not None: 131 | cond_mask = cond_mask & edit_mask 132 | 133 | if isinstance(duration, int): 134 | duration = torch.full((batch,), duration, device=device, dtype=torch.long) 135 | 136 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated 137 | duration = duration.clamp(max=max_duration) 138 | max_duration = duration.amax() 139 | 140 | # duplicate test corner for inner time step oberservation 141 | if duplicate_test: 142 | test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) 143 | 144 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) 145 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) 146 | cond_mask = cond_mask.unsqueeze(-1) 147 | step_cond = torch.where( 148 | cond_mask, cond, torch.zeros_like(cond) 149 | ) # allow direct control (cut cond audio) with lens passed in 150 | 151 | if batch > 1: 152 | mask = lens_to_mask(duration) 153 | else: # save memory and speed up, as single inference need no mask currently 154 | mask = None 155 | 156 | # test for no ref audio 157 | if no_ref_audio: 158 | cond = torch.zeros_like(cond) 159 | 160 | # neural ode 161 | 162 | def fn(t, x): 163 | # at each step, conditioning is fixed 164 | # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) 165 | 166 | # predict flow 167 | pred = self.transformer( 168 | x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False 169 | ) 170 | if cfg_strength < 1e-5: 171 | return pred 172 | 173 | null_pred = self.transformer( 174 | x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True 175 | ) 176 | return pred + (pred - null_pred) * cfg_strength 177 | 178 | # noise input 179 | # to make sure batch inference result is same with different batch size, and for sure single inference 180 | # still some difference maybe due to convolutional layers 181 | y0 = [] 182 | for dur in duration: 183 | if exists(seed): 184 | torch.manual_seed(seed) 185 | y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) 186 | y0 = pad_sequence(y0, padding_value=0, batch_first=True) 187 | 188 | t_start = 0 189 | 190 | # duplicate test corner for inner time step oberservation 191 | if duplicate_test: 192 | t_start = t_inter 193 | y0 = (1 - t_start) * y0 + t_start * test_cond 194 | steps = int(steps * (1 - t_start)) 195 | 196 | t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) 197 | if sway_sampling_coef is not None: 198 | t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) 199 | 200 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs) 201 | 202 | sampled = trajectory[-1] 203 | out = sampled 204 | out = torch.where(cond_mask, cond, out) 205 | 206 | if exists(vocoder): 207 | out = out.permute(0, 2, 1) 208 | out = vocoder(out) 209 | 210 | return out, trajectory 211 | 212 | def forward( 213 | self, 214 | inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 215 | text: int["b nt"] | list[str], # noqa: F722 216 | *, 217 | lens: int["b"] | None = None, # noqa: F821 218 | noise_scheduler: str | None = None, 219 | ): 220 | # handle raw wave 221 | if inp.ndim == 2: 222 | inp = self.mel_spec(inp) 223 | inp = inp.permute(0, 2, 1) 224 | assert inp.shape[-1] == self.num_channels 225 | 226 | batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma 227 | 228 | # handle text as string 229 | if isinstance(text, list): 230 | if exists(self.vocab_char_map): 231 | text = list_str_to_idx(text, self.vocab_char_map).to(device) 232 | else: 233 | text = list_str_to_tensor(text).to(device) 234 | assert text.shape[0] == batch 235 | 236 | # lens and mask 237 | if not exists(lens): 238 | lens = torch.full((batch,), seq_len, device=device) 239 | 240 | mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch 241 | 242 | # get a random span to mask out for training conditionally 243 | frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) 244 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) 245 | 246 | if exists(mask): 247 | rand_span_mask &= mask 248 | 249 | # mel is x1 250 | x1 = inp 251 | 252 | # x0 is gaussian noise 253 | x0 = torch.randn_like(x1) 254 | 255 | # time step 256 | time = torch.rand((batch,), dtype=dtype, device=self.device) 257 | # TODO. noise_scheduler 258 | 259 | # sample xt (φ_t(x) in the paper) 260 | t = time.unsqueeze(-1).unsqueeze(-1) 261 | φ = (1 - t) * x0 + t * x1 262 | flow = x1 - x0 263 | 264 | # only predict what is within the random mask span for infilling 265 | cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) 266 | 267 | # transformer and cfg training with a drop rate 268 | drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper 269 | if random() < self.cond_drop_prob: # p_uncond in voicebox paper 270 | drop_audio_cond = True 271 | drop_text = True 272 | else: 273 | drop_text = False 274 | 275 | # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here 276 | # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences 277 | pred = self.transformer( 278 | x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text 279 | ) 280 | 281 | # flow matching loss 282 | loss = F.mse_loss(pred, flow, reduction="none") 283 | loss = loss[rand_span_mask] 284 | 285 | return loss.mean(), cond, pred 286 | -------------------------------------------------------------------------------- /f5_tts/model/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from importlib.resources import files 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchaudio 8 | from datasets import Dataset as Dataset_ 9 | from datasets import load_from_disk 10 | from torch import nn 11 | from torch.utils.data import Dataset, Sampler 12 | from tqdm import tqdm 13 | 14 | from f5_tts.model.modules import MelSpec 15 | from f5_tts.model.utils import default 16 | 17 | 18 | class HFDataset(Dataset): 19 | def __init__( 20 | self, 21 | hf_dataset: Dataset, 22 | target_sample_rate=24_000, 23 | n_mel_channels=100, 24 | hop_length=256, 25 | n_fft=1024, 26 | win_length=1024, 27 | mel_spec_type="vocos", 28 | ): 29 | self.data = hf_dataset 30 | self.target_sample_rate = target_sample_rate 31 | self.hop_length = hop_length 32 | 33 | self.mel_spectrogram = MelSpec( 34 | n_fft=n_fft, 35 | hop_length=hop_length, 36 | win_length=win_length, 37 | n_mel_channels=n_mel_channels, 38 | target_sample_rate=target_sample_rate, 39 | mel_spec_type=mel_spec_type, 40 | ) 41 | 42 | def get_frame_len(self, index): 43 | row = self.data[index] 44 | audio = row["audio"]["array"] 45 | sample_rate = row["audio"]["sampling_rate"] 46 | return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def __getitem__(self, index): 52 | row = self.data[index] 53 | audio = row["audio"]["array"] 54 | 55 | # logger.info(f"Audio shape: {audio.shape}") 56 | 57 | sample_rate = row["audio"]["sampling_rate"] 58 | duration = audio.shape[-1] / sample_rate 59 | 60 | if duration > 30 or duration < 0.3: 61 | return self.__getitem__((index + 1) % len(self.data)) 62 | 63 | audio_tensor = torch.from_numpy(audio).float() 64 | 65 | if sample_rate != self.target_sample_rate: 66 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) 67 | audio_tensor = resampler(audio_tensor) 68 | 69 | audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') 70 | 71 | mel_spec = self.mel_spectrogram(audio_tensor) 72 | 73 | mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' 74 | 75 | text = row["text"] 76 | 77 | return dict( 78 | mel_spec=mel_spec, 79 | text=text, 80 | ) 81 | 82 | 83 | class CustomDataset(Dataset): 84 | def __init__( 85 | self, 86 | custom_dataset: Dataset, 87 | durations=None, 88 | target_sample_rate=24_000, 89 | hop_length=256, 90 | n_mel_channels=100, 91 | n_fft=1024, 92 | win_length=1024, 93 | mel_spec_type="vocos", 94 | preprocessed_mel=False, 95 | mel_spec_module: nn.Module | None = None, 96 | ): 97 | self.data = custom_dataset 98 | self.durations = durations 99 | self.target_sample_rate = target_sample_rate 100 | self.hop_length = hop_length 101 | self.n_fft = n_fft 102 | self.win_length = win_length 103 | self.mel_spec_type = mel_spec_type 104 | self.preprocessed_mel = preprocessed_mel 105 | 106 | if not preprocessed_mel: 107 | self.mel_spectrogram = default( 108 | mel_spec_module, 109 | MelSpec( 110 | n_fft=n_fft, 111 | hop_length=hop_length, 112 | win_length=win_length, 113 | n_mel_channels=n_mel_channels, 114 | target_sample_rate=target_sample_rate, 115 | mel_spec_type=mel_spec_type, 116 | ), 117 | ) 118 | 119 | def get_frame_len(self, index): 120 | if ( 121 | self.durations is not None 122 | ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM 123 | return self.durations[index] * self.target_sample_rate / self.hop_length 124 | return self.data[index]["duration"] * self.target_sample_rate / self.hop_length 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def __getitem__(self, index): 130 | while True: 131 | row = self.data[index] 132 | audio_path = row["audio_path"] 133 | text = row["text"] 134 | duration = row["duration"] 135 | 136 | # filter by given length 137 | if 0.3 <= duration <= 30: 138 | break # valid 139 | 140 | index = (index + 1) % len(self.data) 141 | 142 | if self.preprocessed_mel: 143 | mel_spec = torch.tensor(row["mel_spec"]) 144 | else: 145 | audio, source_sample_rate = torchaudio.load(audio_path) 146 | 147 | # make sure mono input 148 | if audio.shape[0] > 1: 149 | audio = torch.mean(audio, dim=0, keepdim=True) 150 | 151 | # resample if necessary 152 | if source_sample_rate != self.target_sample_rate: 153 | resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) 154 | audio = resampler(audio) 155 | 156 | # to mel spectrogram 157 | mel_spec = self.mel_spectrogram(audio) 158 | mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' 159 | 160 | return { 161 | "mel_spec": mel_spec, 162 | "text": text, 163 | } 164 | 165 | 166 | # Dynamic Batch Sampler 167 | class DynamicBatchSampler(Sampler[list[int]]): 168 | """Extension of Sampler that will do the following: 169 | 1. Change the batch size (essentially number of sequences) 170 | in a batch to ensure that the total number of frames are less 171 | than a certain threshold. 172 | 2. Make sure the padding efficiency in the batch is high. 173 | """ 174 | 175 | def __init__( 176 | self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False 177 | ): 178 | self.sampler = sampler 179 | self.frames_threshold = frames_threshold 180 | self.max_samples = max_samples 181 | 182 | indices, batches = [], [] 183 | data_source = self.sampler.data_source 184 | 185 | for idx in tqdm( 186 | self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration" 187 | ): 188 | indices.append((idx, data_source.get_frame_len(idx))) 189 | indices.sort(key=lambda elem: elem[1]) 190 | 191 | batch = [] 192 | batch_frames = 0 193 | for idx, frame_len in tqdm( 194 | indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu" 195 | ): 196 | if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): 197 | batch.append(idx) 198 | batch_frames += frame_len 199 | else: 200 | if len(batch) > 0: 201 | batches.append(batch) 202 | if frame_len <= self.frames_threshold: 203 | batch = [idx] 204 | batch_frames = frame_len 205 | else: 206 | batch = [] 207 | batch_frames = 0 208 | 209 | if not drop_last and len(batch) > 0: 210 | batches.append(batch) 211 | 212 | del indices 213 | 214 | # if want to have different batches between epochs, may just set a seed and log it in ckpt 215 | # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different 216 | # e.g. for epoch n, use (random_seed + n) 217 | random.seed(random_seed) 218 | random.shuffle(batches) 219 | 220 | self.batches = batches 221 | 222 | def __iter__(self): 223 | return iter(self.batches) 224 | 225 | def __len__(self): 226 | return len(self.batches) 227 | 228 | 229 | # Load dataset 230 | 231 | 232 | def load_dataset( 233 | dataset_name: str, 234 | tokenizer: str = "pinyin", 235 | dataset_type: str = "CustomDataset", 236 | audio_type: str = "raw", 237 | mel_spec_module: nn.Module | None = None, 238 | mel_spec_kwargs: dict = dict(), 239 | ) -> CustomDataset | HFDataset: 240 | """ 241 | dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset 242 | - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer 243 | """ 244 | 245 | print("Loading dataset ...") 246 | 247 | if dataset_type == "CustomDataset": 248 | rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")) 249 | if audio_type == "raw": 250 | try: 251 | train_dataset = load_from_disk(f"{rel_data_path}/raw") 252 | except: # noqa: E722 253 | train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow") 254 | preprocessed_mel = False 255 | elif audio_type == "mel": 256 | train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow") 257 | preprocessed_mel = True 258 | with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f: 259 | data_dict = json.load(f) 260 | durations = data_dict["duration"] 261 | train_dataset = CustomDataset( 262 | train_dataset, 263 | durations=durations, 264 | preprocessed_mel=preprocessed_mel, 265 | mel_spec_module=mel_spec_module, 266 | **mel_spec_kwargs, 267 | ) 268 | 269 | elif dataset_type == "CustomDatasetPath": 270 | try: 271 | train_dataset = load_from_disk(f"{dataset_name}/raw") 272 | except: # noqa: E722 273 | train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow") 274 | 275 | with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f: 276 | data_dict = json.load(f) 277 | durations = data_dict["duration"] 278 | train_dataset = CustomDataset( 279 | train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs 280 | ) 281 | 282 | elif dataset_type == "HFDataset": 283 | print( 284 | "Should manually modify the path of huggingface dataset to your need.\n" 285 | + "May also the corresponding script cuz different dataset may have different format." 286 | ) 287 | pre, post = dataset_name.split("_") 288 | train_dataset = HFDataset( 289 | load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))), 290 | ) 291 | 292 | return train_dataset 293 | 294 | 295 | # collation 296 | 297 | 298 | def collate_fn(batch): 299 | mel_specs = [item["mel_spec"].squeeze(0) for item in batch] 300 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) 301 | max_mel_length = mel_lengths.amax() 302 | 303 | padded_mel_specs = [] 304 | for spec in mel_specs: # TODO. maybe records mask for attention here 305 | padding = (0, max_mel_length - spec.size(-1)) 306 | padded_spec = F.pad(spec, padding, value=0) 307 | padded_mel_specs.append(padded_spec) 308 | 309 | mel_specs = torch.stack(padded_mel_specs) 310 | 311 | text = [item["text"] for item in batch] 312 | text_lengths = torch.LongTensor([len(item) for item in text]) 313 | 314 | return dict( 315 | mel=mel_specs, 316 | mel_lengths=mel_lengths, 317 | text=text, 318 | text_lengths=text_lengths, 319 | ) 320 | -------------------------------------------------------------------------------- /f5_tts/model/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import random 5 | from collections import defaultdict 6 | from importlib.resources import files 7 | 8 | import torch 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | import jieba 12 | from pypinyin import lazy_pinyin, Style 13 | 14 | 15 | # seed everything 16 | 17 | 18 | def seed_everything(seed=0): 19 | random.seed(seed) 20 | os.environ["PYTHONHASHSEED"] = str(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | 28 | # helpers 29 | 30 | 31 | def exists(v): 32 | return v is not None 33 | 34 | 35 | def default(v, d): 36 | return v if exists(v) else d 37 | 38 | 39 | # tensor helpers 40 | 41 | 42 | def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 43 | if not exists(length): 44 | length = t.amax() 45 | 46 | seq = torch.arange(length, device=t.device) 47 | return seq[None, :] < t[:, None] 48 | 49 | 50 | def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 51 | max_seq_len = seq_len.max().item() 52 | seq = torch.arange(max_seq_len, device=start.device).long() 53 | start_mask = seq[None, :] >= start[:, None] 54 | end_mask = seq[None, :] < end[:, None] 55 | return start_mask & end_mask 56 | 57 | 58 | def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 59 | lengths = (frac_lengths * seq_len).long() 60 | max_start = seq_len - lengths 61 | 62 | rand = torch.rand_like(frac_lengths) 63 | start = (max_start * rand).long().clamp(min=0) 64 | end = start + lengths 65 | 66 | return mask_from_start_end_indices(seq_len, start, end) 67 | 68 | 69 | def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 70 | if not exists(mask): 71 | return t.mean(dim=1) 72 | 73 | t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) 74 | num = t.sum(dim=1) 75 | den = mask.float().sum(dim=1) 76 | 77 | return num / den.clamp(min=1.0) 78 | 79 | 80 | # simple utf-8 tokenizer, since paper went character based 81 | def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 82 | list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style 83 | text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) 84 | return text 85 | 86 | 87 | # char tokenizer, based on custom dataset's extracted .txt file 88 | def list_str_to_idx( 89 | text: list[str] | list[list[str]], 90 | vocab_char_map: dict[str, int], # {char: idx} 91 | padding_value=-1, 92 | ) -> int["b nt"]: # noqa: F722 93 | list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style 94 | text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) 95 | return text 96 | 97 | 98 | # Get tokenizer 99 | 100 | 101 | def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): 102 | """ 103 | tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file 104 | - "char" for char-wise tokenizer, need .txt vocab_file 105 | - "byte" for utf-8 tokenizer 106 | - "custom" if you're directly passing in a path to the vocab.txt you want to use 107 | vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols 108 | - if use "char", derived from unfiltered character & symbol counts of custom dataset 109 | - if use "byte", set to 256 (unicode byte range) 110 | """ 111 | if tokenizer in ["pinyin", "char"]: 112 | tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") 113 | with open(tokenizer_path, "r", encoding="utf-8") as f: 114 | vocab_char_map = {} 115 | for i, char in enumerate(f): 116 | vocab_char_map[char[:-1]] = i 117 | vocab_size = len(vocab_char_map) 118 | assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" 119 | 120 | elif tokenizer == "byte": 121 | vocab_char_map = None 122 | vocab_size = 256 123 | 124 | elif tokenizer == "custom": 125 | with open(dataset_name, "r", encoding="utf-8") as f: 126 | vocab_char_map = {} 127 | for i, char in enumerate(f): 128 | vocab_char_map[char[:-1]] = i 129 | vocab_size = len(vocab_char_map) 130 | 131 | return vocab_char_map, vocab_size 132 | 133 | 134 | # convert char to pinyin 135 | 136 | jieba.initialize() 137 | print("Word segmentation module jieba initialized.\n") 138 | 139 | 140 | def convert_char_to_pinyin(text_list, polyphone=True): 141 | final_text_list = [] 142 | custom_trans = str.maketrans( 143 | {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} 144 | ) # add custom trans here, to address oov 145 | 146 | def is_chinese(c): 147 | return ( 148 | "\u3100" <= c <= "\u9fff" # common chinese characters 149 | ) 150 | 151 | for text in text_list: 152 | char_list = [] 153 | text = text.translate(custom_trans) 154 | for seg in jieba.cut(text): 155 | seg_byte_len = len(bytes(seg, "UTF-8")) 156 | if seg_byte_len == len(seg): # if pure alphabets and symbols 157 | if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": 158 | char_list.append(" ") 159 | char_list.extend(seg) 160 | elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters 161 | seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) 162 | for i, c in enumerate(seg): 163 | if is_chinese(c): 164 | char_list.append(" ") 165 | char_list.append(seg_[i]) 166 | else: # if mixed characters, alphabets and symbols 167 | for c in seg: 168 | if ord(c) < 256: 169 | char_list.extend(c) 170 | elif is_chinese(c): 171 | char_list.append(" ") 172 | char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) 173 | else: 174 | char_list.append(c) 175 | final_text_list.append(char_list) 176 | 177 | return final_text_list 178 | 179 | 180 | # filter func for dirty data with many repetitions 181 | 182 | 183 | def repetition_found(text, length=2, tolerance=10): 184 | pattern_count = defaultdict(int) 185 | for i in range(len(text) - length + 1): 186 | pattern = text[i : i + length] 187 | pattern_count[pattern] += 1 188 | for pattern, count in pattern_count.items(): 189 | if count > tolerance: 190 | return True 191 | return False 192 | -------------------------------------------------------------------------------- /f5_tts/scripts/count_max_epoch.py: -------------------------------------------------------------------------------- 1 | """ADAPTIVE BATCH SIZE""" 2 | 3 | print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in") 4 | print(" -> least padding, gather wavs with accumulated frames in a batch\n") 5 | 6 | # data 7 | total_hours = 95282 8 | mel_hop_length = 256 9 | mel_sampling_rate = 24000 10 | 11 | # target 12 | wanted_max_updates = 1000000 13 | 14 | # train params 15 | gpus = 8 16 | frames_per_gpu = 38400 # 8 * 38400 = 307200 17 | grad_accum = 1 18 | 19 | # intermediate 20 | mini_batch_frames = frames_per_gpu * grad_accum * gpus 21 | mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 22 | updates_per_epoch = total_hours / mini_batch_hours 23 | steps_per_epoch = updates_per_epoch * grad_accum 24 | 25 | # result 26 | epochs = wanted_max_updates / updates_per_epoch 27 | print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})") 28 | print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") 29 | print(f" or approx. 0/{steps_per_epoch:.0f} steps") 30 | 31 | # others 32 | print(f"total {total_hours:.0f} hours") 33 | print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch") 34 | -------------------------------------------------------------------------------- /f5_tts/scripts/count_params_gflops.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | from f5_tts.model import CFM, DiT 7 | 8 | import torch 9 | import thop 10 | 11 | 12 | """ ~155M """ 13 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) 14 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) 15 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2) 16 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4) 17 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True) 18 | # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2) 19 | 20 | """ ~335M """ 21 | # FLOPs: 622.1 G, Params: 333.2 M 22 | # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 23 | # FLOPs: 363.4 G, Params: 335.8 M 24 | transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 25 | 26 | 27 | model = CFM(transformer=transformer) 28 | target_sample_rate = 24000 29 | n_mel_channels = 100 30 | hop_length = 256 31 | duration = 20 32 | frame_length = int(duration * target_sample_rate / hop_length) 33 | text_length = 150 34 | 35 | flops, params = thop.profile( 36 | model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)) 37 | ) 38 | print(f"FLOPs: {flops / 1e9} G") 39 | print(f"Params: {params / 1e6} M") 40 | -------------------------------------------------------------------------------- /f5_tts/socket_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import socket 4 | import struct 5 | import torch 6 | import torchaudio 7 | import traceback 8 | from importlib.resources import files 9 | from threading import Thread 10 | 11 | from cached_path import cached_path 12 | 13 | from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model 14 | from model.backbones.dit import DiT 15 | 16 | 17 | class TTSStreamingProcessor: 18 | def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): 19 | self.device = device or ( 20 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 21 | ) 22 | 23 | # Load the model using the provided checkpoint and vocab files 24 | self.model = load_model( 25 | model_cls=DiT, 26 | model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), 27 | ckpt_path=ckpt_file, 28 | mel_spec_type="vocos", # or "bigvgan" depending on vocoder 29 | vocab_file=vocab_file, 30 | ode_method="euler", 31 | use_ema=True, 32 | device=self.device, 33 | ).to(self.device, dtype=dtype) 34 | 35 | # Load the vocoder 36 | self.vocoder = load_vocoder(is_local=False) 37 | 38 | # Set sampling rate for streaming 39 | self.sampling_rate = 24000 # Consistency with client 40 | 41 | # Set reference audio and text 42 | self.ref_audio = ref_audio 43 | self.ref_text = ref_text 44 | 45 | # Warm up the model 46 | self._warm_up() 47 | 48 | def _warm_up(self): 49 | """Warm up the model with a dummy input to ensure it's ready for real-time processing.""" 50 | print("Warming up the model...") 51 | ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text) 52 | audio, sr = torchaudio.load(ref_audio) 53 | gen_text = "Warm-up text for the model." 54 | 55 | # Pass the vocoder as an argument here 56 | infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device) 57 | print("Warm-up completed.") 58 | 59 | def generate_stream(self, text, play_steps_in_s=0.5): 60 | """Generate audio in chunks and yield them in real-time.""" 61 | # Preprocess the reference audio and text 62 | ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text) 63 | 64 | # Load reference audio 65 | audio, sr = torchaudio.load(ref_audio) 66 | 67 | # Run inference for the input text 68 | audio_chunk, final_sample_rate, _ = infer_batch_process( 69 | (audio, sr), 70 | ref_text, 71 | [text], 72 | self.model, 73 | self.vocoder, 74 | device=self.device, # Pass vocoder here 75 | ) 76 | 77 | # Break the generated audio into chunks and send them 78 | chunk_size = int(final_sample_rate * play_steps_in_s) 79 | 80 | if len(audio_chunk) < chunk_size: 81 | packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk) 82 | yield packed_audio 83 | return 84 | 85 | for i in range(0, len(audio_chunk), chunk_size): 86 | chunk = audio_chunk[i : i + chunk_size] 87 | 88 | # Check if it's the final chunk 89 | if i + chunk_size >= len(audio_chunk): 90 | chunk = audio_chunk[i:] 91 | 92 | # Send the chunk if it is not empty 93 | if len(chunk) > 0: 94 | packed_audio = struct.pack(f"{len(chunk)}f", *chunk) 95 | yield packed_audio 96 | 97 | 98 | def handle_client(client_socket, processor): 99 | try: 100 | while True: 101 | # Receive data from the client 102 | data = client_socket.recv(1024).decode("utf-8") 103 | if not data: 104 | break 105 | 106 | try: 107 | # The client sends the text input 108 | text = data.strip() 109 | 110 | # Generate and stream audio chunks 111 | for audio_chunk in processor.generate_stream(text): 112 | client_socket.sendall(audio_chunk) 113 | 114 | # Send end-of-audio signal 115 | client_socket.sendall(b"END_OF_AUDIO") 116 | 117 | except Exception as inner_e: 118 | print(f"Error during processing: {inner_e}") 119 | traceback.print_exc() # Print the full traceback to diagnose the issue 120 | break 121 | 122 | except Exception as e: 123 | print(f"Error handling client: {e}") 124 | traceback.print_exc() 125 | finally: 126 | client_socket.close() 127 | 128 | 129 | def start_server(host, port, processor): 130 | server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 131 | server.bind((host, port)) 132 | server.listen(5) 133 | print(f"Server listening on {host}:{port}") 134 | 135 | while True: 136 | client_socket, addr = server.accept() 137 | print(f"Accepted connection from {addr}") 138 | client_handler = Thread(target=handle_client, args=(client_socket, processor)) 139 | client_handler.start() 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | 145 | parser.add_argument("--host", default="0.0.0.0") 146 | parser.add_argument("--port", default=9998) 147 | 148 | parser.add_argument( 149 | "--ckpt_file", 150 | default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")), 151 | help="Path to the model checkpoint file", 152 | ) 153 | parser.add_argument( 154 | "--vocab_file", 155 | default="", 156 | help="Path to the vocab file if customized", 157 | ) 158 | 159 | parser.add_argument( 160 | "--ref_audio", 161 | default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), 162 | help="Reference audio to provide model with speaker characteristics", 163 | ) 164 | parser.add_argument( 165 | "--ref_text", 166 | default="", 167 | help="Reference audio subtitle, leave empty to auto-transcribe", 168 | ) 169 | 170 | parser.add_argument("--device", default=None, help="Device to run the model on") 171 | parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference") 172 | 173 | args = parser.parse_args() 174 | 175 | try: 176 | # Initialize the processor with the model and vocoder 177 | processor = TTSStreamingProcessor( 178 | ckpt_file=args.ckpt_file, 179 | vocab_file=args.vocab_file, 180 | ref_audio=args.ref_audio, 181 | ref_text=args.ref_text, 182 | device=args.device, 183 | dtype=args.dtype, 184 | ) 185 | 186 | # Start the server 187 | start_server(args.host, args.port, processor) 188 | 189 | except KeyboardInterrupt: 190 | gc.collect() 191 | -------------------------------------------------------------------------------- /f5_tts/train/README.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | ## Prepare Dataset 4 | 5 | Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`. 6 | 7 | ### 1. Some specific Datasets preparing scripts 8 | Download corresponding dataset first, and fill in the path in scripts. 9 | 10 | ```bash 11 | # Prepare the Emilia dataset 12 | python src/f5_tts/train/datasets/prepare_emilia.py 13 | 14 | # Prepare the Wenetspeech4TTS dataset 15 | python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py 16 | 17 | # Prepare the LibriTTS dataset 18 | python src/f5_tts/train/datasets/prepare_libritts.py 19 | 20 | # Prepare the LJSpeech dataset 21 | python src/f5_tts/train/datasets/prepare_ljspeech.py 22 | ``` 23 | 24 | ### 2. Create custom dataset with metadata.csv 25 | Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029). 26 | 27 | ```bash 28 | python src/f5_tts/train/datasets/prepare_csv_wavs.py 29 | ``` 30 | 31 | ## Training & Finetuning 32 | 33 | Once your datasets are prepared, you can start the training process. 34 | 35 | ### 1. Training script used for pretrained model 36 | 37 | ```bash 38 | # setup accelerate config, e.g. use multi-gpu ddp, fp16 39 | # will be to: ~/.cache/huggingface/accelerate/default_config.yaml 40 | accelerate config 41 | 42 | # .yaml files are under src/f5_tts/configs directory 43 | accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml 44 | 45 | # possible to overwrite accelerate and hydra config 46 | accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200 47 | ``` 48 | 49 | ### 2. Finetuning practice 50 | Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57). 51 | 52 | Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143). 53 | 54 | The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results. 55 | 56 | ### 3. Wandb Logging 57 | 58 | The `wandb/` dir will be created under path you run training/finetuning scripts. 59 | 60 | By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`). 61 | 62 | To turn on wandb logging, you can either: 63 | 64 | 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login) 65 | 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows: 66 | 67 | On Mac & Linux: 68 | 69 | ``` 70 | export WANDB_API_KEY= 71 | ``` 72 | 73 | On Windows: 74 | 75 | ``` 76 | set WANDB_API_KEY= 77 | ``` 78 | Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows: 79 | 80 | ``` 81 | export WANDB_MODE=offline 82 | ``` 83 | -------------------------------------------------------------------------------- /f5_tts/train/datasets/prepare_csv_wavs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import argparse 7 | import csv 8 | import json 9 | import shutil 10 | from importlib.resources import files 11 | from pathlib import Path 12 | 13 | import torchaudio 14 | from tqdm import tqdm 15 | from datasets.arrow_writer import ArrowWriter 16 | 17 | from f5_tts.model.utils import ( 18 | convert_char_to_pinyin, 19 | ) 20 | 21 | 22 | PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") 23 | 24 | 25 | def is_csv_wavs_format(input_dataset_dir): 26 | fpath = Path(input_dataset_dir) 27 | metadata = fpath / "metadata.csv" 28 | wavs = fpath / "wavs" 29 | return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir() 30 | 31 | 32 | def prepare_csv_wavs_dir(input_dir): 33 | assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}" 34 | input_dir = Path(input_dir) 35 | metadata_path = input_dir / "metadata.csv" 36 | audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix()) 37 | 38 | sub_result, durations = [], [] 39 | vocab_set = set() 40 | polyphone = True 41 | for audio_path, text in audio_path_text_pairs: 42 | if not Path(audio_path).exists(): 43 | print(f"audio {audio_path} not found, skipping") 44 | continue 45 | audio_duration = get_audio_duration(audio_path) 46 | # assume tokenizer = "pinyin" ("pinyin" | "char") 47 | text = convert_char_to_pinyin([text], polyphone=polyphone)[0] 48 | sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration}) 49 | durations.append(audio_duration) 50 | vocab_set.update(list(text)) 51 | 52 | return sub_result, durations, vocab_set 53 | 54 | 55 | def get_audio_duration(audio_path): 56 | audio, sample_rate = torchaudio.load(audio_path) 57 | return audio.shape[1] / sample_rate 58 | 59 | 60 | def read_audio_text_pairs(csv_file_path): 61 | audio_text_pairs = [] 62 | 63 | parent = Path(csv_file_path).parent 64 | with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile: 65 | reader = csv.reader(csvfile, delimiter="|") 66 | next(reader) # Skip the header row 67 | for row in reader: 68 | if len(row) >= 2: 69 | audio_file = row[0].strip() # First column: audio file path 70 | text = row[1].strip() # Second column: text 71 | audio_file_path = parent / audio_file 72 | audio_text_pairs.append((audio_file_path.as_posix(), text)) 73 | 74 | return audio_text_pairs 75 | 76 | 77 | def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune): 78 | out_dir = Path(out_dir) 79 | # save preprocessed dataset to disk 80 | out_dir.mkdir(exist_ok=True, parents=True) 81 | print(f"\nSaving to {out_dir} ...") 82 | 83 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom 84 | # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB") 85 | raw_arrow_path = out_dir / "raw.arrow" 86 | with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer: 87 | for line in tqdm(result, desc="Writing to raw.arrow ..."): 88 | writer.write(line) 89 | 90 | # dup a json separately saving duration in case for DynamicBatchSampler ease 91 | dur_json_path = out_dir / "duration.json" 92 | with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f: 93 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 94 | 95 | # vocab map, i.e. tokenizer 96 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 97 | # if tokenizer == "pinyin": 98 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 99 | voca_out_path = out_dir / "vocab.txt" 100 | with open(voca_out_path.as_posix(), "w") as f: 101 | for vocab in sorted(text_vocab_set): 102 | f.write(vocab + "\n") 103 | 104 | if is_finetune: 105 | file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix() 106 | shutil.copy2(file_vocab_finetune, voca_out_path) 107 | else: 108 | with open(voca_out_path, "w") as f: 109 | for vocab in sorted(text_vocab_set): 110 | f.write(vocab + "\n") 111 | 112 | dataset_name = out_dir.stem 113 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 114 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 115 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 116 | 117 | 118 | def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True): 119 | if is_finetune: 120 | assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" 121 | sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir) 122 | save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune) 123 | 124 | 125 | def cli(): 126 | # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin 127 | # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain 128 | parser = argparse.ArgumentParser(description="Prepare and save dataset.") 129 | parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") 130 | parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") 131 | parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") 132 | 133 | args = parser.parse_args() 134 | 135 | prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain) 136 | 137 | 138 | if __name__ == "__main__": 139 | cli() 140 | -------------------------------------------------------------------------------- /f5_tts/train/datasets/prepare_emilia.py: -------------------------------------------------------------------------------- 1 | # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07 2 | # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script 3 | 4 | # generate audio text map for Emilia ZH & EN 5 | # evaluate for vocab size 6 | 7 | import os 8 | import sys 9 | 10 | sys.path.append(os.getcwd()) 11 | 12 | import json 13 | from concurrent.futures import ProcessPoolExecutor 14 | from importlib.resources import files 15 | from pathlib import Path 16 | from tqdm import tqdm 17 | 18 | from datasets.arrow_writer import ArrowWriter 19 | 20 | from f5_tts.model.utils import ( 21 | repetition_found, 22 | convert_char_to_pinyin, 23 | ) 24 | 25 | 26 | out_zh = { 27 | "ZH_B00041_S06226", 28 | "ZH_B00042_S09204", 29 | "ZH_B00065_S09430", 30 | "ZH_B00065_S09431", 31 | "ZH_B00066_S09327", 32 | "ZH_B00066_S09328", 33 | } 34 | zh_filters = ["い", "て"] 35 | # seems synthesized audios, or heavily code-switched 36 | out_en = { 37 | "EN_B00013_S00913", 38 | "EN_B00042_S00120", 39 | "EN_B00055_S04111", 40 | "EN_B00061_S00693", 41 | "EN_B00061_S01494", 42 | "EN_B00061_S03375", 43 | "EN_B00059_S00092", 44 | "EN_B00111_S04300", 45 | "EN_B00100_S03759", 46 | "EN_B00087_S03811", 47 | "EN_B00059_S00950", 48 | "EN_B00089_S00946", 49 | "EN_B00078_S05127", 50 | "EN_B00070_S04089", 51 | "EN_B00074_S09659", 52 | "EN_B00061_S06983", 53 | "EN_B00061_S07060", 54 | "EN_B00059_S08397", 55 | "EN_B00082_S06192", 56 | "EN_B00091_S01238", 57 | "EN_B00089_S07349", 58 | "EN_B00070_S04343", 59 | "EN_B00061_S02400", 60 | "EN_B00076_S01262", 61 | "EN_B00068_S06467", 62 | "EN_B00076_S02943", 63 | "EN_B00064_S05954", 64 | "EN_B00061_S05386", 65 | "EN_B00066_S06544", 66 | "EN_B00076_S06944", 67 | "EN_B00072_S08620", 68 | "EN_B00076_S07135", 69 | "EN_B00076_S09127", 70 | "EN_B00065_S00497", 71 | "EN_B00059_S06227", 72 | "EN_B00063_S02859", 73 | "EN_B00075_S01547", 74 | "EN_B00061_S08286", 75 | "EN_B00079_S02901", 76 | "EN_B00092_S03643", 77 | "EN_B00096_S08653", 78 | "EN_B00063_S04297", 79 | "EN_B00063_S04614", 80 | "EN_B00079_S04698", 81 | "EN_B00104_S01666", 82 | "EN_B00061_S09504", 83 | "EN_B00061_S09694", 84 | "EN_B00065_S05444", 85 | "EN_B00063_S06860", 86 | "EN_B00065_S05725", 87 | "EN_B00069_S07628", 88 | "EN_B00083_S03875", 89 | "EN_B00071_S07665", 90 | "EN_B00071_S07665", 91 | "EN_B00062_S04187", 92 | "EN_B00065_S09873", 93 | "EN_B00065_S09922", 94 | "EN_B00084_S02463", 95 | "EN_B00067_S05066", 96 | "EN_B00106_S08060", 97 | "EN_B00073_S06399", 98 | "EN_B00073_S09236", 99 | "EN_B00087_S00432", 100 | "EN_B00085_S05618", 101 | "EN_B00064_S01262", 102 | "EN_B00072_S01739", 103 | "EN_B00059_S03913", 104 | "EN_B00069_S04036", 105 | "EN_B00067_S05623", 106 | "EN_B00060_S05389", 107 | "EN_B00060_S07290", 108 | "EN_B00062_S08995", 109 | } 110 | en_filters = ["ا", "い", "て"] 111 | 112 | 113 | def deal_with_audio_dir(audio_dir): 114 | audio_jsonl = audio_dir.with_suffix(".jsonl") 115 | sub_result, durations = [], [] 116 | vocab_set = set() 117 | bad_case_zh = 0 118 | bad_case_en = 0 119 | with open(audio_jsonl, "r") as f: 120 | lines = f.readlines() 121 | for line in tqdm(lines, desc=f"{audio_jsonl.stem}"): 122 | obj = json.loads(line) 123 | text = obj["text"] 124 | if obj["language"] == "zh": 125 | if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text): 126 | bad_case_zh += 1 127 | continue 128 | else: 129 | text = text.translate( 130 | str.maketrans({",": ",", "!": "!", "?": "?"}) 131 | ) # not "。" cuz much code-switched 132 | if obj["language"] == "en": 133 | if ( 134 | obj["wav"].split("/")[1] in out_en 135 | or any(f in text for f in en_filters) 136 | or repetition_found(text, length=4) 137 | ): 138 | bad_case_en += 1 139 | continue 140 | if tokenizer == "pinyin": 141 | text = convert_char_to_pinyin([text], polyphone=polyphone)[0] 142 | duration = obj["duration"] 143 | sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration}) 144 | durations.append(duration) 145 | vocab_set.update(list(text)) 146 | return sub_result, durations, vocab_set, bad_case_zh, bad_case_en 147 | 148 | 149 | def main(): 150 | assert tokenizer in ["pinyin", "char"] 151 | result = [] 152 | duration_list = [] 153 | text_vocab_set = set() 154 | total_bad_case_zh = 0 155 | total_bad_case_en = 0 156 | 157 | # process raw data 158 | executor = ProcessPoolExecutor(max_workers=max_workers) 159 | futures = [] 160 | for lang in langs: 161 | dataset_path = Path(os.path.join(dataset_dir, lang)) 162 | [ 163 | futures.append(executor.submit(deal_with_audio_dir, audio_dir)) 164 | for audio_dir in dataset_path.iterdir() 165 | if audio_dir.is_dir() 166 | ] 167 | for futures in tqdm(futures, total=len(futures)): 168 | sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result() 169 | result.extend(sub_result) 170 | duration_list.extend(durations) 171 | text_vocab_set.update(vocab_set) 172 | total_bad_case_zh += bad_case_zh 173 | total_bad_case_en += bad_case_en 174 | executor.shutdown() 175 | 176 | # save preprocessed dataset to disk 177 | if not os.path.exists(f"{save_dir}"): 178 | os.makedirs(f"{save_dir}") 179 | print(f"\nSaving to {save_dir} ...") 180 | 181 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom 182 | # dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") 183 | with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: 184 | for line in tqdm(result, desc="Writing to raw.arrow ..."): 185 | writer.write(line) 186 | 187 | # dup a json separately saving duration in case for DynamicBatchSampler ease 188 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: 189 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 190 | 191 | # vocab map, i.e. tokenizer 192 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 193 | # if tokenizer == "pinyin": 194 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 195 | with open(f"{save_dir}/vocab.txt", "w") as f: 196 | for vocab in sorted(text_vocab_set): 197 | f.write(vocab + "\n") 198 | 199 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 200 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 201 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 202 | if "ZH" in langs: 203 | print(f"Bad zh transcription case: {total_bad_case_zh}") 204 | if "EN" in langs: 205 | print(f"Bad en transcription case: {total_bad_case_en}\n") 206 | 207 | 208 | if __name__ == "__main__": 209 | max_workers = 32 210 | 211 | tokenizer = "pinyin" # "pinyin" | "char" 212 | polyphone = True 213 | 214 | langs = ["ZH", "EN"] 215 | dataset_dir = "/Emilia_Dataset/raw" 216 | dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}" 217 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" 218 | print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") 219 | 220 | main() 221 | 222 | # Emilia ZH & EN 223 | # samples count 37837916 (after removal) 224 | # pinyin vocab size 2543 (polyphone) 225 | # total duration 95281.87 (hours) 226 | # bad zh asr cnt 230435 (samples) 227 | # bad eh asr cnt 37217 (samples) 228 | 229 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) 230 | # please be careful if using pretrained model, make sure the vocab.txt is same 231 | -------------------------------------------------------------------------------- /f5_tts/train/datasets/prepare_libritts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import json 7 | from concurrent.futures import ProcessPoolExecutor 8 | from importlib.resources import files 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | import soundfile as sf 12 | from datasets.arrow_writer import ArrowWriter 13 | 14 | 15 | def deal_with_audio_dir(audio_dir): 16 | sub_result, durations = [], [] 17 | vocab_set = set() 18 | audio_lists = list(audio_dir.rglob("*.wav")) 19 | 20 | for line in audio_lists: 21 | text_path = line.with_suffix(".normalized.txt") 22 | text = open(text_path, "r").read().strip() 23 | duration = sf.info(line).duration 24 | if duration < 0.4 or duration > 30: 25 | continue 26 | sub_result.append({"audio_path": str(line), "text": text, "duration": duration}) 27 | durations.append(duration) 28 | vocab_set.update(list(text)) 29 | return sub_result, durations, vocab_set 30 | 31 | 32 | def main(): 33 | result = [] 34 | duration_list = [] 35 | text_vocab_set = set() 36 | 37 | # process raw data 38 | executor = ProcessPoolExecutor(max_workers=max_workers) 39 | futures = [] 40 | 41 | for subset in tqdm(SUB_SET): 42 | dataset_path = Path(os.path.join(dataset_dir, subset)) 43 | [ 44 | futures.append(executor.submit(deal_with_audio_dir, audio_dir)) 45 | for audio_dir in dataset_path.iterdir() 46 | if audio_dir.is_dir() 47 | ] 48 | for future in tqdm(futures, total=len(futures)): 49 | sub_result, durations, vocab_set = future.result() 50 | result.extend(sub_result) 51 | duration_list.extend(durations) 52 | text_vocab_set.update(vocab_set) 53 | executor.shutdown() 54 | 55 | # save preprocessed dataset to disk 56 | if not os.path.exists(f"{save_dir}"): 57 | os.makedirs(f"{save_dir}") 58 | print(f"\nSaving to {save_dir} ...") 59 | 60 | with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: 61 | for line in tqdm(result, desc="Writing to raw.arrow ..."): 62 | writer.write(line) 63 | 64 | # dup a json separately saving duration in case for DynamicBatchSampler ease 65 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: 66 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 67 | 68 | # vocab map, i.e. tokenizer 69 | with open(f"{save_dir}/vocab.txt", "w") as f: 70 | for vocab in sorted(text_vocab_set): 71 | f.write(vocab + "\n") 72 | 73 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 74 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 75 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 76 | 77 | 78 | if __name__ == "__main__": 79 | max_workers = 36 80 | 81 | tokenizer = "char" # "pinyin" | "char" 82 | 83 | SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"] 84 | dataset_dir = "/LibriTTS" 85 | dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "") 86 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" 87 | print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") 88 | main() 89 | 90 | # For LibriTTS_100_360_500_char, sample count: 354218 91 | # For LibriTTS_100_360_500_char, vocab size is: 78 92 | # For LibriTTS_100_360_500_char, total 554.09 hours 93 | -------------------------------------------------------------------------------- /f5_tts/train/datasets/prepare_ljspeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import json 7 | from importlib.resources import files 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | import soundfile as sf 11 | from datasets.arrow_writer import ArrowWriter 12 | 13 | 14 | def main(): 15 | result = [] 16 | duration_list = [] 17 | text_vocab_set = set() 18 | 19 | with open(meta_info, "r") as f: 20 | lines = f.readlines() 21 | for line in tqdm(lines): 22 | uttr, text, norm_text = line.split("|") 23 | norm_text = norm_text.strip() 24 | wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav" 25 | duration = sf.info(wav_path).duration 26 | if duration < 0.4 or duration > 30: 27 | continue 28 | result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration}) 29 | duration_list.append(duration) 30 | text_vocab_set.update(list(norm_text)) 31 | 32 | # save preprocessed dataset to disk 33 | if not os.path.exists(f"{save_dir}"): 34 | os.makedirs(f"{save_dir}") 35 | print(f"\nSaving to {save_dir} ...") 36 | 37 | with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: 38 | for line in tqdm(result, desc="Writing to raw.arrow ..."): 39 | writer.write(line) 40 | 41 | # dup a json separately saving duration in case for DynamicBatchSampler ease 42 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: 43 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 44 | 45 | # vocab map, i.e. tokenizer 46 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 47 | with open(f"{save_dir}/vocab.txt", "w") as f: 48 | for vocab in sorted(text_vocab_set): 49 | f.write(vocab + "\n") 50 | 51 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 52 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 53 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 54 | 55 | 56 | if __name__ == "__main__": 57 | tokenizer = "char" # "pinyin" | "char" 58 | 59 | dataset_dir = "/LJSpeech-1.1" 60 | dataset_name = f"LJSpeech_{tokenizer}" 61 | meta_info = os.path.join(dataset_dir, "metadata.csv") 62 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" 63 | print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") 64 | 65 | main() 66 | -------------------------------------------------------------------------------- /f5_tts/train/datasets/prepare_wenetspeech4tts.py: -------------------------------------------------------------------------------- 1 | # generate audio text map for WenetSpeech4TTS 2 | # evaluate for vocab size 3 | 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | import json 10 | from concurrent.futures import ProcessPoolExecutor 11 | from importlib.resources import files 12 | from tqdm import tqdm 13 | 14 | import torchaudio 15 | from datasets import Dataset 16 | 17 | from f5_tts.model.utils import convert_char_to_pinyin 18 | 19 | 20 | def deal_with_sub_path_files(dataset_path, sub_path): 21 | print(f"Dealing with: {sub_path}") 22 | 23 | text_dir = os.path.join(dataset_path, sub_path, "txts") 24 | audio_dir = os.path.join(dataset_path, sub_path, "wavs") 25 | text_files = os.listdir(text_dir) 26 | 27 | audio_paths, texts, durations = [], [], [] 28 | for text_file in tqdm(text_files): 29 | with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file: 30 | first_line = file.readline().split("\t") 31 | audio_nm = first_line[0] 32 | audio_path = os.path.join(audio_dir, audio_nm + ".wav") 33 | text = first_line[1].strip() 34 | 35 | audio_paths.append(audio_path) 36 | 37 | if tokenizer == "pinyin": 38 | texts.extend(convert_char_to_pinyin([text], polyphone=polyphone)) 39 | elif tokenizer == "char": 40 | texts.append(text) 41 | 42 | audio, sample_rate = torchaudio.load(audio_path) 43 | durations.append(audio.shape[-1] / sample_rate) 44 | 45 | return audio_paths, texts, durations 46 | 47 | 48 | def main(): 49 | assert tokenizer in ["pinyin", "char"] 50 | 51 | audio_path_list, text_list, duration_list = [], [], [] 52 | 53 | executor = ProcessPoolExecutor(max_workers=max_workers) 54 | futures = [] 55 | for dataset_path in dataset_paths: 56 | sub_items = os.listdir(dataset_path) 57 | sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))] 58 | for sub_path in sub_paths: 59 | futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path)) 60 | for future in tqdm(futures, total=len(futures)): 61 | audio_paths, texts, durations = future.result() 62 | audio_path_list.extend(audio_paths) 63 | text_list.extend(texts) 64 | duration_list.extend(durations) 65 | executor.shutdown() 66 | 67 | if not os.path.exists("data"): 68 | os.makedirs("data") 69 | 70 | print(f"\nSaving to {save_dir} ...") 71 | dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) 72 | dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format 73 | 74 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: 75 | json.dump( 76 | {"duration": duration_list}, f, ensure_ascii=False 77 | ) # dup a json separately saving duration in case for DynamicBatchSampler ease 78 | 79 | print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...") 80 | text_vocab_set = set() 81 | for text in tqdm(text_list): 82 | text_vocab_set.update(list(text)) 83 | 84 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 85 | if tokenizer == "pinyin": 86 | text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 87 | 88 | with open(f"{save_dir}/vocab.txt", "w") as f: 89 | for vocab in sorted(text_vocab_set): 90 | f.write(vocab + "\n") 91 | print(f"\nFor {dataset_name}, sample count: {len(text_list)}") 92 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n") 93 | 94 | 95 | if __name__ == "__main__": 96 | max_workers = 32 97 | 98 | tokenizer = "pinyin" # "pinyin" | "char" 99 | polyphone = True 100 | dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic 101 | 102 | dataset_name = ( 103 | ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1] 104 | + "_" 105 | + tokenizer 106 | ) 107 | dataset_paths = [ 108 | "/WenetSpeech4TTS/Basic", 109 | "/WenetSpeech4TTS/Standard", 110 | "/WenetSpeech4TTS/Premium", 111 | ][-dataset_choice:] 112 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" 113 | print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n") 114 | 115 | main() 116 | 117 | # Results (if adding alphabets with accents and symbols): 118 | # WenetSpeech4TTS Basic Standard Premium 119 | # samples count 3932473 1941220 407494 120 | # pinyin vocab size 1349 1348 1344 (no polyphone) 121 | # - - 1459 (polyphone) 122 | # char vocab size 5264 5219 5042 123 | 124 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) 125 | # please be careful if using pretrained model, make sure the vocab.txt is same 126 | -------------------------------------------------------------------------------- /f5_tts/train/finetune_cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | from cached_path import cached_path 6 | from f5_tts.model import CFM, UNetT, DiT, Trainer 7 | from f5_tts.model.utils import get_tokenizer 8 | from f5_tts.model.dataset import load_dataset 9 | from importlib.resources import files 10 | 11 | 12 | # -------------------------- Dataset Settings --------------------------- # 13 | target_sample_rate = 24000 14 | n_mel_channels = 100 15 | hop_length = 256 16 | win_length = 1024 17 | n_fft = 1024 18 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan' 19 | 20 | 21 | # -------------------------- Argument Parsing --------------------------- # 22 | def parse_args(): 23 | # batch_size_per_gpu = 1000 settting for gpu 8GB 24 | # batch_size_per_gpu = 1600 settting for gpu 12GB 25 | # batch_size_per_gpu = 2000 settting for gpu 16GB 26 | # batch_size_per_gpu = 3200 settting for gpu 24GB 27 | 28 | # num_warmup_updates = 300 for 5000 sample about 10 hours 29 | 30 | # change save_per_updates , last_per_steps change this value what you need , 31 | 32 | parser = argparse.ArgumentParser(description="Train CFM Model") 33 | 34 | parser.add_argument( 35 | "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" 36 | ) 37 | parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") 38 | parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") 39 | parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") 40 | parser.add_argument( 41 | "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" 42 | ) 43 | parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") 44 | parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") 45 | parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") 46 | parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") 47 | parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps") 48 | parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") 49 | parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps") 50 | parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") 51 | parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") 52 | parser.add_argument( 53 | "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" 54 | ) 55 | parser.add_argument( 56 | "--tokenizer_path", 57 | type=str, 58 | default=None, 59 | help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')", 60 | ) 61 | parser.add_argument( 62 | "--log_samples", 63 | type=bool, 64 | default=False, 65 | help="Log inferenced samples per ckpt save steps", 66 | ) 67 | parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger") 68 | parser.add_argument( 69 | "--bnb_optimizer", 70 | type=bool, 71 | default=False, 72 | help="Use 8-bit Adam optimizer from bitsandbytes", 73 | ) 74 | 75 | return parser.parse_args() 76 | 77 | 78 | # -------------------------- Training Settings -------------------------- # 79 | 80 | 81 | def main(): 82 | args = parse_args() 83 | 84 | checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) 85 | 86 | # Model parameters based on experiment name 87 | if args.exp_name == "F5TTS_Base": 88 | wandb_resume_id = None 89 | model_cls = DiT 90 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 91 | if args.finetune: 92 | if args.pretrain is None: 93 | ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) 94 | else: 95 | ckpt_path = args.pretrain 96 | elif args.exp_name == "E2TTS_Base": 97 | wandb_resume_id = None 98 | model_cls = UNetT 99 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 100 | if args.finetune: 101 | if args.pretrain is None: 102 | ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) 103 | else: 104 | ckpt_path = args.pretrain 105 | 106 | if args.finetune: 107 | if not os.path.isdir(checkpoint_path): 108 | os.makedirs(checkpoint_path, exist_ok=True) 109 | 110 | file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) 111 | if not os.path.isfile(file_checkpoint): 112 | shutil.copy2(ckpt_path, file_checkpoint) 113 | print("copy checkpoint for finetune") 114 | 115 | # Use the tokenizer and tokenizer_path provided in the command line arguments 116 | tokenizer = args.tokenizer 117 | if tokenizer == "custom": 118 | if not args.tokenizer_path: 119 | raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.") 120 | tokenizer_path = args.tokenizer_path 121 | else: 122 | tokenizer_path = args.dataset_name 123 | 124 | vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) 125 | 126 | print("\nvocab : ", vocab_size) 127 | print("\nvocoder : ", mel_spec_type) 128 | 129 | mel_spec_kwargs = dict( 130 | n_fft=n_fft, 131 | hop_length=hop_length, 132 | win_length=win_length, 133 | n_mel_channels=n_mel_channels, 134 | target_sample_rate=target_sample_rate, 135 | mel_spec_type=mel_spec_type, 136 | ) 137 | 138 | model = CFM( 139 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 140 | mel_spec_kwargs=mel_spec_kwargs, 141 | vocab_char_map=vocab_char_map, 142 | ) 143 | 144 | trainer = Trainer( 145 | model, 146 | args.epochs, 147 | args.learning_rate, 148 | num_warmup_updates=args.num_warmup_updates, 149 | save_per_updates=args.save_per_updates, 150 | checkpoint_path=checkpoint_path, 151 | batch_size=args.batch_size_per_gpu, 152 | batch_size_type=args.batch_size_type, 153 | max_samples=args.max_samples, 154 | grad_accumulation_steps=args.grad_accumulation_steps, 155 | max_grad_norm=args.max_grad_norm, 156 | logger=args.logger, 157 | wandb_project=args.dataset_name, 158 | wandb_run_name=args.exp_name, 159 | wandb_resume_id=wandb_resume_id, 160 | log_samples=args.log_samples, 161 | last_per_steps=args.last_per_steps, 162 | bnb_optimizer=args.bnb_optimizer, 163 | ) 164 | 165 | train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) 166 | 167 | trainer.train( 168 | train_dataset, 169 | resumable_with_seed=666, # seed for shuffling dataset 170 | ) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /f5_tts/train/train.py: -------------------------------------------------------------------------------- 1 | # training script. 2 | 3 | import os 4 | from importlib.resources import files 5 | 6 | import hydra 7 | 8 | from f5_tts.model import CFM, DiT, Trainer, UNetT 9 | from f5_tts.model.dataset import load_dataset 10 | from f5_tts.model.utils import get_tokenizer 11 | 12 | os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable) 13 | 14 | 15 | @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) 16 | def main(cfg): 17 | tokenizer = cfg.model.tokenizer 18 | mel_spec_type = cfg.model.mel_spec.mel_spec_type 19 | exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" 20 | 21 | # set text tokenizer 22 | if tokenizer != "custom": 23 | tokenizer_path = cfg.datasets.name 24 | else: 25 | tokenizer_path = cfg.model.tokenizer_path 26 | vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) 27 | 28 | # set model 29 | if "F5TTS" in cfg.model.name: 30 | model_cls = DiT 31 | elif "E2TTS" in cfg.model.name: 32 | model_cls = UNetT 33 | wandb_resume_id = None 34 | 35 | model = CFM( 36 | transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), 37 | mel_spec_kwargs=cfg.model.mel_spec, 38 | vocab_char_map=vocab_char_map, 39 | ) 40 | 41 | # init trainer 42 | trainer = Trainer( 43 | model, 44 | epochs=cfg.optim.epochs, 45 | learning_rate=cfg.optim.learning_rate, 46 | num_warmup_updates=cfg.optim.num_warmup_updates, 47 | save_per_updates=cfg.ckpts.save_per_updates, 48 | checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), 49 | batch_size=cfg.datasets.batch_size_per_gpu, 50 | batch_size_type=cfg.datasets.batch_size_type, 51 | max_samples=cfg.datasets.max_samples, 52 | grad_accumulation_steps=cfg.optim.grad_accumulation_steps, 53 | max_grad_norm=cfg.optim.max_grad_norm, 54 | logger=cfg.ckpts.logger, 55 | wandb_project="CFM-TTS", 56 | wandb_run_name=exp_name, 57 | wandb_resume_id=wandb_resume_id, 58 | last_per_steps=cfg.ckpts.last_per_steps, 59 | log_samples=True, 60 | bnb_optimizer=cfg.optim.bnb_optimizer, 61 | mel_spec_type=mel_spec_type, 62 | is_local_vocoder=cfg.model.vocoder.is_local, 63 | local_vocoder_path=cfg.model.vocoder.local_path, 64 | ) 65 | 66 | train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec) 67 | trainer.train( 68 | train_dataset, 69 | num_workers=cfg.datasets.num_workers, 70 | resumable_with_seed=666, # seed for shuffling dataset 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles 2 | aiohttp 3 | aliyun-python-sdk-core 4 | aliyun-python-sdk-kms 5 | anyio 6 | async-timeout 7 | attrs 8 | av 9 | certifi 10 | cffi 11 | charset-normalizer 12 | click 13 | colorama 14 | ctranslate2 15 | datasets 16 | ema-pytorch 17 | encodec 18 | exceptiongroup 19 | fastapi 20 | faster-whisper 21 | ffmpy 22 | filelock 23 | Flask 24 | Flask-Cors 25 | fsspec 26 | funasr 27 | google-api-core 28 | google-auth 29 | gradio 30 | gradio_client 31 | httpcore 32 | httpx 33 | huggingface-hub 34 | jieba 35 | librosa 36 | matplotlib 37 | multidict 38 | multiprocess 39 | numba==0.60.0 40 | numpy==1.23.5 41 | omegaconf==2.3.0 42 | onnxruntime==1.19.2 43 | pandas==2.2.3 44 | pillow==11.0.0 45 | pypinyin==0.53.0 46 | regex==2024.9.11 47 | requests 48 | torch==2.5.0 49 | torch-complex==0.4.4 50 | torchaudio==2.5.0 51 | torchdiffeq==0.2.4 52 | transformers==4.45.2 53 | waitress 54 | zhconv 55 | zhon 56 | -------------------------------------------------------------------------------- /run-api.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | 4 | call %cd%/runtime/python api.py 5 | 6 | pause -------------------------------------------------------------------------------- /run-webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | 4 | call %cd%/runtime/python ./f5_tts/infer/infer_gradio.py 5 | 6 | pause -------------------------------------------------------------------------------- /runtest.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | 4 | call %cd%/runtime/python test.py 5 | 6 | pause -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | res=requests.post('http://127.0.0.1:5010/api',data={ 4 | "gen_text": '你好,今天是个不错的日子', 5 | "ref_text": '你说四大皆空,却为何紧闭双眼,若你睁开眼睛看看我,我不相信你,两眼空空', 6 | "model": 'f5-tts' 7 | },files={"audio":open('./1.wav','rb')}) 8 | 9 | if res.status_code!=200: 10 | print(res.text) 11 | exit() 12 | 13 | with open("ceshi.wav",'wb') as f: 14 | f.write(res.content) 15 | 16 | 17 | -------------------------------------------------------------------------------- /testcuda.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import torch 4 | import os 5 | import sys 6 | from torch.backends import cudnn 7 | 8 | if torch.cuda.is_available(): 9 | print('CUDA 可用') 10 | else: 11 | print("CUDA不可用,请确保是英伟达显卡并安装了CUDA11.8+版本") 12 | sys.exit() 13 | 14 | if cudnn.is_available() and cudnn.is_acceptable(torch.tensor(1.).cuda()): 15 | print('cuDNN可用') 16 | else: 17 | print('cuDNN不可用,请安装cuDNN') 18 | sys.exit() 19 | -------------------------------------------------------------------------------- /测试GPU是否可用.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | 4 | call %cd%/runtime/python testcuda.py 5 | 6 | pause --------------------------------------------------------------------------------