├── .dockerignore
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
├── pull_request_template.md
└── workflows
│ ├── build-docker-image.yml
│ ├── docs.yml
│ └── stale.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .project-root
├── .readthedocs.yaml
├── API_FLAGS.txt
├── LICENSE
├── README.md
├── docker-compose.dev.yml
├── dockerfile
├── dockerfile.dev
├── docs
├── CNAME
├── README.ja.md
├── README.ko.md
├── README.pt-BR.md
├── README.zh.md
├── assets
│ └── figs
│ │ ├── VS_1.jpg
│ │ ├── VS_1_pt-BR.png
│ │ ├── agent_gradio.png
│ │ ├── diagram.png
│ │ ├── diagrama.png
│ │ └── logo-circle.png
├── en
│ ├── finetune.md
│ ├── index.md
│ ├── inference.md
│ ├── samples.md
│ └── start_agent.md
├── ja
│ ├── finetune.md
│ ├── index.md
│ ├── inference.md
│ ├── samples.md
│ └── start_agent.md
├── ko
│ ├── finetune.md
│ ├── index.md
│ ├── inference.md
│ ├── samples.md
│ └── start_agent.md
├── pt
│ ├── finetune.md
│ ├── index.md
│ ├── inference.md
│ ├── samples.md
│ └── start_agent.md
├── requirements.txt
├── stylesheets
│ └── extra.css
└── zh
│ ├── finetune.md
│ ├── index.md
│ ├── inference.md
│ ├── samples.md
│ └── start_agent.md
├── entrypoint.sh
├── fish_speech
├── callbacks
│ ├── __init__.py
│ └── grad_norm.py
├── configs
│ ├── base.yaml
│ ├── firefly_gan_vq.yaml
│ ├── lora
│ │ └── r_8_alpha_16.yaml
│ └── text2semantic_finetune.yaml
├── conversation.py
├── datasets
│ ├── concat_repeat.py
│ ├── protos
│ │ ├── text-data.proto
│ │ ├── text_data_pb2.py
│ │ └── text_data_stream.py
│ ├── semantic.py
│ └── vqgan.py
├── i18n
│ ├── README.md
│ ├── __init__.py
│ ├── core.py
│ ├── locale
│ │ ├── en_US.json
│ │ ├── es_ES.json
│ │ ├── ja_JP.json
│ │ ├── ko_KR.json
│ │ ├── pt_BR.json
│ │ └── zh_CN.json
│ └── scan.py
├── inference_engine
│ ├── __init__.py
│ ├── reference_loader.py
│ ├── utils.py
│ └── vq_manager.py
├── models
│ ├── text2semantic
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ ├── lit_module.py
│ │ ├── llama.py
│ │ └── lora.py
│ └── vqgan
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ ├── modules
│ │ ├── firefly.py
│ │ └── fsq.py
│ │ └── utils.py
├── scheduler.py
├── text
│ ├── __init__.py
│ ├── clean.py
│ └── spliter.py
├── tokenizer.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── braceexpand.py
│ ├── context.py
│ ├── file.py
│ ├── instantiators.py
│ ├── logger.py
│ ├── logging_utils.py
│ ├── rich_utils.py
│ ├── schema.py
│ ├── spectrogram.py
│ └── utils.py
├── inference.ipynb
├── mkdocs.yml
├── pyproject.toml
├── pyrightconfig.json
└── tools
├── api_client.py
├── api_server.py
├── download_models.py
├── e2e_webui.py
├── export_onnx.py
├── extract_model.py
├── fish_e2e.py
├── llama
├── build_dataset.py
├── eval_in_context.py
├── merge_lora.py
└── quantize.py
├── run_webui.py
├── server
├── agent
│ ├── __init__.py
│ ├── generate.py
│ ├── generation_utils.py
│ └── pre_generation_utils.py
├── api_utils.py
├── exception_handler.py
├── inference.py
├── model_manager.py
├── model_utils.py
└── views.py
├── smart_pad.py
├── vqgan
├── create_train_split.py
└── extract_vq.py
├── webui
├── __init__.py
├── inference.py
└── variables.py
└── whisper_asr.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | .git
2 | .github
3 | results
4 | data
5 | *.filelist
6 | /data_server/target
7 | checkpoints
8 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: "🕷️ Bug report"
2 | description: |
3 | Please follow this template carefully to ensure we can address your issue quickly.
4 | Make sure to provide as much detail as possible, including logs and screenshots.
5 | labels:
6 | - bug
7 | body:
8 | - type: checkboxes
9 | attributes:
10 | label: Self Checks
11 | description: "To ensure timely help, please confirm the following:"
12 | options:
13 | - label: This template is only for bug reports. For questions, please visit [Discussions](https://github.com/fishaudio/fish-speech/discussions).
14 | required: true
15 | - label: I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find information to solve my problem. [English](https://speech.fish.audio/) [中文](https://speech.fish.audio/zh/) [日本語](https://speech.fish.audio/ja/) [Portuguese (Brazil)](https://speech.fish.audio/pt/)
16 | required: true
17 | - label: I have searched for existing issues, including closed ones. [Search issues](https://github.com/fishaudio/fish-speech/issues)
18 | required: true
19 | - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)).
20 | required: true
21 | - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
22 | required: true
23 | - label: "Please do not modify this template and fill in all required fields."
24 | required: true
25 | - type: dropdown
26 | attributes:
27 | label: Cloud or Self Hosted
28 | multiple: true
29 | options:
30 | - Cloud
31 | - Self Hosted (Docker)
32 | - Self Hosted (Source)
33 | validations:
34 | required: true
35 | - type: textarea
36 | attributes:
37 | label: Environment Details
38 | description: "Provide details such as OS, Python version, and any relevant software or dependencies."
39 | placeholder: e.g., macOS 13.5, Python 3.10, torch==2.4.1, Gradio 4.44.0
40 | validations:
41 | required: true
42 | - type: textarea
43 | attributes:
44 | label: Steps to Reproduce
45 | description: |
46 | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
47 | placeholder: |
48 | 1. Run the command `python -m tools.api_client -t "xxxxx"`
49 | 2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
50 | validations:
51 | required: true
52 | - type: textarea
53 | attributes:
54 | label: ✔️ Expected Behavior
55 | placeholder: Describe what you expected to happen.
56 | validations:
57 | required: false
58 | - type: textarea
59 | attributes:
60 | label: ❌ Actual Behavior
61 | placeholder: Describe what actually happened.
62 | validations:
63 | required: false
64 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: "\U0001F4E7 Discussions"
4 | url: https://github.com/fishaudio/fish-speech/discussions
5 | about: General discussions and request help from the community
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: "⭐ Feature or enhancement request"
2 | description: Propose something new.
3 | labels:
4 | - enhancement
5 | body:
6 | - type: checkboxes
7 | attributes:
8 | label: Self Checks
9 | description: "To make sure we get to you in time, please check the following :)"
10 | options:
11 | - label: I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find any relevant information that meets my needs. [English](https://speech.fish.audio/) [中文](https://speech.fish.audio/zh/) [日本語](https://speech.fish.audio/ja/) [Portuguese (Brazil)](https://speech.fish.audio/pt/)
12 | required: true
13 | - label: I have searched for existing issues [search for existing issues]([https://github.com/langgenius/dify/issues](https://github.com/fishaudio/fish-speech/issues)), including closed ones.
14 | required: true
15 | - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)).
16 | required: true
17 | - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
18 | required: true
19 | - label: "Please do not modify this template :) and fill in all the required fields."
20 | required: true
21 |
22 | - type: textarea
23 | attributes:
24 | label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
25 | description: |
26 | Describe the specific problem or scenario you’re facing in detail. For example:
27 | *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
28 | placeholder: Please describe the situation in as much detail as possible.
29 | validations:
30 | required: true
31 |
32 | - type: textarea
33 | attributes:
34 | label: 2. What is your suggested solution?
35 | description: |
36 | Provide a clear description of the feature or enhancement you'd like to propose.
37 | How would this feature solve your issue or improve the project?
38 | placeholder: Describe your idea or proposed solution here.
39 | validations:
40 | required: true
41 |
42 | - type: textarea
43 | attributes:
44 | label: 3. Additional context or comments
45 | description: |
46 | Any other relevant information, links, documents, or screenshots that provide clarity.
47 | Use this section for anything not covered above.
48 | placeholder: Add any extra details here.
49 | validations:
50 | required: false
51 |
52 | - type: checkboxes
53 | attributes:
54 | label: 4. Can you help us with this feature?
55 | description: |
56 | Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
57 | options:
58 | - label: I am interested in contributing to this feature.
59 | required: false
60 |
61 | - type: markdown
62 | attributes:
63 | value: |
64 | **Note:** Please submit only one request per issue to keep discussions focused and manageable.
65 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | **Is this PR adding new feature or fix a BUG?**
2 |
3 | Add feature / Fix BUG.
4 |
5 | **Is this pull request related to any issue? If yes, please link the issue.**
6 |
7 | #xxx
8 |
--------------------------------------------------------------------------------
/.github/workflows/build-docker-image.yml:
--------------------------------------------------------------------------------
1 | name: Build Image
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | tags:
8 | - "v*"
9 |
10 | jobs:
11 | build:
12 | runs-on: ubuntu-latest-16c64g
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Docker Buildx
16 | uses: docker/setup-buildx-action@v3
17 | - name: Get Version
18 | run: |
19 | if [[ $GITHUB_REF == refs/tags/v* ]]; then
20 | version=$(basename ${GITHUB_REF})
21 | else
22 | version=nightly
23 | fi
24 |
25 | echo "version=${version}" >> $GITHUB_ENV
26 | echo "Current version: ${version}"
27 |
28 | - name: Login to Docker Hub
29 | uses: docker/login-action@v3
30 | with:
31 | username: ${{ secrets.DOCKER_USER }}
32 | password: ${{ secrets.DOCKER_PAT }}
33 |
34 | - name: Build and Push Image
35 | uses: docker/build-push-action@v6
36 | with:
37 | context: .
38 | file: dockerfile
39 | platforms: linux/amd64
40 | push: true
41 | tags: |
42 | fishaudio/fish-speech:${{ env.version }}
43 | fishaudio/fish-speech:latest
44 | outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true
45 | cache-from: type=registry,ref=fishaudio/fish-speech:latest
46 | cache-to: type=inline
47 |
48 | - name: Build and Push Dev Image
49 | uses: docker/build-push-action@v6
50 | with:
51 | context: .
52 | file: dockerfile.dev
53 | platforms: linux/amd64
54 | push: true
55 | build-args: |
56 | VERSION=${{ env.version }}
57 | BASE_IMAGE=fishaudio/fish-speech:${{ env.version }}
58 | tags: |
59 | fishaudio/fish-speech:${{ env.version }}-dev
60 | fishaudio/fish-speech:latest-dev
61 | outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true
62 | cache-from: type=registry,ref=fishaudio/fish-speech:latest-dev
63 | cache-to: type=inline
64 |
65 | - name: Push README to Dockerhub
66 | uses: peter-evans/dockerhub-description@v4
67 | with:
68 | username: ${{ secrets.DOCKER_USER }}
69 | password: ${{ secrets.DOCKER_PAT }}
70 | repository: fishaudio/fish-speech
71 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | paths:
7 | - 'docs/**'
8 | - 'mkdocs.yml'
9 |
10 | permissions:
11 | contents: write
12 |
13 | jobs:
14 | deploy:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Configure Git Credentials
19 | run: |
20 | git config user.name github-actions[bot]
21 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com
22 | - uses: actions/setup-python@v5
23 | with:
24 | python-version: 3.x
25 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
26 | - uses: actions/cache@v4
27 | with:
28 | key: mkdocs-material-${{ env.cache_id }}
29 | path: .cache
30 | restore-keys: |
31 | mkdocs-material-
32 | - run: pip install -r docs/requirements.txt
33 | - run: mkdocs gh-deploy --force
34 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yml:
--------------------------------------------------------------------------------
1 | name: Close inactive issues
2 | on:
3 | schedule:
4 | - cron: "0 0 * * *"
5 |
6 | jobs:
7 | close-issues:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | issues: write
11 | pull-requests: write
12 | steps:
13 | - uses: actions/stale@v9
14 | with:
15 | days-before-issue-stale: 30
16 | days-before-issue-close: 14
17 | stale-issue-label: "stale"
18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
20 | days-before-pr-stale: 30
21 | days-before-pr-close: 30
22 | stale-pr-label: "stale"
23 | stale-pr-message: "This PR is stale because it has been open for 30 days with no activity."
24 | close-pr-message: "This PR was closed because it has been inactive for 30 days since being marked as stale."
25 | repo-token: ${{ secrets.GITHUB_TOKEN }}
26 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .pgx.*
3 | .pdm-python
4 | /fish_speech.egg-info
5 | __pycache__
6 | /results
7 | /data
8 | /*.test.sh
9 | *.filelist
10 | filelists
11 | /fish_speech/text/cmudict_cache.pickle
12 | /checkpoints
13 | /.vscode
14 | /data_server/target
15 | /*.npy
16 | /*.wav
17 | /*.mp3
18 | /*.lab
19 | /results
20 | /data
21 | /.idea
22 | ffmpeg.exe
23 | ffprobe.exe
24 | asr-label*
25 | /.cache
26 | /fishenv
27 | /.locale
28 | /demo-audios
29 | /references
30 | /example
31 | /faster_whisper
32 | /.gradio
33 | *log
34 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autoupdate_schedule: monthly
3 |
4 | repos:
5 | - repo: https://github.com/pycqa/isort
6 | rev: 6.0.1
7 | hooks:
8 | - id: isort
9 | args: [--profile=black]
10 |
11 | - repo: https://github.com/psf/black
12 | rev: 25.1.0
13 | hooks:
14 | - id: black
15 |
16 | - repo: https://github.com/pre-commit/pre-commit-hooks
17 | rev: v5.0.0
18 | hooks:
19 | - id: end-of-file-fixer
20 | - id: check-yaml
21 | - id: check-json
22 | - id: mixed-line-ending
23 | args: ["--fix=lf"]
24 | - id: check-added-large-files
25 | args: ["--maxkb=5000"]
26 |
--------------------------------------------------------------------------------
/.project-root:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/.project-root
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file for MkDocs projects
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the version of Python and other tools you might need
8 | build:
9 | os: ubuntu-22.04
10 | tools:
11 | python: "3.12"
12 |
13 | mkdocs:
14 | configuration: mkdocs.yml
15 |
16 | # Optionally declare the Python requirements required to build your docs
17 | python:
18 | install:
19 | - requirements: docs/requirements.txt
20 |
--------------------------------------------------------------------------------
/API_FLAGS.txt:
--------------------------------------------------------------------------------
1 | # --infer
2 | --api
3 | --listen 0.0.0.0:8080 \
4 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
5 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
6 | --decoder-config-name firefly_gan_vq
7 |
--------------------------------------------------------------------------------
/docker-compose.dev.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 |
3 | services:
4 | fish-speech:
5 | build:
6 | context: .
7 | dockerfile: dockerfile.dev
8 | container_name: fish-speech
9 | volumes:
10 | - ./:/exp
11 | deploy:
12 | resources:
13 | reservations:
14 | devices:
15 | - driver: nvidia
16 | count: all
17 | capabilities: [gpu]
18 | command: tail -f /dev/null
19 |
--------------------------------------------------------------------------------
/dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.12-slim-bookworm AS stage-1
2 | ARG TARGETARCH
3 |
4 | ARG HUGGINGFACE_MODEL=fish-speech-1.5
5 | ARG HF_ENDPOINT=https://huggingface.co
6 |
7 | WORKDIR /opt/fish-speech
8 |
9 | RUN set -ex \
10 | && pip install huggingface_hub \
11 | && HF_ENDPOINT=${HF_ENDPOINT} huggingface-cli download --resume-download fishaudio/${HUGGINGFACE_MODEL} --local-dir checkpoints/${HUGGINGFACE_MODEL}
12 |
13 | FROM python:3.12-slim-bookworm
14 | ARG TARGETARCH
15 |
16 | ARG DEPENDENCIES=" \
17 | ca-certificates \
18 | libsox-dev \
19 | build-essential \
20 | cmake \
21 | libasound-dev \
22 | portaudio19-dev \
23 | libportaudio2 \
24 | libportaudiocpp0 \
25 | ffmpeg"
26 |
27 | RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
28 | --mount=type=cache,target=/var/lib/apt,sharing=locked \
29 | set -ex \
30 | && rm -f /etc/apt/apt.conf.d/docker-clean \
31 | && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
32 | && apt-get update \
33 | && apt-get -y install --no-install-recommends ${DEPENDENCIES} \
34 | && echo "no" | dpkg-reconfigure dash
35 |
36 | WORKDIR /opt/fish-speech
37 |
38 | COPY . .
39 |
40 | RUN --mount=type=cache,target=/root/.cache,sharing=locked \
41 | set -ex \
42 | && pip install -e .[stable]
43 |
44 | COPY --from=stage-1 /opt/fish-speech/checkpoints /opt/fish-speech/checkpoints
45 |
46 | ENV GRADIO_SERVER_NAME="0.0.0.0"
47 |
48 | EXPOSE 7860
49 |
50 | CMD ["./entrypoint.sh"]
51 |
--------------------------------------------------------------------------------
/dockerfile.dev:
--------------------------------------------------------------------------------
1 | ARG VERSION=dev
2 | ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION}
3 |
4 | FROM ${BASE_IMAGE}
5 |
6 | ARG TOOLS=" \
7 | git \
8 | curl \
9 | build-essential \
10 | ffmpeg \
11 | libsm6 \
12 | libxext6 \
13 | libjpeg-dev \
14 | zlib1g-dev \
15 | aria2 \
16 | zsh \
17 | openssh-server \
18 | sudo \
19 | protobuf-compiler \
20 | libasound-dev \
21 | portaudio19-dev \
22 | libportaudio2 \
23 | libportaudiocpp0 \
24 | cmake"
25 |
26 | RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
27 | --mount=type=cache,target=/var/lib/apt,sharing=locked \
28 | set -ex \
29 | && apt-get update \
30 | && apt-get -y install --no-install-recommends ${TOOLS}
31 |
32 | # Install oh-my-zsh so your terminal looks nice
33 | RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended
34 |
35 | # Set zsh as default shell
36 | RUN chsh -s /usr/bin/zsh
37 | ENV SHELL=/usr/bin/zsh
38 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | speech.fish.audio
2 |
--------------------------------------------------------------------------------
/docs/README.ja.md:
--------------------------------------------------------------------------------
1 |
2 |
Fish Speech
3 |
4 | [English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md)
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |

18 |
19 |
20 |
21 |
32 |
33 | このコードリポジトリはApache 2.0ライセンスの下で公開されており、モデルはCC-BY-NC-SA-4.0ライセンスの下で公開されています。詳細については[LICENSE](../LICENSE)をご参照ください。
34 |
35 | ---
36 |
37 | ## 機能
38 |
39 | 1. **ゼロショット & フューショット TTS**:10〜30 秒の音声サンプルを入力して、高品質の TTS 出力を生成します。**詳細は [音声クローンのベストプラクティス](https://docs.fish.audio/text-to-speech/voice-clone-best-practices) を参照してください。**
40 | 2. **多言語 & クロスリンガル対応**:多言語テキストを入力ボックスにコピーペーストするだけで、言語を気にする必要はありません。現在、英語、日本語、韓国語、中国語、フランス語、ドイツ語、アラビア語、スペイン語に対応しています。
41 | 3. **音素依存なし**:このモデルは強力な汎化能力を持ち、TTS に音素を必要としません。あらゆる言語スクリプトに対応可能です。
42 | 4. **高精度**:5 分間の英語テキストに対し、CER(文字誤り率)と WER(単語誤り率)は約 2%の精度を達成します。
43 | 5. **高速**:fish-tech アクセラレーションにより、Nvidia RTX 4060 ラップトップではリアルタイムファクターが約 1:5、Nvidia RTX 4090 では約 1:15 です。
44 | 6. **WebUI 推論**:使いやすい Gradio ベースの Web ユーザーインターフェースを搭載し、Chrome、Firefox、Edge などのブラウザに対応しています。
45 | 7. **GUI 推論**:PyQt6 のグラフィカルインターフェースを提供し、API サーバーとシームレスに連携します。Linux、Windows、macOS に対応しています。[GUI を見る](https://github.com/AnyaCoder/fish-speech-gui)。
46 | 8. **デプロイしやすい**:Linux、Windows、macOS にネイティブ対応した推論サーバーを簡単にセットアップでき、速度の低下を最小限に抑えます。
47 |
48 | ## 免責事項
49 |
50 | コードベースの違法な使用については一切責任を負いません。DMCA(デジタルミレニアム著作権法)およびその他の関連法については、地域の法律を参照してください。
51 |
52 | ## オンラインデモ
53 |
54 | [Fish Audio](https://fish.audio)
55 |
56 | ## ローカル推論のクイックスタート
57 |
58 | [inference.ipynb](/inference.ipynb)
59 |
60 | ## ビデオ
61 |
62 | #### V1.5 デモビデオ: [Watch the video on X (Twitter).](https://x.com/FishAudio/status/1864370933496205728)
63 |
64 | ## ドキュメント
65 |
66 | - [英語](https://speech.fish.audio/)
67 | - [中文](https://speech.fish.audio/zh/)
68 | - [日本語](https://speech.fish.audio/ja/)
69 | - [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/)
70 |
71 | ## サンプル (2024/10/02 V1.4)
72 |
73 | - [英語](https://speech.fish.audio/samples/)
74 | - [中文](https://speech.fish.audio/zh/samples/)
75 | - [日本語](https://speech.fish.audio/ja/samples/)
76 | - [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/samples/)
77 |
78 | ## クレジット
79 |
80 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
81 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
82 | - [GPT VITS](https://github.com/innnky/gpt-vits)
83 | - [MQTTS](https://github.com/b04901014/MQTTS)
84 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
85 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
86 |
87 | ## スポンサー
88 |
89 |
96 |
--------------------------------------------------------------------------------
/docs/README.ko.md:
--------------------------------------------------------------------------------
1 |
2 |
Fish Speech
3 |
4 | [English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어**
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |

18 |
19 |
20 |
21 |
32 |
33 | 이 코드 저장소는 Apache 2.0 라이선스 하에 배포되며, 모델은 CC-BY-NC-SA-4.0 라이선스 하에 배포됩니다. 자세한 내용은 [LICENSE](../LICENSE)를 참조하십시오.
34 |
35 | ---
36 |
37 | ## 기능
38 |
39 | 1. **Zero-shot & Few-shot TTS:** 10초에서 30초의 음성 샘플을 입력하여 고품질의 TTS 출력을 생성합니다. **자세한 가이드는 [모범 사례](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)를 참조하시길 바랍니다.**
40 |
41 | 2. **다국어 및 교차 언어 지원:** 다국어 걱정 없이, 텍스트를 입력창에 복사하여 붙여넣기만 하면 됩니다. 현재 영어, 일본어, 한국어, 중국어, 프랑스어, 독일어, 아랍어, 스페인어를 지원합니다.
42 |
43 | 3. **음소 의존성 제거:** 이 모델은 강력한 일반화 능력을 가지고 있으며, TTS가 음소에 의존하지 않습니다. 모든 언어 스크립트 텍스트를 손쉽게 처리할 수 있습니다.
44 |
45 | 4. **높은 정확도:** 영어 텍스트 기준 5분 기준에서 단, 2%의 문자 오류율(CER)과 단어 오류율(WER)을 달성합니다.
46 |
47 | 5. **빠른 속도:** fish-tech 가속을 통해 실시간 인자(RTF)는 Nvidia RTX 4060 노트북에서는 약 1:5, Nvidia RTX 4090에서는 1:15입니다.
48 |
49 | 6. **웹 UI 추론:** Chrome, Firefox, Edge 등 다양한 브라우저에서 호환되는 Gradio 기반의 사용하기 쉬운 웹 UI를 제공합니다.
50 |
51 | 7. **GUI 추론:** PyQt6 그래픽 인터페이스를 제공하여 API 서버와 원활하게 작동합니다. Linux, Windows 및 macOS를 지원합니다. [GUI 참조](https://github.com/AnyaCoder/fish-speech-gui).
52 |
53 | 8. **배포 친화적:** Linux, Windows, macOS에서 네이티브로 지원되는 추론 서버를 쉽게 설정할 수 있어 속도 손실을 최소화합니다.
54 |
55 | ## 면책 조항
56 |
57 | 이 코드베이스의 불법적 사용에 대해 어떠한 책임도 지지 않습니다. DMCA 및 관련 법률에 대한 로컬 법률을 참조하십시오.
58 |
59 | ## 온라인 데모
60 |
61 | [Fish Audio](https://fish.audio)
62 |
63 | ## 로컬 추론을 위한 빠른 시작
64 |
65 | [inference.ipynb](/inference.ipynb)
66 |
67 | ## 영상
68 |
69 | #### V1.5 데모 영상: [Watch the video on X (Twitter).](https://x.com/FishAudio/status/1864370933496205728)
70 |
71 | ## 문서
72 |
73 | - [English](https://speech.fish.audio/)
74 | - [中文](https://speech.fish.audio/zh/)
75 | - [日本語](https://speech.fish.audio/ja/)
76 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/)
77 | - [한국어](https://speech.fish.audio/ko/)
78 |
79 | ## Samples (2024/10/02 V1.4)
80 |
81 | - [English](https://speech.fish.audio/samples/)
82 | - [中文](https://speech.fish.audio/zh/samples/)
83 | - [日本語](https://speech.fish.audio/ja/samples/)
84 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
85 | - [한국어](https://speech.fish.audio/ko/samples/)
86 |
87 | ## Credits
88 |
89 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
90 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
91 | - [GPT VITS](https://github.com/innnky/gpt-vits)
92 | - [MQTTS](https://github.com/b04901014/MQTTS)
93 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
94 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
95 |
96 | ## Sponsor
97 |
98 |
105 |
--------------------------------------------------------------------------------
/docs/README.zh.md:
--------------------------------------------------------------------------------
1 |
2 |
Fish Speech
3 |
4 | [English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md)
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |

18 |
19 |
20 |
21 |
22 |
35 |
36 | 此代码库根据 Apache 2.0 许可证发布,模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](../LICENSE) 了解更多细节.
37 |
38 | ---
39 |
40 | ## 特性
41 |
42 | 1. **零样本 & 小样本 TTS**:输入 10 到 30 秒的声音样本即可生成高质量的 TTS 输出。**详见 [语音克隆最佳实践指南](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。**
43 | 2. **多语言 & 跨语言支持**:只需复制并粘贴多语言文本到输入框中,无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。
44 | 3. **无音素依赖**:模型具备强大的泛化能力,不依赖音素进行 TTS,能够处理任何文字表示的语言。
45 | 4. **高准确率**:在 5 分钟的英文文本上,达到了约 2% 的 CER(字符错误率)和 WER(词错误率)。
46 | 5. **快速**:通过 fish-tech 加速,在 Nvidia RTX 4060 笔记本上的实时因子约为 1:5,在 Nvidia RTX 4090 上约为 1:15。
47 | 6. **WebUI 推理**:提供易于使用的基于 Gradio 的网页用户界面,兼容 Chrome、Firefox、Edge 等浏览器。
48 | 7. **GUI 推理**:提供 PyQt6 图形界面,与 API 服务器无缝协作。支持 Linux、Windows 和 macOS。[查看 GUI](https://github.com/AnyaCoder/fish-speech-gui)。
49 | 8. **易于部署**:轻松设置推理服务器,原生支持 Linux、Windows 和 macOS,最大程度减少速度损失。
50 |
51 | ## 免责声明
52 |
53 | 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
54 |
55 | ## 在线 DEMO
56 |
57 | [Fish Audio](https://fish.audio)
58 |
59 | ## 快速开始本地推理
60 |
61 | [inference.ipynb](/inference.ipynb)
62 |
63 | ## 视频
64 |
65 | #### 1.5 介绍: https://www.bilibili.com/video/BV1EKiDYBE4o
66 |
67 | #### 1.4 介绍: https://www.bilibili.com/video/BV1pu46eVEk7
68 |
69 | #### 1.2 介绍: https://www.bilibili.com/video/BV1wz421B71D
70 |
71 | #### 1.1 介绍: https://www.bilibili.com/video/BV1zJ4m1K7cj
72 |
73 | ## 文档
74 |
75 | - [English](https://speech.fish.audio/)
76 | - [中文](https://speech.fish.audio/zh/)
77 | - [日本語](https://speech.fish.audio/ja/)
78 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/)
79 |
80 | ## 例子 (2024/10/02 V1.4)
81 |
82 | - [English](https://speech.fish.audio/samples/)
83 | - [中文](https://speech.fish.audio/zh/samples/)
84 | - [日本語](https://speech.fish.audio/ja/samples/)
85 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
86 |
87 | ## 鸣谢
88 |
89 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
90 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
91 | - [GPT VITS](https://github.com/innnky/gpt-vits)
92 | - [MQTTS](https://github.com/b04901014/MQTTS)
93 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
94 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
95 |
96 | ## 赞助
97 |
98 |
105 |
--------------------------------------------------------------------------------
/docs/assets/figs/VS_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/VS_1.jpg
--------------------------------------------------------------------------------
/docs/assets/figs/VS_1_pt-BR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/VS_1_pt-BR.png
--------------------------------------------------------------------------------
/docs/assets/figs/agent_gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/agent_gradio.png
--------------------------------------------------------------------------------
/docs/assets/figs/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/diagram.png
--------------------------------------------------------------------------------
/docs/assets/figs/diagrama.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/diagrama.png
--------------------------------------------------------------------------------
/docs/assets/figs/logo-circle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/logo-circle.png
--------------------------------------------------------------------------------
/docs/en/finetune.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning
2 |
3 | Obviously, when you opened this page, you were not satisfied with the performance of the few-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset.
4 |
5 | In current version, you only need to finetune the 'LLAMA' part.
6 |
7 | ## Fine-tuning LLAMA
8 | ### 1. Prepare the dataset
9 |
10 | ```
11 | .
12 | ├── SPK1
13 | │ ├── 21.15-26.44.lab
14 | │ ├── 21.15-26.44.mp3
15 | │ ├── 27.51-29.98.lab
16 | │ ├── 27.51-29.98.mp3
17 | │ ├── 30.1-32.71.lab
18 | │ └── 30.1-32.71.mp3
19 | └── SPK2
20 | ├── 38.79-40.85.lab
21 | └── 38.79-40.85.mp3
22 | ```
23 |
24 | You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extensions `.lab`.
25 |
26 | !!! info "Dataset Format"
27 | The `.lab` annotation file only needs to contain the transcription of the audio, with no special formatting required. For example, if `hi.mp3` says "Hello, goodbye," then the `hi.lab` file would contain a single line of text: "Hello, goodbye."
28 |
29 | !!! warning
30 | It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this.
31 |
32 | ```bash
33 | fap loudness-norm data-raw data --clean
34 | ```
35 |
36 |
37 | ### 2. Batch extraction of semantic tokens
38 |
39 | Make sure you have downloaded the VQGAN weights. If not, run the following command:
40 |
41 | ```bash
42 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
43 | ```
44 |
45 | You can then run the following command to extract semantic tokens:
46 |
47 | ```bash
48 | python tools/vqgan/extract_vq.py data \
49 | --num-workers 1 --batch-size 16 \
50 | --config-name "firefly_gan_vq" \
51 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
52 | ```
53 |
54 | !!! note
55 | You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit.
56 | For the VITS format, you can specify a file list using `--filelist xxx.list`.
57 |
58 | This command will create `.npy` files in the `data` directory, as shown below:
59 |
60 | ```
61 | .
62 | ├── SPK1
63 | │ ├── 21.15-26.44.lab
64 | │ ├── 21.15-26.44.mp3
65 | │ ├── 21.15-26.44.npy
66 | │ ├── 27.51-29.98.lab
67 | │ ├── 27.51-29.98.mp3
68 | │ ├── 27.51-29.98.npy
69 | │ ├── 30.1-32.71.lab
70 | │ ├── 30.1-32.71.mp3
71 | │ └── 30.1-32.71.npy
72 | └── SPK2
73 | ├── 38.79-40.85.lab
74 | ├── 38.79-40.85.mp3
75 | └── 38.79-40.85.npy
76 | ```
77 |
78 | ### 3. Pack the dataset into protobuf
79 |
80 | ```bash
81 | python tools/llama/build_dataset.py \
82 | --input "data" \
83 | --output "data/protos" \
84 | --text-extension .lab \
85 | --num-workers 16
86 | ```
87 |
88 | After the command finishes executing, you should see the `quantized-dataset-ft.protos` file in the `data` directory.
89 |
90 | ### 4. Finally, fine-tuning with LoRA
91 |
92 | Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
93 |
94 | ```bash
95 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
96 | ```
97 |
98 | Finally, you can start the fine-tuning by running the following command:
99 |
100 | ```bash
101 | python fish_speech/train.py --config-name text2semantic_finetune \
102 | project=$project \
103 | +lora@model.model.lora_config=r_8_alpha_16
104 | ```
105 |
106 | !!! note
107 | You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
108 |
109 | !!! note
110 | For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues.
111 |
112 | After training is complete, you can refer to the [inference](inference.md) section to generate speech.
113 |
114 | !!! info
115 | By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
116 | If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
117 |
118 | After training, you need to convert the LoRA weights to regular weights before performing inference.
119 |
120 | ```bash
121 | python tools/llama/merge_lora.py \
122 | --lora-config r_8_alpha_16 \
123 | --base-weight checkpoints/fish-speech-1.5 \
124 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \
125 | --output checkpoints/fish-speech-1.5-yth-lora/
126 | ```
127 | !!! note
128 | You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.
129 |
--------------------------------------------------------------------------------
/docs/en/start_agent.md:
--------------------------------------------------------------------------------
1 | # Start Agent
2 |
3 | ## Requirements
4 |
5 | - GPU memory: At least 8GB(under quanization), 16GB or more is recommanded.
6 | - Disk usage: 10GB
7 |
8 | ## Download Model
9 |
10 | You can get the model by:
11 |
12 | ```bash
13 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
14 | ```
15 |
16 | Put them in the 'checkpoints' folder.
17 |
18 | You also need the fish-speech model which you can download instructed by [inference](inference.md).
19 |
20 | So there will be 2 folder in the checkpoints.
21 |
22 | The `checkpoints/fish-speech-1.4` and `checkpoints/fish-agent-v0.1-3b`
23 |
24 | ## Environment Prepare
25 |
26 | If you already have Fish-speech, you can directly use by adding the follow instruction:
27 | ```bash
28 | pip install cachetools
29 | ```
30 |
31 | !!! note
32 | Please use the Python version below 3.12 for compile.
33 |
34 | If you don't have, please use the below commands to build your environment:
35 |
36 | ```bash
37 | sudo apt-get install portaudio19-dev
38 |
39 | pip install -e .[stable]
40 | ```
41 |
42 | ## Launch The Agent Demo.
43 |
44 | To build fish-agent, please use the command below under the main folder:
45 |
46 | ```bash
47 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
48 | ```
49 |
50 | The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.
51 |
52 | It won't compile at once (remember).
53 |
54 | Then open another terminal and use the command:
55 |
56 | ```bash
57 | python -m tools.e2e_webui
58 | ```
59 |
60 | This will create a Gradio WebUI on the device.
61 |
62 | When you first use the model, it will come to compile (if the `--compile` is True) for a short time, so please wait with patience.
63 |
64 | ## Gradio Webui
65 |
66 |
67 |
68 |
69 | Have a good time!
70 |
71 | ## Performance
72 |
73 | Under our test, a 4060 laptop just barely runs, but is very stretched, which is only about 8 tokens/s. The 4090 is around 95 tokens/s under compile, which is what we recommend.
74 |
75 | # About Agent
76 |
77 | The demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request.
78 |
--------------------------------------------------------------------------------
/docs/ja/finetune.md:
--------------------------------------------------------------------------------
1 | # 微調整
2 |
3 | 明らかに、このページを開いたとき、few-shot 事前トレーニングモデルのパフォーマンスに満足していなかったことでしょう。データセット上でのパフォーマンスを向上させるためにモデルを微調整したいと考えています。
4 |
5 | 現在のバージョンでは、「LLAMA」部分のみを微調整する必要があります。
6 |
7 | ## LLAMAの微調整
8 | ### 1. データセットの準備
9 |
10 | ```
11 | .
12 | ├── SPK1
13 | │ ├── 21.15-26.44.lab
14 | │ ├── 21.15-26.44.mp3
15 | │ ├── 27.51-29.98.lab
16 | │ ├── 27.51-29.98.mp3
17 | │ ├── 30.1-32.71.lab
18 | │ └── 30.1-32.71.mp3
19 | └── SPK2
20 | ├── 38.79-40.85.lab
21 | └── 38.79-40.85.mp3
22 | ```
23 |
24 | データセットを上記の形式に変換し、「data」ディレクトリに配置する必要があります。音声ファイルの拡張子は「.mp3」、「.wav」、または「.flac」にすることができ、注釈ファイルの拡張子は「.lab」にする必要があります。
25 |
26 | !!! info
27 | 標準ファイル `.lab` には、音声の転写テキストのみを含め、特別なフォーマットは必要ありません。例えば、`hi.mp3` で「こんにちは、さようなら」と言っている場合、`hi.lab` ファイルには「こんにちは、さようなら」という一行のテキストを含めるだけです。
28 |
29 | !!! warning
30 | データセットにラウドネス正規化を適用することをお勧めします。これを行うには、[fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。
31 |
32 | ```bash
33 | fap loudness-norm data-raw data --clean
34 | ```
35 |
36 |
37 | ### 2. セマンティックトークンのバッチ抽出
38 |
39 | VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
40 |
41 | ```bash
42 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
43 | ```
44 |
45 | 次に、次のコマンドを実行してセマンティックトークンを抽出できます。
46 |
47 | ```bash
48 | python tools/vqgan/extract_vq.py data \
49 | --num-workers 1 --batch-size 16 \
50 | --config-name "firefly_gan_vq" \
51 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
52 | ```
53 |
54 | !!! note
55 | `--num-workers` と `--batch-size` を調整して抽出速度を上げることができますが、GPUメモリの制限を超えないようにしてください。
56 | VITS形式の場合、`--filelist xxx.list` を使用してファイルリストを指定できます。
57 |
58 | このコマンドは、`data`ディレクトリに`.npy`ファイルを作成します。以下のように表示されます。
59 |
60 | ```
61 | .
62 | ├── SPK1
63 | │ ├── 21.15-26.44.lab
64 | │ ├── 21.15-26.44.mp3
65 | │ ├── 21.15-26.44.npy
66 | │ ├── 27.51-29.98.lab
67 | │ ├── 27.51-29.98.mp3
68 | │ ├── 27.51-29.98.npy
69 | │ ├── 30.1-32.71.lab
70 | │ ├── 30.1-32.71.mp3
71 | │ └── 30.1-32.71.npy
72 | └── SPK2
73 | ├── 38.79-40.85.lab
74 | ├── 38.79-40.85.mp3
75 | └── 38.79-40.85.npy
76 | ```
77 |
78 | ### 3. データセットをprotobufにパックする
79 |
80 | ```bash
81 | python tools/llama/build_dataset.py \
82 | --input "data" \
83 | --output "data/protos" \
84 | --text-extension .lab \
85 | --num-workers 16
86 | ```
87 |
88 | コマンドの実行が完了すると、`data`ディレクトリに`quantized-dataset-ft.protos`ファイルが表示されます。
89 |
90 | ### 4. 最後に、LoRAを使用して微調整する
91 |
92 | 同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
93 |
94 | ```bash
95 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
96 | ```
97 |
98 | 最後に、次のコマンドを実行して微調整を開始できます。
99 |
100 | ```bash
101 | python fish_speech/train.py --config-name text2semantic_finetune \
102 | project=$project \
103 | +lora@model.model.lora_config=r_8_alpha_16
104 | ```
105 |
106 | !!! note
107 | `fish_speech/configs/text2semantic_finetune.yaml` を変更して、`batch_size`、`gradient_accumulation_steps` などのトレーニングパラメータを変更し、GPUメモリに適合させることができます。
108 |
109 | !!! note
110 | Windowsユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。
111 |
112 | トレーニングが完了したら、[推論](inference.md)セクションを参照し、音声を生成します。
113 |
114 | !!! info
115 | デフォルトでは、モデルは話者の発話パターンのみを学習し、音色は学習しません。音色の安定性を確保するためにプロンプトを使用する必要があります。
116 | 音色を学習したい場合は、トレーニングステップ数を増やすことができますが、これにより過学習が発生する可能性があります。
117 |
118 | トレーニングが完了したら、推論を行う前にLoRAの重みを通常の重みに変換する必要があります。
119 |
120 | ```bash
121 | python tools/llama/merge_lora.py \
122 | --lora-config r_8_alpha_16 \
123 | --base-weight checkpoints/fish-speech-1.5 \
124 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \
125 | --output checkpoints/fish-speech-1.5-yth-lora/
126 | ```
127 | !!! note
128 | 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。
129 |
--------------------------------------------------------------------------------
/docs/ja/inference.md:
--------------------------------------------------------------------------------
1 | # 推論
2 |
3 | 推論は、コマンドライン、HTTP API、および Web UI をサポートしています。
4 |
5 | !!! note
6 | 全体として、推論は次のいくつかの部分で構成されています:
7 |
8 | 1. VQGANを使用して、与えられた約10秒の音声をエンコードします。
9 | 2. エンコードされたセマンティックトークンと対応するテキストを例として言語モデルに入力します。
10 | 3. 新しいテキストが与えられた場合、モデルに対応するセマンティックトークンを生成させます。
11 | 4. 生成されたセマンティックトークンをVITS / VQGANに入力してデコードし、対応する音声を生成します。
12 |
13 | ## モデルをダウンロード
14 | 必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。
15 |
16 | ```bash
17 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
18 | ```
19 |
20 | ## コマンドライン推論
21 | ### 1. 音声からプロンプトを生成する:
22 |
23 | !!! note
24 | モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。
25 |
26 | !!! warning "将来のバージョンに関する警告"
27 | 元のパス(tools/vqgan/infernce.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。
28 |
29 | ```bash
30 | python fish_speech/models/vqgan/inference.py \
31 | -i "paimon.wav" \
32 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
33 | ```
34 |
35 | `fake.npy`ファイルが生成されるはずです。
36 |
37 | ### 2. テキストからセマンティックトークンを生成する:
38 |
39 | !!! warning "将来のバージョンに関する警告"
40 | 元のパス(tools/llama/generate.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。
41 |
42 | ```bash
43 | python fish_speech/models/text2semantic/inference.py \
44 | --text "変換したいテキスト" \
45 | --prompt-text "参照テキスト" \
46 | --prompt-tokens "fake.npy" \
47 | --checkpoint-path "checkpoints/fish-speech-1.5" \
48 | --num-samples 2 \
49 | --compile
50 | ```
51 |
52 | このコマンドは、作業ディレクトリに`codes_N`ファイルを作成します。ここで、N は 0 から始まる整数です。
53 |
54 | !!! note
55 | `--compile`を使用して CUDA カーネルを融合し、より高速な推論を実現することができます(約 30 トークン/秒 -> 約 500 トークン/秒)。
56 | それに対応して、加速を使用しない場合は、`--compile`パラメータをコメントアウトできます。
57 |
58 | !!! info
59 | bf16 をサポートしていない GPU の場合、`--half`パラメータを使用する必要があるかもしれません。
60 |
61 | ### 3. セマンティックトークンから音声を生成する:
62 |
63 | #### VQGAN デコーダー
64 |
65 | !!! warning "将来のバージョンに関する警告"
66 | 元のパス(tools/vqgan/infernce.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。
67 |
68 | ```bash
69 | python fish_speech/models/vqgan/inference.py \
70 | -i "codes_0.npy" \
71 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
72 | ```
73 |
74 | ## HTTP API 推論
75 |
76 | 推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
77 |
78 | ```bash
79 | python -m tools.api_server \
80 | --listen 0.0.0.0:8080 \
81 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
82 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
83 | --decoder-config-name firefly_gan_vq
84 | ```
85 |
86 | > 推論を高速化したい場合は、`--compile` パラメータを追加できます。
87 |
88 | その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
89 |
90 | 以下は、`tools/api_client.py` を使用してリクエストを送信する例です。
91 |
92 | ```bash
93 | python -m tools.api_client \
94 | --text "入力するテキスト" \
95 | --reference_audio "参照音声へのパス" \
96 | --reference_text "参照音声テキスト" \
97 | --streaming True
98 | ```
99 |
100 | 上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
101 |
102 | !!! info
103 | 使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください
104 |
105 | ## WebUI 推論
106 |
107 | 次のコマンドを使用して WebUI を起動できます:
108 |
109 | ```bash
110 | python -m tools.run_webui \
111 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
112 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
113 | --decoder-config-name firefly_gan_vq
114 | ```
115 | > 推論を高速化したい場合は、`--compile` パラメータを追加できます。
116 |
117 | !!! note
118 | ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
119 |
120 | !!! note
121 | Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。
122 |
123 | お楽しみください!
124 |
--------------------------------------------------------------------------------
/docs/ja/start_agent.md:
--------------------------------------------------------------------------------
1 | # エージェントの開始
2 |
3 | !!! note
4 | もしあなたがネイティブ・スピーカーで、翻訳に問題があるとお感じでしたら、issueかpull requestをお送りください!
5 |
6 | ## 要件
7 |
8 | - GPUメモリ: 最低8GB(量子化使用時)、16GB以上推奨
9 | - ディスク使用量: 10GB
10 |
11 | ## モデルのダウンロード
12 |
13 | 以下のコマンドでモデルを取得できます:
14 |
15 | ```bash
16 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
17 | ```
18 |
19 | これらを'checkpoints'フォルダに配置してください。
20 |
21 | また、[inference](inference.md)の手順に従ってfish-speechモデルもダウンロードする必要があります。
22 |
23 | checkpointsには2つのフォルダが必要です。
24 |
25 | `checkpoints/fish-speech-1.4`と`checkpoints/fish-agent-v0.1-3b`です。
26 |
27 | ## 環境準備
28 |
29 | すでにFish-speechをお持ちの場合は、以下の指示を追加するだけで直接使用できます:
30 | ```bash
31 | pip install cachetools
32 | ```
33 |
34 | !!! note
35 | コンパイルにはPythonバージョン3.12未満を使用してください。
36 |
37 | お持ちでない場合は、以下のコマンドで環境を構築してください:
38 |
39 | ```bash
40 | sudo apt-get install portaudio19-dev
41 |
42 | pip install -e .[stable]
43 | ```
44 |
45 | ## エージェントデモの起動
46 |
47 | fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
48 |
49 | ```bash
50 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
51 | ```
52 |
53 | `--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。
54 |
55 | 一度にコンパイルは行われません(覚えておいてください)。
56 |
57 | 次に、別のターミナルを開いて以下のコマンドを使用します:
58 |
59 | ```bash
60 | python -m tools.e2e_webui
61 | ```
62 |
63 | これにより、デバイス上にGradio WebUIが作成されます。
64 |
65 | モデルを初めて使用する際は、(`--compile`がTrueの場合)しばらくコンパイルが行われますので、お待ちください。
66 |
67 | ## Gradio Webui
68 |
69 |
70 |
71 |
72 | お楽しみください!
73 |
74 | ## パフォーマンス
75 |
76 | テストでは、4060搭載のラップトップではかろうじて動作しますが、非常に厳しい状態で、約8トークン/秒程度です。4090ではコンパイル時に約95トークン/秒で、これが推奨環境です。
77 |
78 | # エージェントについて
79 |
80 | このデモは初期アルファテストバージョンで、推論速度の最適化が必要で、修正を待つバグが多数あります。バグを発見した場合や修正したい場合は、issueやプルリクエストをいただけると大変嬉しく思います。
81 |
--------------------------------------------------------------------------------
/docs/ko/finetune.md:
--------------------------------------------------------------------------------
1 | # 파인튜닝
2 |
3 | 이 페이지를 열었다는 것은, 사전 학습된 퓨샷(Few-shot) 모델의 성능에 만족하지 못했다는 의미일 것입니다. 데이터셋의 성능을 향상시키기 위해 모델을 파인튜닝하고 싶으시겠죠.
4 |
5 | 현재 버전에서는 'LLAMA' 부분만 파인튜닝하시면 됩니다.
6 |
7 | ## LLAMA 파인튜닝
8 | ### 1. 데이터셋 준비
9 |
10 | ```
11 | .
12 | ├── SPK1
13 | │ ├── 21.15-26.44.lab
14 | │ ├── 21.15-26.44.mp3
15 | │ ├── 27.51-29.98.lab
16 | │ ├── 27.51-29.98.mp3
17 | │ ├── 30.1-32.71.lab
18 | │ └── 30.1-32.71.mp3
19 | └── SPK2
20 | ├── 38.79-40.85.lab
21 | └── 38.79-40.85.mp3
22 | ```
23 |
24 | 위와 같은 형식으로 데이터셋을 변환하여 `data` 디렉토리 안에 배치하세요. 오디오 파일의 확장자는 `.mp3`, `.wav`, `.flac` 중 하나여야 하며, 주석 파일은 `.lab` 확장자를 사용해야 합니다.
25 |
26 | !!! info "데이터셋 형식"
27 | `.lab` 주석 파일은 오디오의 전사 내용만 포함하면 되며, 특별한 형식이 필요하지 않습니다. 예를 들어, `hi.mp3`에서 "Hello, goodbye"라는 대사를 말한다면, `hi.lab` 파일에는 "Hello, goodbye"라는 한 줄의 텍스트만 있어야 합니다.
28 |
29 | !!! warning
30 | 데이터셋에 대한 음량 정규화(loudness normalization)를 적용하는 것이 좋습니다. 이를 위해 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess)를 사용할 수 있습니다.
31 |
32 | ```bash
33 | fap loudness-norm data-raw data --clean
34 | ```
35 |
36 | ### 2. 시맨틱 토큰 배치 추출
37 |
38 | VQGAN 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
39 |
40 | ```bash
41 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
42 | ```
43 |
44 | 이후 시맨틱 토큰을 추출하기 위해 아래 명령어를 실행하세요:
45 |
46 | ```bash
47 | python tools/vqgan/extract_vq.py data \
48 | --num-workers 1 --batch-size 16 \
49 | --config-name "firefly_gan_vq" \
50 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
51 | ```
52 |
53 | !!! note
54 | 추출 속도를 높이기 위해 `--num-workers`와 `--batch-size` 값을 조정할 수 있지만, GPU 메모리 한도를 초과하지 않도록 주의하세요.
55 | VITS 형식의 경우, `--filelist xxx.list`를 사용하여 파일 목록을 지정할 수 있습니다.
56 |
57 | 이 명령을 실행하면 `data` 디렉토리 안에 `.npy` 파일이 생성됩니다. 다음과 같이 표시됩니다:
58 |
59 | ```
60 | .
61 | ├── SPK1
62 | │ ├── 21.15-26.44.lab
63 | │ ├── 21.15-26.44.mp3
64 | │ ├── 21.15-26.44.npy
65 | │ ├── 27.51-29.98.lab
66 | │ ├── 27.51-29.98.mp3
67 | │ ├── 27.51-29.98.npy
68 | │ ├── 30.1-32.71.lab
69 | │ ├── 30.1-32.71.mp3
70 | │ └── 30.1-32.71.npy
71 | └── SPK2
72 | ├── 38.79-40.85.lab
73 | ├── 38.79-40.85.mp3
74 | └── 38.79-40.85.npy
75 | ```
76 |
77 | ### 3. 데이터셋을 protobuf로 패킹
78 |
79 | ```bash
80 | python tools/llama/build_dataset.py \
81 | --input "data" \
82 | --output "data/protos" \
83 | --text-extension .lab \
84 | --num-workers 16
85 | ```
86 |
87 | 명령이 완료되면 `data` 디렉토리 안에 `quantized-dataset-ft.protos` 파일이 생성됩니다.
88 |
89 | ### 4. 마지막으로, LoRA를 이용한 파인튜닝
90 |
91 | 마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
92 |
93 | ```bash
94 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
95 | ```
96 |
97 | 마지막으로, 아래 명령어를 실행하여 파인튜닝을 시작할 수 있습니다:
98 |
99 | ```bash
100 | python fish_speech/train.py --config-name text2semantic_finetune \
101 | project=$project \
102 | +lora@model.model.lora_config=r_8_alpha_16
103 | ```
104 |
105 | !!! note
106 | `batch_size`, `gradient_accumulation_steps` 등의 학습 매개변수를 GPU 메모리에 맞게 조정하려면 `fish_speech/configs/text2semantic_finetune.yaml` 파일을 수정할 수 있습니다.
107 |
108 | !!! note
109 | Windows 사용자의 경우, `nccl` 문제를 피하려면 `trainer.strategy.process_group_backend=gloo`를 사용할 수 있습니다.
110 |
111 | 훈련이 완료되면 [추론](inference.md) 섹션을 참고하여 음성을 생성할 수 있습니다.
112 |
113 | !!! info
114 | 기본적으로 모델은 화자의 말하는 패턴만 학습하고 음색은 학습하지 않습니다. 음색의 안정성을 위해 프롬프트를 사용해야 합니다.
115 | 음색을 학습하려면 훈련 단계를 늘릴 수 있지만, 이는 과적합의 위험을 초래할 수 있습니다.
116 |
117 | 훈련이 끝나면 LoRA 가중치를 일반 가중치로 변환한 후에 추론을 수행해야 합니다.
118 |
119 | ```bash
120 | python tools/llama/merge_lora.py \
121 | --lora-config r_8_alpha_16 \
122 | --base-weight checkpoints/fish-speech-1.5 \
123 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \
124 | --output checkpoints/fish-speech-1.5-yth-lora/
125 | ```
126 |
127 | !!! note
128 | 다른 체크포인트도 시도해 볼 수 있습니다. 요구 사항에 맞는 가장 초기 체크포인트를 사용하는 것이 좋습니다. 이들은 종종 분포 밖(OOD) 데이터에서 더 좋은 성능을 발휘합니다.
129 |
--------------------------------------------------------------------------------
/docs/ko/inference.md:
--------------------------------------------------------------------------------
1 | # 추론
2 |
3 | 추론은 명령줄, HTTP API, 그리고 웹 UI에서 지원됩니다.
4 |
5 | !!! note
6 | 전체 추론 과정은 다음의 여러 단계로 구성됩니다:
7 |
8 | 1. VQGAN을 사용하여 약 10초 분량의 음성을 인코딩합니다.
9 | 2. 인코딩된 시맨틱 토큰과 해당 텍스트를 예시로 언어 모델에 입력합니다.
10 | 3. 새로운 텍스트를 입력하면, 모델이 해당하는 시맨틱 토큰을 생성합니다.
11 | 4. 생성된 시맨틱 토큰을 VITS / VQGAN에 입력하여 음성을 디코딩하고 생성합니다.
12 |
13 | ## 모델 다운로드
14 | 필요한 `vqgan` 및 `llama` 모델을 Hugging Face 리포지토리에서 다운로드하세요.
15 |
16 | ```bash
17 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
18 | ```
19 |
20 | ## 명령줄 추론
21 | ### 1. 음성에서 프롬프트 생성:
22 |
23 | !!! note
24 | 모델이 음색을 무작위로 선택하도록 하려면 이 단계를 건너뛸 수 있습니다.
25 |
26 | !!! warning "향후 버전 경고"
27 | 원래 경로(tools/vqgan/infernce.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오.
28 |
29 | ```bash
30 | python fish_speech/models/vqgan/inference.py \
31 | -i "paimon.wav" \
32 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
33 | ```
34 |
35 | 이 명령을 실행하면 `fake.npy` 파일을 얻게 됩니다.
36 |
37 | ### 2. 텍스트에서 시맨틱 토큰 생성:
38 |
39 | !!! warning "향후 버전 경고"
40 | 원래 경로(tools/llama/generate.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오.
41 |
42 | ```bash
43 | python fish_speech/models/text2semantic/inference.py \
44 | --text "변환할 텍스트" \
45 | --prompt-text "참고할 텍스트" \
46 | --prompt-tokens "fake.npy" \
47 | --checkpoint-path "checkpoints/fish-speech-1.5" \
48 | --num-samples 2 \
49 | --compile
50 | ```
51 |
52 | 이 명령을 실행하면 작업 디렉토리에 `codes_N` 파일이 생성되며, N은 0부터 시작하는 정수입니다.
53 |
54 | !!! note
55 | 빠른 추론을 위해 `--compile` 옵션을 사용하여 CUDA 커널을 결합할 수 있습니다 (~초당 30 토큰 -> ~초당 500 토큰).
56 | `--compile` 매개변수를 주석 처리하여 가속화 옵션을 사용하지 않을 수도 있습니다.
57 |
58 | !!! info
59 | bf16을 지원하지 않는 GPU의 경우 `--half` 매개변수를 사용해야 할 수 있습니다.
60 |
61 | ### 3. 시맨틱 토큰에서 음성 생성:
62 |
63 | #### VQGAN 디코더
64 |
65 | !!! warning "향후 버전 경고"
66 | 원래 경로(tools/vqgan/infernce.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오.
67 |
68 | ```bash
69 | python fish_speech/models/vqgan/inference.py \
70 | -i "codes_0.npy" \
71 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
72 | ```
73 |
74 | ## HTTP API 추론
75 |
76 | 추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
77 |
78 | ```bash
79 | python -m tools.api_server \
80 | --listen 0.0.0.0:8080 \
81 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
82 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
83 | --decoder-config-name firefly_gan_vq
84 | ```
85 |
86 | 추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다.
87 |
88 | 이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
89 |
90 | 아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다.
91 |
92 | ```bash
93 | python -m tools.api_client \
94 | --text "입력할 텍스트" \
95 | --reference_audio "참고 음성 경로" \
96 | --reference_text "참고 음성의 텍스트 내용" \
97 | --streaming True
98 | ```
99 |
100 | 위 명령은 참고 음성 정보를 바탕으로 원하는 음성을 합성하고, 스트리밍 방식으로 반환합니다.
101 |
102 | 다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
103 |
104 | ```bash
105 | python -m tools.api_client \
106 | --text "입력할 텍스트" \
107 | --reference_audio "참고 음성 경로1" "참고 음성 경로2" \
108 | --reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
109 | --streaming False \
110 | --output "generated" \
111 | --format "mp3"
112 | ```
113 |
114 | 위 명령어는 여러 참고 음성 정보를 바탕으로 `MP3` 형식의 음성을 합성하여, 현재 디렉토리에 `generated.mp3`로 저장합니다.
115 |
116 | `--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
117 |
118 | !!! info
119 | 제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다.
120 |
121 | ## GUI 추론
122 | [클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)
123 |
124 | ## WebUI 추론
125 |
126 | 다음 명령으로 WebUI를 시작할 수 있습니다:
127 |
128 | ```bash
129 | python -m tools.run_webui \
130 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
131 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
132 | --decoder-config-name firefly_gan_vq
133 | ```
134 |
135 | > 추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다.
136 |
137 | !!! note
138 | 라벨 파일과 참고 음성 파일을 미리 메인 디렉토리의 `references` 폴더에 저장해 두면, WebUI에서 바로 호출할 수 있습니다. (해당 폴더는 직접 생성해야 합니다.)
139 |
140 | !!! note
141 | WebUI를 구성하기 위해 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`과 같은 Gradio 환경 변수를 사용할 수 있습니다.
142 |
143 | 즐기세요!
144 |
--------------------------------------------------------------------------------
/docs/ko/start_agent.md:
--------------------------------------------------------------------------------
1 | # 에이전트 시작하기
2 |
3 | !!! note
4 | 전체 문서는 claude3.5 Sonnet에 의해 번역되었으며, 원어민인 경우 번역에 문제가 있다고 생각되면 이슈나 풀 리퀘스트를 보내주셔서 대단히 감사합니다!
5 |
6 | ## 요구사항
7 |
8 | - GPU 메모리: 최소 8GB(양자화 사용 시), 16GB 이상 권장
9 | - 디스크 사용량: 10GB
10 |
11 | ## 모델 다운로드
12 |
13 | 다음 명령어로 모델을 받을 수 있습니다:
14 |
15 | ```bash
16 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
17 | ```
18 |
19 | 'checkpoints' 폴더에 파일들을 넣으세요.
20 |
21 | 또한 [inference](inference.md)에 설명된 대로 fish-speech 모델도 다운로드해야 합니다.
22 |
23 | checkpoints에는 2개의 폴더가 있어야 합니다.
24 |
25 | `checkpoints/fish-speech-1.4`와 `checkpoints/fish-agent-v0.1-3b`입니다.
26 |
27 | ## 환경 준비
28 |
29 | 이미 Fish-speech가 있다면 다음 명령어를 추가하여 바로 사용할 수 있습니다:
30 | ```bash
31 | pip install cachetools
32 | ```
33 |
34 | !!! 참고
35 | 컴파일을 위해 Python 3.12 미만 버전을 사용해 주세요.
36 |
37 | 없다면 아래 명령어를 사용하여 환경을 구축하세요:
38 |
39 | ```bash
40 | sudo apt-get install portaudio19-dev
41 |
42 | pip install -e .[stable]
43 | ```
44 |
45 | ## 에이전트 데모 실행
46 |
47 | fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
48 |
49 | ```bash
50 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
51 | ```
52 |
53 | `--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.
54 |
55 | 한 번에 컴파일되지 않습니다(기억해 두세요).
56 |
57 | 그런 다음 다른 터미널을 열고 다음 명령어를 사용하세요:
58 |
59 | ```bash
60 | python -m tools.e2e_webui
61 | ```
62 |
63 | 이렇게 하면 기기에 Gradio WebUI가 생성됩니다.
64 |
65 | 모델을 처음 사용할 때는 (`--compile`이 True인 경우) 잠시 컴파일이 진행되므로 기다려 주세요.
66 |
67 | ## Gradio Webui
68 |
69 |
70 |
71 |
72 | 즐거운 시간 되세요!
73 |
74 | ## 성능
75 |
76 | 테스트 결과, 4060 노트북은 겨우 실행되며 매우 부하가 큰 상태로, 초당 약 8토큰 정도만 처리합니다. 4090은 컴파일 상태에서 초당 약 95토큰을 처리하며, 이것이 저희가 권장하는 사양입니다.
77 |
78 | # 에이전트 소개
79 |
80 | 이 데모는 초기 알파 테스트 버전으로, 추론 속도 최적화가 필요하며 수정해야 할 버그가 많이 있습니다. 버그를 발견하거나 수정하고 싶으시다면 이슈나 풀 리퀘스트를 보내주시면 매우 감사하겠습니다.
81 |
--------------------------------------------------------------------------------
/docs/pt/finetune.md:
--------------------------------------------------------------------------------
1 | # Ajuste Fino
2 |
3 | É óbvio que ao abrir esta página, você não deve estar muito satisfeito com o desempenho do modelo pré-treinado com poucos exemplos. Você pode querer ajustar o modelo para melhorar seu desempenho em seu conjunto de dados.
4 |
5 | Na atual versão, a única coisa que você precisa ajustar é a parte do 'LLAMA'.
6 |
7 | ## Ajuste Fino do LLAMA
8 | ### 1. Preparando o conjunto de dados
9 |
10 | ```
11 | .
12 | ├── SPK1
13 | │ ├── 21.15-26.44.lab
14 | │ ├── 21.15-26.44.mp3
15 | │ ├── 27.51-29.98.lab
16 | │ ├── 27.51-29.98.mp3
17 | │ ├── 30.1-32.71.lab
18 | │ └── 30.1-32.71.mp3
19 | └── SPK2
20 | ├── 38.79-40.85.lab
21 | └── 38.79-40.85.mp3
22 | ```
23 |
24 | Você precisa converter seu conjunto de dados para o formato acima e colocá-lo em `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`.
25 |
26 | !!! info
27 | O arquivo de anotação `.lab` deve conter apenas a transcrição do áudio, sem a necessidade de formatação especial. Por exemplo, se o arquivo `hi.mp3` disser "Olá, tchau", o arquivo `hi.lab` conterá uma única linha de texto: "Olá, tchau".
28 |
29 | !!! warning
30 | É recomendado aplicar normalização de volume ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso.
31 |
32 | ```bash
33 | fap loudness-norm data-raw data --clean
34 | ```
35 |
36 |
37 | ### 2. Extração em lote de tokens semânticos
38 |
39 | Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando:
40 |
41 | ```bash
42 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
43 | ```
44 |
45 | Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos:
46 |
47 | ```bash
48 | python tools/vqgan/extract_vq.py data \
49 | --num-workers 1 --batch-size 16 \
50 | --config-name "firefly_gan_vq" \
51 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
52 | ```
53 |
54 | !!! note
55 | Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.
56 | Para o formato VITS, você pode especificar uma lista de arquivos usando `--filelist xxx.list`.
57 |
58 | Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo:
59 |
60 | ```
61 | .
62 | ├── SPK1
63 | │ ├── 21.15-26.44.lab
64 | │ ├── 21.15-26.44.mp3
65 | │ ├── 21.15-26.44.npy
66 | │ ├── 27.51-29.98.lab
67 | │ ├── 27.51-29.98.mp3
68 | │ ├── 27.51-29.98.npy
69 | │ ├── 30.1-32.71.lab
70 | │ ├── 30.1-32.71.mp3
71 | │ └── 30.1-32.71.npy
72 | └── SPK2
73 | ├── 38.79-40.85.lab
74 | ├── 38.79-40.85.mp3
75 | └── 38.79-40.85.npy
76 | ```
77 |
78 | ### 3. Empacotar o conjunto de dados em protobuf
79 |
80 | ```bash
81 | python tools/llama/build_dataset.py \
82 | --input "data" \
83 | --output "data/protos" \
84 | --text-extension .lab \
85 | --num-workers 16
86 | ```
87 |
88 | Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.protos` no diretório `data`.
89 |
90 | ### 4. E finalmente, chegamos ao ajuste fino com LoRA
91 |
92 | Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando:
93 |
94 | ```bash
95 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
96 | ```
97 |
98 | E então, execute o seguinte comando para iniciar o ajuste fino:
99 |
100 | ```bash
101 | python fish_speech/train.py --config-name text2semantic_finetune \
102 | project=$project \
103 | +lora@model.model.lora_config=r_8_alpha_16
104 | ```
105 |
106 | !!! note
107 | Se quiser, você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se ajustar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`.
108 |
109 | !!! note
110 | Para usuários do Windows, é recomendado usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`.
111 |
112 | Após concluir o treinamento, consulte a seção [inferência](inference.md).
113 |
114 | !!! info
115 | Por padrão, o modelo aprenderá apenas os padrões de fala do orador e não o timbre. Ainda pode ser preciso usar prompts para garantir a estabilidade do timbre.
116 | Se quiser que ele aprenda o timbre, aumente o número de etapas de treinamento, mas isso pode levar ao overfitting (sobreajuste).
117 |
118 | Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares antes de realizar a inferência.
119 |
120 | ```bash
121 | python tools/llama/merge_lora.py \
122 | --lora-config r_8_alpha_16 \
123 | --base-weight checkpoints/fish-speech-1.5 \
124 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \
125 | --output checkpoints/fish-speech-1.5-yth-lora/
126 | ```
127 | !!! note
128 | É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD).
129 |
--------------------------------------------------------------------------------
/docs/pt/start_agent.md:
--------------------------------------------------------------------------------
1 | # Iniciar Agente
2 |
3 | !!! note
4 | Todo o documento foi traduzido por claude3.5 Sonnet, se você for um falante nativo e achar a tradução problemática, muito obrigado por nos enviar um problema ou uma solicitação pull!
5 |
6 | ## Requisitos
7 |
8 | - Memória GPU: No mínimo 8GB (com quantização), 16GB ou mais é recomendado.
9 | - Uso de disco: 10GB
10 |
11 | ## Download do Modelo
12 |
13 | Você pode obter o modelo através de:
14 |
15 | ```bash
16 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
17 | ```
18 |
19 | Coloque-os na pasta 'checkpoints'.
20 |
21 | Você também precisará do modelo fish-speech que pode ser baixado seguindo as instruções em [inference](inference.md).
22 |
23 | Então haverá 2 pastas em checkpoints.
24 |
25 | O `checkpoints/fish-speech-1.4` e `checkpoints/fish-agent-v0.1-3b`
26 |
27 | ## Preparação do Ambiente
28 |
29 | Se você já tem o Fish-speech, pode usar diretamente adicionando a seguinte instrução:
30 | ```bash
31 | pip install cachetools
32 | ```
33 |
34 | !!! nota
35 | Por favor, use a versão Python abaixo de 3.12 para compilação.
36 |
37 | Se você não tem, use os comandos abaixo para construir seu ambiente:
38 |
39 | ```bash
40 | sudo apt-get install portaudio19-dev
41 |
42 | pip install -e .[stable]
43 | ```
44 |
45 | ## Iniciar a Demo do Agente
46 |
47 | Para construir o fish-agent, use o comando abaixo na pasta principal:
48 |
49 | ```bash
50 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
51 | ```
52 |
53 | O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.
54 |
55 | Não será compilado de uma vez (lembre-se).
56 |
57 | Então abra outro terminal e use o comando:
58 |
59 | ```bash
60 | python -m tools.e2e_webui
61 | ```
62 |
63 | Isso criará uma WebUI Gradio no dispositivo.
64 |
65 | Quando você usar o modelo pela primeira vez, ele irá compilar (se `--compile` estiver True) por um curto período, então aguarde com paciência.
66 |
67 | ## Gradio Webui
68 |
69 |
70 |
71 |
72 | Divirta-se!
73 |
74 | ## Desempenho
75 |
76 | Em nossos testes, um laptop com 4060 mal consegue rodar, ficando muito sobrecarregado, gerando apenas cerca de 8 tokens/s. A 4090 gera cerca de 95 tokens/s com compilação, que é o que recomendamos.
77 |
78 | # Sobre o Agente
79 |
80 | A demo é uma versão alpha inicial de teste, a velocidade de inferência precisa ser otimizada, e há muitos bugs aguardando correção. Se você encontrou um bug ou quer corrigi-lo, ficaremos muito felizes em receber uma issue ou um pull request.
81 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | mkdocs-material
2 | mkdocs-static-i18n[material]
3 | mkdocs[i18n]
4 |
--------------------------------------------------------------------------------
/docs/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | .md-grid {
2 | max-width: 1440px;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/zh/finetune.md:
--------------------------------------------------------------------------------
1 | # 微调
2 |
3 | 显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.
4 |
5 | 在目前版本,你只需要微调'LLAMA'部分即可.
6 |
7 | ## LLAMA 微调
8 | ### 1. 准备数据集
9 |
10 | ```
11 | .
12 | ├── SPK1
13 | │ ├── 21.15-26.44.lab
14 | │ ├── 21.15-26.44.mp3
15 | │ ├── 27.51-29.98.lab
16 | │ ├── 27.51-29.98.mp3
17 | │ ├── 30.1-32.71.lab
18 | │ └── 30.1-32.71.mp3
19 | └── SPK2
20 | ├── 38.79-40.85.lab
21 | └── 38.79-40.85.mp3
22 | ```
23 |
24 | 你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`.
25 |
26 | !!! info
27 | 标注文件 `.lab` 仅需包含音频的转写文本,无需遵循特殊格式要求。例如,如果 `hi.mp3` 中的内容是“你好,再见。”,那么 `hi.lab` 文件中只需包含一行文本:“你好,再见”。
28 |
29 | !!! warning
30 | 建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤.
31 | ```bash
32 | fap loudness-norm data-raw data --clean
33 | ```
34 |
35 | ### 2. 批量提取语义 token
36 |
37 | 确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
38 |
39 | ```bash
40 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
41 | ```
42 |
43 | 对于中国大陆用户, 可使用 mirror 下载.
44 |
45 | ```bash
46 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
47 | ```
48 |
49 | 随后可运行以下命令来提取语义 token:
50 |
51 | ```bash
52 | python tools/vqgan/extract_vq.py data \
53 | --num-workers 1 --batch-size 16 \
54 | --config-name "firefly_gan_vq" \
55 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
56 | ```
57 |
58 | !!! note
59 | 你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.
60 |
61 | 该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示:
62 |
63 | ```
64 | .
65 | ├── SPK1
66 | │ ├── 21.15-26.44.lab
67 | │ ├── 21.15-26.44.mp3
68 | │ ├── 21.15-26.44.npy
69 | │ ├── 27.51-29.98.lab
70 | │ ├── 27.51-29.98.mp3
71 | │ ├── 27.51-29.98.npy
72 | │ ├── 30.1-32.71.lab
73 | │ ├── 30.1-32.71.mp3
74 | │ └── 30.1-32.71.npy
75 | └── SPK2
76 | ├── 38.79-40.85.lab
77 | ├── 38.79-40.85.mp3
78 | └── 38.79-40.85.npy
79 | ```
80 |
81 | ### 3. 打包数据集为 protobuf
82 |
83 | ```bash
84 | python tools/llama/build_dataset.py \
85 | --input "data" \
86 | --output "data/protos" \
87 | --text-extension .lab \
88 | --num-workers 16
89 | ```
90 |
91 | 命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件.
92 |
93 |
94 | ### 4. 最后, 使用 LoRA 进行微调
95 |
96 | 同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
97 |
98 | ```bash
99 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
100 | ```
101 |
102 | 对于中国大陆用户, 可使用 mirror 下载.
103 |
104 | ```bash
105 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
106 | ```
107 |
108 | 最后, 你可以运行以下命令来启动微调:
109 |
110 | ```bash
111 | python fish_speech/train.py --config-name text2semantic_finetune \
112 | project=$project \
113 | +lora@model.model.lora_config=r_8_alpha_16
114 | ```
115 |
116 | !!! note
117 | 你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
118 |
119 | !!! note
120 | 对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题.
121 |
122 | 训练结束后, 你可以参考 [推理](inference.md) 部分来测试你的模型.
123 |
124 | !!! info
125 | 默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.
126 | 如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.
127 |
128 | 训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理.
129 |
130 | ```bash
131 | python tools/llama/merge_lora.py \
132 | --lora-config r_8_alpha_16 \
133 | --base-weight checkpoints/fish-speech-1.5 \
134 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \
135 | --output checkpoints/fish-speech-1.5-yth-lora/
136 | ```
137 |
138 | !!! note
139 | 你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.
140 |
--------------------------------------------------------------------------------
/docs/zh/inference.md:
--------------------------------------------------------------------------------
1 | # 推理
2 |
3 | 推理支持命令行, http api, 以及 webui 三种方式.
4 |
5 | !!! note
6 | 总的来说, 推理分为几个部分:
7 |
8 | 1. 给定一段 ~10 秒的语音, 将它用 VQGAN 编码.
9 | 2. 将编码后的语义 token 和对应文本输入语言模型作为例子.
10 | 3. 给定一段新文本, 让模型生成对应的语义 token.
11 | 4. 将生成的语义 token 输入 VQGAN 解码, 生成对应的语音.
12 |
13 | ## 下载模型
14 | 从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
15 |
16 | ```bash
17 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
18 | ```
19 |
20 | 对于中国大陆用户,可使用 mirror 下载。
21 |
22 | ```bash
23 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
24 | ```
25 |
26 | ## 命令行推理
27 | ### 1. 从语音生成 prompt:
28 |
29 | !!! note
30 | 如果你打算让模型随机选择音色, 你可以跳过这一步.
31 |
32 | ```bash
33 | python fish_speech/models/vqgan/inference.py \
34 | -i "paimon.wav" \
35 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
36 | ```
37 |
38 | 你应该能得到一个 `fake.npy` 文件.
39 |
40 | ### 2. 从文本生成语义 token:
41 |
42 | ```bash
43 | python fish_speech/models/text2semantic/inference.py \
44 | --text "要转换的文本" \
45 | --prompt-text "你的参考文本" \
46 | --prompt-tokens "fake.npy" \
47 | --checkpoint-path "checkpoints/fish-speech-1.5" \
48 | --num-samples 2 \
49 | --compile
50 | ```
51 |
52 | 该命令会在工作目录下创建 `codes_N` 文件, 其中 N 是从 0 开始的整数.
53 |
54 | !!! note
55 | 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).
56 | 对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.
57 |
58 | !!! info
59 | 对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.
60 |
61 | ### 3. 从语义 token 生成人声:
62 |
63 | #### VQGAN 解码
64 |
65 | ```bash
66 | python fish_speech/models/vqgan/inference.py \
67 | -i "codes_0.npy" \
68 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
69 | ```
70 |
71 | ## HTTP API 推理
72 |
73 | 运行以下命令来启动 HTTP 服务:
74 |
75 | ```bash
76 | python -m tools.api_server \
77 | --listen 0.0.0.0:8080 \
78 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
79 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
80 | --decoder-config-name firefly_gan_vq
81 | ```
82 | > 如果你想要加速推理,可以加上`--compile`参数。
83 |
84 | 推荐中国大陆用户运行以下命令来启动 HTTP 服务:
85 | ```bash
86 | HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
87 | ```
88 |
89 | 随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
90 |
91 | 下面是使用`tools/api_client.py`发送请求的示例。
92 |
93 | ```bash
94 | python -m tools.api_client \
95 | --text "要输入的文本" \
96 | --reference_audio "参考音频路径" \
97 | --reference_text "参考音频的文本内容" \
98 | --streaming True
99 | ```
100 |
101 | 上面的命令表示按照参考音频的信息,合成所需的音频并流式返回.
102 |
103 | 下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
104 | ```bash
105 | python -m tools.api_client \
106 | --text "要输入的文本" \
107 | --reference_audio "参考音频路径1" "参考音频路径2" \
108 | --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
109 | --streaming False \
110 | --output "generated" \
111 | --format "mp3"
112 | ```
113 |
114 | 上面的命令表示按照多个参考音频的信息,合成所需的`MP3`格式音频,并保存为当前目录的`generated.mp3`文件。
115 |
116 | 还可以用`--reference_id`(仅能用一个)来代替`--reference_audio`和`--reference_text`, 前提是在项目根目录下创建`references/`文件夹,
117 | 里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
118 |
119 | !!! info
120 | 要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h`
121 |
122 | ## GUI 推理
123 | [下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)
124 |
125 | ## WebUI 推理
126 |
127 | 你可以使用以下命令来启动 WebUI:
128 |
129 | ```bash
130 | python -m tools.run_webui \
131 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
132 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
133 | --decoder-config-name firefly_gan_vq
134 | ```
135 | > 如果你想要加速推理,可以加上`--compile`参数。
136 |
137 | !!! note
138 | 你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
139 |
140 | !!! note
141 | 你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.
142 |
143 | 祝大家玩得开心!
144 |
--------------------------------------------------------------------------------
/docs/zh/start_agent.md:
--------------------------------------------------------------------------------
1 | # 启动 Agent
2 |
3 | ## 要求
4 |
5 | - GPU 显存: 至少 8GB(在量化的条件下),推荐 16GB 及以上
6 | - 硬盘使用量: 10GB
7 |
8 | ## 下载模型
9 |
10 | 你可以执行下面的语句来获取模型:
11 |
12 | ```bash
13 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
14 | ```
15 |
16 | 如果你处于国内网络,首先执行:
17 |
18 | ```bash
19 | export HF_ENDPOINT=https://hf-mirror.com
20 | ```
21 |
22 | 把他们放进名为 'checkpoints' 的文件夹内。
23 |
24 | 你同样需要 fish-speech 的模型,关于如何获取 fish-speech 模型请查看[inference](inference.md)。
25 |
26 | 完成后你的 checkpoints 文件夹中会有两个子文件夹:`checkpoints/fish-speech-1.4` 和 `checkpoints/fish-agent-v0.1-3b`。
27 |
28 | ## Environment Prepare
29 |
30 | 如果你已经有了 Fish-Speech 环境,你可以在安装下面的包的前提下直接使用:
31 |
32 | ```bash
33 | pip install cachetools
34 | ```
35 |
36 | !!! note
37 | 请使用小于 3.12 的 python 版本使 compile 可用
38 |
39 | 如果你没有 Fish-Speech 环境,请执行下面的语句来构造你的环境:
40 |
41 | ```bash
42 | sudo apt-get install portaudio19-dev
43 |
44 | pip install -e .[stable]
45 | ```
46 |
47 | ## 链接 Agent.
48 |
49 | 你需要使用以下指令来构建 fish-agent
50 |
51 | ```bash
52 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
53 | ```
54 |
55 | `--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。
56 |
57 | 你需要哦注意 compile 需要进行一段时间.
58 |
59 | 然后启动另一个终端并执行:
60 |
61 | ```bash
62 | python -m tools.e2e_webui
63 | ```
64 |
65 | 这会在设备上创建一个 Gradio WebUI。
66 |
67 | 每当进行第一轮对话的时候,模型需要 compile 一段时间,请耐心等待
68 |
69 | ## Gradio Webui
70 |
71 |
72 |
73 |
74 |
75 | 玩得开心!
76 |
77 | ## Performance
78 |
79 | 在我们的测试环境下, 4060 laptop GPU 只能刚刚运行该模型,只有大概 8 tokens/s。 4090 GPU 可以在编译后达到 95 tokens/s,我们推荐使用至少 4080 以上级别的 GPU 来达到较好体验。
80 |
81 | # About Agent
82 |
83 | 该模型仍处于测试阶段。如果你发现了问题,请给我们提 issue 或者 pull request,我们非常感谢。
84 |
--------------------------------------------------------------------------------
/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_ENABLED=${CUDA_ENABLED:-true}
4 | DEVICE=""
5 |
6 | if [ "${CUDA_ENABLED}" != "true" ]; then
7 | DEVICE="--device cpu"
8 | fi
9 |
10 | exec python tools/run_webui.py ${DEVICE}
11 |
--------------------------------------------------------------------------------
/fish_speech/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .grad_norm import GradNormMonitor
2 |
3 | __all__ = ["GradNormMonitor"]
4 |
--------------------------------------------------------------------------------
/fish_speech/callbacks/grad_norm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import lightning.pytorch as pl
4 | import torch
5 | from lightning import LightningModule, Trainer
6 | from lightning.pytorch.callbacks import Callback
7 | from torch import Tensor, nn
8 | from torch.utils._foreach_utils import (
9 | _group_tensors_by_device_and_dtype,
10 | _has_foreach_support,
11 | )
12 |
13 |
14 | @torch.no_grad()
15 | def grad_norm(
16 | parameters: Union[Tensor, list[Tensor]],
17 | norm_type: float = 2.0,
18 | ) -> float:
19 | """
20 | Returns the norm of the gradients of the given parameters.
21 |
22 | Args:
23 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24 | single Tensor that will have gradients normalized
25 | norm_type (float): type of the used p-norm.
26 |
27 | Returns:
28 | Total norm of the parameter gradients (viewed as a single vector).
29 | """ # noqa: E501
30 |
31 | if isinstance(parameters, Tensor):
32 | parameters = [parameters]
33 |
34 | grads = [p.grad for p in parameters if p.grad is not None]
35 | if len(grads) == 0:
36 | return None
37 |
38 | first_device = grads[0].device
39 | grouped_grads: dict[
40 | tuple[torch.device, torch.dtype], list[list[Tensor]]
41 | ] = _group_tensors_by_device_and_dtype(
42 | [[g.detach() for g in grads]]
43 | ) # type: ignore[assignment]
44 |
45 | norms = []
46 | for (device, _), ([grads], _) in grouped_grads.items():
47 | if _has_foreach_support(grads, device=device):
48 | norms.extend(torch._foreach_norm(grads, norm_type))
49 | else:
50 | norms.extend([torch.norm(g, norm_type) for g in grads])
51 |
52 | return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53 |
54 |
55 | class GradNormMonitor(Callback):
56 | """
57 | Callback that computes the gradient norm of the model parameters.
58 | """
59 |
60 | def __init__(
61 | self,
62 | norm_type: float = 2.0,
63 | logging_interval: str = "step",
64 | sub_module: Optional[Union[str, list[str]]] = None,
65 | ) -> None:
66 | """
67 | Args:
68 | norm_type (float): type of the used p-norm.
69 | logging_interval (str): "step" or "epoch".
70 | """
71 | super().__init__()
72 |
73 | self.norm_type = norm_type
74 | self.logging_interval = logging_interval
75 | self.sub_module = sub_module
76 |
77 | def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78 | """
79 | Computes the gradient norm of the model parameters and logs it to the logger.
80 |
81 | Args:
82 | trainer (Trainer): The trainer object
83 | model (LightningModule): The current lightningModule
84 | """
85 |
86 | lightning_model = model
87 |
88 | if self.sub_module is None:
89 | return self.log_sub_module_grad_norm(lightning_model, model, "")
90 |
91 | sub_modules = self.sub_module
92 | if isinstance(sub_modules, str):
93 | sub_modules = [sub_modules]
94 |
95 | for sub_module in sub_modules:
96 | self.log_sub_module_grad_norm(
97 | lightning_model, getattr(model, sub_module), f"/{sub_module}"
98 | )
99 |
100 | def log_sub_module_grad_norm(
101 | self, lightning_model: LightningModule, model: nn.Module, path: str
102 | ) -> None:
103 | grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104 | if grad_norm_val is None:
105 | return
106 |
107 | on_step = self.logging_interval == "step"
108 | lightning_model.log(
109 | f"train{path}/grad_norm",
110 | grad_norm_val,
111 | on_step=on_step,
112 | on_epoch=not on_step,
113 | )
114 |
--------------------------------------------------------------------------------
/fish_speech/configs/base.yaml:
--------------------------------------------------------------------------------
1 | # Base configuration for training a model
2 | paths:
3 | run_dir: results/${project}
4 | ckpt_dir: ${paths.run_dir}/checkpoints
5 |
6 | hydra:
7 | run:
8 | dir: ${paths.run_dir}
9 |
10 | # Lightning Trainer
11 | trainer:
12 | _target_: lightning.pytorch.trainer.Trainer
13 |
14 | default_root_dir: ${paths.run_dir}
15 | accelerator: gpu
16 | num_nodes: 1
17 | devices: auto
18 | strategy:
19 | _target_: lightning.pytorch.strategies.DDPStrategy
20 | process_group_backend: nccl # This should be override when training on windows
21 |
22 | precision: bf16-mixed
23 |
24 | # disable validation by epoch end
25 | check_val_every_n_epoch: null
26 | val_check_interval: 5000
27 | max_steps: 100_000
28 |
29 | # Use torch.backends.cudnn.benchmark to speed up training
30 | benchmark: true
31 |
32 | # Callbacks
33 | callbacks:
34 | model_checkpoint:
35 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
36 | dirpath: ${paths.ckpt_dir}
37 | filename: "step_{step:09d}"
38 | save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39 | save_top_k: 5 # save 5 latest checkpoints
40 | monitor: step # use step to monitor checkpoints
41 | mode: max # save the latest checkpoint with the highest global_step
42 | every_n_epochs: null # don't save checkpoints by epoch end
43 | every_n_train_steps: 5000 # save checkpoints every 5000 steps
44 | auto_insert_metric_name: false
45 |
46 | model_summary:
47 | _target_: lightning.pytorch.callbacks.ModelSummary
48 | max_depth: 2 # the maximum depth of layer nesting that the summary will include
49 |
50 | learning_rate_monitor:
51 | _target_: lightning.pytorch.callbacks.LearningRateMonitor
52 | logging_interval: step
53 | log_momentum: false
54 |
55 | grad_norm_monitor:
56 | _target_: fish_speech.callbacks.GradNormMonitor
57 | norm_type: 2
58 | logging_interval: step
59 |
60 | # Logger
61 | logger:
62 | tensorboard:
63 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64 | save_dir: "${paths.run_dir}/tensorboard/"
65 | name: null
66 | log_graph: false
67 | default_hp_metric: true
68 | prefix: ""
69 |
70 | # wandb:
71 | # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72 | # # name: "" # name of the run (normally generated by wandb)
73 | # save_dir: "${paths.run_dir}"
74 | # offline: False
75 | # id: null # pass correct id to resume experiment!
76 | # anonymous: null # enable anonymous logging
77 | # project: "fish-speech"
78 | # log_model: False # upload lightning ckpts
79 | # prefix: "" # a string to put at the beginning of metric keys
80 | # # entity: "" # set to name of your wandb team
81 | # group: ""
82 | # tags: ["vq", "hq", "finetune"]
83 | # job_type: ""
84 |
85 | # Loop
86 | train: true
87 | test: false
88 |
--------------------------------------------------------------------------------
/fish_speech/configs/firefly_gan_vq.yaml:
--------------------------------------------------------------------------------
1 | _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
2 | spec_transform:
3 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
4 | sample_rate: 44100
5 | n_mels: 160
6 | n_fft: 2048
7 | hop_length: 512
8 | win_length: 2048
9 | backbone:
10 | _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
11 | input_channels: 160
12 | depths: [3, 3, 9, 3]
13 | dims: [128, 256, 384, 512]
14 | drop_path_rate: 0.2
15 | kernel_size: 7
16 | head:
17 | _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
18 | hop_length: 512
19 | upsample_rates: [8, 8, 2, 2, 2] # aka. strides
20 | upsample_kernel_sizes: [16, 16, 4, 4, 4]
21 | resblock_kernel_sizes: [3, 7, 11]
22 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23 | num_mels: 512
24 | upsample_initial_channel: 512
25 | pre_conv_kernel_size: 13
26 | post_conv_kernel_size: 13
27 | quantizer:
28 | _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
29 | input_dim: 512
30 | n_groups: 8
31 | n_codebooks: 1
32 | levels: [8, 5, 5, 5]
33 | downsample_factor: [2, 2]
34 |
--------------------------------------------------------------------------------
/fish_speech/configs/lora/r_8_alpha_16.yaml:
--------------------------------------------------------------------------------
1 | _target_: fish_speech.models.text2semantic.lora.LoraConfig
2 | r: 8
3 | lora_alpha: 16
4 | lora_dropout: 0.01
5 |
--------------------------------------------------------------------------------
/fish_speech/configs/text2semantic_finetune.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 | - _self_
4 |
5 | project: text2semantic_finetune_dual_ar
6 | max_length: 4096
7 | pretrained_ckpt_path: checkpoints/fish-speech-1.5
8 |
9 | # Lightning Trainer
10 | trainer:
11 | accumulate_grad_batches: 1
12 | gradient_clip_val: 1.0
13 | gradient_clip_algorithm: "norm"
14 | max_steps: 10000
15 | precision: bf16-true
16 | limit_val_batches: 10
17 | val_check_interval: 100
18 | # strategy:
19 | # find_unused_parameters: true
20 | # static_graph: true
21 |
22 | # Dataset Configuration
23 | tokenizer:
24 | _target_: fish_speech.tokenizer.FishTokenizer
25 | model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
26 |
27 | # Dataset Configuration
28 | train_dataset:
29 | _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
30 | proto_files:
31 | - data/protos
32 | tokenizer: ${tokenizer}
33 | causal: true
34 | max_length: ${max_length}
35 | use_speaker: false
36 | interactive_prob: 0.7
37 |
38 | val_dataset:
39 | _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
40 | proto_files:
41 | - data/protos
42 | tokenizer: ${tokenizer}
43 | causal: true
44 | max_length: ${max_length}
45 | use_speaker: false
46 | interactive_prob: 0.7
47 |
48 | data:
49 | _target_: fish_speech.datasets.semantic.SemanticDataModule
50 | train_dataset: ${train_dataset}
51 | val_dataset: ${val_dataset}
52 | num_workers: 4
53 | batch_size: 4
54 | tokenizer: ${tokenizer}
55 | max_length: ${max_length}
56 |
57 | # Model Configuration
58 | model:
59 | _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
60 | model:
61 | _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
62 | path: ${pretrained_ckpt_path}
63 | load_weights: true
64 | max_length: ${max_length}
65 | lora_config: null
66 |
67 | optimizer:
68 | _target_: torch.optim.AdamW
69 | _partial_: true
70 | lr: 1e-4
71 | weight_decay: 0
72 | betas: [0.9, 0.95]
73 | eps: 1e-5
74 |
75 | lr_scheduler:
76 | _target_: torch.optim.lr_scheduler.LambdaLR
77 | _partial_: true
78 | lr_lambda:
79 | _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
80 | _partial_: true
81 | num_warmup_steps: 10
82 |
83 | # Callbacks
84 | callbacks:
85 | model_checkpoint:
86 | every_n_train_steps: ${trainer.val_check_interval}
87 |
--------------------------------------------------------------------------------
/fish_speech/datasets/concat_repeat.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import random
3 | from typing import Iterable
4 |
5 | from torch.utils.data import Dataset, IterableDataset
6 |
7 |
8 | class ConcatRepeatDataset(Dataset):
9 | datasets: list[Dataset]
10 | cumulative_sizes: list[int]
11 | repeats: list[int]
12 |
13 | @staticmethod
14 | def cumsum(sequence, repeats):
15 | r, s = [], 0
16 | for dataset, repeat in zip(sequence, repeats):
17 | l = len(dataset) * repeat
18 | r.append(l + s)
19 | s += l
20 | return r
21 |
22 | def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23 | super().__init__()
24 |
25 | self.datasets = list(datasets)
26 | self.repeats = repeats
27 |
28 | assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29 | assert len(self.datasets) == len(
30 | repeats
31 | ), "datasets and repeats should have the same length"
32 |
33 | for d in self.datasets:
34 | assert not isinstance(
35 | d, IterableDataset
36 | ), "ConcatRepeatDataset does not support IterableDataset"
37 |
38 | self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39 |
40 | def __len__(self):
41 | return self.cumulative_sizes[-1]
42 |
43 | def __getitem__(self, idx):
44 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45 |
46 | if dataset_idx == 0:
47 | sample_idx = idx
48 | else:
49 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50 |
51 | dataset = self.datasets[dataset_idx]
52 |
53 | return dataset[sample_idx % len(dataset)]
54 |
--------------------------------------------------------------------------------
/fish_speech/datasets/protos/text-data.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package text_data;
4 |
5 | message Semantics {
6 | repeated uint32 values = 1;
7 | }
8 |
9 | message Sentence {
10 | repeated string texts = 1;
11 | repeated Semantics semantics = 3;
12 | }
13 |
14 | message TextData {
15 | string source = 1;
16 | string name = 2;
17 | repeated Sentence sentences = 4;
18 | }
19 |
20 | message SampledData {
21 | string source = 1;
22 | string name = 2;
23 | repeated Sentence samples = 3;
24 | }
25 |
--------------------------------------------------------------------------------
/fish_speech/datasets/protos/text_data_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: text-data.proto
4 | # Protobuf Python Version: 4.25.1
5 | """Generated protocol buffer code."""
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import descriptor_pool as _descriptor_pool
8 | from google.protobuf import symbol_database as _symbol_database
9 | from google.protobuf.internal import builder as _builder
10 |
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17 | b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18 | )
19 |
20 | _globals = globals()
21 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23 | if _descriptor._USE_C_DESCRIPTORS == False:
24 | DESCRIPTOR._options = None
25 | _globals["_SEMANTICS"]._serialized_start = 30
26 | _globals["_SEMANTICS"]._serialized_end = 57
27 | _globals["_SENTENCE"]._serialized_start = 59
28 | _globals["_SENTENCE"]._serialized_end = 125
29 | _globals["_TEXTDATA"]._serialized_start = 127
30 | _globals["_TEXTDATA"]._serialized_end = 207
31 | _globals["_SAMPLEDDATA"]._serialized_start = 209
32 | _globals["_SAMPLEDDATA"]._serialized_end = 290
33 | # @@protoc_insertion_point(module_scope)
34 |
--------------------------------------------------------------------------------
/fish_speech/datasets/protos/text_data_stream.py:
--------------------------------------------------------------------------------
1 | import struct
2 |
3 | from .text_data_pb2 import TextData
4 |
5 |
6 | def read_pb_stream(f):
7 | while True:
8 | buf = f.read(4)
9 | if len(buf) == 0:
10 | break
11 | size = struct.unpack("I", buf)[0]
12 | buf = f.read(size)
13 | text_data = TextData()
14 | text_data.ParseFromString(buf)
15 | yield text_data
16 |
17 |
18 | def write_pb_stream(f, text_data):
19 | buf = text_data.SerializeToString()
20 | f.write(struct.pack("I", len(buf)))
21 | f.write(buf)
22 |
23 |
24 | def pack_pb_stream(text_data):
25 | buf = text_data.SerializeToString()
26 | return struct.pack("I", len(buf)) + buf
27 |
28 |
29 | def split_pb_stream(f):
30 | while True:
31 | head = f.read(4)
32 | if len(head) == 0:
33 | break
34 | size = struct.unpack("I", head)[0]
35 | buf = f.read(size)
36 | yield head + buf
37 |
--------------------------------------------------------------------------------
/fish_speech/datasets/vqgan.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | import librosa
6 | import numpy as np
7 | import torch
8 | from lightning import LightningDataModule
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from fish_speech.utils import RankedLogger
12 |
13 | logger = RankedLogger(__name__, rank_zero_only=False)
14 |
15 |
16 | class VQGANDataset(Dataset):
17 | def __init__(
18 | self,
19 | filelist: str,
20 | sample_rate: int = 32000,
21 | hop_length: int = 640,
22 | slice_frames: Optional[int] = None,
23 | ):
24 | super().__init__()
25 |
26 | filelist = Path(filelist)
27 | root = filelist.parent
28 |
29 | self.files = [
30 | root / line.strip()
31 | for line in filelist.read_text(encoding="utf-8").splitlines()
32 | if line.strip()
33 | ]
34 | self.sample_rate = sample_rate
35 | self.hop_length = hop_length
36 | self.slice_frames = slice_frames
37 |
38 | def __len__(self):
39 | return len(self.files)
40 |
41 | def get_item(self, idx):
42 | file = self.files[idx]
43 |
44 | audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45 |
46 | # Slice audio and features
47 | if (
48 | self.slice_frames is not None
49 | and audio.shape[0] > self.slice_frames * self.hop_length
50 | ):
51 | start = np.random.randint(
52 | 0, audio.shape[0] - self.slice_frames * self.hop_length
53 | )
54 | audio = audio[start : start + self.slice_frames * self.hop_length]
55 |
56 | if len(audio) == 0:
57 | return None
58 |
59 | max_value = np.abs(audio).max()
60 | if max_value > 1.0:
61 | audio = audio / max_value
62 |
63 | return {
64 | "audio": torch.from_numpy(audio),
65 | }
66 |
67 | def __getitem__(self, idx):
68 | try:
69 | return self.get_item(idx)
70 | except Exception as e:
71 | import traceback
72 |
73 | traceback.print_exc()
74 | logger.error(f"Error loading {self.files[idx]}: {e}")
75 | return None
76 |
77 |
78 | @dataclass
79 | class VQGANCollator:
80 | def __call__(self, batch):
81 | batch = [x for x in batch if x is not None]
82 |
83 | audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84 | audio_maxlen = audio_lengths.max()
85 |
86 | # Rounds up to nearest multiple of 2 (audio_lengths)
87 | audios = []
88 | for x in batch:
89 | audios.append(
90 | torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91 | )
92 |
93 | return {
94 | "audios": torch.stack(audios),
95 | "audio_lengths": audio_lengths,
96 | }
97 |
98 |
99 | class VQGANDataModule(LightningDataModule):
100 | def __init__(
101 | self,
102 | train_dataset: VQGANDataset,
103 | val_dataset: VQGANDataset,
104 | batch_size: int = 32,
105 | num_workers: int = 4,
106 | val_batch_size: Optional[int] = None,
107 | ):
108 | super().__init__()
109 |
110 | self.train_dataset = train_dataset
111 | self.val_dataset = val_dataset
112 | self.batch_size = batch_size
113 | self.val_batch_size = val_batch_size or batch_size
114 | self.num_workers = num_workers
115 |
116 | def train_dataloader(self):
117 | return DataLoader(
118 | self.train_dataset,
119 | batch_size=self.batch_size,
120 | collate_fn=VQGANCollator(),
121 | num_workers=self.num_workers,
122 | shuffle=True,
123 | persistent_workers=True,
124 | )
125 |
126 | def val_dataloader(self):
127 | return DataLoader(
128 | self.val_dataset,
129 | batch_size=self.val_batch_size,
130 | collate_fn=VQGANCollator(),
131 | num_workers=self.num_workers,
132 | persistent_workers=True,
133 | )
134 |
135 |
136 | if __name__ == "__main__":
137 | dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
138 | dataloader = DataLoader(
139 | dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
140 | )
141 |
142 | for batch in dataloader:
143 | print(batch["audios"].shape)
144 | print(batch["features"].shape)
145 | print(batch["audio_lengths"])
146 | print(batch["feature_lengths"])
147 | break
148 |
--------------------------------------------------------------------------------
/fish_speech/i18n/README.md:
--------------------------------------------------------------------------------
1 | ## i18n Folder Attribution
2 |
3 | The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4 |
5 | ### fish_speech/i18n/core.py
6 |
7 | **Related code from RVC:**
8 | [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9 |
10 | **Initial commit:**
11 | add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12 |
13 | **Initial author:**
14 | [@L4Ph](https://github.com/L4Ph)
15 |
16 | ### fish_speech/i18n/scan.py
17 |
18 | **Related code from RVC:**
19 | [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20 |
21 | **Initial commit:**
22 | File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23 |
24 | **Initial author:**
25 | [@towzeur](https://github.com/towzeur)
26 |
27 | We appreciate the contributions of the RVC project and its authors.
28 |
--------------------------------------------------------------------------------
/fish_speech/i18n/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import i18n
2 |
3 | __all__ = ["i18n"]
4 |
--------------------------------------------------------------------------------
/fish_speech/i18n/core.py:
--------------------------------------------------------------------------------
1 | import json
2 | import locale
3 | from pathlib import Path
4 |
5 | I18N_FILE_PATH = Path(__file__).parent / "locale"
6 | DEFAULT_LANGUAGE = "en_US"
7 |
8 |
9 | def load_language_list(language):
10 | with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11 | language_list = json.load(f)
12 |
13 | return language_list
14 |
15 |
16 | class I18nAuto:
17 | def __init__(self):
18 | i18n_file = Path(".locale")
19 |
20 | if i18n_file.exists():
21 | with open(i18n_file, "r", encoding="utf-8") as f:
22 | language = f.read().strip()
23 | else:
24 | # getlocale can't identify the system's language ((None, None))
25 | language = locale.getdefaultlocale()[0]
26 |
27 | if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28 | language = DEFAULT_LANGUAGE
29 |
30 | self.language = language
31 | self.language_map = load_language_list(language)
32 |
33 | def __call__(self, key):
34 | return self.language_map.get(key, key)
35 |
36 | def __repr__(self):
37 | return "Use Language: " + self.language
38 |
39 |
40 | i18n = I18nAuto()
41 |
--------------------------------------------------------------------------------
/fish_speech/i18n/scan.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import glob
3 | import json
4 | from collections import OrderedDict
5 | from pathlib import Path
6 |
7 | from loguru import logger
8 |
9 | from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10 |
11 |
12 | def extract_i18n_strings(node):
13 | i18n_strings = []
14 |
15 | if (
16 | isinstance(node, ast.Call)
17 | and isinstance(node.func, ast.Name)
18 | and node.func.id == "i18n"
19 | ):
20 | for arg in node.args:
21 | if isinstance(arg, ast.Str):
22 | i18n_strings.append(arg.s)
23 |
24 | for child_node in ast.iter_child_nodes(node):
25 | i18n_strings.extend(extract_i18n_strings(child_node))
26 |
27 | return i18n_strings
28 |
29 |
30 | # scan the directory for all .py files (recursively)
31 | # for each file, parse the code into an AST
32 | # for each AST, extract the i18n strings
33 |
34 | strings = []
35 | folders = ["fish_speech", "tools"]
36 | # for filename in glob.iglob("**/*.py", recursive=True):
37 | for folder in folders:
38 | for f in Path(folder).rglob("*.py"):
39 | code = f.read_text(encoding="utf-8")
40 | if "i18n(" in code:
41 | tree = ast.parse(code)
42 | i18n_strings = extract_i18n_strings(tree)
43 | logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44 | strings.extend(i18n_strings)
45 |
46 | code_keys = set(strings)
47 | logger.info(f"Total unique: {len(code_keys)}")
48 |
49 |
50 | standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51 | with open(standard_file, "r", encoding="utf-8") as f:
52 | standard_data = json.load(f, object_pairs_hook=OrderedDict)
53 | standard_keys = set(standard_data.keys())
54 |
55 | # Define the standard file name
56 | unused_keys = standard_keys - code_keys
57 | logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58 | for unused_key in unused_keys:
59 | logger.info(f"\t{unused_key}")
60 |
61 | missing_keys = code_keys - standard_keys
62 | logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63 | for missing_key in missing_keys:
64 | logger.info(f"\t{missing_key}")
65 |
66 | code_keys_dict = OrderedDict()
67 | for s in strings:
68 | code_keys_dict[s] = s
69 |
70 | # write back
71 | with open(standard_file, "w", encoding="utf-8") as f:
72 | json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73 | f.write("\n")
74 |
75 | logger.info(f"Updated {standard_file}")
76 |
77 |
78 | # Define the standard file name
79 | standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80 |
81 | # Find all JSON files in the directory
82 | dir_path = I18N_FILE_PATH
83 | languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84 |
85 | # Load the standard file
86 | with open(standard_file, "r", encoding="utf-8") as f:
87 | standard_data = json.load(f, object_pairs_hook=OrderedDict)
88 |
89 | # Loop through each language file
90 | for lang_file in languages:
91 | # Load the language file
92 | with open(lang_file, "r", encoding="utf-8") as f:
93 | lang_data = json.load(f, object_pairs_hook=OrderedDict)
94 |
95 | # Find the difference between the language file and the standard file
96 | diff = set(standard_data.keys()) - set(lang_data.keys())
97 |
98 | miss = set(lang_data.keys()) - set(standard_data.keys())
99 |
100 | # Add any missing keys to the language file
101 | for key in diff:
102 | lang_data[key] = "#!" + key
103 | logger.info(f"Added missing key: {key} to {lang_file}")
104 |
105 | # Del any extra keys to the language file
106 | for key in miss:
107 | del lang_data[key]
108 | logger.info(f"Del extra key: {key} from {lang_file}")
109 |
110 | # Sort the keys of the language file to match the order of the standard file
111 | lang_data = OrderedDict(
112 | sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113 | )
114 |
115 | # Save the updated language file
116 | with open(lang_file, "w", encoding="utf-8") as f:
117 | json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118 | f.write("\n")
119 |
120 | logger.info(f"Updated {lang_file}")
121 |
122 | logger.info("Done")
123 |
--------------------------------------------------------------------------------
/fish_speech/inference_engine/reference_loader.py:
--------------------------------------------------------------------------------
1 | import io
2 | from hashlib import sha256
3 | from pathlib import Path
4 | from typing import Callable, Literal, Tuple
5 |
6 | import torch
7 | import torchaudio
8 | from loguru import logger
9 |
10 | from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
11 | from fish_speech.utils.file import (
12 | AUDIO_EXTENSIONS,
13 | audio_to_bytes,
14 | list_files,
15 | read_ref_text,
16 | )
17 | from fish_speech.utils.schema import ServeReferenceAudio
18 |
19 |
20 | class ReferenceLoader:
21 |
22 | def __init__(self) -> None:
23 | """
24 | Component of the TTSInferenceEngine class.
25 | Loads and manages the cache for the reference audio and text.
26 | """
27 | self.ref_by_id: dict = {}
28 | self.ref_by_hash: dict = {}
29 |
30 | # Make Pylance happy (attribut/method not defined...)
31 | self.decoder_model: FireflyArchitecture
32 | self.encode_reference: Callable
33 |
34 | # Define the torchaudio backend
35 | backends = torchaudio.list_audio_backends()
36 | if "ffmpeg" in backends:
37 | self.backend = "ffmpeg"
38 | else:
39 | self.backend = "soundfile"
40 |
41 | def load_by_id(
42 | self,
43 | id: str,
44 | use_cache: Literal["on", "off"],
45 | ) -> Tuple:
46 |
47 | # Load the references audio and text by id
48 | ref_folder = Path("references") / id
49 | ref_folder.mkdir(parents=True, exist_ok=True)
50 | ref_audios = list_files(
51 | ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
52 | )
53 |
54 | if use_cache == "off" or id not in self.ref_by_id:
55 | # If the references are not already loaded, encode them
56 | prompt_tokens = [
57 | self.encode_reference(
58 | # decoder_model=self.decoder_model,
59 | reference_audio=audio_to_bytes(str(ref_audio)),
60 | enable_reference_audio=True,
61 | )
62 | for ref_audio in ref_audios
63 | ]
64 | prompt_texts = [
65 | read_ref_text(str(ref_audio.with_suffix(".lab")))
66 | for ref_audio in ref_audios
67 | ]
68 | self.ref_by_id[id] = (prompt_tokens, prompt_texts)
69 |
70 | else:
71 | # Reuse already encoded references
72 | logger.info("Use same references")
73 | prompt_tokens, prompt_texts = self.ref_by_id[id]
74 |
75 | return prompt_tokens, prompt_texts
76 |
77 | def load_by_hash(
78 | self,
79 | references: list[ServeReferenceAudio],
80 | use_cache: Literal["on", "off"],
81 | ) -> Tuple:
82 |
83 | # Load the references audio and text by hash
84 | audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
85 |
86 | cache_used = False
87 | prompt_tokens, prompt_texts = [], []
88 | for i, ref in enumerate(references):
89 | if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
90 | # If the references are not already loaded, encode them
91 | prompt_tokens.append(
92 | self.encode_reference(
93 | reference_audio=ref.audio,
94 | enable_reference_audio=True,
95 | )
96 | )
97 | prompt_texts.append(ref.text)
98 | self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
99 |
100 | else:
101 | # Reuse already encoded references
102 | prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
103 | cache_used = True
104 |
105 | if cache_used:
106 | logger.info("Use same references")
107 |
108 | return prompt_tokens, prompt_texts
109 |
110 | def load_audio(self, reference_audio, sr):
111 | """
112 | Load the audio data from a file or bytes.
113 | """
114 | if len(reference_audio) > 255 or not Path(reference_audio).exists():
115 | audio_data = reference_audio
116 | reference_audio = io.BytesIO(audio_data)
117 |
118 | waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
119 |
120 | if waveform.shape[0] > 1:
121 | waveform = torch.mean(waveform, dim=0, keepdim=True)
122 |
123 | if original_sr != sr:
124 | resampler = torchaudio.transforms.Resample(
125 | orig_freq=original_sr, new_freq=sr
126 | )
127 | waveform = resampler(waveform)
128 |
129 | audio = waveform.squeeze().numpy()
130 | return audio
131 |
--------------------------------------------------------------------------------
/fish_speech/inference_engine/utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import wave
3 | from dataclasses import dataclass
4 | from typing import Literal, Optional, Tuple
5 |
6 | import numpy as np
7 |
8 |
9 | @dataclass
10 | class InferenceResult:
11 | code: Literal["header", "segment", "error", "final"]
12 | audio: Optional[Tuple[int, np.ndarray]]
13 | error: Optional[Exception]
14 |
15 |
16 | def wav_chunk_header(
17 | sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
18 | ) -> bytes:
19 | buffer = io.BytesIO()
20 |
21 | with wave.open(buffer, "wb") as wav_file:
22 | wav_file.setnchannels(channels)
23 | wav_file.setsampwidth(bit_depth // 8)
24 | wav_file.setframerate(sample_rate)
25 |
26 | wav_header_bytes = buffer.getvalue()
27 | buffer.close()
28 |
29 | return wav_header_bytes
30 |
--------------------------------------------------------------------------------
/fish_speech/inference_engine/vq_manager.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | from loguru import logger
5 |
6 | from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
7 |
8 |
9 | class VQManager:
10 |
11 | def __init__(self):
12 | # Make Pylance happy (attribut/method not defined...)
13 | self.decoder_model: FireflyArchitecture
14 | self.load_audio: Callable
15 |
16 | def decode_vq_tokens(self, codes):
17 | feature_lengths = torch.tensor(
18 | [codes.shape[1]], device=self.decoder_model.device
19 | )
20 | logger.info(f"VQ features: {codes.shape}")
21 |
22 | if isinstance(self.decoder_model, FireflyArchitecture):
23 | return self.decoder_model.decode(
24 | indices=codes[None],
25 | feature_lengths=feature_lengths,
26 | )[0].squeeze()
27 |
28 | raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
29 |
30 | def encode_reference(self, reference_audio, enable_reference_audio):
31 | if enable_reference_audio and reference_audio is not None:
32 | # Load audios, and prepare basic info here
33 | reference_audio_content = self.load_audio(
34 | reference_audio, self.decoder_model.spec_transform.sample_rate
35 | )
36 |
37 | audios = torch.from_numpy(reference_audio_content).to(
38 | self.decoder_model.device
39 | )[None, None, :]
40 | audio_lengths = torch.tensor(
41 | [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
42 | )
43 | logger.info(
44 | f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
45 | )
46 |
47 | # VQ Encoder
48 | if isinstance(self.decoder_model, FireflyArchitecture):
49 | prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
50 | logger.info(f"Encoded prompt: {prompt_tokens.shape}")
51 | else:
52 | raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
53 | else:
54 | prompt_tokens = None
55 | logger.info("No reference audio provided")
56 |
57 | return prompt_tokens
58 |
--------------------------------------------------------------------------------
/fish_speech/models/text2semantic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/fish_speech/models/text2semantic/__init__.py
--------------------------------------------------------------------------------
/fish_speech/models/text2semantic/lora.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import loralib as lora
4 |
5 |
6 | @dataclass
7 | class LoraConfig:
8 | r: int
9 | lora_alpha: float
10 | lora_dropout: float = 0.0
11 |
12 |
13 | def setup_lora(model, lora_config):
14 | # Replace the embedding layer with a LoRA layer
15 | model.embeddings = lora.Embedding(
16 | num_embeddings=model.embeddings.num_embeddings,
17 | embedding_dim=model.embeddings.embedding_dim,
18 | padding_idx=model.embeddings.padding_idx,
19 | r=lora_config.r,
20 | lora_alpha=lora_config.lora_alpha,
21 | )
22 |
23 | model.codebook_embeddings = lora.Embedding(
24 | num_embeddings=model.codebook_embeddings.num_embeddings,
25 | embedding_dim=model.codebook_embeddings.embedding_dim,
26 | padding_idx=model.codebook_embeddings.padding_idx,
27 | r=lora_config.r,
28 | lora_alpha=lora_config.lora_alpha,
29 | )
30 |
31 | # Replace output layer with a LoRA layer
32 | linears = [(model, "output")]
33 |
34 | # Replace all linear layers with LoRA layers
35 | for layer in model.layers:
36 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37 | linears.extend(
38 | [
39 | (layer.feed_forward, "w1"),
40 | (layer.feed_forward, "w2"),
41 | (layer.feed_forward, "w3"),
42 | ]
43 | )
44 |
45 | if hasattr(model, "fast_layers"):
46 | model.fast_embeddings = lora.Embedding(
47 | num_embeddings=model.fast_embeddings.num_embeddings,
48 | embedding_dim=model.fast_embeddings.embedding_dim,
49 | padding_idx=model.fast_embeddings.padding_idx,
50 | r=lora_config.r,
51 | lora_alpha=lora_config.lora_alpha,
52 | )
53 |
54 | # Dual-AR model
55 | linears.append((model, "fast_output"))
56 |
57 | for layer in model.fast_layers:
58 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59 | linears.extend(
60 | [
61 | (layer.feed_forward, "w1"),
62 | (layer.feed_forward, "w2"),
63 | (layer.feed_forward, "w3"),
64 | ]
65 | )
66 |
67 | for module, layer in linears:
68 | updated_linear = lora.Linear(
69 | in_features=getattr(module, layer).in_features,
70 | out_features=getattr(module, layer).out_features,
71 | bias=getattr(module, layer).bias,
72 | r=lora_config.r,
73 | lora_alpha=lora_config.lora_alpha,
74 | lora_dropout=lora_config.lora_dropout,
75 | )
76 | setattr(module, layer, updated_linear)
77 |
78 | # Mark only the LoRA layers as trainable
79 | lora.mark_only_lora_as_trainable(model, bias="none")
80 |
81 |
82 | def get_merged_state_dict(model):
83 | # This line will merge the state dict of the model and the LoRA parameters
84 | model.eval()
85 |
86 | # Then we need to remove the LoRA parameters from the state dict
87 | state_dict = model.state_dict()
88 | for name in list(state_dict.keys()):
89 | if "lora" in name:
90 | state_dict.pop(name)
91 |
92 | return state_dict
93 |
--------------------------------------------------------------------------------
/fish_speech/models/vqgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/fish_speech/models/vqgan/__init__.py
--------------------------------------------------------------------------------
/fish_speech/models/vqgan/inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import click
4 | import hydra
5 | import numpy as np
6 | import pyrootutils
7 | import soundfile as sf
8 | import torch
9 | import torchaudio
10 | from hydra import compose, initialize
11 | from hydra.utils import instantiate
12 | from loguru import logger
13 | from omegaconf import OmegaConf
14 |
15 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16 |
17 | from fish_speech.utils.file import AUDIO_EXTENSIONS
18 |
19 | # register eval resolver
20 | OmegaConf.register_new_resolver("eval", eval)
21 |
22 |
23 | def load_model(config_name, checkpoint_path, device="cuda"):
24 | hydra.core.global_hydra.GlobalHydra.instance().clear()
25 | with initialize(version_base="1.3", config_path="../../configs"):
26 | cfg = compose(config_name=config_name)
27 |
28 | model = instantiate(cfg)
29 | state_dict = torch.load(
30 | checkpoint_path, map_location=device, mmap=True, weights_only=True
31 | )
32 | if "state_dict" in state_dict:
33 | state_dict = state_dict["state_dict"]
34 |
35 | if any("generator" in k for k in state_dict):
36 | state_dict = {
37 | k.replace("generator.", ""): v
38 | for k, v in state_dict.items()
39 | if "generator." in k
40 | }
41 |
42 | result = model.load_state_dict(state_dict, strict=False, assign=True)
43 | model.eval()
44 | model.to(device)
45 |
46 | logger.info(f"Loaded model: {result}")
47 | return model
48 |
49 |
50 | @torch.no_grad()
51 | @click.command()
52 | @click.option(
53 | "--input-path",
54 | "-i",
55 | default="test.wav",
56 | type=click.Path(exists=True, path_type=Path),
57 | )
58 | @click.option(
59 | "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
60 | )
61 | @click.option("--config-name", default="firefly_gan_vq")
62 | @click.option(
63 | "--checkpoint-path",
64 | default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
65 | )
66 | @click.option(
67 | "--device",
68 | "-d",
69 | default="cuda",
70 | )
71 | def main(input_path, output_path, config_name, checkpoint_path, device):
72 | model = load_model(config_name, checkpoint_path, device=device)
73 |
74 | if input_path.suffix in AUDIO_EXTENSIONS:
75 | logger.info(f"Processing in-place reconstruction of {input_path}")
76 |
77 | # Load audio
78 | audio, sr = torchaudio.load(str(input_path))
79 | if audio.shape[0] > 1:
80 | audio = audio.mean(0, keepdim=True)
81 | audio = torchaudio.functional.resample(
82 | audio, sr, model.spec_transform.sample_rate
83 | )
84 |
85 | audios = audio[None].to(device)
86 | logger.info(
87 | f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
88 | )
89 |
90 | # VQ Encoder
91 | audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
92 | indices = model.encode(audios, audio_lengths)[0][0]
93 |
94 | logger.info(f"Generated indices of shape {indices.shape}")
95 |
96 | # Save indices
97 | np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
98 | elif input_path.suffix == ".npy":
99 | logger.info(f"Processing precomputed indices from {input_path}")
100 | indices = np.load(input_path)
101 | indices = torch.from_numpy(indices).to(device).long()
102 | assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
103 | else:
104 | raise ValueError(f"Unknown input type: {input_path}")
105 |
106 | # Restore
107 | feature_lengths = torch.tensor([indices.shape[1]], device=device)
108 | fake_audios, _ = model.decode(
109 | indices=indices[None], feature_lengths=feature_lengths
110 | )
111 | audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
112 |
113 | logger.info(
114 | f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
115 | )
116 |
117 | # Save audio
118 | fake_audio = fake_audios[0, 0].float().cpu().numpy()
119 | sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
120 | logger.info(f"Saved audio to {output_path}")
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
--------------------------------------------------------------------------------
/fish_speech/models/vqgan/modules/fsq.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 | from vector_quantize_pytorch import GroupedResidualFSQ
8 |
9 | from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10 |
11 |
12 | @dataclass
13 | class FSQResult:
14 | z: torch.Tensor
15 | codes: torch.Tensor
16 | latents: torch.Tensor
17 |
18 |
19 | class DownsampleFiniteScalarQuantize(nn.Module):
20 | def __init__(
21 | self,
22 | input_dim: int = 512,
23 | n_codebooks: int = 9,
24 | n_groups: int = 1,
25 | levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26 | downsample_factor: tuple[int] = (2, 2),
27 | downsample_dims: tuple[int] | None = None,
28 | ):
29 | super().__init__()
30 |
31 | if downsample_dims is None:
32 | downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33 |
34 | all_dims = (input_dim,) + tuple(downsample_dims)
35 |
36 | self.residual_fsq = GroupedResidualFSQ(
37 | dim=all_dims[-1],
38 | levels=levels,
39 | num_quantizers=n_codebooks,
40 | groups=n_groups,
41 | )
42 |
43 | self.downsample_factor = downsample_factor
44 | self.downsample_dims = downsample_dims
45 |
46 | self.downsample = nn.Sequential(
47 | *[
48 | nn.Sequential(
49 | FishConvNet(
50 | all_dims[idx],
51 | all_dims[idx + 1],
52 | kernel_size=factor,
53 | stride=factor,
54 | ),
55 | ConvNeXtBlock(dim=all_dims[idx + 1]),
56 | )
57 | for idx, factor in enumerate(downsample_factor)
58 | ]
59 | )
60 |
61 | self.upsample = nn.Sequential(
62 | *[
63 | nn.Sequential(
64 | FishTransConvNet(
65 | all_dims[idx + 1],
66 | all_dims[idx],
67 | kernel_size=factor,
68 | stride=factor,
69 | ),
70 | ConvNeXtBlock(dim=all_dims[idx]),
71 | )
72 | for idx, factor in reversed(list(enumerate(downsample_factor)))
73 | ]
74 | )
75 |
76 | self.apply(self._init_weights)
77 |
78 | def _init_weights(self, m):
79 | if isinstance(m, (nn.Conv1d, nn.Linear)):
80 | nn.init.trunc_normal_(m.weight, std=0.02)
81 | nn.init.constant_(m.bias, 0)
82 |
83 | def forward(self, z) -> FSQResult:
84 | original_shape = z.shape
85 | z = self.downsample(z)
86 | quantized, indices = self.residual_fsq(z.mT)
87 | result = FSQResult(
88 | z=quantized.mT,
89 | codes=indices.mT,
90 | latents=z,
91 | )
92 | result.z = self.upsample(result.z)
93 |
94 | # Pad or crop z to match original shape
95 | diff = original_shape[-1] - result.z.shape[-1]
96 | left = diff // 2
97 | right = diff - left
98 |
99 | if diff > 0:
100 | result.z = F.pad(result.z, (left, right))
101 | elif diff < 0:
102 | result.z = result.z[..., -left:right]
103 |
104 | return result
105 |
106 | def encode(self, z):
107 | z = self.downsample(z)
108 | _, indices = self.residual_fsq(z.mT)
109 | indices = rearrange(indices, "g b l r -> b (g r) l")
110 | return indices
111 |
112 | def decode(self, indices: torch.Tensor):
113 | indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114 | z_q = self.residual_fsq.get_output_from_indices(indices)
115 | z_q = self.upsample(z_q.mT)
116 | return z_q
117 |
--------------------------------------------------------------------------------
/fish_speech/models/vqgan/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import torch
3 | from matplotlib import pyplot as plt
4 |
5 | matplotlib.use("Agg")
6 |
7 |
8 | def convert_pad_shape(pad_shape):
9 | l = pad_shape[::-1]
10 | pad_shape = [item for sublist in l for item in sublist]
11 | return pad_shape
12 |
13 |
14 | def sequence_mask(length, max_length=None):
15 | if max_length is None:
16 | max_length = length.max()
17 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18 | return x.unsqueeze(0) < length.unsqueeze(1)
19 |
20 |
21 | def init_weights(m, mean=0.0, std=0.01):
22 | classname = m.__class__.__name__
23 | if classname.find("Conv") != -1:
24 | m.weight.data.normal_(mean, std)
25 |
26 |
27 | def get_padding(kernel_size, dilation=1):
28 | return int((kernel_size * dilation - dilation) / 2)
29 |
30 |
31 | def plot_mel(data, titles=None):
32 | fig, axes = plt.subplots(len(data), 1, squeeze=False)
33 |
34 | if titles is None:
35 | titles = [None for i in range(len(data))]
36 |
37 | plt.tight_layout()
38 |
39 | for i in range(len(data)):
40 | mel = data[i]
41 |
42 | if isinstance(mel, torch.Tensor):
43 | mel = mel.float().detach().cpu().numpy()
44 |
45 | axes[i][0].imshow(mel, origin="lower")
46 | axes[i][0].set_aspect(2.5, adjustable="box")
47 | axes[i][0].set_ylim(0, mel.shape[0])
48 | axes[i][0].set_title(titles[i], fontsize="medium")
49 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50 | axes[i][0].set_anchor("W")
51 |
52 | return fig
53 |
54 |
55 | def slice_segments(x, ids_str, segment_size=4):
56 | ret = torch.zeros_like(x[:, :, :segment_size])
57 | for i in range(x.size(0)):
58 | idx_str = ids_str[i]
59 | idx_end = idx_str + segment_size
60 | ret[i] = x[i, :, idx_str:idx_end]
61 |
62 | return ret
63 |
64 |
65 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
66 | b, d, t = x.size()
67 | if x_lengths is None:
68 | x_lengths = t
69 | ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71 | ret = slice_segments(x, ids_str, segment_size)
72 | return ret, ids_str
73 |
74 |
75 | @torch.jit.script
76 | def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77 | n_channels_int = n_channels[0]
78 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
79 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80 | acts = t_act * s_act
81 |
82 | return acts
83 |
84 |
85 | def avg_with_mask(x, mask):
86 | assert mask.dtype == torch.float, "Mask should be float"
87 |
88 | if mask.ndim == 2:
89 | mask = mask.unsqueeze(1)
90 |
91 | if mask.shape[1] == 1:
92 | mask = mask.expand_as(x)
93 |
94 | return (x * mask).sum() / mask.sum()
95 |
--------------------------------------------------------------------------------
/fish_speech/scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def get_cosine_schedule_with_warmup_lr_lambda(
5 | current_step: int,
6 | *,
7 | num_warmup_steps: int | float,
8 | num_training_steps: int,
9 | num_cycles: float = 0.5,
10 | final_lr_ratio: float = 0.0,
11 | ):
12 | if 0 < num_warmup_steps < 1: # float mode
13 | num_warmup_steps = int(num_warmup_steps * num_training_steps)
14 |
15 | if current_step < num_warmup_steps:
16 | return float(current_step) / float(max(1, num_warmup_steps))
17 |
18 | progress = float(current_step - num_warmup_steps) / float(
19 | max(1, num_training_steps - num_warmup_steps)
20 | )
21 |
22 | return max(
23 | final_lr_ratio,
24 | 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25 | )
26 |
27 |
28 | def get_constant_schedule_with_warmup_lr_lambda(
29 | current_step: int,
30 | *,
31 | num_warmup_steps: int | float,
32 | num_training_steps: int | None = None,
33 | ):
34 | if 0 < num_warmup_steps < 1: # float mode
35 | num_warmup_steps = int(num_warmup_steps * num_training_steps)
36 |
37 | if current_step < num_warmup_steps:
38 | return float(current_step) / float(max(1, num_warmup_steps))
39 |
40 | return 1.0
41 |
--------------------------------------------------------------------------------
/fish_speech/text/__init__.py:
--------------------------------------------------------------------------------
1 | from .clean import clean_text
2 | from .spliter import split_text
3 |
4 | __all__ = ["clean_text", "split_text"]
5 |
--------------------------------------------------------------------------------
/fish_speech/text/clean.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | SYMBOLS_MAPPING = {
4 | "‘": "'",
5 | "’": "'",
6 | }
7 |
8 | REPLACE_SYMBOL_REGEX = re.compile(
9 | "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
10 | )
11 |
12 |
13 | EMOJI_REGEX = re.compile(
14 | "["
15 | "\U0001f600-\U0001f64f" # emoticons
16 | "\U0001f300-\U0001f5ff" # symbols & pictographs
17 | "\U0001f680-\U0001f6ff" # transport & map symbols
18 | "\U0001f1e0-\U0001f1ff" # flags (iOS)
19 | "]+",
20 | flags=re.UNICODE,
21 | )
22 |
23 |
24 | def clean_text(text):
25 | # Clean the text
26 | text = text.strip()
27 |
28 | # Replace all chinese symbols with their english counterparts
29 | text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
30 |
31 | # Remove emojis
32 | text = EMOJI_REGEX.sub(r"", text)
33 |
34 | # Remove continuous periods (...) and commas (,,,)
35 | text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
36 |
37 | return text
38 |
--------------------------------------------------------------------------------
/fish_speech/text/spliter.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | from fish_speech.text.clean import clean_text
5 |
6 |
7 | def utf_8_len(text: str):
8 | return len(text.encode("utf-8"))
9 |
10 |
11 | def break_text(texts, length, splits: set):
12 | for text in texts:
13 | if utf_8_len(text) <= length:
14 | yield text
15 | continue
16 |
17 | curr = ""
18 | for char in text:
19 | curr += char
20 |
21 | if char in splits:
22 | yield curr
23 | curr = ""
24 |
25 | if curr:
26 | yield curr
27 |
28 |
29 | def break_text_by_length(texts, length):
30 | for text in texts:
31 | if utf_8_len(text) <= length:
32 | yield text
33 | continue
34 |
35 | curr = ""
36 | for char in text:
37 | curr += char
38 |
39 | if utf_8_len(curr) >= length:
40 | yield curr
41 | curr = ""
42 |
43 | if curr:
44 | yield curr
45 |
46 |
47 | def add_cleaned(curr, segments):
48 | curr = curr.strip()
49 | if curr and not all(c.isspace() or c in string.punctuation for c in curr):
50 | segments.append(curr)
51 |
52 |
53 | def protect_float(text):
54 | # Turns 3.14 into <3_f_14> to prevent splitting
55 | return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
56 |
57 |
58 | def unprotect_float(text):
59 | # Turns <3_f_14> into 3.14
60 | return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
61 |
62 |
63 | def split_text(text, length):
64 | text = clean_text(text)
65 |
66 | # Break the text into pieces with following rules:
67 | # 1. Split the text at ".", "!", "?" if text is NOT a float
68 | # 2. If the text is longer than length, split at ","
69 | # 3. If the text is still longer than length, split at " "
70 | # 4. If the text is still longer than length, split at any character to length
71 |
72 | texts = [text]
73 | texts = map(protect_float, texts)
74 | texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
75 | texts = map(unprotect_float, texts)
76 | texts = break_text(texts, length, {",", ","})
77 | texts = break_text(texts, length, {" "})
78 | texts = list(break_text_by_length(texts, length))
79 |
80 | # Then, merge the texts into segments with length <= length
81 | segments = []
82 | curr = ""
83 |
84 | for text in texts:
85 | if utf_8_len(curr) + utf_8_len(text) <= length:
86 | curr += text
87 | else:
88 | add_cleaned(curr, segments)
89 | curr = text
90 |
91 | if curr:
92 | add_cleaned(curr, segments)
93 |
94 | return segments
95 |
96 |
97 | if __name__ == "__main__":
98 | # Test the split_text function
99 |
100 | text = "This is a test sentence. This is another test sentence. And a third one."
101 |
102 | assert split_text(text, 50) == [
103 | "This is a test sentence.",
104 | "This is another test sentence. And a third one.",
105 | ]
106 | assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
107 | assert split_text(" ", 10) == []
108 | assert split_text("a", 10) == ["a"]
109 |
110 | text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
111 | assert split_text(text, 50) == [
112 | "This is a test sentence with only commas,",
113 | "and no dots, and no exclamation marks,",
114 | "and no question marks, and no newlines.",
115 | ]
116 |
117 | text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
118 | # First half split at " ", second half split at ","
119 | assert split_text(text, 50) == [
120 | "This is a test sentence This is a test sentence",
121 | "This is a test sentence. This is a test sentence,",
122 | "This is a test sentence, This is a test sentence.",
123 | ]
124 |
125 | text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
126 | assert split_text(text, 50) == [
127 | "这是一段很长的中文文本,",
128 | "而且没有句号,也没有感叹号,",
129 | "也没有问号,也没有换行符.",
130 | ]
131 |
--------------------------------------------------------------------------------
/fish_speech/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["USE_LIBUV"] = "0"
4 | import sys
5 | from typing import Optional
6 |
7 | import hydra
8 | import lightning as L
9 | import pyrootutils
10 | import torch
11 | from lightning import Callback, LightningDataModule, LightningModule, Trainer
12 | from lightning.pytorch.loggers import Logger
13 | from lightning.pytorch.strategies import DDPStrategy
14 | from omegaconf import DictConfig, OmegaConf
15 |
16 | os.environ.pop("SLURM_NTASKS", None)
17 | os.environ.pop("SLURM_JOB_NAME", None)
18 | os.environ.pop("SLURM_NTASKS_PER_NODE", None)
19 |
20 | # register eval resolver and root
21 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22 |
23 | # Allow TF32 on Ampere GPUs
24 | torch.set_float32_matmul_precision("high")
25 | torch.backends.cudnn.allow_tf32 = True
26 |
27 | # register eval resolver
28 | OmegaConf.register_new_resolver("eval", eval)
29 |
30 | import fish_speech.utils as utils
31 |
32 | log = utils.RankedLogger(__name__, rank_zero_only=True)
33 |
34 |
35 | @utils.task_wrapper
36 | def train(cfg: DictConfig) -> tuple[dict, dict]:
37 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
38 | training.
39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
40 | failure. Useful for multiruns, saving info about the crash, etc.
41 | Args:
42 | cfg (DictConfig): Configuration composed by Hydra.
43 | Returns:
44 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
45 | """ # noqa: E501
46 |
47 | # set seed for random number generators in pytorch, numpy and python.random
48 | if cfg.get("seed"):
49 | L.seed_everything(cfg.seed, workers=False)
50 |
51 | if cfg.get("deterministic"):
52 | torch.use_deterministic_algorithms(True)
53 |
54 | log.info(f"Instantiating datamodule <{cfg.data._target_}>")
55 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
56 |
57 | log.info(f"Instantiating model <{cfg.model._target_}>")
58 | model: LightningModule = hydra.utils.instantiate(cfg.model)
59 |
60 | log.info("Instantiating callbacks...")
61 | callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
62 |
63 | log.info("Instantiating loggers...")
64 | logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
65 |
66 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
67 | trainer: Trainer = hydra.utils.instantiate(
68 | cfg.trainer,
69 | callbacks=callbacks,
70 | logger=logger,
71 | )
72 |
73 | object_dict = {
74 | "cfg": cfg,
75 | "datamodule": datamodule,
76 | "model": model,
77 | "callbacks": callbacks,
78 | "logger": logger,
79 | "trainer": trainer,
80 | }
81 |
82 | if logger:
83 | log.info("Logging hyperparameters!")
84 | utils.log_hyperparameters(object_dict)
85 |
86 | if cfg.get("train"):
87 | log.info("Starting training!")
88 |
89 | ckpt_path = cfg.get("ckpt_path")
90 | auto_resume = False
91 |
92 | resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
93 | if resume_ckpt_path is not None:
94 | ckpt_path = resume_ckpt_path
95 | auto_resume = True
96 |
97 | if ckpt_path is not None:
98 | log.info(f"Resuming from checkpoint: {ckpt_path}")
99 |
100 | # resume weights only is disabled for auto-resume
101 | if cfg.get("resume_weights_only") and auto_resume is False:
102 | log.info("Resuming weights only!")
103 | ckpt = torch.load(ckpt_path, map_location=model.device)
104 | if "state_dict" in ckpt:
105 | ckpt = ckpt["state_dict"]
106 | err = model.load_state_dict(ckpt, strict=False)
107 | log.info(f"Error loading state dict: {err}")
108 | ckpt_path = None
109 |
110 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
111 |
112 | train_metrics = trainer.callback_metrics
113 |
114 | if cfg.get("test"):
115 | log.info("Starting testing!")
116 | ckpt_path = trainer.checkpoint_callback.best_model_path
117 | if ckpt_path == "":
118 | log.warning("Best ckpt not found! Using current weights for testing...")
119 | ckpt_path = cfg.get("ckpt_path")
120 |
121 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
122 | log.info(f"Best ckpt path: {ckpt_path}")
123 |
124 | test_metrics = trainer.callback_metrics
125 |
126 | # merge train and test metrics
127 | metric_dict = {**train_metrics, **test_metrics}
128 |
129 | return metric_dict, object_dict
130 |
131 |
132 | @hydra.main(
133 | version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
134 | )
135 | def main(cfg: DictConfig) -> Optional[float]:
136 | # train the model
137 | train(cfg)
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/fish_speech/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .braceexpand import braceexpand
2 | from .context import autocast_exclude_mps
3 | from .file import get_latest_checkpoint
4 | from .instantiators import instantiate_callbacks, instantiate_loggers
5 | from .logger import RankedLogger
6 | from .logging_utils import log_hyperparameters
7 | from .rich_utils import enforce_tags, print_config_tree
8 | from .utils import extras, get_metric_value, set_seed, task_wrapper
9 |
10 | __all__ = [
11 | "enforce_tags",
12 | "extras",
13 | "get_metric_value",
14 | "RankedLogger",
15 | "instantiate_callbacks",
16 | "instantiate_loggers",
17 | "log_hyperparameters",
18 | "print_config_tree",
19 | "task_wrapper",
20 | "braceexpand",
21 | "get_latest_checkpoint",
22 | "autocast_exclude_mps",
23 | "set_seed",
24 | ]
25 |
--------------------------------------------------------------------------------
/fish_speech/utils/context.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 |
3 | import torch
4 |
5 |
6 | def autocast_exclude_mps(
7 | device_type: str, dtype: torch.dtype
8 | ) -> nullcontext | torch.autocast:
9 | return (
10 | nullcontext()
11 | if torch.backends.mps.is_available()
12 | else torch.autocast(device_type, dtype)
13 | )
14 |
--------------------------------------------------------------------------------
/fish_speech/utils/file.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Union
4 |
5 | from loguru import logger
6 | from natsort import natsorted
7 |
8 | AUDIO_EXTENSIONS = {
9 | ".mp3",
10 | ".wav",
11 | ".flac",
12 | ".ogg",
13 | ".m4a",
14 | ".wma",
15 | ".aac",
16 | ".aiff",
17 | ".aif",
18 | ".aifc",
19 | }
20 |
21 | VIDEO_EXTENSIONS = {
22 | ".mp4",
23 | ".avi",
24 | }
25 |
26 |
27 | def get_latest_checkpoint(path: Path | str) -> Path | None:
28 | # Find the latest checkpoint
29 | ckpt_dir = Path(path)
30 |
31 | if ckpt_dir.exists() is False:
32 | return None
33 |
34 | ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
35 | if len(ckpts) == 0:
36 | return None
37 |
38 | return ckpts[-1]
39 |
40 |
41 | def audio_to_bytes(file_path):
42 | if not file_path or not Path(file_path).exists():
43 | return None
44 | with open(file_path, "rb") as wav_file:
45 | wav = wav_file.read()
46 | return wav
47 |
48 |
49 | def read_ref_text(ref_text):
50 | path = Path(ref_text)
51 | if path.exists() and path.is_file():
52 | with path.open("r", encoding="utf-8") as file:
53 | return file.read()
54 | return ref_text
55 |
56 |
57 | def list_files(
58 | path: Union[Path, str],
59 | extensions: set[str] = set(),
60 | recursive: bool = False,
61 | sort: bool = True,
62 | ) -> list[Path]:
63 | """List files in a directory.
64 |
65 | Args:
66 | path (Path): Path to the directory.
67 | extensions (set, optional): Extensions to filter. Defaults to None.
68 | recursive (bool, optional): Whether to search recursively. Defaults to False.
69 | sort (bool, optional): Whether to sort the files. Defaults to True.
70 |
71 | Returns:
72 | list: List of files.
73 | """
74 |
75 | if isinstance(path, str):
76 | path = Path(path)
77 |
78 | if not path.exists():
79 | raise FileNotFoundError(f"Directory {path} does not exist.")
80 |
81 | files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
82 |
83 | if sort:
84 | files = natsorted(files)
85 |
86 | return files
87 |
88 |
89 | def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
90 | """
91 | Load a Bert-VITS2 style filelist.
92 | """
93 |
94 | files = set()
95 | results = []
96 | count_duplicated, count_not_found = 0, 0
97 |
98 | LANGUAGE_TO_LANGUAGES = {
99 | "zh": ["zh", "en"],
100 | "jp": ["jp", "en"],
101 | "en": ["en"],
102 | }
103 |
104 | with open(path, "r", encoding="utf-8") as f:
105 | for line in f.readlines():
106 | splits = line.strip().split("|", maxsplit=3)
107 | if len(splits) != 4:
108 | logger.warning(f"Invalid line: {line}")
109 | continue
110 |
111 | filename, speaker, language, text = splits
112 | file = Path(filename)
113 | language = language.strip().lower()
114 |
115 | if language == "ja":
116 | language = "jp"
117 |
118 | assert language in ["zh", "jp", "en"], f"Invalid language {language}"
119 | languages = LANGUAGE_TO_LANGUAGES[language]
120 |
121 | if file in files:
122 | logger.warning(f"Duplicated file: {file}")
123 | count_duplicated += 1
124 | continue
125 |
126 | if not file.exists():
127 | logger.warning(f"File not found: {file}")
128 | count_not_found += 1
129 | continue
130 |
131 | results.append((file, speaker, languages, text))
132 |
133 | if count_duplicated > 0:
134 | logger.warning(f"Total duplicated files: {count_duplicated}")
135 |
136 | if count_not_found > 0:
137 | logger.warning(f"Total files not found: {count_not_found}")
138 |
139 | return results
140 |
--------------------------------------------------------------------------------
/fish_speech/utils/instantiators.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import hydra
4 | from omegaconf import DictConfig
5 | from pytorch_lightning import Callback
6 | from pytorch_lightning.loggers import Logger
7 |
8 | from .logger import RankedLogger
9 |
10 | log = RankedLogger(__name__, rank_zero_only=True)
11 |
12 |
13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14 | """Instantiates callbacks from config."""
15 |
16 | callbacks: List[Callback] = []
17 |
18 | if not callbacks_cfg:
19 | log.warning("No callback configs found! Skipping..")
20 | return callbacks
21 |
22 | if not isinstance(callbacks_cfg, DictConfig):
23 | raise TypeError("Callbacks config must be a DictConfig!")
24 |
25 | for _, cb_conf in callbacks_cfg.items():
26 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
27 | log.info(f"Instantiating callback <{cb_conf._target_}>")
28 | callbacks.append(hydra.utils.instantiate(cb_conf))
29 |
30 | return callbacks
31 |
32 |
33 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
34 | """Instantiates loggers from config."""
35 |
36 | logger: List[Logger] = []
37 |
38 | if not logger_cfg:
39 | log.warning("No logger configs found! Skipping...")
40 | return logger
41 |
42 | if not isinstance(logger_cfg, DictConfig):
43 | raise TypeError("Logger config must be a DictConfig!")
44 |
45 | for _, lg_conf in logger_cfg.items():
46 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
47 | log.info(f"Instantiating logger <{lg_conf._target_}>")
48 | logger.append(hydra.utils.instantiate(lg_conf))
49 |
50 | return logger
51 |
--------------------------------------------------------------------------------
/fish_speech/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Mapping, Optional
3 |
4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5 |
6 |
7 | class RankedLogger(logging.LoggerAdapter):
8 | """A multi-GPU-friendly python command line logger."""
9 |
10 | def __init__(
11 | self,
12 | name: str = __name__,
13 | rank_zero_only: bool = True,
14 | extra: Optional[Mapping[str, object]] = None,
15 | ) -> None:
16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17 | with their rank prefixed in the log message.
18 |
19 | :param name: The name of the logger. Default is ``__name__``.
20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22 | """
23 | logger = logging.getLogger(name)
24 | super().__init__(logger=logger, extra=extra)
25 | self.rank_zero_only = rank_zero_only
26 |
27 | def log(
28 | self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
29 | ) -> None:
30 | """Delegate a log call to the underlying logger, after prefixing its message with the rank
31 | of the process it's being logged from. If `'rank'` is provided, then the log will only
32 | occur on that rank/process.
33 |
34 | :param level: The level to log at. Look at `logging.__init__.py` for more information.
35 | :param msg: The message to log.
36 | :param rank: The rank to log at.
37 | :param args: Additional args to pass to the underlying logging function.
38 | :param kwargs: Any additional keyword args to pass to the underlying logging function.
39 | """
40 | if self.isEnabledFor(level):
41 | msg, kwargs = self.process(msg, kwargs)
42 | current_rank = getattr(rank_zero_only, "rank", None)
43 | if current_rank is None:
44 | raise RuntimeError(
45 | "The `rank_zero_only.rank` needs to be set before use"
46 | )
47 | msg = rank_prefixed_message(msg, current_rank)
48 | if self.rank_zero_only:
49 | if current_rank == 0:
50 | self.logger.log(level, msg, *args, **kwargs)
51 | else:
52 | if rank is None:
53 | self.logger.log(level, msg, *args, **kwargs)
54 | elif current_rank == rank:
55 | self.logger.log(level, msg, *args, **kwargs)
56 |
--------------------------------------------------------------------------------
/fish_speech/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | from lightning.pytorch.utilities import rank_zero_only
2 |
3 | from fish_speech.utils import logger as log
4 |
5 |
6 | @rank_zero_only
7 | def log_hyperparameters(object_dict: dict) -> None:
8 | """Controls which config parts are saved by lightning loggers.
9 |
10 | Additionally saves:
11 | - Number of model parameters
12 | """
13 |
14 | hparams = {}
15 |
16 | cfg = object_dict["cfg"]
17 | model = object_dict["model"]
18 | trainer = object_dict["trainer"]
19 |
20 | if not trainer.logger:
21 | log.warning("Logger not found! Skipping hyperparameter logging...")
22 | return
23 |
24 | hparams["model"] = cfg["model"]
25 |
26 | # save number of model parameters
27 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
28 | hparams["model/params/trainable"] = sum(
29 | p.numel() for p in model.parameters() if p.requires_grad
30 | )
31 | hparams["model/params/non_trainable"] = sum(
32 | p.numel() for p in model.parameters() if not p.requires_grad
33 | )
34 |
35 | hparams["data"] = cfg["data"]
36 | hparams["trainer"] = cfg["trainer"]
37 |
38 | hparams["callbacks"] = cfg.get("callbacks")
39 | hparams["extras"] = cfg.get("extras")
40 |
41 | hparams["task_name"] = cfg.get("task_name")
42 | hparams["tags"] = cfg.get("tags")
43 | hparams["ckpt_path"] = cfg.get("ckpt_path")
44 | hparams["seed"] = cfg.get("seed")
45 |
46 | # send hparams to all loggers
47 | for logger in trainer.loggers:
48 | logger.log_hyperparams(hparams)
49 |
--------------------------------------------------------------------------------
/fish_speech/utils/rich_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Sequence
3 |
4 | import rich
5 | import rich.syntax
6 | import rich.tree
7 | from hydra.core.hydra_config import HydraConfig
8 | from lightning.pytorch.utilities import rank_zero_only
9 | from omegaconf import DictConfig, OmegaConf, open_dict
10 | from rich.prompt import Prompt
11 |
12 | from fish_speech.utils import logger as log
13 |
14 |
15 | @rank_zero_only
16 | def print_config_tree(
17 | cfg: DictConfig,
18 | print_order: Sequence[str] = (
19 | "data",
20 | "model",
21 | "callbacks",
22 | "logger",
23 | "trainer",
24 | "paths",
25 | "extras",
26 | ),
27 | resolve: bool = False,
28 | save_to_file: bool = False,
29 | ) -> None:
30 | """Prints content of DictConfig using Rich library and its tree structure.
31 |
32 | Args:
33 | cfg (DictConfig): Configuration composed by Hydra.
34 | print_order (Sequence[str], optional): Determines in what order config components are printed.
35 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
36 | save_to_file (bool, optional): Whether to export config to the hydra output folder.
37 | """ # noqa: E501
38 |
39 | style = "dim"
40 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
41 |
42 | queue = []
43 |
44 | # add fields from `print_order` to queue
45 | for field in print_order:
46 | (
47 | queue.append(field)
48 | if field in cfg
49 | else log.warning(
50 | f"Field '{field}' not found in config. "
51 | + f"Skipping '{field}' config printing..."
52 | )
53 | )
54 |
55 | # add all the other fields to queue (not specified in `print_order`)
56 | for field in cfg:
57 | if field not in queue:
58 | queue.append(field)
59 |
60 | # generate config tree from queue
61 | for field in queue:
62 | branch = tree.add(field, style=style, guide_style=style)
63 |
64 | config_group = cfg[field]
65 | if isinstance(config_group, DictConfig):
66 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
67 | else:
68 | branch_content = str(config_group)
69 |
70 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
71 |
72 | # print config tree
73 | rich.print(tree)
74 |
75 | # save config tree to file
76 | if save_to_file:
77 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
78 | rich.print(tree, file=file)
79 |
80 |
81 | @rank_zero_only
82 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
83 | """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
84 |
85 | if not cfg.get("tags"):
86 | if "id" in HydraConfig().cfg.hydra.job:
87 | raise ValueError("Specify tags before launching a multirun!")
88 |
89 | log.warning("No tags provided in config. Prompting user to input tags...")
90 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
91 | tags = [t.strip() for t in tags.split(",") if t != ""]
92 |
93 | with open_dict(cfg):
94 | cfg.tags = tags
95 |
96 | log.info(f"Tags: {cfg.tags}")
97 |
98 | if save_to_file:
99 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
100 | rich.print(cfg.tags, file=file)
101 |
--------------------------------------------------------------------------------
/fish_speech/utils/spectrogram.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio.functional as F
3 | from torch import Tensor, nn
4 | from torchaudio.transforms import MelScale
5 |
6 |
7 | class LinearSpectrogram(nn.Module):
8 | def __init__(
9 | self,
10 | n_fft=2048,
11 | win_length=2048,
12 | hop_length=512,
13 | center=False,
14 | mode="pow2_sqrt",
15 | ):
16 | super().__init__()
17 |
18 | self.n_fft = n_fft
19 | self.win_length = win_length
20 | self.hop_length = hop_length
21 | self.center = center
22 | self.mode = mode
23 | self.return_complex = True
24 |
25 | self.register_buffer("window", torch.hann_window(win_length), persistent=False)
26 |
27 | def forward(self, y: Tensor) -> Tensor:
28 | if y.ndim == 3:
29 | y = y.squeeze(1)
30 |
31 | y = torch.nn.functional.pad(
32 | y.unsqueeze(1),
33 | (
34 | (self.win_length - self.hop_length) // 2,
35 | (self.win_length - self.hop_length + 1) // 2,
36 | ),
37 | mode="reflect",
38 | ).squeeze(1)
39 |
40 | spec = torch.stft(
41 | y,
42 | self.n_fft,
43 | hop_length=self.hop_length,
44 | win_length=self.win_length,
45 | window=self.window,
46 | center=self.center,
47 | pad_mode="reflect",
48 | normalized=False,
49 | onesided=True,
50 | return_complex=self.return_complex,
51 | )
52 |
53 | if self.return_complex:
54 | spec = torch.view_as_real(spec)
55 |
56 | if self.mode == "pow2_sqrt":
57 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
58 |
59 | return spec
60 |
61 |
62 | class LogMelSpectrogram(nn.Module):
63 | def __init__(
64 | self,
65 | sample_rate=44100,
66 | n_fft=2048,
67 | win_length=2048,
68 | hop_length=512,
69 | n_mels=128,
70 | center=False,
71 | f_min=0.0,
72 | f_max=None,
73 | ):
74 | super().__init__()
75 |
76 | self.sample_rate = sample_rate
77 | self.n_fft = n_fft
78 | self.win_length = win_length
79 | self.hop_length = hop_length
80 | self.center = center
81 | self.n_mels = n_mels
82 | self.f_min = f_min
83 | self.f_max = f_max or float(sample_rate // 2)
84 |
85 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
86 |
87 | fb = F.melscale_fbanks(
88 | n_freqs=self.n_fft // 2 + 1,
89 | f_min=self.f_min,
90 | f_max=self.f_max,
91 | n_mels=self.n_mels,
92 | sample_rate=self.sample_rate,
93 | norm="slaney",
94 | mel_scale="slaney",
95 | )
96 | self.register_buffer(
97 | "fb",
98 | fb,
99 | persistent=False,
100 | )
101 |
102 | def compress(self, x: Tensor) -> Tensor:
103 | return torch.log(torch.clamp(x, min=1e-5))
104 |
105 | def decompress(self, x: Tensor) -> Tensor:
106 | return torch.exp(x)
107 |
108 | def apply_mel_scale(self, x: Tensor) -> Tensor:
109 | return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
110 |
111 | def forward(
112 | self, x: Tensor, return_linear: bool = False, sample_rate: int = None
113 | ) -> Tensor:
114 | if sample_rate is not None and sample_rate != self.sample_rate:
115 | x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
116 |
117 | linear = self.spectrogram(x)
118 | x = self.apply_mel_scale(linear)
119 | x = self.compress(x)
120 |
121 | if return_linear:
122 | return x, self.compress(linear)
123 |
124 | return x
125 |
--------------------------------------------------------------------------------
/fish_speech/utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | from importlib.util import find_spec
4 | from typing import Callable
5 |
6 | import numpy as np
7 | import torch
8 | from omegaconf import DictConfig
9 |
10 | from .logger import RankedLogger
11 | from .rich_utils import enforce_tags, print_config_tree
12 |
13 | log = RankedLogger(__name__, rank_zero_only=True)
14 |
15 |
16 | def extras(cfg: DictConfig) -> None:
17 | """Applies optional utilities before the task is started.
18 |
19 | Utilities:
20 | - Ignoring python warnings
21 | - Setting tags from command line
22 | - Rich config printing
23 | """
24 |
25 | # return if no `extras` config
26 | if not cfg.get("extras"):
27 | log.warning("Extras config not found! ")
28 | return
29 |
30 | # disable python warnings
31 | if cfg.extras.get("ignore_warnings"):
32 | log.info("Disabling python warnings! ")
33 | warnings.filterwarnings("ignore")
34 |
35 | # prompt user to input tags from command line if none are provided in the config
36 | if cfg.extras.get("enforce_tags"):
37 | log.info("Enforcing tags! ")
38 | enforce_tags(cfg, save_to_file=True)
39 |
40 | # pretty print config tree using Rich library
41 | if cfg.extras.get("print_config"):
42 | log.info("Printing config tree with Rich! ")
43 | print_config_tree(cfg, resolve=True, save_to_file=True)
44 |
45 |
46 | def task_wrapper(task_func: Callable) -> Callable:
47 | """Optional decorator that controls the failure behavior when executing the task function.
48 |
49 | This wrapper can be used to:
50 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
51 | - save the exception to a `.log` file
52 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
53 | - etc. (adjust depending on your needs)
54 |
55 | Example:
56 | ```
57 | @utils.task_wrapper
58 | def train(cfg: DictConfig) -> Tuple[dict, dict]:
59 |
60 | ...
61 |
62 | return metric_dict, object_dict
63 | ```
64 | """ # noqa: E501
65 |
66 | def wrap(cfg: DictConfig):
67 | # execute the task
68 | try:
69 | metric_dict, object_dict = task_func(cfg=cfg)
70 |
71 | # things to do if exception occurs
72 | except Exception as ex:
73 | # save exception to `.log` file
74 | log.exception("")
75 |
76 | # some hyperparameter combinations might be invalid or
77 | # cause out-of-memory errors so when using hparam search
78 | # plugins like Optuna, you might want to disable
79 | # raising the below exception to avoid multirun failure
80 | raise ex
81 |
82 | # things to always do after either success or exception
83 | finally:
84 | # display output dir path in terminal
85 | log.info(f"Output dir: {cfg.paths.run_dir}")
86 |
87 | # always close wandb run (even if exception occurs so multirun won't fail)
88 | if find_spec("wandb"): # check if wandb is installed
89 | import wandb
90 |
91 | if wandb.run:
92 | log.info("Closing wandb!")
93 | wandb.finish()
94 |
95 | return metric_dict, object_dict
96 |
97 | return wrap
98 |
99 |
100 | def get_metric_value(metric_dict: dict, metric_name: str) -> float:
101 | """Safely retrieves value of the metric logged in LightningModule."""
102 |
103 | if not metric_name:
104 | log.info("Metric name is None! Skipping metric value retrieval...")
105 | return None
106 |
107 | if metric_name not in metric_dict:
108 | raise Exception(
109 | f"Metric value not found! \n"
110 | "Make sure metric name logged in LightningModule is correct!\n"
111 | "Make sure `optimized_metric` name in `hparams_search` config is correct!"
112 | )
113 |
114 | metric_value = metric_dict[metric_name].item()
115 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
116 |
117 | return metric_value
118 |
119 |
120 | def set_seed(seed: int):
121 | if seed < 0:
122 | seed = -seed
123 | if seed > (1 << 31):
124 | seed = 1 << 31
125 |
126 | random.seed(seed)
127 | np.random.seed(seed)
128 | torch.manual_seed(seed)
129 |
130 | if torch.cuda.is_available():
131 | torch.cuda.manual_seed(seed)
132 | torch.cuda.manual_seed_all(seed)
133 |
134 | if torch.backends.cudnn.is_available():
135 | torch.backends.cudnn.deterministic = True
136 | torch.backends.cudnn.benchmark = False
137 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Fish Speech
2 | site_description: Targeting SOTA TTS solutions.
3 | site_url: https://speech.fish.audio
4 |
5 | # Repository
6 | repo_name: fishaudio/fish-speech
7 | repo_url: https://github.com/fishaudio/fish-speech
8 | edit_uri: blob/main/docs
9 |
10 | # Copyright
11 | copyright: Copyright © 2023-2025 by Fish Audio
12 |
13 | theme:
14 | name: material
15 | favicon: assets/figs/logo-circle.png
16 | language: en
17 | features:
18 | - content.action.edit
19 | - content.action.view
20 | - navigation.tracking
21 | - navigation.footer
22 | # - navigation.tabs
23 | - search
24 | - search.suggest
25 | - search.highlight
26 | - search.share
27 | - content.code.copy
28 | icon:
29 | logo: fontawesome/solid/fish
30 |
31 | palette:
32 | # Palette toggle for automatic mode
33 | - media: "(prefers-color-scheme)"
34 | toggle:
35 | icon: material/brightness-auto
36 | name: Switch to light mode
37 |
38 | # Palette toggle for light mode
39 | - media: "(prefers-color-scheme: light)"
40 | scheme: default
41 | toggle:
42 | icon: material/brightness-7
43 | name: Switch to dark mode
44 | primary: black
45 | font:
46 | code: Roboto Mono
47 |
48 | # Palette toggle for dark mode
49 | - media: "(prefers-color-scheme: dark)"
50 | scheme: slate
51 | toggle:
52 | icon: material/brightness-4
53 | name: Switch to light mode
54 | primary: black
55 | font:
56 | code: Roboto Mono
57 |
58 | nav:
59 | - Introduction: index.md
60 | - Finetune: finetune.md
61 | - Inference: inference.md
62 | - Start Agent: start_agent.md
63 | - Samples: samples.md
64 |
65 | # Plugins
66 | plugins:
67 | - search:
68 | separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
69 | lang:
70 | - en
71 | - zh
72 | - ja
73 | - pt
74 | - ko
75 | - i18n:
76 | docs_structure: folder
77 | languages:
78 | - locale: en
79 | name: English
80 | default: true
81 | build: true
82 | - locale: zh
83 | name: 简体中文
84 | build: true
85 | nav:
86 | - 介绍: zh/index.md
87 | - 微调: zh/finetune.md
88 | - 推理: zh/inference.md
89 | - 启动Agent: zh/start_agent.md
90 | - 例子: zh/samples.md
91 | - locale: ja
92 | name: 日本語
93 | build: true
94 | nav:
95 | - Fish Speech の紹介: ja/index.md
96 | - 微調整: ja/finetune.md
97 | - 推論: ja/inference.md
98 | - スタートエージェント: ja/start_agent.md
99 | - サンプル: ja/samples.md
100 | - locale: pt
101 | name: Português (Brasil)
102 | build: true
103 | nav:
104 | - Introdução: pt/index.md
105 | - Ajuste Fino: pt/finetune.md
106 | - Inferência: pt/inference.md
107 | - Agente inicial: pt/start_agent.md
108 | - Amostras: pt/samples.md
109 | - locale: ko
110 | name: 한국어
111 | build: true
112 | nav:
113 | - 소개: ko/index.md
114 | - 파인튜닝: ko/finetune.md
115 | - 추론: ko/inference.md
116 | - 샘플: ko/samples.md
117 |
118 | markdown_extensions:
119 | - pymdownx.highlight:
120 | anchor_linenums: true
121 | line_spans: __span
122 | pygments_lang_class: true
123 | - pymdownx.inlinehilite
124 | - pymdownx.snippets
125 | - pymdownx.superfences
126 | - admonition
127 | - pymdownx.details
128 | - pymdownx.superfences
129 | - attr_list
130 | - md_in_html
131 | - pymdownx.superfences
132 |
133 | extra_css:
134 | - stylesheets/extra.css
135 |
136 | extra:
137 | social:
138 | - icon: fontawesome/brands/discord
139 | link: https://discord.gg/Es5qTB9BcN
140 | - icon: fontawesome/brands/docker
141 | link: https://hub.docker.com/r/fishaudio/fish-speech
142 | - icon: fontawesome/brands/qq
143 | link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093
144 | homepage: https://speech.fish.audio
145 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "fish-speech"
3 | version = "0.1.0"
4 | authors = [
5 | {name = "Lengyue", email = "lengyue@lengyue.me"},
6 | ]
7 | description = "Fish Speech"
8 | readme = "README.md"
9 | requires-python = ">=3.10"
10 | keywords = ["TTS", "Speech"]
11 | license = {text = "Apache-2.0"}
12 | classifiers = [
13 | "Programming Language :: Python :: 3",
14 | ]
15 | dependencies = [
16 | "numpy<=1.26.4",
17 | "transformers>=4.45.2",
18 | "datasets==2.18.0",
19 | "lightning>=2.1.0",
20 | "hydra-core>=1.3.2",
21 | "tensorboard>=2.14.1",
22 | "natsort>=8.4.0",
23 | "einops>=0.7.0",
24 | "librosa>=0.10.1",
25 | "rich>=13.5.3",
26 | "gradio>5.0.0",
27 | "wandb>=0.15.11",
28 | "grpcio>=1.58.0",
29 | "kui>=1.6.0",
30 | "uvicorn>=0.30.0",
31 | "loguru>=0.6.0",
32 | "loralib>=0.1.2",
33 | "pyrootutils>=1.0.4",
34 | "vector_quantize_pytorch==1.14.24",
35 | "resampy>=0.4.3",
36 | "einx[torch]==0.2.2",
37 | "zstandard>=0.22.0",
38 | "pydub",
39 | "pyaudio",
40 | "faster_whisper",
41 | "modelscope==1.17.1",
42 | "funasr==1.1.5",
43 | "opencc-python-reimplemented==0.1.7",
44 | "silero-vad",
45 | "ormsgpack",
46 | "tiktoken>=0.8.0",
47 | "pydantic==2.9.2",
48 | "cachetools",
49 | ]
50 |
51 | [project.optional-dependencies]
52 | stable = [
53 | "torch<=2.4.1",
54 | "torchaudio",
55 | ]
56 |
57 | [build-system]
58 | requires = ["setuptools", "setuptools-scm"]
59 | build-backend = "setuptools.build_meta"
60 |
61 | [tool.setuptools]
62 | packages = ["fish_speech", "tools"]
63 |
--------------------------------------------------------------------------------
/pyrightconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "exclude": [
3 | "data",
4 | "filelists"
5 | ]
6 | }
7 |
--------------------------------------------------------------------------------
/tools/api_server.py:
--------------------------------------------------------------------------------
1 | import re
2 | from threading import Lock
3 |
4 | import pyrootutils
5 | import uvicorn
6 | from kui.asgi import (
7 | Depends,
8 | FactoryClass,
9 | HTTPException,
10 | HttpRoute,
11 | Kui,
12 | OpenAPI,
13 | Routes,
14 | )
15 | from kui.cors import CORSConfig
16 | from kui.openapi.specification import Info
17 | from kui.security import bearer_auth
18 | from loguru import logger
19 | from typing_extensions import Annotated
20 |
21 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22 |
23 | from tools.server.api_utils import MsgPackRequest, parse_args
24 | from tools.server.exception_handler import ExceptionHandler
25 | from tools.server.model_manager import ModelManager
26 | from tools.server.views import routes
27 |
28 |
29 | class API(ExceptionHandler):
30 | def __init__(self):
31 | self.args = parse_args()
32 |
33 | def api_auth(endpoint):
34 | async def verify(token: Annotated[str, Depends(bearer_auth)]):
35 | if token != self.args.api_key:
36 | raise HTTPException(401, None, "Invalid token")
37 | return await endpoint()
38 |
39 | async def passthrough():
40 | return await endpoint()
41 |
42 | if self.args.api_key is not None:
43 | return verify
44 | else:
45 | return passthrough
46 |
47 | self.routes = Routes(
48 | routes, # keep existing routes
49 | http_middlewares=[api_auth], # apply api_auth middleware
50 | )
51 |
52 | # OpenAPIの設定
53 | self.openapi = OpenAPI(
54 | Info(
55 | {
56 | "title": "Fish Speech API",
57 | "version": "1.5.0",
58 | }
59 | ),
60 | ).routes
61 |
62 | # Initialize the app
63 | self.app = Kui(
64 | routes=self.routes + self.openapi[1:], # Remove the default route
65 | exception_handlers={
66 | HTTPException: self.http_exception_handler,
67 | Exception: self.other_exception_handler,
68 | },
69 | factory_class=FactoryClass(http=MsgPackRequest),
70 | cors_config=CORSConfig(),
71 | )
72 |
73 | # Add the state variables
74 | self.app.state.lock = Lock()
75 | self.app.state.device = self.args.device
76 | self.app.state.max_text_length = self.args.max_text_length
77 |
78 | # Associate the app with the model manager
79 | self.app.on_startup(self.initialize_app)
80 |
81 | async def initialize_app(self, app: Kui):
82 | # Make the ModelManager available to the views
83 | app.state.model_manager = ModelManager(
84 | mode=self.args.mode,
85 | device=self.args.device,
86 | half=self.args.half,
87 | compile=self.args.compile,
88 | asr_enabled=self.args.load_asr_model,
89 | llama_checkpoint_path=self.args.llama_checkpoint_path,
90 | decoder_checkpoint_path=self.args.decoder_checkpoint_path,
91 | decoder_config_name=self.args.decoder_config_name,
92 | )
93 |
94 | logger.info(f"Startup done, listening server at http://{self.args.listen}")
95 |
96 |
97 | # Each worker process created by Uvicorn has its own memory space,
98 | # meaning that models and variables are not shared between processes.
99 | # Therefore, any variables (like `llama_queue` or `decoder_model`)
100 | # will not be shared across workers.
101 |
102 | # Multi-threading for deep learning can cause issues, such as inconsistent
103 | # outputs if multiple threads access the same buffers simultaneously.
104 | # Instead, it's better to use multiprocessing or independent models per thread.
105 |
106 | if __name__ == "__main__":
107 | api = API()
108 |
109 | # IPv6 address format is [xxxx:xxxx::xxxx]:port
110 | match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
111 | if match:
112 | host, port = match.groups() # IPv6
113 | else:
114 | host, port = api.args.listen.split(":") # IPv4
115 |
116 | uvicorn.run(
117 | api.app,
118 | host=host,
119 | port=int(port),
120 | workers=api.args.workers,
121 | log_level="info",
122 | )
123 |
--------------------------------------------------------------------------------
/tools/download_models.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from huggingface_hub import hf_hub_download
4 |
5 |
6 | # Download
7 | def check_and_download_files(repo_id, file_list, local_dir):
8 | os.makedirs(local_dir, exist_ok=True)
9 | for file in file_list:
10 | file_path = os.path.join(local_dir, file)
11 | if not os.path.exists(file_path):
12 | print(f"{file} 不存在,从 Hugging Face 仓库下载...")
13 | hf_hub_download(
14 | repo_id=repo_id,
15 | filename=file,
16 | resume_download=True,
17 | local_dir=local_dir,
18 | local_dir_use_symlinks=False,
19 | )
20 | else:
21 | print(f"{file} 已存在,跳过下载。")
22 |
23 |
24 | # 1st
25 | repo_id_1 = "fishaudio/fish-speech-1.5"
26 | local_dir_1 = "./checkpoints/fish-speech-1.5"
27 | files_1 = [
28 | ".gitattributes",
29 | "model.pth",
30 | "README.md",
31 | "special_tokens.json",
32 | "tokenizer.tiktoken",
33 | "config.json",
34 | "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35 | ]
36 |
37 | # 3rd
38 | repo_id_3 = "fishaudio/fish-speech-1"
39 | local_dir_3 = "./"
40 | files_3 = [
41 | "ffmpeg.exe",
42 | "ffprobe.exe",
43 | ]
44 |
45 | # 4th
46 | repo_id_4 = "SpicyqSama007/fish-speech-packed"
47 | local_dir_4 = "./"
48 | files_4 = [
49 | "asr-label-win-x64.exe",
50 | ]
51 |
52 | check_and_download_files(repo_id_1, files_1, local_dir_1)
53 |
54 | check_and_download_files(repo_id_3, files_3, local_dir_3)
55 | check_and_download_files(repo_id_4, files_4, local_dir_4)
56 |
--------------------------------------------------------------------------------
/tools/extract_model.py:
--------------------------------------------------------------------------------
1 | import click
2 | import torch
3 | from loguru import logger
4 |
5 |
6 | @click.command()
7 | @click.argument("model_path")
8 | @click.argument("output_path")
9 | def main(model_path, output_path):
10 | if model_path == output_path:
11 | logger.error("Model path and output path are the same")
12 | return
13 |
14 | logger.info(f"Loading model from {model_path}")
15 | state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16 | torch.save(state_dict, output_path)
17 | logger.info(f"Model saved to {output_path}")
18 |
19 |
20 | if __name__ == "__main__":
21 | main()
22 |
--------------------------------------------------------------------------------
/tools/llama/merge_lora.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from copy import deepcopy
3 | from pathlib import Path
4 |
5 | import click
6 | import hydra
7 | import torch
8 | from hydra import compose, initialize
9 | from hydra.utils import instantiate
10 | from loguru import logger
11 |
12 | from fish_speech.models.text2semantic.llama import BaseTransformer
13 | from fish_speech.models.text2semantic.lora import get_merged_state_dict
14 |
15 |
16 | @click.command()
17 | @click.option("--lora-config", type=str, default="r_8_alpha_16")
18 | @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
19 | @click.option("--lora-weight", type=str, required=True)
20 | @click.option("--output", type=str, required=True)
21 | def merge(lora_config, base_weight, lora_weight, output):
22 | output = Path(output)
23 | logger.info(
24 | f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
25 | )
26 |
27 | with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
28 | cfg = compose(config_name=lora_config)
29 |
30 | lora_config = instantiate(cfg)
31 | logger.info(f"Loaded lora model with config {lora_config}")
32 |
33 | llama_model = BaseTransformer.from_pretrained(
34 | path=base_weight,
35 | load_weights=True,
36 | lora_config=lora_config,
37 | )
38 | logger.info(f"Loaded llama model")
39 |
40 | llama_state_dict = llama_model.state_dict()
41 | llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
42 | llama_state_dict_copy = deepcopy(llama_state_dict)
43 | lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False)
44 |
45 | if "state_dict" in llama_state_dict:
46 | llama_state_dict = llama_state_dict["state_dict"]
47 |
48 | if "state_dict" in lora_state_dict:
49 | lora_state_dict = lora_state_dict["state_dict"]
50 |
51 | # remove prefix model.
52 | if any(k.startswith("model.") for k in llama_state_dict.keys()):
53 | llama_state_dict = {
54 | k.replace("model.", ""): v
55 | for k, v in llama_state_dict.items()
56 | if k.startswith("model.")
57 | }
58 | if any(k.startswith("model.") for k in lora_state_dict.keys()):
59 | lora_state_dict = {
60 | k.replace("model.", ""): v
61 | for k, v in lora_state_dict.items()
62 | if k.startswith("model.")
63 | }
64 |
65 | logger.info(f"Found {len(llama_state_dict)} keys in llama model")
66 | logger.info(f"Found {len(lora_state_dict)} keys in lora model")
67 |
68 | merged_state_dict = llama_state_dict | lora_state_dict
69 | llama_model.load_state_dict(merged_state_dict, strict=True)
70 | logger.info(f"Merged model loaded")
71 |
72 | # Trigger eval mode to merge lora
73 | llama_model.eval()
74 | llama_model.save_pretrained(output, drop_lora=True)
75 | logger.info(f"Saved merged model to {output}, validating")
76 |
77 | new_state_dict = torch.load(output / "model.pth", map_location="cpu")
78 | original_keys = set(llama_state_dict_copy.keys())
79 |
80 | tolerance = 1e-5
81 | for key in original_keys:
82 | diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
83 | if diff_l1 > tolerance:
84 | logger.info(f"Significant difference found in key: {key}")
85 | break
86 |
87 | if diff_l1 <= tolerance:
88 | logger.warning(
89 | "Merged model seems identical to the original model. Further validation might be needed."
90 | )
91 | else:
92 | logger.info("Merged model is different from the original model, check passed")
93 |
94 |
95 | if __name__ == "__main__":
96 | merge()
97 |
--------------------------------------------------------------------------------
/tools/run_webui.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from pathlib import Path
4 |
5 | import pyrootutils
6 | import torch
7 | from loguru import logger
8 |
9 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10 |
11 | from fish_speech.inference_engine import TTSInferenceEngine
12 | from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
13 | from fish_speech.models.vqgan.inference import load_model as load_decoder_model
14 | from fish_speech.utils.schema import ServeTTSRequest
15 | from tools.webui import build_app
16 | from tools.webui.inference import get_inference_wrapper
17 |
18 | # Make einx happy
19 | os.environ["EINX_FILTER_TRACEBACK"] = "false"
20 |
21 |
22 | def parse_args():
23 | parser = ArgumentParser()
24 | parser.add_argument(
25 | "--llama-checkpoint-path",
26 | type=Path,
27 | default="checkpoints/fish-speech-1.5",
28 | )
29 | parser.add_argument(
30 | "--decoder-checkpoint-path",
31 | type=Path,
32 | default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
33 | )
34 | parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
35 | parser.add_argument("--device", type=str, default="cuda")
36 | parser.add_argument("--half", action="store_true")
37 | parser.add_argument("--compile", action="store_true")
38 | parser.add_argument("--max-gradio-length", type=int, default=0)
39 | parser.add_argument("--theme", type=str, default="light")
40 |
41 | return parser.parse_args()
42 |
43 |
44 | if __name__ == "__main__":
45 | args = parse_args()
46 | args.precision = torch.half if args.half else torch.bfloat16
47 |
48 | # Check if MPS or CUDA is available
49 | if torch.backends.mps.is_available():
50 | args.device = "mps"
51 | logger.info("mps is available, running on mps.")
52 | elif not torch.cuda.is_available():
53 | logger.info("CUDA is not available, running on CPU.")
54 | args.device = "cpu"
55 |
56 | logger.info("Loading Llama model...")
57 | llama_queue = launch_thread_safe_queue(
58 | checkpoint_path=args.llama_checkpoint_path,
59 | device=args.device,
60 | precision=args.precision,
61 | compile=args.compile,
62 | )
63 |
64 | logger.info("Loading VQ-GAN model...")
65 | decoder_model = load_decoder_model(
66 | config_name=args.decoder_config_name,
67 | checkpoint_path=args.decoder_checkpoint_path,
68 | device=args.device,
69 | )
70 |
71 | logger.info("Decoder model loaded, warming up...")
72 |
73 | # Create the inference engine
74 | inference_engine = TTSInferenceEngine(
75 | llama_queue=llama_queue,
76 | decoder_model=decoder_model,
77 | compile=args.compile,
78 | precision=args.precision,
79 | )
80 |
81 | # Dry run to check if the model is loaded correctly and avoid the first-time latency
82 | list(
83 | inference_engine.inference(
84 | ServeTTSRequest(
85 | text="Hello world.",
86 | references=[],
87 | reference_id=None,
88 | max_new_tokens=1024,
89 | chunk_length=200,
90 | top_p=0.7,
91 | repetition_penalty=1.5,
92 | temperature=0.7,
93 | format="wav",
94 | )
95 | )
96 | )
97 |
98 | logger.info("Warming up done, launching the web UI...")
99 |
100 | # Get the inference function with the immutable arguments
101 | inference_fct = get_inference_wrapper(inference_engine)
102 |
103 | app = build_app(inference_fct, args.theme)
104 | app.launch(show_api=True)
105 |
--------------------------------------------------------------------------------
/tools/server/agent/__init__.py:
--------------------------------------------------------------------------------
1 | import struct
2 | from functools import partial
3 |
4 | import ormsgpack
5 |
6 | from tools.server.agent.generate import generate_responses
7 | from tools.server.agent.pre_generation_utils import prepare_messages
8 |
9 |
10 | def execute_request(input_queue, tokenizer, config, request, device):
11 | """
12 | This function prepares the conversation, encodes the request,
13 | sends the generation request, and handles decoding/streaming.
14 | It returns a response generator (ServeResponse or ServeStreamResponse).
15 | """
16 | prompt, im_end_id = prepare_messages(request, tokenizer, config)
17 | yield from generate_responses(
18 | input_queue, tokenizer, config, request, prompt, im_end_id, device
19 | )
20 |
21 |
22 | def response_generator(req, llama_queue, tokenizer, config, device):
23 | """
24 | Non-streaming response wrapper for the chat endpoint.
25 | Only returns the final result.
26 | """
27 | generator = execute_request(llama_queue, tokenizer, config, req, device)
28 | return next(generator)
29 |
30 |
31 | async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
32 | """
33 | Streaming response wrapper for the chat endpoint.
34 | Returns the response in chunks.
35 | """
36 | generator = execute_request(llama_queue, tokenizer, config, req, device)
37 | for i in generator:
38 | if json_mode:
39 | body = i.model_dump_json().encode("utf-8")
40 | yield b"data: " + body + b"\n\n"
41 | else:
42 | body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
43 | yield struct.pack("I", len(body)) + body
44 |
45 |
46 | def get_response_generator(
47 | llama_queue, tokenizer, config, req, device, json_mode
48 | ) -> partial:
49 | """
50 | Get the correct response generator based on the request.
51 | """
52 | if not req.streaming:
53 | return partial(response_generator, req, llama_queue, tokenizer, config, device)
54 | else:
55 | return partial(
56 | streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
57 | )
58 |
--------------------------------------------------------------------------------
/tools/server/agent/generate.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from fish_speech.utils.schema import ServeMessage, ServeResponse, ServeStreamResponse
4 | from tools.server.agent.generation_utils import (
5 | initialize_decode_buffers,
6 | process_response_tokens,
7 | send_reset_buffer,
8 | )
9 | from tools.server.agent.pre_generation_utils import (
10 | create_generation_request,
11 | send_generation_request,
12 | )
13 |
14 |
15 | def generate_responses(
16 | input_queue, tokenizer, config, request, prompt, im_end_id, device
17 | ):
18 | """
19 | Main generation function that handles the conversation, encodes the request,
20 | sends the generation request, and handles decoding/streaming.
21 | It returns a response generator (ServeResponse or ServeStreamResponse).
22 | """
23 | stats = {}
24 | start = time.time()
25 | stats["start_time"] = start
26 | stats["tokens_count"] = 0
27 |
28 | # Prepare and send the generation request
29 | req = create_generation_request(prompt, request, im_end_id, device)
30 | response_queue = send_generation_request(input_queue, req)
31 | decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
32 |
33 | while True:
34 | response = response_queue.get()
35 |
36 | # Handle abnormal finish or error
37 | if response in ["stop", "error"]:
38 | finish_reason = response
39 | break
40 |
41 | # Process the response tokens
42 | is_first_token = stats["tokens_count"] == 0
43 | responses = process_response_tokens(
44 | response,
45 | tokenizer,
46 | config,
47 | request,
48 | decode_buffer,
49 | parts,
50 | finished,
51 | im_end_id,
52 | stats,
53 | start,
54 | is_first_token,
55 | )
56 |
57 | # Yield the responses if streaming
58 | if request.streaming and responses:
59 | for r in responses:
60 | yield r
61 |
62 | stats["tokens_count"] += 1
63 |
64 | # Check if all samples are finished
65 | if all(finished):
66 | finish_reason = "stop"
67 | break
68 |
69 | # Finalize the response
70 | final_responses = finalize_response(
71 | request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
72 | )
73 | for fr in final_responses:
74 | yield fr
75 |
76 |
77 | def finalize_response(
78 | request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
79 | ):
80 | """
81 | Finalize the response by sending the remaining text buffers.
82 | """
83 | responses = []
84 |
85 | # Send the remaining text buffers
86 | for sample_id in range(request.num_samples):
87 | responses.extend(
88 | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
89 | )
90 |
91 | # Calculate the final stats
92 | stats["total_time"] = (time.time() - stats["start_time"]) * 1000
93 | stats["total_tokens"] = stats["tokens_count"]
94 |
95 | # If streaming, send the final chunks for each sample
96 | if request.streaming:
97 | for sample_id in range(request.num_samples):
98 | if finished[sample_id]:
99 | continue
100 | responses.append(
101 | ServeStreamResponse(
102 | finish_reason=finish_reason, stats=stats, sample_id=sample_id
103 | )
104 | )
105 | else:
106 | # If not streaming, send the full messages for each sample
107 | full_messages = [
108 | ServeMessage(role="assistant", parts=parts[i])
109 | for i in range(request.num_samples)
110 | ]
111 | responses.append(
112 | ServeResponse(
113 | messages=full_messages,
114 | finish_reason=finish_reason,
115 | stats=stats,
116 | )
117 | )
118 |
119 | return responses
120 |
--------------------------------------------------------------------------------
/tools/server/agent/generation_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from fish_speech.utils.schema import (
4 | ServeStreamDelta,
5 | ServeStreamResponse,
6 | ServeTextPart,
7 | ServeVQPart,
8 | )
9 |
10 |
11 | def initialize_decode_buffers(num_samples):
12 | """Initialise the decode buffers for each sample."""
13 | decode_buffer = [[] for _ in range(num_samples)]
14 | parts = [[] for _ in range(num_samples)]
15 | finished = [False for _ in range(num_samples)]
16 | return decode_buffer, parts, finished
17 |
18 |
19 | def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
20 | """Send the remaining text buffer for a sample."""
21 | if len(decode_buffer[sample_id]) == 0:
22 | return []
23 |
24 | decoded = tokenizer.decode(decode_buffer[sample_id])
25 | part = ServeTextPart(text=decoded)
26 |
27 | responses = []
28 | if request.streaming:
29 | responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
30 | else:
31 | parts[sample_id].append(part)
32 |
33 | decode_buffer[sample_id] = []
34 | return responses
35 |
36 |
37 | def handle_semantic_tokens(tokens, config, sample_id, parts, request):
38 | """Handle the semantic tokens returned by the model."""
39 | responses = []
40 | _tokens = tokens[1:].clone()
41 |
42 | if not config.share_codebook_embeddings:
43 | for i in range(len(_tokens)):
44 | _tokens[i] -= config.codebook_size * i
45 |
46 | # If streaming, send the VQ parts directly
47 | if request.streaming:
48 | responses.append(
49 | ServeStreamResponse(
50 | sample_id=sample_id,
51 | delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
52 | )
53 | )
54 | else:
55 | # If not streaming, accumulate the VQ parts
56 | if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
57 | parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
58 | else:
59 | # Accumulate the codes
60 | for codebook_id, value in enumerate(_tokens):
61 | parts[sample_id][-1].codes[codebook_id].append(value.item())
62 |
63 | return responses
64 |
65 |
66 | def process_response_tokens(
67 | response,
68 | tokenizer,
69 | config,
70 | request,
71 | decode_buffer,
72 | parts,
73 | finished,
74 | im_end_id,
75 | stats,
76 | start,
77 | is_first_token,
78 | ):
79 | """Process the response tokens returned by the model."""
80 | responses = []
81 | for sample_id, tokens in enumerate(response):
82 | if finished[sample_id]:
83 | continue
84 |
85 | # End of the conversation
86 | if tokens[0] == im_end_id:
87 | finished[sample_id] = True
88 | # Send the remaining text buffer
89 | responses.extend(
90 | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
91 | )
92 | if request.streaming:
93 | responses.append(
94 | ServeStreamResponse(
95 | sample_id=sample_id,
96 | finish_reason="stop",
97 | stats=stats,
98 | )
99 | )
100 | continue
101 |
102 | # Check if the token is semantic
103 | is_semantic = (
104 | tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
105 | )
106 |
107 | if is_semantic:
108 | # Before the semantic tokens, send the remaining text buffer
109 | responses.extend(
110 | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
111 | )
112 | responses.extend(
113 | handle_semantic_tokens(tokens, config, sample_id, parts, request)
114 | )
115 | else:
116 | # Accumulate the text tokens (not implemented?)
117 | decode_buffer[sample_id].append(tokens[0, 0])
118 |
119 | if is_first_token:
120 | stats["time_to_first_token"] = (time.time() - start) * 1000
121 |
122 | return responses
123 |
--------------------------------------------------------------------------------
/tools/server/agent/pre_generation_utils.py:
--------------------------------------------------------------------------------
1 | import queue
2 |
3 | from fish_speech.conversation import Conversation, Message
4 | from fish_speech.models.text2semantic.inference import GenerateRequest
5 | from fish_speech.tokenizer import IM_END_TOKEN
6 |
7 |
8 | def prepare_messages(request, tokenizer, config):
9 | """
10 | Reorganise the provided list of messages into a conversation.
11 | Encode the conversation for inference.
12 | """
13 | # Convert the messages to ConversationMessage objects
14 | messages = [msg.to_conversation_message() for msg in request.messages]
15 |
16 | if len(messages) < 1:
17 | raise ValueError("At least one message is required")
18 |
19 | # Check the last message to determine the next step
20 | last_role = messages[-1].role
21 | match last_role:
22 | case "user":
23 | # The last message is from the user, ask the assistant to respond with a new message
24 | messages.append(
25 | Message(role="assistant", parts=[], add_im_end=False, modality="voice")
26 | )
27 | case "raw":
28 | # The last message is raw text, ask the assistant to complete it
29 | messages[-1].add_im_start = False
30 | messages[-1].add_im_end = False
31 | messages[-1].modality = "voice"
32 | case "assistant":
33 | # The last message is from the assistant, ask the assistant to continue
34 | messages[-1].add_im_end = False
35 | case _:
36 | # We expect it to be assistant if not user or raw
37 | raise ValueError("The last message must be from the assistant, user or raw")
38 |
39 | # Create a conversation object and encode it for inference
40 | conv = Conversation(messages=messages)
41 | prompt = conv.encode_for_inference(
42 | tokenizer=tokenizer, num_codebooks=config.num_codebooks
43 | )
44 | im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
45 |
46 | return prompt, im_end_id
47 |
48 |
49 | def create_generation_request(prompt, request, im_end_id, device):
50 | """
51 | Convert the request into a dictionary that can be sent to the model for generation.
52 | """
53 | req = {
54 | "prompt": prompt.to(device),
55 | "max_new_tokens": request.max_new_tokens,
56 | "im_end_id": im_end_id,
57 | "temperature": request.temperature,
58 | "top_p": request.top_p,
59 | "repetition_penalty": request.repetition_penalty,
60 | "num_samples": request.num_samples,
61 | "early_stop_threshold": request.early_stop_threshold,
62 | }
63 | return req
64 |
65 |
66 | def send_generation_request(input_queue, req):
67 | """
68 | Send the generation request to the model and return a queue to get the response.
69 | """
70 | response_queue = queue.Queue()
71 | input_queue.put(GenerateRequest(req, response_queue))
72 | return response_queue
73 |
--------------------------------------------------------------------------------
/tools/server/api_utils.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from http import HTTPStatus
3 | from typing import Annotated, Any
4 |
5 | import ormsgpack
6 | from baize.datastructures import ContentType
7 | from kui.asgi import HTTPException, HttpRequest
8 |
9 | from fish_speech.inference_engine import TTSInferenceEngine
10 | from fish_speech.utils.schema import ServeTTSRequest
11 | from tools.server.inference import inference_wrapper as inference
12 |
13 |
14 | def parse_args():
15 | parser = ArgumentParser()
16 | parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
17 | parser.add_argument("--load-asr-model", action="store_true")
18 | parser.add_argument(
19 | "--llama-checkpoint-path",
20 | type=str,
21 | default="checkpoints/fish-speech-1.5",
22 | )
23 | parser.add_argument(
24 | "--decoder-checkpoint-path",
25 | type=str,
26 | default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
27 | )
28 | parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
29 | parser.add_argument("--device", type=str, default="cuda")
30 | parser.add_argument("--half", action="store_true")
31 | parser.add_argument("--compile", action="store_true")
32 | parser.add_argument("--max-text-length", type=int, default=0)
33 | parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
34 | parser.add_argument("--workers", type=int, default=1)
35 | parser.add_argument("--api-key", type=str, default=None)
36 |
37 | return parser.parse_args()
38 |
39 |
40 | class MsgPackRequest(HttpRequest):
41 | async def data(
42 | self,
43 | ) -> Annotated[
44 | Any, ContentType("application/msgpack"), ContentType("application/json")
45 | ]:
46 | if self.content_type == "application/msgpack":
47 | return ormsgpack.unpackb(await self.body)
48 |
49 | elif self.content_type == "application/json":
50 | return await self.json
51 |
52 | raise HTTPException(
53 | HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
54 | headers={"Accept": "application/msgpack, application/json"},
55 | )
56 |
57 |
58 | async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
59 | for chunk in inference(req, engine):
60 | if isinstance(chunk, bytes):
61 | yield chunk
62 |
63 |
64 | async def buffer_to_async_generator(buffer):
65 | yield buffer
66 |
67 |
68 | def get_content_type(audio_format):
69 | if audio_format == "wav":
70 | return "audio/wav"
71 | elif audio_format == "flac":
72 | return "audio/flac"
73 | elif audio_format == "mp3":
74 | return "audio/mpeg"
75 | else:
76 | return "application/octet-stream"
77 |
--------------------------------------------------------------------------------
/tools/server/exception_handler.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from http import HTTPStatus
3 |
4 | from kui.asgi import HTTPException, JSONResponse
5 |
6 |
7 | class ExceptionHandler:
8 |
9 | async def http_exception_handler(self, exc: HTTPException):
10 | return JSONResponse(
11 | dict(
12 | statusCode=exc.status_code,
13 | message=exc.content,
14 | error=HTTPStatus(exc.status_code).phrase,
15 | ),
16 | exc.status_code,
17 | exc.headers,
18 | )
19 |
20 | async def other_exception_handler(self, exc: Exception):
21 | traceback.print_exc()
22 |
23 | status = HTTPStatus.INTERNAL_SERVER_ERROR
24 | return JSONResponse(
25 | dict(statusCode=status, message=str(exc), error=status.phrase),
26 | status,
27 | )
28 |
--------------------------------------------------------------------------------
/tools/server/inference.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 |
3 | import numpy as np
4 | from kui.asgi import HTTPException
5 |
6 | from fish_speech.inference_engine import TTSInferenceEngine
7 | from fish_speech.utils.schema import ServeTTSRequest
8 |
9 | AMPLITUDE = 32768 # Needs an explaination
10 |
11 |
12 | def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
13 | """
14 | Wrapper for the inference function.
15 | Used in the API server.
16 | """
17 | count = 0
18 | for result in engine.inference(req):
19 | match result.code:
20 | case "header":
21 | if isinstance(result.audio, tuple):
22 | yield result.audio[1]
23 |
24 | case "error":
25 | raise HTTPException(
26 | HTTPStatus.INTERNAL_SERVER_ERROR,
27 | content=str(result.error),
28 | )
29 |
30 | case "segment":
31 | count += 1
32 | if isinstance(result.audio, tuple):
33 | yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
34 |
35 | case "final":
36 | count += 1
37 | if isinstance(result.audio, tuple):
38 | yield result.audio[1]
39 | return None # Stop the generator
40 |
41 | if count == 0:
42 | raise HTTPException(
43 | HTTPStatus.INTERNAL_SERVER_ERROR,
44 | content="No audio generated, please check the input text.",
45 | )
46 |
--------------------------------------------------------------------------------
/tools/server/model_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from funasr import AutoModel
3 | from loguru import logger
4 |
5 | from fish_speech.inference_engine import TTSInferenceEngine
6 | from fish_speech.models.text2semantic.inference import (
7 | launch_thread_safe_queue,
8 | launch_thread_safe_queue_agent,
9 | )
10 | from fish_speech.models.vqgan.inference import load_model as load_decoder_model
11 | from fish_speech.utils.schema import ServeTTSRequest
12 | from tools.server.inference import inference_wrapper as inference
13 |
14 | ASR_MODEL_NAME = "iic/SenseVoiceSmall"
15 |
16 |
17 | class ModelManager:
18 | def __init__(
19 | self,
20 | mode: str,
21 | device: str,
22 | half: bool,
23 | compile: bool,
24 | asr_enabled: bool,
25 | llama_checkpoint_path: str,
26 | decoder_checkpoint_path: str,
27 | decoder_config_name: str,
28 | ) -> None:
29 |
30 | self.mode = mode
31 | self.device = device
32 | self.half = half
33 | self.compile = compile
34 |
35 | self.precision = torch.half if half else torch.bfloat16
36 |
37 | # Check if MPS or CUDA is available
38 | if torch.backends.mps.is_available():
39 | self.device = "mps"
40 | logger.info("mps is available, running on mps.")
41 | elif not torch.cuda.is_available():
42 | self.device = "cpu"
43 | logger.info("CUDA is not available, running on CPU.")
44 |
45 | # Load the ASR model if enabled
46 | if asr_enabled:
47 | self.load_asr_model(self.device)
48 |
49 | # Load the TTS models
50 | self.load_llama_model(
51 | llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
52 | )
53 | self.load_decoder_model(
54 | decoder_config_name, decoder_checkpoint_path, self.device
55 | )
56 | self.tts_inference_engine = TTSInferenceEngine(
57 | llama_queue=self.llama_queue,
58 | decoder_model=self.decoder_model,
59 | precision=self.precision,
60 | compile=self.compile,
61 | )
62 |
63 | # Warm up the models
64 | if self.mode == "tts":
65 | self.warm_up(self.tts_inference_engine)
66 |
67 | def load_asr_model(self, device, hub="ms") -> None:
68 | self.asr_model = AutoModel(
69 | model=ASR_MODEL_NAME,
70 | device=device,
71 | disable_pbar=True,
72 | hub=hub,
73 | )
74 | logger.info("ASR model loaded.")
75 |
76 | def load_llama_model(
77 | self, checkpoint_path, device, precision, compile, mode
78 | ) -> None:
79 |
80 | if mode == "tts":
81 | self.llama_queue = launch_thread_safe_queue(
82 | checkpoint_path=checkpoint_path,
83 | device=device,
84 | precision=precision,
85 | compile=compile,
86 | )
87 | elif mode == "agent":
88 | self.llama_queue, self.tokenizer, self.config = (
89 | launch_thread_safe_queue_agent(
90 | checkpoint_path=checkpoint_path,
91 | device=device,
92 | precision=precision,
93 | compile=compile,
94 | )
95 | )
96 | else:
97 | raise ValueError(f"Invalid mode: {mode}")
98 |
99 | logger.info("LLAMA model loaded.")
100 |
101 | def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
102 | self.decoder_model = load_decoder_model(
103 | config_name=config_name,
104 | checkpoint_path=checkpoint_path,
105 | device=device,
106 | )
107 | logger.info("Decoder model loaded.")
108 |
109 | def warm_up(self, tts_inference_engine) -> None:
110 | request = ServeTTSRequest(
111 | text="Hello world.",
112 | references=[],
113 | reference_id=None,
114 | max_new_tokens=1024,
115 | chunk_length=200,
116 | top_p=0.7,
117 | repetition_penalty=1.2,
118 | temperature=0.7,
119 | format="wav",
120 | )
121 | list(inference(request, tts_inference_engine))
122 | logger.info("Models warmed up.")
123 |
--------------------------------------------------------------------------------
/tools/server/model_utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import re
3 |
4 | import librosa
5 | import torch
6 | import torchaudio
7 | from cachetools import LRUCache, cached
8 |
9 | CACHE_MAXSIZE = 10000
10 | MICRO_BATCH_SIZE = 8
11 | ASR_SAMPLE_RATE = 16000
12 | HUGE_GAP_THRESHOLD = 4000
13 |
14 |
15 | @torch.no_grad()
16 | @torch.autocast(device_type="cuda", dtype=torch.half)
17 | def batch_encode(model, audios_list: list[bytes]):
18 | audios: list[torch.Tensor] = [
19 | (
20 | torch.from_numpy(
21 | librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
22 | )[None]
23 | if isinstance(audio, bytes)
24 | else audio
25 | )
26 | for audio in audios_list
27 | ]
28 |
29 | lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
30 | max_length = lengths.max().item()
31 |
32 | print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
33 |
34 | padded = torch.stack(
35 | [
36 | torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
37 | for audio in audios
38 | ]
39 | ).to(model.device)
40 |
41 | features, feature_lengths = model.encode(padded, audio_lengths=lengths)
42 | features, feature_lengths = features.cpu(), feature_lengths.cpu()
43 |
44 | return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
45 |
46 |
47 | @cached(
48 | cache=LRUCache(maxsize=CACHE_MAXSIZE),
49 | key=lambda model, audios: (model.device, tuple(audios)),
50 | )
51 | def cached_vqgan_batch_encode(model, audios: list[bytes]):
52 | return batch_encode(model, audios)
53 |
54 |
55 | @torch.no_grad()
56 | @torch.autocast(device_type="cuda", dtype=torch.half)
57 | def batch_vqgan_decode(model, features):
58 | lengths = torch.tensor(
59 | [feature.shape[-1] for feature in features], device=model.device
60 | )
61 | max_length = lengths.max().item()
62 | padded = torch.stack(
63 | [
64 | torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
65 | for feature in features
66 | ]
67 | ).to(model.device)
68 |
69 | # If bs too large, we do micro batch decode
70 | audios, audio_lengths = [], []
71 | for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
72 | audio, audio_length = model.decode(
73 | padded[i : i + MICRO_BATCH_SIZE],
74 | feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
75 | )
76 | audios.append(audio)
77 | audio_lengths.append(audio_length)
78 | audios = torch.cat(audios, dim=0)
79 | audio_lengths = torch.cat(audio_lengths, dim=0)
80 | audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
81 |
82 | return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
83 |
84 |
85 | @torch.no_grad()
86 | def batch_asr(model, lock, audios, sr, language="auto"):
87 | resampled_audios = []
88 | for audio in audios:
89 | audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
90 | assert audio.ndim == 1
91 | resampled_audios.append(audio)
92 |
93 | with lock:
94 | res = model.generate(
95 | input=resampled_audios,
96 | batch_size=len(resampled_audios),
97 | language=language,
98 | use_itn=True,
99 | )
100 |
101 | results = []
102 | for r, audio in zip(res, audios):
103 | text = r["text"]
104 | text = re.sub(r"<\|.*?\|>", "", text)
105 | duration = len(audio) / sr * 1000
106 | huge_gap = False
107 |
108 | if "timestamp" in r and len(r["timestamp"]) > 2:
109 | for timestamp_a, timestamp_b in zip(
110 | r["timestamp"][:-1], r["timestamp"][1:]
111 | ):
112 | # If there is a gap of more than 4 seconds, we consider it as a huge gap
113 | if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
114 | huge_gap = True
115 | break
116 |
117 | # Doesn't make sense to have a huge gap at the end
118 | if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
119 | huge_gap = True
120 |
121 | results.append(
122 | {
123 | "text": text,
124 | "duration": duration,
125 | "huge_gap": huge_gap,
126 | }
127 | )
128 |
129 | return results
130 |
--------------------------------------------------------------------------------
/tools/smart_pad.py:
--------------------------------------------------------------------------------
1 | import random
2 | from multiprocessing import Pool
3 | from pathlib import Path
4 |
5 | import click
6 | import librosa
7 | import torch.nn.functional as F
8 | import torchaudio
9 | from tqdm import tqdm
10 |
11 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
12 |
13 | threshold = 10 ** (-50 / 20.0)
14 |
15 |
16 | def process(file):
17 | waveform, sample_rate = torchaudio.load(str(file), backend="sox")
18 | if waveform.size(0) > 1:
19 | waveform = waveform.mean(dim=0, keepdim=True)
20 |
21 | loudness = librosa.feature.rms(
22 | y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
23 | )[0]
24 |
25 | for i in range(len(loudness) - 1, 0, -1):
26 | if loudness[i] > threshold:
27 | break
28 |
29 | end_silent_time = (len(loudness) - i) * 512 / sample_rate
30 |
31 | if end_silent_time <= 0.3:
32 | random_time = random.uniform(0.3, 0.7) - end_silent_time
33 | waveform = F.pad(
34 | waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
35 | )
36 |
37 | for i in range(len(loudness)):
38 | if loudness[i] > threshold:
39 | break
40 |
41 | start_silent_time = i * 512 / sample_rate
42 |
43 | if start_silent_time > 0.02:
44 | waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
45 |
46 | torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
47 |
48 |
49 | @click.command()
50 | @click.argument("source", type=Path)
51 | @click.option("--num-workers", type=int, default=12)
52 | def main(source, num_workers):
53 | files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
54 |
55 | with Pool(num_workers) as p:
56 | list(tqdm(p.imap_unordered(process, files), total=len(files)))
57 |
58 |
59 | if __name__ == "__main__":
60 | main()
61 |
--------------------------------------------------------------------------------
/tools/vqgan/create_train_split.py:
--------------------------------------------------------------------------------
1 | import math
2 | from pathlib import Path
3 | from random import Random
4 |
5 | import click
6 | from loguru import logger
7 | from pydub import AudioSegment
8 | from tqdm import tqdm
9 |
10 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
11 |
12 |
13 | @click.command()
14 | @click.argument("root", type=click.Path(exists=True, path_type=Path))
15 | @click.option("--val-ratio", type=float, default=None)
16 | @click.option("--val-count", type=int, default=None)
17 | @click.option("--filelist", default=None, type=Path)
18 | @click.option("--min-duration", default=None, type=float)
19 | @click.option("--max-duration", default=None, type=float)
20 | def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
21 | if filelist:
22 | files = [i[0] for i in load_filelist(filelist)]
23 | else:
24 | files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
25 |
26 | if min_duration is None and max_duration is None:
27 | filtered_files = list(map(str, [file.relative_to(root) for file in files]))
28 | else:
29 | filtered_files = []
30 | for file in tqdm(files):
31 | try:
32 | audio = AudioSegment.from_file(str(file))
33 | duration = len(audio) / 1000.0
34 |
35 | if min_duration is not None and duration < min_duration:
36 | logger.info(
37 | f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
38 | )
39 | continue
40 |
41 | if max_duration is not None and duration > max_duration:
42 | logger.info(
43 | f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
44 | )
45 | continue
46 |
47 | filtered_files.append(str(file.relative_to(root)))
48 | except Exception as e:
49 | logger.info(f"Error processing {file}: {e}")
50 |
51 | logger.info(
52 | f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
53 | )
54 |
55 | Random(42).shuffle(filtered_files)
56 |
57 | if val_count is None and val_ratio is None:
58 | logger.info("Validation ratio and count not specified, using min(20%, 100)")
59 | val_size = min(100, math.ceil(len(filtered_files) * 0.2))
60 | elif val_count is not None and val_ratio is not None:
61 | logger.error("Cannot specify both val_count and val_ratio")
62 | return
63 | elif val_count is not None:
64 | if val_count < 1 or val_count > len(filtered_files):
65 | logger.error("val_count must be between 1 and number of files")
66 | return
67 | val_size = val_count
68 | else:
69 | val_size = math.ceil(len(filtered_files) * val_ratio)
70 |
71 | logger.info(f"Using {val_size} files for validation")
72 |
73 | with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
74 | f.write("\n".join(filtered_files[val_size:]))
75 |
76 | with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
77 | f.write("\n".join(filtered_files[:val_size]))
78 |
79 | logger.info("Done")
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/tools/webui/inference.py:
--------------------------------------------------------------------------------
1 | import html
2 | from functools import partial
3 | from typing import Any, Callable
4 |
5 | from fish_speech.i18n import i18n
6 | from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
7 |
8 |
9 | def inference_wrapper(
10 | text,
11 | reference_id,
12 | reference_audio,
13 | reference_text,
14 | max_new_tokens,
15 | chunk_length,
16 | top_p,
17 | repetition_penalty,
18 | temperature,
19 | seed,
20 | use_memory_cache,
21 | engine,
22 | ):
23 | """
24 | Wrapper for the inference function.
25 | Used in the Gradio interface.
26 | """
27 |
28 | if reference_audio:
29 | references = get_reference_audio(reference_audio, reference_text)
30 | else:
31 | references = []
32 |
33 | req = ServeTTSRequest(
34 | text=text,
35 | reference_id=reference_id if reference_id else None,
36 | references=references,
37 | max_new_tokens=max_new_tokens,
38 | chunk_length=chunk_length,
39 | top_p=top_p,
40 | repetition_penalty=repetition_penalty,
41 | temperature=temperature,
42 | seed=int(seed) if seed else None,
43 | use_memory_cache=use_memory_cache,
44 | )
45 |
46 | for result in engine.inference(req):
47 | match result.code:
48 | case "final":
49 | return result.audio, None
50 | case "error":
51 | return None, build_html_error_message(i18n(result.error))
52 | case _:
53 | pass
54 |
55 | return None, i18n("No audio generated")
56 |
57 |
58 | def get_reference_audio(reference_audio: str, reference_text: str) -> list:
59 | """
60 | Get the reference audio bytes.
61 | """
62 |
63 | with open(reference_audio, "rb") as audio_file:
64 | audio_bytes = audio_file.read()
65 |
66 | return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
67 |
68 |
69 | def build_html_error_message(error: Any) -> str:
70 |
71 | error = error if isinstance(error, Exception) else Exception("Unknown error")
72 |
73 | return f"""
74 |
76 | {html.escape(str(error))}
77 |
78 | """
79 |
80 |
81 | def get_inference_wrapper(engine) -> Callable:
82 | """
83 | Get the inference function with the immutable arguments.
84 | """
85 |
86 | return partial(
87 | inference_wrapper,
88 | engine=engine,
89 | )
90 |
--------------------------------------------------------------------------------
/tools/webui/variables.py:
--------------------------------------------------------------------------------
1 | from fish_speech.i18n import i18n
2 |
3 | HEADER_MD = f"""# Fish Speech
4 |
5 | {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
6 |
7 | {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
8 |
9 | {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
10 |
11 | {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
12 | """
13 |
14 | TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
15 |
--------------------------------------------------------------------------------