├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ ├── feature_request.yml │ ├── help_wanted.yml │ └── question.yml └── workflows │ ├── pre-commit.yaml │ ├── publish-docker-image.yaml │ └── sync-hf.yaml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── ckpts └── README.md ├── data ├── Emilia_ZH_EN_pinyin │ └── vocab.txt └── librispeech_pc_test_clean_cross_sentence.lst ├── pyproject.toml ├── ruff.toml └── src └── f5_tts ├── api.py ├── eval ├── README.md ├── ecapa_tdnn.py ├── eval_infer_batch.py ├── eval_infer_batch.sh ├── eval_librispeech_test_clean.py ├── eval_seedtts_testset.py └── utils_eval.py ├── infer ├── README.md ├── SHARED.md ├── examples │ ├── basic │ │ ├── basic.toml │ │ ├── basic_ref_en.wav │ │ └── basic_ref_zh.wav │ ├── multi │ │ ├── country.flac │ │ ├── main.flac │ │ ├── story.toml │ │ ├── story.txt │ │ └── town.flac │ └── vocab.txt ├── infer_cli.py ├── infer_gradio.py ├── speech_edit.py └── utils_infer.py ├── model ├── __init__.py ├── backbones │ ├── README.md │ ├── dit.py │ ├── mmdit.py │ └── unett.py ├── cfm.py ├── dataset.py ├── modules.py ├── trainer.py └── utils.py ├── scripts ├── count_max_epoch.py └── count_params_gflops.py ├── socket_server.py └── train ├── README.md ├── datasets ├── prepare_csv_wavs.py ├── prepare_emilia.py └── prepare_wenetspeech4tts.py ├── finetune_cli.py ├── finetune_gradio.py └── train.py /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: "Bug Report" 2 | description: | 3 | Please provide as much details to help address the issue, including logs and screenshots. 4 | labels: 5 | - bug 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Checks 10 | description: "To ensure timely help, please confirm the following:" 11 | options: 12 | - label: This template is only for bug reports, usage problems go with 'Help Wanted'. 13 | required: true 14 | - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem. 15 | required: true 16 | - label: I have searched for existing issues, including closed ones, and couldn't find a solution. 17 | required: true 18 | - label: I confirm that I am using English to submit this report in order to facilitate communication. 19 | required: true 20 | - type: textarea 21 | attributes: 22 | label: Environment Details 23 | description: "Provide details such as OS, Python version, and any relevant software or dependencies." 24 | placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8 25 | validations: 26 | required: true 27 | - type: textarea 28 | attributes: 29 | label: Steps to Reproduce 30 | description: | 31 | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks. 32 | placeholder: | 33 | 1. Create a new conda environment. 34 | 2. Clone the repository, install as local editable and properly set up. 35 | 3. Run the command: `accelerate launch src/f5_tts/train/train.py`. 36 | 4. Have following error message... (attach logs). 37 | validations: 38 | required: true 39 | - type: textarea 40 | attributes: 41 | label: ✔️ Expected Behavior 42 | placeholder: Describe what you expected to happen. 43 | validations: 44 | required: false 45 | - type: textarea 46 | attributes: 47 | label: ❌ Actual Behavior 48 | placeholder: Describe what actually happened. 49 | validations: 50 | required: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: "Feature Request" 2 | description: | 3 | Some constructive suggestions and new ideas regarding current repo. 4 | labels: 5 | - enhancement 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Checks 10 | description: "To help us grasp quickly, please confirm the following:" 11 | options: 12 | - label: This template is only for feature request. 13 | required: true 14 | - label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs. 15 | required: true 16 | - label: I have searched for existing issues, including closed ones, and found not discussion yet. 17 | required: true 18 | - label: I confirm that I am using English to submit this report in order to facilitate communication. 19 | required: true 20 | - type: textarea 21 | attributes: 22 | label: 1. Is this request related to a challenge you're experiencing? Tell us your story. 23 | description: | 24 | Describe the specific problem or scenario you're facing in detail. For example: 25 | *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."* 26 | placeholder: Please describe the situation in as much detail as possible. 27 | validations: 28 | required: true 29 | 30 | - type: textarea 31 | attributes: 32 | label: 2. What is your suggested solution? 33 | description: | 34 | Provide a clear description of the feature or enhancement you'd like to propose. 35 | How would this feature solve your issue or improve the project? 36 | placeholder: Describe your idea or proposed solution here. 37 | validations: 38 | required: true 39 | 40 | - type: textarea 41 | attributes: 42 | label: 3. Additional context or comments 43 | description: | 44 | Any other relevant information, links, documents, or screenshots that provide clarity. 45 | Use this section for anything not covered above. 46 | placeholder: Add any extra details here. 47 | validations: 48 | required: false 49 | 50 | - type: checkboxes 51 | attributes: 52 | label: 4. Can you help us with this feature? 53 | description: | 54 | Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration. 55 | options: 56 | - label: I am interested in contributing to this feature. 57 | required: false 58 | 59 | - type: markdown 60 | attributes: 61 | value: | 62 | **Note:** Please submit only one request per issue to keep discussions focused and manageable. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help_wanted.yml: -------------------------------------------------------------------------------- 1 | name: "Help Wanted" 2 | description: | 3 | Please provide as much details to help address the issue, including logs and screenshots. 4 | labels: 5 | - help wanted 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Checks 10 | description: "To ensure timely help, please confirm the following:" 11 | options: 12 | - label: This template is only for usage issues encountered. 13 | required: true 14 | - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem. 15 | required: true 16 | - label: I have searched for existing issues, including closed ones, and couldn't find a solution. 17 | required: true 18 | - label: I confirm that I am using English to submit this report in order to facilitate communication. 19 | required: true 20 | - type: textarea 21 | attributes: 22 | label: Environment Details 23 | description: "Provide details such as OS, Python version, and any relevant software or dependencies." 24 | placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1 25 | validations: 26 | required: true 27 | - type: textarea 28 | attributes: 29 | label: Steps to Reproduce 30 | description: | 31 | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks. 32 | placeholder: | 33 | 1. Create a new conda environment. 34 | 2. Clone the repository and install as pip package. 35 | 3. Run the command: `f5-tts_infer-gradio` with no ref_text provided. 36 | 4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c). 37 | validations: 38 | required: true 39 | - type: textarea 40 | attributes: 41 | label: ✔️ Expected Behavior 42 | placeholder: Describe what you expected to happen, e.g. output a generated audio 43 | validations: 44 | required: false 45 | - type: textarea 46 | attributes: 47 | label: ❌ Actual Behavior 48 | placeholder: Describe what actually happened, failure messages, etc. 49 | validations: 50 | required: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.yml: -------------------------------------------------------------------------------- 1 | name: "Question" 2 | description: | 3 | Pure question or inquiry about the project, usage issue goes with "help wanted". 4 | labels: 5 | - question 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Checks 10 | description: "To help us grasp quickly, please confirm the following:" 11 | options: 12 | - label: This template is only for question, not feature requests or bug reports. 13 | required: true 14 | - label: I have thoroughly reviewed the project documentation and read the related paper(s). 15 | required: true 16 | - label: I have searched for existing issues, including closed ones, no similar questions. 17 | required: true 18 | - label: I confirm that I am using English to submit this report in order to facilitate communication. 19 | required: true 20 | - type: textarea 21 | attributes: 22 | label: Question details 23 | description: | 24 | Question details, clearly stated using proper markdown syntax. 25 | validations: 26 | required: true 27 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.1 15 | -------------------------------------------------------------------------------- /.github/workflows/publish-docker-image.yaml: -------------------------------------------------------------------------------- 1 | name: Create and publish a Docker image 2 | 3 | # Configures this workflow to run every time a change is pushed to the branch called `release`. 4 | on: 5 | push: 6 | branches: ['main'] 7 | 8 | # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds. 9 | env: 10 | REGISTRY: ghcr.io 11 | IMAGE_NAME: ${{ github.repository }} 12 | 13 | # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. 14 | jobs: 15 | build-and-push-image: 16 | runs-on: ubuntu-latest 17 | # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. 18 | permissions: 19 | contents: read 20 | packages: write 21 | # 22 | steps: 23 | - name: Checkout repository 24 | uses: actions/checkout@v4 25 | - name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧 26 | uses: jlumbroso/free-disk-space@main 27 | with: 28 | # This might remove tools that are actually needed, if set to "true" but frees about 6 GB 29 | tool-cache: false 30 | 31 | # All of these default to true, but feel free to set to "false" if necessary for your workflow 32 | android: true 33 | dotnet: true 34 | haskell: true 35 | large-packages: false 36 | swap-storage: false 37 | docker-images: false 38 | # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. 39 | - name: Log in to the Container registry 40 | uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 41 | with: 42 | registry: ${{ env.REGISTRY }} 43 | username: ${{ github.actor }} 44 | password: ${{ secrets.GITHUB_TOKEN }} 45 | # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. 46 | - name: Extract metadata (tags, labels) for Docker 47 | id: meta 48 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 49 | with: 50 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 51 | # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. 52 | # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository. 53 | # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. 54 | - name: Build and push Docker image 55 | uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 56 | with: 57 | context: . 58 | push: true 59 | tags: ${{ steps.meta.outputs.tags }} 60 | labels: ${{ steps.meta.outputs.labels }} 61 | -------------------------------------------------------------------------------- /.github/workflows/sync-hf.yaml: -------------------------------------------------------------------------------- 1 | name: Sync to HF Space 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | trigger_curl: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Send cURL POST request 14 | run: | 15 | curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \ 16 | -s \ 17 | -H "Content-Type: application/json" \ 18 | -d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}" 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Customed 2 | .vscode/ 3 | tests/ 4 | runs/ 5 | data/ 6 | ckpts/ 7 | wandb/ 8 | results/ 9 | 10 | 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 121 | .pdm.toml 122 | .pdm-python 123 | .pdm-build/ 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/third_party/BigVGAN"] 2 | path = src/third_party/BigVGAN 3 | url = https://github.com/NVIDIA/BigVGAN.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.0 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [--fix] 9 | # Run the formatter. 10 | - id: ruff-format 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v2.3.0 13 | hooks: 14 | - id: check-yaml 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel 2 | 3 | USER root 4 | 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | 7 | LABEL github_repo="https://github.com/SWivid/F5-TTS" 8 | 9 | RUN set -x \ 10 | && apt-get update \ 11 | && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \ 12 | && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \ 13 | && rm -rf /var/lib/apt/lists/* \ 14 | && apt-get clean 15 | 16 | WORKDIR /workspace 17 | 18 | RUN git clone https://github.com/SWivid/F5-TTS.git \ 19 | && cd F5-TTS \ 20 | && pip install -e .[eval] 21 | 22 | ENV SHELL=/bin/bash 23 | 24 | WORKDIR /workspace/F5-TTS 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yushen CHEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching 2 | 3 | [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885) 5 | [![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/) 6 | [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) 7 | [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS) 8 | [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/) 9 | Watermark 10 | 11 | **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference. 12 | 13 | **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009). 14 | 15 | **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance 16 | 17 | ### Thanks to all the contributors ! 18 | 19 | ## News 20 | - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN). 21 | 22 | ## Installation 23 | 24 | ```bash 25 | # Create a python 3.10 conda env (you could also use virtualenv) 26 | conda create -n f5-tts python=3.10 27 | conda activate f5-tts 28 | 29 | # Install pytorch with your CUDA version, e.g. 30 | pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 31 | ``` 32 | 33 | Then you can choose from a few options below: 34 | 35 | ### 1. As a pip package (if just for inference) 36 | 37 | ```bash 38 | pip install git+https://github.com/SWivid/F5-TTS.git 39 | ``` 40 | 41 | ### 2. Local editable (if also do training, finetuning) 42 | 43 | ```bash 44 | git clone https://github.com/SWivid/F5-TTS.git 45 | cd F5-TTS 46 | # git submodule update --init --recursive # (optional, if need bigvgan) 47 | pip install -e . 48 | ``` 49 | If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`. 50 | ```python 51 | import os 52 | import sys 53 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 54 | ``` 55 | 56 | ### 3. Docker usage 57 | ```bash 58 | # Build from Dockerfile 59 | docker build -t f5tts:v1 . 60 | 61 | # Or pull from GitHub Container Registry 62 | docker pull ghcr.io/swivid/f5-tts:main 63 | ``` 64 | 65 | 66 | ## Inference 67 | 68 | ### 1. Gradio App 69 | 70 | Currently supported features: 71 | 72 | - Basic TTS with Chunk Inference 73 | - Multi-Style / Multi-Speaker Generation 74 | - Voice Chat powered by Qwen2.5-3B-Instruct 75 | - [Custom model](src/f5_tts/infer/SHARED.md) inference (local only) 76 | 77 | ```bash 78 | # Launch a Gradio app (web interface) 79 | f5-tts_infer-gradio 80 | 81 | # Specify the port/host 82 | f5-tts_infer-gradio --port 7860 --host 0.0.0.0 83 | 84 | # Launch a share link 85 | f5-tts_infer-gradio --share 86 | ``` 87 | 88 | ### 2. CLI Inference 89 | 90 | ```bash 91 | # Run with flags 92 | # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) 93 | f5-tts_infer-cli \ 94 | --model "F5-TTS" \ 95 | --ref_audio "ref_audio.wav" \ 96 | --ref_text "The content, subtitle or transcription of reference audio." \ 97 | --gen_text "Some text you want TTS model generate for you." 98 | 99 | # Run with default setting. src/f5_tts/infer/examples/basic/basic.toml 100 | f5-tts_infer-cli 101 | # Or with your own .toml file 102 | f5-tts_infer-cli -c custom.toml 103 | 104 | # Multi voice. See src/f5_tts/infer/README.md 105 | f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml 106 | ``` 107 | 108 | ### 3. More instructions 109 | 110 | - In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer). 111 | - The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue. 112 | 113 | 114 | ## Training 115 | 116 | ### 1. Gradio App 117 | 118 | Read [training & finetuning guidance](src/f5_tts/train) for more instructions. 119 | 120 | ```bash 121 | # Quick start with Gradio web interface 122 | f5-tts_finetune-gradio 123 | ``` 124 | 125 | 126 | ## [Evaluation](src/f5_tts/eval) 127 | 128 | 129 | ## Development 130 | 131 | Use pre-commit to ensure code quality (will run linters and formatters automatically) 132 | 133 | ```bash 134 | pip install pre-commit 135 | pre-commit install 136 | ``` 137 | 138 | When making a pull request, before each commit, run: 139 | 140 | ```bash 141 | pre-commit run --all-files 142 | ``` 143 | 144 | Note: Some model components have linting exceptions for E722 to accommodate tensor notation 145 | 146 | 147 | ## Acknowledgements 148 | 149 | - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective 150 | - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets 151 | - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion 152 | - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure 153 | - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder 154 | - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools 155 | - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test 156 | - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~ 157 | - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman) 158 | - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ) 159 | 160 | ## Citation 161 | If our work and codebase is useful for you, please cite as: 162 | ``` 163 | @article{chen-etal-2024-f5tts, 164 | title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching}, 165 | author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen}, 166 | journal={arXiv preprint arXiv:2410.06885}, 167 | year={2024}, 168 | } 169 | ``` 170 | ## License 171 | 172 | Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause. 173 | -------------------------------------------------------------------------------- /ckpts/README.md: -------------------------------------------------------------------------------- 1 | 2 | Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS 3 | 4 | ``` 5 | ckpts/ 6 | E2TTS_Base/ 7 | model_1200000.pt 8 | F5TTS_Base/ 9 | model_1200000.pt 10 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "f5-tts" 7 | version = "0.1.1" 8 | description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" 9 | readme = "README.md" 10 | license = {text = "MIT License"} 11 | classifiers = [ 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | ] 16 | dependencies = [ 17 | "accelerate>=0.33.0", 18 | "bitsandbytes>0.37.0", 19 | "cached_path", 20 | "click", 21 | "datasets", 22 | "ema_pytorch>=0.5.2", 23 | "gradio>=3.45.2", 24 | "jieba", 25 | "librosa", 26 | "matplotlib", 27 | "numpy<=1.26.4", 28 | "pydub", 29 | "pypinyin", 30 | "safetensors", 31 | "soundfile", 32 | "tomli", 33 | "torch>=2.0.0", 34 | "torchaudio>=2.0.0", 35 | "torchdiffeq", 36 | "tqdm>=4.65.0", 37 | "transformers", 38 | "transformers_stream_generator", 39 | "vocos", 40 | "wandb", 41 | "x_transformers>=1.31.14", 42 | ] 43 | 44 | [project.optional-dependencies] 45 | eval = [ 46 | "faster_whisper==0.10.1", 47 | "funasr", 48 | "jiwer", 49 | "modelscope", 50 | "zhconv", 51 | "zhon", 52 | ] 53 | 54 | [project.urls] 55 | Homepage = "https://github.com/SWivid/F5-TTS" 56 | 57 | [project.scripts] 58 | "f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main" 59 | "f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main" 60 | "f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main" 61 | "f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main" 62 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 120 2 | target-version = "py310" 3 | 4 | [lint] 5 | # Only ignore variables with names starting with "_". 6 | dummy-variable-rgx = "^_.*$" 7 | 8 | [lint.isort] 9 | force-single-line = true 10 | lines-after-imports = 2 11 | -------------------------------------------------------------------------------- /src/f5_tts/api.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from importlib.resources import files 4 | 5 | import soundfile as sf 6 | import torch 7 | import tqdm 8 | from cached_path import cached_path 9 | 10 | from f5_tts.infer.utils_infer import ( 11 | hop_length, 12 | infer_process, 13 | load_model, 14 | load_vocoder, 15 | preprocess_ref_audio_text, 16 | remove_silence_for_generated_wav, 17 | save_spectrogram, 18 | transcribe, 19 | target_sample_rate, 20 | ) 21 | from f5_tts.model import DiT, UNetT 22 | from f5_tts.model.utils import seed_everything 23 | 24 | 25 | class F5TTS: 26 | def __init__( 27 | self, 28 | model_type="F5-TTS", 29 | ckpt_file="", 30 | vocab_file="", 31 | ode_method="euler", 32 | use_ema=True, 33 | vocoder_name="vocos", 34 | local_path=None, 35 | device=None, 36 | hf_cache_dir=None, 37 | ): 38 | # Initialize parameters 39 | self.final_wave = None 40 | self.target_sample_rate = target_sample_rate 41 | self.hop_length = hop_length 42 | self.seed = -1 43 | self.mel_spec_type = vocoder_name 44 | 45 | # Set device 46 | self.device = device or ( 47 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 48 | ) 49 | 50 | # Load models 51 | self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) 52 | self.load_ema_model( 53 | model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir 54 | ) 55 | 56 | def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None): 57 | self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir) 58 | 59 | def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None): 60 | if model_type == "F5-TTS": 61 | if not ckpt_file: 62 | if mel_spec_type == "vocos": 63 | ckpt_file = str( 64 | cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) 65 | ) 66 | elif mel_spec_type == "bigvgan": 67 | ckpt_file = str( 68 | cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir) 69 | ) 70 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 71 | model_cls = DiT 72 | elif model_type == "E2-TTS": 73 | if not ckpt_file: 74 | ckpt_file = str( 75 | cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) 76 | ) 77 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 78 | model_cls = UNetT 79 | else: 80 | raise ValueError(f"Unknown model type: {model_type}") 81 | 82 | self.ema_model = load_model( 83 | model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device 84 | ) 85 | 86 | def transcribe(self, ref_audio, language=None): 87 | return transcribe(ref_audio, language) 88 | 89 | def export_wav(self, wav, file_wave, remove_silence=False): 90 | sf.write(file_wave, wav, self.target_sample_rate) 91 | 92 | if remove_silence: 93 | remove_silence_for_generated_wav(file_wave) 94 | 95 | def export_spectrogram(self, spect, file_spect): 96 | save_spectrogram(spect, file_spect) 97 | 98 | def infer( 99 | self, 100 | ref_file, 101 | ref_text, 102 | gen_text, 103 | show_info=print, 104 | progress=tqdm, 105 | target_rms=0.1, 106 | cross_fade_duration=0.15, 107 | sway_sampling_coef=-1, 108 | cfg_strength=2, 109 | nfe_step=32, 110 | speed=1.0, 111 | fix_duration=None, 112 | remove_silence=False, 113 | file_wave=None, 114 | file_spect=None, 115 | seed=-1, 116 | ): 117 | if seed == -1: 118 | seed = random.randint(0, sys.maxsize) 119 | seed_everything(seed) 120 | self.seed = seed 121 | 122 | ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) 123 | 124 | wav, sr, spect = infer_process( 125 | ref_file, 126 | ref_text, 127 | gen_text, 128 | self.ema_model, 129 | self.vocoder, 130 | self.mel_spec_type, 131 | show_info=show_info, 132 | progress=progress, 133 | target_rms=target_rms, 134 | cross_fade_duration=cross_fade_duration, 135 | nfe_step=nfe_step, 136 | cfg_strength=cfg_strength, 137 | sway_sampling_coef=sway_sampling_coef, 138 | speed=speed, 139 | fix_duration=fix_duration, 140 | device=self.device, 141 | ) 142 | 143 | if file_wave is not None: 144 | self.export_wav(wav, file_wave, remove_silence) 145 | 146 | if file_spect is not None: 147 | self.export_spectrogram(spect, file_spect) 148 | 149 | return wav, sr, spect 150 | 151 | 152 | if __name__ == "__main__": 153 | f5tts = F5TTS() 154 | 155 | wav, sr, spect = f5tts.infer( 156 | ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), 157 | ref_text="some call me nature, others call me mother nature.", 158 | gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", 159 | file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), 160 | file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), 161 | seed=-1, # random seed = -1 162 | ) 163 | 164 | print("seed :", f5tts.seed) 165 | -------------------------------------------------------------------------------- /src/f5_tts/eval/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Evaluation 3 | 4 | Install packages for evaluation: 5 | 6 | ```bash 7 | pip install -e .[eval] 8 | ``` 9 | 10 | ## Generating Samples for Evaluation 11 | 12 | ### Prepare Test Datasets 13 | 14 | 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval). 15 | 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/). 16 | 3. Unzip the downloaded datasets and place them in the `data/` directory. 17 | 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py` 18 | 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst` 19 | 20 | ### Batch Inference for Test Set 21 | 22 | To run batch inference for evaluations, execute the following commands: 23 | 24 | ```bash 25 | # batch inference for evaluations 26 | accelerate config # if not set before 27 | bash src/f5_tts/eval/eval_infer_batch.sh 28 | ``` 29 | 30 | ## Objective Evaluation on Generated Results 31 | 32 | ### Download Evaluation Model Checkpoints 33 | 34 | 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh) 35 | 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3) 36 | 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view). 37 | 38 | Then update in the following scripts with the paths you put evaluation model ckpts to. 39 | 40 | ### Objective Evaluation 41 | 42 | Update the path with your batch-inferenced results, and carry out WER / SIM evaluations: 43 | ```bash 44 | # Evaluation for Seed-TTS test set 45 | python src/f5_tts/eval/eval_seedtts_testset.py 46 | 47 | # Evaluation for LibriSpeech-PC test-clean (cross-sentence) 48 | python src/f5_tts/eval/eval_librispeech_test_clean.py 49 | ``` -------------------------------------------------------------------------------- /src/f5_tts/eval/ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | # just for speaker similarity evaluation, third-party code 2 | 3 | # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ 4 | # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN 5 | 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | """ Res2Conv1d + BatchNorm1d + ReLU 13 | """ 14 | 15 | 16 | class Res2Conv1dReluBn(nn.Module): 17 | """ 18 | in_channels == out_channels == channels 19 | """ 20 | 21 | def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): 22 | super().__init__() 23 | assert channels % scale == 0, "{} % {} != 0".format(channels, scale) 24 | self.scale = scale 25 | self.width = channels // scale 26 | self.nums = scale if scale == 1 else scale - 1 27 | 28 | self.convs = [] 29 | self.bns = [] 30 | for i in range(self.nums): 31 | self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias)) 32 | self.bns.append(nn.BatchNorm1d(self.width)) 33 | self.convs = nn.ModuleList(self.convs) 34 | self.bns = nn.ModuleList(self.bns) 35 | 36 | def forward(self, x): 37 | out = [] 38 | spx = torch.split(x, self.width, 1) 39 | for i in range(self.nums): 40 | if i == 0: 41 | sp = spx[i] 42 | else: 43 | sp = sp + spx[i] 44 | # Order: conv -> relu -> bn 45 | sp = self.convs[i](sp) 46 | sp = self.bns[i](F.relu(sp)) 47 | out.append(sp) 48 | if self.scale != 1: 49 | out.append(spx[self.nums]) 50 | out = torch.cat(out, dim=1) 51 | 52 | return out 53 | 54 | 55 | """ Conv1d + BatchNorm1d + ReLU 56 | """ 57 | 58 | 59 | class Conv1dReluBn(nn.Module): 60 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): 61 | super().__init__() 62 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) 63 | self.bn = nn.BatchNorm1d(out_channels) 64 | 65 | def forward(self, x): 66 | return self.bn(F.relu(self.conv(x))) 67 | 68 | 69 | """ The SE connection of 1D case. 70 | """ 71 | 72 | 73 | class SE_Connect(nn.Module): 74 | def __init__(self, channels, se_bottleneck_dim=128): 75 | super().__init__() 76 | self.linear1 = nn.Linear(channels, se_bottleneck_dim) 77 | self.linear2 = nn.Linear(se_bottleneck_dim, channels) 78 | 79 | def forward(self, x): 80 | out = x.mean(dim=2) 81 | out = F.relu(self.linear1(out)) 82 | out = torch.sigmoid(self.linear2(out)) 83 | out = x * out.unsqueeze(2) 84 | 85 | return out 86 | 87 | 88 | """ SE-Res2Block of the ECAPA-TDNN architecture. 89 | """ 90 | 91 | # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): 92 | # return nn.Sequential( 93 | # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), 94 | # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), 95 | # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), 96 | # SE_Connect(channels) 97 | # ) 98 | 99 | 100 | class SE_Res2Block(nn.Module): 101 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): 102 | super().__init__() 103 | self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 104 | self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) 105 | self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) 106 | self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) 107 | 108 | self.shortcut = None 109 | if in_channels != out_channels: 110 | self.shortcut = nn.Conv1d( 111 | in_channels=in_channels, 112 | out_channels=out_channels, 113 | kernel_size=1, 114 | ) 115 | 116 | def forward(self, x): 117 | residual = x 118 | if self.shortcut: 119 | residual = self.shortcut(x) 120 | 121 | x = self.Conv1dReluBn1(x) 122 | x = self.Res2Conv1dReluBn(x) 123 | x = self.Conv1dReluBn2(x) 124 | x = self.SE_Connect(x) 125 | 126 | return x + residual 127 | 128 | 129 | """ Attentive weighted mean and standard deviation pooling. 130 | """ 131 | 132 | 133 | class AttentiveStatsPool(nn.Module): 134 | def __init__(self, in_dim, attention_channels=128, global_context_att=False): 135 | super().__init__() 136 | self.global_context_att = global_context_att 137 | 138 | # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. 139 | if global_context_att: 140 | self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper 141 | else: 142 | self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper 143 | self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper 144 | 145 | def forward(self, x): 146 | if self.global_context_att: 147 | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) 148 | context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) 149 | x_in = torch.cat((x, context_mean, context_std), dim=1) 150 | else: 151 | x_in = x 152 | 153 | # DON'T use ReLU here! In experiments, I find ReLU hard to converge. 154 | alpha = torch.tanh(self.linear1(x_in)) 155 | # alpha = F.relu(self.linear1(x_in)) 156 | alpha = torch.softmax(self.linear2(alpha), dim=2) 157 | mean = torch.sum(alpha * x, dim=2) 158 | residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 159 | std = torch.sqrt(residuals.clamp(min=1e-9)) 160 | return torch.cat([mean, std], dim=1) 161 | 162 | 163 | class ECAPA_TDNN(nn.Module): 164 | def __init__( 165 | self, 166 | feat_dim=80, 167 | channels=512, 168 | emb_dim=192, 169 | global_context_att=False, 170 | feat_type="wavlm_large", 171 | sr=16000, 172 | feature_selection="hidden_states", 173 | update_extract=False, 174 | config_path=None, 175 | ): 176 | super().__init__() 177 | 178 | self.feat_type = feat_type 179 | self.feature_selection = feature_selection 180 | self.update_extract = update_extract 181 | self.sr = sr 182 | 183 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True 184 | try: 185 | local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main") 186 | self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path) 187 | except: # noqa: E722 188 | self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) 189 | 190 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( 191 | self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" 192 | ): 193 | self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False 194 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( 195 | self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" 196 | ): 197 | self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False 198 | 199 | self.feat_num = self.get_feat_num() 200 | self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) 201 | 202 | if feat_type != "fbank" and feat_type != "mfcc": 203 | freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"] 204 | for name, param in self.feature_extract.named_parameters(): 205 | for freeze_val in freeze_list: 206 | if freeze_val in name: 207 | param.requires_grad = False 208 | break 209 | 210 | if not self.update_extract: 211 | for param in self.feature_extract.parameters(): 212 | param.requires_grad = False 213 | 214 | self.instance_norm = nn.InstanceNorm1d(feat_dim) 215 | # self.channels = [channels] * 4 + [channels * 3] 216 | self.channels = [channels] * 4 + [1536] 217 | 218 | self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) 219 | self.layer2 = SE_Res2Block( 220 | self.channels[0], 221 | self.channels[1], 222 | kernel_size=3, 223 | stride=1, 224 | padding=2, 225 | dilation=2, 226 | scale=8, 227 | se_bottleneck_dim=128, 228 | ) 229 | self.layer3 = SE_Res2Block( 230 | self.channels[1], 231 | self.channels[2], 232 | kernel_size=3, 233 | stride=1, 234 | padding=3, 235 | dilation=3, 236 | scale=8, 237 | se_bottleneck_dim=128, 238 | ) 239 | self.layer4 = SE_Res2Block( 240 | self.channels[2], 241 | self.channels[3], 242 | kernel_size=3, 243 | stride=1, 244 | padding=4, 245 | dilation=4, 246 | scale=8, 247 | se_bottleneck_dim=128, 248 | ) 249 | 250 | # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) 251 | cat_channels = channels * 3 252 | self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) 253 | self.pooling = AttentiveStatsPool( 254 | self.channels[-1], attention_channels=128, global_context_att=global_context_att 255 | ) 256 | self.bn = nn.BatchNorm1d(self.channels[-1] * 2) 257 | self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) 258 | 259 | def get_feat_num(self): 260 | self.feature_extract.eval() 261 | wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] 262 | with torch.no_grad(): 263 | features = self.feature_extract(wav) 264 | select_feature = features[self.feature_selection] 265 | if isinstance(select_feature, (list, tuple)): 266 | return len(select_feature) 267 | else: 268 | return 1 269 | 270 | def get_feat(self, x): 271 | if self.update_extract: 272 | x = self.feature_extract([sample for sample in x]) 273 | else: 274 | with torch.no_grad(): 275 | if self.feat_type == "fbank" or self.feat_type == "mfcc": 276 | x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len 277 | else: 278 | x = self.feature_extract([sample for sample in x]) 279 | 280 | if self.feat_type == "fbank": 281 | x = x.log() 282 | 283 | if self.feat_type != "fbank" and self.feat_type != "mfcc": 284 | x = x[self.feature_selection] 285 | if isinstance(x, (list, tuple)): 286 | x = torch.stack(x, dim=0) 287 | else: 288 | x = x.unsqueeze(0) 289 | norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 290 | x = (norm_weights * x).sum(dim=0) 291 | x = torch.transpose(x, 1, 2) + 1e-6 292 | 293 | x = self.instance_norm(x) 294 | return x 295 | 296 | def forward(self, x): 297 | x = self.get_feat(x) 298 | 299 | out1 = self.layer1(x) 300 | out2 = self.layer2(out1) 301 | out3 = self.layer3(out2) 302 | out4 = self.layer4(out3) 303 | 304 | out = torch.cat([out2, out3, out4], dim=1) 305 | out = F.relu(self.conv(out)) 306 | out = self.bn(self.pooling(out)) 307 | out = self.linear(out) 308 | 309 | return out 310 | 311 | 312 | def ECAPA_TDNN_SMALL( 313 | feat_dim, 314 | emb_dim=256, 315 | feat_type="wavlm_large", 316 | sr=16000, 317 | feature_selection="hidden_states", 318 | update_extract=False, 319 | config_path=None, 320 | ): 321 | return ECAPA_TDNN( 322 | feat_dim=feat_dim, 323 | channels=512, 324 | emb_dim=emb_dim, 325 | feat_type=feat_type, 326 | sr=sr, 327 | feature_selection=feature_selection, 328 | update_extract=update_extract, 329 | config_path=config_path, 330 | ) 331 | -------------------------------------------------------------------------------- /src/f5_tts/eval/eval_infer_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import argparse 7 | import time 8 | from importlib.resources import files 9 | 10 | import torch 11 | import torchaudio 12 | from accelerate import Accelerator 13 | from tqdm import tqdm 14 | 15 | from f5_tts.eval.utils_eval import ( 16 | get_inference_prompt, 17 | get_librispeech_test_clean_metainfo, 18 | get_seedtts_testset_metainfo, 19 | ) 20 | from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder 21 | from f5_tts.model import CFM, DiT, UNetT 22 | from f5_tts.model.utils import get_tokenizer 23 | 24 | accelerator = Accelerator() 25 | device = f"cuda:{accelerator.process_index}" 26 | 27 | 28 | # --------------------- Dataset Settings -------------------- # 29 | 30 | target_sample_rate = 24000 31 | n_mel_channels = 100 32 | hop_length = 256 33 | win_length = 1024 34 | n_fft = 1024 35 | target_rms = 0.1 36 | 37 | 38 | tokenizer = "pinyin" 39 | rel_path = str(files("f5_tts").joinpath("../../")) 40 | 41 | 42 | def main(): 43 | # ---------------------- infer setting ---------------------- # 44 | 45 | parser = argparse.ArgumentParser(description="batch inference") 46 | 47 | parser.add_argument("-s", "--seed", default=None, type=int) 48 | parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") 49 | parser.add_argument("-n", "--expname", required=True) 50 | parser.add_argument("-c", "--ckptstep", default=1200000, type=int) 51 | parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"]) 52 | 53 | parser.add_argument("-nfe", "--nfestep", default=32, type=int) 54 | parser.add_argument("-o", "--odemethod", default="euler") 55 | parser.add_argument("-ss", "--swaysampling", default=-1, type=float) 56 | 57 | parser.add_argument("-t", "--testset", required=True) 58 | 59 | args = parser.parse_args() 60 | 61 | seed = args.seed 62 | dataset_name = args.dataset 63 | exp_name = args.expname 64 | ckpt_step = args.ckptstep 65 | ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" 66 | mel_spec_type = args.mel_spec_type 67 | 68 | nfe_step = args.nfestep 69 | ode_method = args.odemethod 70 | sway_sampling_coef = args.swaysampling 71 | 72 | testset = args.testset 73 | 74 | infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) 75 | cfg_strength = 2.0 76 | speed = 1.0 77 | use_truth_duration = False 78 | no_ref_audio = False 79 | 80 | if exp_name == "F5TTS_Base": 81 | model_cls = DiT 82 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 83 | 84 | elif exp_name == "E2TTS_Base": 85 | model_cls = UNetT 86 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 87 | 88 | if testset == "ls_pc_test_clean": 89 | metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" 90 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path 91 | metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) 92 | 93 | elif testset == "seedtts_test_zh": 94 | metalst = rel_path + "/data/seedtts_testset/zh/meta.lst" 95 | metainfo = get_seedtts_testset_metainfo(metalst) 96 | 97 | elif testset == "seedtts_test_en": 98 | metalst = rel_path + "/data/seedtts_testset/en/meta.lst" 99 | metainfo = get_seedtts_testset_metainfo(metalst) 100 | 101 | # path to save genereted wavs 102 | output_dir = ( 103 | f"{rel_path}/" 104 | f"results/{exp_name}_{ckpt_step}/{testset}/" 105 | f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}" 106 | f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" 107 | f"_cfg{cfg_strength}_speed{speed}" 108 | f"{'_gt-dur' if use_truth_duration else ''}" 109 | f"{'_no-ref-audio' if no_ref_audio else ''}" 110 | ) 111 | 112 | # -------------------------------------------------# 113 | 114 | use_ema = True 115 | 116 | prompts_all = get_inference_prompt( 117 | metainfo, 118 | speed=speed, 119 | tokenizer=tokenizer, 120 | target_sample_rate=target_sample_rate, 121 | n_mel_channels=n_mel_channels, 122 | hop_length=hop_length, 123 | mel_spec_type=mel_spec_type, 124 | target_rms=target_rms, 125 | use_truth_duration=use_truth_duration, 126 | infer_batch_size=infer_batch_size, 127 | ) 128 | 129 | # Vocoder model 130 | local = False 131 | if mel_spec_type == "vocos": 132 | vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" 133 | elif mel_spec_type == "bigvgan": 134 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" 135 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) 136 | 137 | # Tokenizer 138 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 139 | 140 | # Model 141 | model = CFM( 142 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 143 | mel_spec_kwargs=dict( 144 | n_fft=n_fft, 145 | hop_length=hop_length, 146 | win_length=win_length, 147 | n_mel_channels=n_mel_channels, 148 | target_sample_rate=target_sample_rate, 149 | mel_spec_type=mel_spec_type, 150 | ), 151 | odeint_kwargs=dict( 152 | method=ode_method, 153 | ), 154 | vocab_char_map=vocab_char_map, 155 | ).to(device) 156 | 157 | dtype = torch.float32 if mel_spec_type == "bigvgan" else None 158 | model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) 159 | 160 | if not os.path.exists(output_dir) and accelerator.is_main_process: 161 | os.makedirs(output_dir) 162 | 163 | # start batch inference 164 | accelerator.wait_for_everyone() 165 | start = time.time() 166 | 167 | with accelerator.split_between_processes(prompts_all) as prompts: 168 | for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): 169 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt 170 | ref_mels = ref_mels.to(device) 171 | ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) 172 | total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) 173 | 174 | # Inference 175 | with torch.inference_mode(): 176 | generated, _ = model.sample( 177 | cond=ref_mels, 178 | text=final_text_list, 179 | duration=total_mel_lens, 180 | lens=ref_mel_lens, 181 | steps=nfe_step, 182 | cfg_strength=cfg_strength, 183 | sway_sampling_coef=sway_sampling_coef, 184 | no_ref_audio=no_ref_audio, 185 | seed=seed, 186 | ) 187 | # Final result 188 | for i, gen in enumerate(generated): 189 | gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) 190 | gen_mel_spec = gen.permute(0, 2, 1) 191 | if mel_spec_type == "vocos": 192 | generated_wave = vocoder.decode(gen_mel_spec) 193 | elif mel_spec_type == "bigvgan": 194 | generated_wave = vocoder(gen_mel_spec) 195 | 196 | if ref_rms_list[i] < target_rms: 197 | generated_wave = generated_wave * ref_rms_list[i] / target_rms 198 | torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate) 199 | 200 | accelerator.wait_for_everyone() 201 | if accelerator.is_main_process: 202 | timediff = time.time() - start 203 | print(f"Done batch inference in {timediff / 60 :.2f} minutes.") 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /src/f5_tts/eval/eval_infer_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # e.g. F5-TTS, 16 NFE 4 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 5 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 6 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 7 | 8 | # e.g. Vanilla E2 TTS, 32 NFE 9 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 10 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 11 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 12 | 13 | # etc. 14 | -------------------------------------------------------------------------------- /src/f5_tts/eval/eval_librispeech_test_clean.py: -------------------------------------------------------------------------------- 1 | # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation) 2 | 3 | import sys 4 | import os 5 | 6 | sys.path.append(os.getcwd()) 7 | 8 | import multiprocessing as mp 9 | from importlib.resources import files 10 | 11 | import numpy as np 12 | 13 | from f5_tts.eval.utils_eval import ( 14 | get_librispeech_test, 15 | run_asr_wer, 16 | run_sim, 17 | ) 18 | 19 | rel_path = str(files("f5_tts").joinpath("../../")) 20 | 21 | 22 | eval_task = "wer" # sim | wer 23 | lang = "en" 24 | metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" 25 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path 26 | gen_wav_dir = "PATH_TO_GENERATED" # generated wavs 27 | 28 | gpus = [0, 1, 2, 3, 4, 5, 6, 7] 29 | test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) 30 | 31 | ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, 32 | ## leading to a low similarity for the ground truth in some cases. 33 | # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth 34 | 35 | local = False 36 | if local: # use local custom checkpoint dir 37 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" 38 | else: 39 | asr_ckpt_dir = "" # auto download to cache dir 40 | 41 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" 42 | 43 | 44 | # --------------------------- WER --------------------------- 45 | 46 | if eval_task == "wer": 47 | wers = [] 48 | 49 | with mp.Pool(processes=len(gpus)) as pool: 50 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] 51 | results = pool.map(run_asr_wer, args) 52 | for wers_ in results: 53 | wers.extend(wers_) 54 | 55 | wer = round(np.mean(wers) * 100, 3) 56 | print(f"\nTotal {len(wers)} samples") 57 | print(f"WER : {wer}%") 58 | 59 | 60 | # --------------------------- SIM --------------------------- 61 | 62 | if eval_task == "sim": 63 | sim_list = [] 64 | 65 | with mp.Pool(processes=len(gpus)) as pool: 66 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] 67 | results = pool.map(run_sim, args) 68 | for sim_ in results: 69 | sim_list.extend(sim_) 70 | 71 | sim = round(sum(sim_list) / len(sim_list), 3) 72 | print(f"\nTotal {len(sim_list)} samples") 73 | print(f"SIM : {sim}") 74 | -------------------------------------------------------------------------------- /src/f5_tts/eval/eval_seedtts_testset.py: -------------------------------------------------------------------------------- 1 | # Evaluate with Seed-TTS testset 2 | 3 | import sys 4 | import os 5 | 6 | sys.path.append(os.getcwd()) 7 | 8 | import multiprocessing as mp 9 | from importlib.resources import files 10 | 11 | import numpy as np 12 | 13 | from f5_tts.eval.utils_eval import ( 14 | get_seed_tts_test, 15 | run_asr_wer, 16 | run_sim, 17 | ) 18 | 19 | rel_path = str(files("f5_tts").joinpath("../../")) 20 | 21 | 22 | eval_task = "wer" # sim | wer 23 | lang = "zh" # zh | en 24 | metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset 25 | # gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs 26 | gen_wav_dir = "PATH_TO_GENERATED" # generated wavs 27 | 28 | 29 | # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different 30 | # zh 1.254 seems a result of 4 workers wer_seed_tts 31 | gpus = [0, 1, 2, 3, 4, 5, 6, 7] 32 | test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) 33 | 34 | local = False 35 | if local: # use local custom checkpoint dir 36 | if lang == "zh": 37 | asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr 38 | elif lang == "en": 39 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" 40 | else: 41 | asr_ckpt_dir = "" # auto download to cache dir 42 | 43 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" 44 | 45 | 46 | # --------------------------- WER --------------------------- 47 | 48 | if eval_task == "wer": 49 | wers = [] 50 | 51 | with mp.Pool(processes=len(gpus)) as pool: 52 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] 53 | results = pool.map(run_asr_wer, args) 54 | for wers_ in results: 55 | wers.extend(wers_) 56 | 57 | wer = round(np.mean(wers) * 100, 3) 58 | print(f"\nTotal {len(wers)} samples") 59 | print(f"WER : {wer}%") 60 | 61 | 62 | # --------------------------- SIM --------------------------- 63 | 64 | if eval_task == "sim": 65 | sim_list = [] 66 | 67 | with mp.Pool(processes=len(gpus)) as pool: 68 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] 69 | results = pool.map(run_sim, args) 70 | for sim_ in results: 71 | sim_list.extend(sim_) 72 | 73 | sim = round(sum(sim_list) / len(sim_list), 3) 74 | print(f"\nTotal {len(sim_list)} samples") 75 | print(f"SIM : {sim}") 76 | -------------------------------------------------------------------------------- /src/f5_tts/eval/utils_eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import string 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torchaudio 9 | from tqdm import tqdm 10 | 11 | from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL 12 | from f5_tts.model.modules import MelSpec 13 | from f5_tts.model.utils import convert_char_to_pinyin 14 | 15 | 16 | # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav 17 | def get_seedtts_testset_metainfo(metalst): 18 | f = open(metalst) 19 | lines = f.readlines() 20 | f.close() 21 | metainfo = [] 22 | for line in lines: 23 | if len(line.strip().split("|")) == 5: 24 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") 25 | elif len(line.strip().split("|")) == 4: 26 | utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") 27 | gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") 28 | if not os.path.isabs(prompt_wav): 29 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) 30 | metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) 31 | return metainfo 32 | 33 | 34 | # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav 35 | def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path): 36 | f = open(metalst) 37 | lines = f.readlines() 38 | f.close() 39 | metainfo = [] 40 | for line in lines: 41 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") 42 | 43 | # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) 44 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") 45 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") 46 | 47 | # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) 48 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") 49 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") 50 | 51 | metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav)) 52 | 53 | return metainfo 54 | 55 | 56 | # padded to max length mel batch 57 | def padded_mel_batch(ref_mels): 58 | max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() 59 | padded_ref_mels = [] 60 | for mel in ref_mels: 61 | padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) 62 | padded_ref_mels.append(padded_ref_mel) 63 | padded_ref_mels = torch.stack(padded_ref_mels) 64 | padded_ref_mels = padded_ref_mels.permute(0, 2, 1) 65 | return padded_ref_mels 66 | 67 | 68 | # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav 69 | 70 | 71 | def get_inference_prompt( 72 | metainfo, 73 | speed=1.0, 74 | tokenizer="pinyin", 75 | polyphone=True, 76 | target_sample_rate=24000, 77 | n_fft=1024, 78 | win_length=1024, 79 | n_mel_channels=100, 80 | hop_length=256, 81 | mel_spec_type="vocos", 82 | target_rms=0.1, 83 | use_truth_duration=False, 84 | infer_batch_size=1, 85 | num_buckets=200, 86 | min_secs=3, 87 | max_secs=40, 88 | ): 89 | prompts_all = [] 90 | 91 | min_tokens = min_secs * target_sample_rate // hop_length 92 | max_tokens = max_secs * target_sample_rate // hop_length 93 | 94 | batch_accum = [0] * num_buckets 95 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( 96 | [[] for _ in range(num_buckets)] for _ in range(6) 97 | ) 98 | 99 | mel_spectrogram = MelSpec( 100 | n_fft=n_fft, 101 | hop_length=hop_length, 102 | win_length=win_length, 103 | n_mel_channels=n_mel_channels, 104 | target_sample_rate=target_sample_rate, 105 | mel_spec_type=mel_spec_type, 106 | ) 107 | 108 | for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): 109 | # Audio 110 | ref_audio, ref_sr = torchaudio.load(prompt_wav) 111 | ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) 112 | if ref_rms < target_rms: 113 | ref_audio = ref_audio * target_rms / ref_rms 114 | assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." 115 | if ref_sr != target_sample_rate: 116 | resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) 117 | ref_audio = resampler(ref_audio) 118 | 119 | # Text 120 | if len(prompt_text[-1].encode("utf-8")) == 1: 121 | prompt_text = prompt_text + " " 122 | text = [prompt_text + gt_text] 123 | if tokenizer == "pinyin": 124 | text_list = convert_char_to_pinyin(text, polyphone=polyphone) 125 | else: 126 | text_list = text 127 | 128 | # Duration, mel frame length 129 | ref_mel_len = ref_audio.shape[-1] // hop_length 130 | if use_truth_duration: 131 | gt_audio, gt_sr = torchaudio.load(gt_wav) 132 | if gt_sr != target_sample_rate: 133 | resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) 134 | gt_audio = resampler(gt_audio) 135 | total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) 136 | 137 | # # test vocoder resynthesis 138 | # ref_audio = gt_audio 139 | else: 140 | ref_text_len = len(prompt_text.encode("utf-8")) 141 | gen_text_len = len(gt_text.encode("utf-8")) 142 | total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) 143 | 144 | # to mel spectrogram 145 | ref_mel = mel_spectrogram(ref_audio) 146 | ref_mel = ref_mel.squeeze(0) 147 | 148 | # deal with batch 149 | assert infer_batch_size > 0, "infer_batch_size should be greater than 0." 150 | assert ( 151 | min_tokens <= total_mel_len <= max_tokens 152 | ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." 153 | bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) 154 | 155 | utts[bucket_i].append(utt) 156 | ref_rms_list[bucket_i].append(ref_rms) 157 | ref_mels[bucket_i].append(ref_mel) 158 | ref_mel_lens[bucket_i].append(ref_mel_len) 159 | total_mel_lens[bucket_i].append(total_mel_len) 160 | final_text_list[bucket_i].extend(text_list) 161 | 162 | batch_accum[bucket_i] += total_mel_len 163 | 164 | if batch_accum[bucket_i] >= infer_batch_size: 165 | # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") 166 | prompts_all.append( 167 | ( 168 | utts[bucket_i], 169 | ref_rms_list[bucket_i], 170 | padded_mel_batch(ref_mels[bucket_i]), 171 | ref_mel_lens[bucket_i], 172 | total_mel_lens[bucket_i], 173 | final_text_list[bucket_i], 174 | ) 175 | ) 176 | batch_accum[bucket_i] = 0 177 | ( 178 | utts[bucket_i], 179 | ref_rms_list[bucket_i], 180 | ref_mels[bucket_i], 181 | ref_mel_lens[bucket_i], 182 | total_mel_lens[bucket_i], 183 | final_text_list[bucket_i], 184 | ) = [], [], [], [], [], [] 185 | 186 | # add residual 187 | for bucket_i, bucket_frames in enumerate(batch_accum): 188 | if bucket_frames > 0: 189 | prompts_all.append( 190 | ( 191 | utts[bucket_i], 192 | ref_rms_list[bucket_i], 193 | padded_mel_batch(ref_mels[bucket_i]), 194 | ref_mel_lens[bucket_i], 195 | total_mel_lens[bucket_i], 196 | final_text_list[bucket_i], 197 | ) 198 | ) 199 | # not only leave easy work for last workers 200 | random.seed(666) 201 | random.shuffle(prompts_all) 202 | 203 | return prompts_all 204 | 205 | 206 | # get wav_res_ref_text of seed-tts test metalst 207 | # https://github.com/BytedanceSpeech/seed-tts-eval 208 | 209 | 210 | def get_seed_tts_test(metalst, gen_wav_dir, gpus): 211 | f = open(metalst) 212 | lines = f.readlines() 213 | f.close() 214 | 215 | test_set_ = [] 216 | for line in tqdm(lines): 217 | if len(line.strip().split("|")) == 5: 218 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") 219 | elif len(line.strip().split("|")) == 4: 220 | utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") 221 | 222 | if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")): 223 | continue 224 | gen_wav = os.path.join(gen_wav_dir, utt + ".wav") 225 | if not os.path.isabs(prompt_wav): 226 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) 227 | 228 | test_set_.append((gen_wav, prompt_wav, gt_text)) 229 | 230 | num_jobs = len(gpus) 231 | if num_jobs == 1: 232 | return [(gpus[0], test_set_)] 233 | 234 | wav_per_job = len(test_set_) // num_jobs + 1 235 | test_set = [] 236 | for i in range(num_jobs): 237 | test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) 238 | 239 | return test_set 240 | 241 | 242 | # get librispeech test-clean cross sentence test 243 | 244 | 245 | def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False): 246 | f = open(metalst) 247 | lines = f.readlines() 248 | f.close() 249 | 250 | test_set_ = [] 251 | for line in tqdm(lines): 252 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") 253 | 254 | if eval_ground_truth: 255 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") 256 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") 257 | else: 258 | if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")): 259 | raise FileNotFoundError(f"Generated wav not found: {gen_utt}") 260 | gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav") 261 | 262 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") 263 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") 264 | 265 | test_set_.append((gen_wav, ref_wav, gen_txt)) 266 | 267 | num_jobs = len(gpus) 268 | if num_jobs == 1: 269 | return [(gpus[0], test_set_)] 270 | 271 | wav_per_job = len(test_set_) // num_jobs + 1 272 | test_set = [] 273 | for i in range(num_jobs): 274 | test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) 275 | 276 | return test_set 277 | 278 | 279 | # load asr model 280 | 281 | 282 | def load_asr_model(lang, ckpt_dir=""): 283 | if lang == "zh": 284 | from funasr import AutoModel 285 | 286 | model = AutoModel( 287 | model=os.path.join(ckpt_dir, "paraformer-zh"), 288 | # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), 289 | # punc_model = os.path.join(ckpt_dir, "ct-punc"), 290 | # spk_model = os.path.join(ckpt_dir, "cam++"), 291 | disable_update=True, 292 | ) # following seed-tts setting 293 | elif lang == "en": 294 | from faster_whisper import WhisperModel 295 | 296 | model_size = "large-v3" if ckpt_dir == "" else ckpt_dir 297 | model = WhisperModel(model_size, device="cuda", compute_type="float16") 298 | return model 299 | 300 | 301 | # WER Evaluation, the way Seed-TTS does 302 | 303 | 304 | def run_asr_wer(args): 305 | rank, lang, test_set, ckpt_dir = args 306 | 307 | if lang == "zh": 308 | import zhconv 309 | 310 | torch.cuda.set_device(rank) 311 | elif lang == "en": 312 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 313 | else: 314 | raise NotImplementedError( 315 | "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now." 316 | ) 317 | 318 | asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir) 319 | 320 | from zhon.hanzi import punctuation 321 | 322 | punctuation_all = punctuation + string.punctuation 323 | wers = [] 324 | 325 | from jiwer import compute_measures 326 | 327 | for gen_wav, prompt_wav, truth in tqdm(test_set): 328 | if lang == "zh": 329 | res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True) 330 | hypo = res[0]["text"] 331 | hypo = zhconv.convert(hypo, "zh-cn") 332 | elif lang == "en": 333 | segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en") 334 | hypo = "" 335 | for segment in segments: 336 | hypo = hypo + " " + segment.text 337 | 338 | # raw_truth = truth 339 | # raw_hypo = hypo 340 | 341 | for x in punctuation_all: 342 | truth = truth.replace(x, "") 343 | hypo = hypo.replace(x, "") 344 | 345 | truth = truth.replace(" ", " ") 346 | hypo = hypo.replace(" ", " ") 347 | 348 | if lang == "zh": 349 | truth = " ".join([x for x in truth]) 350 | hypo = " ".join([x for x in hypo]) 351 | elif lang == "en": 352 | truth = truth.lower() 353 | hypo = hypo.lower() 354 | 355 | measures = compute_measures(truth, hypo) 356 | wer = measures["wer"] 357 | 358 | # ref_list = truth.split(" ") 359 | # subs = measures["substitutions"] / len(ref_list) 360 | # dele = measures["deletions"] / len(ref_list) 361 | # inse = measures["insertions"] / len(ref_list) 362 | 363 | wers.append(wer) 364 | 365 | return wers 366 | 367 | 368 | # SIM Evaluation 369 | 370 | 371 | def run_sim(args): 372 | rank, test_set, ckpt_dir = args 373 | device = f"cuda:{rank}" 374 | 375 | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None) 376 | state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage) 377 | model.load_state_dict(state_dict["model"], strict=False) 378 | 379 | use_gpu = True if torch.cuda.is_available() else False 380 | if use_gpu: 381 | model = model.cuda(device) 382 | model.eval() 383 | 384 | sim_list = [] 385 | for wav1, wav2, truth in tqdm(test_set): 386 | wav1, sr1 = torchaudio.load(wav1) 387 | wav2, sr2 = torchaudio.load(wav2) 388 | 389 | resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000) 390 | resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000) 391 | wav1 = resample1(wav1) 392 | wav2 = resample2(wav2) 393 | 394 | if use_gpu: 395 | wav1 = wav1.cuda(device) 396 | wav2 = wav2.cuda(device) 397 | with torch.no_grad(): 398 | emb1 = model(wav1) 399 | emb2 = model(wav2) 400 | 401 | sim = F.cosine_similarity(emb1, emb2)[0].item() 402 | # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") 403 | sim_list.append(sim) 404 | 405 | return sim_list 406 | -------------------------------------------------------------------------------- /src/f5_tts/infer/README.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts. 4 | 5 | **More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.** 6 | 7 | Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**. 8 | 9 | To avoid possible inference failures, make sure you have seen through the following instructions. 10 | 11 | - Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation. 12 | - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words. 13 | - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. 14 | - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English. 15 | 16 | 17 | ## Gradio App 18 | 19 | Currently supported features: 20 | 21 | - Basic TTS with Chunk Inference 22 | - Multi-Style / Multi-Speaker Generation 23 | - Voice Chat powered by Qwen2.5-3B-Instruct 24 | 25 | The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference. 26 | 27 | The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat. 28 | 29 | Could also be used as a component for larger application. 30 | ```python 31 | import gradio as gr 32 | from f5_tts.infer.infer_gradio import app 33 | 34 | with gr.Blocks() as main_app: 35 | gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app") 36 | 37 | # ... other Gradio components 38 | 39 | app.render() 40 | 41 | main_app.launch() 42 | ``` 43 | 44 | 45 | ## CLI Inference 46 | 47 | The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference. 48 | 49 | The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`. 50 | 51 | For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file. 52 | 53 | Basically you can inference with flags: 54 | ```bash 55 | # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) 56 | f5-tts_infer-cli \ 57 | --model "F5-TTS" \ 58 | --ref_audio "ref_audio.wav" \ 59 | --ref_text "The content, subtitle or transcription of reference audio." \ 60 | --gen_text "Some text you want TTS model generate for you." 61 | 62 | # Choose Vocoder 63 | f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file 64 | f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file 65 | ``` 66 | 67 | And a `.toml` file would help with more flexible usage. 68 | 69 | ```bash 70 | f5-tts_infer-cli -c custom.toml 71 | ``` 72 | 73 | For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`: 74 | 75 | ```toml 76 | # F5-TTS | E2-TTS 77 | model = "F5-TTS" 78 | ref_audio = "infer/examples/basic/basic_ref_en.wav" 79 | # If an empty "", transcribes the reference audio automatically. 80 | ref_text = "Some call me nature, others call me mother nature." 81 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." 82 | # File with text to generate. Ignores the text above. 83 | gen_file = "" 84 | remove_silence = false 85 | output_dir = "tests" 86 | ``` 87 | 88 | You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`. 89 | 90 | ```toml 91 | # F5-TTS | E2-TTS 92 | model = "F5-TTS" 93 | ref_audio = "infer/examples/multi/main.flac" 94 | # If an empty "", transcribes the reference audio automatically. 95 | ref_text = "" 96 | gen_text = "" 97 | # File with text to generate. Ignores the text above. 98 | gen_file = "infer/examples/multi/story.txt" 99 | remove_silence = true 100 | output_dir = "tests" 101 | 102 | [voices.town] 103 | ref_audio = "infer/examples/multi/town.flac" 104 | ref_text = "" 105 | 106 | [voices.country] 107 | ref_audio = "infer/examples/multi/country.flac" 108 | ref_text = "" 109 | ``` 110 | You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`. 111 | 112 | ## Speech Editing 113 | 114 | To test speech editing capabilities, use the following command: 115 | 116 | ```bash 117 | python src/f5_tts/infer/speech_edit.py 118 | ``` 119 | 120 | ## Socket Realtime Client 121 | 122 | To communicate with socket server you need to run 123 | ```bash 124 | python src/f5_tts/socket_server.py 125 | ``` 126 | 127 |
128 | Then create client to communicate 129 | 130 | ``` python 131 | import socket 132 | import numpy as np 133 | import asyncio 134 | import pyaudio 135 | 136 | async def listen_to_voice(text, server_ip='localhost', server_port=9999): 137 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 138 | client_socket.connect((server_ip, server_port)) 139 | 140 | async def play_audio_stream(): 141 | buffer = b'' 142 | p = pyaudio.PyAudio() 143 | stream = p.open(format=pyaudio.paFloat32, 144 | channels=1, 145 | rate=24000, # Ensure this matches the server's sampling rate 146 | output=True, 147 | frames_per_buffer=2048) 148 | 149 | try: 150 | while True: 151 | chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024) 152 | if not chunk: # End of stream 153 | break 154 | if b"END_OF_AUDIO" in chunk: 155 | buffer += chunk.replace(b"END_OF_AUDIO", b"") 156 | if buffer: 157 | audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy 158 | stream.write(audio_array.tobytes()) 159 | break 160 | buffer += chunk 161 | if len(buffer) >= 4096: 162 | audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy 163 | stream.write(audio_array.tobytes()) 164 | buffer = buffer[4096:] 165 | finally: 166 | stream.stop_stream() 167 | stream.close() 168 | p.terminate() 169 | 170 | try: 171 | # Send only the text to the server 172 | await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8')) 173 | await play_audio_stream() 174 | print("Audio playback finished.") 175 | 176 | except Exception as e: 177 | print(f"Error in listen_to_voice: {e}") 178 | 179 | finally: 180 | client_socket.close() 181 | 182 | # Example usage: Replace this with your actual server IP and port 183 | async def main(): 184 | await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998) 185 | 186 | # Run the main async function 187 | asyncio.run(main()) 188 | ``` 189 | 190 |
191 | 192 | -------------------------------------------------------------------------------- /src/f5_tts/infer/SHARED.md: -------------------------------------------------------------------------------- 1 | 2 | # Shared Model Cards 3 | 4 | 5 | ### **Prerequisites of using** 6 | - This document is serving as a quick lookup table for the community training/finetuning result, with various language support. 7 | - The models in this repository are open source and are based on voluntary contributions from contributors. 8 | - The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts. 9 | 10 | 11 | ### **Welcome to share here** 12 | - Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization). 13 | - Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files. 14 | - Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`. 15 | 16 | 17 | ### Supported Languages 18 | - [Multilingual](#multilingual) 19 | - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en) 20 | - [Mandarin](#mandarin) 21 | - [Japanese](#japanese) 22 | - [F5-TTS Base @ pretrain/finetune @ ja](#f5-tts-base--pretrainfinetune--ja) 23 | - [English](#english) 24 | - [French](#french) 25 | - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr) 26 | 27 | 28 | ## Multilingual 29 | 30 | #### F5-TTS Base @ pretrain @ zh & en 31 | |Model|🤗Hugging Face|Data (Hours)|Model License| 32 | |:---:|:------------:|:-----------:|:-------------:| 33 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| 34 | 35 | ```bash 36 | MODEL_CKPT: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors 37 | VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt 38 | ``` 39 | 40 | *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...* 41 | 42 | 43 | ## Mandarin 44 | 45 | ## Japanese 46 | 47 | #### F5-TTS Base @ pretrain/finetune @ ja 48 | |Model|🤗Hugging Face|Data (Hours)|Model License| 49 | |:---:|:------------:|:-----------:|:-------------:| 50 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0| 51 | 52 | ```bash 53 | MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt 54 | VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt 55 | ``` 56 | 57 | ## English 58 | 59 | 60 | ## French 61 | 62 | #### French LibriVox @ finetune @ fr 63 | |Model|🤗Hugging Face|Data (Hours)|Model License| 64 | |:---:|:------------:|:-----------:|:-------------:| 65 | |F5-TTS French|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0| 66 | 67 | ```bash 68 | MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt 69 | VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt 70 | ``` 71 | 72 | - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french). 73 | - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys). 74 | - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434). 75 | -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/basic/basic.toml: -------------------------------------------------------------------------------- 1 | # F5-TTS | E2-TTS 2 | model = "F5-TTS" 3 | ref_audio = "infer/examples/basic/basic_ref_en.wav" 4 | # If an empty "", transcribes the reference audio automatically. 5 | ref_text = "Some call me nature, others call me mother nature." 6 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." 7 | # File with text to generate. Ignores the text above. 8 | gen_file = "" 9 | remove_silence = false 10 | output_dir = "tests" -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/basic/basic_ref_en.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/basic/basic_ref_en.wav -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/basic/basic_ref_zh.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/basic/basic_ref_zh.wav -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/multi/country.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/multi/country.flac -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/multi/main.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/multi/main.flac -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/multi/story.toml: -------------------------------------------------------------------------------- 1 | # F5-TTS | E2-TTS 2 | model = "F5-TTS" 3 | ref_audio = "infer/examples/multi/main.flac" 4 | # If an empty "", transcribes the reference audio automatically. 5 | ref_text = "" 6 | gen_text = "" 7 | # File with text to generate. Ignores the text above. 8 | gen_file = "infer/examples/multi/story.txt" 9 | remove_silence = true 10 | output_dir = "tests" 11 | 12 | [voices.town] 13 | ref_audio = "infer/examples/multi/town.flac" 14 | ref_text = "" 15 | 16 | [voices.country] 17 | ref_audio = "infer/examples/multi/country.flac" 18 | ref_text = "" 19 | 20 | -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/multi/story.txt: -------------------------------------------------------------------------------- 1 | A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.” -------------------------------------------------------------------------------- /src/f5_tts/infer/examples/multi/town.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/multi/town.flac -------------------------------------------------------------------------------- /src/f5_tts/infer/infer_cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | import os 4 | import re 5 | from importlib.resources import files 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import soundfile as sf 10 | import tomli 11 | from cached_path import cached_path 12 | 13 | from f5_tts.infer.utils_infer import ( 14 | infer_process, 15 | load_model, 16 | load_vocoder, 17 | preprocess_ref_audio_text, 18 | remove_silence_for_generated_wav, 19 | ) 20 | from f5_tts.model import DiT, UNetT 21 | 22 | parser = argparse.ArgumentParser( 23 | prog="python3 infer-cli.py", 24 | description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.", 25 | epilog="Specify options above to override one or more settings from config.", 26 | ) 27 | parser.add_argument( 28 | "-c", 29 | "--config", 30 | help="Configuration file. Default=infer/examples/basic/basic.toml", 31 | default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), 32 | ) 33 | parser.add_argument( 34 | "-m", 35 | "--model", 36 | help="F5-TTS | E2-TTS", 37 | ) 38 | parser.add_argument( 39 | "-p", 40 | "--ckpt_file", 41 | help="The Checkpoint .pt", 42 | ) 43 | parser.add_argument( 44 | "-v", 45 | "--vocab_file", 46 | help="The vocab .txt", 47 | ) 48 | parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.") 49 | parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.") 50 | parser.add_argument( 51 | "-t", 52 | "--gen_text", 53 | type=str, 54 | help="Text to generate.", 55 | ) 56 | parser.add_argument( 57 | "-f", 58 | "--gen_file", 59 | type=str, 60 | help="File with text to generate. Ignores --text", 61 | ) 62 | parser.add_argument( 63 | "-o", 64 | "--output_dir", 65 | type=str, 66 | help="Path to output folder..", 67 | ) 68 | parser.add_argument( 69 | "--remove_silence", 70 | help="Remove silence.", 71 | ) 72 | parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name") 73 | parser.add_argument( 74 | "--load_vocoder_from_local", 75 | action="store_true", 76 | help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz", 77 | ) 78 | parser.add_argument( 79 | "--speed", 80 | type=float, 81 | default=1.0, 82 | help="Adjust the speed of the audio generation (default: 1.0)", 83 | ) 84 | args = parser.parse_args() 85 | 86 | config = tomli.load(open(args.config, "rb")) 87 | 88 | ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"] 89 | ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"] 90 | gen_text = args.gen_text if args.gen_text else config["gen_text"] 91 | gen_file = args.gen_file if args.gen_file else config["gen_file"] 92 | 93 | # patches for pip pkg user 94 | if "infer/examples/" in ref_audio: 95 | ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}")) 96 | if "infer/examples/" in gen_file: 97 | gen_file = str(files("f5_tts").joinpath(f"{gen_file}")) 98 | if "voices" in config: 99 | for voice in config["voices"]: 100 | voice_ref_audio = config["voices"][voice]["ref_audio"] 101 | if "infer/examples/" in voice_ref_audio: 102 | config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) 103 | 104 | if gen_file: 105 | gen_text = codecs.open(gen_file, "r", "utf-8").read() 106 | output_dir = args.output_dir if args.output_dir else config["output_dir"] 107 | model = args.model if args.model else config["model"] 108 | ckpt_file = args.ckpt_file if args.ckpt_file else "" 109 | vocab_file = args.vocab_file if args.vocab_file else "" 110 | remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] 111 | speed = args.speed 112 | wave_path = Path(output_dir) / "infer_cli_out.wav" 113 | # spectrogram_path = Path(output_dir) / "infer_cli_out.png" 114 | if args.vocoder_name == "vocos": 115 | vocoder_local_path = "../checkpoints/vocos-mel-24khz" 116 | elif args.vocoder_name == "bigvgan": 117 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" 118 | mel_spec_type = args.vocoder_name 119 | 120 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path) 121 | 122 | 123 | # load models 124 | if model == "F5-TTS": 125 | model_cls = DiT 126 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 127 | if ckpt_file == "": 128 | if args.vocoder_name == "vocos": 129 | repo_name = "F5-TTS" 130 | exp_name = "F5TTS_Base" 131 | ckpt_step = 1200000 132 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) 133 | # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path 134 | elif args.vocoder_name == "bigvgan": 135 | repo_name = "F5-TTS" 136 | exp_name = "F5TTS_Base_bigvgan" 137 | ckpt_step = 1250000 138 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) 139 | 140 | elif model == "E2-TTS": 141 | model_cls = UNetT 142 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 143 | if ckpt_file == "": 144 | repo_name = "E2-TTS" 145 | exp_name = "E2TTS_Base" 146 | ckpt_step = 1200000 147 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) 148 | # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path 149 | elif args.vocoder_name == "bigvgan": # TODO: need to test 150 | repo_name = "F5-TTS" 151 | exp_name = "F5TTS_Base_bigvgan" 152 | ckpt_step = 1250000 153 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) 154 | 155 | 156 | print(f"Using {model}...") 157 | ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file) 158 | 159 | 160 | def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed): 161 | main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} 162 | if "voices" not in config: 163 | voices = {"main": main_voice} 164 | else: 165 | voices = config["voices"] 166 | voices["main"] = main_voice 167 | for voice in voices: 168 | voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( 169 | voices[voice]["ref_audio"], voices[voice]["ref_text"] 170 | ) 171 | print("Voice:", voice) 172 | print("Ref_audio:", voices[voice]["ref_audio"]) 173 | print("Ref_text:", voices[voice]["ref_text"]) 174 | 175 | generated_audio_segments = [] 176 | reg1 = r"(?=\[\w+\])" 177 | chunks = re.split(reg1, text_gen) 178 | reg2 = r"\[(\w+)\]" 179 | for text in chunks: 180 | if not text.strip(): 181 | continue 182 | match = re.match(reg2, text) 183 | if match: 184 | voice = match[1] 185 | else: 186 | print("No voice tag found, using main.") 187 | voice = "main" 188 | if voice not in voices: 189 | print(f"Voice {voice} not found, using main.") 190 | voice = "main" 191 | text = re.sub(reg2, "", text) 192 | gen_text = text.strip() 193 | ref_audio = voices[voice]["ref_audio"] 194 | ref_text = voices[voice]["ref_text"] 195 | print(f"Voice: {voice}") 196 | audio, final_sample_rate, spectragram = infer_process( 197 | ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed 198 | ) 199 | generated_audio_segments.append(audio) 200 | 201 | if generated_audio_segments: 202 | final_wave = np.concatenate(generated_audio_segments) 203 | 204 | if not os.path.exists(output_dir): 205 | os.makedirs(output_dir) 206 | 207 | with open(wave_path, "wb") as f: 208 | sf.write(f.name, final_wave, final_sample_rate) 209 | # Remove silence 210 | if remove_silence: 211 | remove_silence_for_generated_wav(f.name) 212 | print(f.name) 213 | 214 | 215 | def main(): 216 | main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed) 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /src/f5_tts/infer/speech_edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torchaudio 6 | 7 | from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram 8 | from f5_tts.model import CFM, DiT, UNetT 9 | from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer 10 | 11 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 12 | 13 | 14 | # --------------------- Dataset Settings -------------------- # 15 | 16 | target_sample_rate = 24000 17 | n_mel_channels = 100 18 | hop_length = 256 19 | win_length = 1024 20 | n_fft = 1024 21 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan' 22 | target_rms = 0.1 23 | 24 | tokenizer = "pinyin" 25 | dataset_name = "Emilia_ZH_EN" 26 | 27 | 28 | # ---------------------- infer setting ---------------------- # 29 | 30 | seed = None # int | None 31 | 32 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base 33 | ckpt_step = 1200000 34 | 35 | nfe_step = 32 # 16, 32 36 | cfg_strength = 2.0 37 | ode_method = "euler" # euler | midpoint 38 | sway_sampling_coef = -1.0 39 | speed = 1.0 40 | 41 | if exp_name == "F5TTS_Base": 42 | model_cls = DiT 43 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 44 | 45 | elif exp_name == "E2TTS_Base": 46 | model_cls = UNetT 47 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 48 | 49 | ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" 50 | output_dir = "tests" 51 | 52 | # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] 53 | # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git 54 | # [write the origin_text into a file, e.g. tests/test_edit.txt] 55 | # ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char" 56 | # [result will be saved at same path of audio file] 57 | # [--language "zho" for Chinese, "eng" for English] 58 | # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] 59 | 60 | audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav" 61 | origin_text = "Some call me nature, others call me mother nature." 62 | target_text = "Some call me optimist, others call me realist." 63 | parts_to_edit = [ 64 | [1.42, 2.44], 65 | [4.04, 4.9], 66 | ] # stard_ends of "nature" & "mother nature", in seconds 67 | fix_duration = [ 68 | 1.2, 69 | 1, 70 | ] # fix duration for "optimist" & "realist", in seconds 71 | 72 | # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav" 73 | # origin_text = "对,这就是我,万人敬仰的太乙真人。" 74 | # target_text = "对,那就是你,万人敬仰的太白金星。" 75 | # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] 76 | # fix_duration = None # use origin text duration 77 | 78 | 79 | # -------------------------------------------------# 80 | 81 | use_ema = True 82 | 83 | if not os.path.exists(output_dir): 84 | os.makedirs(output_dir) 85 | 86 | # Vocoder model 87 | local = False 88 | if mel_spec_type == "vocos": 89 | vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" 90 | elif mel_spec_type == "bigvgan": 91 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" 92 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) 93 | 94 | # Tokenizer 95 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) 96 | 97 | # Model 98 | model = CFM( 99 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 100 | mel_spec_kwargs=dict( 101 | n_fft=n_fft, 102 | hop_length=hop_length, 103 | win_length=win_length, 104 | n_mel_channels=n_mel_channels, 105 | target_sample_rate=target_sample_rate, 106 | mel_spec_type=mel_spec_type, 107 | ), 108 | odeint_kwargs=dict( 109 | method=ode_method, 110 | ), 111 | vocab_char_map=vocab_char_map, 112 | ).to(device) 113 | 114 | dtype = torch.float32 if mel_spec_type == "bigvgan" else None 115 | model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) 116 | 117 | # Audio 118 | audio, sr = torchaudio.load(audio_to_edit) 119 | if audio.shape[0] > 1: 120 | audio = torch.mean(audio, dim=0, keepdim=True) 121 | rms = torch.sqrt(torch.mean(torch.square(audio))) 122 | if rms < target_rms: 123 | audio = audio * target_rms / rms 124 | if sr != target_sample_rate: 125 | resampler = torchaudio.transforms.Resample(sr, target_sample_rate) 126 | audio = resampler(audio) 127 | offset = 0 128 | audio_ = torch.zeros(1, 0) 129 | edit_mask = torch.zeros(1, 0, dtype=torch.bool) 130 | for part in parts_to_edit: 131 | start, end = part 132 | part_dur = end - start if fix_duration is None else fix_duration.pop(0) 133 | part_dur = part_dur * target_sample_rate 134 | start = start * target_sample_rate 135 | audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1) 136 | edit_mask = torch.cat( 137 | ( 138 | edit_mask, 139 | torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool), 140 | torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool), 141 | ), 142 | dim=-1, 143 | ) 144 | offset = end * target_sample_rate 145 | # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1) 146 | edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True) 147 | audio = audio.to(device) 148 | edit_mask = edit_mask.to(device) 149 | 150 | # Text 151 | text_list = [target_text] 152 | if tokenizer == "pinyin": 153 | final_text_list = convert_char_to_pinyin(text_list) 154 | else: 155 | final_text_list = [text_list] 156 | print(f"text : {text_list}") 157 | print(f"pinyin: {final_text_list}") 158 | 159 | # Duration 160 | ref_audio_len = 0 161 | duration = audio.shape[-1] // hop_length 162 | 163 | # Inference 164 | with torch.inference_mode(): 165 | generated, trajectory = model.sample( 166 | cond=audio, 167 | text=final_text_list, 168 | duration=duration, 169 | steps=nfe_step, 170 | cfg_strength=cfg_strength, 171 | sway_sampling_coef=sway_sampling_coef, 172 | seed=seed, 173 | edit_mask=edit_mask, 174 | ) 175 | print(f"Generated mel: {generated.shape}") 176 | 177 | # Final result 178 | generated = generated.to(torch.float32) 179 | generated = generated[:, ref_audio_len:, :] 180 | gen_mel_spec = generated.permute(0, 2, 1) 181 | if mel_spec_type == "vocos": 182 | generated_wave = vocoder.decode(gen_mel_spec) 183 | elif mel_spec_type == "bigvgan": 184 | generated_wave = vocoder(gen_mel_spec) 185 | 186 | if rms < target_rms: 187 | generated_wave = generated_wave * rms / target_rms 188 | 189 | save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") 190 | torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate) 191 | print(f"Generated wav: {generated_wave.shape}") 192 | -------------------------------------------------------------------------------- /src/f5_tts/model/__init__.py: -------------------------------------------------------------------------------- 1 | from f5_tts.model.cfm import CFM 2 | 3 | from f5_tts.model.backbones.unett import UNetT 4 | from f5_tts.model.backbones.dit import DiT 5 | from f5_tts.model.backbones.mmdit import MMDiT 6 | 7 | from f5_tts.model.trainer import Trainer 8 | 9 | 10 | __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] 11 | -------------------------------------------------------------------------------- /src/f5_tts/model/backbones/README.md: -------------------------------------------------------------------------------- 1 | ## Backbones quick introduction 2 | 3 | 4 | ### unett.py 5 | - flat unet transformer 6 | - structure same as in e2-tts & voicebox paper except using rotary pos emb 7 | - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat 8 | 9 | ### dit.py 10 | - adaln-zero dit 11 | - embedded timestep as condition 12 | - concatted noised_input + masked_cond + embedded_text, linear proj in 13 | - possible abs pos emb & convnextv2 blocks for embedded text before concat 14 | - possible long skip connection (first layer to last layer) 15 | 16 | ### mmdit.py 17 | - sd3 structure 18 | - timestep as condition 19 | - left stream: text embedded and applied a abs pos emb 20 | - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett 21 | -------------------------------------------------------------------------------- /src/f5_tts/model/backbones/dit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | from x_transformers.x_transformers import RotaryEmbedding 17 | 18 | from f5_tts.model.modules import ( 19 | TimestepEmbedding, 20 | ConvNeXtV2Block, 21 | ConvPositionEmbedding, 22 | DiTBlock, 23 | AdaLayerNormZero_Final, 24 | precompute_freqs_cis, 25 | get_pos_embed_indices, 26 | ) 27 | 28 | 29 | # Text embedding 30 | 31 | 32 | class TextEmbedding(nn.Module): 33 | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): 34 | super().__init__() 35 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token 36 | 37 | if conv_layers > 0: 38 | self.extra_modeling = True 39 | self.precompute_max_pos = 4096 # ~44s of 24khz audio 40 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) 41 | self.text_blocks = nn.Sequential( 42 | *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] 43 | ) 44 | else: 45 | self.extra_modeling = False 46 | 47 | def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 48 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() 49 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens 50 | batch, text_len = text.shape[0], text.shape[1] 51 | text = F.pad(text, (0, seq_len - text_len), value=0) 52 | 53 | if drop_text: # cfg for text 54 | text = torch.zeros_like(text) 55 | 56 | text = self.text_embed(text) # b n -> b n d 57 | 58 | # possible extra modeling 59 | if self.extra_modeling: 60 | # sinus pos emb 61 | batch_start = torch.zeros((batch,), dtype=torch.long) 62 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) 63 | text_pos_embed = self.freqs_cis[pos_idx] 64 | text = text + text_pos_embed 65 | 66 | # convnextv2 blocks 67 | text = self.text_blocks(text) 68 | 69 | return text 70 | 71 | 72 | # noised input audio and context mixing embedding 73 | 74 | 75 | class InputEmbedding(nn.Module): 76 | def __init__(self, mel_dim, text_dim, out_dim): 77 | super().__init__() 78 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) 79 | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) 80 | 81 | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 82 | if drop_audio_cond: # cfg for cond audio 83 | cond = torch.zeros_like(cond) 84 | 85 | x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) 86 | x = self.conv_pos_embed(x) + x 87 | return x 88 | 89 | 90 | # Transformer backbone using DiT blocks 91 | 92 | 93 | class DiT(nn.Module): 94 | def __init__( 95 | self, 96 | *, 97 | dim, 98 | depth=8, 99 | heads=8, 100 | dim_head=64, 101 | dropout=0.1, 102 | ff_mult=4, 103 | mel_dim=100, 104 | text_num_embeds=256, 105 | text_dim=None, 106 | conv_layers=0, 107 | long_skip_connection=False, 108 | ): 109 | super().__init__() 110 | 111 | self.time_embed = TimestepEmbedding(dim) 112 | if text_dim is None: 113 | text_dim = mel_dim 114 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) 115 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim) 116 | 117 | self.rotary_embed = RotaryEmbedding(dim_head) 118 | 119 | self.dim = dim 120 | self.depth = depth 121 | 122 | self.transformer_blocks = nn.ModuleList( 123 | [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] 124 | ) 125 | self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None 126 | 127 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation 128 | self.proj_out = nn.Linear(dim, mel_dim) 129 | 130 | def forward( 131 | self, 132 | x: float["b n d"], # nosied input audio # noqa: F722 133 | cond: float["b n d"], # masked cond audio # noqa: F722 134 | text: int["b nt"], # text # noqa: F722 135 | time: float["b"] | float[""], # time step # noqa: F821 F722 136 | drop_audio_cond, # cfg for cond audio 137 | drop_text, # cfg for text 138 | mask: bool["b n"] | None = None, # noqa: F722 139 | ): 140 | batch, seq_len = x.shape[0], x.shape[1] 141 | if time.ndim == 0: 142 | time = time.repeat(batch) 143 | 144 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio 145 | t = self.time_embed(time) 146 | text_embed = self.text_embed(text, seq_len, drop_text=drop_text) 147 | x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) 148 | 149 | rope = self.rotary_embed.forward_from_seq_len(seq_len) 150 | 151 | if self.long_skip_connection is not None: 152 | residual = x 153 | 154 | for block in self.transformer_blocks: 155 | x = block(x, t, mask=mask, rope=rope) 156 | 157 | if self.long_skip_connection is not None: 158 | x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) 159 | 160 | x = self.norm_out(x, t) 161 | output = self.proj_out(x) 162 | 163 | return output 164 | -------------------------------------------------------------------------------- /src/f5_tts/model/backbones/mmdit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from x_transformers.x_transformers import RotaryEmbedding 16 | 17 | from f5_tts.model.modules import ( 18 | TimestepEmbedding, 19 | ConvPositionEmbedding, 20 | MMDiTBlock, 21 | AdaLayerNormZero_Final, 22 | precompute_freqs_cis, 23 | get_pos_embed_indices, 24 | ) 25 | 26 | 27 | # text embedding 28 | 29 | 30 | class TextEmbedding(nn.Module): 31 | def __init__(self, out_dim, text_num_embeds): 32 | super().__init__() 33 | self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token 34 | 35 | self.precompute_max_pos = 1024 36 | self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) 37 | 38 | def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 39 | text = text + 1 40 | if drop_text: 41 | text = torch.zeros_like(text) 42 | text = self.text_embed(text) 43 | 44 | # sinus pos emb 45 | batch_start = torch.zeros((text.shape[0],), dtype=torch.long) 46 | batch_text_len = text.shape[1] 47 | pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) 48 | text_pos_embed = self.freqs_cis[pos_idx] 49 | 50 | text = text + text_pos_embed 51 | 52 | return text 53 | 54 | 55 | # noised input & masked cond audio embedding 56 | 57 | 58 | class AudioEmbedding(nn.Module): 59 | def __init__(self, in_dim, out_dim): 60 | super().__init__() 61 | self.linear = nn.Linear(2 * in_dim, out_dim) 62 | self.conv_pos_embed = ConvPositionEmbedding(out_dim) 63 | 64 | def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 65 | if drop_audio_cond: 66 | cond = torch.zeros_like(cond) 67 | x = torch.cat((x, cond), dim=-1) 68 | x = self.linear(x) 69 | x = self.conv_pos_embed(x) + x 70 | return x 71 | 72 | 73 | # Transformer backbone using MM-DiT blocks 74 | 75 | 76 | class MMDiT(nn.Module): 77 | def __init__( 78 | self, 79 | *, 80 | dim, 81 | depth=8, 82 | heads=8, 83 | dim_head=64, 84 | dropout=0.1, 85 | ff_mult=4, 86 | text_num_embeds=256, 87 | mel_dim=100, 88 | ): 89 | super().__init__() 90 | 91 | self.time_embed = TimestepEmbedding(dim) 92 | self.text_embed = TextEmbedding(dim, text_num_embeds) 93 | self.audio_embed = AudioEmbedding(mel_dim, dim) 94 | 95 | self.rotary_embed = RotaryEmbedding(dim_head) 96 | 97 | self.dim = dim 98 | self.depth = depth 99 | 100 | self.transformer_blocks = nn.ModuleList( 101 | [ 102 | MMDiTBlock( 103 | dim=dim, 104 | heads=heads, 105 | dim_head=dim_head, 106 | dropout=dropout, 107 | ff_mult=ff_mult, 108 | context_pre_only=i == depth - 1, 109 | ) 110 | for i in range(depth) 111 | ] 112 | ) 113 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation 114 | self.proj_out = nn.Linear(dim, mel_dim) 115 | 116 | def forward( 117 | self, 118 | x: float["b n d"], # nosied input audio # noqa: F722 119 | cond: float["b n d"], # masked cond audio # noqa: F722 120 | text: int["b nt"], # text # noqa: F722 121 | time: float["b"] | float[""], # time step # noqa: F821 F722 122 | drop_audio_cond, # cfg for cond audio 123 | drop_text, # cfg for text 124 | mask: bool["b n"] | None = None, # noqa: F722 125 | ): 126 | batch = x.shape[0] 127 | if time.ndim == 0: 128 | time = time.repeat(batch) 129 | 130 | # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio 131 | t = self.time_embed(time) 132 | c = self.text_embed(text, drop_text=drop_text) 133 | x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) 134 | 135 | seq_len = x.shape[1] 136 | text_len = text.shape[1] 137 | rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) 138 | rope_text = self.rotary_embed.forward_from_seq_len(text_len) 139 | 140 | for block in self.transformer_blocks: 141 | c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) 142 | 143 | x = self.norm_out(x, t) 144 | output = self.proj_out(x) 145 | 146 | return output 147 | -------------------------------------------------------------------------------- /src/f5_tts/model/backbones/unett.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | from typing import Literal 12 | 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from x_transformers import RMSNorm 18 | from x_transformers.x_transformers import RotaryEmbedding 19 | 20 | from f5_tts.model.modules import ( 21 | TimestepEmbedding, 22 | ConvNeXtV2Block, 23 | ConvPositionEmbedding, 24 | Attention, 25 | AttnProcessor, 26 | FeedForward, 27 | precompute_freqs_cis, 28 | get_pos_embed_indices, 29 | ) 30 | 31 | 32 | # Text embedding 33 | 34 | 35 | class TextEmbedding(nn.Module): 36 | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): 37 | super().__init__() 38 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token 39 | 40 | if conv_layers > 0: 41 | self.extra_modeling = True 42 | self.precompute_max_pos = 4096 # ~44s of 24khz audio 43 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) 44 | self.text_blocks = nn.Sequential( 45 | *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] 46 | ) 47 | else: 48 | self.extra_modeling = False 49 | 50 | def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 51 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() 52 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens 53 | batch, text_len = text.shape[0], text.shape[1] 54 | text = F.pad(text, (0, seq_len - text_len), value=0) 55 | 56 | if drop_text: # cfg for text 57 | text = torch.zeros_like(text) 58 | 59 | text = self.text_embed(text) # b n -> b n d 60 | 61 | # possible extra modeling 62 | if self.extra_modeling: 63 | # sinus pos emb 64 | batch_start = torch.zeros((batch,), dtype=torch.long) 65 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) 66 | text_pos_embed = self.freqs_cis[pos_idx] 67 | text = text + text_pos_embed 68 | 69 | # convnextv2 blocks 70 | text = self.text_blocks(text) 71 | 72 | return text 73 | 74 | 75 | # noised input audio and context mixing embedding 76 | 77 | 78 | class InputEmbedding(nn.Module): 79 | def __init__(self, mel_dim, text_dim, out_dim): 80 | super().__init__() 81 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) 82 | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) 83 | 84 | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 85 | if drop_audio_cond: # cfg for cond audio 86 | cond = torch.zeros_like(cond) 87 | 88 | x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) 89 | x = self.conv_pos_embed(x) + x 90 | return x 91 | 92 | 93 | # Flat UNet Transformer backbone 94 | 95 | 96 | class UNetT(nn.Module): 97 | def __init__( 98 | self, 99 | *, 100 | dim, 101 | depth=8, 102 | heads=8, 103 | dim_head=64, 104 | dropout=0.1, 105 | ff_mult=4, 106 | mel_dim=100, 107 | text_num_embeds=256, 108 | text_dim=None, 109 | conv_layers=0, 110 | skip_connect_type: Literal["add", "concat", "none"] = "concat", 111 | ): 112 | super().__init__() 113 | assert depth % 2 == 0, "UNet-Transformer's depth should be even." 114 | 115 | self.time_embed = TimestepEmbedding(dim) 116 | if text_dim is None: 117 | text_dim = mel_dim 118 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) 119 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim) 120 | 121 | self.rotary_embed = RotaryEmbedding(dim_head) 122 | 123 | # transformer layers & skip connections 124 | 125 | self.dim = dim 126 | self.skip_connect_type = skip_connect_type 127 | needs_skip_proj = skip_connect_type == "concat" 128 | 129 | self.depth = depth 130 | self.layers = nn.ModuleList([]) 131 | 132 | for idx in range(depth): 133 | is_later_half = idx >= (depth // 2) 134 | 135 | attn_norm = RMSNorm(dim) 136 | attn = Attention( 137 | processor=AttnProcessor(), 138 | dim=dim, 139 | heads=heads, 140 | dim_head=dim_head, 141 | dropout=dropout, 142 | ) 143 | 144 | ff_norm = RMSNorm(dim) 145 | ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") 146 | 147 | skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None 148 | 149 | self.layers.append( 150 | nn.ModuleList( 151 | [ 152 | skip_proj, 153 | attn_norm, 154 | attn, 155 | ff_norm, 156 | ff, 157 | ] 158 | ) 159 | ) 160 | 161 | self.norm_out = RMSNorm(dim) 162 | self.proj_out = nn.Linear(dim, mel_dim) 163 | 164 | def forward( 165 | self, 166 | x: float["b n d"], # nosied input audio # noqa: F722 167 | cond: float["b n d"], # masked cond audio # noqa: F722 168 | text: int["b nt"], # text # noqa: F722 169 | time: float["b"] | float[""], # time step # noqa: F821 F722 170 | drop_audio_cond, # cfg for cond audio 171 | drop_text, # cfg for text 172 | mask: bool["b n"] | None = None, # noqa: F722 173 | ): 174 | batch, seq_len = x.shape[0], x.shape[1] 175 | if time.ndim == 0: 176 | time = time.repeat(batch) 177 | 178 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio 179 | t = self.time_embed(time) 180 | text_embed = self.text_embed(text, seq_len, drop_text=drop_text) 181 | x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) 182 | 183 | # postfix time t to input x, [b n d] -> [b n+1 d] 184 | x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x 185 | if mask is not None: 186 | mask = F.pad(mask, (1, 0), value=1) 187 | 188 | rope = self.rotary_embed.forward_from_seq_len(seq_len + 1) 189 | 190 | # flat unet transformer 191 | skip_connect_type = self.skip_connect_type 192 | skips = [] 193 | for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): 194 | layer = idx + 1 195 | 196 | # skip connection logic 197 | is_first_half = layer <= (self.depth // 2) 198 | is_later_half = not is_first_half 199 | 200 | if is_first_half: 201 | skips.append(x) 202 | 203 | if is_later_half: 204 | skip = skips.pop() 205 | if skip_connect_type == "concat": 206 | x = torch.cat((x, skip), dim=-1) 207 | x = maybe_skip_proj(x) 208 | elif skip_connect_type == "add": 209 | x = x + skip 210 | 211 | # attention and feedforward blocks 212 | x = attn(attn_norm(x), rope=rope, mask=mask) + x 213 | x = ff(ff_norm(x)) + x 214 | 215 | assert len(skips) == 0 216 | 217 | x = self.norm_out(x)[:, 1:, :] # unpack t from x 218 | 219 | return self.proj_out(x) 220 | -------------------------------------------------------------------------------- /src/f5_tts/model/cfm.py: -------------------------------------------------------------------------------- 1 | """ 2 | ein notation: 3 | b - batch 4 | n - sequence 5 | nt - text sequence 6 | nw - raw wave length 7 | d - dimension 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from random import random 13 | from typing import Callable 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch import nn 18 | from torch.nn.utils.rnn import pad_sequence 19 | from torchdiffeq import odeint 20 | 21 | from f5_tts.model.modules import MelSpec 22 | from f5_tts.model.utils import ( 23 | default, 24 | exists, 25 | lens_to_mask, 26 | list_str_to_idx, 27 | list_str_to_tensor, 28 | mask_from_frac_lengths, 29 | ) 30 | 31 | 32 | class CFM(nn.Module): 33 | def __init__( 34 | self, 35 | transformer: nn.Module, 36 | sigma=0.0, 37 | odeint_kwargs: dict = dict( 38 | # atol = 1e-5, 39 | # rtol = 1e-5, 40 | method="euler" # 'midpoint' 41 | ), 42 | audio_drop_prob=0.3, 43 | cond_drop_prob=0.2, 44 | num_channels=None, 45 | mel_spec_module: nn.Module | None = None, 46 | mel_spec_kwargs: dict = dict(), 47 | frac_lengths_mask: tuple[float, float] = (0.7, 1.0), 48 | vocab_char_map: dict[str:int] | None = None, 49 | ): 50 | super().__init__() 51 | 52 | self.frac_lengths_mask = frac_lengths_mask 53 | 54 | # mel spec 55 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) 56 | num_channels = default(num_channels, self.mel_spec.n_mel_channels) 57 | self.num_channels = num_channels 58 | 59 | # classifier-free guidance 60 | self.audio_drop_prob = audio_drop_prob 61 | self.cond_drop_prob = cond_drop_prob 62 | 63 | # transformer 64 | self.transformer = transformer 65 | dim = transformer.dim 66 | self.dim = dim 67 | 68 | # conditional flow related 69 | self.sigma = sigma 70 | 71 | # sampling related 72 | self.odeint_kwargs = odeint_kwargs 73 | 74 | # vocab map for tokenization 75 | self.vocab_char_map = vocab_char_map 76 | 77 | @property 78 | def device(self): 79 | return next(self.parameters()).device 80 | 81 | @torch.no_grad() 82 | def sample( 83 | self, 84 | cond: float["b n d"] | float["b nw"], # noqa: F722 85 | text: int["b nt"] | list[str], # noqa: F722 86 | duration: int | int["b"], # noqa: F821 87 | *, 88 | lens: int["b"] | None = None, # noqa: F821 89 | steps=32, 90 | cfg_strength=1.0, 91 | sway_sampling_coef=None, 92 | seed: int | None = None, 93 | max_duration=4096, 94 | vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 95 | no_ref_audio=False, 96 | duplicate_test=False, 97 | t_inter=0.1, 98 | edit_mask=None, 99 | ): 100 | self.eval() 101 | # raw wave 102 | 103 | if cond.ndim == 2: 104 | cond = self.mel_spec(cond) 105 | cond = cond.permute(0, 2, 1) 106 | assert cond.shape[-1] == self.num_channels 107 | 108 | cond = cond.to(next(self.parameters()).dtype) 109 | 110 | batch, cond_seq_len, device = *cond.shape[:2], cond.device 111 | if not exists(lens): 112 | lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) 113 | 114 | # text 115 | 116 | if isinstance(text, list): 117 | if exists(self.vocab_char_map): 118 | text = list_str_to_idx(text, self.vocab_char_map).to(device) 119 | else: 120 | text = list_str_to_tensor(text).to(device) 121 | assert text.shape[0] == batch 122 | 123 | if exists(text): 124 | text_lens = (text != -1).sum(dim=-1) 125 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters 126 | 127 | # duration 128 | 129 | cond_mask = lens_to_mask(lens) 130 | if edit_mask is not None: 131 | cond_mask = cond_mask & edit_mask 132 | 133 | if isinstance(duration, int): 134 | duration = torch.full((batch,), duration, device=device, dtype=torch.long) 135 | 136 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated 137 | duration = duration.clamp(max=max_duration) 138 | max_duration = duration.amax() 139 | 140 | # duplicate test corner for inner time step oberservation 141 | if duplicate_test: 142 | test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) 143 | 144 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) 145 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) 146 | cond_mask = cond_mask.unsqueeze(-1) 147 | step_cond = torch.where( 148 | cond_mask, cond, torch.zeros_like(cond) 149 | ) # allow direct control (cut cond audio) with lens passed in 150 | 151 | if batch > 1: 152 | mask = lens_to_mask(duration) 153 | else: # save memory and speed up, as single inference need no mask currently 154 | mask = None 155 | 156 | # test for no ref audio 157 | if no_ref_audio: 158 | cond = torch.zeros_like(cond) 159 | 160 | # neural ode 161 | 162 | def fn(t, x): 163 | # at each step, conditioning is fixed 164 | # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) 165 | 166 | # predict flow 167 | pred = self.transformer( 168 | x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False 169 | ) 170 | if cfg_strength < 1e-5: 171 | return pred 172 | 173 | null_pred = self.transformer( 174 | x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True 175 | ) 176 | return pred + (pred - null_pred) * cfg_strength 177 | 178 | # noise input 179 | # to make sure batch inference result is same with different batch size, and for sure single inference 180 | # still some difference maybe due to convolutional layers 181 | y0 = [] 182 | for dur in duration: 183 | if exists(seed): 184 | torch.manual_seed(seed) 185 | y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) 186 | y0 = pad_sequence(y0, padding_value=0, batch_first=True) 187 | 188 | t_start = 0 189 | 190 | # duplicate test corner for inner time step oberservation 191 | if duplicate_test: 192 | t_start = t_inter 193 | y0 = (1 - t_start) * y0 + t_start * test_cond 194 | steps = int(steps * (1 - t_start)) 195 | 196 | t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype) 197 | if sway_sampling_coef is not None: 198 | t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) 199 | 200 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs) 201 | 202 | sampled = trajectory[-1] 203 | out = sampled 204 | out = torch.where(cond_mask, cond, out) 205 | 206 | if exists(vocoder): 207 | out = out.permute(0, 2, 1) 208 | out = vocoder(out) 209 | 210 | return out, trajectory 211 | 212 | def forward( 213 | self, 214 | inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 215 | text: int["b nt"] | list[str], # noqa: F722 216 | *, 217 | lens: int["b"] | None = None, # noqa: F821 218 | noise_scheduler: str | None = None, 219 | ): 220 | # handle raw wave 221 | if inp.ndim == 2: 222 | inp = self.mel_spec(inp) 223 | inp = inp.permute(0, 2, 1) 224 | assert inp.shape[-1] == self.num_channels 225 | 226 | batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma 227 | 228 | # handle text as string 229 | if isinstance(text, list): 230 | if exists(self.vocab_char_map): 231 | text = list_str_to_idx(text, self.vocab_char_map).to(device) 232 | else: 233 | text = list_str_to_tensor(text).to(device) 234 | assert text.shape[0] == batch 235 | 236 | # lens and mask 237 | if not exists(lens): 238 | lens = torch.full((batch,), seq_len, device=device) 239 | 240 | mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch 241 | 242 | # get a random span to mask out for training conditionally 243 | frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) 244 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) 245 | 246 | if exists(mask): 247 | rand_span_mask &= mask 248 | 249 | # mel is x1 250 | x1 = inp 251 | 252 | # x0 is gaussian noise 253 | x0 = torch.randn_like(x1) 254 | 255 | # time step 256 | time = torch.rand((batch,), dtype=dtype, device=self.device) 257 | # TODO. noise_scheduler 258 | 259 | # sample xt (φ_t(x) in the paper) 260 | t = time.unsqueeze(-1).unsqueeze(-1) 261 | φ = (1 - t) * x0 + t * x1 262 | flow = x1 - x0 263 | 264 | # only predict what is within the random mask span for infilling 265 | cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) 266 | 267 | # transformer and cfg training with a drop rate 268 | drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper 269 | if random() < self.cond_drop_prob: # p_uncond in voicebox paper 270 | drop_audio_cond = True 271 | drop_text = True 272 | else: 273 | drop_text = False 274 | 275 | # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here 276 | # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences 277 | pred = self.transformer( 278 | x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text 279 | ) 280 | 281 | # flow matching loss 282 | loss = F.mse_loss(pred, flow, reduction="none") 283 | loss = loss[rand_span_mask] 284 | 285 | return loss.mean(), cond, pred 286 | -------------------------------------------------------------------------------- /src/f5_tts/model/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from importlib.resources import files 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchaudio 8 | from datasets import Dataset as Dataset_ 9 | from datasets import load_from_disk 10 | from torch import nn 11 | from torch.utils.data import Dataset, Sampler 12 | from tqdm import tqdm 13 | 14 | from f5_tts.model.modules import MelSpec 15 | from f5_tts.model.utils import default 16 | 17 | 18 | class HFDataset(Dataset): 19 | def __init__( 20 | self, 21 | hf_dataset: Dataset, 22 | target_sample_rate=24_000, 23 | n_mel_channels=100, 24 | hop_length=256, 25 | n_fft=1024, 26 | win_length=1024, 27 | mel_spec_type="vocos", 28 | ): 29 | self.data = hf_dataset 30 | self.target_sample_rate = target_sample_rate 31 | self.hop_length = hop_length 32 | 33 | self.mel_spectrogram = MelSpec( 34 | n_fft=n_fft, 35 | hop_length=hop_length, 36 | win_length=win_length, 37 | n_mel_channels=n_mel_channels, 38 | target_sample_rate=target_sample_rate, 39 | mel_spec_type=mel_spec_type, 40 | ) 41 | 42 | def get_frame_len(self, index): 43 | row = self.data[index] 44 | audio = row["audio"]["array"] 45 | sample_rate = row["audio"]["sampling_rate"] 46 | return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def __getitem__(self, index): 52 | row = self.data[index] 53 | audio = row["audio"]["array"] 54 | 55 | # logger.info(f"Audio shape: {audio.shape}") 56 | 57 | sample_rate = row["audio"]["sampling_rate"] 58 | duration = audio.shape[-1] / sample_rate 59 | 60 | if duration > 30 or duration < 0.3: 61 | return self.__getitem__((index + 1) % len(self.data)) 62 | 63 | audio_tensor = torch.from_numpy(audio).float() 64 | 65 | if sample_rate != self.target_sample_rate: 66 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) 67 | audio_tensor = resampler(audio_tensor) 68 | 69 | audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') 70 | 71 | mel_spec = self.mel_spectrogram(audio_tensor) 72 | 73 | mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' 74 | 75 | text = row["text"] 76 | 77 | return dict( 78 | mel_spec=mel_spec, 79 | text=text, 80 | ) 81 | 82 | 83 | class CustomDataset(Dataset): 84 | def __init__( 85 | self, 86 | custom_dataset: Dataset, 87 | durations=None, 88 | target_sample_rate=24_000, 89 | hop_length=256, 90 | n_mel_channels=100, 91 | n_fft=1024, 92 | win_length=1024, 93 | mel_spec_type="vocos", 94 | preprocessed_mel=False, 95 | mel_spec_module: nn.Module | None = None, 96 | ): 97 | self.data = custom_dataset 98 | self.durations = durations 99 | self.target_sample_rate = target_sample_rate 100 | self.hop_length = hop_length 101 | self.n_fft = n_fft 102 | self.win_length = win_length 103 | self.mel_spec_type = mel_spec_type 104 | self.preprocessed_mel = preprocessed_mel 105 | 106 | if not preprocessed_mel: 107 | self.mel_spectrogram = default( 108 | mel_spec_module, 109 | MelSpec( 110 | n_fft=n_fft, 111 | hop_length=hop_length, 112 | win_length=win_length, 113 | n_mel_channels=n_mel_channels, 114 | target_sample_rate=target_sample_rate, 115 | mel_spec_type=mel_spec_type, 116 | ), 117 | ) 118 | 119 | def get_frame_len(self, index): 120 | if ( 121 | self.durations is not None 122 | ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM 123 | return self.durations[index] * self.target_sample_rate / self.hop_length 124 | return self.data[index]["duration"] * self.target_sample_rate / self.hop_length 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def __getitem__(self, index): 130 | row = self.data[index] 131 | audio_path = row["audio_path"] 132 | text = row["text"] 133 | duration = row["duration"] 134 | 135 | if self.preprocessed_mel: 136 | mel_spec = torch.tensor(row["mel_spec"]) 137 | 138 | else: 139 | audio, source_sample_rate = torchaudio.load(audio_path) 140 | if audio.shape[0] > 1: 141 | audio = torch.mean(audio, dim=0, keepdim=True) 142 | 143 | if duration > 30 or duration < 0.3: 144 | return self.__getitem__((index + 1) % len(self.data)) 145 | 146 | if source_sample_rate != self.target_sample_rate: 147 | resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) 148 | audio = resampler(audio) 149 | 150 | mel_spec = self.mel_spectrogram(audio) 151 | mel_spec = mel_spec.squeeze(0) # '1 d t -> d t') 152 | 153 | return dict( 154 | mel_spec=mel_spec, 155 | text=text, 156 | ) 157 | 158 | 159 | # Dynamic Batch Sampler 160 | 161 | 162 | class DynamicBatchSampler(Sampler[list[int]]): 163 | """Extension of Sampler that will do the following: 164 | 1. Change the batch size (essentially number of sequences) 165 | in a batch to ensure that the total number of frames are less 166 | than a certain threshold. 167 | 2. Make sure the padding efficiency in the batch is high. 168 | """ 169 | 170 | def __init__( 171 | self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False 172 | ): 173 | self.sampler = sampler 174 | self.frames_threshold = frames_threshold 175 | self.max_samples = max_samples 176 | 177 | indices, batches = [], [] 178 | data_source = self.sampler.data_source 179 | 180 | for idx in tqdm( 181 | self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration" 182 | ): 183 | indices.append((idx, data_source.get_frame_len(idx))) 184 | indices.sort(key=lambda elem: elem[1]) 185 | 186 | batch = [] 187 | batch_frames = 0 188 | for idx, frame_len in tqdm( 189 | indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu" 190 | ): 191 | if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): 192 | batch.append(idx) 193 | batch_frames += frame_len 194 | else: 195 | if len(batch) > 0: 196 | batches.append(batch) 197 | if frame_len <= self.frames_threshold: 198 | batch = [idx] 199 | batch_frames = frame_len 200 | else: 201 | batch = [] 202 | batch_frames = 0 203 | 204 | if not drop_last and len(batch) > 0: 205 | batches.append(batch) 206 | 207 | del indices 208 | 209 | # if want to have different batches between epochs, may just set a seed and log it in ckpt 210 | # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different 211 | # e.g. for epoch n, use (random_seed + n) 212 | random.seed(random_seed) 213 | random.shuffle(batches) 214 | 215 | self.batches = batches 216 | 217 | def __iter__(self): 218 | return iter(self.batches) 219 | 220 | def __len__(self): 221 | return len(self.batches) 222 | 223 | 224 | # Load dataset 225 | 226 | 227 | def load_dataset( 228 | dataset_name: str, 229 | tokenizer: str = "pinyin", 230 | dataset_type: str = "CustomDataset", 231 | audio_type: str = "raw", 232 | mel_spec_module: nn.Module | None = None, 233 | mel_spec_kwargs: dict = dict(), 234 | ) -> CustomDataset | HFDataset: 235 | """ 236 | dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset 237 | - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer 238 | """ 239 | 240 | print("Loading dataset ...") 241 | 242 | if dataset_type == "CustomDataset": 243 | rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")) 244 | if audio_type == "raw": 245 | try: 246 | train_dataset = load_from_disk(f"{rel_data_path}/raw") 247 | except: # noqa: E722 248 | train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow") 249 | preprocessed_mel = False 250 | elif audio_type == "mel": 251 | train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow") 252 | preprocessed_mel = True 253 | with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f: 254 | data_dict = json.load(f) 255 | durations = data_dict["duration"] 256 | train_dataset = CustomDataset( 257 | train_dataset, 258 | durations=durations, 259 | preprocessed_mel=preprocessed_mel, 260 | mel_spec_module=mel_spec_module, 261 | **mel_spec_kwargs, 262 | ) 263 | 264 | elif dataset_type == "CustomDatasetPath": 265 | try: 266 | train_dataset = load_from_disk(f"{dataset_name}/raw") 267 | except: # noqa: E722 268 | train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow") 269 | 270 | with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f: 271 | data_dict = json.load(f) 272 | durations = data_dict["duration"] 273 | train_dataset = CustomDataset( 274 | train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs 275 | ) 276 | 277 | elif dataset_type == "HFDataset": 278 | print( 279 | "Should manually modify the path of huggingface dataset to your need.\n" 280 | + "May also the corresponding script cuz different dataset may have different format." 281 | ) 282 | pre, post = dataset_name.split("_") 283 | train_dataset = HFDataset( 284 | load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))), 285 | ) 286 | 287 | return train_dataset 288 | 289 | 290 | # collation 291 | 292 | 293 | def collate_fn(batch): 294 | mel_specs = [item["mel_spec"].squeeze(0) for item in batch] 295 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) 296 | max_mel_length = mel_lengths.amax() 297 | 298 | padded_mel_specs = [] 299 | for spec in mel_specs: # TODO. maybe records mask for attention here 300 | padding = (0, max_mel_length - spec.size(-1)) 301 | padded_spec = F.pad(spec, padding, value=0) 302 | padded_mel_specs.append(padded_spec) 303 | 304 | mel_specs = torch.stack(padded_mel_specs) 305 | 306 | text = [item["text"] for item in batch] 307 | text_lengths = torch.LongTensor([len(item) for item in text]) 308 | 309 | return dict( 310 | mel=mel_specs, 311 | mel_lengths=mel_lengths, 312 | text=text, 313 | text_lengths=text_lengths, 314 | ) 315 | -------------------------------------------------------------------------------- /src/f5_tts/model/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gc 4 | import os 5 | 6 | import torch 7 | import torchaudio 8 | import wandb 9 | from accelerate import Accelerator 10 | from accelerate.utils import DistributedDataParallelKwargs 11 | from ema_pytorch import EMA 12 | from torch.optim import AdamW 13 | from torch.optim.lr_scheduler import LinearLR, SequentialLR 14 | from torch.utils.data import DataLoader, Dataset, SequentialSampler 15 | from tqdm import tqdm 16 | 17 | from f5_tts.model import CFM 18 | from f5_tts.model.dataset import DynamicBatchSampler, collate_fn 19 | from f5_tts.model.utils import default, exists 20 | 21 | # trainer 22 | 23 | 24 | class Trainer: 25 | def __init__( 26 | self, 27 | model: CFM, 28 | epochs, 29 | learning_rate, 30 | num_warmup_updates=20000, 31 | save_per_updates=1000, 32 | checkpoint_path=None, 33 | batch_size=32, 34 | batch_size_type: str = "sample", 35 | max_samples=32, 36 | grad_accumulation_steps=1, 37 | max_grad_norm=1.0, 38 | noise_scheduler: str | None = None, 39 | duration_predictor: torch.nn.Module | None = None, 40 | logger: str | None = "wandb", # "wandb" | "tensorboard" | None 41 | wandb_project="test_e2-tts", 42 | wandb_run_name="test_run", 43 | wandb_resume_id: str = None, 44 | log_samples: bool = False, 45 | last_per_steps=None, 46 | accelerate_kwargs: dict = dict(), 47 | ema_kwargs: dict = dict(), 48 | bnb_optimizer: bool = False, 49 | mel_spec_type: str = "vocos", # "vocos" | "bigvgan" 50 | ): 51 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 52 | 53 | if logger == "wandb" and not wandb.api.api_key: 54 | logger = None 55 | print(f"Using logger: {logger}") 56 | self.log_samples = log_samples 57 | 58 | self.accelerator = Accelerator( 59 | log_with=logger if logger == "wandb" else None, 60 | kwargs_handlers=[ddp_kwargs], 61 | gradient_accumulation_steps=grad_accumulation_steps, 62 | **accelerate_kwargs, 63 | ) 64 | 65 | self.logger = logger 66 | if self.logger == "wandb": 67 | if exists(wandb_resume_id): 68 | init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} 69 | else: 70 | init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} 71 | 72 | self.accelerator.init_trackers( 73 | project_name=wandb_project, 74 | init_kwargs=init_kwargs, 75 | config={ 76 | "epochs": epochs, 77 | "learning_rate": learning_rate, 78 | "num_warmup_updates": num_warmup_updates, 79 | "batch_size": batch_size, 80 | "batch_size_type": batch_size_type, 81 | "max_samples": max_samples, 82 | "grad_accumulation_steps": grad_accumulation_steps, 83 | "max_grad_norm": max_grad_norm, 84 | "gpus": self.accelerator.num_processes, 85 | "noise_scheduler": noise_scheduler, 86 | }, 87 | ) 88 | 89 | elif self.logger == "tensorboard": 90 | from torch.utils.tensorboard import SummaryWriter 91 | 92 | self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}") 93 | 94 | self.model = model 95 | 96 | if self.is_main: 97 | self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) 98 | self.ema_model.to(self.accelerator.device) 99 | 100 | self.epochs = epochs 101 | self.num_warmup_updates = num_warmup_updates 102 | self.save_per_updates = save_per_updates 103 | self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) 104 | self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") 105 | 106 | self.batch_size = batch_size 107 | self.batch_size_type = batch_size_type 108 | self.max_samples = max_samples 109 | self.grad_accumulation_steps = grad_accumulation_steps 110 | self.max_grad_norm = max_grad_norm 111 | self.vocoder_name = mel_spec_type 112 | 113 | self.noise_scheduler = noise_scheduler 114 | 115 | self.duration_predictor = duration_predictor 116 | 117 | if bnb_optimizer: 118 | import bitsandbytes as bnb 119 | 120 | self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) 121 | else: 122 | self.optimizer = AdamW(model.parameters(), lr=learning_rate) 123 | self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) 124 | 125 | @property 126 | def is_main(self): 127 | return self.accelerator.is_main_process 128 | 129 | def save_checkpoint(self, step, last=False): 130 | self.accelerator.wait_for_everyone() 131 | if self.is_main: 132 | checkpoint = dict( 133 | model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), 134 | optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), 135 | ema_model_state_dict=self.ema_model.state_dict(), 136 | scheduler_state_dict=self.scheduler.state_dict(), 137 | step=step, 138 | ) 139 | if not os.path.exists(self.checkpoint_path): 140 | os.makedirs(self.checkpoint_path) 141 | if last: 142 | self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") 143 | print(f"Saved last checkpoint at step {step}") 144 | else: 145 | self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") 146 | 147 | def load_checkpoint(self): 148 | if ( 149 | not exists(self.checkpoint_path) 150 | or not os.path.exists(self.checkpoint_path) 151 | or not os.listdir(self.checkpoint_path) 152 | ): 153 | return 0 154 | 155 | self.accelerator.wait_for_everyone() 156 | if "model_last.pt" in os.listdir(self.checkpoint_path): 157 | latest_checkpoint = "model_last.pt" 158 | else: 159 | latest_checkpoint = sorted( 160 | [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")], 161 | key=lambda x: int("".join(filter(str.isdigit, x))), 162 | )[-1] 163 | # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ 164 | checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") 165 | 166 | # patch for backward compatibility, 305e3ea 167 | for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: 168 | if key in checkpoint["ema_model_state_dict"]: 169 | del checkpoint["ema_model_state_dict"][key] 170 | 171 | if self.is_main: 172 | self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"]) 173 | 174 | if "step" in checkpoint: 175 | # patch for backward compatibility, 305e3ea 176 | for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: 177 | if key in checkpoint["model_state_dict"]: 178 | del checkpoint["model_state_dict"][key] 179 | 180 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) 181 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) 182 | if self.scheduler: 183 | self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) 184 | step = checkpoint["step"] 185 | else: 186 | checkpoint["model_state_dict"] = { 187 | k.replace("ema_model.", ""): v 188 | for k, v in checkpoint["ema_model_state_dict"].items() 189 | if k not in ["initted", "step"] 190 | } 191 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) 192 | step = 0 193 | 194 | del checkpoint 195 | gc.collect() 196 | return step 197 | 198 | def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): 199 | if self.log_samples: 200 | from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef 201 | 202 | vocoder = load_vocoder(vocoder_name=self.vocoder_name) 203 | target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate 204 | log_samples_path = f"{self.checkpoint_path}/samples" 205 | os.makedirs(log_samples_path, exist_ok=True) 206 | 207 | if exists(resumable_with_seed): 208 | generator = torch.Generator() 209 | generator.manual_seed(resumable_with_seed) 210 | else: 211 | generator = None 212 | 213 | if self.batch_size_type == "sample": 214 | train_dataloader = DataLoader( 215 | train_dataset, 216 | collate_fn=collate_fn, 217 | num_workers=num_workers, 218 | pin_memory=True, 219 | persistent_workers=True, 220 | batch_size=self.batch_size, 221 | shuffle=True, 222 | generator=generator, 223 | ) 224 | elif self.batch_size_type == "frame": 225 | self.accelerator.even_batches = False 226 | sampler = SequentialSampler(train_dataset) 227 | batch_sampler = DynamicBatchSampler( 228 | sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False 229 | ) 230 | train_dataloader = DataLoader( 231 | train_dataset, 232 | collate_fn=collate_fn, 233 | num_workers=num_workers, 234 | pin_memory=True, 235 | persistent_workers=True, 236 | batch_sampler=batch_sampler, 237 | ) 238 | else: 239 | raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") 240 | 241 | # accelerator.prepare() dispatches batches to devices; 242 | # which means the length of dataloader calculated before, should consider the number of devices 243 | warmup_steps = ( 244 | self.num_warmup_updates * self.accelerator.num_processes 245 | ) # consider a fixed warmup steps while using accelerate multi-gpu ddp 246 | # otherwise by default with split_batches=False, warmup steps change with num_processes 247 | total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps 248 | decay_steps = total_steps - warmup_steps 249 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) 250 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) 251 | self.scheduler = SequentialLR( 252 | self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] 253 | ) 254 | train_dataloader, self.scheduler = self.accelerator.prepare( 255 | train_dataloader, self.scheduler 256 | ) # actual steps = 1 gpu steps / gpus 257 | start_step = self.load_checkpoint() 258 | global_step = start_step 259 | 260 | if exists(resumable_with_seed): 261 | orig_epoch_step = len(train_dataloader) 262 | skipped_epoch = int(start_step // orig_epoch_step) 263 | skipped_batch = start_step % orig_epoch_step 264 | skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) 265 | else: 266 | skipped_epoch = 0 267 | 268 | for epoch in range(skipped_epoch, self.epochs): 269 | self.model.train() 270 | if exists(resumable_with_seed) and epoch == skipped_epoch: 271 | progress_bar = tqdm( 272 | skipped_dataloader, 273 | desc=f"Epoch {epoch+1}/{self.epochs}", 274 | unit="step", 275 | disable=not self.accelerator.is_local_main_process, 276 | initial=skipped_batch, 277 | total=orig_epoch_step, 278 | ) 279 | else: 280 | progress_bar = tqdm( 281 | train_dataloader, 282 | desc=f"Epoch {epoch+1}/{self.epochs}", 283 | unit="step", 284 | disable=not self.accelerator.is_local_main_process, 285 | ) 286 | 287 | for batch in progress_bar: 288 | with self.accelerator.accumulate(self.model): 289 | text_inputs = batch["text"] 290 | mel_spec = batch["mel"].permute(0, 2, 1) 291 | mel_lengths = batch["mel_lengths"] 292 | 293 | # TODO. add duration predictor training 294 | if self.duration_predictor is not None and self.accelerator.is_local_main_process: 295 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) 296 | self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step) 297 | 298 | loss, cond, pred = self.model( 299 | mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler 300 | ) 301 | self.accelerator.backward(loss) 302 | 303 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients: 304 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 305 | 306 | self.optimizer.step() 307 | self.scheduler.step() 308 | self.optimizer.zero_grad() 309 | 310 | if self.is_main: 311 | self.ema_model.update() 312 | 313 | global_step += 1 314 | 315 | if self.accelerator.is_local_main_process: 316 | self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step) 317 | if self.logger == "tensorboard": 318 | self.writer.add_scalar("loss", loss.item(), global_step) 319 | self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step) 320 | 321 | progress_bar.set_postfix(step=str(global_step), loss=loss.item()) 322 | 323 | if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: 324 | self.save_checkpoint(global_step) 325 | 326 | if self.log_samples and self.accelerator.is_local_main_process: 327 | ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0] 328 | torchaudio.save( 329 | f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate 330 | ) 331 | with torch.inference_mode(): 332 | generated, _ = self.accelerator.unwrap_model(self.model).sample( 333 | cond=mel_spec[0][:ref_audio_len].unsqueeze(0), 334 | text=[text_inputs[0] + [" "] + text_inputs[0]], 335 | duration=ref_audio_len * 2, 336 | steps=nfe_step, 337 | cfg_strength=cfg_strength, 338 | sway_sampling_coef=sway_sampling_coef, 339 | ) 340 | generated = generated.to(torch.float32) 341 | gen_audio = vocoder.decode( 342 | generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) 343 | ) 344 | torchaudio.save( 345 | f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate 346 | ) 347 | 348 | if global_step % self.last_per_steps == 0: 349 | self.save_checkpoint(global_step, last=True) 350 | 351 | self.save_checkpoint(global_step, last=True) 352 | 353 | self.accelerator.end_training() 354 | -------------------------------------------------------------------------------- /src/f5_tts/model/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import random 5 | from collections import defaultdict 6 | from importlib.resources import files 7 | 8 | import torch 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | import jieba 12 | from pypinyin import lazy_pinyin, Style 13 | 14 | 15 | # seed everything 16 | 17 | 18 | def seed_everything(seed=0): 19 | random.seed(seed) 20 | os.environ["PYTHONHASHSEED"] = str(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | 28 | # helpers 29 | 30 | 31 | def exists(v): 32 | return v is not None 33 | 34 | 35 | def default(v, d): 36 | return v if exists(v) else d 37 | 38 | 39 | # tensor helpers 40 | 41 | 42 | def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 43 | if not exists(length): 44 | length = t.amax() 45 | 46 | seq = torch.arange(length, device=t.device) 47 | return seq[None, :] < t[:, None] 48 | 49 | 50 | def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 51 | max_seq_len = seq_len.max().item() 52 | seq = torch.arange(max_seq_len, device=start.device).long() 53 | start_mask = seq[None, :] >= start[:, None] 54 | end_mask = seq[None, :] < end[:, None] 55 | return start_mask & end_mask 56 | 57 | 58 | def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 59 | lengths = (frac_lengths * seq_len).long() 60 | max_start = seq_len - lengths 61 | 62 | rand = torch.rand_like(frac_lengths) 63 | start = (max_start * rand).long().clamp(min=0) 64 | end = start + lengths 65 | 66 | return mask_from_start_end_indices(seq_len, start, end) 67 | 68 | 69 | def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 70 | if not exists(mask): 71 | return t.mean(dim=1) 72 | 73 | t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) 74 | num = t.sum(dim=1) 75 | den = mask.float().sum(dim=1) 76 | 77 | return num / den.clamp(min=1.0) 78 | 79 | 80 | # simple utf-8 tokenizer, since paper went character based 81 | def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 82 | list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style 83 | text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) 84 | return text 85 | 86 | 87 | # char tokenizer, based on custom dataset's extracted .txt file 88 | def list_str_to_idx( 89 | text: list[str] | list[list[str]], 90 | vocab_char_map: dict[str, int], # {char: idx} 91 | padding_value=-1, 92 | ) -> int["b nt"]: # noqa: F722 93 | list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style 94 | text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) 95 | return text 96 | 97 | 98 | # Get tokenizer 99 | 100 | 101 | def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): 102 | """ 103 | tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file 104 | - "char" for char-wise tokenizer, need .txt vocab_file 105 | - "byte" for utf-8 tokenizer 106 | - "custom" if you're directly passing in a path to the vocab.txt you want to use 107 | vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols 108 | - if use "char", derived from unfiltered character & symbol counts of custom dataset 109 | - if use "byte", set to 256 (unicode byte range) 110 | """ 111 | if tokenizer in ["pinyin", "char"]: 112 | tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") 113 | with open(tokenizer_path, "r", encoding="utf-8") as f: 114 | vocab_char_map = {} 115 | for i, char in enumerate(f): 116 | vocab_char_map[char[:-1]] = i 117 | vocab_size = len(vocab_char_map) 118 | assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" 119 | 120 | elif tokenizer == "byte": 121 | vocab_char_map = None 122 | vocab_size = 256 123 | 124 | elif tokenizer == "custom": 125 | with open(dataset_name, "r", encoding="utf-8") as f: 126 | vocab_char_map = {} 127 | for i, char in enumerate(f): 128 | vocab_char_map[char[:-1]] = i 129 | vocab_size = len(vocab_char_map) 130 | 131 | return vocab_char_map, vocab_size 132 | 133 | 134 | # convert char to pinyin 135 | 136 | 137 | def convert_char_to_pinyin(text_list, polyphone=True): 138 | final_text_list = [] 139 | god_knows_why_en_testset_contains_zh_quote = str.maketrans( 140 | {"“": '"', "”": '"', "‘": "'", "’": "'"} 141 | ) # in case librispeech (orig no-pc) test-clean 142 | custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov 143 | for text in text_list: 144 | char_list = [] 145 | text = text.translate(god_knows_why_en_testset_contains_zh_quote) 146 | text = text.translate(custom_trans) 147 | for seg in jieba.cut(text): 148 | seg_byte_len = len(bytes(seg, "UTF-8")) 149 | if seg_byte_len == len(seg): # if pure alphabets and symbols 150 | if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": 151 | char_list.append(" ") 152 | char_list.extend(seg) 153 | elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters 154 | seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) 155 | for c in seg: 156 | if c not in "。,、;:?!《》【】—…": 157 | char_list.append(" ") 158 | char_list.append(c) 159 | else: # if mixed chinese characters, alphabets and symbols 160 | for c in seg: 161 | if ord(c) < 256: 162 | char_list.extend(c) 163 | else: 164 | if c not in "。,、;:?!《》【】—…": 165 | char_list.append(" ") 166 | char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) 167 | else: # if is zh punc 168 | char_list.append(c) 169 | final_text_list.append(char_list) 170 | 171 | return final_text_list 172 | 173 | 174 | # filter func for dirty data with many repetitions 175 | 176 | 177 | def repetition_found(text, length=2, tolerance=10): 178 | pattern_count = defaultdict(int) 179 | for i in range(len(text) - length + 1): 180 | pattern = text[i : i + length] 181 | pattern_count[pattern] += 1 182 | for pattern, count in pattern_count.items(): 183 | if count > tolerance: 184 | return True 185 | return False 186 | -------------------------------------------------------------------------------- /src/f5_tts/scripts/count_max_epoch.py: -------------------------------------------------------------------------------- 1 | """ADAPTIVE BATCH SIZE""" 2 | 3 | print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in") 4 | print(" -> least padding, gather wavs with accumulated frames in a batch\n") 5 | 6 | # data 7 | total_hours = 95282 8 | mel_hop_length = 256 9 | mel_sampling_rate = 24000 10 | 11 | # target 12 | wanted_max_updates = 1000000 13 | 14 | # train params 15 | gpus = 8 16 | frames_per_gpu = 38400 # 8 * 38400 = 307200 17 | grad_accum = 1 18 | 19 | # intermediate 20 | mini_batch_frames = frames_per_gpu * grad_accum * gpus 21 | mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 22 | updates_per_epoch = total_hours / mini_batch_hours 23 | steps_per_epoch = updates_per_epoch * grad_accum 24 | 25 | # result 26 | epochs = wanted_max_updates / updates_per_epoch 27 | print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})") 28 | print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") 29 | print(f" or approx. 0/{steps_per_epoch:.0f} steps") 30 | 31 | # others 32 | print(f"total {total_hours:.0f} hours") 33 | print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch") 34 | -------------------------------------------------------------------------------- /src/f5_tts/scripts/count_params_gflops.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | from f5_tts.model import CFM, DiT 7 | 8 | import torch 9 | import thop 10 | 11 | 12 | """ ~155M """ 13 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) 14 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) 15 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2) 16 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4) 17 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True) 18 | # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2) 19 | 20 | """ ~335M """ 21 | # FLOPs: 622.1 G, Params: 333.2 M 22 | # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4) 23 | # FLOPs: 363.4 G, Params: 335.8 M 24 | transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 25 | 26 | 27 | model = CFM(transformer=transformer) 28 | target_sample_rate = 24000 29 | n_mel_channels = 100 30 | hop_length = 256 31 | duration = 20 32 | frame_length = int(duration * target_sample_rate / hop_length) 33 | text_length = 150 34 | 35 | flops, params = thop.profile( 36 | model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)) 37 | ) 38 | print(f"FLOPs: {flops / 1e9} G") 39 | print(f"Params: {params / 1e6} M") 40 | -------------------------------------------------------------------------------- /src/f5_tts/socket_server.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import struct 3 | import torch 4 | import torchaudio 5 | from threading import Thread 6 | 7 | 8 | import gc 9 | import traceback 10 | 11 | 12 | from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model 13 | from model.backbones.dit import DiT 14 | 15 | 16 | class TTSStreamingProcessor: 17 | def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): 18 | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | # Load the model using the provided checkpoint and vocab files 21 | self.model = load_model( 22 | model_cls=DiT, 23 | model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), 24 | ckpt_path=ckpt_file, 25 | mel_spec_type="vocos", # or "bigvgan" depending on vocoder 26 | vocab_file=vocab_file, 27 | ode_method="euler", 28 | use_ema=True, 29 | device=self.device, 30 | ).to(self.device, dtype=dtype) 31 | 32 | # Load the vocoder 33 | self.vocoder = load_vocoder(is_local=False) 34 | 35 | # Set sampling rate for streaming 36 | self.sampling_rate = 24000 # Consistency with client 37 | 38 | # Set reference audio and text 39 | self.ref_audio = ref_audio 40 | self.ref_text = ref_text 41 | 42 | # Warm up the model 43 | self._warm_up() 44 | 45 | def _warm_up(self): 46 | """Warm up the model with a dummy input to ensure it's ready for real-time processing.""" 47 | print("Warming up the model...") 48 | ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text) 49 | audio, sr = torchaudio.load(ref_audio) 50 | gen_text = "Warm-up text for the model." 51 | 52 | # Pass the vocoder as an argument here 53 | infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device) 54 | print("Warm-up completed.") 55 | 56 | def generate_stream(self, text, play_steps_in_s=0.5): 57 | """Generate audio in chunks and yield them in real-time.""" 58 | # Preprocess the reference audio and text 59 | ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text) 60 | 61 | # Load reference audio 62 | audio, sr = torchaudio.load(ref_audio) 63 | 64 | # Run inference for the input text 65 | audio_chunk, final_sample_rate, _ = infer_batch_process( 66 | (audio, sr), 67 | ref_text, 68 | [text], 69 | self.model, 70 | self.vocoder, 71 | device=self.device, # Pass vocoder here 72 | ) 73 | 74 | # Break the generated audio into chunks and send them 75 | chunk_size = int(final_sample_rate * play_steps_in_s) 76 | 77 | for i in range(0, len(audio_chunk), chunk_size): 78 | chunk = audio_chunk[i : i + chunk_size] 79 | 80 | # Check if it's the final chunk 81 | if i + chunk_size >= len(audio_chunk): 82 | chunk = audio_chunk[i:] 83 | 84 | # Avoid sending empty or repeated chunks 85 | if len(chunk) == 0: 86 | break 87 | 88 | # Pack and send the audio chunk 89 | packed_audio = struct.pack(f"{len(chunk)}f", *chunk) 90 | yield packed_audio 91 | 92 | # Ensure that no final word is repeated by not resending partial chunks 93 | if len(audio_chunk) % chunk_size != 0: 94 | remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size) :] 95 | packed_audio = struct.pack(f"{len(remaining_chunk)}f", *remaining_chunk) 96 | yield packed_audio 97 | 98 | 99 | def handle_client(client_socket, processor): 100 | try: 101 | while True: 102 | # Receive data from the client 103 | data = client_socket.recv(1024).decode("utf-8") 104 | if not data: 105 | break 106 | 107 | try: 108 | # The client sends the text input 109 | text = data.strip() 110 | 111 | # Generate and stream audio chunks 112 | for audio_chunk in processor.generate_stream(text): 113 | client_socket.sendall(audio_chunk) 114 | 115 | # Send end-of-audio signal 116 | client_socket.sendall(b"END_OF_AUDIO") 117 | 118 | except Exception as inner_e: 119 | print(f"Error during processing: {inner_e}") 120 | traceback.print_exc() # Print the full traceback to diagnose the issue 121 | break 122 | 123 | except Exception as e: 124 | print(f"Error handling client: {e}") 125 | traceback.print_exc() 126 | finally: 127 | client_socket.close() 128 | 129 | 130 | def start_server(host, port, processor): 131 | server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 132 | server.bind((host, port)) 133 | server.listen(5) 134 | print(f"Server listening on {host}:{port}") 135 | 136 | while True: 137 | client_socket, addr = server.accept() 138 | print(f"Accepted connection from {addr}") 139 | client_handler = Thread(target=handle_client, args=(client_socket, processor)) 140 | client_handler.start() 141 | 142 | 143 | if __name__ == "__main__": 144 | try: 145 | # Load the model and vocoder using the provided files 146 | ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt" 147 | vocab_file = "" # Add vocab file path if needed 148 | ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav" 149 | ref_text = "" 150 | 151 | # Initialize the processor with the model and vocoder 152 | processor = TTSStreamingProcessor( 153 | ckpt_file=ckpt_file, 154 | vocab_file=vocab_file, 155 | ref_audio=ref_audio, 156 | ref_text=ref_text, 157 | dtype=torch.float32, 158 | ) 159 | 160 | # Start the server 161 | start_server("0.0.0.0", 9998, processor) 162 | except KeyboardInterrupt: 163 | gc.collect() 164 | -------------------------------------------------------------------------------- /src/f5_tts/train/README.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | ## Prepare Dataset 4 | 5 | Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`. 6 | 7 | ### 1. Datasets used for pretrained models 8 | Download corresponding dataset first, and fill in the path in scripts. 9 | 10 | ```bash 11 | # Prepare the Emilia dataset 12 | python src/f5_tts/train/datasets/prepare_emilia.py 13 | 14 | # Prepare the Wenetspeech4TTS dataset 15 | python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py 16 | ``` 17 | 18 | ### 2. Create custom dataset with metadata.csv 19 | Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029). 20 | 21 | ```bash 22 | python src/f5_tts/train/datasets/prepare_csv_wavs.py 23 | ``` 24 | 25 | ## Training & Finetuning 26 | 27 | Once your datasets are prepared, you can start the training process. 28 | 29 | ### 1. Training script used for pretrained model 30 | 31 | ```bash 32 | # setup accelerate config, e.g. use multi-gpu ddp, fp16 33 | # will be to: ~/.cache/huggingface/accelerate/default_config.yaml 34 | accelerate config 35 | accelerate launch src/f5_tts/train/train.py 36 | ``` 37 | 38 | ### 2. Finetuning practice 39 | Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57). 40 | 41 | Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143). 42 | 43 | ### 3. Wandb Logging 44 | 45 | The `wandb/` dir will be created under path you run training/finetuning scripts. 46 | 47 | By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`). 48 | 49 | To turn on wandb logging, you can either: 50 | 51 | 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login) 52 | 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows: 53 | 54 | On Mac & Linux: 55 | 56 | ``` 57 | export WANDB_API_KEY= 58 | ``` 59 | 60 | On Windows: 61 | 62 | ``` 63 | set WANDB_API_KEY= 64 | ``` 65 | Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows: 66 | 67 | ``` 68 | export WANDB_MODE=offline 69 | ``` 70 | -------------------------------------------------------------------------------- /src/f5_tts/train/datasets/prepare_csv_wavs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import argparse 7 | import csv 8 | import json 9 | import shutil 10 | from importlib.resources import files 11 | from pathlib import Path 12 | 13 | import torchaudio 14 | from tqdm import tqdm 15 | from datasets.arrow_writer import ArrowWriter 16 | 17 | from f5_tts.model.utils import ( 18 | convert_char_to_pinyin, 19 | ) 20 | 21 | 22 | PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") 23 | 24 | 25 | def is_csv_wavs_format(input_dataset_dir): 26 | fpath = Path(input_dataset_dir) 27 | metadata = fpath / "metadata.csv" 28 | wavs = fpath / "wavs" 29 | return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir() 30 | 31 | 32 | def prepare_csv_wavs_dir(input_dir): 33 | assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}" 34 | input_dir = Path(input_dir) 35 | metadata_path = input_dir / "metadata.csv" 36 | audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix()) 37 | 38 | sub_result, durations = [], [] 39 | vocab_set = set() 40 | polyphone = True 41 | for audio_path, text in audio_path_text_pairs: 42 | if not Path(audio_path).exists(): 43 | print(f"audio {audio_path} not found, skipping") 44 | continue 45 | audio_duration = get_audio_duration(audio_path) 46 | # assume tokenizer = "pinyin" ("pinyin" | "char") 47 | text = convert_char_to_pinyin([text], polyphone=polyphone)[0] 48 | sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration}) 49 | durations.append(audio_duration) 50 | vocab_set.update(list(text)) 51 | 52 | return sub_result, durations, vocab_set 53 | 54 | 55 | def get_audio_duration(audio_path): 56 | audio, sample_rate = torchaudio.load(audio_path) 57 | return audio.shape[1] / sample_rate 58 | 59 | 60 | def read_audio_text_pairs(csv_file_path): 61 | audio_text_pairs = [] 62 | 63 | parent = Path(csv_file_path).parent 64 | with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile: 65 | reader = csv.reader(csvfile, delimiter="|") 66 | next(reader) # Skip the header row 67 | for row in reader: 68 | if len(row) >= 2: 69 | audio_file = row[0].strip() # First column: audio file path 70 | text = row[1].strip() # Second column: text 71 | audio_file_path = parent / audio_file 72 | audio_text_pairs.append((audio_file_path.as_posix(), text)) 73 | 74 | return audio_text_pairs 75 | 76 | 77 | def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune): 78 | out_dir = Path(out_dir) 79 | # save preprocessed dataset to disk 80 | out_dir.mkdir(exist_ok=True, parents=True) 81 | print(f"\nSaving to {out_dir} ...") 82 | 83 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom 84 | # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB") 85 | raw_arrow_path = out_dir / "raw.arrow" 86 | with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer: 87 | for line in tqdm(result, desc="Writing to raw.arrow ..."): 88 | writer.write(line) 89 | 90 | # dup a json separately saving duration in case for DynamicBatchSampler ease 91 | dur_json_path = out_dir / "duration.json" 92 | with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f: 93 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 94 | 95 | # vocab map, i.e. tokenizer 96 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 97 | # if tokenizer == "pinyin": 98 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 99 | voca_out_path = out_dir / "vocab.txt" 100 | with open(voca_out_path.as_posix(), "w") as f: 101 | for vocab in sorted(text_vocab_set): 102 | f.write(vocab + "\n") 103 | 104 | if is_finetune: 105 | file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix() 106 | shutil.copy2(file_vocab_finetune, voca_out_path) 107 | else: 108 | with open(voca_out_path, "w") as f: 109 | for vocab in sorted(text_vocab_set): 110 | f.write(vocab + "\n") 111 | 112 | dataset_name = out_dir.stem 113 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 114 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 115 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 116 | 117 | 118 | def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True): 119 | if is_finetune: 120 | assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" 121 | sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir) 122 | save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune) 123 | 124 | 125 | def cli(): 126 | # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin 127 | # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain 128 | parser = argparse.ArgumentParser(description="Prepare and save dataset.") 129 | parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") 130 | parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") 131 | parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") 132 | 133 | args = parser.parse_args() 134 | 135 | prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain) 136 | 137 | 138 | if __name__ == "__main__": 139 | cli() 140 | -------------------------------------------------------------------------------- /src/f5_tts/train/datasets/prepare_emilia.py: -------------------------------------------------------------------------------- 1 | # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07 2 | # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script 3 | 4 | # generate audio text map for Emilia ZH & EN 5 | # evaluate for vocab size 6 | 7 | import os 8 | import sys 9 | 10 | sys.path.append(os.getcwd()) 11 | 12 | import json 13 | from concurrent.futures import ProcessPoolExecutor 14 | from importlib.resources import files 15 | from pathlib import Path 16 | from tqdm import tqdm 17 | 18 | from datasets.arrow_writer import ArrowWriter 19 | 20 | from f5_tts.model.utils import ( 21 | repetition_found, 22 | convert_char_to_pinyin, 23 | ) 24 | 25 | 26 | out_zh = { 27 | "ZH_B00041_S06226", 28 | "ZH_B00042_S09204", 29 | "ZH_B00065_S09430", 30 | "ZH_B00065_S09431", 31 | "ZH_B00066_S09327", 32 | "ZH_B00066_S09328", 33 | } 34 | zh_filters = ["い", "て"] 35 | # seems synthesized audios, or heavily code-switched 36 | out_en = { 37 | "EN_B00013_S00913", 38 | "EN_B00042_S00120", 39 | "EN_B00055_S04111", 40 | "EN_B00061_S00693", 41 | "EN_B00061_S01494", 42 | "EN_B00061_S03375", 43 | "EN_B00059_S00092", 44 | "EN_B00111_S04300", 45 | "EN_B00100_S03759", 46 | "EN_B00087_S03811", 47 | "EN_B00059_S00950", 48 | "EN_B00089_S00946", 49 | "EN_B00078_S05127", 50 | "EN_B00070_S04089", 51 | "EN_B00074_S09659", 52 | "EN_B00061_S06983", 53 | "EN_B00061_S07060", 54 | "EN_B00059_S08397", 55 | "EN_B00082_S06192", 56 | "EN_B00091_S01238", 57 | "EN_B00089_S07349", 58 | "EN_B00070_S04343", 59 | "EN_B00061_S02400", 60 | "EN_B00076_S01262", 61 | "EN_B00068_S06467", 62 | "EN_B00076_S02943", 63 | "EN_B00064_S05954", 64 | "EN_B00061_S05386", 65 | "EN_B00066_S06544", 66 | "EN_B00076_S06944", 67 | "EN_B00072_S08620", 68 | "EN_B00076_S07135", 69 | "EN_B00076_S09127", 70 | "EN_B00065_S00497", 71 | "EN_B00059_S06227", 72 | "EN_B00063_S02859", 73 | "EN_B00075_S01547", 74 | "EN_B00061_S08286", 75 | "EN_B00079_S02901", 76 | "EN_B00092_S03643", 77 | "EN_B00096_S08653", 78 | "EN_B00063_S04297", 79 | "EN_B00063_S04614", 80 | "EN_B00079_S04698", 81 | "EN_B00104_S01666", 82 | "EN_B00061_S09504", 83 | "EN_B00061_S09694", 84 | "EN_B00065_S05444", 85 | "EN_B00063_S06860", 86 | "EN_B00065_S05725", 87 | "EN_B00069_S07628", 88 | "EN_B00083_S03875", 89 | "EN_B00071_S07665", 90 | "EN_B00071_S07665", 91 | "EN_B00062_S04187", 92 | "EN_B00065_S09873", 93 | "EN_B00065_S09922", 94 | "EN_B00084_S02463", 95 | "EN_B00067_S05066", 96 | "EN_B00106_S08060", 97 | "EN_B00073_S06399", 98 | "EN_B00073_S09236", 99 | "EN_B00087_S00432", 100 | "EN_B00085_S05618", 101 | "EN_B00064_S01262", 102 | "EN_B00072_S01739", 103 | "EN_B00059_S03913", 104 | "EN_B00069_S04036", 105 | "EN_B00067_S05623", 106 | "EN_B00060_S05389", 107 | "EN_B00060_S07290", 108 | "EN_B00062_S08995", 109 | } 110 | en_filters = ["ا", "い", "て"] 111 | 112 | 113 | def deal_with_audio_dir(audio_dir): 114 | audio_jsonl = audio_dir.with_suffix(".jsonl") 115 | sub_result, durations = [], [] 116 | vocab_set = set() 117 | bad_case_zh = 0 118 | bad_case_en = 0 119 | with open(audio_jsonl, "r") as f: 120 | lines = f.readlines() 121 | for line in tqdm(lines, desc=f"{audio_jsonl.stem}"): 122 | obj = json.loads(line) 123 | text = obj["text"] 124 | if obj["language"] == "zh": 125 | if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text): 126 | bad_case_zh += 1 127 | continue 128 | else: 129 | text = text.translate( 130 | str.maketrans({",": ",", "!": "!", "?": "?"}) 131 | ) # not "。" cuz much code-switched 132 | if obj["language"] == "en": 133 | if ( 134 | obj["wav"].split("/")[1] in out_en 135 | or any(f in text for f in en_filters) 136 | or repetition_found(text, length=4) 137 | ): 138 | bad_case_en += 1 139 | continue 140 | if tokenizer == "pinyin": 141 | text = convert_char_to_pinyin([text], polyphone=polyphone)[0] 142 | duration = obj["duration"] 143 | sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration}) 144 | durations.append(duration) 145 | vocab_set.update(list(text)) 146 | return sub_result, durations, vocab_set, bad_case_zh, bad_case_en 147 | 148 | 149 | def main(): 150 | assert tokenizer in ["pinyin", "char"] 151 | result = [] 152 | duration_list = [] 153 | text_vocab_set = set() 154 | total_bad_case_zh = 0 155 | total_bad_case_en = 0 156 | 157 | # process raw data 158 | executor = ProcessPoolExecutor(max_workers=max_workers) 159 | futures = [] 160 | for lang in langs: 161 | dataset_path = Path(os.path.join(dataset_dir, lang)) 162 | [ 163 | futures.append(executor.submit(deal_with_audio_dir, audio_dir)) 164 | for audio_dir in dataset_path.iterdir() 165 | if audio_dir.is_dir() 166 | ] 167 | for futures in tqdm(futures, total=len(futures)): 168 | sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result() 169 | result.extend(sub_result) 170 | duration_list.extend(durations) 171 | text_vocab_set.update(vocab_set) 172 | total_bad_case_zh += bad_case_zh 173 | total_bad_case_en += bad_case_en 174 | executor.shutdown() 175 | 176 | # save preprocessed dataset to disk 177 | if not os.path.exists(f"{save_dir}"): 178 | os.makedirs(f"{save_dir}") 179 | print(f"\nSaving to {save_dir} ...") 180 | 181 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom 182 | # dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") 183 | with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: 184 | for line in tqdm(result, desc="Writing to raw.arrow ..."): 185 | writer.write(line) 186 | 187 | # dup a json separately saving duration in case for DynamicBatchSampler ease 188 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: 189 | json.dump({"duration": duration_list}, f, ensure_ascii=False) 190 | 191 | # vocab map, i.e. tokenizer 192 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 193 | # if tokenizer == "pinyin": 194 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 195 | with open(f"{save_dir}/vocab.txt", "w") as f: 196 | for vocab in sorted(text_vocab_set): 197 | f.write(vocab + "\n") 198 | 199 | print(f"\nFor {dataset_name}, sample count: {len(result)}") 200 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") 201 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") 202 | if "ZH" in langs: 203 | print(f"Bad zh transcription case: {total_bad_case_zh}") 204 | if "EN" in langs: 205 | print(f"Bad en transcription case: {total_bad_case_en}\n") 206 | 207 | 208 | if __name__ == "__main__": 209 | max_workers = 32 210 | 211 | tokenizer = "pinyin" # "pinyin" | "char" 212 | polyphone = True 213 | 214 | langs = ["ZH", "EN"] 215 | dataset_dir = "/Emilia_Dataset/raw" 216 | dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}" 217 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" 218 | print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") 219 | 220 | main() 221 | 222 | # Emilia ZH & EN 223 | # samples count 37837916 (after removal) 224 | # pinyin vocab size 2543 (polyphone) 225 | # total duration 95281.87 (hours) 226 | # bad zh asr cnt 230435 (samples) 227 | # bad eh asr cnt 37217 (samples) 228 | 229 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) 230 | # please be careful if using pretrained model, make sure the vocab.txt is same 231 | -------------------------------------------------------------------------------- /src/f5_tts/train/datasets/prepare_wenetspeech4tts.py: -------------------------------------------------------------------------------- 1 | # generate audio text map for WenetSpeech4TTS 2 | # evaluate for vocab size 3 | 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | import json 10 | from concurrent.futures import ProcessPoolExecutor 11 | from importlib.resources import files 12 | from tqdm import tqdm 13 | 14 | import torchaudio 15 | from datasets import Dataset 16 | 17 | from f5_tts.model.utils import convert_char_to_pinyin 18 | 19 | 20 | def deal_with_sub_path_files(dataset_path, sub_path): 21 | print(f"Dealing with: {sub_path}") 22 | 23 | text_dir = os.path.join(dataset_path, sub_path, "txts") 24 | audio_dir = os.path.join(dataset_path, sub_path, "wavs") 25 | text_files = os.listdir(text_dir) 26 | 27 | audio_paths, texts, durations = [], [], [] 28 | for text_file in tqdm(text_files): 29 | with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file: 30 | first_line = file.readline().split("\t") 31 | audio_nm = first_line[0] 32 | audio_path = os.path.join(audio_dir, audio_nm + ".wav") 33 | text = first_line[1].strip() 34 | 35 | audio_paths.append(audio_path) 36 | 37 | if tokenizer == "pinyin": 38 | texts.extend(convert_char_to_pinyin([text], polyphone=polyphone)) 39 | elif tokenizer == "char": 40 | texts.append(text) 41 | 42 | audio, sample_rate = torchaudio.load(audio_path) 43 | durations.append(audio.shape[-1] / sample_rate) 44 | 45 | return audio_paths, texts, durations 46 | 47 | 48 | def main(): 49 | assert tokenizer in ["pinyin", "char"] 50 | 51 | audio_path_list, text_list, duration_list = [], [], [] 52 | 53 | executor = ProcessPoolExecutor(max_workers=max_workers) 54 | futures = [] 55 | for dataset_path in dataset_paths: 56 | sub_items = os.listdir(dataset_path) 57 | sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))] 58 | for sub_path in sub_paths: 59 | futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path)) 60 | for future in tqdm(futures, total=len(futures)): 61 | audio_paths, texts, durations = future.result() 62 | audio_path_list.extend(audio_paths) 63 | text_list.extend(texts) 64 | duration_list.extend(durations) 65 | executor.shutdown() 66 | 67 | if not os.path.exists("data"): 68 | os.makedirs("data") 69 | 70 | print(f"\nSaving to {save_dir} ...") 71 | dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) 72 | dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format 73 | 74 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: 75 | json.dump( 76 | {"duration": duration_list}, f, ensure_ascii=False 77 | ) # dup a json separately saving duration in case for DynamicBatchSampler ease 78 | 79 | print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...") 80 | text_vocab_set = set() 81 | for text in tqdm(text_list): 82 | text_vocab_set.update(list(text)) 83 | 84 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.) 85 | if tokenizer == "pinyin": 86 | text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) 87 | 88 | with open(f"{save_dir}/vocab.txt", "w") as f: 89 | for vocab in sorted(text_vocab_set): 90 | f.write(vocab + "\n") 91 | print(f"\nFor {dataset_name}, sample count: {len(text_list)}") 92 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n") 93 | 94 | 95 | if __name__ == "__main__": 96 | max_workers = 32 97 | 98 | tokenizer = "pinyin" # "pinyin" | "char" 99 | polyphone = True 100 | dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic 101 | 102 | dataset_name = ( 103 | ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1] 104 | + "_" 105 | + tokenizer 106 | ) 107 | dataset_paths = [ 108 | "/WenetSpeech4TTS/Basic", 109 | "/WenetSpeech4TTS/Standard", 110 | "/WenetSpeech4TTS/Premium", 111 | ][-dataset_choice:] 112 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" 113 | print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n") 114 | 115 | main() 116 | 117 | # Results (if adding alphabets with accents and symbols): 118 | # WenetSpeech4TTS Basic Standard Premium 119 | # samples count 3932473 1941220 407494 120 | # pinyin vocab size 1349 1348 1344 (no polyphone) 121 | # - - 1459 (polyphone) 122 | # char vocab size 5264 5219 5042 123 | 124 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme) 125 | # please be careful if using pretrained model, make sure the vocab.txt is same 126 | -------------------------------------------------------------------------------- /src/f5_tts/train/finetune_cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | from cached_path import cached_path 6 | from f5_tts.model import CFM, UNetT, DiT, Trainer 7 | from f5_tts.model.utils import get_tokenizer 8 | from f5_tts.model.dataset import load_dataset 9 | from importlib.resources import files 10 | 11 | 12 | # -------------------------- Dataset Settings --------------------------- # 13 | target_sample_rate = 24000 14 | n_mel_channels = 100 15 | hop_length = 256 16 | win_length = 1024 17 | n_fft = 1024 18 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan' 19 | 20 | 21 | # -------------------------- Argument Parsing --------------------------- # 22 | def parse_args(): 23 | # batch_size_per_gpu = 1000 settting for gpu 8GB 24 | # batch_size_per_gpu = 1600 settting for gpu 12GB 25 | # batch_size_per_gpu = 2000 settting for gpu 16GB 26 | # batch_size_per_gpu = 3200 settting for gpu 24GB 27 | 28 | # num_warmup_updates = 300 for 5000 sample about 10 hours 29 | 30 | # change save_per_updates , last_per_steps change this value what you need , 31 | 32 | parser = argparse.ArgumentParser(description="Train CFM Model") 33 | 34 | parser.add_argument( 35 | "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" 36 | ) 37 | parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") 38 | parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") 39 | parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") 40 | parser.add_argument( 41 | "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" 42 | ) 43 | parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") 44 | parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") 45 | parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") 46 | parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") 47 | parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps") 48 | parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") 49 | parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps") 50 | parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") 51 | parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") 52 | parser.add_argument( 53 | "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" 54 | ) 55 | parser.add_argument( 56 | "--tokenizer_path", 57 | type=str, 58 | default=None, 59 | help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')", 60 | ) 61 | parser.add_argument( 62 | "--log_samples", 63 | type=bool, 64 | default=False, 65 | help="Log inferenced samples per ckpt save steps", 66 | ) 67 | parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger") 68 | parser.add_argument( 69 | "--bnb_optimizer", 70 | type=bool, 71 | default=False, 72 | help="Use 8-bit Adam optimizer from bitsandbytes", 73 | ) 74 | 75 | return parser.parse_args() 76 | 77 | 78 | # -------------------------- Training Settings -------------------------- # 79 | 80 | 81 | def main(): 82 | args = parse_args() 83 | 84 | checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) 85 | 86 | # Model parameters based on experiment name 87 | if args.exp_name == "F5TTS_Base": 88 | wandb_resume_id = None 89 | model_cls = DiT 90 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 91 | if args.finetune: 92 | if args.pretrain is None: 93 | ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) 94 | else: 95 | ckpt_path = args.pretrain 96 | elif args.exp_name == "E2TTS_Base": 97 | wandb_resume_id = None 98 | model_cls = UNetT 99 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 100 | if args.finetune: 101 | if args.pretrain is None: 102 | ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) 103 | else: 104 | ckpt_path = args.pretrain 105 | 106 | if args.finetune: 107 | if not os.path.isdir(checkpoint_path): 108 | os.makedirs(checkpoint_path, exist_ok=True) 109 | 110 | file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) 111 | if not os.path.isfile(file_checkpoint): 112 | shutil.copy2(ckpt_path, file_checkpoint) 113 | print("copy checkpoint for finetune") 114 | 115 | # Use the tokenizer and tokenizer_path provided in the command line arguments 116 | tokenizer = args.tokenizer 117 | if tokenizer == "custom": 118 | if not args.tokenizer_path: 119 | raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.") 120 | tokenizer_path = args.tokenizer_path 121 | else: 122 | tokenizer_path = args.dataset_name 123 | 124 | vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) 125 | 126 | print("\nvocab : ", vocab_size) 127 | print("\nvocoder : ", mel_spec_type) 128 | 129 | mel_spec_kwargs = dict( 130 | n_fft=n_fft, 131 | hop_length=hop_length, 132 | win_length=win_length, 133 | n_mel_channels=n_mel_channels, 134 | target_sample_rate=target_sample_rate, 135 | mel_spec_type=mel_spec_type, 136 | ) 137 | 138 | model = CFM( 139 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 140 | mel_spec_kwargs=mel_spec_kwargs, 141 | vocab_char_map=vocab_char_map, 142 | ) 143 | 144 | trainer = Trainer( 145 | model, 146 | args.epochs, 147 | args.learning_rate, 148 | num_warmup_updates=args.num_warmup_updates, 149 | save_per_updates=args.save_per_updates, 150 | checkpoint_path=checkpoint_path, 151 | batch_size=args.batch_size_per_gpu, 152 | batch_size_type=args.batch_size_type, 153 | max_samples=args.max_samples, 154 | grad_accumulation_steps=args.grad_accumulation_steps, 155 | max_grad_norm=args.max_grad_norm, 156 | logger=args.logger, 157 | wandb_project=args.dataset_name, 158 | wandb_run_name=args.exp_name, 159 | wandb_resume_id=wandb_resume_id, 160 | log_samples=args.log_samples, 161 | last_per_steps=args.last_per_steps, 162 | bnb_optimizer=args.bnb_optimizer, 163 | ) 164 | 165 | train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) 166 | 167 | trainer.train( 168 | train_dataset, 169 | resumable_with_seed=666, # seed for shuffling dataset 170 | ) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /src/f5_tts/train/train.py: -------------------------------------------------------------------------------- 1 | # training script. 2 | 3 | from importlib.resources import files 4 | 5 | from f5_tts.model import CFM, DiT, Trainer, UNetT 6 | from f5_tts.model.dataset import load_dataset 7 | from f5_tts.model.utils import get_tokenizer 8 | 9 | # -------------------------- Dataset Settings --------------------------- # 10 | 11 | target_sample_rate = 24000 12 | n_mel_channels = 100 13 | hop_length = 256 14 | win_length = 1024 15 | n_fft = 1024 16 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan' 17 | 18 | tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' 19 | tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) 20 | dataset_name = "Emilia_ZH_EN" 21 | 22 | # -------------------------- Training Settings -------------------------- # 23 | 24 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base 25 | 26 | learning_rate = 7.5e-5 27 | 28 | batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200 29 | batch_size_type = "frame" # "frame" or "sample" 30 | max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models 31 | grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps 32 | max_grad_norm = 1.0 33 | 34 | epochs = 11 # use linear decay, thus epochs control the slope 35 | num_warmup_updates = 20000 # warmup steps 36 | save_per_updates = 50000 # save checkpoint per steps 37 | last_per_steps = 5000 # save last checkpoint per steps 38 | 39 | # model params 40 | if exp_name == "F5TTS_Base": 41 | wandb_resume_id = None 42 | model_cls = DiT 43 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) 44 | elif exp_name == "E2TTS_Base": 45 | wandb_resume_id = None 46 | model_cls = UNetT 47 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) 48 | 49 | 50 | # ----------------------------------------------------------------------- # 51 | 52 | 53 | def main(): 54 | if tokenizer == "custom": 55 | tokenizer_path = tokenizer_path 56 | else: 57 | tokenizer_path = dataset_name 58 | vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) 59 | 60 | mel_spec_kwargs = dict( 61 | n_fft=n_fft, 62 | hop_length=hop_length, 63 | win_length=win_length, 64 | n_mel_channels=n_mel_channels, 65 | target_sample_rate=target_sample_rate, 66 | mel_spec_type=mel_spec_type, 67 | ) 68 | 69 | model = CFM( 70 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), 71 | mel_spec_kwargs=mel_spec_kwargs, 72 | vocab_char_map=vocab_char_map, 73 | ) 74 | 75 | trainer = Trainer( 76 | model, 77 | epochs, 78 | learning_rate, 79 | num_warmup_updates=num_warmup_updates, 80 | save_per_updates=save_per_updates, 81 | checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")), 82 | batch_size=batch_size_per_gpu, 83 | batch_size_type=batch_size_type, 84 | max_samples=max_samples, 85 | grad_accumulation_steps=grad_accumulation_steps, 86 | max_grad_norm=max_grad_norm, 87 | wandb_project="CFM-TTS", 88 | wandb_run_name=exp_name, 89 | wandb_resume_id=wandb_resume_id, 90 | last_per_steps=last_per_steps, 91 | log_samples=True, 92 | mel_spec_type=mel_spec_type, 93 | ) 94 | 95 | train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) 96 | trainer.train( 97 | train_dataset, 98 | resumable_with_seed=666, # seed for shuffling dataset 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | --------------------------------------------------------------------------------