├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ ├── feature_request.yml
│ ├── help_wanted.yml
│ └── question.yml
└── workflows
│ ├── pre-commit.yaml
│ ├── publish-docker-image.yaml
│ └── sync-hf.yaml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── ckpts
└── README.md
├── data
├── Emilia_ZH_EN_pinyin
│ └── vocab.txt
└── librispeech_pc_test_clean_cross_sentence.lst
├── pyproject.toml
├── ruff.toml
└── src
└── f5_tts
├── api.py
├── eval
├── README.md
├── ecapa_tdnn.py
├── eval_infer_batch.py
├── eval_infer_batch.sh
├── eval_librispeech_test_clean.py
├── eval_seedtts_testset.py
└── utils_eval.py
├── infer
├── README.md
├── SHARED.md
├── examples
│ ├── basic
│ │ ├── basic.toml
│ │ ├── basic_ref_en.wav
│ │ └── basic_ref_zh.wav
│ ├── multi
│ │ ├── country.flac
│ │ ├── main.flac
│ │ ├── story.toml
│ │ ├── story.txt
│ │ └── town.flac
│ └── vocab.txt
├── infer_cli.py
├── infer_gradio.py
├── speech_edit.py
└── utils_infer.py
├── model
├── __init__.py
├── backbones
│ ├── README.md
│ ├── dit.py
│ ├── mmdit.py
│ └── unett.py
├── cfm.py
├── dataset.py
├── modules.py
├── trainer.py
└── utils.py
├── scripts
├── count_max_epoch.py
└── count_params_gflops.py
├── socket_server.py
└── train
├── README.md
├── datasets
├── prepare_csv_wavs.py
├── prepare_emilia.py
└── prepare_wenetspeech4tts.py
├── finetune_cli.py
├── finetune_gradio.py
└── train.py
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: "Bug Report"
2 | description: |
3 | Please provide as much details to help address the issue, including logs and screenshots.
4 | labels:
5 | - bug
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Checks
10 | description: "To ensure timely help, please confirm the following:"
11 | options:
12 | - label: This template is only for bug reports, usage problems go with 'Help Wanted'.
13 | required: true
14 | - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
15 | required: true
16 | - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
17 | required: true
18 | - label: I confirm that I am using English to submit this report in order to facilitate communication.
19 | required: true
20 | - type: textarea
21 | attributes:
22 | label: Environment Details
23 | description: "Provide details such as OS, Python version, and any relevant software or dependencies."
24 | placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
25 | validations:
26 | required: true
27 | - type: textarea
28 | attributes:
29 | label: Steps to Reproduce
30 | description: |
31 | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
32 | placeholder: |
33 | 1. Create a new conda environment.
34 | 2. Clone the repository, install as local editable and properly set up.
35 | 3. Run the command: `accelerate launch src/f5_tts/train/train.py`.
36 | 4. Have following error message... (attach logs).
37 | validations:
38 | required: true
39 | - type: textarea
40 | attributes:
41 | label: ✔️ Expected Behavior
42 | placeholder: Describe what you expected to happen.
43 | validations:
44 | required: false
45 | - type: textarea
46 | attributes:
47 | label: ❌ Actual Behavior
48 | placeholder: Describe what actually happened.
49 | validations:
50 | required: false
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: "Feature Request"
2 | description: |
3 | Some constructive suggestions and new ideas regarding current repo.
4 | labels:
5 | - enhancement
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Checks
10 | description: "To help us grasp quickly, please confirm the following:"
11 | options:
12 | - label: This template is only for feature request.
13 | required: true
14 | - label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs.
15 | required: true
16 | - label: I have searched for existing issues, including closed ones, and found not discussion yet.
17 | required: true
18 | - label: I confirm that I am using English to submit this report in order to facilitate communication.
19 | required: true
20 | - type: textarea
21 | attributes:
22 | label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
23 | description: |
24 | Describe the specific problem or scenario you're facing in detail. For example:
25 | *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
26 | placeholder: Please describe the situation in as much detail as possible.
27 | validations:
28 | required: true
29 |
30 | - type: textarea
31 | attributes:
32 | label: 2. What is your suggested solution?
33 | description: |
34 | Provide a clear description of the feature or enhancement you'd like to propose.
35 | How would this feature solve your issue or improve the project?
36 | placeholder: Describe your idea or proposed solution here.
37 | validations:
38 | required: true
39 |
40 | - type: textarea
41 | attributes:
42 | label: 3. Additional context or comments
43 | description: |
44 | Any other relevant information, links, documents, or screenshots that provide clarity.
45 | Use this section for anything not covered above.
46 | placeholder: Add any extra details here.
47 | validations:
48 | required: false
49 |
50 | - type: checkboxes
51 | attributes:
52 | label: 4. Can you help us with this feature?
53 | description: |
54 | Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
55 | options:
56 | - label: I am interested in contributing to this feature.
57 | required: false
58 |
59 | - type: markdown
60 | attributes:
61 | value: |
62 | **Note:** Please submit only one request per issue to keep discussions focused and manageable.
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/help_wanted.yml:
--------------------------------------------------------------------------------
1 | name: "Help Wanted"
2 | description: |
3 | Please provide as much details to help address the issue, including logs and screenshots.
4 | labels:
5 | - help wanted
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Checks
10 | description: "To ensure timely help, please confirm the following:"
11 | options:
12 | - label: This template is only for usage issues encountered.
13 | required: true
14 | - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
15 | required: true
16 | - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
17 | required: true
18 | - label: I confirm that I am using English to submit this report in order to facilitate communication.
19 | required: true
20 | - type: textarea
21 | attributes:
22 | label: Environment Details
23 | description: "Provide details such as OS, Python version, and any relevant software or dependencies."
24 | placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
25 | validations:
26 | required: true
27 | - type: textarea
28 | attributes:
29 | label: Steps to Reproduce
30 | description: |
31 | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
32 | placeholder: |
33 | 1. Create a new conda environment.
34 | 2. Clone the repository and install as pip package.
35 | 3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
36 | 4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
37 | validations:
38 | required: true
39 | - type: textarea
40 | attributes:
41 | label: ✔️ Expected Behavior
42 | placeholder: Describe what you expected to happen, e.g. output a generated audio
43 | validations:
44 | required: false
45 | - type: textarea
46 | attributes:
47 | label: ❌ Actual Behavior
48 | placeholder: Describe what actually happened, failure messages, etc.
49 | validations:
50 | required: false
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.yml:
--------------------------------------------------------------------------------
1 | name: "Question"
2 | description: |
3 | Pure question or inquiry about the project, usage issue goes with "help wanted".
4 | labels:
5 | - question
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Checks
10 | description: "To help us grasp quickly, please confirm the following:"
11 | options:
12 | - label: This template is only for question, not feature requests or bug reports.
13 | required: true
14 | - label: I have thoroughly reviewed the project documentation and read the related paper(s).
15 | required: true
16 | - label: I have searched for existing issues, including closed ones, no similar questions.
17 | required: true
18 | - label: I confirm that I am using English to submit this report in order to facilitate communication.
19 | required: true
20 | - type: textarea
21 | attributes:
22 | label: Question details
23 | description: |
24 | Question details, clearly stated using proper markdown syntax.
25 | validations:
26 | required: true
27 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 | - uses: actions/setup-python@v3
14 | - uses: pre-commit/action@v3.0.1
15 |
--------------------------------------------------------------------------------
/.github/workflows/publish-docker-image.yaml:
--------------------------------------------------------------------------------
1 | name: Create and publish a Docker image
2 |
3 | # Configures this workflow to run every time a change is pushed to the branch called `release`.
4 | on:
5 | push:
6 | branches: ['main']
7 |
8 | # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
9 | env:
10 | REGISTRY: ghcr.io
11 | IMAGE_NAME: ${{ github.repository }}
12 |
13 | # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
14 | jobs:
15 | build-and-push-image:
16 | runs-on: ubuntu-latest
17 | # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
18 | permissions:
19 | contents: read
20 | packages: write
21 | #
22 | steps:
23 | - name: Checkout repository
24 | uses: actions/checkout@v4
25 | - name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
26 | uses: jlumbroso/free-disk-space@main
27 | with:
28 | # This might remove tools that are actually needed, if set to "true" but frees about 6 GB
29 | tool-cache: false
30 |
31 | # All of these default to true, but feel free to set to "false" if necessary for your workflow
32 | android: true
33 | dotnet: true
34 | haskell: true
35 | large-packages: false
36 | swap-storage: false
37 | docker-images: false
38 | # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
39 | - name: Log in to the Container registry
40 | uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
41 | with:
42 | registry: ${{ env.REGISTRY }}
43 | username: ${{ github.actor }}
44 | password: ${{ secrets.GITHUB_TOKEN }}
45 | # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
46 | - name: Extract metadata (tags, labels) for Docker
47 | id: meta
48 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
49 | with:
50 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
51 | # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
52 | # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
53 | # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
54 | - name: Build and push Docker image
55 | uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
56 | with:
57 | context: .
58 | push: true
59 | tags: ${{ steps.meta.outputs.tags }}
60 | labels: ${{ steps.meta.outputs.labels }}
61 |
--------------------------------------------------------------------------------
/.github/workflows/sync-hf.yaml:
--------------------------------------------------------------------------------
1 | name: Sync to HF Space
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | trigger_curl:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Send cURL POST request
14 | run: |
15 | curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
16 | -s \
17 | -H "Content-Type: application/json" \
18 | -d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
19 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Customed
2 | .vscode/
3 | tests/
4 | runs/
5 | data/
6 | ckpts/
7 | wandb/
8 | results/
9 |
10 |
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 | cover/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | .pybuilder/
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | # For a library or package, you might want to ignore these files since the code is
98 | # intended to run in multiple environments; otherwise, check them in:
99 | # .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # poetry
109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110 | # This is especially recommended for binary packages to ensure reproducibility, and is more
111 | # commonly ignored for libraries.
112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113 | #poetry.lock
114 |
115 | # pdm
116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117 | #pdm.lock
118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119 | # in version control.
120 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121 | .pdm.toml
122 | .pdm-python
123 | .pdm-build/
124 |
125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126 | __pypackages__/
127 |
128 | # Celery stuff
129 | celerybeat-schedule
130 | celerybeat.pid
131 |
132 | # SageMath parsed files
133 | *.sage.py
134 |
135 | # Environments
136 | .env
137 | .venv
138 | env/
139 | venv/
140 | ENV/
141 | env.bak/
142 | venv.bak/
143 |
144 | # Spyder project settings
145 | .spyderproject
146 | .spyproject
147 |
148 | # Rope project settings
149 | .ropeproject
150 |
151 | # mkdocs documentation
152 | /site
153 |
154 | # mypy
155 | .mypy_cache/
156 | .dmypy.json
157 | dmypy.json
158 |
159 | # Pyre type checker
160 | .pyre/
161 |
162 | # pytype static type analyzer
163 | .pytype/
164 |
165 | # Cython debug symbols
166 | cython_debug/
167 |
168 | # PyCharm
169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171 | # and can be added to the global gitignore or merged into this file. For a more nuclear
172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173 | #.idea/
174 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "src/third_party/BigVGAN"]
2 | path = src/third_party/BigVGAN
3 | url = https://github.com/NVIDIA/BigVGAN.git
4 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.7.0
5 | hooks:
6 | # Run the linter.
7 | - id: ruff
8 | args: [--fix]
9 | # Run the formatter.
10 | - id: ruff-format
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v2.3.0
13 | hooks:
14 | - id: check-yaml
15 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2 |
3 | USER root
4 |
5 | ARG DEBIAN_FRONTEND=noninteractive
6 |
7 | LABEL github_repo="https://github.com/SWivid/F5-TTS"
8 |
9 | RUN set -x \
10 | && apt-get update \
11 | && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12 | && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13 | && rm -rf /var/lib/apt/lists/* \
14 | && apt-get clean
15 |
16 | WORKDIR /workspace
17 |
18 | RUN git clone https://github.com/SWivid/F5-TTS.git \
19 | && cd F5-TTS \
20 | && pip install -e .[eval]
21 |
22 | ENV SHELL=/bin/bash
23 |
24 | WORKDIR /workspace/F5-TTS
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yushen CHEN
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2 |
3 | [](https://github.com/SWivid/F5-TTS)
4 | [](https://arxiv.org/abs/2410.06885)
5 | [](https://swivid.github.io/F5-TTS/)
6 | [](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7 | [](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
8 | [](https://x-lance.sjtu.edu.cn/)
9 |
10 |
11 | **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
12 |
13 | **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
14 |
15 | **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
16 |
17 | ### Thanks to all the contributors !
18 |
19 | ## News
20 | - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
21 |
22 | ## Installation
23 |
24 | ```bash
25 | # Create a python 3.10 conda env (you could also use virtualenv)
26 | conda create -n f5-tts python=3.10
27 | conda activate f5-tts
28 |
29 | # Install pytorch with your CUDA version, e.g.
30 | pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
31 | ```
32 |
33 | Then you can choose from a few options below:
34 |
35 | ### 1. As a pip package (if just for inference)
36 |
37 | ```bash
38 | pip install git+https://github.com/SWivid/F5-TTS.git
39 | ```
40 |
41 | ### 2. Local editable (if also do training, finetuning)
42 |
43 | ```bash
44 | git clone https://github.com/SWivid/F5-TTS.git
45 | cd F5-TTS
46 | # git submodule update --init --recursive # (optional, if need bigvgan)
47 | pip install -e .
48 | ```
49 | If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`.
50 | ```python
51 | import os
52 | import sys
53 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
54 | ```
55 |
56 | ### 3. Docker usage
57 | ```bash
58 | # Build from Dockerfile
59 | docker build -t f5tts:v1 .
60 |
61 | # Or pull from GitHub Container Registry
62 | docker pull ghcr.io/swivid/f5-tts:main
63 | ```
64 |
65 |
66 | ## Inference
67 |
68 | ### 1. Gradio App
69 |
70 | Currently supported features:
71 |
72 | - Basic TTS with Chunk Inference
73 | - Multi-Style / Multi-Speaker Generation
74 | - Voice Chat powered by Qwen2.5-3B-Instruct
75 | - [Custom model](src/f5_tts/infer/SHARED.md) inference (local only)
76 |
77 | ```bash
78 | # Launch a Gradio app (web interface)
79 | f5-tts_infer-gradio
80 |
81 | # Specify the port/host
82 | f5-tts_infer-gradio --port 7860 --host 0.0.0.0
83 |
84 | # Launch a share link
85 | f5-tts_infer-gradio --share
86 | ```
87 |
88 | ### 2. CLI Inference
89 |
90 | ```bash
91 | # Run with flags
92 | # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
93 | f5-tts_infer-cli \
94 | --model "F5-TTS" \
95 | --ref_audio "ref_audio.wav" \
96 | --ref_text "The content, subtitle or transcription of reference audio." \
97 | --gen_text "Some text you want TTS model generate for you."
98 |
99 | # Run with default setting. src/f5_tts/infer/examples/basic/basic.toml
100 | f5-tts_infer-cli
101 | # Or with your own .toml file
102 | f5-tts_infer-cli -c custom.toml
103 |
104 | # Multi voice. See src/f5_tts/infer/README.md
105 | f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
106 | ```
107 |
108 | ### 3. More instructions
109 |
110 | - In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
111 | - The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
112 |
113 |
114 | ## Training
115 |
116 | ### 1. Gradio App
117 |
118 | Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
119 |
120 | ```bash
121 | # Quick start with Gradio web interface
122 | f5-tts_finetune-gradio
123 | ```
124 |
125 |
126 | ## [Evaluation](src/f5_tts/eval)
127 |
128 |
129 | ## Development
130 |
131 | Use pre-commit to ensure code quality (will run linters and formatters automatically)
132 |
133 | ```bash
134 | pip install pre-commit
135 | pre-commit install
136 | ```
137 |
138 | When making a pull request, before each commit, run:
139 |
140 | ```bash
141 | pre-commit run --all-files
142 | ```
143 |
144 | Note: Some model components have linting exceptions for E722 to accommodate tensor notation
145 |
146 |
147 | ## Acknowledgements
148 |
149 | - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150 | - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
151 | - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152 | - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153 | - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
154 | - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
155 | - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156 | - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157 | - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
158 | - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
159 |
160 | ## Citation
161 | If our work and codebase is useful for you, please cite as:
162 | ```
163 | @article{chen-etal-2024-f5tts,
164 | title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
165 | author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
166 | journal={arXiv preprint arXiv:2410.06885},
167 | year={2024},
168 | }
169 | ```
170 | ## License
171 |
172 | Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
173 |
--------------------------------------------------------------------------------
/ckpts/README.md:
--------------------------------------------------------------------------------
1 |
2 | Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
3 |
4 | ```
5 | ckpts/
6 | E2TTS_Base/
7 | model_1200000.pt
8 | F5TTS_Base/
9 | model_1200000.pt
10 | ```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "f5-tts"
7 | version = "0.1.1"
8 | description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9 | readme = "README.md"
10 | license = {text = "MIT License"}
11 | classifiers = [
12 | "License :: OSI Approved :: MIT License",
13 | "Operating System :: OS Independent",
14 | "Programming Language :: Python :: 3",
15 | ]
16 | dependencies = [
17 | "accelerate>=0.33.0",
18 | "bitsandbytes>0.37.0",
19 | "cached_path",
20 | "click",
21 | "datasets",
22 | "ema_pytorch>=0.5.2",
23 | "gradio>=3.45.2",
24 | "jieba",
25 | "librosa",
26 | "matplotlib",
27 | "numpy<=1.26.4",
28 | "pydub",
29 | "pypinyin",
30 | "safetensors",
31 | "soundfile",
32 | "tomli",
33 | "torch>=2.0.0",
34 | "torchaudio>=2.0.0",
35 | "torchdiffeq",
36 | "tqdm>=4.65.0",
37 | "transformers",
38 | "transformers_stream_generator",
39 | "vocos",
40 | "wandb",
41 | "x_transformers>=1.31.14",
42 | ]
43 |
44 | [project.optional-dependencies]
45 | eval = [
46 | "faster_whisper==0.10.1",
47 | "funasr",
48 | "jiwer",
49 | "modelscope",
50 | "zhconv",
51 | "zhon",
52 | ]
53 |
54 | [project.urls]
55 | Homepage = "https://github.com/SWivid/F5-TTS"
56 |
57 | [project.scripts]
58 | "f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main"
59 | "f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
60 | "f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
61 | "f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"
62 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | line-length = 120
2 | target-version = "py310"
3 |
4 | [lint]
5 | # Only ignore variables with names starting with "_".
6 | dummy-variable-rgx = "^_.*$"
7 |
8 | [lint.isort]
9 | force-single-line = true
10 | lines-after-imports = 2
11 |
--------------------------------------------------------------------------------
/src/f5_tts/api.py:
--------------------------------------------------------------------------------
1 | import random
2 | import sys
3 | from importlib.resources import files
4 |
5 | import soundfile as sf
6 | import torch
7 | import tqdm
8 | from cached_path import cached_path
9 |
10 | from f5_tts.infer.utils_infer import (
11 | hop_length,
12 | infer_process,
13 | load_model,
14 | load_vocoder,
15 | preprocess_ref_audio_text,
16 | remove_silence_for_generated_wav,
17 | save_spectrogram,
18 | transcribe,
19 | target_sample_rate,
20 | )
21 | from f5_tts.model import DiT, UNetT
22 | from f5_tts.model.utils import seed_everything
23 |
24 |
25 | class F5TTS:
26 | def __init__(
27 | self,
28 | model_type="F5-TTS",
29 | ckpt_file="",
30 | vocab_file="",
31 | ode_method="euler",
32 | use_ema=True,
33 | vocoder_name="vocos",
34 | local_path=None,
35 | device=None,
36 | hf_cache_dir=None,
37 | ):
38 | # Initialize parameters
39 | self.final_wave = None
40 | self.target_sample_rate = target_sample_rate
41 | self.hop_length = hop_length
42 | self.seed = -1
43 | self.mel_spec_type = vocoder_name
44 |
45 | # Set device
46 | self.device = device or (
47 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
48 | )
49 |
50 | # Load models
51 | self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
52 | self.load_ema_model(
53 | model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
54 | )
55 |
56 | def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
57 | self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
58 |
59 | def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
60 | if model_type == "F5-TTS":
61 | if not ckpt_file:
62 | if mel_spec_type == "vocos":
63 | ckpt_file = str(
64 | cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
65 | )
66 | elif mel_spec_type == "bigvgan":
67 | ckpt_file = str(
68 | cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
69 | )
70 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
71 | model_cls = DiT
72 | elif model_type == "E2-TTS":
73 | if not ckpt_file:
74 | ckpt_file = str(
75 | cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
76 | )
77 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
78 | model_cls = UNetT
79 | else:
80 | raise ValueError(f"Unknown model type: {model_type}")
81 |
82 | self.ema_model = load_model(
83 | model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
84 | )
85 |
86 | def transcribe(self, ref_audio, language=None):
87 | return transcribe(ref_audio, language)
88 |
89 | def export_wav(self, wav, file_wave, remove_silence=False):
90 | sf.write(file_wave, wav, self.target_sample_rate)
91 |
92 | if remove_silence:
93 | remove_silence_for_generated_wav(file_wave)
94 |
95 | def export_spectrogram(self, spect, file_spect):
96 | save_spectrogram(spect, file_spect)
97 |
98 | def infer(
99 | self,
100 | ref_file,
101 | ref_text,
102 | gen_text,
103 | show_info=print,
104 | progress=tqdm,
105 | target_rms=0.1,
106 | cross_fade_duration=0.15,
107 | sway_sampling_coef=-1,
108 | cfg_strength=2,
109 | nfe_step=32,
110 | speed=1.0,
111 | fix_duration=None,
112 | remove_silence=False,
113 | file_wave=None,
114 | file_spect=None,
115 | seed=-1,
116 | ):
117 | if seed == -1:
118 | seed = random.randint(0, sys.maxsize)
119 | seed_everything(seed)
120 | self.seed = seed
121 |
122 | ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
123 |
124 | wav, sr, spect = infer_process(
125 | ref_file,
126 | ref_text,
127 | gen_text,
128 | self.ema_model,
129 | self.vocoder,
130 | self.mel_spec_type,
131 | show_info=show_info,
132 | progress=progress,
133 | target_rms=target_rms,
134 | cross_fade_duration=cross_fade_duration,
135 | nfe_step=nfe_step,
136 | cfg_strength=cfg_strength,
137 | sway_sampling_coef=sway_sampling_coef,
138 | speed=speed,
139 | fix_duration=fix_duration,
140 | device=self.device,
141 | )
142 |
143 | if file_wave is not None:
144 | self.export_wav(wav, file_wave, remove_silence)
145 |
146 | if file_spect is not None:
147 | self.export_spectrogram(spect, file_spect)
148 |
149 | return wav, sr, spect
150 |
151 |
152 | if __name__ == "__main__":
153 | f5tts = F5TTS()
154 |
155 | wav, sr, spect = f5tts.infer(
156 | ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
157 | ref_text="some call me nature, others call me mother nature.",
158 | gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
159 | file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
160 | file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
161 | seed=-1, # random seed = -1
162 | )
163 |
164 | print("seed :", f5tts.seed)
165 |
--------------------------------------------------------------------------------
/src/f5_tts/eval/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Evaluation
3 |
4 | Install packages for evaluation:
5 |
6 | ```bash
7 | pip install -e .[eval]
8 | ```
9 |
10 | ## Generating Samples for Evaluation
11 |
12 | ### Prepare Test Datasets
13 |
14 | 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
15 | 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
16 | 3. Unzip the downloaded datasets and place them in the `data/` directory.
17 | 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
18 | 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
19 |
20 | ### Batch Inference for Test Set
21 |
22 | To run batch inference for evaluations, execute the following commands:
23 |
24 | ```bash
25 | # batch inference for evaluations
26 | accelerate config # if not set before
27 | bash src/f5_tts/eval/eval_infer_batch.sh
28 | ```
29 |
30 | ## Objective Evaluation on Generated Results
31 |
32 | ### Download Evaluation Model Checkpoints
33 |
34 | 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
35 | 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
36 | 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
37 |
38 | Then update in the following scripts with the paths you put evaluation model ckpts to.
39 |
40 | ### Objective Evaluation
41 |
42 | Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43 | ```bash
44 | # Evaluation for Seed-TTS test set
45 | python src/f5_tts/eval/eval_seedtts_testset.py
46 |
47 | # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48 | python src/f5_tts/eval/eval_librispeech_test_clean.py
49 | ```
--------------------------------------------------------------------------------
/src/f5_tts/eval/ecapa_tdnn.py:
--------------------------------------------------------------------------------
1 | # just for speaker similarity evaluation, third-party code
2 |
3 | # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4 | # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5 |
6 | import os
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | """ Res2Conv1d + BatchNorm1d + ReLU
13 | """
14 |
15 |
16 | class Res2Conv1dReluBn(nn.Module):
17 | """
18 | in_channels == out_channels == channels
19 | """
20 |
21 | def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22 | super().__init__()
23 | assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
24 | self.scale = scale
25 | self.width = channels // scale
26 | self.nums = scale if scale == 1 else scale - 1
27 |
28 | self.convs = []
29 | self.bns = []
30 | for i in range(self.nums):
31 | self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
32 | self.bns.append(nn.BatchNorm1d(self.width))
33 | self.convs = nn.ModuleList(self.convs)
34 | self.bns = nn.ModuleList(self.bns)
35 |
36 | def forward(self, x):
37 | out = []
38 | spx = torch.split(x, self.width, 1)
39 | for i in range(self.nums):
40 | if i == 0:
41 | sp = spx[i]
42 | else:
43 | sp = sp + spx[i]
44 | # Order: conv -> relu -> bn
45 | sp = self.convs[i](sp)
46 | sp = self.bns[i](F.relu(sp))
47 | out.append(sp)
48 | if self.scale != 1:
49 | out.append(spx[self.nums])
50 | out = torch.cat(out, dim=1)
51 |
52 | return out
53 |
54 |
55 | """ Conv1d + BatchNorm1d + ReLU
56 | """
57 |
58 |
59 | class Conv1dReluBn(nn.Module):
60 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
61 | super().__init__()
62 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
63 | self.bn = nn.BatchNorm1d(out_channels)
64 |
65 | def forward(self, x):
66 | return self.bn(F.relu(self.conv(x)))
67 |
68 |
69 | """ The SE connection of 1D case.
70 | """
71 |
72 |
73 | class SE_Connect(nn.Module):
74 | def __init__(self, channels, se_bottleneck_dim=128):
75 | super().__init__()
76 | self.linear1 = nn.Linear(channels, se_bottleneck_dim)
77 | self.linear2 = nn.Linear(se_bottleneck_dim, channels)
78 |
79 | def forward(self, x):
80 | out = x.mean(dim=2)
81 | out = F.relu(self.linear1(out))
82 | out = torch.sigmoid(self.linear2(out))
83 | out = x * out.unsqueeze(2)
84 |
85 | return out
86 |
87 |
88 | """ SE-Res2Block of the ECAPA-TDNN architecture.
89 | """
90 |
91 | # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92 | # return nn.Sequential(
93 | # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
94 | # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
95 | # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
96 | # SE_Connect(channels)
97 | # )
98 |
99 |
100 | class SE_Res2Block(nn.Module):
101 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102 | super().__init__()
103 | self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
104 | self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
105 | self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
106 | self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
107 |
108 | self.shortcut = None
109 | if in_channels != out_channels:
110 | self.shortcut = nn.Conv1d(
111 | in_channels=in_channels,
112 | out_channels=out_channels,
113 | kernel_size=1,
114 | )
115 |
116 | def forward(self, x):
117 | residual = x
118 | if self.shortcut:
119 | residual = self.shortcut(x)
120 |
121 | x = self.Conv1dReluBn1(x)
122 | x = self.Res2Conv1dReluBn(x)
123 | x = self.Conv1dReluBn2(x)
124 | x = self.SE_Connect(x)
125 |
126 | return x + residual
127 |
128 |
129 | """ Attentive weighted mean and standard deviation pooling.
130 | """
131 |
132 |
133 | class AttentiveStatsPool(nn.Module):
134 | def __init__(self, in_dim, attention_channels=128, global_context_att=False):
135 | super().__init__()
136 | self.global_context_att = global_context_att
137 |
138 | # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
139 | if global_context_att:
140 | self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
141 | else:
142 | self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
143 | self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144 |
145 | def forward(self, x):
146 | if self.global_context_att:
147 | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148 | context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
149 | x_in = torch.cat((x, context_mean, context_std), dim=1)
150 | else:
151 | x_in = x
152 |
153 | # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
154 | alpha = torch.tanh(self.linear1(x_in))
155 | # alpha = F.relu(self.linear1(x_in))
156 | alpha = torch.softmax(self.linear2(alpha), dim=2)
157 | mean = torch.sum(alpha * x, dim=2)
158 | residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159 | std = torch.sqrt(residuals.clamp(min=1e-9))
160 | return torch.cat([mean, std], dim=1)
161 |
162 |
163 | class ECAPA_TDNN(nn.Module):
164 | def __init__(
165 | self,
166 | feat_dim=80,
167 | channels=512,
168 | emb_dim=192,
169 | global_context_att=False,
170 | feat_type="wavlm_large",
171 | sr=16000,
172 | feature_selection="hidden_states",
173 | update_extract=False,
174 | config_path=None,
175 | ):
176 | super().__init__()
177 |
178 | self.feat_type = feat_type
179 | self.feature_selection = feature_selection
180 | self.update_extract = update_extract
181 | self.sr = sr
182 |
183 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184 | try:
185 | local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186 | self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187 | except: # noqa: E722
188 | self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189 |
190 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191 | self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192 | ):
193 | self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194 | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195 | self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196 | ):
197 | self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198 |
199 | self.feat_num = self.get_feat_num()
200 | self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201 |
202 | if feat_type != "fbank" and feat_type != "mfcc":
203 | freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204 | for name, param in self.feature_extract.named_parameters():
205 | for freeze_val in freeze_list:
206 | if freeze_val in name:
207 | param.requires_grad = False
208 | break
209 |
210 | if not self.update_extract:
211 | for param in self.feature_extract.parameters():
212 | param.requires_grad = False
213 |
214 | self.instance_norm = nn.InstanceNorm1d(feat_dim)
215 | # self.channels = [channels] * 4 + [channels * 3]
216 | self.channels = [channels] * 4 + [1536]
217 |
218 | self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219 | self.layer2 = SE_Res2Block(
220 | self.channels[0],
221 | self.channels[1],
222 | kernel_size=3,
223 | stride=1,
224 | padding=2,
225 | dilation=2,
226 | scale=8,
227 | se_bottleneck_dim=128,
228 | )
229 | self.layer3 = SE_Res2Block(
230 | self.channels[1],
231 | self.channels[2],
232 | kernel_size=3,
233 | stride=1,
234 | padding=3,
235 | dilation=3,
236 | scale=8,
237 | se_bottleneck_dim=128,
238 | )
239 | self.layer4 = SE_Res2Block(
240 | self.channels[2],
241 | self.channels[3],
242 | kernel_size=3,
243 | stride=1,
244 | padding=4,
245 | dilation=4,
246 | scale=8,
247 | se_bottleneck_dim=128,
248 | )
249 |
250 | # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251 | cat_channels = channels * 3
252 | self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253 | self.pooling = AttentiveStatsPool(
254 | self.channels[-1], attention_channels=128, global_context_att=global_context_att
255 | )
256 | self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257 | self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258 |
259 | def get_feat_num(self):
260 | self.feature_extract.eval()
261 | wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
262 | with torch.no_grad():
263 | features = self.feature_extract(wav)
264 | select_feature = features[self.feature_selection]
265 | if isinstance(select_feature, (list, tuple)):
266 | return len(select_feature)
267 | else:
268 | return 1
269 |
270 | def get_feat(self, x):
271 | if self.update_extract:
272 | x = self.feature_extract([sample for sample in x])
273 | else:
274 | with torch.no_grad():
275 | if self.feat_type == "fbank" or self.feat_type == "mfcc":
276 | x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277 | else:
278 | x = self.feature_extract([sample for sample in x])
279 |
280 | if self.feat_type == "fbank":
281 | x = x.log()
282 |
283 | if self.feat_type != "fbank" and self.feat_type != "mfcc":
284 | x = x[self.feature_selection]
285 | if isinstance(x, (list, tuple)):
286 | x = torch.stack(x, dim=0)
287 | else:
288 | x = x.unsqueeze(0)
289 | norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
290 | x = (norm_weights * x).sum(dim=0)
291 | x = torch.transpose(x, 1, 2) + 1e-6
292 |
293 | x = self.instance_norm(x)
294 | return x
295 |
296 | def forward(self, x):
297 | x = self.get_feat(x)
298 |
299 | out1 = self.layer1(x)
300 | out2 = self.layer2(out1)
301 | out3 = self.layer3(out2)
302 | out4 = self.layer4(out3)
303 |
304 | out = torch.cat([out2, out3, out4], dim=1)
305 | out = F.relu(self.conv(out))
306 | out = self.bn(self.pooling(out))
307 | out = self.linear(out)
308 |
309 | return out
310 |
311 |
312 | def ECAPA_TDNN_SMALL(
313 | feat_dim,
314 | emb_dim=256,
315 | feat_type="wavlm_large",
316 | sr=16000,
317 | feature_selection="hidden_states",
318 | update_extract=False,
319 | config_path=None,
320 | ):
321 | return ECAPA_TDNN(
322 | feat_dim=feat_dim,
323 | channels=512,
324 | emb_dim=emb_dim,
325 | feat_type=feat_type,
326 | sr=sr,
327 | feature_selection=feature_selection,
328 | update_extract=update_extract,
329 | config_path=config_path,
330 | )
331 |
--------------------------------------------------------------------------------
/src/f5_tts/eval/eval_infer_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | import argparse
7 | import time
8 | from importlib.resources import files
9 |
10 | import torch
11 | import torchaudio
12 | from accelerate import Accelerator
13 | from tqdm import tqdm
14 |
15 | from f5_tts.eval.utils_eval import (
16 | get_inference_prompt,
17 | get_librispeech_test_clean_metainfo,
18 | get_seedtts_testset_metainfo,
19 | )
20 | from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21 | from f5_tts.model import CFM, DiT, UNetT
22 | from f5_tts.model.utils import get_tokenizer
23 |
24 | accelerator = Accelerator()
25 | device = f"cuda:{accelerator.process_index}"
26 |
27 |
28 | # --------------------- Dataset Settings -------------------- #
29 |
30 | target_sample_rate = 24000
31 | n_mel_channels = 100
32 | hop_length = 256
33 | win_length = 1024
34 | n_fft = 1024
35 | target_rms = 0.1
36 |
37 |
38 | tokenizer = "pinyin"
39 | rel_path = str(files("f5_tts").joinpath("../../"))
40 |
41 |
42 | def main():
43 | # ---------------------- infer setting ---------------------- #
44 |
45 | parser = argparse.ArgumentParser(description="batch inference")
46 |
47 | parser.add_argument("-s", "--seed", default=None, type=int)
48 | parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
49 | parser.add_argument("-n", "--expname", required=True)
50 | parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
51 | parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
52 |
53 | parser.add_argument("-nfe", "--nfestep", default=32, type=int)
54 | parser.add_argument("-o", "--odemethod", default="euler")
55 | parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
56 |
57 | parser.add_argument("-t", "--testset", required=True)
58 |
59 | args = parser.parse_args()
60 |
61 | seed = args.seed
62 | dataset_name = args.dataset
63 | exp_name = args.expname
64 | ckpt_step = args.ckptstep
65 | ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
66 | mel_spec_type = args.mel_spec_type
67 |
68 | nfe_step = args.nfestep
69 | ode_method = args.odemethod
70 | sway_sampling_coef = args.swaysampling
71 |
72 | testset = args.testset
73 |
74 | infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
75 | cfg_strength = 2.0
76 | speed = 1.0
77 | use_truth_duration = False
78 | no_ref_audio = False
79 |
80 | if exp_name == "F5TTS_Base":
81 | model_cls = DiT
82 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
83 |
84 | elif exp_name == "E2TTS_Base":
85 | model_cls = UNetT
86 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
87 |
88 | if testset == "ls_pc_test_clean":
89 | metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
90 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path
91 | metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
92 |
93 | elif testset == "seedtts_test_zh":
94 | metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
95 | metainfo = get_seedtts_testset_metainfo(metalst)
96 |
97 | elif testset == "seedtts_test_en":
98 | metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
99 | metainfo = get_seedtts_testset_metainfo(metalst)
100 |
101 | # path to save genereted wavs
102 | output_dir = (
103 | f"{rel_path}/"
104 | f"results/{exp_name}_{ckpt_step}/{testset}/"
105 | f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
106 | f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107 | f"_cfg{cfg_strength}_speed{speed}"
108 | f"{'_gt-dur' if use_truth_duration else ''}"
109 | f"{'_no-ref-audio' if no_ref_audio else ''}"
110 | )
111 |
112 | # -------------------------------------------------#
113 |
114 | use_ema = True
115 |
116 | prompts_all = get_inference_prompt(
117 | metainfo,
118 | speed=speed,
119 | tokenizer=tokenizer,
120 | target_sample_rate=target_sample_rate,
121 | n_mel_channels=n_mel_channels,
122 | hop_length=hop_length,
123 | mel_spec_type=mel_spec_type,
124 | target_rms=target_rms,
125 | use_truth_duration=use_truth_duration,
126 | infer_batch_size=infer_batch_size,
127 | )
128 |
129 | # Vocoder model
130 | local = False
131 | if mel_spec_type == "vocos":
132 | vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
133 | elif mel_spec_type == "bigvgan":
134 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
135 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
136 |
137 | # Tokenizer
138 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
139 |
140 | # Model
141 | model = CFM(
142 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
143 | mel_spec_kwargs=dict(
144 | n_fft=n_fft,
145 | hop_length=hop_length,
146 | win_length=win_length,
147 | n_mel_channels=n_mel_channels,
148 | target_sample_rate=target_sample_rate,
149 | mel_spec_type=mel_spec_type,
150 | ),
151 | odeint_kwargs=dict(
152 | method=ode_method,
153 | ),
154 | vocab_char_map=vocab_char_map,
155 | ).to(device)
156 |
157 | dtype = torch.float32 if mel_spec_type == "bigvgan" else None
158 | model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
159 |
160 | if not os.path.exists(output_dir) and accelerator.is_main_process:
161 | os.makedirs(output_dir)
162 |
163 | # start batch inference
164 | accelerator.wait_for_everyone()
165 | start = time.time()
166 |
167 | with accelerator.split_between_processes(prompts_all) as prompts:
168 | for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170 | ref_mels = ref_mels.to(device)
171 | ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
172 | total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
173 |
174 | # Inference
175 | with torch.inference_mode():
176 | generated, _ = model.sample(
177 | cond=ref_mels,
178 | text=final_text_list,
179 | duration=total_mel_lens,
180 | lens=ref_mel_lens,
181 | steps=nfe_step,
182 | cfg_strength=cfg_strength,
183 | sway_sampling_coef=sway_sampling_coef,
184 | no_ref_audio=no_ref_audio,
185 | seed=seed,
186 | )
187 | # Final result
188 | for i, gen in enumerate(generated):
189 | gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190 | gen_mel_spec = gen.permute(0, 2, 1)
191 | if mel_spec_type == "vocos":
192 | generated_wave = vocoder.decode(gen_mel_spec)
193 | elif mel_spec_type == "bigvgan":
194 | generated_wave = vocoder(gen_mel_spec)
195 |
196 | if ref_rms_list[i] < target_rms:
197 | generated_wave = generated_wave * ref_rms_list[i] / target_rms
198 | torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
199 |
200 | accelerator.wait_for_everyone()
201 | if accelerator.is_main_process:
202 | timediff = time.time() - start
203 | print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
204 |
205 |
206 | if __name__ == "__main__":
207 | main()
208 |
--------------------------------------------------------------------------------
/src/f5_tts/eval/eval_infer_batch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # e.g. F5-TTS, 16 NFE
4 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7 |
8 | # e.g. Vanilla E2 TTS, 32 NFE
9 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11 | accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12 |
13 | # etc.
14 |
--------------------------------------------------------------------------------
/src/f5_tts/eval/eval_librispeech_test_clean.py:
--------------------------------------------------------------------------------
1 | # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2 |
3 | import sys
4 | import os
5 |
6 | sys.path.append(os.getcwd())
7 |
8 | import multiprocessing as mp
9 | from importlib.resources import files
10 |
11 | import numpy as np
12 |
13 | from f5_tts.eval.utils_eval import (
14 | get_librispeech_test,
15 | run_asr_wer,
16 | run_sim,
17 | )
18 |
19 | rel_path = str(files("f5_tts").joinpath("../../"))
20 |
21 |
22 | eval_task = "wer" # sim | wer
23 | lang = "en"
24 | metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
25 | librispeech_test_clean_path = "/LibriSpeech/test-clean" # test-clean path
26 | gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27 |
28 | gpus = [0, 1, 2, 3, 4, 5, 6, 7]
29 | test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
30 |
31 | ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
32 | ## leading to a low similarity for the ground truth in some cases.
33 | # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
34 |
35 | local = False
36 | if local: # use local custom checkpoint dir
37 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
38 | else:
39 | asr_ckpt_dir = "" # auto download to cache dir
40 |
41 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
42 |
43 |
44 | # --------------------------- WER ---------------------------
45 |
46 | if eval_task == "wer":
47 | wers = []
48 |
49 | with mp.Pool(processes=len(gpus)) as pool:
50 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
51 | results = pool.map(run_asr_wer, args)
52 | for wers_ in results:
53 | wers.extend(wers_)
54 |
55 | wer = round(np.mean(wers) * 100, 3)
56 | print(f"\nTotal {len(wers)} samples")
57 | print(f"WER : {wer}%")
58 |
59 |
60 | # --------------------------- SIM ---------------------------
61 |
62 | if eval_task == "sim":
63 | sim_list = []
64 |
65 | with mp.Pool(processes=len(gpus)) as pool:
66 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
67 | results = pool.map(run_sim, args)
68 | for sim_ in results:
69 | sim_list.extend(sim_)
70 |
71 | sim = round(sum(sim_list) / len(sim_list), 3)
72 | print(f"\nTotal {len(sim_list)} samples")
73 | print(f"SIM : {sim}")
74 |
--------------------------------------------------------------------------------
/src/f5_tts/eval/eval_seedtts_testset.py:
--------------------------------------------------------------------------------
1 | # Evaluate with Seed-TTS testset
2 |
3 | import sys
4 | import os
5 |
6 | sys.path.append(os.getcwd())
7 |
8 | import multiprocessing as mp
9 | from importlib.resources import files
10 |
11 | import numpy as np
12 |
13 | from f5_tts.eval.utils_eval import (
14 | get_seed_tts_test,
15 | run_asr_wer,
16 | run_sim,
17 | )
18 |
19 | rel_path = str(files("f5_tts").joinpath("../../"))
20 |
21 |
22 | eval_task = "wer" # sim | wer
23 | lang = "zh" # zh | en
24 | metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
25 | # gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs
26 | gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27 |
28 |
29 | # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
30 | # zh 1.254 seems a result of 4 workers wer_seed_tts
31 | gpus = [0, 1, 2, 3, 4, 5, 6, 7]
32 | test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
33 |
34 | local = False
35 | if local: # use local custom checkpoint dir
36 | if lang == "zh":
37 | asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
38 | elif lang == "en":
39 | asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
40 | else:
41 | asr_ckpt_dir = "" # auto download to cache dir
42 |
43 | wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
44 |
45 |
46 | # --------------------------- WER ---------------------------
47 |
48 | if eval_task == "wer":
49 | wers = []
50 |
51 | with mp.Pool(processes=len(gpus)) as pool:
52 | args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
53 | results = pool.map(run_asr_wer, args)
54 | for wers_ in results:
55 | wers.extend(wers_)
56 |
57 | wer = round(np.mean(wers) * 100, 3)
58 | print(f"\nTotal {len(wers)} samples")
59 | print(f"WER : {wer}%")
60 |
61 |
62 | # --------------------------- SIM ---------------------------
63 |
64 | if eval_task == "sim":
65 | sim_list = []
66 |
67 | with mp.Pool(processes=len(gpus)) as pool:
68 | args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69 | results = pool.map(run_sim, args)
70 | for sim_ in results:
71 | sim_list.extend(sim_)
72 |
73 | sim = round(sum(sim_list) / len(sim_list), 3)
74 | print(f"\nTotal {len(sim_list)} samples")
75 | print(f"SIM : {sim}")
76 |
--------------------------------------------------------------------------------
/src/f5_tts/eval/utils_eval.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import string
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import torchaudio
9 | from tqdm import tqdm
10 |
11 | from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
12 | from f5_tts.model.modules import MelSpec
13 | from f5_tts.model.utils import convert_char_to_pinyin
14 |
15 |
16 | # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
17 | def get_seedtts_testset_metainfo(metalst):
18 | f = open(metalst)
19 | lines = f.readlines()
20 | f.close()
21 | metainfo = []
22 | for line in lines:
23 | if len(line.strip().split("|")) == 5:
24 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
25 | elif len(line.strip().split("|")) == 4:
26 | utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
27 | gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
28 | if not os.path.isabs(prompt_wav):
29 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
30 | metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
31 | return metainfo
32 |
33 |
34 | # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
35 | def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
36 | f = open(metalst)
37 | lines = f.readlines()
38 | f.close()
39 | metainfo = []
40 | for line in lines:
41 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
42 |
43 | # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
44 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
45 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
46 |
47 | # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
48 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
49 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
50 |
51 | metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
52 |
53 | return metainfo
54 |
55 |
56 | # padded to max length mel batch
57 | def padded_mel_batch(ref_mels):
58 | max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
59 | padded_ref_mels = []
60 | for mel in ref_mels:
61 | padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
62 | padded_ref_mels.append(padded_ref_mel)
63 | padded_ref_mels = torch.stack(padded_ref_mels)
64 | padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
65 | return padded_ref_mels
66 |
67 |
68 | # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
69 |
70 |
71 | def get_inference_prompt(
72 | metainfo,
73 | speed=1.0,
74 | tokenizer="pinyin",
75 | polyphone=True,
76 | target_sample_rate=24000,
77 | n_fft=1024,
78 | win_length=1024,
79 | n_mel_channels=100,
80 | hop_length=256,
81 | mel_spec_type="vocos",
82 | target_rms=0.1,
83 | use_truth_duration=False,
84 | infer_batch_size=1,
85 | num_buckets=200,
86 | min_secs=3,
87 | max_secs=40,
88 | ):
89 | prompts_all = []
90 |
91 | min_tokens = min_secs * target_sample_rate // hop_length
92 | max_tokens = max_secs * target_sample_rate // hop_length
93 |
94 | batch_accum = [0] * num_buckets
95 | utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
96 | [[] for _ in range(num_buckets)] for _ in range(6)
97 | )
98 |
99 | mel_spectrogram = MelSpec(
100 | n_fft=n_fft,
101 | hop_length=hop_length,
102 | win_length=win_length,
103 | n_mel_channels=n_mel_channels,
104 | target_sample_rate=target_sample_rate,
105 | mel_spec_type=mel_spec_type,
106 | )
107 |
108 | for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
109 | # Audio
110 | ref_audio, ref_sr = torchaudio.load(prompt_wav)
111 | ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
112 | if ref_rms < target_rms:
113 | ref_audio = ref_audio * target_rms / ref_rms
114 | assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
115 | if ref_sr != target_sample_rate:
116 | resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
117 | ref_audio = resampler(ref_audio)
118 |
119 | # Text
120 | if len(prompt_text[-1].encode("utf-8")) == 1:
121 | prompt_text = prompt_text + " "
122 | text = [prompt_text + gt_text]
123 | if tokenizer == "pinyin":
124 | text_list = convert_char_to_pinyin(text, polyphone=polyphone)
125 | else:
126 | text_list = text
127 |
128 | # Duration, mel frame length
129 | ref_mel_len = ref_audio.shape[-1] // hop_length
130 | if use_truth_duration:
131 | gt_audio, gt_sr = torchaudio.load(gt_wav)
132 | if gt_sr != target_sample_rate:
133 | resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
134 | gt_audio = resampler(gt_audio)
135 | total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
136 |
137 | # # test vocoder resynthesis
138 | # ref_audio = gt_audio
139 | else:
140 | ref_text_len = len(prompt_text.encode("utf-8"))
141 | gen_text_len = len(gt_text.encode("utf-8"))
142 | total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
143 |
144 | # to mel spectrogram
145 | ref_mel = mel_spectrogram(ref_audio)
146 | ref_mel = ref_mel.squeeze(0)
147 |
148 | # deal with batch
149 | assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
150 | assert (
151 | min_tokens <= total_mel_len <= max_tokens
152 | ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
153 | bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
154 |
155 | utts[bucket_i].append(utt)
156 | ref_rms_list[bucket_i].append(ref_rms)
157 | ref_mels[bucket_i].append(ref_mel)
158 | ref_mel_lens[bucket_i].append(ref_mel_len)
159 | total_mel_lens[bucket_i].append(total_mel_len)
160 | final_text_list[bucket_i].extend(text_list)
161 |
162 | batch_accum[bucket_i] += total_mel_len
163 |
164 | if batch_accum[bucket_i] >= infer_batch_size:
165 | # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
166 | prompts_all.append(
167 | (
168 | utts[bucket_i],
169 | ref_rms_list[bucket_i],
170 | padded_mel_batch(ref_mels[bucket_i]),
171 | ref_mel_lens[bucket_i],
172 | total_mel_lens[bucket_i],
173 | final_text_list[bucket_i],
174 | )
175 | )
176 | batch_accum[bucket_i] = 0
177 | (
178 | utts[bucket_i],
179 | ref_rms_list[bucket_i],
180 | ref_mels[bucket_i],
181 | ref_mel_lens[bucket_i],
182 | total_mel_lens[bucket_i],
183 | final_text_list[bucket_i],
184 | ) = [], [], [], [], [], []
185 |
186 | # add residual
187 | for bucket_i, bucket_frames in enumerate(batch_accum):
188 | if bucket_frames > 0:
189 | prompts_all.append(
190 | (
191 | utts[bucket_i],
192 | ref_rms_list[bucket_i],
193 | padded_mel_batch(ref_mels[bucket_i]),
194 | ref_mel_lens[bucket_i],
195 | total_mel_lens[bucket_i],
196 | final_text_list[bucket_i],
197 | )
198 | )
199 | # not only leave easy work for last workers
200 | random.seed(666)
201 | random.shuffle(prompts_all)
202 |
203 | return prompts_all
204 |
205 |
206 | # get wav_res_ref_text of seed-tts test metalst
207 | # https://github.com/BytedanceSpeech/seed-tts-eval
208 |
209 |
210 | def get_seed_tts_test(metalst, gen_wav_dir, gpus):
211 | f = open(metalst)
212 | lines = f.readlines()
213 | f.close()
214 |
215 | test_set_ = []
216 | for line in tqdm(lines):
217 | if len(line.strip().split("|")) == 5:
218 | utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
219 | elif len(line.strip().split("|")) == 4:
220 | utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
221 |
222 | if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
223 | continue
224 | gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
225 | if not os.path.isabs(prompt_wav):
226 | prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
227 |
228 | test_set_.append((gen_wav, prompt_wav, gt_text))
229 |
230 | num_jobs = len(gpus)
231 | if num_jobs == 1:
232 | return [(gpus[0], test_set_)]
233 |
234 | wav_per_job = len(test_set_) // num_jobs + 1
235 | test_set = []
236 | for i in range(num_jobs):
237 | test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
238 |
239 | return test_set
240 |
241 |
242 | # get librispeech test-clean cross sentence test
243 |
244 |
245 | def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
246 | f = open(metalst)
247 | lines = f.readlines()
248 | f.close()
249 |
250 | test_set_ = []
251 | for line in tqdm(lines):
252 | ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
253 |
254 | if eval_ground_truth:
255 | gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
256 | gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
257 | else:
258 | if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
259 | raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
260 | gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
261 |
262 | ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
263 | ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
264 |
265 | test_set_.append((gen_wav, ref_wav, gen_txt))
266 |
267 | num_jobs = len(gpus)
268 | if num_jobs == 1:
269 | return [(gpus[0], test_set_)]
270 |
271 | wav_per_job = len(test_set_) // num_jobs + 1
272 | test_set = []
273 | for i in range(num_jobs):
274 | test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
275 |
276 | return test_set
277 |
278 |
279 | # load asr model
280 |
281 |
282 | def load_asr_model(lang, ckpt_dir=""):
283 | if lang == "zh":
284 | from funasr import AutoModel
285 |
286 | model = AutoModel(
287 | model=os.path.join(ckpt_dir, "paraformer-zh"),
288 | # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
289 | # punc_model = os.path.join(ckpt_dir, "ct-punc"),
290 | # spk_model = os.path.join(ckpt_dir, "cam++"),
291 | disable_update=True,
292 | ) # following seed-tts setting
293 | elif lang == "en":
294 | from faster_whisper import WhisperModel
295 |
296 | model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
297 | model = WhisperModel(model_size, device="cuda", compute_type="float16")
298 | return model
299 |
300 |
301 | # WER Evaluation, the way Seed-TTS does
302 |
303 |
304 | def run_asr_wer(args):
305 | rank, lang, test_set, ckpt_dir = args
306 |
307 | if lang == "zh":
308 | import zhconv
309 |
310 | torch.cuda.set_device(rank)
311 | elif lang == "en":
312 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
313 | else:
314 | raise NotImplementedError(
315 | "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
316 | )
317 |
318 | asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
319 |
320 | from zhon.hanzi import punctuation
321 |
322 | punctuation_all = punctuation + string.punctuation
323 | wers = []
324 |
325 | from jiwer import compute_measures
326 |
327 | for gen_wav, prompt_wav, truth in tqdm(test_set):
328 | if lang == "zh":
329 | res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
330 | hypo = res[0]["text"]
331 | hypo = zhconv.convert(hypo, "zh-cn")
332 | elif lang == "en":
333 | segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
334 | hypo = ""
335 | for segment in segments:
336 | hypo = hypo + " " + segment.text
337 |
338 | # raw_truth = truth
339 | # raw_hypo = hypo
340 |
341 | for x in punctuation_all:
342 | truth = truth.replace(x, "")
343 | hypo = hypo.replace(x, "")
344 |
345 | truth = truth.replace(" ", " ")
346 | hypo = hypo.replace(" ", " ")
347 |
348 | if lang == "zh":
349 | truth = " ".join([x for x in truth])
350 | hypo = " ".join([x for x in hypo])
351 | elif lang == "en":
352 | truth = truth.lower()
353 | hypo = hypo.lower()
354 |
355 | measures = compute_measures(truth, hypo)
356 | wer = measures["wer"]
357 |
358 | # ref_list = truth.split(" ")
359 | # subs = measures["substitutions"] / len(ref_list)
360 | # dele = measures["deletions"] / len(ref_list)
361 | # inse = measures["insertions"] / len(ref_list)
362 |
363 | wers.append(wer)
364 |
365 | return wers
366 |
367 |
368 | # SIM Evaluation
369 |
370 |
371 | def run_sim(args):
372 | rank, test_set, ckpt_dir = args
373 | device = f"cuda:{rank}"
374 |
375 | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
376 | state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
377 | model.load_state_dict(state_dict["model"], strict=False)
378 |
379 | use_gpu = True if torch.cuda.is_available() else False
380 | if use_gpu:
381 | model = model.cuda(device)
382 | model.eval()
383 |
384 | sim_list = []
385 | for wav1, wav2, truth in tqdm(test_set):
386 | wav1, sr1 = torchaudio.load(wav1)
387 | wav2, sr2 = torchaudio.load(wav2)
388 |
389 | resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
390 | resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
391 | wav1 = resample1(wav1)
392 | wav2 = resample2(wav2)
393 |
394 | if use_gpu:
395 | wav1 = wav1.cuda(device)
396 | wav2 = wav2.cuda(device)
397 | with torch.no_grad():
398 | emb1 = model(wav1)
399 | emb2 = model(wav2)
400 |
401 | sim = F.cosine_similarity(emb1, emb2)[0].item()
402 | # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
403 | sim_list.append(sim)
404 |
405 | return sim_list
406 |
--------------------------------------------------------------------------------
/src/f5_tts/infer/README.md:
--------------------------------------------------------------------------------
1 | # Inference
2 |
3 | The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
4 |
5 | **More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
6 |
7 | Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
8 |
9 | To avoid possible inference failures, make sure you have seen through the following instructions.
10 |
11 | - Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
12 | - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
13 | - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
14 | - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
15 |
16 |
17 | ## Gradio App
18 |
19 | Currently supported features:
20 |
21 | - Basic TTS with Chunk Inference
22 | - Multi-Style / Multi-Speaker Generation
23 | - Voice Chat powered by Qwen2.5-3B-Instruct
24 |
25 | The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
26 |
27 | The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
28 |
29 | Could also be used as a component for larger application.
30 | ```python
31 | import gradio as gr
32 | from f5_tts.infer.infer_gradio import app
33 |
34 | with gr.Blocks() as main_app:
35 | gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
36 |
37 | # ... other Gradio components
38 |
39 | app.render()
40 |
41 | main_app.launch()
42 | ```
43 |
44 |
45 | ## CLI Inference
46 |
47 | The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
48 |
49 | The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
50 |
51 | For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
52 |
53 | Basically you can inference with flags:
54 | ```bash
55 | # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
56 | f5-tts_infer-cli \
57 | --model "F5-TTS" \
58 | --ref_audio "ref_audio.wav" \
59 | --ref_text "The content, subtitle or transcription of reference audio." \
60 | --gen_text "Some text you want TTS model generate for you."
61 |
62 | # Choose Vocoder
63 | f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file
64 | f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file
65 | ```
66 |
67 | And a `.toml` file would help with more flexible usage.
68 |
69 | ```bash
70 | f5-tts_infer-cli -c custom.toml
71 | ```
72 |
73 | For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
74 |
75 | ```toml
76 | # F5-TTS | E2-TTS
77 | model = "F5-TTS"
78 | ref_audio = "infer/examples/basic/basic_ref_en.wav"
79 | # If an empty "", transcribes the reference audio automatically.
80 | ref_text = "Some call me nature, others call me mother nature."
81 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
82 | # File with text to generate. Ignores the text above.
83 | gen_file = ""
84 | remove_silence = false
85 | output_dir = "tests"
86 | ```
87 |
88 | You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
89 |
90 | ```toml
91 | # F5-TTS | E2-TTS
92 | model = "F5-TTS"
93 | ref_audio = "infer/examples/multi/main.flac"
94 | # If an empty "", transcribes the reference audio automatically.
95 | ref_text = ""
96 | gen_text = ""
97 | # File with text to generate. Ignores the text above.
98 | gen_file = "infer/examples/multi/story.txt"
99 | remove_silence = true
100 | output_dir = "tests"
101 |
102 | [voices.town]
103 | ref_audio = "infer/examples/multi/town.flac"
104 | ref_text = ""
105 |
106 | [voices.country]
107 | ref_audio = "infer/examples/multi/country.flac"
108 | ref_text = ""
109 | ```
110 | You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
111 |
112 | ## Speech Editing
113 |
114 | To test speech editing capabilities, use the following command:
115 |
116 | ```bash
117 | python src/f5_tts/infer/speech_edit.py
118 | ```
119 |
120 | ## Socket Realtime Client
121 |
122 | To communicate with socket server you need to run
123 | ```bash
124 | python src/f5_tts/socket_server.py
125 | ```
126 |
127 |
128 | Then create client to communicate
129 |
130 | ``` python
131 | import socket
132 | import numpy as np
133 | import asyncio
134 | import pyaudio
135 |
136 | async def listen_to_voice(text, server_ip='localhost', server_port=9999):
137 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
138 | client_socket.connect((server_ip, server_port))
139 |
140 | async def play_audio_stream():
141 | buffer = b''
142 | p = pyaudio.PyAudio()
143 | stream = p.open(format=pyaudio.paFloat32,
144 | channels=1,
145 | rate=24000, # Ensure this matches the server's sampling rate
146 | output=True,
147 | frames_per_buffer=2048)
148 |
149 | try:
150 | while True:
151 | chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024)
152 | if not chunk: # End of stream
153 | break
154 | if b"END_OF_AUDIO" in chunk:
155 | buffer += chunk.replace(b"END_OF_AUDIO", b"")
156 | if buffer:
157 | audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy
158 | stream.write(audio_array.tobytes())
159 | break
160 | buffer += chunk
161 | if len(buffer) >= 4096:
162 | audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy
163 | stream.write(audio_array.tobytes())
164 | buffer = buffer[4096:]
165 | finally:
166 | stream.stop_stream()
167 | stream.close()
168 | p.terminate()
169 |
170 | try:
171 | # Send only the text to the server
172 | await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8'))
173 | await play_audio_stream()
174 | print("Audio playback finished.")
175 |
176 | except Exception as e:
177 | print(f"Error in listen_to_voice: {e}")
178 |
179 | finally:
180 | client_socket.close()
181 |
182 | # Example usage: Replace this with your actual server IP and port
183 | async def main():
184 | await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998)
185 |
186 | # Run the main async function
187 | asyncio.run(main())
188 | ```
189 |
190 |
191 |
192 |
--------------------------------------------------------------------------------
/src/f5_tts/infer/SHARED.md:
--------------------------------------------------------------------------------
1 |
2 | # Shared Model Cards
3 |
4 |
5 | ### **Prerequisites of using**
6 | - This document is serving as a quick lookup table for the community training/finetuning result, with various language support.
7 | - The models in this repository are open source and are based on voluntary contributions from contributors.
8 | - The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts.
9 |
10 |
11 | ### **Welcome to share here**
12 | - Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization).
13 | - Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files.
14 | - Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`.
15 |
16 |
17 | ### Supported Languages
18 | - [Multilingual](#multilingual)
19 | - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
20 | - [Mandarin](#mandarin)
21 | - [Japanese](#japanese)
22 | - [F5-TTS Base @ pretrain/finetune @ ja](#f5-tts-base--pretrainfinetune--ja)
23 | - [English](#english)
24 | - [French](#french)
25 | - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
26 |
27 |
28 | ## Multilingual
29 |
30 | #### F5-TTS Base @ pretrain @ zh & en
31 | |Model|🤗Hugging Face|Data (Hours)|Model License|
32 | |:---:|:------------:|:-----------:|:-------------:|
33 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
34 |
35 | ```bash
36 | MODEL_CKPT: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
37 | VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
38 | ```
39 |
40 | *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
41 |
42 |
43 | ## Mandarin
44 |
45 | ## Japanese
46 |
47 | #### F5-TTS Base @ pretrain/finetune @ ja
48 | |Model|🤗Hugging Face|Data (Hours)|Model License|
49 | |:---:|:------------:|:-----------:|:-------------:|
50 | |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
51 |
52 | ```bash
53 | MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
54 | VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
55 | ```
56 |
57 | ## English
58 |
59 |
60 | ## French
61 |
62 | #### French LibriVox @ finetune @ fr
63 | |Model|🤗Hugging Face|Data (Hours)|Model License|
64 | |:---:|:------------:|:-----------:|:-------------:|
65 | |F5-TTS French|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
66 |
67 | ```bash
68 | MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
69 | VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
70 | ```
71 |
72 | - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
73 | - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
74 | - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
75 |
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/basic/basic.toml:
--------------------------------------------------------------------------------
1 | # F5-TTS | E2-TTS
2 | model = "F5-TTS"
3 | ref_audio = "infer/examples/basic/basic_ref_en.wav"
4 | # If an empty "", transcribes the reference audio automatically.
5 | ref_text = "Some call me nature, others call me mother nature."
6 | gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
7 | # File with text to generate. Ignores the text above.
8 | gen_file = ""
9 | remove_silence = false
10 | output_dir = "tests"
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/basic/basic_ref_en.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/basic/basic_ref_en.wav
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/basic/basic_ref_zh.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/basic/basic_ref_zh.wav
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/multi/country.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/multi/country.flac
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/multi/main.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/multi/main.flac
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/multi/story.toml:
--------------------------------------------------------------------------------
1 | # F5-TTS | E2-TTS
2 | model = "F5-TTS"
3 | ref_audio = "infer/examples/multi/main.flac"
4 | # If an empty "", transcribes the reference audio automatically.
5 | ref_text = ""
6 | gen_text = ""
7 | # File with text to generate. Ignores the text above.
8 | gen_file = "infer/examples/multi/story.txt"
9 | remove_silence = true
10 | output_dir = "tests"
11 |
12 | [voices.town]
13 | ref_audio = "infer/examples/multi/town.flac"
14 | ref_text = ""
15 |
16 | [voices.country]
17 | ref_audio = "infer/examples/multi/country.flac"
18 | ref_text = ""
19 |
20 |
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/multi/story.txt:
--------------------------------------------------------------------------------
1 | A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
--------------------------------------------------------------------------------
/src/f5_tts/infer/examples/multi/town.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lpscr/F5-TTS/96946f85fa0c1eecb971f9f30f2e5b922d270e5c/src/f5_tts/infer/examples/multi/town.flac
--------------------------------------------------------------------------------
/src/f5_tts/infer/infer_cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import codecs
3 | import os
4 | import re
5 | from importlib.resources import files
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import soundfile as sf
10 | import tomli
11 | from cached_path import cached_path
12 |
13 | from f5_tts.infer.utils_infer import (
14 | infer_process,
15 | load_model,
16 | load_vocoder,
17 | preprocess_ref_audio_text,
18 | remove_silence_for_generated_wav,
19 | )
20 | from f5_tts.model import DiT, UNetT
21 |
22 | parser = argparse.ArgumentParser(
23 | prog="python3 infer-cli.py",
24 | description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
25 | epilog="Specify options above to override one or more settings from config.",
26 | )
27 | parser.add_argument(
28 | "-c",
29 | "--config",
30 | help="Configuration file. Default=infer/examples/basic/basic.toml",
31 | default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
32 | )
33 | parser.add_argument(
34 | "-m",
35 | "--model",
36 | help="F5-TTS | E2-TTS",
37 | )
38 | parser.add_argument(
39 | "-p",
40 | "--ckpt_file",
41 | help="The Checkpoint .pt",
42 | )
43 | parser.add_argument(
44 | "-v",
45 | "--vocab_file",
46 | help="The vocab .txt",
47 | )
48 | parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49 | parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50 | parser.add_argument(
51 | "-t",
52 | "--gen_text",
53 | type=str,
54 | help="Text to generate.",
55 | )
56 | parser.add_argument(
57 | "-f",
58 | "--gen_file",
59 | type=str,
60 | help="File with text to generate. Ignores --text",
61 | )
62 | parser.add_argument(
63 | "-o",
64 | "--output_dir",
65 | type=str,
66 | help="Path to output folder..",
67 | )
68 | parser.add_argument(
69 | "--remove_silence",
70 | help="Remove silence.",
71 | )
72 | parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
73 | parser.add_argument(
74 | "--load_vocoder_from_local",
75 | action="store_true",
76 | help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
77 | )
78 | parser.add_argument(
79 | "--speed",
80 | type=float,
81 | default=1.0,
82 | help="Adjust the speed of the audio generation (default: 1.0)",
83 | )
84 | args = parser.parse_args()
85 |
86 | config = tomli.load(open(args.config, "rb"))
87 |
88 | ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
89 | ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
90 | gen_text = args.gen_text if args.gen_text else config["gen_text"]
91 | gen_file = args.gen_file if args.gen_file else config["gen_file"]
92 |
93 | # patches for pip pkg user
94 | if "infer/examples/" in ref_audio:
95 | ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
96 | if "infer/examples/" in gen_file:
97 | gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
98 | if "voices" in config:
99 | for voice in config["voices"]:
100 | voice_ref_audio = config["voices"][voice]["ref_audio"]
101 | if "infer/examples/" in voice_ref_audio:
102 | config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
103 |
104 | if gen_file:
105 | gen_text = codecs.open(gen_file, "r", "utf-8").read()
106 | output_dir = args.output_dir if args.output_dir else config["output_dir"]
107 | model = args.model if args.model else config["model"]
108 | ckpt_file = args.ckpt_file if args.ckpt_file else ""
109 | vocab_file = args.vocab_file if args.vocab_file else ""
110 | remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
111 | speed = args.speed
112 | wave_path = Path(output_dir) / "infer_cli_out.wav"
113 | # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114 | if args.vocoder_name == "vocos":
115 | vocoder_local_path = "../checkpoints/vocos-mel-24khz"
116 | elif args.vocoder_name == "bigvgan":
117 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
118 | mel_spec_type = args.vocoder_name
119 |
120 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
121 |
122 |
123 | # load models
124 | if model == "F5-TTS":
125 | model_cls = DiT
126 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
127 | if ckpt_file == "":
128 | if args.vocoder_name == "vocos":
129 | repo_name = "F5-TTS"
130 | exp_name = "F5TTS_Base"
131 | ckpt_step = 1200000
132 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
133 | # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
134 | elif args.vocoder_name == "bigvgan":
135 | repo_name = "F5-TTS"
136 | exp_name = "F5TTS_Base_bigvgan"
137 | ckpt_step = 1250000
138 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
139 |
140 | elif model == "E2-TTS":
141 | model_cls = UNetT
142 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
143 | if ckpt_file == "":
144 | repo_name = "E2-TTS"
145 | exp_name = "E2TTS_Base"
146 | ckpt_step = 1200000
147 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
148 | # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
149 | elif args.vocoder_name == "bigvgan": # TODO: need to test
150 | repo_name = "F5-TTS"
151 | exp_name = "F5TTS_Base_bigvgan"
152 | ckpt_step = 1250000
153 | ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
154 |
155 |
156 | print(f"Using {model}...")
157 | ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
158 |
159 |
160 | def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
161 | main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
162 | if "voices" not in config:
163 | voices = {"main": main_voice}
164 | else:
165 | voices = config["voices"]
166 | voices["main"] = main_voice
167 | for voice in voices:
168 | voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
169 | voices[voice]["ref_audio"], voices[voice]["ref_text"]
170 | )
171 | print("Voice:", voice)
172 | print("Ref_audio:", voices[voice]["ref_audio"])
173 | print("Ref_text:", voices[voice]["ref_text"])
174 |
175 | generated_audio_segments = []
176 | reg1 = r"(?=\[\w+\])"
177 | chunks = re.split(reg1, text_gen)
178 | reg2 = r"\[(\w+)\]"
179 | for text in chunks:
180 | if not text.strip():
181 | continue
182 | match = re.match(reg2, text)
183 | if match:
184 | voice = match[1]
185 | else:
186 | print("No voice tag found, using main.")
187 | voice = "main"
188 | if voice not in voices:
189 | print(f"Voice {voice} not found, using main.")
190 | voice = "main"
191 | text = re.sub(reg2, "", text)
192 | gen_text = text.strip()
193 | ref_audio = voices[voice]["ref_audio"]
194 | ref_text = voices[voice]["ref_text"]
195 | print(f"Voice: {voice}")
196 | audio, final_sample_rate, spectragram = infer_process(
197 | ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
198 | )
199 | generated_audio_segments.append(audio)
200 |
201 | if generated_audio_segments:
202 | final_wave = np.concatenate(generated_audio_segments)
203 |
204 | if not os.path.exists(output_dir):
205 | os.makedirs(output_dir)
206 |
207 | with open(wave_path, "wb") as f:
208 | sf.write(f.name, final_wave, final_sample_rate)
209 | # Remove silence
210 | if remove_silence:
211 | remove_silence_for_generated_wav(f.name)
212 | print(f.name)
213 |
214 |
215 | def main():
216 | main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
217 |
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/src/f5_tts/infer/speech_edit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torchaudio
6 |
7 | from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
8 | from f5_tts.model import CFM, DiT, UNetT
9 | from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
10 |
11 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
12 |
13 |
14 | # --------------------- Dataset Settings -------------------- #
15 |
16 | target_sample_rate = 24000
17 | n_mel_channels = 100
18 | hop_length = 256
19 | win_length = 1024
20 | n_fft = 1024
21 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
22 | target_rms = 0.1
23 |
24 | tokenizer = "pinyin"
25 | dataset_name = "Emilia_ZH_EN"
26 |
27 |
28 | # ---------------------- infer setting ---------------------- #
29 |
30 | seed = None # int | None
31 |
32 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
33 | ckpt_step = 1200000
34 |
35 | nfe_step = 32 # 16, 32
36 | cfg_strength = 2.0
37 | ode_method = "euler" # euler | midpoint
38 | sway_sampling_coef = -1.0
39 | speed = 1.0
40 |
41 | if exp_name == "F5TTS_Base":
42 | model_cls = DiT
43 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
44 |
45 | elif exp_name == "E2TTS_Base":
46 | model_cls = UNetT
47 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
48 |
49 | ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
50 | output_dir = "tests"
51 |
52 | # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
53 | # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
54 | # [write the origin_text into a file, e.g. tests/test_edit.txt]
55 | # ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
56 | # [result will be saved at same path of audio file]
57 | # [--language "zho" for Chinese, "eng" for English]
58 | # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
59 |
60 | audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
61 | origin_text = "Some call me nature, others call me mother nature."
62 | target_text = "Some call me optimist, others call me realist."
63 | parts_to_edit = [
64 | [1.42, 2.44],
65 | [4.04, 4.9],
66 | ] # stard_ends of "nature" & "mother nature", in seconds
67 | fix_duration = [
68 | 1.2,
69 | 1,
70 | ] # fix duration for "optimist" & "realist", in seconds
71 |
72 | # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
73 | # origin_text = "对,这就是我,万人敬仰的太乙真人。"
74 | # target_text = "对,那就是你,万人敬仰的太白金星。"
75 | # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
76 | # fix_duration = None # use origin text duration
77 |
78 |
79 | # -------------------------------------------------#
80 |
81 | use_ema = True
82 |
83 | if not os.path.exists(output_dir):
84 | os.makedirs(output_dir)
85 |
86 | # Vocoder model
87 | local = False
88 | if mel_spec_type == "vocos":
89 | vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
90 | elif mel_spec_type == "bigvgan":
91 | vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
92 | vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
93 |
94 | # Tokenizer
95 | vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96 |
97 | # Model
98 | model = CFM(
99 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
100 | mel_spec_kwargs=dict(
101 | n_fft=n_fft,
102 | hop_length=hop_length,
103 | win_length=win_length,
104 | n_mel_channels=n_mel_channels,
105 | target_sample_rate=target_sample_rate,
106 | mel_spec_type=mel_spec_type,
107 | ),
108 | odeint_kwargs=dict(
109 | method=ode_method,
110 | ),
111 | vocab_char_map=vocab_char_map,
112 | ).to(device)
113 |
114 | dtype = torch.float32 if mel_spec_type == "bigvgan" else None
115 | model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
116 |
117 | # Audio
118 | audio, sr = torchaudio.load(audio_to_edit)
119 | if audio.shape[0] > 1:
120 | audio = torch.mean(audio, dim=0, keepdim=True)
121 | rms = torch.sqrt(torch.mean(torch.square(audio)))
122 | if rms < target_rms:
123 | audio = audio * target_rms / rms
124 | if sr != target_sample_rate:
125 | resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126 | audio = resampler(audio)
127 | offset = 0
128 | audio_ = torch.zeros(1, 0)
129 | edit_mask = torch.zeros(1, 0, dtype=torch.bool)
130 | for part in parts_to_edit:
131 | start, end = part
132 | part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133 | part_dur = part_dur * target_sample_rate
134 | start = start * target_sample_rate
135 | audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
136 | edit_mask = torch.cat(
137 | (
138 | edit_mask,
139 | torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
140 | torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
141 | ),
142 | dim=-1,
143 | )
144 | offset = end * target_sample_rate
145 | # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
146 | edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
147 | audio = audio.to(device)
148 | edit_mask = edit_mask.to(device)
149 |
150 | # Text
151 | text_list = [target_text]
152 | if tokenizer == "pinyin":
153 | final_text_list = convert_char_to_pinyin(text_list)
154 | else:
155 | final_text_list = [text_list]
156 | print(f"text : {text_list}")
157 | print(f"pinyin: {final_text_list}")
158 |
159 | # Duration
160 | ref_audio_len = 0
161 | duration = audio.shape[-1] // hop_length
162 |
163 | # Inference
164 | with torch.inference_mode():
165 | generated, trajectory = model.sample(
166 | cond=audio,
167 | text=final_text_list,
168 | duration=duration,
169 | steps=nfe_step,
170 | cfg_strength=cfg_strength,
171 | sway_sampling_coef=sway_sampling_coef,
172 | seed=seed,
173 | edit_mask=edit_mask,
174 | )
175 | print(f"Generated mel: {generated.shape}")
176 |
177 | # Final result
178 | generated = generated.to(torch.float32)
179 | generated = generated[:, ref_audio_len:, :]
180 | gen_mel_spec = generated.permute(0, 2, 1)
181 | if mel_spec_type == "vocos":
182 | generated_wave = vocoder.decode(gen_mel_spec)
183 | elif mel_spec_type == "bigvgan":
184 | generated_wave = vocoder(gen_mel_spec)
185 |
186 | if rms < target_rms:
187 | generated_wave = generated_wave * rms / target_rms
188 |
189 | save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
190 | torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
191 | print(f"Generated wav: {generated_wave.shape}")
192 |
--------------------------------------------------------------------------------
/src/f5_tts/model/__init__.py:
--------------------------------------------------------------------------------
1 | from f5_tts.model.cfm import CFM
2 |
3 | from f5_tts.model.backbones.unett import UNetT
4 | from f5_tts.model.backbones.dit import DiT
5 | from f5_tts.model.backbones.mmdit import MMDiT
6 |
7 | from f5_tts.model.trainer import Trainer
8 |
9 |
10 | __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
11 |
--------------------------------------------------------------------------------
/src/f5_tts/model/backbones/README.md:
--------------------------------------------------------------------------------
1 | ## Backbones quick introduction
2 |
3 |
4 | ### unett.py
5 | - flat unet transformer
6 | - structure same as in e2-tts & voicebox paper except using rotary pos emb
7 | - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8 |
9 | ### dit.py
10 | - adaln-zero dit
11 | - embedded timestep as condition
12 | - concatted noised_input + masked_cond + embedded_text, linear proj in
13 | - possible abs pos emb & convnextv2 blocks for embedded text before concat
14 | - possible long skip connection (first layer to last layer)
15 |
16 | ### mmdit.py
17 | - sd3 structure
18 | - timestep as condition
19 | - left stream: text embedded and applied a abs pos emb
20 | - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
21 |
--------------------------------------------------------------------------------
/src/f5_tts/model/backbones/dit.py:
--------------------------------------------------------------------------------
1 | """
2 | ein notation:
3 | b - batch
4 | n - sequence
5 | nt - text sequence
6 | nw - raw wave length
7 | d - dimension
8 | """
9 |
10 | from __future__ import annotations
11 |
12 | import torch
13 | from torch import nn
14 | import torch.nn.functional as F
15 |
16 | from x_transformers.x_transformers import RotaryEmbedding
17 |
18 | from f5_tts.model.modules import (
19 | TimestepEmbedding,
20 | ConvNeXtV2Block,
21 | ConvPositionEmbedding,
22 | DiTBlock,
23 | AdaLayerNormZero_Final,
24 | precompute_freqs_cis,
25 | get_pos_embed_indices,
26 | )
27 |
28 |
29 | # Text embedding
30 |
31 |
32 | class TextEmbedding(nn.Module):
33 | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34 | super().__init__()
35 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36 |
37 | if conv_layers > 0:
38 | self.extra_modeling = True
39 | self.precompute_max_pos = 4096 # ~44s of 24khz audio
40 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41 | self.text_blocks = nn.Sequential(
42 | *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43 | )
44 | else:
45 | self.extra_modeling = False
46 |
47 | def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50 | batch, text_len = text.shape[0], text.shape[1]
51 | text = F.pad(text, (0, seq_len - text_len), value=0)
52 |
53 | if drop_text: # cfg for text
54 | text = torch.zeros_like(text)
55 |
56 | text = self.text_embed(text) # b n -> b n d
57 |
58 | # possible extra modeling
59 | if self.extra_modeling:
60 | # sinus pos emb
61 | batch_start = torch.zeros((batch,), dtype=torch.long)
62 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
63 | text_pos_embed = self.freqs_cis[pos_idx]
64 | text = text + text_pos_embed
65 |
66 | # convnextv2 blocks
67 | text = self.text_blocks(text)
68 |
69 | return text
70 |
71 |
72 | # noised input audio and context mixing embedding
73 |
74 |
75 | class InputEmbedding(nn.Module):
76 | def __init__(self, mel_dim, text_dim, out_dim):
77 | super().__init__()
78 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79 | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80 |
81 | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82 | if drop_audio_cond: # cfg for cond audio
83 | cond = torch.zeros_like(cond)
84 |
85 | x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86 | x = self.conv_pos_embed(x) + x
87 | return x
88 |
89 |
90 | # Transformer backbone using DiT blocks
91 |
92 |
93 | class DiT(nn.Module):
94 | def __init__(
95 | self,
96 | *,
97 | dim,
98 | depth=8,
99 | heads=8,
100 | dim_head=64,
101 | dropout=0.1,
102 | ff_mult=4,
103 | mel_dim=100,
104 | text_num_embeds=256,
105 | text_dim=None,
106 | conv_layers=0,
107 | long_skip_connection=False,
108 | ):
109 | super().__init__()
110 |
111 | self.time_embed = TimestepEmbedding(dim)
112 | if text_dim is None:
113 | text_dim = mel_dim
114 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116 |
117 | self.rotary_embed = RotaryEmbedding(dim_head)
118 |
119 | self.dim = dim
120 | self.depth = depth
121 |
122 | self.transformer_blocks = nn.ModuleList(
123 | [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
124 | )
125 | self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126 |
127 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128 | self.proj_out = nn.Linear(dim, mel_dim)
129 |
130 | def forward(
131 | self,
132 | x: float["b n d"], # nosied input audio # noqa: F722
133 | cond: float["b n d"], # masked cond audio # noqa: F722
134 | text: int["b nt"], # text # noqa: F722
135 | time: float["b"] | float[""], # time step # noqa: F821 F722
136 | drop_audio_cond, # cfg for cond audio
137 | drop_text, # cfg for text
138 | mask: bool["b n"] | None = None, # noqa: F722
139 | ):
140 | batch, seq_len = x.shape[0], x.shape[1]
141 | if time.ndim == 0:
142 | time = time.repeat(batch)
143 |
144 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145 | t = self.time_embed(time)
146 | text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147 | x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148 |
149 | rope = self.rotary_embed.forward_from_seq_len(seq_len)
150 |
151 | if self.long_skip_connection is not None:
152 | residual = x
153 |
154 | for block in self.transformer_blocks:
155 | x = block(x, t, mask=mask, rope=rope)
156 |
157 | if self.long_skip_connection is not None:
158 | x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159 |
160 | x = self.norm_out(x, t)
161 | output = self.proj_out(x)
162 |
163 | return output
164 |
--------------------------------------------------------------------------------
/src/f5_tts/model/backbones/mmdit.py:
--------------------------------------------------------------------------------
1 | """
2 | ein notation:
3 | b - batch
4 | n - sequence
5 | nt - text sequence
6 | nw - raw wave length
7 | d - dimension
8 | """
9 |
10 | from __future__ import annotations
11 |
12 | import torch
13 | from torch import nn
14 |
15 | from x_transformers.x_transformers import RotaryEmbedding
16 |
17 | from f5_tts.model.modules import (
18 | TimestepEmbedding,
19 | ConvPositionEmbedding,
20 | MMDiTBlock,
21 | AdaLayerNormZero_Final,
22 | precompute_freqs_cis,
23 | get_pos_embed_indices,
24 | )
25 |
26 |
27 | # text embedding
28 |
29 |
30 | class TextEmbedding(nn.Module):
31 | def __init__(self, out_dim, text_num_embeds):
32 | super().__init__()
33 | self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34 |
35 | self.precompute_max_pos = 1024
36 | self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37 |
38 | def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39 | text = text + 1
40 | if drop_text:
41 | text = torch.zeros_like(text)
42 | text = self.text_embed(text)
43 |
44 | # sinus pos emb
45 | batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46 | batch_text_len = text.shape[1]
47 | pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48 | text_pos_embed = self.freqs_cis[pos_idx]
49 |
50 | text = text + text_pos_embed
51 |
52 | return text
53 |
54 |
55 | # noised input & masked cond audio embedding
56 |
57 |
58 | class AudioEmbedding(nn.Module):
59 | def __init__(self, in_dim, out_dim):
60 | super().__init__()
61 | self.linear = nn.Linear(2 * in_dim, out_dim)
62 | self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63 |
64 | def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65 | if drop_audio_cond:
66 | cond = torch.zeros_like(cond)
67 | x = torch.cat((x, cond), dim=-1)
68 | x = self.linear(x)
69 | x = self.conv_pos_embed(x) + x
70 | return x
71 |
72 |
73 | # Transformer backbone using MM-DiT blocks
74 |
75 |
76 | class MMDiT(nn.Module):
77 | def __init__(
78 | self,
79 | *,
80 | dim,
81 | depth=8,
82 | heads=8,
83 | dim_head=64,
84 | dropout=0.1,
85 | ff_mult=4,
86 | text_num_embeds=256,
87 | mel_dim=100,
88 | ):
89 | super().__init__()
90 |
91 | self.time_embed = TimestepEmbedding(dim)
92 | self.text_embed = TextEmbedding(dim, text_num_embeds)
93 | self.audio_embed = AudioEmbedding(mel_dim, dim)
94 |
95 | self.rotary_embed = RotaryEmbedding(dim_head)
96 |
97 | self.dim = dim
98 | self.depth = depth
99 |
100 | self.transformer_blocks = nn.ModuleList(
101 | [
102 | MMDiTBlock(
103 | dim=dim,
104 | heads=heads,
105 | dim_head=dim_head,
106 | dropout=dropout,
107 | ff_mult=ff_mult,
108 | context_pre_only=i == depth - 1,
109 | )
110 | for i in range(depth)
111 | ]
112 | )
113 | self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114 | self.proj_out = nn.Linear(dim, mel_dim)
115 |
116 | def forward(
117 | self,
118 | x: float["b n d"], # nosied input audio # noqa: F722
119 | cond: float["b n d"], # masked cond audio # noqa: F722
120 | text: int["b nt"], # text # noqa: F722
121 | time: float["b"] | float[""], # time step # noqa: F821 F722
122 | drop_audio_cond, # cfg for cond audio
123 | drop_text, # cfg for text
124 | mask: bool["b n"] | None = None, # noqa: F722
125 | ):
126 | batch = x.shape[0]
127 | if time.ndim == 0:
128 | time = time.repeat(batch)
129 |
130 | # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131 | t = self.time_embed(time)
132 | c = self.text_embed(text, drop_text=drop_text)
133 | x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134 |
135 | seq_len = x.shape[1]
136 | text_len = text.shape[1]
137 | rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138 | rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139 |
140 | for block in self.transformer_blocks:
141 | c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142 |
143 | x = self.norm_out(x, t)
144 | output = self.proj_out(x)
145 |
146 | return output
147 |
--------------------------------------------------------------------------------
/src/f5_tts/model/backbones/unett.py:
--------------------------------------------------------------------------------
1 | """
2 | ein notation:
3 | b - batch
4 | n - sequence
5 | nt - text sequence
6 | nw - raw wave length
7 | d - dimension
8 | """
9 |
10 | from __future__ import annotations
11 | from typing import Literal
12 |
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F
16 |
17 | from x_transformers import RMSNorm
18 | from x_transformers.x_transformers import RotaryEmbedding
19 |
20 | from f5_tts.model.modules import (
21 | TimestepEmbedding,
22 | ConvNeXtV2Block,
23 | ConvPositionEmbedding,
24 | Attention,
25 | AttnProcessor,
26 | FeedForward,
27 | precompute_freqs_cis,
28 | get_pos_embed_indices,
29 | )
30 |
31 |
32 | # Text embedding
33 |
34 |
35 | class TextEmbedding(nn.Module):
36 | def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37 | super().__init__()
38 | self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39 |
40 | if conv_layers > 0:
41 | self.extra_modeling = True
42 | self.precompute_max_pos = 4096 # ~44s of 24khz audio
43 | self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44 | self.text_blocks = nn.Sequential(
45 | *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46 | )
47 | else:
48 | self.extra_modeling = False
49 |
50 | def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51 | text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52 | text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53 | batch, text_len = text.shape[0], text.shape[1]
54 | text = F.pad(text, (0, seq_len - text_len), value=0)
55 |
56 | if drop_text: # cfg for text
57 | text = torch.zeros_like(text)
58 |
59 | text = self.text_embed(text) # b n -> b n d
60 |
61 | # possible extra modeling
62 | if self.extra_modeling:
63 | # sinus pos emb
64 | batch_start = torch.zeros((batch,), dtype=torch.long)
65 | pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
66 | text_pos_embed = self.freqs_cis[pos_idx]
67 | text = text + text_pos_embed
68 |
69 | # convnextv2 blocks
70 | text = self.text_blocks(text)
71 |
72 | return text
73 |
74 |
75 | # noised input audio and context mixing embedding
76 |
77 |
78 | class InputEmbedding(nn.Module):
79 | def __init__(self, mel_dim, text_dim, out_dim):
80 | super().__init__()
81 | self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82 | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83 |
84 | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85 | if drop_audio_cond: # cfg for cond audio
86 | cond = torch.zeros_like(cond)
87 |
88 | x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89 | x = self.conv_pos_embed(x) + x
90 | return x
91 |
92 |
93 | # Flat UNet Transformer backbone
94 |
95 |
96 | class UNetT(nn.Module):
97 | def __init__(
98 | self,
99 | *,
100 | dim,
101 | depth=8,
102 | heads=8,
103 | dim_head=64,
104 | dropout=0.1,
105 | ff_mult=4,
106 | mel_dim=100,
107 | text_num_embeds=256,
108 | text_dim=None,
109 | conv_layers=0,
110 | skip_connect_type: Literal["add", "concat", "none"] = "concat",
111 | ):
112 | super().__init__()
113 | assert depth % 2 == 0, "UNet-Transformer's depth should be even."
114 |
115 | self.time_embed = TimestepEmbedding(dim)
116 | if text_dim is None:
117 | text_dim = mel_dim
118 | self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119 | self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120 |
121 | self.rotary_embed = RotaryEmbedding(dim_head)
122 |
123 | # transformer layers & skip connections
124 |
125 | self.dim = dim
126 | self.skip_connect_type = skip_connect_type
127 | needs_skip_proj = skip_connect_type == "concat"
128 |
129 | self.depth = depth
130 | self.layers = nn.ModuleList([])
131 |
132 | for idx in range(depth):
133 | is_later_half = idx >= (depth // 2)
134 |
135 | attn_norm = RMSNorm(dim)
136 | attn = Attention(
137 | processor=AttnProcessor(),
138 | dim=dim,
139 | heads=heads,
140 | dim_head=dim_head,
141 | dropout=dropout,
142 | )
143 |
144 | ff_norm = RMSNorm(dim)
145 | ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146 |
147 | skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148 |
149 | self.layers.append(
150 | nn.ModuleList(
151 | [
152 | skip_proj,
153 | attn_norm,
154 | attn,
155 | ff_norm,
156 | ff,
157 | ]
158 | )
159 | )
160 |
161 | self.norm_out = RMSNorm(dim)
162 | self.proj_out = nn.Linear(dim, mel_dim)
163 |
164 | def forward(
165 | self,
166 | x: float["b n d"], # nosied input audio # noqa: F722
167 | cond: float["b n d"], # masked cond audio # noqa: F722
168 | text: int["b nt"], # text # noqa: F722
169 | time: float["b"] | float[""], # time step # noqa: F821 F722
170 | drop_audio_cond, # cfg for cond audio
171 | drop_text, # cfg for text
172 | mask: bool["b n"] | None = None, # noqa: F722
173 | ):
174 | batch, seq_len = x.shape[0], x.shape[1]
175 | if time.ndim == 0:
176 | time = time.repeat(batch)
177 |
178 | # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179 | t = self.time_embed(time)
180 | text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181 | x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182 |
183 | # postfix time t to input x, [b n d] -> [b n+1 d]
184 | x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185 | if mask is not None:
186 | mask = F.pad(mask, (1, 0), value=1)
187 |
188 | rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189 |
190 | # flat unet transformer
191 | skip_connect_type = self.skip_connect_type
192 | skips = []
193 | for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
194 | layer = idx + 1
195 |
196 | # skip connection logic
197 | is_first_half = layer <= (self.depth // 2)
198 | is_later_half = not is_first_half
199 |
200 | if is_first_half:
201 | skips.append(x)
202 |
203 | if is_later_half:
204 | skip = skips.pop()
205 | if skip_connect_type == "concat":
206 | x = torch.cat((x, skip), dim=-1)
207 | x = maybe_skip_proj(x)
208 | elif skip_connect_type == "add":
209 | x = x + skip
210 |
211 | # attention and feedforward blocks
212 | x = attn(attn_norm(x), rope=rope, mask=mask) + x
213 | x = ff(ff_norm(x)) + x
214 |
215 | assert len(skips) == 0
216 |
217 | x = self.norm_out(x)[:, 1:, :] # unpack t from x
218 |
219 | return self.proj_out(x)
220 |
--------------------------------------------------------------------------------
/src/f5_tts/model/cfm.py:
--------------------------------------------------------------------------------
1 | """
2 | ein notation:
3 | b - batch
4 | n - sequence
5 | nt - text sequence
6 | nw - raw wave length
7 | d - dimension
8 | """
9 |
10 | from __future__ import annotations
11 |
12 | from random import random
13 | from typing import Callable
14 |
15 | import torch
16 | import torch.nn.functional as F
17 | from torch import nn
18 | from torch.nn.utils.rnn import pad_sequence
19 | from torchdiffeq import odeint
20 |
21 | from f5_tts.model.modules import MelSpec
22 | from f5_tts.model.utils import (
23 | default,
24 | exists,
25 | lens_to_mask,
26 | list_str_to_idx,
27 | list_str_to_tensor,
28 | mask_from_frac_lengths,
29 | )
30 |
31 |
32 | class CFM(nn.Module):
33 | def __init__(
34 | self,
35 | transformer: nn.Module,
36 | sigma=0.0,
37 | odeint_kwargs: dict = dict(
38 | # atol = 1e-5,
39 | # rtol = 1e-5,
40 | method="euler" # 'midpoint'
41 | ),
42 | audio_drop_prob=0.3,
43 | cond_drop_prob=0.2,
44 | num_channels=None,
45 | mel_spec_module: nn.Module | None = None,
46 | mel_spec_kwargs: dict = dict(),
47 | frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48 | vocab_char_map: dict[str:int] | None = None,
49 | ):
50 | super().__init__()
51 |
52 | self.frac_lengths_mask = frac_lengths_mask
53 |
54 | # mel spec
55 | self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56 | num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57 | self.num_channels = num_channels
58 |
59 | # classifier-free guidance
60 | self.audio_drop_prob = audio_drop_prob
61 | self.cond_drop_prob = cond_drop_prob
62 |
63 | # transformer
64 | self.transformer = transformer
65 | dim = transformer.dim
66 | self.dim = dim
67 |
68 | # conditional flow related
69 | self.sigma = sigma
70 |
71 | # sampling related
72 | self.odeint_kwargs = odeint_kwargs
73 |
74 | # vocab map for tokenization
75 | self.vocab_char_map = vocab_char_map
76 |
77 | @property
78 | def device(self):
79 | return next(self.parameters()).device
80 |
81 | @torch.no_grad()
82 | def sample(
83 | self,
84 | cond: float["b n d"] | float["b nw"], # noqa: F722
85 | text: int["b nt"] | list[str], # noqa: F722
86 | duration: int | int["b"], # noqa: F821
87 | *,
88 | lens: int["b"] | None = None, # noqa: F821
89 | steps=32,
90 | cfg_strength=1.0,
91 | sway_sampling_coef=None,
92 | seed: int | None = None,
93 | max_duration=4096,
94 | vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95 | no_ref_audio=False,
96 | duplicate_test=False,
97 | t_inter=0.1,
98 | edit_mask=None,
99 | ):
100 | self.eval()
101 | # raw wave
102 |
103 | if cond.ndim == 2:
104 | cond = self.mel_spec(cond)
105 | cond = cond.permute(0, 2, 1)
106 | assert cond.shape[-1] == self.num_channels
107 |
108 | cond = cond.to(next(self.parameters()).dtype)
109 |
110 | batch, cond_seq_len, device = *cond.shape[:2], cond.device
111 | if not exists(lens):
112 | lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
113 |
114 | # text
115 |
116 | if isinstance(text, list):
117 | if exists(self.vocab_char_map):
118 | text = list_str_to_idx(text, self.vocab_char_map).to(device)
119 | else:
120 | text = list_str_to_tensor(text).to(device)
121 | assert text.shape[0] == batch
122 |
123 | if exists(text):
124 | text_lens = (text != -1).sum(dim=-1)
125 | lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
126 |
127 | # duration
128 |
129 | cond_mask = lens_to_mask(lens)
130 | if edit_mask is not None:
131 | cond_mask = cond_mask & edit_mask
132 |
133 | if isinstance(duration, int):
134 | duration = torch.full((batch,), duration, device=device, dtype=torch.long)
135 |
136 | duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
137 | duration = duration.clamp(max=max_duration)
138 | max_duration = duration.amax()
139 |
140 | # duplicate test corner for inner time step oberservation
141 | if duplicate_test:
142 | test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
143 |
144 | cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
145 | cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
146 | cond_mask = cond_mask.unsqueeze(-1)
147 | step_cond = torch.where(
148 | cond_mask, cond, torch.zeros_like(cond)
149 | ) # allow direct control (cut cond audio) with lens passed in
150 |
151 | if batch > 1:
152 | mask = lens_to_mask(duration)
153 | else: # save memory and speed up, as single inference need no mask currently
154 | mask = None
155 |
156 | # test for no ref audio
157 | if no_ref_audio:
158 | cond = torch.zeros_like(cond)
159 |
160 | # neural ode
161 |
162 | def fn(t, x):
163 | # at each step, conditioning is fixed
164 | # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
165 |
166 | # predict flow
167 | pred = self.transformer(
168 | x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
169 | )
170 | if cfg_strength < 1e-5:
171 | return pred
172 |
173 | null_pred = self.transformer(
174 | x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
175 | )
176 | return pred + (pred - null_pred) * cfg_strength
177 |
178 | # noise input
179 | # to make sure batch inference result is same with different batch size, and for sure single inference
180 | # still some difference maybe due to convolutional layers
181 | y0 = []
182 | for dur in duration:
183 | if exists(seed):
184 | torch.manual_seed(seed)
185 | y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
186 | y0 = pad_sequence(y0, padding_value=0, batch_first=True)
187 |
188 | t_start = 0
189 |
190 | # duplicate test corner for inner time step oberservation
191 | if duplicate_test:
192 | t_start = t_inter
193 | y0 = (1 - t_start) * y0 + t_start * test_cond
194 | steps = int(steps * (1 - t_start))
195 |
196 | t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
197 | if sway_sampling_coef is not None:
198 | t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
199 |
200 | trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
201 |
202 | sampled = trajectory[-1]
203 | out = sampled
204 | out = torch.where(cond_mask, cond, out)
205 |
206 | if exists(vocoder):
207 | out = out.permute(0, 2, 1)
208 | out = vocoder(out)
209 |
210 | return out, trajectory
211 |
212 | def forward(
213 | self,
214 | inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
215 | text: int["b nt"] | list[str], # noqa: F722
216 | *,
217 | lens: int["b"] | None = None, # noqa: F821
218 | noise_scheduler: str | None = None,
219 | ):
220 | # handle raw wave
221 | if inp.ndim == 2:
222 | inp = self.mel_spec(inp)
223 | inp = inp.permute(0, 2, 1)
224 | assert inp.shape[-1] == self.num_channels
225 |
226 | batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
227 |
228 | # handle text as string
229 | if isinstance(text, list):
230 | if exists(self.vocab_char_map):
231 | text = list_str_to_idx(text, self.vocab_char_map).to(device)
232 | else:
233 | text = list_str_to_tensor(text).to(device)
234 | assert text.shape[0] == batch
235 |
236 | # lens and mask
237 | if not exists(lens):
238 | lens = torch.full((batch,), seq_len, device=device)
239 |
240 | mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
241 |
242 | # get a random span to mask out for training conditionally
243 | frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
244 | rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
245 |
246 | if exists(mask):
247 | rand_span_mask &= mask
248 |
249 | # mel is x1
250 | x1 = inp
251 |
252 | # x0 is gaussian noise
253 | x0 = torch.randn_like(x1)
254 |
255 | # time step
256 | time = torch.rand((batch,), dtype=dtype, device=self.device)
257 | # TODO. noise_scheduler
258 |
259 | # sample xt (φ_t(x) in the paper)
260 | t = time.unsqueeze(-1).unsqueeze(-1)
261 | φ = (1 - t) * x0 + t * x1
262 | flow = x1 - x0
263 |
264 | # only predict what is within the random mask span for infilling
265 | cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
266 |
267 | # transformer and cfg training with a drop rate
268 | drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
269 | if random() < self.cond_drop_prob: # p_uncond in voicebox paper
270 | drop_audio_cond = True
271 | drop_text = True
272 | else:
273 | drop_text = False
274 |
275 | # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
276 | # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
277 | pred = self.transformer(
278 | x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
279 | )
280 |
281 | # flow matching loss
282 | loss = F.mse_loss(pred, flow, reduction="none")
283 | loss = loss[rand_span_mask]
284 |
285 | return loss.mean(), cond, pred
286 |
--------------------------------------------------------------------------------
/src/f5_tts/model/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from importlib.resources import files
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torchaudio
8 | from datasets import Dataset as Dataset_
9 | from datasets import load_from_disk
10 | from torch import nn
11 | from torch.utils.data import Dataset, Sampler
12 | from tqdm import tqdm
13 |
14 | from f5_tts.model.modules import MelSpec
15 | from f5_tts.model.utils import default
16 |
17 |
18 | class HFDataset(Dataset):
19 | def __init__(
20 | self,
21 | hf_dataset: Dataset,
22 | target_sample_rate=24_000,
23 | n_mel_channels=100,
24 | hop_length=256,
25 | n_fft=1024,
26 | win_length=1024,
27 | mel_spec_type="vocos",
28 | ):
29 | self.data = hf_dataset
30 | self.target_sample_rate = target_sample_rate
31 | self.hop_length = hop_length
32 |
33 | self.mel_spectrogram = MelSpec(
34 | n_fft=n_fft,
35 | hop_length=hop_length,
36 | win_length=win_length,
37 | n_mel_channels=n_mel_channels,
38 | target_sample_rate=target_sample_rate,
39 | mel_spec_type=mel_spec_type,
40 | )
41 |
42 | def get_frame_len(self, index):
43 | row = self.data[index]
44 | audio = row["audio"]["array"]
45 | sample_rate = row["audio"]["sampling_rate"]
46 | return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
47 |
48 | def __len__(self):
49 | return len(self.data)
50 |
51 | def __getitem__(self, index):
52 | row = self.data[index]
53 | audio = row["audio"]["array"]
54 |
55 | # logger.info(f"Audio shape: {audio.shape}")
56 |
57 | sample_rate = row["audio"]["sampling_rate"]
58 | duration = audio.shape[-1] / sample_rate
59 |
60 | if duration > 30 or duration < 0.3:
61 | return self.__getitem__((index + 1) % len(self.data))
62 |
63 | audio_tensor = torch.from_numpy(audio).float()
64 |
65 | if sample_rate != self.target_sample_rate:
66 | resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
67 | audio_tensor = resampler(audio_tensor)
68 |
69 | audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
70 |
71 | mel_spec = self.mel_spectrogram(audio_tensor)
72 |
73 | mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
74 |
75 | text = row["text"]
76 |
77 | return dict(
78 | mel_spec=mel_spec,
79 | text=text,
80 | )
81 |
82 |
83 | class CustomDataset(Dataset):
84 | def __init__(
85 | self,
86 | custom_dataset: Dataset,
87 | durations=None,
88 | target_sample_rate=24_000,
89 | hop_length=256,
90 | n_mel_channels=100,
91 | n_fft=1024,
92 | win_length=1024,
93 | mel_spec_type="vocos",
94 | preprocessed_mel=False,
95 | mel_spec_module: nn.Module | None = None,
96 | ):
97 | self.data = custom_dataset
98 | self.durations = durations
99 | self.target_sample_rate = target_sample_rate
100 | self.hop_length = hop_length
101 | self.n_fft = n_fft
102 | self.win_length = win_length
103 | self.mel_spec_type = mel_spec_type
104 | self.preprocessed_mel = preprocessed_mel
105 |
106 | if not preprocessed_mel:
107 | self.mel_spectrogram = default(
108 | mel_spec_module,
109 | MelSpec(
110 | n_fft=n_fft,
111 | hop_length=hop_length,
112 | win_length=win_length,
113 | n_mel_channels=n_mel_channels,
114 | target_sample_rate=target_sample_rate,
115 | mel_spec_type=mel_spec_type,
116 | ),
117 | )
118 |
119 | def get_frame_len(self, index):
120 | if (
121 | self.durations is not None
122 | ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
123 | return self.durations[index] * self.target_sample_rate / self.hop_length
124 | return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
125 |
126 | def __len__(self):
127 | return len(self.data)
128 |
129 | def __getitem__(self, index):
130 | row = self.data[index]
131 | audio_path = row["audio_path"]
132 | text = row["text"]
133 | duration = row["duration"]
134 |
135 | if self.preprocessed_mel:
136 | mel_spec = torch.tensor(row["mel_spec"])
137 |
138 | else:
139 | audio, source_sample_rate = torchaudio.load(audio_path)
140 | if audio.shape[0] > 1:
141 | audio = torch.mean(audio, dim=0, keepdim=True)
142 |
143 | if duration > 30 or duration < 0.3:
144 | return self.__getitem__((index + 1) % len(self.data))
145 |
146 | if source_sample_rate != self.target_sample_rate:
147 | resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
148 | audio = resampler(audio)
149 |
150 | mel_spec = self.mel_spectrogram(audio)
151 | mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
152 |
153 | return dict(
154 | mel_spec=mel_spec,
155 | text=text,
156 | )
157 |
158 |
159 | # Dynamic Batch Sampler
160 |
161 |
162 | class DynamicBatchSampler(Sampler[list[int]]):
163 | """Extension of Sampler that will do the following:
164 | 1. Change the batch size (essentially number of sequences)
165 | in a batch to ensure that the total number of frames are less
166 | than a certain threshold.
167 | 2. Make sure the padding efficiency in the batch is high.
168 | """
169 |
170 | def __init__(
171 | self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
172 | ):
173 | self.sampler = sampler
174 | self.frames_threshold = frames_threshold
175 | self.max_samples = max_samples
176 |
177 | indices, batches = [], []
178 | data_source = self.sampler.data_source
179 |
180 | for idx in tqdm(
181 | self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
182 | ):
183 | indices.append((idx, data_source.get_frame_len(idx)))
184 | indices.sort(key=lambda elem: elem[1])
185 |
186 | batch = []
187 | batch_frames = 0
188 | for idx, frame_len in tqdm(
189 | indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
190 | ):
191 | if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
192 | batch.append(idx)
193 | batch_frames += frame_len
194 | else:
195 | if len(batch) > 0:
196 | batches.append(batch)
197 | if frame_len <= self.frames_threshold:
198 | batch = [idx]
199 | batch_frames = frame_len
200 | else:
201 | batch = []
202 | batch_frames = 0
203 |
204 | if not drop_last and len(batch) > 0:
205 | batches.append(batch)
206 |
207 | del indices
208 |
209 | # if want to have different batches between epochs, may just set a seed and log it in ckpt
210 | # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
211 | # e.g. for epoch n, use (random_seed + n)
212 | random.seed(random_seed)
213 | random.shuffle(batches)
214 |
215 | self.batches = batches
216 |
217 | def __iter__(self):
218 | return iter(self.batches)
219 |
220 | def __len__(self):
221 | return len(self.batches)
222 |
223 |
224 | # Load dataset
225 |
226 |
227 | def load_dataset(
228 | dataset_name: str,
229 | tokenizer: str = "pinyin",
230 | dataset_type: str = "CustomDataset",
231 | audio_type: str = "raw",
232 | mel_spec_module: nn.Module | None = None,
233 | mel_spec_kwargs: dict = dict(),
234 | ) -> CustomDataset | HFDataset:
235 | """
236 | dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
237 | - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
238 | """
239 |
240 | print("Loading dataset ...")
241 |
242 | if dataset_type == "CustomDataset":
243 | rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
244 | if audio_type == "raw":
245 | try:
246 | train_dataset = load_from_disk(f"{rel_data_path}/raw")
247 | except: # noqa: E722
248 | train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
249 | preprocessed_mel = False
250 | elif audio_type == "mel":
251 | train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
252 | preprocessed_mel = True
253 | with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
254 | data_dict = json.load(f)
255 | durations = data_dict["duration"]
256 | train_dataset = CustomDataset(
257 | train_dataset,
258 | durations=durations,
259 | preprocessed_mel=preprocessed_mel,
260 | mel_spec_module=mel_spec_module,
261 | **mel_spec_kwargs,
262 | )
263 |
264 | elif dataset_type == "CustomDatasetPath":
265 | try:
266 | train_dataset = load_from_disk(f"{dataset_name}/raw")
267 | except: # noqa: E722
268 | train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
269 |
270 | with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
271 | data_dict = json.load(f)
272 | durations = data_dict["duration"]
273 | train_dataset = CustomDataset(
274 | train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
275 | )
276 |
277 | elif dataset_type == "HFDataset":
278 | print(
279 | "Should manually modify the path of huggingface dataset to your need.\n"
280 | + "May also the corresponding script cuz different dataset may have different format."
281 | )
282 | pre, post = dataset_name.split("_")
283 | train_dataset = HFDataset(
284 | load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
285 | )
286 |
287 | return train_dataset
288 |
289 |
290 | # collation
291 |
292 |
293 | def collate_fn(batch):
294 | mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
295 | mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
296 | max_mel_length = mel_lengths.amax()
297 |
298 | padded_mel_specs = []
299 | for spec in mel_specs: # TODO. maybe records mask for attention here
300 | padding = (0, max_mel_length - spec.size(-1))
301 | padded_spec = F.pad(spec, padding, value=0)
302 | padded_mel_specs.append(padded_spec)
303 |
304 | mel_specs = torch.stack(padded_mel_specs)
305 |
306 | text = [item["text"] for item in batch]
307 | text_lengths = torch.LongTensor([len(item) for item in text])
308 |
309 | return dict(
310 | mel=mel_specs,
311 | mel_lengths=mel_lengths,
312 | text=text,
313 | text_lengths=text_lengths,
314 | )
315 |
--------------------------------------------------------------------------------
/src/f5_tts/model/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gc
4 | import os
5 |
6 | import torch
7 | import torchaudio
8 | import wandb
9 | from accelerate import Accelerator
10 | from accelerate.utils import DistributedDataParallelKwargs
11 | from ema_pytorch import EMA
12 | from torch.optim import AdamW
13 | from torch.optim.lr_scheduler import LinearLR, SequentialLR
14 | from torch.utils.data import DataLoader, Dataset, SequentialSampler
15 | from tqdm import tqdm
16 |
17 | from f5_tts.model import CFM
18 | from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19 | from f5_tts.model.utils import default, exists
20 |
21 | # trainer
22 |
23 |
24 | class Trainer:
25 | def __init__(
26 | self,
27 | model: CFM,
28 | epochs,
29 | learning_rate,
30 | num_warmup_updates=20000,
31 | save_per_updates=1000,
32 | checkpoint_path=None,
33 | batch_size=32,
34 | batch_size_type: str = "sample",
35 | max_samples=32,
36 | grad_accumulation_steps=1,
37 | max_grad_norm=1.0,
38 | noise_scheduler: str | None = None,
39 | duration_predictor: torch.nn.Module | None = None,
40 | logger: str | None = "wandb", # "wandb" | "tensorboard" | None
41 | wandb_project="test_e2-tts",
42 | wandb_run_name="test_run",
43 | wandb_resume_id: str = None,
44 | log_samples: bool = False,
45 | last_per_steps=None,
46 | accelerate_kwargs: dict = dict(),
47 | ema_kwargs: dict = dict(),
48 | bnb_optimizer: bool = False,
49 | mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50 | ):
51 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
52 |
53 | if logger == "wandb" and not wandb.api.api_key:
54 | logger = None
55 | print(f"Using logger: {logger}")
56 | self.log_samples = log_samples
57 |
58 | self.accelerator = Accelerator(
59 | log_with=logger if logger == "wandb" else None,
60 | kwargs_handlers=[ddp_kwargs],
61 | gradient_accumulation_steps=grad_accumulation_steps,
62 | **accelerate_kwargs,
63 | )
64 |
65 | self.logger = logger
66 | if self.logger == "wandb":
67 | if exists(wandb_resume_id):
68 | init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
69 | else:
70 | init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
71 |
72 | self.accelerator.init_trackers(
73 | project_name=wandb_project,
74 | init_kwargs=init_kwargs,
75 | config={
76 | "epochs": epochs,
77 | "learning_rate": learning_rate,
78 | "num_warmup_updates": num_warmup_updates,
79 | "batch_size": batch_size,
80 | "batch_size_type": batch_size_type,
81 | "max_samples": max_samples,
82 | "grad_accumulation_steps": grad_accumulation_steps,
83 | "max_grad_norm": max_grad_norm,
84 | "gpus": self.accelerator.num_processes,
85 | "noise_scheduler": noise_scheduler,
86 | },
87 | )
88 |
89 | elif self.logger == "tensorboard":
90 | from torch.utils.tensorboard import SummaryWriter
91 |
92 | self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
93 |
94 | self.model = model
95 |
96 | if self.is_main:
97 | self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
98 | self.ema_model.to(self.accelerator.device)
99 |
100 | self.epochs = epochs
101 | self.num_warmup_updates = num_warmup_updates
102 | self.save_per_updates = save_per_updates
103 | self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
104 | self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
105 |
106 | self.batch_size = batch_size
107 | self.batch_size_type = batch_size_type
108 | self.max_samples = max_samples
109 | self.grad_accumulation_steps = grad_accumulation_steps
110 | self.max_grad_norm = max_grad_norm
111 | self.vocoder_name = mel_spec_type
112 |
113 | self.noise_scheduler = noise_scheduler
114 |
115 | self.duration_predictor = duration_predictor
116 |
117 | if bnb_optimizer:
118 | import bitsandbytes as bnb
119 |
120 | self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
121 | else:
122 | self.optimizer = AdamW(model.parameters(), lr=learning_rate)
123 | self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
124 |
125 | @property
126 | def is_main(self):
127 | return self.accelerator.is_main_process
128 |
129 | def save_checkpoint(self, step, last=False):
130 | self.accelerator.wait_for_everyone()
131 | if self.is_main:
132 | checkpoint = dict(
133 | model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
134 | optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
135 | ema_model_state_dict=self.ema_model.state_dict(),
136 | scheduler_state_dict=self.scheduler.state_dict(),
137 | step=step,
138 | )
139 | if not os.path.exists(self.checkpoint_path):
140 | os.makedirs(self.checkpoint_path)
141 | if last:
142 | self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
143 | print(f"Saved last checkpoint at step {step}")
144 | else:
145 | self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
146 |
147 | def load_checkpoint(self):
148 | if (
149 | not exists(self.checkpoint_path)
150 | or not os.path.exists(self.checkpoint_path)
151 | or not os.listdir(self.checkpoint_path)
152 | ):
153 | return 0
154 |
155 | self.accelerator.wait_for_everyone()
156 | if "model_last.pt" in os.listdir(self.checkpoint_path):
157 | latest_checkpoint = "model_last.pt"
158 | else:
159 | latest_checkpoint = sorted(
160 | [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
161 | key=lambda x: int("".join(filter(str.isdigit, x))),
162 | )[-1]
163 | # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164 | checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165 |
166 | # patch for backward compatibility, 305e3ea
167 | for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168 | if key in checkpoint["ema_model_state_dict"]:
169 | del checkpoint["ema_model_state_dict"][key]
170 |
171 | if self.is_main:
172 | self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
173 |
174 | if "step" in checkpoint:
175 | # patch for backward compatibility, 305e3ea
176 | for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
177 | if key in checkpoint["model_state_dict"]:
178 | del checkpoint["model_state_dict"][key]
179 |
180 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
181 | self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
182 | if self.scheduler:
183 | self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
184 | step = checkpoint["step"]
185 | else:
186 | checkpoint["model_state_dict"] = {
187 | k.replace("ema_model.", ""): v
188 | for k, v in checkpoint["ema_model_state_dict"].items()
189 | if k not in ["initted", "step"]
190 | }
191 | self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
192 | step = 0
193 |
194 | del checkpoint
195 | gc.collect()
196 | return step
197 |
198 | def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
199 | if self.log_samples:
200 | from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
201 |
202 | vocoder = load_vocoder(vocoder_name=self.vocoder_name)
203 | target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
204 | log_samples_path = f"{self.checkpoint_path}/samples"
205 | os.makedirs(log_samples_path, exist_ok=True)
206 |
207 | if exists(resumable_with_seed):
208 | generator = torch.Generator()
209 | generator.manual_seed(resumable_with_seed)
210 | else:
211 | generator = None
212 |
213 | if self.batch_size_type == "sample":
214 | train_dataloader = DataLoader(
215 | train_dataset,
216 | collate_fn=collate_fn,
217 | num_workers=num_workers,
218 | pin_memory=True,
219 | persistent_workers=True,
220 | batch_size=self.batch_size,
221 | shuffle=True,
222 | generator=generator,
223 | )
224 | elif self.batch_size_type == "frame":
225 | self.accelerator.even_batches = False
226 | sampler = SequentialSampler(train_dataset)
227 | batch_sampler = DynamicBatchSampler(
228 | sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
229 | )
230 | train_dataloader = DataLoader(
231 | train_dataset,
232 | collate_fn=collate_fn,
233 | num_workers=num_workers,
234 | pin_memory=True,
235 | persistent_workers=True,
236 | batch_sampler=batch_sampler,
237 | )
238 | else:
239 | raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
240 |
241 | # accelerator.prepare() dispatches batches to devices;
242 | # which means the length of dataloader calculated before, should consider the number of devices
243 | warmup_steps = (
244 | self.num_warmup_updates * self.accelerator.num_processes
245 | ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
246 | # otherwise by default with split_batches=False, warmup steps change with num_processes
247 | total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
248 | decay_steps = total_steps - warmup_steps
249 | warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
250 | decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
251 | self.scheduler = SequentialLR(
252 | self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
253 | )
254 | train_dataloader, self.scheduler = self.accelerator.prepare(
255 | train_dataloader, self.scheduler
256 | ) # actual steps = 1 gpu steps / gpus
257 | start_step = self.load_checkpoint()
258 | global_step = start_step
259 |
260 | if exists(resumable_with_seed):
261 | orig_epoch_step = len(train_dataloader)
262 | skipped_epoch = int(start_step // orig_epoch_step)
263 | skipped_batch = start_step % orig_epoch_step
264 | skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
265 | else:
266 | skipped_epoch = 0
267 |
268 | for epoch in range(skipped_epoch, self.epochs):
269 | self.model.train()
270 | if exists(resumable_with_seed) and epoch == skipped_epoch:
271 | progress_bar = tqdm(
272 | skipped_dataloader,
273 | desc=f"Epoch {epoch+1}/{self.epochs}",
274 | unit="step",
275 | disable=not self.accelerator.is_local_main_process,
276 | initial=skipped_batch,
277 | total=orig_epoch_step,
278 | )
279 | else:
280 | progress_bar = tqdm(
281 | train_dataloader,
282 | desc=f"Epoch {epoch+1}/{self.epochs}",
283 | unit="step",
284 | disable=not self.accelerator.is_local_main_process,
285 | )
286 |
287 | for batch in progress_bar:
288 | with self.accelerator.accumulate(self.model):
289 | text_inputs = batch["text"]
290 | mel_spec = batch["mel"].permute(0, 2, 1)
291 | mel_lengths = batch["mel_lengths"]
292 |
293 | # TODO. add duration predictor training
294 | if self.duration_predictor is not None and self.accelerator.is_local_main_process:
295 | dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
296 | self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
297 |
298 | loss, cond, pred = self.model(
299 | mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
300 | )
301 | self.accelerator.backward(loss)
302 |
303 | if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
304 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
305 |
306 | self.optimizer.step()
307 | self.scheduler.step()
308 | self.optimizer.zero_grad()
309 |
310 | if self.is_main:
311 | self.ema_model.update()
312 |
313 | global_step += 1
314 |
315 | if self.accelerator.is_local_main_process:
316 | self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
317 | if self.logger == "tensorboard":
318 | self.writer.add_scalar("loss", loss.item(), global_step)
319 | self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
320 |
321 | progress_bar.set_postfix(step=str(global_step), loss=loss.item())
322 |
323 | if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
324 | self.save_checkpoint(global_step)
325 |
326 | if self.log_samples and self.accelerator.is_local_main_process:
327 | ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328 | torchaudio.save(
329 | f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
330 | )
331 | with torch.inference_mode():
332 | generated, _ = self.accelerator.unwrap_model(self.model).sample(
333 | cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
334 | text=[text_inputs[0] + [" "] + text_inputs[0]],
335 | duration=ref_audio_len * 2,
336 | steps=nfe_step,
337 | cfg_strength=cfg_strength,
338 | sway_sampling_coef=sway_sampling_coef,
339 | )
340 | generated = generated.to(torch.float32)
341 | gen_audio = vocoder.decode(
342 | generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
343 | )
344 | torchaudio.save(
345 | f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
346 | )
347 |
348 | if global_step % self.last_per_steps == 0:
349 | self.save_checkpoint(global_step, last=True)
350 |
351 | self.save_checkpoint(global_step, last=True)
352 |
353 | self.accelerator.end_training()
354 |
--------------------------------------------------------------------------------
/src/f5_tts/model/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import random
5 | from collections import defaultdict
6 | from importlib.resources import files
7 |
8 | import torch
9 | from torch.nn.utils.rnn import pad_sequence
10 |
11 | import jieba
12 | from pypinyin import lazy_pinyin, Style
13 |
14 |
15 | # seed everything
16 |
17 |
18 | def seed_everything(seed=0):
19 | random.seed(seed)
20 | os.environ["PYTHONHASHSEED"] = str(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed(seed)
23 | torch.cuda.manual_seed_all(seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 |
28 | # helpers
29 |
30 |
31 | def exists(v):
32 | return v is not None
33 |
34 |
35 | def default(v, d):
36 | return v if exists(v) else d
37 |
38 |
39 | # tensor helpers
40 |
41 |
42 | def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43 | if not exists(length):
44 | length = t.amax()
45 |
46 | seq = torch.arange(length, device=t.device)
47 | return seq[None, :] < t[:, None]
48 |
49 |
50 | def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51 | max_seq_len = seq_len.max().item()
52 | seq = torch.arange(max_seq_len, device=start.device).long()
53 | start_mask = seq[None, :] >= start[:, None]
54 | end_mask = seq[None, :] < end[:, None]
55 | return start_mask & end_mask
56 |
57 |
58 | def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
59 | lengths = (frac_lengths * seq_len).long()
60 | max_start = seq_len - lengths
61 |
62 | rand = torch.rand_like(frac_lengths)
63 | start = (max_start * rand).long().clamp(min=0)
64 | end = start + lengths
65 |
66 | return mask_from_start_end_indices(seq_len, start, end)
67 |
68 |
69 | def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70 | if not exists(mask):
71 | return t.mean(dim=1)
72 |
73 | t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74 | num = t.sum(dim=1)
75 | den = mask.float().sum(dim=1)
76 |
77 | return num / den.clamp(min=1.0)
78 |
79 |
80 | # simple utf-8 tokenizer, since paper went character based
81 | def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82 | list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83 | text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84 | return text
85 |
86 |
87 | # char tokenizer, based on custom dataset's extracted .txt file
88 | def list_str_to_idx(
89 | text: list[str] | list[list[str]],
90 | vocab_char_map: dict[str, int], # {char: idx}
91 | padding_value=-1,
92 | ) -> int["b nt"]: # noqa: F722
93 | list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94 | text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95 | return text
96 |
97 |
98 | # Get tokenizer
99 |
100 |
101 | def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102 | """
103 | tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104 | - "char" for char-wise tokenizer, need .txt vocab_file
105 | - "byte" for utf-8 tokenizer
106 | - "custom" if you're directly passing in a path to the vocab.txt you want to use
107 | vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108 | - if use "char", derived from unfiltered character & symbol counts of custom dataset
109 | - if use "byte", set to 256 (unicode byte range)
110 | """
111 | if tokenizer in ["pinyin", "char"]:
112 | tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113 | with open(tokenizer_path, "r", encoding="utf-8") as f:
114 | vocab_char_map = {}
115 | for i, char in enumerate(f):
116 | vocab_char_map[char[:-1]] = i
117 | vocab_size = len(vocab_char_map)
118 | assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119 |
120 | elif tokenizer == "byte":
121 | vocab_char_map = None
122 | vocab_size = 256
123 |
124 | elif tokenizer == "custom":
125 | with open(dataset_name, "r", encoding="utf-8") as f:
126 | vocab_char_map = {}
127 | for i, char in enumerate(f):
128 | vocab_char_map[char[:-1]] = i
129 | vocab_size = len(vocab_char_map)
130 |
131 | return vocab_char_map, vocab_size
132 |
133 |
134 | # convert char to pinyin
135 |
136 |
137 | def convert_char_to_pinyin(text_list, polyphone=True):
138 | final_text_list = []
139 | god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140 | {"“": '"', "”": '"', "‘": "'", "’": "'"}
141 | ) # in case librispeech (orig no-pc) test-clean
142 | custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
143 | for text in text_list:
144 | char_list = []
145 | text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146 | text = text.translate(custom_trans)
147 | for seg in jieba.cut(text):
148 | seg_byte_len = len(bytes(seg, "UTF-8"))
149 | if seg_byte_len == len(seg): # if pure alphabets and symbols
150 | if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151 | char_list.append(" ")
152 | char_list.extend(seg)
153 | elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
154 | seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
155 | for c in seg:
156 | if c not in "。,、;:?!《》【】—…":
157 | char_list.append(" ")
158 | char_list.append(c)
159 | else: # if mixed chinese characters, alphabets and symbols
160 | for c in seg:
161 | if ord(c) < 256:
162 | char_list.extend(c)
163 | else:
164 | if c not in "。,、;:?!《》【】—…":
165 | char_list.append(" ")
166 | char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
167 | else: # if is zh punc
168 | char_list.append(c)
169 | final_text_list.append(char_list)
170 |
171 | return final_text_list
172 |
173 |
174 | # filter func for dirty data with many repetitions
175 |
176 |
177 | def repetition_found(text, length=2, tolerance=10):
178 | pattern_count = defaultdict(int)
179 | for i in range(len(text) - length + 1):
180 | pattern = text[i : i + length]
181 | pattern_count[pattern] += 1
182 | for pattern, count in pattern_count.items():
183 | if count > tolerance:
184 | return True
185 | return False
186 |
--------------------------------------------------------------------------------
/src/f5_tts/scripts/count_max_epoch.py:
--------------------------------------------------------------------------------
1 | """ADAPTIVE BATCH SIZE"""
2 |
3 | print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
4 | print(" -> least padding, gather wavs with accumulated frames in a batch\n")
5 |
6 | # data
7 | total_hours = 95282
8 | mel_hop_length = 256
9 | mel_sampling_rate = 24000
10 |
11 | # target
12 | wanted_max_updates = 1000000
13 |
14 | # train params
15 | gpus = 8
16 | frames_per_gpu = 38400 # 8 * 38400 = 307200
17 | grad_accum = 1
18 |
19 | # intermediate
20 | mini_batch_frames = frames_per_gpu * grad_accum * gpus
21 | mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
22 | updates_per_epoch = total_hours / mini_batch_hours
23 | steps_per_epoch = updates_per_epoch * grad_accum
24 |
25 | # result
26 | epochs = wanted_max_updates / updates_per_epoch
27 | print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28 | print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29 | print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30 |
31 | # others
32 | print(f"total {total_hours:.0f} hours")
33 | print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
34 |
--------------------------------------------------------------------------------
/src/f5_tts/scripts/count_params_gflops.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | from f5_tts.model import CFM, DiT
7 |
8 | import torch
9 | import thop
10 |
11 |
12 | """ ~155M """
13 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14 | # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
16 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
17 | # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18 | # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19 |
20 | """ ~335M """
21 | # FLOPs: 622.1 G, Params: 333.2 M
22 | # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23 | # FLOPs: 363.4 G, Params: 335.8 M
24 | transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25 |
26 |
27 | model = CFM(transformer=transformer)
28 | target_sample_rate = 24000
29 | n_mel_channels = 100
30 | hop_length = 256
31 | duration = 20
32 | frame_length = int(duration * target_sample_rate / hop_length)
33 | text_length = 150
34 |
35 | flops, params = thop.profile(
36 | model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37 | )
38 | print(f"FLOPs: {flops / 1e9} G")
39 | print(f"Params: {params / 1e6} M")
40 |
--------------------------------------------------------------------------------
/src/f5_tts/socket_server.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import struct
3 | import torch
4 | import torchaudio
5 | from threading import Thread
6 |
7 |
8 | import gc
9 | import traceback
10 |
11 |
12 | from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13 | from model.backbones.dit import DiT
14 |
15 |
16 | class TTSStreamingProcessor:
17 | def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18 | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | # Load the model using the provided checkpoint and vocab files
21 | self.model = load_model(
22 | model_cls=DiT,
23 | model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
24 | ckpt_path=ckpt_file,
25 | mel_spec_type="vocos", # or "bigvgan" depending on vocoder
26 | vocab_file=vocab_file,
27 | ode_method="euler",
28 | use_ema=True,
29 | device=self.device,
30 | ).to(self.device, dtype=dtype)
31 |
32 | # Load the vocoder
33 | self.vocoder = load_vocoder(is_local=False)
34 |
35 | # Set sampling rate for streaming
36 | self.sampling_rate = 24000 # Consistency with client
37 |
38 | # Set reference audio and text
39 | self.ref_audio = ref_audio
40 | self.ref_text = ref_text
41 |
42 | # Warm up the model
43 | self._warm_up()
44 |
45 | def _warm_up(self):
46 | """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
47 | print("Warming up the model...")
48 | ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
49 | audio, sr = torchaudio.load(ref_audio)
50 | gen_text = "Warm-up text for the model."
51 |
52 | # Pass the vocoder as an argument here
53 | infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
54 | print("Warm-up completed.")
55 |
56 | def generate_stream(self, text, play_steps_in_s=0.5):
57 | """Generate audio in chunks and yield them in real-time."""
58 | # Preprocess the reference audio and text
59 | ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
60 |
61 | # Load reference audio
62 | audio, sr = torchaudio.load(ref_audio)
63 |
64 | # Run inference for the input text
65 | audio_chunk, final_sample_rate, _ = infer_batch_process(
66 | (audio, sr),
67 | ref_text,
68 | [text],
69 | self.model,
70 | self.vocoder,
71 | device=self.device, # Pass vocoder here
72 | )
73 |
74 | # Break the generated audio into chunks and send them
75 | chunk_size = int(final_sample_rate * play_steps_in_s)
76 |
77 | for i in range(0, len(audio_chunk), chunk_size):
78 | chunk = audio_chunk[i : i + chunk_size]
79 |
80 | # Check if it's the final chunk
81 | if i + chunk_size >= len(audio_chunk):
82 | chunk = audio_chunk[i:]
83 |
84 | # Avoid sending empty or repeated chunks
85 | if len(chunk) == 0:
86 | break
87 |
88 | # Pack and send the audio chunk
89 | packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
90 | yield packed_audio
91 |
92 | # Ensure that no final word is repeated by not resending partial chunks
93 | if len(audio_chunk) % chunk_size != 0:
94 | remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size) :]
95 | packed_audio = struct.pack(f"{len(remaining_chunk)}f", *remaining_chunk)
96 | yield packed_audio
97 |
98 |
99 | def handle_client(client_socket, processor):
100 | try:
101 | while True:
102 | # Receive data from the client
103 | data = client_socket.recv(1024).decode("utf-8")
104 | if not data:
105 | break
106 |
107 | try:
108 | # The client sends the text input
109 | text = data.strip()
110 |
111 | # Generate and stream audio chunks
112 | for audio_chunk in processor.generate_stream(text):
113 | client_socket.sendall(audio_chunk)
114 |
115 | # Send end-of-audio signal
116 | client_socket.sendall(b"END_OF_AUDIO")
117 |
118 | except Exception as inner_e:
119 | print(f"Error during processing: {inner_e}")
120 | traceback.print_exc() # Print the full traceback to diagnose the issue
121 | break
122 |
123 | except Exception as e:
124 | print(f"Error handling client: {e}")
125 | traceback.print_exc()
126 | finally:
127 | client_socket.close()
128 |
129 |
130 | def start_server(host, port, processor):
131 | server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
132 | server.bind((host, port))
133 | server.listen(5)
134 | print(f"Server listening on {host}:{port}")
135 |
136 | while True:
137 | client_socket, addr = server.accept()
138 | print(f"Accepted connection from {addr}")
139 | client_handler = Thread(target=handle_client, args=(client_socket, processor))
140 | client_handler.start()
141 |
142 |
143 | if __name__ == "__main__":
144 | try:
145 | # Load the model and vocoder using the provided files
146 | ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
147 | vocab_file = "" # Add vocab file path if needed
148 | ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
149 | ref_text = ""
150 |
151 | # Initialize the processor with the model and vocoder
152 | processor = TTSStreamingProcessor(
153 | ckpt_file=ckpt_file,
154 | vocab_file=vocab_file,
155 | ref_audio=ref_audio,
156 | ref_text=ref_text,
157 | dtype=torch.float32,
158 | )
159 |
160 | # Start the server
161 | start_server("0.0.0.0", 9998, processor)
162 | except KeyboardInterrupt:
163 | gc.collect()
164 |
--------------------------------------------------------------------------------
/src/f5_tts/train/README.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | ## Prepare Dataset
4 |
5 | Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
6 |
7 | ### 1. Datasets used for pretrained models
8 | Download corresponding dataset first, and fill in the path in scripts.
9 |
10 | ```bash
11 | # Prepare the Emilia dataset
12 | python src/f5_tts/train/datasets/prepare_emilia.py
13 |
14 | # Prepare the Wenetspeech4TTS dataset
15 | python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
16 | ```
17 |
18 | ### 2. Create custom dataset with metadata.csv
19 | Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
20 |
21 | ```bash
22 | python src/f5_tts/train/datasets/prepare_csv_wavs.py
23 | ```
24 |
25 | ## Training & Finetuning
26 |
27 | Once your datasets are prepared, you can start the training process.
28 |
29 | ### 1. Training script used for pretrained model
30 |
31 | ```bash
32 | # setup accelerate config, e.g. use multi-gpu ddp, fp16
33 | # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
34 | accelerate config
35 | accelerate launch src/f5_tts/train/train.py
36 | ```
37 |
38 | ### 2. Finetuning practice
39 | Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
40 |
41 | Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
42 |
43 | ### 3. Wandb Logging
44 |
45 | The `wandb/` dir will be created under path you run training/finetuning scripts.
46 |
47 | By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
48 |
49 | To turn on wandb logging, you can either:
50 |
51 | 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
52 | 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
53 |
54 | On Mac & Linux:
55 |
56 | ```
57 | export WANDB_API_KEY=
58 | ```
59 |
60 | On Windows:
61 |
62 | ```
63 | set WANDB_API_KEY=
64 | ```
65 | Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
66 |
67 | ```
68 | export WANDB_MODE=offline
69 | ```
70 |
--------------------------------------------------------------------------------
/src/f5_tts/train/datasets/prepare_csv_wavs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | import argparse
7 | import csv
8 | import json
9 | import shutil
10 | from importlib.resources import files
11 | from pathlib import Path
12 |
13 | import torchaudio
14 | from tqdm import tqdm
15 | from datasets.arrow_writer import ArrowWriter
16 |
17 | from f5_tts.model.utils import (
18 | convert_char_to_pinyin,
19 | )
20 |
21 |
22 | PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
23 |
24 |
25 | def is_csv_wavs_format(input_dataset_dir):
26 | fpath = Path(input_dataset_dir)
27 | metadata = fpath / "metadata.csv"
28 | wavs = fpath / "wavs"
29 | return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
30 |
31 |
32 | def prepare_csv_wavs_dir(input_dir):
33 | assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
34 | input_dir = Path(input_dir)
35 | metadata_path = input_dir / "metadata.csv"
36 | audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
37 |
38 | sub_result, durations = [], []
39 | vocab_set = set()
40 | polyphone = True
41 | for audio_path, text in audio_path_text_pairs:
42 | if not Path(audio_path).exists():
43 | print(f"audio {audio_path} not found, skipping")
44 | continue
45 | audio_duration = get_audio_duration(audio_path)
46 | # assume tokenizer = "pinyin" ("pinyin" | "char")
47 | text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
48 | sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
49 | durations.append(audio_duration)
50 | vocab_set.update(list(text))
51 |
52 | return sub_result, durations, vocab_set
53 |
54 |
55 | def get_audio_duration(audio_path):
56 | audio, sample_rate = torchaudio.load(audio_path)
57 | return audio.shape[1] / sample_rate
58 |
59 |
60 | def read_audio_text_pairs(csv_file_path):
61 | audio_text_pairs = []
62 |
63 | parent = Path(csv_file_path).parent
64 | with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
65 | reader = csv.reader(csvfile, delimiter="|")
66 | next(reader) # Skip the header row
67 | for row in reader:
68 | if len(row) >= 2:
69 | audio_file = row[0].strip() # First column: audio file path
70 | text = row[1].strip() # Second column: text
71 | audio_file_path = parent / audio_file
72 | audio_text_pairs.append((audio_file_path.as_posix(), text))
73 |
74 | return audio_text_pairs
75 |
76 |
77 | def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
78 | out_dir = Path(out_dir)
79 | # save preprocessed dataset to disk
80 | out_dir.mkdir(exist_ok=True, parents=True)
81 | print(f"\nSaving to {out_dir} ...")
82 |
83 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
84 | # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
85 | raw_arrow_path = out_dir / "raw.arrow"
86 | with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
87 | for line in tqdm(result, desc="Writing to raw.arrow ..."):
88 | writer.write(line)
89 |
90 | # dup a json separately saving duration in case for DynamicBatchSampler ease
91 | dur_json_path = out_dir / "duration.json"
92 | with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
93 | json.dump({"duration": duration_list}, f, ensure_ascii=False)
94 |
95 | # vocab map, i.e. tokenizer
96 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
97 | # if tokenizer == "pinyin":
98 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
99 | voca_out_path = out_dir / "vocab.txt"
100 | with open(voca_out_path.as_posix(), "w") as f:
101 | for vocab in sorted(text_vocab_set):
102 | f.write(vocab + "\n")
103 |
104 | if is_finetune:
105 | file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
106 | shutil.copy2(file_vocab_finetune, voca_out_path)
107 | else:
108 | with open(voca_out_path, "w") as f:
109 | for vocab in sorted(text_vocab_set):
110 | f.write(vocab + "\n")
111 |
112 | dataset_name = out_dir.stem
113 | print(f"\nFor {dataset_name}, sample count: {len(result)}")
114 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
115 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
116 |
117 |
118 | def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
119 | if is_finetune:
120 | assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
121 | sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
122 | save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
123 |
124 |
125 | def cli():
126 | # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
127 | # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
128 | parser = argparse.ArgumentParser(description="Prepare and save dataset.")
129 | parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
130 | parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
131 | parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
132 |
133 | args = parser.parse_args()
134 |
135 | prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
136 |
137 |
138 | if __name__ == "__main__":
139 | cli()
140 |
--------------------------------------------------------------------------------
/src/f5_tts/train/datasets/prepare_emilia.py:
--------------------------------------------------------------------------------
1 | # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
2 | # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
3 |
4 | # generate audio text map for Emilia ZH & EN
5 | # evaluate for vocab size
6 |
7 | import os
8 | import sys
9 |
10 | sys.path.append(os.getcwd())
11 |
12 | import json
13 | from concurrent.futures import ProcessPoolExecutor
14 | from importlib.resources import files
15 | from pathlib import Path
16 | from tqdm import tqdm
17 |
18 | from datasets.arrow_writer import ArrowWriter
19 |
20 | from f5_tts.model.utils import (
21 | repetition_found,
22 | convert_char_to_pinyin,
23 | )
24 |
25 |
26 | out_zh = {
27 | "ZH_B00041_S06226",
28 | "ZH_B00042_S09204",
29 | "ZH_B00065_S09430",
30 | "ZH_B00065_S09431",
31 | "ZH_B00066_S09327",
32 | "ZH_B00066_S09328",
33 | }
34 | zh_filters = ["い", "て"]
35 | # seems synthesized audios, or heavily code-switched
36 | out_en = {
37 | "EN_B00013_S00913",
38 | "EN_B00042_S00120",
39 | "EN_B00055_S04111",
40 | "EN_B00061_S00693",
41 | "EN_B00061_S01494",
42 | "EN_B00061_S03375",
43 | "EN_B00059_S00092",
44 | "EN_B00111_S04300",
45 | "EN_B00100_S03759",
46 | "EN_B00087_S03811",
47 | "EN_B00059_S00950",
48 | "EN_B00089_S00946",
49 | "EN_B00078_S05127",
50 | "EN_B00070_S04089",
51 | "EN_B00074_S09659",
52 | "EN_B00061_S06983",
53 | "EN_B00061_S07060",
54 | "EN_B00059_S08397",
55 | "EN_B00082_S06192",
56 | "EN_B00091_S01238",
57 | "EN_B00089_S07349",
58 | "EN_B00070_S04343",
59 | "EN_B00061_S02400",
60 | "EN_B00076_S01262",
61 | "EN_B00068_S06467",
62 | "EN_B00076_S02943",
63 | "EN_B00064_S05954",
64 | "EN_B00061_S05386",
65 | "EN_B00066_S06544",
66 | "EN_B00076_S06944",
67 | "EN_B00072_S08620",
68 | "EN_B00076_S07135",
69 | "EN_B00076_S09127",
70 | "EN_B00065_S00497",
71 | "EN_B00059_S06227",
72 | "EN_B00063_S02859",
73 | "EN_B00075_S01547",
74 | "EN_B00061_S08286",
75 | "EN_B00079_S02901",
76 | "EN_B00092_S03643",
77 | "EN_B00096_S08653",
78 | "EN_B00063_S04297",
79 | "EN_B00063_S04614",
80 | "EN_B00079_S04698",
81 | "EN_B00104_S01666",
82 | "EN_B00061_S09504",
83 | "EN_B00061_S09694",
84 | "EN_B00065_S05444",
85 | "EN_B00063_S06860",
86 | "EN_B00065_S05725",
87 | "EN_B00069_S07628",
88 | "EN_B00083_S03875",
89 | "EN_B00071_S07665",
90 | "EN_B00071_S07665",
91 | "EN_B00062_S04187",
92 | "EN_B00065_S09873",
93 | "EN_B00065_S09922",
94 | "EN_B00084_S02463",
95 | "EN_B00067_S05066",
96 | "EN_B00106_S08060",
97 | "EN_B00073_S06399",
98 | "EN_B00073_S09236",
99 | "EN_B00087_S00432",
100 | "EN_B00085_S05618",
101 | "EN_B00064_S01262",
102 | "EN_B00072_S01739",
103 | "EN_B00059_S03913",
104 | "EN_B00069_S04036",
105 | "EN_B00067_S05623",
106 | "EN_B00060_S05389",
107 | "EN_B00060_S07290",
108 | "EN_B00062_S08995",
109 | }
110 | en_filters = ["ا", "い", "て"]
111 |
112 |
113 | def deal_with_audio_dir(audio_dir):
114 | audio_jsonl = audio_dir.with_suffix(".jsonl")
115 | sub_result, durations = [], []
116 | vocab_set = set()
117 | bad_case_zh = 0
118 | bad_case_en = 0
119 | with open(audio_jsonl, "r") as f:
120 | lines = f.readlines()
121 | for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
122 | obj = json.loads(line)
123 | text = obj["text"]
124 | if obj["language"] == "zh":
125 | if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
126 | bad_case_zh += 1
127 | continue
128 | else:
129 | text = text.translate(
130 | str.maketrans({",": ",", "!": "!", "?": "?"})
131 | ) # not "。" cuz much code-switched
132 | if obj["language"] == "en":
133 | if (
134 | obj["wav"].split("/")[1] in out_en
135 | or any(f in text for f in en_filters)
136 | or repetition_found(text, length=4)
137 | ):
138 | bad_case_en += 1
139 | continue
140 | if tokenizer == "pinyin":
141 | text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
142 | duration = obj["duration"]
143 | sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
144 | durations.append(duration)
145 | vocab_set.update(list(text))
146 | return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
147 |
148 |
149 | def main():
150 | assert tokenizer in ["pinyin", "char"]
151 | result = []
152 | duration_list = []
153 | text_vocab_set = set()
154 | total_bad_case_zh = 0
155 | total_bad_case_en = 0
156 |
157 | # process raw data
158 | executor = ProcessPoolExecutor(max_workers=max_workers)
159 | futures = []
160 | for lang in langs:
161 | dataset_path = Path(os.path.join(dataset_dir, lang))
162 | [
163 | futures.append(executor.submit(deal_with_audio_dir, audio_dir))
164 | for audio_dir in dataset_path.iterdir()
165 | if audio_dir.is_dir()
166 | ]
167 | for futures in tqdm(futures, total=len(futures)):
168 | sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
169 | result.extend(sub_result)
170 | duration_list.extend(durations)
171 | text_vocab_set.update(vocab_set)
172 | total_bad_case_zh += bad_case_zh
173 | total_bad_case_en += bad_case_en
174 | executor.shutdown()
175 |
176 | # save preprocessed dataset to disk
177 | if not os.path.exists(f"{save_dir}"):
178 | os.makedirs(f"{save_dir}")
179 | print(f"\nSaving to {save_dir} ...")
180 |
181 | # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
182 | # dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
183 | with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
184 | for line in tqdm(result, desc="Writing to raw.arrow ..."):
185 | writer.write(line)
186 |
187 | # dup a json separately saving duration in case for DynamicBatchSampler ease
188 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
189 | json.dump({"duration": duration_list}, f, ensure_ascii=False)
190 |
191 | # vocab map, i.e. tokenizer
192 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
193 | # if tokenizer == "pinyin":
194 | # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
195 | with open(f"{save_dir}/vocab.txt", "w") as f:
196 | for vocab in sorted(text_vocab_set):
197 | f.write(vocab + "\n")
198 |
199 | print(f"\nFor {dataset_name}, sample count: {len(result)}")
200 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
201 | print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
202 | if "ZH" in langs:
203 | print(f"Bad zh transcription case: {total_bad_case_zh}")
204 | if "EN" in langs:
205 | print(f"Bad en transcription case: {total_bad_case_en}\n")
206 |
207 |
208 | if __name__ == "__main__":
209 | max_workers = 32
210 |
211 | tokenizer = "pinyin" # "pinyin" | "char"
212 | polyphone = True
213 |
214 | langs = ["ZH", "EN"]
215 | dataset_dir = "/Emilia_Dataset/raw"
216 | dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
217 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
218 | print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
219 |
220 | main()
221 |
222 | # Emilia ZH & EN
223 | # samples count 37837916 (after removal)
224 | # pinyin vocab size 2543 (polyphone)
225 | # total duration 95281.87 (hours)
226 | # bad zh asr cnt 230435 (samples)
227 | # bad eh asr cnt 37217 (samples)
228 |
229 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
230 | # please be careful if using pretrained model, make sure the vocab.txt is same
231 |
--------------------------------------------------------------------------------
/src/f5_tts/train/datasets/prepare_wenetspeech4tts.py:
--------------------------------------------------------------------------------
1 | # generate audio text map for WenetSpeech4TTS
2 | # evaluate for vocab size
3 |
4 | import os
5 | import sys
6 |
7 | sys.path.append(os.getcwd())
8 |
9 | import json
10 | from concurrent.futures import ProcessPoolExecutor
11 | from importlib.resources import files
12 | from tqdm import tqdm
13 |
14 | import torchaudio
15 | from datasets import Dataset
16 |
17 | from f5_tts.model.utils import convert_char_to_pinyin
18 |
19 |
20 | def deal_with_sub_path_files(dataset_path, sub_path):
21 | print(f"Dealing with: {sub_path}")
22 |
23 | text_dir = os.path.join(dataset_path, sub_path, "txts")
24 | audio_dir = os.path.join(dataset_path, sub_path, "wavs")
25 | text_files = os.listdir(text_dir)
26 |
27 | audio_paths, texts, durations = [], [], []
28 | for text_file in tqdm(text_files):
29 | with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
30 | first_line = file.readline().split("\t")
31 | audio_nm = first_line[0]
32 | audio_path = os.path.join(audio_dir, audio_nm + ".wav")
33 | text = first_line[1].strip()
34 |
35 | audio_paths.append(audio_path)
36 |
37 | if tokenizer == "pinyin":
38 | texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
39 | elif tokenizer == "char":
40 | texts.append(text)
41 |
42 | audio, sample_rate = torchaudio.load(audio_path)
43 | durations.append(audio.shape[-1] / sample_rate)
44 |
45 | return audio_paths, texts, durations
46 |
47 |
48 | def main():
49 | assert tokenizer in ["pinyin", "char"]
50 |
51 | audio_path_list, text_list, duration_list = [], [], []
52 |
53 | executor = ProcessPoolExecutor(max_workers=max_workers)
54 | futures = []
55 | for dataset_path in dataset_paths:
56 | sub_items = os.listdir(dataset_path)
57 | sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
58 | for sub_path in sub_paths:
59 | futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
60 | for future in tqdm(futures, total=len(futures)):
61 | audio_paths, texts, durations = future.result()
62 | audio_path_list.extend(audio_paths)
63 | text_list.extend(texts)
64 | duration_list.extend(durations)
65 | executor.shutdown()
66 |
67 | if not os.path.exists("data"):
68 | os.makedirs("data")
69 |
70 | print(f"\nSaving to {save_dir} ...")
71 | dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
72 | dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
73 |
74 | with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
75 | json.dump(
76 | {"duration": duration_list}, f, ensure_ascii=False
77 | ) # dup a json separately saving duration in case for DynamicBatchSampler ease
78 |
79 | print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
80 | text_vocab_set = set()
81 | for text in tqdm(text_list):
82 | text_vocab_set.update(list(text))
83 |
84 | # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
85 | if tokenizer == "pinyin":
86 | text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
87 |
88 | with open(f"{save_dir}/vocab.txt", "w") as f:
89 | for vocab in sorted(text_vocab_set):
90 | f.write(vocab + "\n")
91 | print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
92 | print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
93 |
94 |
95 | if __name__ == "__main__":
96 | max_workers = 32
97 |
98 | tokenizer = "pinyin" # "pinyin" | "char"
99 | polyphone = True
100 | dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
101 |
102 | dataset_name = (
103 | ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
104 | + "_"
105 | + tokenizer
106 | )
107 | dataset_paths = [
108 | "/WenetSpeech4TTS/Basic",
109 | "/WenetSpeech4TTS/Standard",
110 | "/WenetSpeech4TTS/Premium",
111 | ][-dataset_choice:]
112 | save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
113 | print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
114 |
115 | main()
116 |
117 | # Results (if adding alphabets with accents and symbols):
118 | # WenetSpeech4TTS Basic Standard Premium
119 | # samples count 3932473 1941220 407494
120 | # pinyin vocab size 1349 1348 1344 (no polyphone)
121 | # - - 1459 (polyphone)
122 | # char vocab size 5264 5219 5042
123 |
124 | # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
125 | # please be careful if using pretrained model, make sure the vocab.txt is same
126 |
--------------------------------------------------------------------------------
/src/f5_tts/train/finetune_cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | from cached_path import cached_path
6 | from f5_tts.model import CFM, UNetT, DiT, Trainer
7 | from f5_tts.model.utils import get_tokenizer
8 | from f5_tts.model.dataset import load_dataset
9 | from importlib.resources import files
10 |
11 |
12 | # -------------------------- Dataset Settings --------------------------- #
13 | target_sample_rate = 24000
14 | n_mel_channels = 100
15 | hop_length = 256
16 | win_length = 1024
17 | n_fft = 1024
18 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
19 |
20 |
21 | # -------------------------- Argument Parsing --------------------------- #
22 | def parse_args():
23 | # batch_size_per_gpu = 1000 settting for gpu 8GB
24 | # batch_size_per_gpu = 1600 settting for gpu 12GB
25 | # batch_size_per_gpu = 2000 settting for gpu 16GB
26 | # batch_size_per_gpu = 3200 settting for gpu 24GB
27 |
28 | # num_warmup_updates = 300 for 5000 sample about 10 hours
29 |
30 | # change save_per_updates , last_per_steps change this value what you need ,
31 |
32 | parser = argparse.ArgumentParser(description="Train CFM Model")
33 |
34 | parser.add_argument(
35 | "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
36 | )
37 | parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
38 | parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
39 | parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
40 | parser.add_argument(
41 | "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
42 | )
43 | parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
44 | parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45 | parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46 | parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47 | parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
48 | parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
49 | parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
50 | parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
51 | parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
52 | parser.add_argument(
53 | "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
54 | )
55 | parser.add_argument(
56 | "--tokenizer_path",
57 | type=str,
58 | default=None,
59 | help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
60 | )
61 | parser.add_argument(
62 | "--log_samples",
63 | type=bool,
64 | default=False,
65 | help="Log inferenced samples per ckpt save steps",
66 | )
67 | parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
68 | parser.add_argument(
69 | "--bnb_optimizer",
70 | type=bool,
71 | default=False,
72 | help="Use 8-bit Adam optimizer from bitsandbytes",
73 | )
74 |
75 | return parser.parse_args()
76 |
77 |
78 | # -------------------------- Training Settings -------------------------- #
79 |
80 |
81 | def main():
82 | args = parse_args()
83 |
84 | checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
85 |
86 | # Model parameters based on experiment name
87 | if args.exp_name == "F5TTS_Base":
88 | wandb_resume_id = None
89 | model_cls = DiT
90 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
91 | if args.finetune:
92 | if args.pretrain is None:
93 | ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
94 | else:
95 | ckpt_path = args.pretrain
96 | elif args.exp_name == "E2TTS_Base":
97 | wandb_resume_id = None
98 | model_cls = UNetT
99 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
100 | if args.finetune:
101 | if args.pretrain is None:
102 | ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
103 | else:
104 | ckpt_path = args.pretrain
105 |
106 | if args.finetune:
107 | if not os.path.isdir(checkpoint_path):
108 | os.makedirs(checkpoint_path, exist_ok=True)
109 |
110 | file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
111 | if not os.path.isfile(file_checkpoint):
112 | shutil.copy2(ckpt_path, file_checkpoint)
113 | print("copy checkpoint for finetune")
114 |
115 | # Use the tokenizer and tokenizer_path provided in the command line arguments
116 | tokenizer = args.tokenizer
117 | if tokenizer == "custom":
118 | if not args.tokenizer_path:
119 | raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
120 | tokenizer_path = args.tokenizer_path
121 | else:
122 | tokenizer_path = args.dataset_name
123 |
124 | vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
125 |
126 | print("\nvocab : ", vocab_size)
127 | print("\nvocoder : ", mel_spec_type)
128 |
129 | mel_spec_kwargs = dict(
130 | n_fft=n_fft,
131 | hop_length=hop_length,
132 | win_length=win_length,
133 | n_mel_channels=n_mel_channels,
134 | target_sample_rate=target_sample_rate,
135 | mel_spec_type=mel_spec_type,
136 | )
137 |
138 | model = CFM(
139 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
140 | mel_spec_kwargs=mel_spec_kwargs,
141 | vocab_char_map=vocab_char_map,
142 | )
143 |
144 | trainer = Trainer(
145 | model,
146 | args.epochs,
147 | args.learning_rate,
148 | num_warmup_updates=args.num_warmup_updates,
149 | save_per_updates=args.save_per_updates,
150 | checkpoint_path=checkpoint_path,
151 | batch_size=args.batch_size_per_gpu,
152 | batch_size_type=args.batch_size_type,
153 | max_samples=args.max_samples,
154 | grad_accumulation_steps=args.grad_accumulation_steps,
155 | max_grad_norm=args.max_grad_norm,
156 | logger=args.logger,
157 | wandb_project=args.dataset_name,
158 | wandb_run_name=args.exp_name,
159 | wandb_resume_id=wandb_resume_id,
160 | log_samples=args.log_samples,
161 | last_per_steps=args.last_per_steps,
162 | bnb_optimizer=args.bnb_optimizer,
163 | )
164 |
165 | train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
166 |
167 | trainer.train(
168 | train_dataset,
169 | resumable_with_seed=666, # seed for shuffling dataset
170 | )
171 |
172 |
173 | if __name__ == "__main__":
174 | main()
175 |
--------------------------------------------------------------------------------
/src/f5_tts/train/train.py:
--------------------------------------------------------------------------------
1 | # training script.
2 |
3 | from importlib.resources import files
4 |
5 | from f5_tts.model import CFM, DiT, Trainer, UNetT
6 | from f5_tts.model.dataset import load_dataset
7 | from f5_tts.model.utils import get_tokenizer
8 |
9 | # -------------------------- Dataset Settings --------------------------- #
10 |
11 | target_sample_rate = 24000
12 | n_mel_channels = 100
13 | hop_length = 256
14 | win_length = 1024
15 | n_fft = 1024
16 | mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
17 |
18 | tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
19 | tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
20 | dataset_name = "Emilia_ZH_EN"
21 |
22 | # -------------------------- Training Settings -------------------------- #
23 |
24 | exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
25 |
26 | learning_rate = 7.5e-5
27 |
28 | batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
29 | batch_size_type = "frame" # "frame" or "sample"
30 | max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
31 | grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
32 | max_grad_norm = 1.0
33 |
34 | epochs = 11 # use linear decay, thus epochs control the slope
35 | num_warmup_updates = 20000 # warmup steps
36 | save_per_updates = 50000 # save checkpoint per steps
37 | last_per_steps = 5000 # save last checkpoint per steps
38 |
39 | # model params
40 | if exp_name == "F5TTS_Base":
41 | wandb_resume_id = None
42 | model_cls = DiT
43 | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
44 | elif exp_name == "E2TTS_Base":
45 | wandb_resume_id = None
46 | model_cls = UNetT
47 | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
48 |
49 |
50 | # ----------------------------------------------------------------------- #
51 |
52 |
53 | def main():
54 | if tokenizer == "custom":
55 | tokenizer_path = tokenizer_path
56 | else:
57 | tokenizer_path = dataset_name
58 | vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
59 |
60 | mel_spec_kwargs = dict(
61 | n_fft=n_fft,
62 | hop_length=hop_length,
63 | win_length=win_length,
64 | n_mel_channels=n_mel_channels,
65 | target_sample_rate=target_sample_rate,
66 | mel_spec_type=mel_spec_type,
67 | )
68 |
69 | model = CFM(
70 | transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
71 | mel_spec_kwargs=mel_spec_kwargs,
72 | vocab_char_map=vocab_char_map,
73 | )
74 |
75 | trainer = Trainer(
76 | model,
77 | epochs,
78 | learning_rate,
79 | num_warmup_updates=num_warmup_updates,
80 | save_per_updates=save_per_updates,
81 | checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")),
82 | batch_size=batch_size_per_gpu,
83 | batch_size_type=batch_size_type,
84 | max_samples=max_samples,
85 | grad_accumulation_steps=grad_accumulation_steps,
86 | max_grad_norm=max_grad_norm,
87 | wandb_project="CFM-TTS",
88 | wandb_run_name=exp_name,
89 | wandb_resume_id=wandb_resume_id,
90 | last_per_steps=last_per_steps,
91 | log_samples=True,
92 | mel_spec_type=mel_spec_type,
93 | )
94 |
95 | train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
96 | trainer.train(
97 | train_dataset,
98 | resumable_with_seed=666, # seed for shuffling dataset
99 | )
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------