├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── pull_request_template.md └── workflows │ ├── build-docker-image.yml │ ├── docs.yml │ └── stale.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── .readthedocs.yaml ├── API_FLAGS.txt ├── LICENSE ├── README.md ├── docker-compose.dev.yml ├── dockerfile ├── dockerfile.dev ├── docs ├── CNAME ├── README.ja.md ├── README.ko.md ├── README.pt-BR.md ├── README.zh.md ├── assets │ └── figs │ │ ├── VS_1.jpg │ │ ├── VS_1_pt-BR.png │ │ ├── agent_gradio.png │ │ ├── diagram.png │ │ ├── diagrama.png │ │ └── logo-circle.png ├── en │ ├── finetune.md │ ├── index.md │ ├── inference.md │ ├── samples.md │ └── start_agent.md ├── ja │ ├── finetune.md │ ├── index.md │ ├── inference.md │ ├── samples.md │ └── start_agent.md ├── ko │ ├── finetune.md │ ├── index.md │ ├── inference.md │ ├── samples.md │ └── start_agent.md ├── pt │ ├── finetune.md │ ├── index.md │ ├── inference.md │ ├── samples.md │ └── start_agent.md ├── requirements.txt ├── stylesheets │ └── extra.css └── zh │ ├── finetune.md │ ├── index.md │ ├── inference.md │ ├── samples.md │ └── start_agent.md ├── entrypoint.sh ├── fish_speech ├── callbacks │ ├── __init__.py │ └── grad_norm.py ├── configs │ ├── base.yaml │ ├── firefly_gan_vq.yaml │ ├── lora │ │ └── r_8_alpha_16.yaml │ └── text2semantic_finetune.yaml ├── conversation.py ├── datasets │ ├── concat_repeat.py │ ├── protos │ │ ├── text-data.proto │ │ ├── text_data_pb2.py │ │ └── text_data_stream.py │ ├── semantic.py │ └── vqgan.py ├── i18n │ ├── README.md │ ├── __init__.py │ ├── core.py │ ├── locale │ │ ├── en_US.json │ │ ├── es_ES.json │ │ ├── ja_JP.json │ │ ├── ko_KR.json │ │ ├── pt_BR.json │ │ └── zh_CN.json │ └── scan.py ├── inference_engine │ ├── __init__.py │ ├── reference_loader.py │ ├── utils.py │ └── vq_manager.py ├── models │ ├── text2semantic │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── lit_module.py │ │ ├── llama.py │ │ └── lora.py │ └── vqgan │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── modules │ │ ├── firefly.py │ │ └── fsq.py │ │ └── utils.py ├── scheduler.py ├── text │ ├── __init__.py │ ├── clean.py │ └── spliter.py ├── tokenizer.py ├── train.py └── utils │ ├── __init__.py │ ├── braceexpand.py │ ├── context.py │ ├── file.py │ ├── instantiators.py │ ├── logger.py │ ├── logging_utils.py │ ├── rich_utils.py │ ├── schema.py │ ├── spectrogram.py │ └── utils.py ├── inference.ipynb ├── mkdocs.yml ├── pyproject.toml ├── pyrightconfig.json └── tools ├── api_client.py ├── api_server.py ├── download_models.py ├── e2e_webui.py ├── export_onnx.py ├── extract_model.py ├── fish_e2e.py ├── llama ├── build_dataset.py ├── eval_in_context.py ├── merge_lora.py └── quantize.py ├── run_webui.py ├── server ├── agent │ ├── __init__.py │ ├── generate.py │ ├── generation_utils.py │ └── pre_generation_utils.py ├── api_utils.py ├── exception_handler.py ├── inference.py ├── model_manager.py ├── model_utils.py └── views.py ├── smart_pad.py ├── vqgan ├── create_train_split.py └── extract_vq.py ├── webui ├── __init__.py ├── inference.py └── variables.py └── whisper_asr.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .github 3 | results 4 | data 5 | *.filelist 6 | /data_server/target 7 | checkpoints 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: "🕷️ Bug report" 2 | description: | 3 | Please follow this template carefully to ensure we can address your issue quickly. 4 | Make sure to provide as much detail as possible, including logs and screenshots. 5 | labels: 6 | - bug 7 | body: 8 | - type: checkboxes 9 | attributes: 10 | label: Self Checks 11 | description: "To ensure timely help, please confirm the following:" 12 | options: 13 | - label: This template is only for bug reports. For questions, please visit [Discussions](https://github.com/fishaudio/fish-speech/discussions). 14 | required: true 15 | - label: I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find information to solve my problem. [English](https://speech.fish.audio/) [中文](https://speech.fish.audio/zh/) [日本語](https://speech.fish.audio/ja/) [Portuguese (Brazil)](https://speech.fish.audio/pt/) 16 | required: true 17 | - label: I have searched for existing issues, including closed ones. [Search issues](https://github.com/fishaudio/fish-speech/issues) 18 | required: true 19 | - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)). 20 | required: true 21 | - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" 22 | required: true 23 | - label: "Please do not modify this template and fill in all required fields." 24 | required: true 25 | - type: dropdown 26 | attributes: 27 | label: Cloud or Self Hosted 28 | multiple: true 29 | options: 30 | - Cloud 31 | - Self Hosted (Docker) 32 | - Self Hosted (Source) 33 | validations: 34 | required: true 35 | - type: textarea 36 | attributes: 37 | label: Environment Details 38 | description: "Provide details such as OS, Python version, and any relevant software or dependencies." 39 | placeholder: e.g., macOS 13.5, Python 3.10, torch==2.4.1, Gradio 4.44.0 40 | validations: 41 | required: true 42 | - type: textarea 43 | attributes: 44 | label: Steps to Reproduce 45 | description: | 46 | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks. 47 | placeholder: | 48 | 1. Run the command `python -m tools.api_client -t "xxxxx"` 49 | 2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better) 50 | validations: 51 | required: true 52 | - type: textarea 53 | attributes: 54 | label: ✔️ Expected Behavior 55 | placeholder: Describe what you expected to happen. 56 | validations: 57 | required: false 58 | - type: textarea 59 | attributes: 60 | label: ❌ Actual Behavior 61 | placeholder: Describe what actually happened. 62 | validations: 63 | required: false 64 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: "\U0001F4E7 Discussions" 4 | url: https://github.com/fishaudio/fish-speech/discussions 5 | about: General discussions and request help from the community 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: "⭐ Feature or enhancement request" 2 | description: Propose something new. 3 | labels: 4 | - enhancement 5 | body: 6 | - type: checkboxes 7 | attributes: 8 | label: Self Checks 9 | description: "To make sure we get to you in time, please check the following :)" 10 | options: 11 | - label: I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find any relevant information that meets my needs. [English](https://speech.fish.audio/) [中文](https://speech.fish.audio/zh/) [日本語](https://speech.fish.audio/ja/) [Portuguese (Brazil)](https://speech.fish.audio/pt/) 12 | required: true 13 | - label: I have searched for existing issues [search for existing issues]([https://github.com/langgenius/dify/issues](https://github.com/fishaudio/fish-speech/issues)), including closed ones. 14 | required: true 15 | - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)). 16 | required: true 17 | - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" 18 | required: true 19 | - label: "Please do not modify this template :) and fill in all the required fields." 20 | required: true 21 | 22 | - type: textarea 23 | attributes: 24 | label: 1. Is this request related to a challenge you're experiencing? Tell us your story. 25 | description: | 26 | Describe the specific problem or scenario you’re facing in detail. For example: 27 | *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."* 28 | placeholder: Please describe the situation in as much detail as possible. 29 | validations: 30 | required: true 31 | 32 | - type: textarea 33 | attributes: 34 | label: 2. What is your suggested solution? 35 | description: | 36 | Provide a clear description of the feature or enhancement you'd like to propose. 37 | How would this feature solve your issue or improve the project? 38 | placeholder: Describe your idea or proposed solution here. 39 | validations: 40 | required: true 41 | 42 | - type: textarea 43 | attributes: 44 | label: 3. Additional context or comments 45 | description: | 46 | Any other relevant information, links, documents, or screenshots that provide clarity. 47 | Use this section for anything not covered above. 48 | placeholder: Add any extra details here. 49 | validations: 50 | required: false 51 | 52 | - type: checkboxes 53 | attributes: 54 | label: 4. Can you help us with this feature? 55 | description: | 56 | Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration. 57 | options: 58 | - label: I am interested in contributing to this feature. 59 | required: false 60 | 61 | - type: markdown 62 | attributes: 63 | value: | 64 | **Note:** Please submit only one request per issue to keep discussions focused and manageable. 65 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | **Is this PR adding new feature or fix a BUG?** 2 | 3 | Add feature / Fix BUG. 4 | 5 | **Is this pull request related to any issue? If yes, please link the issue.** 6 | 7 | #xxx 8 | -------------------------------------------------------------------------------- /.github/workflows/build-docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Build Image 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - "v*" 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest-16c64g 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Docker Buildx 16 | uses: docker/setup-buildx-action@v3 17 | - name: Get Version 18 | run: | 19 | if [[ $GITHUB_REF == refs/tags/v* ]]; then 20 | version=$(basename ${GITHUB_REF}) 21 | else 22 | version=nightly 23 | fi 24 | 25 | echo "version=${version}" >> $GITHUB_ENV 26 | echo "Current version: ${version}" 27 | 28 | - name: Login to Docker Hub 29 | uses: docker/login-action@v3 30 | with: 31 | username: ${{ secrets.DOCKER_USER }} 32 | password: ${{ secrets.DOCKER_PAT }} 33 | 34 | - name: Build and Push Image 35 | uses: docker/build-push-action@v6 36 | with: 37 | context: . 38 | file: dockerfile 39 | platforms: linux/amd64 40 | push: true 41 | tags: | 42 | fishaudio/fish-speech:${{ env.version }} 43 | fishaudio/fish-speech:latest 44 | outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true 45 | cache-from: type=registry,ref=fishaudio/fish-speech:latest 46 | cache-to: type=inline 47 | 48 | - name: Build and Push Dev Image 49 | uses: docker/build-push-action@v6 50 | with: 51 | context: . 52 | file: dockerfile.dev 53 | platforms: linux/amd64 54 | push: true 55 | build-args: | 56 | VERSION=${{ env.version }} 57 | BASE_IMAGE=fishaudio/fish-speech:${{ env.version }} 58 | tags: | 59 | fishaudio/fish-speech:${{ env.version }}-dev 60 | fishaudio/fish-speech:latest-dev 61 | outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true 62 | cache-from: type=registry,ref=fishaudio/fish-speech:latest-dev 63 | cache-to: type=inline 64 | 65 | - name: Push README to Dockerhub 66 | uses: peter-evans/dockerhub-description@v4 67 | with: 68 | username: ${{ secrets.DOCKER_USER }} 69 | password: ${{ secrets.DOCKER_PAT }} 70 | repository: fishaudio/fish-speech 71 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | paths: 7 | - 'docs/**' 8 | - 'mkdocs.yml' 9 | 10 | permissions: 11 | contents: write 12 | 13 | jobs: 14 | deploy: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Configure Git Credentials 19 | run: | 20 | git config user.name github-actions[bot] 21 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: 3.x 25 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 26 | - uses: actions/cache@v4 27 | with: 28 | key: mkdocs-material-${{ env.cache_id }} 29 | path: .cache 30 | restore-keys: | 31 | mkdocs-material- 32 | - run: pip install -r docs/requirements.txt 33 | - run: mkdocs gh-deploy --force 34 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "0 0 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 20 | days-before-pr-stale: 30 21 | days-before-pr-close: 30 22 | stale-pr-label: "stale" 23 | stale-pr-message: "This PR is stale because it has been open for 30 days with no activity." 24 | close-pr-message: "This PR was closed because it has been inactive for 30 days since being marked as stale." 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .pgx.* 3 | .pdm-python 4 | /fish_speech.egg-info 5 | __pycache__ 6 | /results 7 | /data 8 | /*.test.sh 9 | *.filelist 10 | filelists 11 | /fish_speech/text/cmudict_cache.pickle 12 | /checkpoints 13 | /.vscode 14 | /data_server/target 15 | /*.npy 16 | /*.wav 17 | /*.mp3 18 | /*.lab 19 | /results 20 | /data 21 | /.idea 22 | ffmpeg.exe 23 | ffprobe.exe 24 | asr-label* 25 | /.cache 26 | /fishenv 27 | /.locale 28 | /demo-audios 29 | /references 30 | /example 31 | /faster_whisper 32 | /.gradio 33 | *log 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_schedule: monthly 3 | 4 | repos: 5 | - repo: https://github.com/pycqa/isort 6 | rev: 6.0.1 7 | hooks: 8 | - id: isort 9 | args: [--profile=black] 10 | 11 | - repo: https://github.com/psf/black 12 | rev: 25.1.0 13 | hooks: 14 | - id: black 15 | 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v5.0.0 18 | hooks: 19 | - id: end-of-file-fixer 20 | - id: check-yaml 21 | - id: check-json 22 | - id: mixed-line-ending 23 | args: ["--fix=lf"] 24 | - id: check-added-large-files 25 | args: ["--maxkb=5000"] 26 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/.project-root -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for MkDocs projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | 13 | mkdocs: 14 | configuration: mkdocs.yml 15 | 16 | # Optionally declare the Python requirements required to build your docs 17 | python: 18 | install: 19 | - requirements: docs/requirements.txt 20 | -------------------------------------------------------------------------------- /API_FLAGS.txt: -------------------------------------------------------------------------------- 1 | # --infer 2 | --api 3 | --listen 0.0.0.0:8080 \ 4 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 5 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 6 | --decoder-config-name firefly_gan_vq 7 | -------------------------------------------------------------------------------- /docker-compose.dev.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | fish-speech: 5 | build: 6 | context: . 7 | dockerfile: dockerfile.dev 8 | container_name: fish-speech 9 | volumes: 10 | - ./:/exp 11 | deploy: 12 | resources: 13 | reservations: 14 | devices: 15 | - driver: nvidia 16 | count: all 17 | capabilities: [gpu] 18 | command: tail -f /dev/null 19 | -------------------------------------------------------------------------------- /dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim-bookworm AS stage-1 2 | ARG TARGETARCH 3 | 4 | ARG HUGGINGFACE_MODEL=fish-speech-1.5 5 | ARG HF_ENDPOINT=https://huggingface.co 6 | 7 | WORKDIR /opt/fish-speech 8 | 9 | RUN set -ex \ 10 | && pip install huggingface_hub \ 11 | && HF_ENDPOINT=${HF_ENDPOINT} huggingface-cli download --resume-download fishaudio/${HUGGINGFACE_MODEL} --local-dir checkpoints/${HUGGINGFACE_MODEL} 12 | 13 | FROM python:3.12-slim-bookworm 14 | ARG TARGETARCH 15 | 16 | ARG DEPENDENCIES=" \ 17 | ca-certificates \ 18 | libsox-dev \ 19 | build-essential \ 20 | cmake \ 21 | libasound-dev \ 22 | portaudio19-dev \ 23 | libportaudio2 \ 24 | libportaudiocpp0 \ 25 | ffmpeg" 26 | 27 | RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ 28 | --mount=type=cache,target=/var/lib/apt,sharing=locked \ 29 | set -ex \ 30 | && rm -f /etc/apt/apt.conf.d/docker-clean \ 31 | && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \ 32 | && apt-get update \ 33 | && apt-get -y install --no-install-recommends ${DEPENDENCIES} \ 34 | && echo "no" | dpkg-reconfigure dash 35 | 36 | WORKDIR /opt/fish-speech 37 | 38 | COPY . . 39 | 40 | RUN --mount=type=cache,target=/root/.cache,sharing=locked \ 41 | set -ex \ 42 | && pip install -e .[stable] 43 | 44 | COPY --from=stage-1 /opt/fish-speech/checkpoints /opt/fish-speech/checkpoints 45 | 46 | ENV GRADIO_SERVER_NAME="0.0.0.0" 47 | 48 | EXPOSE 7860 49 | 50 | CMD ["./entrypoint.sh"] 51 | -------------------------------------------------------------------------------- /dockerfile.dev: -------------------------------------------------------------------------------- 1 | ARG VERSION=dev 2 | ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION} 3 | 4 | FROM ${BASE_IMAGE} 5 | 6 | ARG TOOLS=" \ 7 | git \ 8 | curl \ 9 | build-essential \ 10 | ffmpeg \ 11 | libsm6 \ 12 | libxext6 \ 13 | libjpeg-dev \ 14 | zlib1g-dev \ 15 | aria2 \ 16 | zsh \ 17 | openssh-server \ 18 | sudo \ 19 | protobuf-compiler \ 20 | libasound-dev \ 21 | portaudio19-dev \ 22 | libportaudio2 \ 23 | libportaudiocpp0 \ 24 | cmake" 25 | 26 | RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ 27 | --mount=type=cache,target=/var/lib/apt,sharing=locked \ 28 | set -ex \ 29 | && apt-get update \ 30 | && apt-get -y install --no-install-recommends ${TOOLS} 31 | 32 | # Install oh-my-zsh so your terminal looks nice 33 | RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended 34 | 35 | # Set zsh as default shell 36 | RUN chsh -s /usr/bin/zsh 37 | ENV SHELL=/usr/bin/zsh 38 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | speech.fish.audio 2 | -------------------------------------------------------------------------------- /docs/README.ja.md: -------------------------------------------------------------------------------- 1 |
2 |

Fish Speech

3 | 4 | [English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md)
5 | 6 | 7 | Fish Speech 1.4 - Open-Source Multilingual Text-to-Speech with Voice Cloning | Product Hunt 8 | 9 | 10 | fishaudio%2Ffish-speech | Trendshift 11 | 12 |
13 |
14 |
15 | 16 |
17 |
18 |
19 |
20 | 21 |
22 | 23 | Discord 24 | 25 | 26 | Docker 27 | 28 | 29 | Huggingface 30 | 31 |
32 | 33 | このコードリポジトリはApache 2.0ライセンスの下で公開されており、モデルはCC-BY-NC-SA-4.0ライセンスの下で公開されています。詳細については[LICENSE](../LICENSE)をご参照ください。 34 | 35 | --- 36 | 37 | ## 機能 38 | 39 | 1. **ゼロショット & フューショット TTS**:10〜30 秒の音声サンプルを入力して、高品質の TTS 出力を生成します。**詳細は [音声クローンのベストプラクティス](https://docs.fish.audio/text-to-speech/voice-clone-best-practices) を参照してください。** 40 | 2. **多言語 & クロスリンガル対応**:多言語テキストを入力ボックスにコピーペーストするだけで、言語を気にする必要はありません。現在、英語、日本語、韓国語、中国語、フランス語、ドイツ語、アラビア語、スペイン語に対応しています。 41 | 3. **音素依存なし**:このモデルは強力な汎化能力を持ち、TTS に音素を必要としません。あらゆる言語スクリプトに対応可能です。 42 | 4. **高精度**:5 分間の英語テキストに対し、CER(文字誤り率)と WER(単語誤り率)は約 2%の精度を達成します。 43 | 5. **高速**:fish-tech アクセラレーションにより、Nvidia RTX 4060 ラップトップではリアルタイムファクターが約 1:5、Nvidia RTX 4090 では約 1:15 です。 44 | 6. **WebUI 推論**:使いやすい Gradio ベースの Web ユーザーインターフェースを搭載し、Chrome、Firefox、Edge などのブラウザに対応しています。 45 | 7. **GUI 推論**:PyQt6 のグラフィカルインターフェースを提供し、API サーバーとシームレスに連携します。Linux、Windows、macOS に対応しています。[GUI を見る](https://github.com/AnyaCoder/fish-speech-gui)。 46 | 8. **デプロイしやすい**:Linux、Windows、macOS にネイティブ対応した推論サーバーを簡単にセットアップでき、速度の低下を最小限に抑えます。 47 | 48 | ## 免責事項 49 | 50 | コードベースの違法な使用については一切責任を負いません。DMCA(デジタルミレニアム著作権法)およびその他の関連法については、地域の法律を参照してください。 51 | 52 | ## オンラインデモ 53 | 54 | [Fish Audio](https://fish.audio) 55 | 56 | ## ローカル推論のクイックスタート 57 | 58 | [inference.ipynb](/inference.ipynb) 59 | 60 | ## ビデオ 61 | 62 | #### V1.5 デモビデオ: [Watch the video on X (Twitter).](https://x.com/FishAudio/status/1864370933496205728) 63 | 64 | ## ドキュメント 65 | 66 | - [英語](https://speech.fish.audio/) 67 | - [中文](https://speech.fish.audio/zh/) 68 | - [日本語](https://speech.fish.audio/ja/) 69 | - [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/) 70 | 71 | ## サンプル (2024/10/02 V1.4) 72 | 73 | - [英語](https://speech.fish.audio/samples/) 74 | - [中文](https://speech.fish.audio/zh/samples/) 75 | - [日本語](https://speech.fish.audio/ja/samples/) 76 | - [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/samples/) 77 | 78 | ## クレジット 79 | 80 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) 81 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) 82 | - [GPT VITS](https://github.com/innnky/gpt-vits) 83 | - [MQTTS](https://github.com/b04901014/MQTTS) 84 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast) 85 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 86 | 87 | ## スポンサー 88 | 89 |
90 | 91 | 6Block Avatar 92 | 93 |
94 | データ処理スポンサー:6Block 95 |
96 | -------------------------------------------------------------------------------- /docs/README.ko.md: -------------------------------------------------------------------------------- 1 |
2 |

Fish Speech

3 | 4 | [English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어**
5 | 6 | 7 | Fish Speech 1.4 - Open-Source Multilingual Text-to-Speech with Voice Cloning | Product Hunt 8 | 9 | 10 | fishaudio%2Ffish-speech | Trendshift 11 | 12 |
13 |
14 |
15 | 16 |
17 |
18 |
19 |
20 | 21 |
22 | 23 | Discord 24 | 25 | 26 | Docker 27 | 28 | 29 | Huggingface 30 | 31 |
32 | 33 | 이 코드 저장소는 Apache 2.0 라이선스 하에 배포되며, 모델은 CC-BY-NC-SA-4.0 라이선스 하에 배포됩니다. 자세한 내용은 [LICENSE](../LICENSE)를 참조하십시오. 34 | 35 | --- 36 | 37 | ## 기능 38 | 39 | 1. **Zero-shot & Few-shot TTS:** 10초에서 30초의 음성 샘플을 입력하여 고품질의 TTS 출력을 생성합니다. **자세한 가이드는 [모범 사례](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)를 참조하시길 바랍니다.** 40 | 41 | 2. **다국어 및 교차 언어 지원:** 다국어 걱정 없이, 텍스트를 입력창에 복사하여 붙여넣기만 하면 됩니다. 현재 영어, 일본어, 한국어, 중국어, 프랑스어, 독일어, 아랍어, 스페인어를 지원합니다. 42 | 43 | 3. **음소 의존성 제거:** 이 모델은 강력한 일반화 능력을 가지고 있으며, TTS가 음소에 의존하지 않습니다. 모든 언어 스크립트 텍스트를 손쉽게 처리할 수 있습니다. 44 | 45 | 4. **높은 정확도:** 영어 텍스트 기준 5분 기준에서 단, 2%의 문자 오류율(CER)과 단어 오류율(WER)을 달성합니다. 46 | 47 | 5. **빠른 속도:** fish-tech 가속을 통해 실시간 인자(RTF)는 Nvidia RTX 4060 노트북에서는 약 1:5, Nvidia RTX 4090에서는 1:15입니다. 48 | 49 | 6. **웹 UI 추론:** Chrome, Firefox, Edge 등 다양한 브라우저에서 호환되는 Gradio 기반의 사용하기 쉬운 웹 UI를 제공합니다. 50 | 51 | 7. **GUI 추론:** PyQt6 그래픽 인터페이스를 제공하여 API 서버와 원활하게 작동합니다. Linux, Windows 및 macOS를 지원합니다. [GUI 참조](https://github.com/AnyaCoder/fish-speech-gui). 52 | 53 | 8. **배포 친화적:** Linux, Windows, macOS에서 네이티브로 지원되는 추론 서버를 쉽게 설정할 수 있어 속도 손실을 최소화합니다. 54 | 55 | ## 면책 조항 56 | 57 | 이 코드베이스의 불법적 사용에 대해 어떠한 책임도 지지 않습니다. DMCA 및 관련 법률에 대한 로컬 법률을 참조하십시오. 58 | 59 | ## 온라인 데모 60 | 61 | [Fish Audio](https://fish.audio) 62 | 63 | ## 로컬 추론을 위한 빠른 시작 64 | 65 | [inference.ipynb](/inference.ipynb) 66 | 67 | ## 영상 68 | 69 | #### V1.5 데모 영상: [Watch the video on X (Twitter).](https://x.com/FishAudio/status/1864370933496205728) 70 | 71 | ## 문서 72 | 73 | - [English](https://speech.fish.audio/) 74 | - [中文](https://speech.fish.audio/zh/) 75 | - [日本語](https://speech.fish.audio/ja/) 76 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/) 77 | - [한국어](https://speech.fish.audio/ko/) 78 | 79 | ## Samples (2024/10/02 V1.4) 80 | 81 | - [English](https://speech.fish.audio/samples/) 82 | - [中文](https://speech.fish.audio/zh/samples/) 83 | - [日本語](https://speech.fish.audio/ja/samples/) 84 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/) 85 | - [한국어](https://speech.fish.audio/ko/samples/) 86 | 87 | ## Credits 88 | 89 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) 90 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) 91 | - [GPT VITS](https://github.com/innnky/gpt-vits) 92 | - [MQTTS](https://github.com/b04901014/MQTTS) 93 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast) 94 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 95 | 96 | ## Sponsor 97 | 98 |
99 | 100 | 6Block Avatar 101 | 102 |
103 | 데이터 처리 후원: 6Block 104 |
105 | -------------------------------------------------------------------------------- /docs/README.zh.md: -------------------------------------------------------------------------------- 1 |
2 |

Fish Speech

3 | 4 | [English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md)
5 | 6 | 7 | Fish Speech 1.4 - Open-Source Multilingual Text-to-Speech with Voice Cloning | Product Hunt 8 | 9 | 10 | fishaudio%2Ffish-speech | Trendshift 11 | 12 |
13 |
14 |
15 | 16 |
17 |
18 |
19 | 20 |
21 | 22 |
23 | 24 | Discord 25 | 26 | 27 | Docker 28 | 29 | 30 | Huggingface 31 | 32 |
33 | 34 |
35 | 36 | 此代码库根据 Apache 2.0 许可证发布,模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](../LICENSE) 了解更多细节. 37 | 38 | --- 39 | 40 | ## 特性 41 | 42 | 1. **零样本 & 小样本 TTS**:输入 10 到 30 秒的声音样本即可生成高质量的 TTS 输出。**详见 [语音克隆最佳实践指南](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。** 43 | 2. **多语言 & 跨语言支持**:只需复制并粘贴多语言文本到输入框中,无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。 44 | 3. **无音素依赖**:模型具备强大的泛化能力,不依赖音素进行 TTS,能够处理任何文字表示的语言。 45 | 4. **高准确率**:在 5 分钟的英文文本上,达到了约 2% 的 CER(字符错误率)和 WER(词错误率)。 46 | 5. **快速**:通过 fish-tech 加速,在 Nvidia RTX 4060 笔记本上的实时因子约为 1:5,在 Nvidia RTX 4090 上约为 1:15。 47 | 6. **WebUI 推理**:提供易于使用的基于 Gradio 的网页用户界面,兼容 Chrome、Firefox、Edge 等浏览器。 48 | 7. **GUI 推理**:提供 PyQt6 图形界面,与 API 服务器无缝协作。支持 Linux、Windows 和 macOS。[查看 GUI](https://github.com/AnyaCoder/fish-speech-gui)。 49 | 8. **易于部署**:轻松设置推理服务器,原生支持 Linux、Windows 和 macOS,最大程度减少速度损失。 50 | 51 | ## 免责声明 52 | 53 | 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规. 54 | 55 | ## 在线 DEMO 56 | 57 | [Fish Audio](https://fish.audio) 58 | 59 | ## 快速开始本地推理 60 | 61 | [inference.ipynb](/inference.ipynb) 62 | 63 | ## 视频 64 | 65 | #### 1.5 介绍: https://www.bilibili.com/video/BV1EKiDYBE4o 66 | 67 | #### 1.4 介绍: https://www.bilibili.com/video/BV1pu46eVEk7 68 | 69 | #### 1.2 介绍: https://www.bilibili.com/video/BV1wz421B71D 70 | 71 | #### 1.1 介绍: https://www.bilibili.com/video/BV1zJ4m1K7cj 72 | 73 | ## 文档 74 | 75 | - [English](https://speech.fish.audio/) 76 | - [中文](https://speech.fish.audio/zh/) 77 | - [日本語](https://speech.fish.audio/ja/) 78 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/) 79 | 80 | ## 例子 (2024/10/02 V1.4) 81 | 82 | - [English](https://speech.fish.audio/samples/) 83 | - [中文](https://speech.fish.audio/zh/samples/) 84 | - [日本語](https://speech.fish.audio/ja/samples/) 85 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/) 86 | 87 | ## 鸣谢 88 | 89 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) 90 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) 91 | - [GPT VITS](https://github.com/innnky/gpt-vits) 92 | - [MQTTS](https://github.com/b04901014/MQTTS) 93 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast) 94 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 95 | 96 | ## 赞助 97 | 98 |
99 | 100 | 6Block Avatar 101 | 102 |
103 | 数据处理服务器由 6Block 提供 104 |
105 | -------------------------------------------------------------------------------- /docs/assets/figs/VS_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/VS_1.jpg -------------------------------------------------------------------------------- /docs/assets/figs/VS_1_pt-BR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/VS_1_pt-BR.png -------------------------------------------------------------------------------- /docs/assets/figs/agent_gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/agent_gradio.png -------------------------------------------------------------------------------- /docs/assets/figs/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/diagram.png -------------------------------------------------------------------------------- /docs/assets/figs/diagrama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/diagrama.png -------------------------------------------------------------------------------- /docs/assets/figs/logo-circle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/docs/assets/figs/logo-circle.png -------------------------------------------------------------------------------- /docs/en/finetune.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning 2 | 3 | Obviously, when you opened this page, you were not satisfied with the performance of the few-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset. 4 | 5 | In current version, you only need to finetune the 'LLAMA' part. 6 | 7 | ## Fine-tuning LLAMA 8 | ### 1. Prepare the dataset 9 | 10 | ``` 11 | . 12 | ├── SPK1 13 | │ ├── 21.15-26.44.lab 14 | │ ├── 21.15-26.44.mp3 15 | │ ├── 27.51-29.98.lab 16 | │ ├── 27.51-29.98.mp3 17 | │ ├── 30.1-32.71.lab 18 | │ └── 30.1-32.71.mp3 19 | └── SPK2 20 | ├── 38.79-40.85.lab 21 | └── 38.79-40.85.mp3 22 | ``` 23 | 24 | You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extensions `.lab`. 25 | 26 | !!! info "Dataset Format" 27 | The `.lab` annotation file only needs to contain the transcription of the audio, with no special formatting required. For example, if `hi.mp3` says "Hello, goodbye," then the `hi.lab` file would contain a single line of text: "Hello, goodbye." 28 | 29 | !!! warning 30 | It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this. 31 | 32 | ```bash 33 | fap loudness-norm data-raw data --clean 34 | ``` 35 | 36 | 37 | ### 2. Batch extraction of semantic tokens 38 | 39 | Make sure you have downloaded the VQGAN weights. If not, run the following command: 40 | 41 | ```bash 42 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 43 | ``` 44 | 45 | You can then run the following command to extract semantic tokens: 46 | 47 | ```bash 48 | python tools/vqgan/extract_vq.py data \ 49 | --num-workers 1 --batch-size 16 \ 50 | --config-name "firefly_gan_vq" \ 51 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 52 | ``` 53 | 54 | !!! note 55 | You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit. 56 | For the VITS format, you can specify a file list using `--filelist xxx.list`. 57 | 58 | This command will create `.npy` files in the `data` directory, as shown below: 59 | 60 | ``` 61 | . 62 | ├── SPK1 63 | │ ├── 21.15-26.44.lab 64 | │ ├── 21.15-26.44.mp3 65 | │ ├── 21.15-26.44.npy 66 | │ ├── 27.51-29.98.lab 67 | │ ├── 27.51-29.98.mp3 68 | │ ├── 27.51-29.98.npy 69 | │ ├── 30.1-32.71.lab 70 | │ ├── 30.1-32.71.mp3 71 | │ └── 30.1-32.71.npy 72 | └── SPK2 73 | ├── 38.79-40.85.lab 74 | ├── 38.79-40.85.mp3 75 | └── 38.79-40.85.npy 76 | ``` 77 | 78 | ### 3. Pack the dataset into protobuf 79 | 80 | ```bash 81 | python tools/llama/build_dataset.py \ 82 | --input "data" \ 83 | --output "data/protos" \ 84 | --text-extension .lab \ 85 | --num-workers 16 86 | ``` 87 | 88 | After the command finishes executing, you should see the `quantized-dataset-ft.protos` file in the `data` directory. 89 | 90 | ### 4. Finally, fine-tuning with LoRA 91 | 92 | Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command: 93 | 94 | ```bash 95 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 96 | ``` 97 | 98 | Finally, you can start the fine-tuning by running the following command: 99 | 100 | ```bash 101 | python fish_speech/train.py --config-name text2semantic_finetune \ 102 | project=$project \ 103 | +lora@model.model.lora_config=r_8_alpha_16 104 | ``` 105 | 106 | !!! note 107 | You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`. 108 | 109 | !!! note 110 | For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues. 111 | 112 | After training is complete, you can refer to the [inference](inference.md) section to generate speech. 113 | 114 | !!! info 115 | By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability. 116 | If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting. 117 | 118 | After training, you need to convert the LoRA weights to regular weights before performing inference. 119 | 120 | ```bash 121 | python tools/llama/merge_lora.py \ 122 | --lora-config r_8_alpha_16 \ 123 | --base-weight checkpoints/fish-speech-1.5 \ 124 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \ 125 | --output checkpoints/fish-speech-1.5-yth-lora/ 126 | ``` 127 | !!! note 128 | You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data. 129 | -------------------------------------------------------------------------------- /docs/en/start_agent.md: -------------------------------------------------------------------------------- 1 | # Start Agent 2 | 3 | ## Requirements 4 | 5 | - GPU memory: At least 8GB(under quanization), 16GB or more is recommanded. 6 | - Disk usage: 10GB 7 | 8 | ## Download Model 9 | 10 | You can get the model by: 11 | 12 | ```bash 13 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b 14 | ``` 15 | 16 | Put them in the 'checkpoints' folder. 17 | 18 | You also need the fish-speech model which you can download instructed by [inference](inference.md). 19 | 20 | So there will be 2 folder in the checkpoints. 21 | 22 | The `checkpoints/fish-speech-1.4` and `checkpoints/fish-agent-v0.1-3b` 23 | 24 | ## Environment Prepare 25 | 26 | If you already have Fish-speech, you can directly use by adding the follow instruction: 27 | ```bash 28 | pip install cachetools 29 | ``` 30 | 31 | !!! note 32 | Please use the Python version below 3.12 for compile. 33 | 34 | If you don't have, please use the below commands to build your environment: 35 | 36 | ```bash 37 | sudo apt-get install portaudio19-dev 38 | 39 | pip install -e .[stable] 40 | ``` 41 | 42 | ## Launch The Agent Demo. 43 | 44 | To build fish-agent, please use the command below under the main folder: 45 | 46 | ```bash 47 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile 48 | ``` 49 | 50 | The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation. 51 | 52 | It won't compile at once (remember). 53 | 54 | Then open another terminal and use the command: 55 | 56 | ```bash 57 | python -m tools.e2e_webui 58 | ``` 59 | 60 | This will create a Gradio WebUI on the device. 61 | 62 | When you first use the model, it will come to compile (if the `--compile` is True) for a short time, so please wait with patience. 63 | 64 | ## Gradio Webui 65 |

66 | 67 |

68 | 69 | Have a good time! 70 | 71 | ## Performance 72 | 73 | Under our test, a 4060 laptop just barely runs, but is very stretched, which is only about 8 tokens/s. The 4090 is around 95 tokens/s under compile, which is what we recommend. 74 | 75 | # About Agent 76 | 77 | The demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request. 78 | -------------------------------------------------------------------------------- /docs/ja/finetune.md: -------------------------------------------------------------------------------- 1 | # 微調整 2 | 3 | 明らかに、このページを開いたとき、few-shot 事前トレーニングモデルのパフォーマンスに満足していなかったことでしょう。データセット上でのパフォーマンスを向上させるためにモデルを微調整したいと考えています。 4 | 5 | 現在のバージョンでは、「LLAMA」部分のみを微調整する必要があります。 6 | 7 | ## LLAMAの微調整 8 | ### 1. データセットの準備 9 | 10 | ``` 11 | . 12 | ├── SPK1 13 | │ ├── 21.15-26.44.lab 14 | │ ├── 21.15-26.44.mp3 15 | │ ├── 27.51-29.98.lab 16 | │ ├── 27.51-29.98.mp3 17 | │ ├── 30.1-32.71.lab 18 | │ └── 30.1-32.71.mp3 19 | └── SPK2 20 | ├── 38.79-40.85.lab 21 | └── 38.79-40.85.mp3 22 | ``` 23 | 24 | データセットを上記の形式に変換し、「data」ディレクトリに配置する必要があります。音声ファイルの拡張子は「.mp3」、「.wav」、または「.flac」にすることができ、注釈ファイルの拡張子は「.lab」にする必要があります。 25 | 26 | !!! info 27 | 標準ファイル `.lab` には、音声の転写テキストのみを含め、特別なフォーマットは必要ありません。例えば、`hi.mp3` で「こんにちは、さようなら」と言っている場合、`hi.lab` ファイルには「こんにちは、さようなら」という一行のテキストを含めるだけです。 28 | 29 | !!! warning 30 | データセットにラウドネス正規化を適用することをお勧めします。これを行うには、[fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。 31 | 32 | ```bash 33 | fap loudness-norm data-raw data --clean 34 | ``` 35 | 36 | 37 | ### 2. セマンティックトークンのバッチ抽出 38 | 39 | VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。 40 | 41 | ```bash 42 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 43 | ``` 44 | 45 | 次に、次のコマンドを実行してセマンティックトークンを抽出できます。 46 | 47 | ```bash 48 | python tools/vqgan/extract_vq.py data \ 49 | --num-workers 1 --batch-size 16 \ 50 | --config-name "firefly_gan_vq" \ 51 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 52 | ``` 53 | 54 | !!! note 55 | `--num-workers` と `--batch-size` を調整して抽出速度を上げることができますが、GPUメモリの制限を超えないようにしてください。 56 | VITS形式の場合、`--filelist xxx.list` を使用してファイルリストを指定できます。 57 | 58 | このコマンドは、`data`ディレクトリに`.npy`ファイルを作成します。以下のように表示されます。 59 | 60 | ``` 61 | . 62 | ├── SPK1 63 | │ ├── 21.15-26.44.lab 64 | │ ├── 21.15-26.44.mp3 65 | │ ├── 21.15-26.44.npy 66 | │ ├── 27.51-29.98.lab 67 | │ ├── 27.51-29.98.mp3 68 | │ ├── 27.51-29.98.npy 69 | │ ├── 30.1-32.71.lab 70 | │ ├── 30.1-32.71.mp3 71 | │ └── 30.1-32.71.npy 72 | └── SPK2 73 | ├── 38.79-40.85.lab 74 | ├── 38.79-40.85.mp3 75 | └── 38.79-40.85.npy 76 | ``` 77 | 78 | ### 3. データセットをprotobufにパックする 79 | 80 | ```bash 81 | python tools/llama/build_dataset.py \ 82 | --input "data" \ 83 | --output "data/protos" \ 84 | --text-extension .lab \ 85 | --num-workers 16 86 | ``` 87 | 88 | コマンドの実行が完了すると、`data`ディレクトリに`quantized-dataset-ft.protos`ファイルが表示されます。 89 | 90 | ### 4. 最後に、LoRAを使用して微調整する 91 | 92 | 同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。 93 | 94 | ```bash 95 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 96 | ``` 97 | 98 | 最後に、次のコマンドを実行して微調整を開始できます。 99 | 100 | ```bash 101 | python fish_speech/train.py --config-name text2semantic_finetune \ 102 | project=$project \ 103 | +lora@model.model.lora_config=r_8_alpha_16 104 | ``` 105 | 106 | !!! note 107 | `fish_speech/configs/text2semantic_finetune.yaml` を変更して、`batch_size`、`gradient_accumulation_steps` などのトレーニングパラメータを変更し、GPUメモリに適合させることができます。 108 | 109 | !!! note 110 | Windowsユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。 111 | 112 | トレーニングが完了したら、[推論](inference.md)セクションを参照し、音声を生成します。 113 | 114 | !!! info 115 | デフォルトでは、モデルは話者の発話パターンのみを学習し、音色は学習しません。音色の安定性を確保するためにプロンプトを使用する必要があります。 116 | 音色を学習したい場合は、トレーニングステップ数を増やすことができますが、これにより過学習が発生する可能性があります。 117 | 118 | トレーニングが完了したら、推論を行う前にLoRAの重みを通常の重みに変換する必要があります。 119 | 120 | ```bash 121 | python tools/llama/merge_lora.py \ 122 | --lora-config r_8_alpha_16 \ 123 | --base-weight checkpoints/fish-speech-1.5 \ 124 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \ 125 | --output checkpoints/fish-speech-1.5-yth-lora/ 126 | ``` 127 | !!! note 128 | 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。 129 | -------------------------------------------------------------------------------- /docs/ja/inference.md: -------------------------------------------------------------------------------- 1 | # 推論 2 | 3 | 推論は、コマンドライン、HTTP API、および Web UI をサポートしています。 4 | 5 | !!! note 6 | 全体として、推論は次のいくつかの部分で構成されています: 7 | 8 | 1. VQGANを使用して、与えられた約10秒の音声をエンコードします。 9 | 2. エンコードされたセマンティックトークンと対応するテキストを例として言語モデルに入力します。 10 | 3. 新しいテキストが与えられた場合、モデルに対応するセマンティックトークンを生成させます。 11 | 4. 生成されたセマンティックトークンをVITS / VQGANに入力してデコードし、対応する音声を生成します。 12 | 13 | ## モデルをダウンロード 14 | 必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。 15 | 16 | ```bash 17 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 18 | ``` 19 | 20 | ## コマンドライン推論 21 | ### 1. 音声からプロンプトを生成する: 22 | 23 | !!! note 24 | モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。 25 | 26 | !!! warning "将来のバージョンに関する警告" 27 | 元のパス(tools/vqgan/infernce.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。 28 | 29 | ```bash 30 | python fish_speech/models/vqgan/inference.py \ 31 | -i "paimon.wav" \ 32 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 33 | ``` 34 | 35 | `fake.npy`ファイルが生成されるはずです。 36 | 37 | ### 2. テキストからセマンティックトークンを生成する: 38 | 39 | !!! warning "将来のバージョンに関する警告" 40 | 元のパス(tools/llama/generate.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。 41 | 42 | ```bash 43 | python fish_speech/models/text2semantic/inference.py \ 44 | --text "変換したいテキスト" \ 45 | --prompt-text "参照テキスト" \ 46 | --prompt-tokens "fake.npy" \ 47 | --checkpoint-path "checkpoints/fish-speech-1.5" \ 48 | --num-samples 2 \ 49 | --compile 50 | ``` 51 | 52 | このコマンドは、作業ディレクトリに`codes_N`ファイルを作成します。ここで、N は 0 から始まる整数です。 53 | 54 | !!! note 55 | `--compile`を使用して CUDA カーネルを融合し、より高速な推論を実現することができます(約 30 トークン/秒 -> 約 500 トークン/秒)。 56 | それに対応して、加速を使用しない場合は、`--compile`パラメータをコメントアウトできます。 57 | 58 | !!! info 59 | bf16 をサポートしていない GPU の場合、`--half`パラメータを使用する必要があるかもしれません。 60 | 61 | ### 3. セマンティックトークンから音声を生成する: 62 | 63 | #### VQGAN デコーダー 64 | 65 | !!! warning "将来のバージョンに関する警告" 66 | 元のパス(tools/vqgan/infernce.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。 67 | 68 | ```bash 69 | python fish_speech/models/vqgan/inference.py \ 70 | -i "codes_0.npy" \ 71 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 72 | ``` 73 | 74 | ## HTTP API 推論 75 | 76 | 推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます: 77 | 78 | ```bash 79 | python -m tools.api_server \ 80 | --listen 0.0.0.0:8080 \ 81 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 82 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 83 | --decoder-config-name firefly_gan_vq 84 | ``` 85 | 86 | > 推論を高速化したい場合は、`--compile` パラメータを追加できます。 87 | 88 | その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。 89 | 90 | 以下は、`tools/api_client.py` を使用してリクエストを送信する例です。 91 | 92 | ```bash 93 | python -m tools.api_client \ 94 | --text "入力するテキスト" \ 95 | --reference_audio "参照音声へのパス" \ 96 | --reference_text "参照音声テキスト" \ 97 | --streaming True 98 | ``` 99 | 100 | 上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。 101 | 102 | !!! info 103 | 使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください 104 | 105 | ## WebUI 推論 106 | 107 | 次のコマンドを使用して WebUI を起動できます: 108 | 109 | ```bash 110 | python -m tools.run_webui \ 111 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 112 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 113 | --decoder-config-name firefly_gan_vq 114 | ``` 115 | > 推論を高速化したい場合は、`--compile` パラメータを追加できます。 116 | 117 | !!! note 118 | ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。 119 | 120 | !!! note 121 | Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。 122 | 123 | お楽しみください! 124 | -------------------------------------------------------------------------------- /docs/ja/start_agent.md: -------------------------------------------------------------------------------- 1 | # エージェントの開始 2 | 3 | !!! note 4 | もしあなたがネイティブ・スピーカーで、翻訳に問題があるとお感じでしたら、issueかpull requestをお送りください! 5 | 6 | ## 要件 7 | 8 | - GPUメモリ: 最低8GB(量子化使用時)、16GB以上推奨 9 | - ディスク使用量: 10GB 10 | 11 | ## モデルのダウンロード 12 | 13 | 以下のコマンドでモデルを取得できます: 14 | 15 | ```bash 16 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b 17 | ``` 18 | 19 | これらを'checkpoints'フォルダに配置してください。 20 | 21 | また、[inference](inference.md)の手順に従ってfish-speechモデルもダウンロードする必要があります。 22 | 23 | checkpointsには2つのフォルダが必要です。 24 | 25 | `checkpoints/fish-speech-1.4`と`checkpoints/fish-agent-v0.1-3b`です。 26 | 27 | ## 環境準備 28 | 29 | すでにFish-speechをお持ちの場合は、以下の指示を追加するだけで直接使用できます: 30 | ```bash 31 | pip install cachetools 32 | ``` 33 | 34 | !!! note 35 | コンパイルにはPythonバージョン3.12未満を使用してください。 36 | 37 | お持ちでない場合は、以下のコマンドで環境を構築してください: 38 | 39 | ```bash 40 | sudo apt-get install portaudio19-dev 41 | 42 | pip install -e .[stable] 43 | ``` 44 | 45 | ## エージェントデモの起動 46 | 47 | fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください: 48 | 49 | ```bash 50 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile 51 | ``` 52 | 53 | `--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。 54 | 55 | 一度にコンパイルは行われません(覚えておいてください)。 56 | 57 | 次に、別のターミナルを開いて以下のコマンドを使用します: 58 | 59 | ```bash 60 | python -m tools.e2e_webui 61 | ``` 62 | 63 | これにより、デバイス上にGradio WebUIが作成されます。 64 | 65 | モデルを初めて使用する際は、(`--compile`がTrueの場合)しばらくコンパイルが行われますので、お待ちください。 66 | 67 | ## Gradio Webui 68 |

69 | 70 |

71 | 72 | お楽しみください! 73 | 74 | ## パフォーマンス 75 | 76 | テストでは、4060搭載のラップトップではかろうじて動作しますが、非常に厳しい状態で、約8トークン/秒程度です。4090ではコンパイル時に約95トークン/秒で、これが推奨環境です。 77 | 78 | # エージェントについて 79 | 80 | このデモは初期アルファテストバージョンで、推論速度の最適化が必要で、修正を待つバグが多数あります。バグを発見した場合や修正したい場合は、issueやプルリクエストをいただけると大変嬉しく思います。 81 | -------------------------------------------------------------------------------- /docs/ko/finetune.md: -------------------------------------------------------------------------------- 1 | # 파인튜닝 2 | 3 | 이 페이지를 열었다는 것은, 사전 학습된 퓨샷(Few-shot) 모델의 성능에 만족하지 못했다는 의미일 것입니다. 데이터셋의 성능을 향상시키기 위해 모델을 파인튜닝하고 싶으시겠죠. 4 | 5 | 현재 버전에서는 'LLAMA' 부분만 파인튜닝하시면 됩니다. 6 | 7 | ## LLAMA 파인튜닝 8 | ### 1. 데이터셋 준비 9 | 10 | ``` 11 | . 12 | ├── SPK1 13 | │ ├── 21.15-26.44.lab 14 | │ ├── 21.15-26.44.mp3 15 | │ ├── 27.51-29.98.lab 16 | │ ├── 27.51-29.98.mp3 17 | │ ├── 30.1-32.71.lab 18 | │ └── 30.1-32.71.mp3 19 | └── SPK2 20 | ├── 38.79-40.85.lab 21 | └── 38.79-40.85.mp3 22 | ``` 23 | 24 | 위와 같은 형식으로 데이터셋을 변환하여 `data` 디렉토리 안에 배치하세요. 오디오 파일의 확장자는 `.mp3`, `.wav`, `.flac` 중 하나여야 하며, 주석 파일은 `.lab` 확장자를 사용해야 합니다. 25 | 26 | !!! info "데이터셋 형식" 27 | `.lab` 주석 파일은 오디오의 전사 내용만 포함하면 되며, 특별한 형식이 필요하지 않습니다. 예를 들어, `hi.mp3`에서 "Hello, goodbye"라는 대사를 말한다면, `hi.lab` 파일에는 "Hello, goodbye"라는 한 줄의 텍스트만 있어야 합니다. 28 | 29 | !!! warning 30 | 데이터셋에 대한 음량 정규화(loudness normalization)를 적용하는 것이 좋습니다. 이를 위해 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess)를 사용할 수 있습니다. 31 | 32 | ```bash 33 | fap loudness-norm data-raw data --clean 34 | ``` 35 | 36 | ### 2. 시맨틱 토큰 배치 추출 37 | 38 | VQGAN 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요: 39 | 40 | ```bash 41 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 42 | ``` 43 | 44 | 이후 시맨틱 토큰을 추출하기 위해 아래 명령어를 실행하세요: 45 | 46 | ```bash 47 | python tools/vqgan/extract_vq.py data \ 48 | --num-workers 1 --batch-size 16 \ 49 | --config-name "firefly_gan_vq" \ 50 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 51 | ``` 52 | 53 | !!! note 54 | 추출 속도를 높이기 위해 `--num-workers`와 `--batch-size` 값을 조정할 수 있지만, GPU 메모리 한도를 초과하지 않도록 주의하세요. 55 | VITS 형식의 경우, `--filelist xxx.list`를 사용하여 파일 목록을 지정할 수 있습니다. 56 | 57 | 이 명령을 실행하면 `data` 디렉토리 안에 `.npy` 파일이 생성됩니다. 다음과 같이 표시됩니다: 58 | 59 | ``` 60 | . 61 | ├── SPK1 62 | │ ├── 21.15-26.44.lab 63 | │ ├── 21.15-26.44.mp3 64 | │ ├── 21.15-26.44.npy 65 | │ ├── 27.51-29.98.lab 66 | │ ├── 27.51-29.98.mp3 67 | │ ├── 27.51-29.98.npy 68 | │ ├── 30.1-32.71.lab 69 | │ ├── 30.1-32.71.mp3 70 | │ └── 30.1-32.71.npy 71 | └── SPK2 72 | ├── 38.79-40.85.lab 73 | ├── 38.79-40.85.mp3 74 | └── 38.79-40.85.npy 75 | ``` 76 | 77 | ### 3. 데이터셋을 protobuf로 패킹 78 | 79 | ```bash 80 | python tools/llama/build_dataset.py \ 81 | --input "data" \ 82 | --output "data/protos" \ 83 | --text-extension .lab \ 84 | --num-workers 16 85 | ``` 86 | 87 | 명령이 완료되면 `data` 디렉토리 안에 `quantized-dataset-ft.protos` 파일이 생성됩니다. 88 | 89 | ### 4. 마지막으로, LoRA를 이용한 파인튜닝 90 | 91 | 마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요: 92 | 93 | ```bash 94 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 95 | ``` 96 | 97 | 마지막으로, 아래 명령어를 실행하여 파인튜닝을 시작할 수 있습니다: 98 | 99 | ```bash 100 | python fish_speech/train.py --config-name text2semantic_finetune \ 101 | project=$project \ 102 | +lora@model.model.lora_config=r_8_alpha_16 103 | ``` 104 | 105 | !!! note 106 | `batch_size`, `gradient_accumulation_steps` 등의 학습 매개변수를 GPU 메모리에 맞게 조정하려면 `fish_speech/configs/text2semantic_finetune.yaml` 파일을 수정할 수 있습니다. 107 | 108 | !!! note 109 | Windows 사용자의 경우, `nccl` 문제를 피하려면 `trainer.strategy.process_group_backend=gloo`를 사용할 수 있습니다. 110 | 111 | 훈련이 완료되면 [추론](inference.md) 섹션을 참고하여 음성을 생성할 수 있습니다. 112 | 113 | !!! info 114 | 기본적으로 모델은 화자의 말하는 패턴만 학습하고 음색은 학습하지 않습니다. 음색의 안정성을 위해 프롬프트를 사용해야 합니다. 115 | 음색을 학습하려면 훈련 단계를 늘릴 수 있지만, 이는 과적합의 위험을 초래할 수 있습니다. 116 | 117 | 훈련이 끝나면 LoRA 가중치를 일반 가중치로 변환한 후에 추론을 수행해야 합니다. 118 | 119 | ```bash 120 | python tools/llama/merge_lora.py \ 121 | --lora-config r_8_alpha_16 \ 122 | --base-weight checkpoints/fish-speech-1.5 \ 123 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \ 124 | --output checkpoints/fish-speech-1.5-yth-lora/ 125 | ``` 126 | 127 | !!! note 128 | 다른 체크포인트도 시도해 볼 수 있습니다. 요구 사항에 맞는 가장 초기 체크포인트를 사용하는 것이 좋습니다. 이들은 종종 분포 밖(OOD) 데이터에서 더 좋은 성능을 발휘합니다. 129 | -------------------------------------------------------------------------------- /docs/ko/inference.md: -------------------------------------------------------------------------------- 1 | # 추론 2 | 3 | 추론은 명령줄, HTTP API, 그리고 웹 UI에서 지원됩니다. 4 | 5 | !!! note 6 | 전체 추론 과정은 다음의 여러 단계로 구성됩니다: 7 | 8 | 1. VQGAN을 사용하여 약 10초 분량의 음성을 인코딩합니다. 9 | 2. 인코딩된 시맨틱 토큰과 해당 텍스트를 예시로 언어 모델에 입력합니다. 10 | 3. 새로운 텍스트를 입력하면, 모델이 해당하는 시맨틱 토큰을 생성합니다. 11 | 4. 생성된 시맨틱 토큰을 VITS / VQGAN에 입력하여 음성을 디코딩하고 생성합니다. 12 | 13 | ## 모델 다운로드 14 | 필요한 `vqgan` 및 `llama` 모델을 Hugging Face 리포지토리에서 다운로드하세요. 15 | 16 | ```bash 17 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 18 | ``` 19 | 20 | ## 명령줄 추론 21 | ### 1. 음성에서 프롬프트 생성: 22 | 23 | !!! note 24 | 모델이 음색을 무작위로 선택하도록 하려면 이 단계를 건너뛸 수 있습니다. 25 | 26 | !!! warning "향후 버전 경고" 27 | 원래 경로(tools/vqgan/infernce.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오. 28 | 29 | ```bash 30 | python fish_speech/models/vqgan/inference.py \ 31 | -i "paimon.wav" \ 32 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 33 | ``` 34 | 35 | 이 명령을 실행하면 `fake.npy` 파일을 얻게 됩니다. 36 | 37 | ### 2. 텍스트에서 시맨틱 토큰 생성: 38 | 39 | !!! warning "향후 버전 경고" 40 | 원래 경로(tools/llama/generate.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오. 41 | 42 | ```bash 43 | python fish_speech/models/text2semantic/inference.py \ 44 | --text "변환할 텍스트" \ 45 | --prompt-text "참고할 텍스트" \ 46 | --prompt-tokens "fake.npy" \ 47 | --checkpoint-path "checkpoints/fish-speech-1.5" \ 48 | --num-samples 2 \ 49 | --compile 50 | ``` 51 | 52 | 이 명령을 실행하면 작업 디렉토리에 `codes_N` 파일이 생성되며, N은 0부터 시작하는 정수입니다. 53 | 54 | !!! note 55 | 빠른 추론을 위해 `--compile` 옵션을 사용하여 CUDA 커널을 결합할 수 있습니다 (~초당 30 토큰 -> ~초당 500 토큰). 56 | `--compile` 매개변수를 주석 처리하여 가속화 옵션을 사용하지 않을 수도 있습니다. 57 | 58 | !!! info 59 | bf16을 지원하지 않는 GPU의 경우 `--half` 매개변수를 사용해야 할 수 있습니다. 60 | 61 | ### 3. 시맨틱 토큰에서 음성 생성: 62 | 63 | #### VQGAN 디코더 64 | 65 | !!! warning "향후 버전 경고" 66 | 원래 경로(tools/vqgan/infernce.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오. 67 | 68 | ```bash 69 | python fish_speech/models/vqgan/inference.py \ 70 | -i "codes_0.npy" \ 71 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 72 | ``` 73 | 74 | ## HTTP API 추론 75 | 76 | 추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다: 77 | 78 | ```bash 79 | python -m tools.api_server \ 80 | --listen 0.0.0.0:8080 \ 81 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 82 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 83 | --decoder-config-name firefly_gan_vq 84 | ``` 85 | 86 | 추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다. 87 | 88 | 이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다. 89 | 90 | 아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다. 91 | 92 | ```bash 93 | python -m tools.api_client \ 94 | --text "입력할 텍스트" \ 95 | --reference_audio "참고 음성 경로" \ 96 | --reference_text "참고 음성의 텍스트 내용" \ 97 | --streaming True 98 | ``` 99 | 100 | 위 명령은 참고 음성 정보를 바탕으로 원하는 음성을 합성하고, 스트리밍 방식으로 반환합니다. 101 | 102 | 다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다. 103 | 104 | ```bash 105 | python -m tools.api_client \ 106 | --text "입력할 텍스트" \ 107 | --reference_audio "참고 음성 경로1" "참고 음성 경로2" \ 108 | --reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\ 109 | --streaming False \ 110 | --output "generated" \ 111 | --format "mp3" 112 | ``` 113 | 114 | 위 명령어는 여러 참고 음성 정보를 바탕으로 `MP3` 형식의 음성을 합성하여, 현재 디렉토리에 `generated.mp3`로 저장합니다. 115 | 116 | `--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다. 117 | 118 | !!! info 119 | 제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다. 120 | 121 | ## GUI 추론 122 | [클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases) 123 | 124 | ## WebUI 추론 125 | 126 | 다음 명령으로 WebUI를 시작할 수 있습니다: 127 | 128 | ```bash 129 | python -m tools.run_webui \ 130 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 131 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 132 | --decoder-config-name firefly_gan_vq 133 | ``` 134 | 135 | > 추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다. 136 | 137 | !!! note 138 | 라벨 파일과 참고 음성 파일을 미리 메인 디렉토리의 `references` 폴더에 저장해 두면, WebUI에서 바로 호출할 수 있습니다. (해당 폴더는 직접 생성해야 합니다.) 139 | 140 | !!! note 141 | WebUI를 구성하기 위해 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`과 같은 Gradio 환경 변수를 사용할 수 있습니다. 142 | 143 | 즐기세요! 144 | -------------------------------------------------------------------------------- /docs/ko/start_agent.md: -------------------------------------------------------------------------------- 1 | # 에이전트 시작하기 2 | 3 | !!! note 4 | 전체 문서는 claude3.5 Sonnet에 의해 번역되었으며, 원어민인 경우 번역에 문제가 있다고 생각되면 이슈나 풀 리퀘스트를 보내주셔서 대단히 감사합니다! 5 | 6 | ## 요구사항 7 | 8 | - GPU 메모리: 최소 8GB(양자화 사용 시), 16GB 이상 권장 9 | - 디스크 사용량: 10GB 10 | 11 | ## 모델 다운로드 12 | 13 | 다음 명령어로 모델을 받을 수 있습니다: 14 | 15 | ```bash 16 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b 17 | ``` 18 | 19 | 'checkpoints' 폴더에 파일들을 넣으세요. 20 | 21 | 또한 [inference](inference.md)에 설명된 대로 fish-speech 모델도 다운로드해야 합니다. 22 | 23 | checkpoints에는 2개의 폴더가 있어야 합니다. 24 | 25 | `checkpoints/fish-speech-1.4`와 `checkpoints/fish-agent-v0.1-3b`입니다. 26 | 27 | ## 환경 준비 28 | 29 | 이미 Fish-speech가 있다면 다음 명령어를 추가하여 바로 사용할 수 있습니다: 30 | ```bash 31 | pip install cachetools 32 | ``` 33 | 34 | !!! 참고 35 | 컴파일을 위해 Python 3.12 미만 버전을 사용해 주세요. 36 | 37 | 없다면 아래 명령어를 사용하여 환경을 구축하세요: 38 | 39 | ```bash 40 | sudo apt-get install portaudio19-dev 41 | 42 | pip install -e .[stable] 43 | ``` 44 | 45 | ## 에이전트 데모 실행 46 | 47 | fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요: 48 | 49 | ```bash 50 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile 51 | ``` 52 | 53 | `--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다. 54 | 55 | 한 번에 컴파일되지 않습니다(기억해 두세요). 56 | 57 | 그런 다음 다른 터미널을 열고 다음 명령어를 사용하세요: 58 | 59 | ```bash 60 | python -m tools.e2e_webui 61 | ``` 62 | 63 | 이렇게 하면 기기에 Gradio WebUI가 생성됩니다. 64 | 65 | 모델을 처음 사용할 때는 (`--compile`이 True인 경우) 잠시 컴파일이 진행되므로 기다려 주세요. 66 | 67 | ## Gradio Webui 68 |

69 | 70 |

71 | 72 | 즐거운 시간 되세요! 73 | 74 | ## 성능 75 | 76 | 테스트 결과, 4060 노트북은 겨우 실행되며 매우 부하가 큰 상태로, 초당 약 8토큰 정도만 처리합니다. 4090은 컴파일 상태에서 초당 약 95토큰을 처리하며, 이것이 저희가 권장하는 사양입니다. 77 | 78 | # 에이전트 소개 79 | 80 | 이 데모는 초기 알파 테스트 버전으로, 추론 속도 최적화가 필요하며 수정해야 할 버그가 많이 있습니다. 버그를 발견하거나 수정하고 싶으시다면 이슈나 풀 리퀘스트를 보내주시면 매우 감사하겠습니다. 81 | -------------------------------------------------------------------------------- /docs/pt/finetune.md: -------------------------------------------------------------------------------- 1 | # Ajuste Fino 2 | 3 | É óbvio que ao abrir esta página, você não deve estar muito satisfeito com o desempenho do modelo pré-treinado com poucos exemplos. Você pode querer ajustar o modelo para melhorar seu desempenho em seu conjunto de dados. 4 | 5 | Na atual versão, a única coisa que você precisa ajustar é a parte do 'LLAMA'. 6 | 7 | ## Ajuste Fino do LLAMA 8 | ### 1. Preparando o conjunto de dados 9 | 10 | ``` 11 | . 12 | ├── SPK1 13 | │ ├── 21.15-26.44.lab 14 | │ ├── 21.15-26.44.mp3 15 | │ ├── 27.51-29.98.lab 16 | │ ├── 27.51-29.98.mp3 17 | │ ├── 30.1-32.71.lab 18 | │ └── 30.1-32.71.mp3 19 | └── SPK2 20 | ├── 38.79-40.85.lab 21 | └── 38.79-40.85.mp3 22 | ``` 23 | 24 | Você precisa converter seu conjunto de dados para o formato acima e colocá-lo em `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`. 25 | 26 | !!! info 27 | O arquivo de anotação `.lab` deve conter apenas a transcrição do áudio, sem a necessidade de formatação especial. Por exemplo, se o arquivo `hi.mp3` disser "Olá, tchau", o arquivo `hi.lab` conterá uma única linha de texto: "Olá, tchau". 28 | 29 | !!! warning 30 | É recomendado aplicar normalização de volume ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso. 31 | 32 | ```bash 33 | fap loudness-norm data-raw data --clean 34 | ``` 35 | 36 | 37 | ### 2. Extração em lote de tokens semânticos 38 | 39 | Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando: 40 | 41 | ```bash 42 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 43 | ``` 44 | 45 | Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos: 46 | 47 | ```bash 48 | python tools/vqgan/extract_vq.py data \ 49 | --num-workers 1 --batch-size 16 \ 50 | --config-name "firefly_gan_vq" \ 51 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 52 | ``` 53 | 54 | !!! note 55 | Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.   56 | Para o formato VITS, você pode especificar uma lista de arquivos usando `--filelist xxx.list`. 57 | 58 | Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo: 59 | 60 | ``` 61 | . 62 | ├── SPK1 63 | │ ├── 21.15-26.44.lab 64 | │ ├── 21.15-26.44.mp3 65 | │ ├── 21.15-26.44.npy 66 | │ ├── 27.51-29.98.lab 67 | │ ├── 27.51-29.98.mp3 68 | │ ├── 27.51-29.98.npy 69 | │ ├── 30.1-32.71.lab 70 | │ ├── 30.1-32.71.mp3 71 | │ └── 30.1-32.71.npy 72 | └── SPK2 73 | ├── 38.79-40.85.lab 74 | ├── 38.79-40.85.mp3 75 | └── 38.79-40.85.npy 76 | ``` 77 | 78 | ### 3. Empacotar o conjunto de dados em protobuf 79 | 80 | ```bash 81 | python tools/llama/build_dataset.py \ 82 | --input "data" \ 83 | --output "data/protos" \ 84 | --text-extension .lab \ 85 | --num-workers 16 86 | ``` 87 | 88 | Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.protos` no diretório `data`. 89 | 90 | ### 4. E finalmente, chegamos ao ajuste fino com LoRA 91 | 92 | Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando: 93 | 94 | ```bash 95 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 96 | ``` 97 | 98 | E então, execute o seguinte comando para iniciar o ajuste fino: 99 | 100 | ```bash 101 | python fish_speech/train.py --config-name text2semantic_finetune \ 102 | project=$project \ 103 | +lora@model.model.lora_config=r_8_alpha_16 104 | ``` 105 | 106 | !!! note 107 | Se quiser, você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se ajustar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`. 108 | 109 | !!! note 110 | Para usuários do Windows, é recomendado usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`. 111 | 112 | Após concluir o treinamento, consulte a seção [inferência](inference.md). 113 | 114 | !!! info 115 | Por padrão, o modelo aprenderá apenas os padrões de fala do orador e não o timbre. Ainda pode ser preciso usar prompts para garantir a estabilidade do timbre. 116 | Se quiser que ele aprenda o timbre, aumente o número de etapas de treinamento, mas isso pode levar ao overfitting (sobreajuste). 117 | 118 | Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares antes de realizar a inferência. 119 | 120 | ```bash 121 | python tools/llama/merge_lora.py \ 122 | --lora-config r_8_alpha_16 \ 123 | --base-weight checkpoints/fish-speech-1.5 \ 124 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \ 125 | --output checkpoints/fish-speech-1.5-yth-lora/ 126 | ``` 127 | !!! note 128 | É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD). 129 | -------------------------------------------------------------------------------- /docs/pt/start_agent.md: -------------------------------------------------------------------------------- 1 | # Iniciar Agente 2 | 3 | !!! note 4 | Todo o documento foi traduzido por claude3.5 Sonnet, se você for um falante nativo e achar a tradução problemática, muito obrigado por nos enviar um problema ou uma solicitação pull! 5 | 6 | ## Requisitos 7 | 8 | - Memória GPU: No mínimo 8GB (com quantização), 16GB ou mais é recomendado. 9 | - Uso de disco: 10GB 10 | 11 | ## Download do Modelo 12 | 13 | Você pode obter o modelo através de: 14 | 15 | ```bash 16 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b 17 | ``` 18 | 19 | Coloque-os na pasta 'checkpoints'. 20 | 21 | Você também precisará do modelo fish-speech que pode ser baixado seguindo as instruções em [inference](inference.md). 22 | 23 | Então haverá 2 pastas em checkpoints. 24 | 25 | O `checkpoints/fish-speech-1.4` e `checkpoints/fish-agent-v0.1-3b` 26 | 27 | ## Preparação do Ambiente 28 | 29 | Se você já tem o Fish-speech, pode usar diretamente adicionando a seguinte instrução: 30 | ```bash 31 | pip install cachetools 32 | ``` 33 | 34 | !!! nota 35 | Por favor, use a versão Python abaixo de 3.12 para compilação. 36 | 37 | Se você não tem, use os comandos abaixo para construir seu ambiente: 38 | 39 | ```bash 40 | sudo apt-get install portaudio19-dev 41 | 42 | pip install -e .[stable] 43 | ``` 44 | 45 | ## Iniciar a Demo do Agente 46 | 47 | Para construir o fish-agent, use o comando abaixo na pasta principal: 48 | 49 | ```bash 50 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile 51 | ``` 52 | 53 | O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens. 54 | 55 | Não será compilado de uma vez (lembre-se). 56 | 57 | Então abra outro terminal e use o comando: 58 | 59 | ```bash 60 | python -m tools.e2e_webui 61 | ``` 62 | 63 | Isso criará uma WebUI Gradio no dispositivo. 64 | 65 | Quando você usar o modelo pela primeira vez, ele irá compilar (se `--compile` estiver True) por um curto período, então aguarde com paciência. 66 | 67 | ## Gradio Webui 68 |

69 | 70 |

71 | 72 | Divirta-se! 73 | 74 | ## Desempenho 75 | 76 | Em nossos testes, um laptop com 4060 mal consegue rodar, ficando muito sobrecarregado, gerando apenas cerca de 8 tokens/s. A 4090 gera cerca de 95 tokens/s com compilação, que é o que recomendamos. 77 | 78 | # Sobre o Agente 79 | 80 | A demo é uma versão alpha inicial de teste, a velocidade de inferência precisa ser otimizada, e há muitos bugs aguardando correção. Se você encontrou um bug ou quer corrigi-lo, ficaremos muito felizes em receber uma issue ou um pull request. 81 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs-material 2 | mkdocs-static-i18n[material] 3 | mkdocs[i18n] 4 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .md-grid { 2 | max-width: 1440px; 3 | } 4 | -------------------------------------------------------------------------------- /docs/zh/finetune.md: -------------------------------------------------------------------------------- 1 | # 微调 2 | 3 | 显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好. 4 | 5 | 在目前版本,你只需要微调'LLAMA'部分即可. 6 | 7 | ## LLAMA 微调 8 | ### 1. 准备数据集 9 | 10 | ``` 11 | . 12 | ├── SPK1 13 | │ ├── 21.15-26.44.lab 14 | │ ├── 21.15-26.44.mp3 15 | │ ├── 27.51-29.98.lab 16 | │ ├── 27.51-29.98.mp3 17 | │ ├── 30.1-32.71.lab 18 | │ └── 30.1-32.71.mp3 19 | └── SPK2 20 | ├── 38.79-40.85.lab 21 | └── 38.79-40.85.mp3 22 | ``` 23 | 24 | 你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`. 25 | 26 | !!! info 27 | 标注文件 `.lab` 仅需包含音频的转写文本,无需遵循特殊格式要求。例如,如果 `hi.mp3` 中的内容是“你好,再见。”,那么 `hi.lab` 文件中只需包含一行文本:“你好,再见”。 28 | 29 | !!! warning 30 | 建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤. 31 | ```bash 32 | fap loudness-norm data-raw data --clean 33 | ``` 34 | 35 | ### 2. 批量提取语义 token 36 | 37 | 确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令: 38 | 39 | ```bash 40 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 41 | ``` 42 | 43 | 对于中国大陆用户, 可使用 mirror 下载. 44 | 45 | ```bash 46 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 47 | ``` 48 | 49 | 随后可运行以下命令来提取语义 token: 50 | 51 | ```bash 52 | python tools/vqgan/extract_vq.py data \ 53 | --num-workers 1 --batch-size 16 \ 54 | --config-name "firefly_gan_vq" \ 55 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 56 | ``` 57 | 58 | !!! note 59 | 你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制. 60 | 61 | 该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示: 62 | 63 | ``` 64 | . 65 | ├── SPK1 66 | │ ├── 21.15-26.44.lab 67 | │ ├── 21.15-26.44.mp3 68 | │ ├── 21.15-26.44.npy 69 | │ ├── 27.51-29.98.lab 70 | │ ├── 27.51-29.98.mp3 71 | │ ├── 27.51-29.98.npy 72 | │ ├── 30.1-32.71.lab 73 | │ ├── 30.1-32.71.mp3 74 | │ └── 30.1-32.71.npy 75 | └── SPK2 76 | ├── 38.79-40.85.lab 77 | ├── 38.79-40.85.mp3 78 | └── 38.79-40.85.npy 79 | ``` 80 | 81 | ### 3. 打包数据集为 protobuf 82 | 83 | ```bash 84 | python tools/llama/build_dataset.py \ 85 | --input "data" \ 86 | --output "data/protos" \ 87 | --text-extension .lab \ 88 | --num-workers 16 89 | ``` 90 | 91 | 命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件. 92 | 93 | 94 | ### 4. 最后, 使用 LoRA 进行微调 95 | 96 | 同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令: 97 | 98 | ```bash 99 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 100 | ``` 101 | 102 | 对于中国大陆用户, 可使用 mirror 下载. 103 | 104 | ```bash 105 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 106 | ``` 107 | 108 | 最后, 你可以运行以下命令来启动微调: 109 | 110 | ```bash 111 | python fish_speech/train.py --config-name text2semantic_finetune \ 112 | project=$project \ 113 | +lora@model.model.lora_config=r_8_alpha_16 114 | ``` 115 | 116 | !!! note 117 | 你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存. 118 | 119 | !!! note 120 | 对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题. 121 | 122 | 训练结束后, 你可以参考 [推理](inference.md) 部分来测试你的模型. 123 | 124 | !!! info 125 | 默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性. 126 | 如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合. 127 | 128 | 训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理. 129 | 130 | ```bash 131 | python tools/llama/merge_lora.py \ 132 | --lora-config r_8_alpha_16 \ 133 | --base-weight checkpoints/fish-speech-1.5 \ 134 | --lora-weight results/$project/checkpoints/step_000000010.ckpt \ 135 | --output checkpoints/fish-speech-1.5-yth-lora/ 136 | ``` 137 | 138 | !!! note 139 | 你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好. 140 | -------------------------------------------------------------------------------- /docs/zh/inference.md: -------------------------------------------------------------------------------- 1 | # 推理 2 | 3 | 推理支持命令行, http api, 以及 webui 三种方式. 4 | 5 | !!! note 6 | 总的来说, 推理分为几个部分: 7 | 8 | 1. 给定一段 ~10 秒的语音, 将它用 VQGAN 编码. 9 | 2. 将编码后的语义 token 和对应文本输入语言模型作为例子. 10 | 3. 给定一段新文本, 让模型生成对应的语义 token. 11 | 4. 将生成的语义 token 输入 VQGAN 解码, 生成对应的语音. 12 | 13 | ## 下载模型 14 | 从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。 15 | 16 | ```bash 17 | huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 18 | ``` 19 | 20 | 对于中国大陆用户,可使用 mirror 下载。 21 | 22 | ```bash 23 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5 24 | ``` 25 | 26 | ## 命令行推理 27 | ### 1. 从语音生成 prompt: 28 | 29 | !!! note 30 | 如果你打算让模型随机选择音色, 你可以跳过这一步. 31 | 32 | ```bash 33 | python fish_speech/models/vqgan/inference.py \ 34 | -i "paimon.wav" \ 35 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 36 | ``` 37 | 38 | 你应该能得到一个 `fake.npy` 文件. 39 | 40 | ### 2. 从文本生成语义 token: 41 | 42 | ```bash 43 | python fish_speech/models/text2semantic/inference.py \ 44 | --text "要转换的文本" \ 45 | --prompt-text "你的参考文本" \ 46 | --prompt-tokens "fake.npy" \ 47 | --checkpoint-path "checkpoints/fish-speech-1.5" \ 48 | --num-samples 2 \ 49 | --compile 50 | ``` 51 | 52 | 该命令会在工作目录下创建 `codes_N` 文件, 其中 N 是从 0 开始的整数. 53 | 54 | !!! note 55 | 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒). 56 | 对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数. 57 | 58 | !!! info 59 | 对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数. 60 | 61 | ### 3. 从语义 token 生成人声: 62 | 63 | #### VQGAN 解码 64 | 65 | ```bash 66 | python fish_speech/models/vqgan/inference.py \ 67 | -i "codes_0.npy" \ 68 | --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" 69 | ``` 70 | 71 | ## HTTP API 推理 72 | 73 | 运行以下命令来启动 HTTP 服务: 74 | 75 | ```bash 76 | python -m tools.api_server \ 77 | --listen 0.0.0.0:8080 \ 78 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 79 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 80 | --decoder-config-name firefly_gan_vq 81 | ``` 82 | > 如果你想要加速推理,可以加上`--compile`参数。 83 | 84 | 推荐中国大陆用户运行以下命令来启动 HTTP 服务: 85 | ```bash 86 | HF_ENDPOINT=https://hf-mirror.com python -m ...(同上) 87 | ``` 88 | 89 | 随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API. 90 | 91 | 下面是使用`tools/api_client.py`发送请求的示例。 92 | 93 | ```bash 94 | python -m tools.api_client \ 95 | --text "要输入的文本" \ 96 | --reference_audio "参考音频路径" \ 97 | --reference_text "参考音频的文本内容" \ 98 | --streaming True 99 | ``` 100 | 101 | 上面的命令表示按照参考音频的信息,合成所需的音频并流式返回. 102 | 103 | 下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。 104 | ```bash 105 | python -m tools.api_client \ 106 | --text "要输入的文本" \ 107 | --reference_audio "参考音频路径1" "参考音频路径2" \ 108 | --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\ 109 | --streaming False \ 110 | --output "generated" \ 111 | --format "mp3" 112 | ``` 113 | 114 | 上面的命令表示按照多个参考音频的信息,合成所需的`MP3`格式音频,并保存为当前目录的`generated.mp3`文件。 115 | 116 | 还可以用`--reference_id`(仅能用一个)来代替`--reference_audio`和`--reference_text`, 前提是在项目根目录下创建`references/`文件夹, 117 | 里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。 118 | 119 | !!! info 120 | 要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h` 121 | 122 | ## GUI 推理 123 | [下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases) 124 | 125 | ## WebUI 推理 126 | 127 | 你可以使用以下命令来启动 WebUI: 128 | 129 | ```bash 130 | python -m tools.run_webui \ 131 | --llama-checkpoint-path "checkpoints/fish-speech-1.5" \ 132 | --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 133 | --decoder-config-name firefly_gan_vq 134 | ``` 135 | > 如果你想要加速推理,可以加上`--compile`参数。 136 | 137 | !!! note 138 | 你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。 139 | 140 | !!! note 141 | 你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI. 142 | 143 | 祝大家玩得开心! 144 | -------------------------------------------------------------------------------- /docs/zh/start_agent.md: -------------------------------------------------------------------------------- 1 | # 启动 Agent 2 | 3 | ## 要求 4 | 5 | - GPU 显存: 至少 8GB(在量化的条件下),推荐 16GB 及以上 6 | - 硬盘使用量: 10GB 7 | 8 | ## 下载模型 9 | 10 | 你可以执行下面的语句来获取模型: 11 | 12 | ```bash 13 | huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b 14 | ``` 15 | 16 | 如果你处于国内网络,首先执行: 17 | 18 | ```bash 19 | export HF_ENDPOINT=https://hf-mirror.com 20 | ``` 21 | 22 | 把他们放进名为 'checkpoints' 的文件夹内。 23 | 24 | 你同样需要 fish-speech 的模型,关于如何获取 fish-speech 模型请查看[inference](inference.md)。 25 | 26 | 完成后你的 checkpoints 文件夹中会有两个子文件夹:`checkpoints/fish-speech-1.4` 和 `checkpoints/fish-agent-v0.1-3b`。 27 | 28 | ## Environment Prepare 29 | 30 | 如果你已经有了 Fish-Speech 环境,你可以在安装下面的包的前提下直接使用: 31 | 32 | ```bash 33 | pip install cachetools 34 | ``` 35 | 36 | !!! note 37 | 请使用小于 3.12 的 python 版本使 compile 可用 38 | 39 | 如果你没有 Fish-Speech 环境,请执行下面的语句来构造你的环境: 40 | 41 | ```bash 42 | sudo apt-get install portaudio19-dev 43 | 44 | pip install -e .[stable] 45 | ``` 46 | 47 | ## 链接 Agent. 48 | 49 | 你需要使用以下指令来构建 fish-agent 50 | 51 | ```bash 52 | python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile 53 | ``` 54 | 55 | `--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。 56 | 57 | 你需要哦注意 compile 需要进行一段时间. 58 | 59 | 然后启动另一个终端并执行: 60 | 61 | ```bash 62 | python -m tools.e2e_webui 63 | ``` 64 | 65 | 这会在设备上创建一个 Gradio WebUI。 66 | 67 | 每当进行第一轮对话的时候,模型需要 compile 一段时间,请耐心等待 68 | 69 | ## Gradio Webui 70 | 71 |

72 | 73 |

74 | 75 | 玩得开心! 76 | 77 | ## Performance 78 | 79 | 在我们的测试环境下, 4060 laptop GPU 只能刚刚运行该模型,只有大概 8 tokens/s。 4090 GPU 可以在编译后达到 95 tokens/s,我们推荐使用至少 4080 以上级别的 GPU 来达到较好体验。 80 | 81 | # About Agent 82 | 83 | 该模型仍处于测试阶段。如果你发现了问题,请给我们提 issue 或者 pull request,我们非常感谢。 84 | -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_ENABLED=${CUDA_ENABLED:-true} 4 | DEVICE="" 5 | 6 | if [ "${CUDA_ENABLED}" != "true" ]; then 7 | DEVICE="--device cpu" 8 | fi 9 | 10 | exec python tools/run_webui.py ${DEVICE} 11 | -------------------------------------------------------------------------------- /fish_speech/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .grad_norm import GradNormMonitor 2 | 3 | __all__ = ["GradNormMonitor"] 4 | -------------------------------------------------------------------------------- /fish_speech/callbacks/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import lightning.pytorch as pl 4 | import torch 5 | from lightning import LightningModule, Trainer 6 | from lightning.pytorch.callbacks import Callback 7 | from torch import Tensor, nn 8 | from torch.utils._foreach_utils import ( 9 | _group_tensors_by_device_and_dtype, 10 | _has_foreach_support, 11 | ) 12 | 13 | 14 | @torch.no_grad() 15 | def grad_norm( 16 | parameters: Union[Tensor, list[Tensor]], 17 | norm_type: float = 2.0, 18 | ) -> float: 19 | """ 20 | Returns the norm of the gradients of the given parameters. 21 | 22 | Args: 23 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 24 | single Tensor that will have gradients normalized 25 | norm_type (float): type of the used p-norm. 26 | 27 | Returns: 28 | Total norm of the parameter gradients (viewed as a single vector). 29 | """ # noqa: E501 30 | 31 | if isinstance(parameters, Tensor): 32 | parameters = [parameters] 33 | 34 | grads = [p.grad for p in parameters if p.grad is not None] 35 | if len(grads) == 0: 36 | return None 37 | 38 | first_device = grads[0].device 39 | grouped_grads: dict[ 40 | tuple[torch.device, torch.dtype], list[list[Tensor]] 41 | ] = _group_tensors_by_device_and_dtype( 42 | [[g.detach() for g in grads]] 43 | ) # type: ignore[assignment] 44 | 45 | norms = [] 46 | for (device, _), ([grads], _) in grouped_grads.items(): 47 | if _has_foreach_support(grads, device=device): 48 | norms.extend(torch._foreach_norm(grads, norm_type)) 49 | else: 50 | norms.extend([torch.norm(g, norm_type) for g in grads]) 51 | 52 | return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) 53 | 54 | 55 | class GradNormMonitor(Callback): 56 | """ 57 | Callback that computes the gradient norm of the model parameters. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | norm_type: float = 2.0, 63 | logging_interval: str = "step", 64 | sub_module: Optional[Union[str, list[str]]] = None, 65 | ) -> None: 66 | """ 67 | Args: 68 | norm_type (float): type of the used p-norm. 69 | logging_interval (str): "step" or "epoch". 70 | """ 71 | super().__init__() 72 | 73 | self.norm_type = norm_type 74 | self.logging_interval = logging_interval 75 | self.sub_module = sub_module 76 | 77 | def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: 78 | """ 79 | Computes the gradient norm of the model parameters and logs it to the logger. 80 | 81 | Args: 82 | trainer (Trainer): The trainer object 83 | model (LightningModule): The current lightningModule 84 | """ 85 | 86 | lightning_model = model 87 | 88 | if self.sub_module is None: 89 | return self.log_sub_module_grad_norm(lightning_model, model, "") 90 | 91 | sub_modules = self.sub_module 92 | if isinstance(sub_modules, str): 93 | sub_modules = [sub_modules] 94 | 95 | for sub_module in sub_modules: 96 | self.log_sub_module_grad_norm( 97 | lightning_model, getattr(model, sub_module), f"/{sub_module}" 98 | ) 99 | 100 | def log_sub_module_grad_norm( 101 | self, lightning_model: LightningModule, model: nn.Module, path: str 102 | ) -> None: 103 | grad_norm_val = grad_norm(model.parameters(), self.norm_type) 104 | if grad_norm_val is None: 105 | return 106 | 107 | on_step = self.logging_interval == "step" 108 | lightning_model.log( 109 | f"train{path}/grad_norm", 110 | grad_norm_val, 111 | on_step=on_step, 112 | on_epoch=not on_step, 113 | ) 114 | -------------------------------------------------------------------------------- /fish_speech/configs/base.yaml: -------------------------------------------------------------------------------- 1 | # Base configuration for training a model 2 | paths: 3 | run_dir: results/${project} 4 | ckpt_dir: ${paths.run_dir}/checkpoints 5 | 6 | hydra: 7 | run: 8 | dir: ${paths.run_dir} 9 | 10 | # Lightning Trainer 11 | trainer: 12 | _target_: lightning.pytorch.trainer.Trainer 13 | 14 | default_root_dir: ${paths.run_dir} 15 | accelerator: gpu 16 | num_nodes: 1 17 | devices: auto 18 | strategy: 19 | _target_: lightning.pytorch.strategies.DDPStrategy 20 | process_group_backend: nccl # This should be override when training on windows 21 | 22 | precision: bf16-mixed 23 | 24 | # disable validation by epoch end 25 | check_val_every_n_epoch: null 26 | val_check_interval: 5000 27 | max_steps: 100_000 28 | 29 | # Use torch.backends.cudnn.benchmark to speed up training 30 | benchmark: true 31 | 32 | # Callbacks 33 | callbacks: 34 | model_checkpoint: 35 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 36 | dirpath: ${paths.ckpt_dir} 37 | filename: "step_{step:09d}" 38 | save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt 39 | save_top_k: 5 # save 5 latest checkpoints 40 | monitor: step # use step to monitor checkpoints 41 | mode: max # save the latest checkpoint with the highest global_step 42 | every_n_epochs: null # don't save checkpoints by epoch end 43 | every_n_train_steps: 5000 # save checkpoints every 5000 steps 44 | auto_insert_metric_name: false 45 | 46 | model_summary: 47 | _target_: lightning.pytorch.callbacks.ModelSummary 48 | max_depth: 2 # the maximum depth of layer nesting that the summary will include 49 | 50 | learning_rate_monitor: 51 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 52 | logging_interval: step 53 | log_momentum: false 54 | 55 | grad_norm_monitor: 56 | _target_: fish_speech.callbacks.GradNormMonitor 57 | norm_type: 2 58 | logging_interval: step 59 | 60 | # Logger 61 | logger: 62 | tensorboard: 63 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 64 | save_dir: "${paths.run_dir}/tensorboard/" 65 | name: null 66 | log_graph: false 67 | default_hp_metric: true 68 | prefix: "" 69 | 70 | # wandb: 71 | # _target_: lightning.pytorch.loggers.wandb.WandbLogger 72 | # # name: "" # name of the run (normally generated by wandb) 73 | # save_dir: "${paths.run_dir}" 74 | # offline: False 75 | # id: null # pass correct id to resume experiment! 76 | # anonymous: null # enable anonymous logging 77 | # project: "fish-speech" 78 | # log_model: False # upload lightning ckpts 79 | # prefix: "" # a string to put at the beginning of metric keys 80 | # # entity: "" # set to name of your wandb team 81 | # group: "" 82 | # tags: ["vq", "hq", "finetune"] 83 | # job_type: "" 84 | 85 | # Loop 86 | train: true 87 | test: false 88 | -------------------------------------------------------------------------------- /fish_speech/configs/firefly_gan_vq.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture 2 | spec_transform: 3 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 4 | sample_rate: 44100 5 | n_mels: 160 6 | n_fft: 2048 7 | hop_length: 512 8 | win_length: 2048 9 | backbone: 10 | _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder 11 | input_channels: 160 12 | depths: [3, 3, 9, 3] 13 | dims: [128, 256, 384, 512] 14 | drop_path_rate: 0.2 15 | kernel_size: 7 16 | head: 17 | _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator 18 | hop_length: 512 19 | upsample_rates: [8, 8, 2, 2, 2] # aka. strides 20 | upsample_kernel_sizes: [16, 16, 4, 4, 4] 21 | resblock_kernel_sizes: [3, 7, 11] 22 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 23 | num_mels: 512 24 | upsample_initial_channel: 512 25 | pre_conv_kernel_size: 13 26 | post_conv_kernel_size: 13 27 | quantizer: 28 | _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize 29 | input_dim: 512 30 | n_groups: 8 31 | n_codebooks: 1 32 | levels: [8, 5, 5, 5] 33 | downsample_factor: [2, 2] 34 | -------------------------------------------------------------------------------- /fish_speech/configs/lora/r_8_alpha_16.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.text2semantic.lora.LoraConfig 2 | r: 8 3 | lora_alpha: 16 4 | lora_dropout: 0.01 5 | -------------------------------------------------------------------------------- /fish_speech/configs/text2semantic_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | project: text2semantic_finetune_dual_ar 6 | max_length: 4096 7 | pretrained_ckpt_path: checkpoints/fish-speech-1.5 8 | 9 | # Lightning Trainer 10 | trainer: 11 | accumulate_grad_batches: 1 12 | gradient_clip_val: 1.0 13 | gradient_clip_algorithm: "norm" 14 | max_steps: 10000 15 | precision: bf16-true 16 | limit_val_batches: 10 17 | val_check_interval: 100 18 | # strategy: 19 | # find_unused_parameters: true 20 | # static_graph: true 21 | 22 | # Dataset Configuration 23 | tokenizer: 24 | _target_: fish_speech.tokenizer.FishTokenizer 25 | model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken 26 | 27 | # Dataset Configuration 28 | train_dataset: 29 | _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset 30 | proto_files: 31 | - data/protos 32 | tokenizer: ${tokenizer} 33 | causal: true 34 | max_length: ${max_length} 35 | use_speaker: false 36 | interactive_prob: 0.7 37 | 38 | val_dataset: 39 | _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset 40 | proto_files: 41 | - data/protos 42 | tokenizer: ${tokenizer} 43 | causal: true 44 | max_length: ${max_length} 45 | use_speaker: false 46 | interactive_prob: 0.7 47 | 48 | data: 49 | _target_: fish_speech.datasets.semantic.SemanticDataModule 50 | train_dataset: ${train_dataset} 51 | val_dataset: ${val_dataset} 52 | num_workers: 4 53 | batch_size: 4 54 | tokenizer: ${tokenizer} 55 | max_length: ${max_length} 56 | 57 | # Model Configuration 58 | model: 59 | _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic 60 | model: 61 | _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained 62 | path: ${pretrained_ckpt_path} 63 | load_weights: true 64 | max_length: ${max_length} 65 | lora_config: null 66 | 67 | optimizer: 68 | _target_: torch.optim.AdamW 69 | _partial_: true 70 | lr: 1e-4 71 | weight_decay: 0 72 | betas: [0.9, 0.95] 73 | eps: 1e-5 74 | 75 | lr_scheduler: 76 | _target_: torch.optim.lr_scheduler.LambdaLR 77 | _partial_: true 78 | lr_lambda: 79 | _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda 80 | _partial_: true 81 | num_warmup_steps: 10 82 | 83 | # Callbacks 84 | callbacks: 85 | model_checkpoint: 86 | every_n_train_steps: ${trainer.val_check_interval} 87 | -------------------------------------------------------------------------------- /fish_speech/datasets/concat_repeat.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | from typing import Iterable 4 | 5 | from torch.utils.data import Dataset, IterableDataset 6 | 7 | 8 | class ConcatRepeatDataset(Dataset): 9 | datasets: list[Dataset] 10 | cumulative_sizes: list[int] 11 | repeats: list[int] 12 | 13 | @staticmethod 14 | def cumsum(sequence, repeats): 15 | r, s = [], 0 16 | for dataset, repeat in zip(sequence, repeats): 17 | l = len(dataset) * repeat 18 | r.append(l + s) 19 | s += l 20 | return r 21 | 22 | def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): 23 | super().__init__() 24 | 25 | self.datasets = list(datasets) 26 | self.repeats = repeats 27 | 28 | assert len(self.datasets) > 0, "datasets should not be an empty iterable" 29 | assert len(self.datasets) == len( 30 | repeats 31 | ), "datasets and repeats should have the same length" 32 | 33 | for d in self.datasets: 34 | assert not isinstance( 35 | d, IterableDataset 36 | ), "ConcatRepeatDataset does not support IterableDataset" 37 | 38 | self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) 39 | 40 | def __len__(self): 41 | return self.cumulative_sizes[-1] 42 | 43 | def __getitem__(self, idx): 44 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 45 | 46 | if dataset_idx == 0: 47 | sample_idx = idx 48 | else: 49 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 50 | 51 | dataset = self.datasets[dataset_idx] 52 | 53 | return dataset[sample_idx % len(dataset)] 54 | -------------------------------------------------------------------------------- /fish_speech/datasets/protos/text-data.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package text_data; 4 | 5 | message Semantics { 6 | repeated uint32 values = 1; 7 | } 8 | 9 | message Sentence { 10 | repeated string texts = 1; 11 | repeated Semantics semantics = 3; 12 | } 13 | 14 | message TextData { 15 | string source = 1; 16 | string name = 2; 17 | repeated Sentence sentences = 4; 18 | } 19 | 20 | message SampledData { 21 | string source = 1; 22 | string name = 2; 23 | repeated Sentence samples = 3; 24 | } 25 | -------------------------------------------------------------------------------- /fish_speech/datasets/protos/text_data_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: text-data.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 17 | b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' 18 | ) 19 | 20 | _globals = globals() 21 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 22 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) 23 | if _descriptor._USE_C_DESCRIPTORS == False: 24 | DESCRIPTOR._options = None 25 | _globals["_SEMANTICS"]._serialized_start = 30 26 | _globals["_SEMANTICS"]._serialized_end = 57 27 | _globals["_SENTENCE"]._serialized_start = 59 28 | _globals["_SENTENCE"]._serialized_end = 125 29 | _globals["_TEXTDATA"]._serialized_start = 127 30 | _globals["_TEXTDATA"]._serialized_end = 207 31 | _globals["_SAMPLEDDATA"]._serialized_start = 209 32 | _globals["_SAMPLEDDATA"]._serialized_end = 290 33 | # @@protoc_insertion_point(module_scope) 34 | -------------------------------------------------------------------------------- /fish_speech/datasets/protos/text_data_stream.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | from .text_data_pb2 import TextData 4 | 5 | 6 | def read_pb_stream(f): 7 | while True: 8 | buf = f.read(4) 9 | if len(buf) == 0: 10 | break 11 | size = struct.unpack("I", buf)[0] 12 | buf = f.read(size) 13 | text_data = TextData() 14 | text_data.ParseFromString(buf) 15 | yield text_data 16 | 17 | 18 | def write_pb_stream(f, text_data): 19 | buf = text_data.SerializeToString() 20 | f.write(struct.pack("I", len(buf))) 21 | f.write(buf) 22 | 23 | 24 | def pack_pb_stream(text_data): 25 | buf = text_data.SerializeToString() 26 | return struct.pack("I", len(buf)) + buf 27 | 28 | 29 | def split_pb_stream(f): 30 | while True: 31 | head = f.read(4) 32 | if len(head) == 0: 33 | break 34 | size = struct.unpack("I", head)[0] 35 | buf = f.read(size) 36 | yield head + buf 37 | -------------------------------------------------------------------------------- /fish_speech/datasets/vqgan.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import librosa 6 | import numpy as np 7 | import torch 8 | from lightning import LightningDataModule 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from fish_speech.utils import RankedLogger 12 | 13 | logger = RankedLogger(__name__, rank_zero_only=False) 14 | 15 | 16 | class VQGANDataset(Dataset): 17 | def __init__( 18 | self, 19 | filelist: str, 20 | sample_rate: int = 32000, 21 | hop_length: int = 640, 22 | slice_frames: Optional[int] = None, 23 | ): 24 | super().__init__() 25 | 26 | filelist = Path(filelist) 27 | root = filelist.parent 28 | 29 | self.files = [ 30 | root / line.strip() 31 | for line in filelist.read_text(encoding="utf-8").splitlines() 32 | if line.strip() 33 | ] 34 | self.sample_rate = sample_rate 35 | self.hop_length = hop_length 36 | self.slice_frames = slice_frames 37 | 38 | def __len__(self): 39 | return len(self.files) 40 | 41 | def get_item(self, idx): 42 | file = self.files[idx] 43 | 44 | audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) 45 | 46 | # Slice audio and features 47 | if ( 48 | self.slice_frames is not None 49 | and audio.shape[0] > self.slice_frames * self.hop_length 50 | ): 51 | start = np.random.randint( 52 | 0, audio.shape[0] - self.slice_frames * self.hop_length 53 | ) 54 | audio = audio[start : start + self.slice_frames * self.hop_length] 55 | 56 | if len(audio) == 0: 57 | return None 58 | 59 | max_value = np.abs(audio).max() 60 | if max_value > 1.0: 61 | audio = audio / max_value 62 | 63 | return { 64 | "audio": torch.from_numpy(audio), 65 | } 66 | 67 | def __getitem__(self, idx): 68 | try: 69 | return self.get_item(idx) 70 | except Exception as e: 71 | import traceback 72 | 73 | traceback.print_exc() 74 | logger.error(f"Error loading {self.files[idx]}: {e}") 75 | return None 76 | 77 | 78 | @dataclass 79 | class VQGANCollator: 80 | def __call__(self, batch): 81 | batch = [x for x in batch if x is not None] 82 | 83 | audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) 84 | audio_maxlen = audio_lengths.max() 85 | 86 | # Rounds up to nearest multiple of 2 (audio_lengths) 87 | audios = [] 88 | for x in batch: 89 | audios.append( 90 | torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) 91 | ) 92 | 93 | return { 94 | "audios": torch.stack(audios), 95 | "audio_lengths": audio_lengths, 96 | } 97 | 98 | 99 | class VQGANDataModule(LightningDataModule): 100 | def __init__( 101 | self, 102 | train_dataset: VQGANDataset, 103 | val_dataset: VQGANDataset, 104 | batch_size: int = 32, 105 | num_workers: int = 4, 106 | val_batch_size: Optional[int] = None, 107 | ): 108 | super().__init__() 109 | 110 | self.train_dataset = train_dataset 111 | self.val_dataset = val_dataset 112 | self.batch_size = batch_size 113 | self.val_batch_size = val_batch_size or batch_size 114 | self.num_workers = num_workers 115 | 116 | def train_dataloader(self): 117 | return DataLoader( 118 | self.train_dataset, 119 | batch_size=self.batch_size, 120 | collate_fn=VQGANCollator(), 121 | num_workers=self.num_workers, 122 | shuffle=True, 123 | persistent_workers=True, 124 | ) 125 | 126 | def val_dataloader(self): 127 | return DataLoader( 128 | self.val_dataset, 129 | batch_size=self.val_batch_size, 130 | collate_fn=VQGANCollator(), 131 | num_workers=self.num_workers, 132 | persistent_workers=True, 133 | ) 134 | 135 | 136 | if __name__ == "__main__": 137 | dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") 138 | dataloader = DataLoader( 139 | dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() 140 | ) 141 | 142 | for batch in dataloader: 143 | print(batch["audios"].shape) 144 | print(batch["features"].shape) 145 | print(batch["audio_lengths"]) 146 | print(batch["feature_lengths"]) 147 | break 148 | -------------------------------------------------------------------------------- /fish_speech/i18n/README.md: -------------------------------------------------------------------------------- 1 | ## i18n Folder Attribution 2 | 3 | The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: 4 | 5 | ### fish_speech/i18n/core.py 6 | 7 | **Related code from RVC:** 8 | [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) 9 | 10 | **Initial commit:** 11 | add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) 12 | 13 | **Initial author:** 14 | [@L4Ph](https://github.com/L4Ph) 15 | 16 | ### fish_speech/i18n/scan.py 17 | 18 | **Related code from RVC:** 19 | [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) 20 | 21 | **Initial commit:** 22 | File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) 23 | 24 | **Initial author:** 25 | [@towzeur](https://github.com/towzeur) 26 | 27 | We appreciate the contributions of the RVC project and its authors. 28 | -------------------------------------------------------------------------------- /fish_speech/i18n/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import i18n 2 | 3 | __all__ = ["i18n"] 4 | -------------------------------------------------------------------------------- /fish_speech/i18n/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import locale 3 | from pathlib import Path 4 | 5 | I18N_FILE_PATH = Path(__file__).parent / "locale" 6 | DEFAULT_LANGUAGE = "en_US" 7 | 8 | 9 | def load_language_list(language): 10 | with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: 11 | language_list = json.load(f) 12 | 13 | return language_list 14 | 15 | 16 | class I18nAuto: 17 | def __init__(self): 18 | i18n_file = Path(".locale") 19 | 20 | if i18n_file.exists(): 21 | with open(i18n_file, "r", encoding="utf-8") as f: 22 | language = f.read().strip() 23 | else: 24 | # getlocale can't identify the system's language ((None, None)) 25 | language = locale.getdefaultlocale()[0] 26 | 27 | if (I18N_FILE_PATH / f"{language}.json").exists() is False: 28 | language = DEFAULT_LANGUAGE 29 | 30 | self.language = language 31 | self.language_map = load_language_list(language) 32 | 33 | def __call__(self, key): 34 | return self.language_map.get(key, key) 35 | 36 | def __repr__(self): 37 | return "Use Language: " + self.language 38 | 39 | 40 | i18n = I18nAuto() 41 | -------------------------------------------------------------------------------- /fish_speech/i18n/scan.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import glob 3 | import json 4 | from collections import OrderedDict 5 | from pathlib import Path 6 | 7 | from loguru import logger 8 | 9 | from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH 10 | 11 | 12 | def extract_i18n_strings(node): 13 | i18n_strings = [] 14 | 15 | if ( 16 | isinstance(node, ast.Call) 17 | and isinstance(node.func, ast.Name) 18 | and node.func.id == "i18n" 19 | ): 20 | for arg in node.args: 21 | if isinstance(arg, ast.Str): 22 | i18n_strings.append(arg.s) 23 | 24 | for child_node in ast.iter_child_nodes(node): 25 | i18n_strings.extend(extract_i18n_strings(child_node)) 26 | 27 | return i18n_strings 28 | 29 | 30 | # scan the directory for all .py files (recursively) 31 | # for each file, parse the code into an AST 32 | # for each AST, extract the i18n strings 33 | 34 | strings = [] 35 | folders = ["fish_speech", "tools"] 36 | # for filename in glob.iglob("**/*.py", recursive=True): 37 | for folder in folders: 38 | for f in Path(folder).rglob("*.py"): 39 | code = f.read_text(encoding="utf-8") 40 | if "i18n(" in code: 41 | tree = ast.parse(code) 42 | i18n_strings = extract_i18n_strings(tree) 43 | logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") 44 | strings.extend(i18n_strings) 45 | 46 | code_keys = set(strings) 47 | logger.info(f"Total unique: {len(code_keys)}") 48 | 49 | 50 | standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" 51 | with open(standard_file, "r", encoding="utf-8") as f: 52 | standard_data = json.load(f, object_pairs_hook=OrderedDict) 53 | standard_keys = set(standard_data.keys()) 54 | 55 | # Define the standard file name 56 | unused_keys = standard_keys - code_keys 57 | logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") 58 | for unused_key in unused_keys: 59 | logger.info(f"\t{unused_key}") 60 | 61 | missing_keys = code_keys - standard_keys 62 | logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") 63 | for missing_key in missing_keys: 64 | logger.info(f"\t{missing_key}") 65 | 66 | code_keys_dict = OrderedDict() 67 | for s in strings: 68 | code_keys_dict[s] = s 69 | 70 | # write back 71 | with open(standard_file, "w", encoding="utf-8") as f: 72 | json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) 73 | f.write("\n") 74 | 75 | logger.info(f"Updated {standard_file}") 76 | 77 | 78 | # Define the standard file name 79 | standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" 80 | 81 | # Find all JSON files in the directory 82 | dir_path = I18N_FILE_PATH 83 | languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] 84 | 85 | # Load the standard file 86 | with open(standard_file, "r", encoding="utf-8") as f: 87 | standard_data = json.load(f, object_pairs_hook=OrderedDict) 88 | 89 | # Loop through each language file 90 | for lang_file in languages: 91 | # Load the language file 92 | with open(lang_file, "r", encoding="utf-8") as f: 93 | lang_data = json.load(f, object_pairs_hook=OrderedDict) 94 | 95 | # Find the difference between the language file and the standard file 96 | diff = set(standard_data.keys()) - set(lang_data.keys()) 97 | 98 | miss = set(lang_data.keys()) - set(standard_data.keys()) 99 | 100 | # Add any missing keys to the language file 101 | for key in diff: 102 | lang_data[key] = "#!" + key 103 | logger.info(f"Added missing key: {key} to {lang_file}") 104 | 105 | # Del any extra keys to the language file 106 | for key in miss: 107 | del lang_data[key] 108 | logger.info(f"Del extra key: {key} from {lang_file}") 109 | 110 | # Sort the keys of the language file to match the order of the standard file 111 | lang_data = OrderedDict( 112 | sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) 113 | ) 114 | 115 | # Save the updated language file 116 | with open(lang_file, "w", encoding="utf-8") as f: 117 | json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) 118 | f.write("\n") 119 | 120 | logger.info(f"Updated {lang_file}") 121 | 122 | logger.info("Done") 123 | -------------------------------------------------------------------------------- /fish_speech/inference_engine/reference_loader.py: -------------------------------------------------------------------------------- 1 | import io 2 | from hashlib import sha256 3 | from pathlib import Path 4 | from typing import Callable, Literal, Tuple 5 | 6 | import torch 7 | import torchaudio 8 | from loguru import logger 9 | 10 | from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture 11 | from fish_speech.utils.file import ( 12 | AUDIO_EXTENSIONS, 13 | audio_to_bytes, 14 | list_files, 15 | read_ref_text, 16 | ) 17 | from fish_speech.utils.schema import ServeReferenceAudio 18 | 19 | 20 | class ReferenceLoader: 21 | 22 | def __init__(self) -> None: 23 | """ 24 | Component of the TTSInferenceEngine class. 25 | Loads and manages the cache for the reference audio and text. 26 | """ 27 | self.ref_by_id: dict = {} 28 | self.ref_by_hash: dict = {} 29 | 30 | # Make Pylance happy (attribut/method not defined...) 31 | self.decoder_model: FireflyArchitecture 32 | self.encode_reference: Callable 33 | 34 | # Define the torchaudio backend 35 | backends = torchaudio.list_audio_backends() 36 | if "ffmpeg" in backends: 37 | self.backend = "ffmpeg" 38 | else: 39 | self.backend = "soundfile" 40 | 41 | def load_by_id( 42 | self, 43 | id: str, 44 | use_cache: Literal["on", "off"], 45 | ) -> Tuple: 46 | 47 | # Load the references audio and text by id 48 | ref_folder = Path("references") / id 49 | ref_folder.mkdir(parents=True, exist_ok=True) 50 | ref_audios = list_files( 51 | ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False 52 | ) 53 | 54 | if use_cache == "off" or id not in self.ref_by_id: 55 | # If the references are not already loaded, encode them 56 | prompt_tokens = [ 57 | self.encode_reference( 58 | # decoder_model=self.decoder_model, 59 | reference_audio=audio_to_bytes(str(ref_audio)), 60 | enable_reference_audio=True, 61 | ) 62 | for ref_audio in ref_audios 63 | ] 64 | prompt_texts = [ 65 | read_ref_text(str(ref_audio.with_suffix(".lab"))) 66 | for ref_audio in ref_audios 67 | ] 68 | self.ref_by_id[id] = (prompt_tokens, prompt_texts) 69 | 70 | else: 71 | # Reuse already encoded references 72 | logger.info("Use same references") 73 | prompt_tokens, prompt_texts = self.ref_by_id[id] 74 | 75 | return prompt_tokens, prompt_texts 76 | 77 | def load_by_hash( 78 | self, 79 | references: list[ServeReferenceAudio], 80 | use_cache: Literal["on", "off"], 81 | ) -> Tuple: 82 | 83 | # Load the references audio and text by hash 84 | audio_hashes = [sha256(ref.audio).hexdigest() for ref in references] 85 | 86 | cache_used = False 87 | prompt_tokens, prompt_texts = [], [] 88 | for i, ref in enumerate(references): 89 | if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash: 90 | # If the references are not already loaded, encode them 91 | prompt_tokens.append( 92 | self.encode_reference( 93 | reference_audio=ref.audio, 94 | enable_reference_audio=True, 95 | ) 96 | ) 97 | prompt_texts.append(ref.text) 98 | self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts) 99 | 100 | else: 101 | # Reuse already encoded references 102 | prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]] 103 | cache_used = True 104 | 105 | if cache_used: 106 | logger.info("Use same references") 107 | 108 | return prompt_tokens, prompt_texts 109 | 110 | def load_audio(self, reference_audio, sr): 111 | """ 112 | Load the audio data from a file or bytes. 113 | """ 114 | if len(reference_audio) > 255 or not Path(reference_audio).exists(): 115 | audio_data = reference_audio 116 | reference_audio = io.BytesIO(audio_data) 117 | 118 | waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend) 119 | 120 | if waveform.shape[0] > 1: 121 | waveform = torch.mean(waveform, dim=0, keepdim=True) 122 | 123 | if original_sr != sr: 124 | resampler = torchaudio.transforms.Resample( 125 | orig_freq=original_sr, new_freq=sr 126 | ) 127 | waveform = resampler(waveform) 128 | 129 | audio = waveform.squeeze().numpy() 130 | return audio 131 | -------------------------------------------------------------------------------- /fish_speech/inference_engine/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import wave 3 | from dataclasses import dataclass 4 | from typing import Literal, Optional, Tuple 5 | 6 | import numpy as np 7 | 8 | 9 | @dataclass 10 | class InferenceResult: 11 | code: Literal["header", "segment", "error", "final"] 12 | audio: Optional[Tuple[int, np.ndarray]] 13 | error: Optional[Exception] 14 | 15 | 16 | def wav_chunk_header( 17 | sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1 18 | ) -> bytes: 19 | buffer = io.BytesIO() 20 | 21 | with wave.open(buffer, "wb") as wav_file: 22 | wav_file.setnchannels(channels) 23 | wav_file.setsampwidth(bit_depth // 8) 24 | wav_file.setframerate(sample_rate) 25 | 26 | wav_header_bytes = buffer.getvalue() 27 | buffer.close() 28 | 29 | return wav_header_bytes 30 | -------------------------------------------------------------------------------- /fish_speech/inference_engine/vq_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from loguru import logger 5 | 6 | from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture 7 | 8 | 9 | class VQManager: 10 | 11 | def __init__(self): 12 | # Make Pylance happy (attribut/method not defined...) 13 | self.decoder_model: FireflyArchitecture 14 | self.load_audio: Callable 15 | 16 | def decode_vq_tokens(self, codes): 17 | feature_lengths = torch.tensor( 18 | [codes.shape[1]], device=self.decoder_model.device 19 | ) 20 | logger.info(f"VQ features: {codes.shape}") 21 | 22 | if isinstance(self.decoder_model, FireflyArchitecture): 23 | return self.decoder_model.decode( 24 | indices=codes[None], 25 | feature_lengths=feature_lengths, 26 | )[0].squeeze() 27 | 28 | raise ValueError(f"Unknown model type: {type(self.decoder_model)}") 29 | 30 | def encode_reference(self, reference_audio, enable_reference_audio): 31 | if enable_reference_audio and reference_audio is not None: 32 | # Load audios, and prepare basic info here 33 | reference_audio_content = self.load_audio( 34 | reference_audio, self.decoder_model.spec_transform.sample_rate 35 | ) 36 | 37 | audios = torch.from_numpy(reference_audio_content).to( 38 | self.decoder_model.device 39 | )[None, None, :] 40 | audio_lengths = torch.tensor( 41 | [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long 42 | ) 43 | logger.info( 44 | f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds" 45 | ) 46 | 47 | # VQ Encoder 48 | if isinstance(self.decoder_model, FireflyArchitecture): 49 | prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0] 50 | logger.info(f"Encoded prompt: {prompt_tokens.shape}") 51 | else: 52 | raise ValueError(f"Unknown model type: {type(self.decoder_model)}") 53 | else: 54 | prompt_tokens = None 55 | logger.info("No reference audio provided") 56 | 57 | return prompt_tokens 58 | -------------------------------------------------------------------------------- /fish_speech/models/text2semantic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/fish_speech/models/text2semantic/__init__.py -------------------------------------------------------------------------------- /fish_speech/models/text2semantic/lora.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import loralib as lora 4 | 5 | 6 | @dataclass 7 | class LoraConfig: 8 | r: int 9 | lora_alpha: float 10 | lora_dropout: float = 0.0 11 | 12 | 13 | def setup_lora(model, lora_config): 14 | # Replace the embedding layer with a LoRA layer 15 | model.embeddings = lora.Embedding( 16 | num_embeddings=model.embeddings.num_embeddings, 17 | embedding_dim=model.embeddings.embedding_dim, 18 | padding_idx=model.embeddings.padding_idx, 19 | r=lora_config.r, 20 | lora_alpha=lora_config.lora_alpha, 21 | ) 22 | 23 | model.codebook_embeddings = lora.Embedding( 24 | num_embeddings=model.codebook_embeddings.num_embeddings, 25 | embedding_dim=model.codebook_embeddings.embedding_dim, 26 | padding_idx=model.codebook_embeddings.padding_idx, 27 | r=lora_config.r, 28 | lora_alpha=lora_config.lora_alpha, 29 | ) 30 | 31 | # Replace output layer with a LoRA layer 32 | linears = [(model, "output")] 33 | 34 | # Replace all linear layers with LoRA layers 35 | for layer in model.layers: 36 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) 37 | linears.extend( 38 | [ 39 | (layer.feed_forward, "w1"), 40 | (layer.feed_forward, "w2"), 41 | (layer.feed_forward, "w3"), 42 | ] 43 | ) 44 | 45 | if hasattr(model, "fast_layers"): 46 | model.fast_embeddings = lora.Embedding( 47 | num_embeddings=model.fast_embeddings.num_embeddings, 48 | embedding_dim=model.fast_embeddings.embedding_dim, 49 | padding_idx=model.fast_embeddings.padding_idx, 50 | r=lora_config.r, 51 | lora_alpha=lora_config.lora_alpha, 52 | ) 53 | 54 | # Dual-AR model 55 | linears.append((model, "fast_output")) 56 | 57 | for layer in model.fast_layers: 58 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) 59 | linears.extend( 60 | [ 61 | (layer.feed_forward, "w1"), 62 | (layer.feed_forward, "w2"), 63 | (layer.feed_forward, "w3"), 64 | ] 65 | ) 66 | 67 | for module, layer in linears: 68 | updated_linear = lora.Linear( 69 | in_features=getattr(module, layer).in_features, 70 | out_features=getattr(module, layer).out_features, 71 | bias=getattr(module, layer).bias, 72 | r=lora_config.r, 73 | lora_alpha=lora_config.lora_alpha, 74 | lora_dropout=lora_config.lora_dropout, 75 | ) 76 | setattr(module, layer, updated_linear) 77 | 78 | # Mark only the LoRA layers as trainable 79 | lora.mark_only_lora_as_trainable(model, bias="none") 80 | 81 | 82 | def get_merged_state_dict(model): 83 | # This line will merge the state dict of the model and the LoRA parameters 84 | model.eval() 85 | 86 | # Then we need to remove the LoRA parameters from the state dict 87 | state_dict = model.state_dict() 88 | for name in list(state_dict.keys()): 89 | if "lora" in name: 90 | state_dict.pop(name) 91 | 92 | return state_dict 93 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fishaudio/fish-speech/58046eaa1a4cefb0c8cc3a3a667b34186ea02dde/fish_speech/models/vqgan/__init__.py -------------------------------------------------------------------------------- /fish_speech/models/vqgan/inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | import hydra 5 | import numpy as np 6 | import pyrootutils 7 | import soundfile as sf 8 | import torch 9 | import torchaudio 10 | from hydra import compose, initialize 11 | from hydra.utils import instantiate 12 | from loguru import logger 13 | from omegaconf import OmegaConf 14 | 15 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 16 | 17 | from fish_speech.utils.file import AUDIO_EXTENSIONS 18 | 19 | # register eval resolver 20 | OmegaConf.register_new_resolver("eval", eval) 21 | 22 | 23 | def load_model(config_name, checkpoint_path, device="cuda"): 24 | hydra.core.global_hydra.GlobalHydra.instance().clear() 25 | with initialize(version_base="1.3", config_path="../../configs"): 26 | cfg = compose(config_name=config_name) 27 | 28 | model = instantiate(cfg) 29 | state_dict = torch.load( 30 | checkpoint_path, map_location=device, mmap=True, weights_only=True 31 | ) 32 | if "state_dict" in state_dict: 33 | state_dict = state_dict["state_dict"] 34 | 35 | if any("generator" in k for k in state_dict): 36 | state_dict = { 37 | k.replace("generator.", ""): v 38 | for k, v in state_dict.items() 39 | if "generator." in k 40 | } 41 | 42 | result = model.load_state_dict(state_dict, strict=False, assign=True) 43 | model.eval() 44 | model.to(device) 45 | 46 | logger.info(f"Loaded model: {result}") 47 | return model 48 | 49 | 50 | @torch.no_grad() 51 | @click.command() 52 | @click.option( 53 | "--input-path", 54 | "-i", 55 | default="test.wav", 56 | type=click.Path(exists=True, path_type=Path), 57 | ) 58 | @click.option( 59 | "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path) 60 | ) 61 | @click.option("--config-name", default="firefly_gan_vq") 62 | @click.option( 63 | "--checkpoint-path", 64 | default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", 65 | ) 66 | @click.option( 67 | "--device", 68 | "-d", 69 | default="cuda", 70 | ) 71 | def main(input_path, output_path, config_name, checkpoint_path, device): 72 | model = load_model(config_name, checkpoint_path, device=device) 73 | 74 | if input_path.suffix in AUDIO_EXTENSIONS: 75 | logger.info(f"Processing in-place reconstruction of {input_path}") 76 | 77 | # Load audio 78 | audio, sr = torchaudio.load(str(input_path)) 79 | if audio.shape[0] > 1: 80 | audio = audio.mean(0, keepdim=True) 81 | audio = torchaudio.functional.resample( 82 | audio, sr, model.spec_transform.sample_rate 83 | ) 84 | 85 | audios = audio[None].to(device) 86 | logger.info( 87 | f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds" 88 | ) 89 | 90 | # VQ Encoder 91 | audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long) 92 | indices = model.encode(audios, audio_lengths)[0][0] 93 | 94 | logger.info(f"Generated indices of shape {indices.shape}") 95 | 96 | # Save indices 97 | np.save(output_path.with_suffix(".npy"), indices.cpu().numpy()) 98 | elif input_path.suffix == ".npy": 99 | logger.info(f"Processing precomputed indices from {input_path}") 100 | indices = np.load(input_path) 101 | indices = torch.from_numpy(indices).to(device).long() 102 | assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}" 103 | else: 104 | raise ValueError(f"Unknown input type: {input_path}") 105 | 106 | # Restore 107 | feature_lengths = torch.tensor([indices.shape[1]], device=device) 108 | fake_audios, _ = model.decode( 109 | indices=indices[None], feature_lengths=feature_lengths 110 | ) 111 | audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate 112 | 113 | logger.info( 114 | f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}" 115 | ) 116 | 117 | # Save audio 118 | fake_audio = fake_audios[0, 0].float().cpu().numpy() 119 | sf.write(output_path, fake_audio, model.spec_transform.sample_rate) 120 | logger.info(f"Saved audio to {output_path}") 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/modules/fsq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from vector_quantize_pytorch import GroupedResidualFSQ 8 | 9 | from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet 10 | 11 | 12 | @dataclass 13 | class FSQResult: 14 | z: torch.Tensor 15 | codes: torch.Tensor 16 | latents: torch.Tensor 17 | 18 | 19 | class DownsampleFiniteScalarQuantize(nn.Module): 20 | def __init__( 21 | self, 22 | input_dim: int = 512, 23 | n_codebooks: int = 9, 24 | n_groups: int = 1, 25 | levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 26 | downsample_factor: tuple[int] = (2, 2), 27 | downsample_dims: tuple[int] | None = None, 28 | ): 29 | super().__init__() 30 | 31 | if downsample_dims is None: 32 | downsample_dims = [input_dim for _ in range(len(downsample_factor))] 33 | 34 | all_dims = (input_dim,) + tuple(downsample_dims) 35 | 36 | self.residual_fsq = GroupedResidualFSQ( 37 | dim=all_dims[-1], 38 | levels=levels, 39 | num_quantizers=n_codebooks, 40 | groups=n_groups, 41 | ) 42 | 43 | self.downsample_factor = downsample_factor 44 | self.downsample_dims = downsample_dims 45 | 46 | self.downsample = nn.Sequential( 47 | *[ 48 | nn.Sequential( 49 | FishConvNet( 50 | all_dims[idx], 51 | all_dims[idx + 1], 52 | kernel_size=factor, 53 | stride=factor, 54 | ), 55 | ConvNeXtBlock(dim=all_dims[idx + 1]), 56 | ) 57 | for idx, factor in enumerate(downsample_factor) 58 | ] 59 | ) 60 | 61 | self.upsample = nn.Sequential( 62 | *[ 63 | nn.Sequential( 64 | FishTransConvNet( 65 | all_dims[idx + 1], 66 | all_dims[idx], 67 | kernel_size=factor, 68 | stride=factor, 69 | ), 70 | ConvNeXtBlock(dim=all_dims[idx]), 71 | ) 72 | for idx, factor in reversed(list(enumerate(downsample_factor))) 73 | ] 74 | ) 75 | 76 | self.apply(self._init_weights) 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, (nn.Conv1d, nn.Linear)): 80 | nn.init.trunc_normal_(m.weight, std=0.02) 81 | nn.init.constant_(m.bias, 0) 82 | 83 | def forward(self, z) -> FSQResult: 84 | original_shape = z.shape 85 | z = self.downsample(z) 86 | quantized, indices = self.residual_fsq(z.mT) 87 | result = FSQResult( 88 | z=quantized.mT, 89 | codes=indices.mT, 90 | latents=z, 91 | ) 92 | result.z = self.upsample(result.z) 93 | 94 | # Pad or crop z to match original shape 95 | diff = original_shape[-1] - result.z.shape[-1] 96 | left = diff // 2 97 | right = diff - left 98 | 99 | if diff > 0: 100 | result.z = F.pad(result.z, (left, right)) 101 | elif diff < 0: 102 | result.z = result.z[..., -left:right] 103 | 104 | return result 105 | 106 | def encode(self, z): 107 | z = self.downsample(z) 108 | _, indices = self.residual_fsq(z.mT) 109 | indices = rearrange(indices, "g b l r -> b (g r) l") 110 | return indices 111 | 112 | def decode(self, indices: torch.Tensor): 113 | indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) 114 | z_q = self.residual_fsq.get_output_from_indices(indices) 115 | z_q = self.upsample(z_q.mT) 116 | return z_q 117 | -------------------------------------------------------------------------------- /fish_speech/models/vqgan/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import torch 3 | from matplotlib import pyplot as plt 4 | 5 | matplotlib.use("Agg") 6 | 7 | 8 | def convert_pad_shape(pad_shape): 9 | l = pad_shape[::-1] 10 | pad_shape = [item for sublist in l for item in sublist] 11 | return pad_shape 12 | 13 | 14 | def sequence_mask(length, max_length=None): 15 | if max_length is None: 16 | max_length = length.max() 17 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 18 | return x.unsqueeze(0) < length.unsqueeze(1) 19 | 20 | 21 | def init_weights(m, mean=0.0, std=0.01): 22 | classname = m.__class__.__name__ 23 | if classname.find("Conv") != -1: 24 | m.weight.data.normal_(mean, std) 25 | 26 | 27 | def get_padding(kernel_size, dilation=1): 28 | return int((kernel_size * dilation - dilation) / 2) 29 | 30 | 31 | def plot_mel(data, titles=None): 32 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 33 | 34 | if titles is None: 35 | titles = [None for i in range(len(data))] 36 | 37 | plt.tight_layout() 38 | 39 | for i in range(len(data)): 40 | mel = data[i] 41 | 42 | if isinstance(mel, torch.Tensor): 43 | mel = mel.float().detach().cpu().numpy() 44 | 45 | axes[i][0].imshow(mel, origin="lower") 46 | axes[i][0].set_aspect(2.5, adjustable="box") 47 | axes[i][0].set_ylim(0, mel.shape[0]) 48 | axes[i][0].set_title(titles[i], fontsize="medium") 49 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 50 | axes[i][0].set_anchor("W") 51 | 52 | return fig 53 | 54 | 55 | def slice_segments(x, ids_str, segment_size=4): 56 | ret = torch.zeros_like(x[:, :, :segment_size]) 57 | for i in range(x.size(0)): 58 | idx_str = ids_str[i] 59 | idx_end = idx_str + segment_size 60 | ret[i] = x[i, :, idx_str:idx_end] 61 | 62 | return ret 63 | 64 | 65 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 66 | b, d, t = x.size() 67 | if x_lengths is None: 68 | x_lengths = t 69 | ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) 70 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) 71 | ret = slice_segments(x, ids_str, segment_size) 72 | return ret, ids_str 73 | 74 | 75 | @torch.jit.script 76 | def fused_add_tanh_sigmoid_multiply(in_act, n_channels): 77 | n_channels_int = n_channels[0] 78 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 79 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 80 | acts = t_act * s_act 81 | 82 | return acts 83 | 84 | 85 | def avg_with_mask(x, mask): 86 | assert mask.dtype == torch.float, "Mask should be float" 87 | 88 | if mask.ndim == 2: 89 | mask = mask.unsqueeze(1) 90 | 91 | if mask.shape[1] == 1: 92 | mask = mask.expand_as(x) 93 | 94 | return (x * mask).sum() / mask.sum() 95 | -------------------------------------------------------------------------------- /fish_speech/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def get_cosine_schedule_with_warmup_lr_lambda( 5 | current_step: int, 6 | *, 7 | num_warmup_steps: int | float, 8 | num_training_steps: int, 9 | num_cycles: float = 0.5, 10 | final_lr_ratio: float = 0.0, 11 | ): 12 | if 0 < num_warmup_steps < 1: # float mode 13 | num_warmup_steps = int(num_warmup_steps * num_training_steps) 14 | 15 | if current_step < num_warmup_steps: 16 | return float(current_step) / float(max(1, num_warmup_steps)) 17 | 18 | progress = float(current_step - num_warmup_steps) / float( 19 | max(1, num_training_steps - num_warmup_steps) 20 | ) 21 | 22 | return max( 23 | final_lr_ratio, 24 | 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), 25 | ) 26 | 27 | 28 | def get_constant_schedule_with_warmup_lr_lambda( 29 | current_step: int, 30 | *, 31 | num_warmup_steps: int | float, 32 | num_training_steps: int | None = None, 33 | ): 34 | if 0 < num_warmup_steps < 1: # float mode 35 | num_warmup_steps = int(num_warmup_steps * num_training_steps) 36 | 37 | if current_step < num_warmup_steps: 38 | return float(current_step) / float(max(1, num_warmup_steps)) 39 | 40 | return 1.0 41 | -------------------------------------------------------------------------------- /fish_speech/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .clean import clean_text 2 | from .spliter import split_text 3 | 4 | __all__ = ["clean_text", "split_text"] 5 | -------------------------------------------------------------------------------- /fish_speech/text/clean.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | SYMBOLS_MAPPING = { 4 | "‘": "'", 5 | "’": "'", 6 | } 7 | 8 | REPLACE_SYMBOL_REGEX = re.compile( 9 | "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) 10 | ) 11 | 12 | 13 | EMOJI_REGEX = re.compile( 14 | "[" 15 | "\U0001f600-\U0001f64f" # emoticons 16 | "\U0001f300-\U0001f5ff" # symbols & pictographs 17 | "\U0001f680-\U0001f6ff" # transport & map symbols 18 | "\U0001f1e0-\U0001f1ff" # flags (iOS) 19 | "]+", 20 | flags=re.UNICODE, 21 | ) 22 | 23 | 24 | def clean_text(text): 25 | # Clean the text 26 | text = text.strip() 27 | 28 | # Replace all chinese symbols with their english counterparts 29 | text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) 30 | 31 | # Remove emojis 32 | text = EMOJI_REGEX.sub(r"", text) 33 | 34 | # Remove continuous periods (...) and commas (,,,) 35 | text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text) 36 | 37 | return text 38 | -------------------------------------------------------------------------------- /fish_speech/text/spliter.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | from fish_speech.text.clean import clean_text 5 | 6 | 7 | def utf_8_len(text: str): 8 | return len(text.encode("utf-8")) 9 | 10 | 11 | def break_text(texts, length, splits: set): 12 | for text in texts: 13 | if utf_8_len(text) <= length: 14 | yield text 15 | continue 16 | 17 | curr = "" 18 | for char in text: 19 | curr += char 20 | 21 | if char in splits: 22 | yield curr 23 | curr = "" 24 | 25 | if curr: 26 | yield curr 27 | 28 | 29 | def break_text_by_length(texts, length): 30 | for text in texts: 31 | if utf_8_len(text) <= length: 32 | yield text 33 | continue 34 | 35 | curr = "" 36 | for char in text: 37 | curr += char 38 | 39 | if utf_8_len(curr) >= length: 40 | yield curr 41 | curr = "" 42 | 43 | if curr: 44 | yield curr 45 | 46 | 47 | def add_cleaned(curr, segments): 48 | curr = curr.strip() 49 | if curr and not all(c.isspace() or c in string.punctuation for c in curr): 50 | segments.append(curr) 51 | 52 | 53 | def protect_float(text): 54 | # Turns 3.14 into <3_f_14> to prevent splitting 55 | return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) 56 | 57 | 58 | def unprotect_float(text): 59 | # Turns <3_f_14> into 3.14 60 | return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) 61 | 62 | 63 | def split_text(text, length): 64 | text = clean_text(text) 65 | 66 | # Break the text into pieces with following rules: 67 | # 1. Split the text at ".", "!", "?" if text is NOT a float 68 | # 2. If the text is longer than length, split at "," 69 | # 3. If the text is still longer than length, split at " " 70 | # 4. If the text is still longer than length, split at any character to length 71 | 72 | texts = [text] 73 | texts = map(protect_float, texts) 74 | texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) 75 | texts = map(unprotect_float, texts) 76 | texts = break_text(texts, length, {",", ","}) 77 | texts = break_text(texts, length, {" "}) 78 | texts = list(break_text_by_length(texts, length)) 79 | 80 | # Then, merge the texts into segments with length <= length 81 | segments = [] 82 | curr = "" 83 | 84 | for text in texts: 85 | if utf_8_len(curr) + utf_8_len(text) <= length: 86 | curr += text 87 | else: 88 | add_cleaned(curr, segments) 89 | curr = text 90 | 91 | if curr: 92 | add_cleaned(curr, segments) 93 | 94 | return segments 95 | 96 | 97 | if __name__ == "__main__": 98 | # Test the split_text function 99 | 100 | text = "This is a test sentence. This is another test sentence. And a third one." 101 | 102 | assert split_text(text, 50) == [ 103 | "This is a test sentence.", 104 | "This is another test sentence. And a third one.", 105 | ] 106 | assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] 107 | assert split_text(" ", 10) == [] 108 | assert split_text("a", 10) == ["a"] 109 | 110 | text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." 111 | assert split_text(text, 50) == [ 112 | "This is a test sentence with only commas,", 113 | "and no dots, and no exclamation marks,", 114 | "and no question marks, and no newlines.", 115 | ] 116 | 117 | text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." 118 | # First half split at " ", second half split at "," 119 | assert split_text(text, 50) == [ 120 | "This is a test sentence This is a test sentence", 121 | "This is a test sentence. This is a test sentence,", 122 | "This is a test sentence, This is a test sentence.", 123 | ] 124 | 125 | text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" 126 | assert split_text(text, 50) == [ 127 | "这是一段很长的中文文本,", 128 | "而且没有句号,也没有感叹号,", 129 | "也没有问号,也没有换行符.", 130 | ] 131 | -------------------------------------------------------------------------------- /fish_speech/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["USE_LIBUV"] = "0" 4 | import sys 5 | from typing import Optional 6 | 7 | import hydra 8 | import lightning as L 9 | import pyrootutils 10 | import torch 11 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 12 | from lightning.pytorch.loggers import Logger 13 | from lightning.pytorch.strategies import DDPStrategy 14 | from omegaconf import DictConfig, OmegaConf 15 | 16 | os.environ.pop("SLURM_NTASKS", None) 17 | os.environ.pop("SLURM_JOB_NAME", None) 18 | os.environ.pop("SLURM_NTASKS_PER_NODE", None) 19 | 20 | # register eval resolver and root 21 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 22 | 23 | # Allow TF32 on Ampere GPUs 24 | torch.set_float32_matmul_precision("high") 25 | torch.backends.cudnn.allow_tf32 = True 26 | 27 | # register eval resolver 28 | OmegaConf.register_new_resolver("eval", eval) 29 | 30 | import fish_speech.utils as utils 31 | 32 | log = utils.RankedLogger(__name__, rank_zero_only=True) 33 | 34 | 35 | @utils.task_wrapper 36 | def train(cfg: DictConfig) -> tuple[dict, dict]: 37 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 38 | training. 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | Args: 42 | cfg (DictConfig): Configuration composed by Hydra. 43 | Returns: 44 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 45 | """ # noqa: E501 46 | 47 | # set seed for random number generators in pytorch, numpy and python.random 48 | if cfg.get("seed"): 49 | L.seed_everything(cfg.seed, workers=False) 50 | 51 | if cfg.get("deterministic"): 52 | torch.use_deterministic_algorithms(True) 53 | 54 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 55 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 56 | 57 | log.info(f"Instantiating model <{cfg.model._target_}>") 58 | model: LightningModule = hydra.utils.instantiate(cfg.model) 59 | 60 | log.info("Instantiating callbacks...") 61 | callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 62 | 63 | log.info("Instantiating loggers...") 64 | logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) 65 | 66 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 67 | trainer: Trainer = hydra.utils.instantiate( 68 | cfg.trainer, 69 | callbacks=callbacks, 70 | logger=logger, 71 | ) 72 | 73 | object_dict = { 74 | "cfg": cfg, 75 | "datamodule": datamodule, 76 | "model": model, 77 | "callbacks": callbacks, 78 | "logger": logger, 79 | "trainer": trainer, 80 | } 81 | 82 | if logger: 83 | log.info("Logging hyperparameters!") 84 | utils.log_hyperparameters(object_dict) 85 | 86 | if cfg.get("train"): 87 | log.info("Starting training!") 88 | 89 | ckpt_path = cfg.get("ckpt_path") 90 | auto_resume = False 91 | 92 | resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) 93 | if resume_ckpt_path is not None: 94 | ckpt_path = resume_ckpt_path 95 | auto_resume = True 96 | 97 | if ckpt_path is not None: 98 | log.info(f"Resuming from checkpoint: {ckpt_path}") 99 | 100 | # resume weights only is disabled for auto-resume 101 | if cfg.get("resume_weights_only") and auto_resume is False: 102 | log.info("Resuming weights only!") 103 | ckpt = torch.load(ckpt_path, map_location=model.device) 104 | if "state_dict" in ckpt: 105 | ckpt = ckpt["state_dict"] 106 | err = model.load_state_dict(ckpt, strict=False) 107 | log.info(f"Error loading state dict: {err}") 108 | ckpt_path = None 109 | 110 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 111 | 112 | train_metrics = trainer.callback_metrics 113 | 114 | if cfg.get("test"): 115 | log.info("Starting testing!") 116 | ckpt_path = trainer.checkpoint_callback.best_model_path 117 | if ckpt_path == "": 118 | log.warning("Best ckpt not found! Using current weights for testing...") 119 | ckpt_path = cfg.get("ckpt_path") 120 | 121 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 122 | log.info(f"Best ckpt path: {ckpt_path}") 123 | 124 | test_metrics = trainer.callback_metrics 125 | 126 | # merge train and test metrics 127 | metric_dict = {**train_metrics, **test_metrics} 128 | 129 | return metric_dict, object_dict 130 | 131 | 132 | @hydra.main( 133 | version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" 134 | ) 135 | def main(cfg: DictConfig) -> Optional[float]: 136 | # train the model 137 | train(cfg) 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /fish_speech/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .braceexpand import braceexpand 2 | from .context import autocast_exclude_mps 3 | from .file import get_latest_checkpoint 4 | from .instantiators import instantiate_callbacks, instantiate_loggers 5 | from .logger import RankedLogger 6 | from .logging_utils import log_hyperparameters 7 | from .rich_utils import enforce_tags, print_config_tree 8 | from .utils import extras, get_metric_value, set_seed, task_wrapper 9 | 10 | __all__ = [ 11 | "enforce_tags", 12 | "extras", 13 | "get_metric_value", 14 | "RankedLogger", 15 | "instantiate_callbacks", 16 | "instantiate_loggers", 17 | "log_hyperparameters", 18 | "print_config_tree", 19 | "task_wrapper", 20 | "braceexpand", 21 | "get_latest_checkpoint", 22 | "autocast_exclude_mps", 23 | "set_seed", 24 | ] 25 | -------------------------------------------------------------------------------- /fish_speech/utils/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | 3 | import torch 4 | 5 | 6 | def autocast_exclude_mps( 7 | device_type: str, dtype: torch.dtype 8 | ) -> nullcontext | torch.autocast: 9 | return ( 10 | nullcontext() 11 | if torch.backends.mps.is_available() 12 | else torch.autocast(device_type, dtype) 13 | ) 14 | -------------------------------------------------------------------------------- /fish_speech/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from loguru import logger 6 | from natsort import natsorted 7 | 8 | AUDIO_EXTENSIONS = { 9 | ".mp3", 10 | ".wav", 11 | ".flac", 12 | ".ogg", 13 | ".m4a", 14 | ".wma", 15 | ".aac", 16 | ".aiff", 17 | ".aif", 18 | ".aifc", 19 | } 20 | 21 | VIDEO_EXTENSIONS = { 22 | ".mp4", 23 | ".avi", 24 | } 25 | 26 | 27 | def get_latest_checkpoint(path: Path | str) -> Path | None: 28 | # Find the latest checkpoint 29 | ckpt_dir = Path(path) 30 | 31 | if ckpt_dir.exists() is False: 32 | return None 33 | 34 | ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) 35 | if len(ckpts) == 0: 36 | return None 37 | 38 | return ckpts[-1] 39 | 40 | 41 | def audio_to_bytes(file_path): 42 | if not file_path or not Path(file_path).exists(): 43 | return None 44 | with open(file_path, "rb") as wav_file: 45 | wav = wav_file.read() 46 | return wav 47 | 48 | 49 | def read_ref_text(ref_text): 50 | path = Path(ref_text) 51 | if path.exists() and path.is_file(): 52 | with path.open("r", encoding="utf-8") as file: 53 | return file.read() 54 | return ref_text 55 | 56 | 57 | def list_files( 58 | path: Union[Path, str], 59 | extensions: set[str] = set(), 60 | recursive: bool = False, 61 | sort: bool = True, 62 | ) -> list[Path]: 63 | """List files in a directory. 64 | 65 | Args: 66 | path (Path): Path to the directory. 67 | extensions (set, optional): Extensions to filter. Defaults to None. 68 | recursive (bool, optional): Whether to search recursively. Defaults to False. 69 | sort (bool, optional): Whether to sort the files. Defaults to True. 70 | 71 | Returns: 72 | list: List of files. 73 | """ 74 | 75 | if isinstance(path, str): 76 | path = Path(path) 77 | 78 | if not path.exists(): 79 | raise FileNotFoundError(f"Directory {path} does not exist.") 80 | 81 | files = [file for ext in extensions for file in path.rglob(f"*{ext}")] 82 | 83 | if sort: 84 | files = natsorted(files) 85 | 86 | return files 87 | 88 | 89 | def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: 90 | """ 91 | Load a Bert-VITS2 style filelist. 92 | """ 93 | 94 | files = set() 95 | results = [] 96 | count_duplicated, count_not_found = 0, 0 97 | 98 | LANGUAGE_TO_LANGUAGES = { 99 | "zh": ["zh", "en"], 100 | "jp": ["jp", "en"], 101 | "en": ["en"], 102 | } 103 | 104 | with open(path, "r", encoding="utf-8") as f: 105 | for line in f.readlines(): 106 | splits = line.strip().split("|", maxsplit=3) 107 | if len(splits) != 4: 108 | logger.warning(f"Invalid line: {line}") 109 | continue 110 | 111 | filename, speaker, language, text = splits 112 | file = Path(filename) 113 | language = language.strip().lower() 114 | 115 | if language == "ja": 116 | language = "jp" 117 | 118 | assert language in ["zh", "jp", "en"], f"Invalid language {language}" 119 | languages = LANGUAGE_TO_LANGUAGES[language] 120 | 121 | if file in files: 122 | logger.warning(f"Duplicated file: {file}") 123 | count_duplicated += 1 124 | continue 125 | 126 | if not file.exists(): 127 | logger.warning(f"File not found: {file}") 128 | count_not_found += 1 129 | continue 130 | 131 | results.append((file, speaker, languages, text)) 132 | 133 | if count_duplicated > 0: 134 | logger.warning(f"Total duplicated files: {count_duplicated}") 135 | 136 | if count_not_found > 0: 137 | logger.warning(f"Total files not found: {count_not_found}") 138 | 139 | return results 140 | -------------------------------------------------------------------------------- /fish_speech/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import Callback 6 | from pytorch_lightning.loggers import Logger 7 | 8 | from .logger import RankedLogger 9 | 10 | log = RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config.""" 15 | 16 | callbacks: List[Callback] = [] 17 | 18 | if not callbacks_cfg: 19 | log.warning("No callback configs found! Skipping..") 20 | return callbacks 21 | 22 | if not isinstance(callbacks_cfg, DictConfig): 23 | raise TypeError("Callbacks config must be a DictConfig!") 24 | 25 | for _, cb_conf in callbacks_cfg.items(): 26 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 27 | log.info(f"Instantiating callback <{cb_conf._target_}>") 28 | callbacks.append(hydra.utils.instantiate(cb_conf)) 29 | 30 | return callbacks 31 | 32 | 33 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 34 | """Instantiates loggers from config.""" 35 | 36 | logger: List[Logger] = [] 37 | 38 | if not logger_cfg: 39 | log.warning("No logger configs found! Skipping...") 40 | return logger 41 | 42 | if not isinstance(logger_cfg, DictConfig): 43 | raise TypeError("Logger config must be a DictConfig!") 44 | 45 | for _, lg_conf in logger_cfg.items(): 46 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 47 | log.info(f"Instantiating logger <{lg_conf._target_}>") 48 | logger.append(hydra.utils.instantiate(lg_conf)) 49 | 50 | return logger 51 | -------------------------------------------------------------------------------- /fish_speech/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = True, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log( 28 | self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 29 | ) -> None: 30 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 31 | of the process it's being logged from. If `'rank'` is provided, then the log will only 32 | occur on that rank/process. 33 | 34 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 35 | :param msg: The message to log. 36 | :param rank: The rank to log at. 37 | :param args: Additional args to pass to the underlying logging function. 38 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 39 | """ 40 | if self.isEnabledFor(level): 41 | msg, kwargs = self.process(msg, kwargs) 42 | current_rank = getattr(rank_zero_only, "rank", None) 43 | if current_rank is None: 44 | raise RuntimeError( 45 | "The `rank_zero_only.rank` needs to be set before use" 46 | ) 47 | msg = rank_prefixed_message(msg, current_rank) 48 | if self.rank_zero_only: 49 | if current_rank == 0: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | else: 52 | if rank is None: 53 | self.logger.log(level, msg, *args, **kwargs) 54 | elif current_rank == rank: 55 | self.logger.log(level, msg, *args, **kwargs) 56 | -------------------------------------------------------------------------------- /fish_speech/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | 3 | from fish_speech.utils import logger as log 4 | 5 | 6 | @rank_zero_only 7 | def log_hyperparameters(object_dict: dict) -> None: 8 | """Controls which config parts are saved by lightning loggers. 9 | 10 | Additionally saves: 11 | - Number of model parameters 12 | """ 13 | 14 | hparams = {} 15 | 16 | cfg = object_dict["cfg"] 17 | model = object_dict["model"] 18 | trainer = object_dict["trainer"] 19 | 20 | if not trainer.logger: 21 | log.warning("Logger not found! Skipping hyperparameter logging...") 22 | return 23 | 24 | hparams["model"] = cfg["model"] 25 | 26 | # save number of model parameters 27 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 28 | hparams["model/params/trainable"] = sum( 29 | p.numel() for p in model.parameters() if p.requires_grad 30 | ) 31 | hparams["model/params/non_trainable"] = sum( 32 | p.numel() for p in model.parameters() if not p.requires_grad 33 | ) 34 | 35 | hparams["data"] = cfg["data"] 36 | hparams["trainer"] = cfg["trainer"] 37 | 38 | hparams["callbacks"] = cfg.get("callbacks") 39 | hparams["extras"] = cfg.get("extras") 40 | 41 | hparams["task_name"] = cfg.get("task_name") 42 | hparams["tags"] = cfg.get("tags") 43 | hparams["ckpt_path"] = cfg.get("ckpt_path") 44 | hparams["seed"] = cfg.get("seed") 45 | 46 | # send hparams to all loggers 47 | for logger in trainer.loggers: 48 | logger.log_hyperparams(hparams) 49 | -------------------------------------------------------------------------------- /fish_speech/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from fish_speech.utils import logger as log 13 | 14 | 15 | @rank_zero_only 16 | def print_config_tree( 17 | cfg: DictConfig, 18 | print_order: Sequence[str] = ( 19 | "data", 20 | "model", 21 | "callbacks", 22 | "logger", 23 | "trainer", 24 | "paths", 25 | "extras", 26 | ), 27 | resolve: bool = False, 28 | save_to_file: bool = False, 29 | ) -> None: 30 | """Prints content of DictConfig using Rich library and its tree structure. 31 | 32 | Args: 33 | cfg (DictConfig): Configuration composed by Hydra. 34 | print_order (Sequence[str], optional): Determines in what order config components are printed. 35 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 36 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 37 | """ # noqa: E501 38 | 39 | style = "dim" 40 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 41 | 42 | queue = [] 43 | 44 | # add fields from `print_order` to queue 45 | for field in print_order: 46 | ( 47 | queue.append(field) 48 | if field in cfg 49 | else log.warning( 50 | f"Field '{field}' not found in config. " 51 | + f"Skipping '{field}' config printing..." 52 | ) 53 | ) 54 | 55 | # add all the other fields to queue (not specified in `print_order`) 56 | for field in cfg: 57 | if field not in queue: 58 | queue.append(field) 59 | 60 | # generate config tree from queue 61 | for field in queue: 62 | branch = tree.add(field, style=style, guide_style=style) 63 | 64 | config_group = cfg[field] 65 | if isinstance(config_group, DictConfig): 66 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 67 | else: 68 | branch_content = str(config_group) 69 | 70 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 71 | 72 | # print config tree 73 | rich.print(tree) 74 | 75 | # save config tree to file 76 | if save_to_file: 77 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 78 | rich.print(tree, file=file) 79 | 80 | 81 | @rank_zero_only 82 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 83 | """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 84 | 85 | if not cfg.get("tags"): 86 | if "id" in HydraConfig().cfg.hydra.job: 87 | raise ValueError("Specify tags before launching a multirun!") 88 | 89 | log.warning("No tags provided in config. Prompting user to input tags...") 90 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 91 | tags = [t.strip() for t in tags.split(",") if t != ""] 92 | 93 | with open_dict(cfg): 94 | cfg.tags = tags 95 | 96 | log.info(f"Tags: {cfg.tags}") 97 | 98 | if save_to_file: 99 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 100 | rich.print(cfg.tags, file=file) 101 | -------------------------------------------------------------------------------- /fish_speech/utils/spectrogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio.functional as F 3 | from torch import Tensor, nn 4 | from torchaudio.transforms import MelScale 5 | 6 | 7 | class LinearSpectrogram(nn.Module): 8 | def __init__( 9 | self, 10 | n_fft=2048, 11 | win_length=2048, 12 | hop_length=512, 13 | center=False, 14 | mode="pow2_sqrt", 15 | ): 16 | super().__init__() 17 | 18 | self.n_fft = n_fft 19 | self.win_length = win_length 20 | self.hop_length = hop_length 21 | self.center = center 22 | self.mode = mode 23 | self.return_complex = True 24 | 25 | self.register_buffer("window", torch.hann_window(win_length), persistent=False) 26 | 27 | def forward(self, y: Tensor) -> Tensor: 28 | if y.ndim == 3: 29 | y = y.squeeze(1) 30 | 31 | y = torch.nn.functional.pad( 32 | y.unsqueeze(1), 33 | ( 34 | (self.win_length - self.hop_length) // 2, 35 | (self.win_length - self.hop_length + 1) // 2, 36 | ), 37 | mode="reflect", 38 | ).squeeze(1) 39 | 40 | spec = torch.stft( 41 | y, 42 | self.n_fft, 43 | hop_length=self.hop_length, 44 | win_length=self.win_length, 45 | window=self.window, 46 | center=self.center, 47 | pad_mode="reflect", 48 | normalized=False, 49 | onesided=True, 50 | return_complex=self.return_complex, 51 | ) 52 | 53 | if self.return_complex: 54 | spec = torch.view_as_real(spec) 55 | 56 | if self.mode == "pow2_sqrt": 57 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 58 | 59 | return spec 60 | 61 | 62 | class LogMelSpectrogram(nn.Module): 63 | def __init__( 64 | self, 65 | sample_rate=44100, 66 | n_fft=2048, 67 | win_length=2048, 68 | hop_length=512, 69 | n_mels=128, 70 | center=False, 71 | f_min=0.0, 72 | f_max=None, 73 | ): 74 | super().__init__() 75 | 76 | self.sample_rate = sample_rate 77 | self.n_fft = n_fft 78 | self.win_length = win_length 79 | self.hop_length = hop_length 80 | self.center = center 81 | self.n_mels = n_mels 82 | self.f_min = f_min 83 | self.f_max = f_max or float(sample_rate // 2) 84 | 85 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) 86 | 87 | fb = F.melscale_fbanks( 88 | n_freqs=self.n_fft // 2 + 1, 89 | f_min=self.f_min, 90 | f_max=self.f_max, 91 | n_mels=self.n_mels, 92 | sample_rate=self.sample_rate, 93 | norm="slaney", 94 | mel_scale="slaney", 95 | ) 96 | self.register_buffer( 97 | "fb", 98 | fb, 99 | persistent=False, 100 | ) 101 | 102 | def compress(self, x: Tensor) -> Tensor: 103 | return torch.log(torch.clamp(x, min=1e-5)) 104 | 105 | def decompress(self, x: Tensor) -> Tensor: 106 | return torch.exp(x) 107 | 108 | def apply_mel_scale(self, x: Tensor) -> Tensor: 109 | return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) 110 | 111 | def forward( 112 | self, x: Tensor, return_linear: bool = False, sample_rate: int = None 113 | ) -> Tensor: 114 | if sample_rate is not None and sample_rate != self.sample_rate: 115 | x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) 116 | 117 | linear = self.spectrogram(x) 118 | x = self.apply_mel_scale(linear) 119 | x = self.compress(x) 120 | 121 | if return_linear: 122 | return x, self.compress(linear) 123 | 124 | return x 125 | -------------------------------------------------------------------------------- /fish_speech/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from importlib.util import find_spec 4 | from typing import Callable 5 | 6 | import numpy as np 7 | import torch 8 | from omegaconf import DictConfig 9 | 10 | from .logger import RankedLogger 11 | from .rich_utils import enforce_tags, print_config_tree 12 | 13 | log = RankedLogger(__name__, rank_zero_only=True) 14 | 15 | 16 | def extras(cfg: DictConfig) -> None: 17 | """Applies optional utilities before the task is started. 18 | 19 | Utilities: 20 | - Ignoring python warnings 21 | - Setting tags from command line 22 | - Rich config printing 23 | """ 24 | 25 | # return if no `extras` config 26 | if not cfg.get("extras"): 27 | log.warning("Extras config not found! ") 28 | return 29 | 30 | # disable python warnings 31 | if cfg.extras.get("ignore_warnings"): 32 | log.info("Disabling python warnings! ") 33 | warnings.filterwarnings("ignore") 34 | 35 | # prompt user to input tags from command line if none are provided in the config 36 | if cfg.extras.get("enforce_tags"): 37 | log.info("Enforcing tags! ") 38 | enforce_tags(cfg, save_to_file=True) 39 | 40 | # pretty print config tree using Rich library 41 | if cfg.extras.get("print_config"): 42 | log.info("Printing config tree with Rich! ") 43 | print_config_tree(cfg, resolve=True, save_to_file=True) 44 | 45 | 46 | def task_wrapper(task_func: Callable) -> Callable: 47 | """Optional decorator that controls the failure behavior when executing the task function. 48 | 49 | This wrapper can be used to: 50 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 51 | - save the exception to a `.log` file 52 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 53 | - etc. (adjust depending on your needs) 54 | 55 | Example: 56 | ``` 57 | @utils.task_wrapper 58 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 59 | 60 | ... 61 | 62 | return metric_dict, object_dict 63 | ``` 64 | """ # noqa: E501 65 | 66 | def wrap(cfg: DictConfig): 67 | # execute the task 68 | try: 69 | metric_dict, object_dict = task_func(cfg=cfg) 70 | 71 | # things to do if exception occurs 72 | except Exception as ex: 73 | # save exception to `.log` file 74 | log.exception("") 75 | 76 | # some hyperparameter combinations might be invalid or 77 | # cause out-of-memory errors so when using hparam search 78 | # plugins like Optuna, you might want to disable 79 | # raising the below exception to avoid multirun failure 80 | raise ex 81 | 82 | # things to always do after either success or exception 83 | finally: 84 | # display output dir path in terminal 85 | log.info(f"Output dir: {cfg.paths.run_dir}") 86 | 87 | # always close wandb run (even if exception occurs so multirun won't fail) 88 | if find_spec("wandb"): # check if wandb is installed 89 | import wandb 90 | 91 | if wandb.run: 92 | log.info("Closing wandb!") 93 | wandb.finish() 94 | 95 | return metric_dict, object_dict 96 | 97 | return wrap 98 | 99 | 100 | def get_metric_value(metric_dict: dict, metric_name: str) -> float: 101 | """Safely retrieves value of the metric logged in LightningModule.""" 102 | 103 | if not metric_name: 104 | log.info("Metric name is None! Skipping metric value retrieval...") 105 | return None 106 | 107 | if metric_name not in metric_dict: 108 | raise Exception( 109 | f"Metric value not found! \n" 110 | "Make sure metric name logged in LightningModule is correct!\n" 111 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 112 | ) 113 | 114 | metric_value = metric_dict[metric_name].item() 115 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 116 | 117 | return metric_value 118 | 119 | 120 | def set_seed(seed: int): 121 | if seed < 0: 122 | seed = -seed 123 | if seed > (1 << 31): 124 | seed = 1 << 31 125 | 126 | random.seed(seed) 127 | np.random.seed(seed) 128 | torch.manual_seed(seed) 129 | 130 | if torch.cuda.is_available(): 131 | torch.cuda.manual_seed(seed) 132 | torch.cuda.manual_seed_all(seed) 133 | 134 | if torch.backends.cudnn.is_available(): 135 | torch.backends.cudnn.deterministic = True 136 | torch.backends.cudnn.benchmark = False 137 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Fish Speech 2 | site_description: Targeting SOTA TTS solutions. 3 | site_url: https://speech.fish.audio 4 | 5 | # Repository 6 | repo_name: fishaudio/fish-speech 7 | repo_url: https://github.com/fishaudio/fish-speech 8 | edit_uri: blob/main/docs 9 | 10 | # Copyright 11 | copyright: Copyright © 2023-2025 by Fish Audio 12 | 13 | theme: 14 | name: material 15 | favicon: assets/figs/logo-circle.png 16 | language: en 17 | features: 18 | - content.action.edit 19 | - content.action.view 20 | - navigation.tracking 21 | - navigation.footer 22 | # - navigation.tabs 23 | - search 24 | - search.suggest 25 | - search.highlight 26 | - search.share 27 | - content.code.copy 28 | icon: 29 | logo: fontawesome/solid/fish 30 | 31 | palette: 32 | # Palette toggle for automatic mode 33 | - media: "(prefers-color-scheme)" 34 | toggle: 35 | icon: material/brightness-auto 36 | name: Switch to light mode 37 | 38 | # Palette toggle for light mode 39 | - media: "(prefers-color-scheme: light)" 40 | scheme: default 41 | toggle: 42 | icon: material/brightness-7 43 | name: Switch to dark mode 44 | primary: black 45 | font: 46 | code: Roboto Mono 47 | 48 | # Palette toggle for dark mode 49 | - media: "(prefers-color-scheme: dark)" 50 | scheme: slate 51 | toggle: 52 | icon: material/brightness-4 53 | name: Switch to light mode 54 | primary: black 55 | font: 56 | code: Roboto Mono 57 | 58 | nav: 59 | - Introduction: index.md 60 | - Finetune: finetune.md 61 | - Inference: inference.md 62 | - Start Agent: start_agent.md 63 | - Samples: samples.md 64 | 65 | # Plugins 66 | plugins: 67 | - search: 68 | separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' 69 | lang: 70 | - en 71 | - zh 72 | - ja 73 | - pt 74 | - ko 75 | - i18n: 76 | docs_structure: folder 77 | languages: 78 | - locale: en 79 | name: English 80 | default: true 81 | build: true 82 | - locale: zh 83 | name: 简体中文 84 | build: true 85 | nav: 86 | - 介绍: zh/index.md 87 | - 微调: zh/finetune.md 88 | - 推理: zh/inference.md 89 | - 启动Agent: zh/start_agent.md 90 | - 例子: zh/samples.md 91 | - locale: ja 92 | name: 日本語 93 | build: true 94 | nav: 95 | - Fish Speech の紹介: ja/index.md 96 | - 微調整: ja/finetune.md 97 | - 推論: ja/inference.md 98 | - スタートエージェント: ja/start_agent.md 99 | - サンプル: ja/samples.md 100 | - locale: pt 101 | name: Português (Brasil) 102 | build: true 103 | nav: 104 | - Introdução: pt/index.md 105 | - Ajuste Fino: pt/finetune.md 106 | - Inferência: pt/inference.md 107 | - Agente inicial: pt/start_agent.md 108 | - Amostras: pt/samples.md 109 | - locale: ko 110 | name: 한국어 111 | build: true 112 | nav: 113 | - 소개: ko/index.md 114 | - 파인튜닝: ko/finetune.md 115 | - 추론: ko/inference.md 116 | - 샘플: ko/samples.md 117 | 118 | markdown_extensions: 119 | - pymdownx.highlight: 120 | anchor_linenums: true 121 | line_spans: __span 122 | pygments_lang_class: true 123 | - pymdownx.inlinehilite 124 | - pymdownx.snippets 125 | - pymdownx.superfences 126 | - admonition 127 | - pymdownx.details 128 | - pymdownx.superfences 129 | - attr_list 130 | - md_in_html 131 | - pymdownx.superfences 132 | 133 | extra_css: 134 | - stylesheets/extra.css 135 | 136 | extra: 137 | social: 138 | - icon: fontawesome/brands/discord 139 | link: https://discord.gg/Es5qTB9BcN 140 | - icon: fontawesome/brands/docker 141 | link: https://hub.docker.com/r/fishaudio/fish-speech 142 | - icon: fontawesome/brands/qq 143 | link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093 144 | homepage: https://speech.fish.audio 145 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fish-speech" 3 | version = "0.1.0" 4 | authors = [ 5 | {name = "Lengyue", email = "lengyue@lengyue.me"}, 6 | ] 7 | description = "Fish Speech" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | keywords = ["TTS", "Speech"] 11 | license = {text = "Apache-2.0"} 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | ] 15 | dependencies = [ 16 | "numpy<=1.26.4", 17 | "transformers>=4.45.2", 18 | "datasets==2.18.0", 19 | "lightning>=2.1.0", 20 | "hydra-core>=1.3.2", 21 | "tensorboard>=2.14.1", 22 | "natsort>=8.4.0", 23 | "einops>=0.7.0", 24 | "librosa>=0.10.1", 25 | "rich>=13.5.3", 26 | "gradio>5.0.0", 27 | "wandb>=0.15.11", 28 | "grpcio>=1.58.0", 29 | "kui>=1.6.0", 30 | "uvicorn>=0.30.0", 31 | "loguru>=0.6.0", 32 | "loralib>=0.1.2", 33 | "pyrootutils>=1.0.4", 34 | "vector_quantize_pytorch==1.14.24", 35 | "resampy>=0.4.3", 36 | "einx[torch]==0.2.2", 37 | "zstandard>=0.22.0", 38 | "pydub", 39 | "pyaudio", 40 | "faster_whisper", 41 | "modelscope==1.17.1", 42 | "funasr==1.1.5", 43 | "opencc-python-reimplemented==0.1.7", 44 | "silero-vad", 45 | "ormsgpack", 46 | "tiktoken>=0.8.0", 47 | "pydantic==2.9.2", 48 | "cachetools", 49 | ] 50 | 51 | [project.optional-dependencies] 52 | stable = [ 53 | "torch<=2.4.1", 54 | "torchaudio", 55 | ] 56 | 57 | [build-system] 58 | requires = ["setuptools", "setuptools-scm"] 59 | build-backend = "setuptools.build_meta" 60 | 61 | [tool.setuptools] 62 | packages = ["fish_speech", "tools"] 63 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "data", 4 | "filelists" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /tools/api_server.py: -------------------------------------------------------------------------------- 1 | import re 2 | from threading import Lock 3 | 4 | import pyrootutils 5 | import uvicorn 6 | from kui.asgi import ( 7 | Depends, 8 | FactoryClass, 9 | HTTPException, 10 | HttpRoute, 11 | Kui, 12 | OpenAPI, 13 | Routes, 14 | ) 15 | from kui.cors import CORSConfig 16 | from kui.openapi.specification import Info 17 | from kui.security import bearer_auth 18 | from loguru import logger 19 | from typing_extensions import Annotated 20 | 21 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 22 | 23 | from tools.server.api_utils import MsgPackRequest, parse_args 24 | from tools.server.exception_handler import ExceptionHandler 25 | from tools.server.model_manager import ModelManager 26 | from tools.server.views import routes 27 | 28 | 29 | class API(ExceptionHandler): 30 | def __init__(self): 31 | self.args = parse_args() 32 | 33 | def api_auth(endpoint): 34 | async def verify(token: Annotated[str, Depends(bearer_auth)]): 35 | if token != self.args.api_key: 36 | raise HTTPException(401, None, "Invalid token") 37 | return await endpoint() 38 | 39 | async def passthrough(): 40 | return await endpoint() 41 | 42 | if self.args.api_key is not None: 43 | return verify 44 | else: 45 | return passthrough 46 | 47 | self.routes = Routes( 48 | routes, # keep existing routes 49 | http_middlewares=[api_auth], # apply api_auth middleware 50 | ) 51 | 52 | # OpenAPIの設定 53 | self.openapi = OpenAPI( 54 | Info( 55 | { 56 | "title": "Fish Speech API", 57 | "version": "1.5.0", 58 | } 59 | ), 60 | ).routes 61 | 62 | # Initialize the app 63 | self.app = Kui( 64 | routes=self.routes + self.openapi[1:], # Remove the default route 65 | exception_handlers={ 66 | HTTPException: self.http_exception_handler, 67 | Exception: self.other_exception_handler, 68 | }, 69 | factory_class=FactoryClass(http=MsgPackRequest), 70 | cors_config=CORSConfig(), 71 | ) 72 | 73 | # Add the state variables 74 | self.app.state.lock = Lock() 75 | self.app.state.device = self.args.device 76 | self.app.state.max_text_length = self.args.max_text_length 77 | 78 | # Associate the app with the model manager 79 | self.app.on_startup(self.initialize_app) 80 | 81 | async def initialize_app(self, app: Kui): 82 | # Make the ModelManager available to the views 83 | app.state.model_manager = ModelManager( 84 | mode=self.args.mode, 85 | device=self.args.device, 86 | half=self.args.half, 87 | compile=self.args.compile, 88 | asr_enabled=self.args.load_asr_model, 89 | llama_checkpoint_path=self.args.llama_checkpoint_path, 90 | decoder_checkpoint_path=self.args.decoder_checkpoint_path, 91 | decoder_config_name=self.args.decoder_config_name, 92 | ) 93 | 94 | logger.info(f"Startup done, listening server at http://{self.args.listen}") 95 | 96 | 97 | # Each worker process created by Uvicorn has its own memory space, 98 | # meaning that models and variables are not shared between processes. 99 | # Therefore, any variables (like `llama_queue` or `decoder_model`) 100 | # will not be shared across workers. 101 | 102 | # Multi-threading for deep learning can cause issues, such as inconsistent 103 | # outputs if multiple threads access the same buffers simultaneously. 104 | # Instead, it's better to use multiprocessing or independent models per thread. 105 | 106 | if __name__ == "__main__": 107 | api = API() 108 | 109 | # IPv6 address format is [xxxx:xxxx::xxxx]:port 110 | match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen) 111 | if match: 112 | host, port = match.groups() # IPv6 113 | else: 114 | host, port = api.args.listen.split(":") # IPv4 115 | 116 | uvicorn.run( 117 | api.app, 118 | host=host, 119 | port=int(port), 120 | workers=api.args.workers, 121 | log_level="info", 122 | ) 123 | -------------------------------------------------------------------------------- /tools/download_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from huggingface_hub import hf_hub_download 4 | 5 | 6 | # Download 7 | def check_and_download_files(repo_id, file_list, local_dir): 8 | os.makedirs(local_dir, exist_ok=True) 9 | for file in file_list: 10 | file_path = os.path.join(local_dir, file) 11 | if not os.path.exists(file_path): 12 | print(f"{file} 不存在,从 Hugging Face 仓库下载...") 13 | hf_hub_download( 14 | repo_id=repo_id, 15 | filename=file, 16 | resume_download=True, 17 | local_dir=local_dir, 18 | local_dir_use_symlinks=False, 19 | ) 20 | else: 21 | print(f"{file} 已存在,跳过下载。") 22 | 23 | 24 | # 1st 25 | repo_id_1 = "fishaudio/fish-speech-1.5" 26 | local_dir_1 = "./checkpoints/fish-speech-1.5" 27 | files_1 = [ 28 | ".gitattributes", 29 | "model.pth", 30 | "README.md", 31 | "special_tokens.json", 32 | "tokenizer.tiktoken", 33 | "config.json", 34 | "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", 35 | ] 36 | 37 | # 3rd 38 | repo_id_3 = "fishaudio/fish-speech-1" 39 | local_dir_3 = "./" 40 | files_3 = [ 41 | "ffmpeg.exe", 42 | "ffprobe.exe", 43 | ] 44 | 45 | # 4th 46 | repo_id_4 = "SpicyqSama007/fish-speech-packed" 47 | local_dir_4 = "./" 48 | files_4 = [ 49 | "asr-label-win-x64.exe", 50 | ] 51 | 52 | check_and_download_files(repo_id_1, files_1, local_dir_1) 53 | 54 | check_and_download_files(repo_id_3, files_3, local_dir_3) 55 | check_and_download_files(repo_id_4, files_4, local_dir_4) 56 | -------------------------------------------------------------------------------- /tools/extract_model.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | from loguru import logger 4 | 5 | 6 | @click.command() 7 | @click.argument("model_path") 8 | @click.argument("output_path") 9 | def main(model_path, output_path): 10 | if model_path == output_path: 11 | logger.error("Model path and output path are the same") 12 | return 13 | 14 | logger.info(f"Loading model from {model_path}") 15 | state_dict = torch.load(model_path, map_location="cpu")["state_dict"] 16 | torch.save(state_dict, output_path) 17 | logger.info(f"Model saved to {output_path}") 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /tools/llama/merge_lora.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from copy import deepcopy 3 | from pathlib import Path 4 | 5 | import click 6 | import hydra 7 | import torch 8 | from hydra import compose, initialize 9 | from hydra.utils import instantiate 10 | from loguru import logger 11 | 12 | from fish_speech.models.text2semantic.llama import BaseTransformer 13 | from fish_speech.models.text2semantic.lora import get_merged_state_dict 14 | 15 | 16 | @click.command() 17 | @click.option("--lora-config", type=str, default="r_8_alpha_16") 18 | @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4") 19 | @click.option("--lora-weight", type=str, required=True) 20 | @click.option("--output", type=str, required=True) 21 | def merge(lora_config, base_weight, lora_weight, output): 22 | output = Path(output) 23 | logger.info( 24 | f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}" 25 | ) 26 | 27 | with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): 28 | cfg = compose(config_name=lora_config) 29 | 30 | lora_config = instantiate(cfg) 31 | logger.info(f"Loaded lora model with config {lora_config}") 32 | 33 | llama_model = BaseTransformer.from_pretrained( 34 | path=base_weight, 35 | load_weights=True, 36 | lora_config=lora_config, 37 | ) 38 | logger.info(f"Loaded llama model") 39 | 40 | llama_state_dict = llama_model.state_dict() 41 | llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k} 42 | llama_state_dict_copy = deepcopy(llama_state_dict) 43 | lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False) 44 | 45 | if "state_dict" in llama_state_dict: 46 | llama_state_dict = llama_state_dict["state_dict"] 47 | 48 | if "state_dict" in lora_state_dict: 49 | lora_state_dict = lora_state_dict["state_dict"] 50 | 51 | # remove prefix model. 52 | if any(k.startswith("model.") for k in llama_state_dict.keys()): 53 | llama_state_dict = { 54 | k.replace("model.", ""): v 55 | for k, v in llama_state_dict.items() 56 | if k.startswith("model.") 57 | } 58 | if any(k.startswith("model.") for k in lora_state_dict.keys()): 59 | lora_state_dict = { 60 | k.replace("model.", ""): v 61 | for k, v in lora_state_dict.items() 62 | if k.startswith("model.") 63 | } 64 | 65 | logger.info(f"Found {len(llama_state_dict)} keys in llama model") 66 | logger.info(f"Found {len(lora_state_dict)} keys in lora model") 67 | 68 | merged_state_dict = llama_state_dict | lora_state_dict 69 | llama_model.load_state_dict(merged_state_dict, strict=True) 70 | logger.info(f"Merged model loaded") 71 | 72 | # Trigger eval mode to merge lora 73 | llama_model.eval() 74 | llama_model.save_pretrained(output, drop_lora=True) 75 | logger.info(f"Saved merged model to {output}, validating") 76 | 77 | new_state_dict = torch.load(output / "model.pth", map_location="cpu") 78 | original_keys = set(llama_state_dict_copy.keys()) 79 | 80 | tolerance = 1e-5 81 | for key in original_keys: 82 | diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() 83 | if diff_l1 > tolerance: 84 | logger.info(f"Significant difference found in key: {key}") 85 | break 86 | 87 | if diff_l1 <= tolerance: 88 | logger.warning( 89 | "Merged model seems identical to the original model. Further validation might be needed." 90 | ) 91 | else: 92 | logger.info("Merged model is different from the original model, check passed") 93 | 94 | 95 | if __name__ == "__main__": 96 | merge() 97 | -------------------------------------------------------------------------------- /tools/run_webui.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | import pyrootutils 6 | import torch 7 | from loguru import logger 8 | 9 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 10 | 11 | from fish_speech.inference_engine import TTSInferenceEngine 12 | from fish_speech.models.text2semantic.inference import launch_thread_safe_queue 13 | from fish_speech.models.vqgan.inference import load_model as load_decoder_model 14 | from fish_speech.utils.schema import ServeTTSRequest 15 | from tools.webui import build_app 16 | from tools.webui.inference import get_inference_wrapper 17 | 18 | # Make einx happy 19 | os.environ["EINX_FILTER_TRACEBACK"] = "false" 20 | 21 | 22 | def parse_args(): 23 | parser = ArgumentParser() 24 | parser.add_argument( 25 | "--llama-checkpoint-path", 26 | type=Path, 27 | default="checkpoints/fish-speech-1.5", 28 | ) 29 | parser.add_argument( 30 | "--decoder-checkpoint-path", 31 | type=Path, 32 | default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", 33 | ) 34 | parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") 35 | parser.add_argument("--device", type=str, default="cuda") 36 | parser.add_argument("--half", action="store_true") 37 | parser.add_argument("--compile", action="store_true") 38 | parser.add_argument("--max-gradio-length", type=int, default=0) 39 | parser.add_argument("--theme", type=str, default="light") 40 | 41 | return parser.parse_args() 42 | 43 | 44 | if __name__ == "__main__": 45 | args = parse_args() 46 | args.precision = torch.half if args.half else torch.bfloat16 47 | 48 | # Check if MPS or CUDA is available 49 | if torch.backends.mps.is_available(): 50 | args.device = "mps" 51 | logger.info("mps is available, running on mps.") 52 | elif not torch.cuda.is_available(): 53 | logger.info("CUDA is not available, running on CPU.") 54 | args.device = "cpu" 55 | 56 | logger.info("Loading Llama model...") 57 | llama_queue = launch_thread_safe_queue( 58 | checkpoint_path=args.llama_checkpoint_path, 59 | device=args.device, 60 | precision=args.precision, 61 | compile=args.compile, 62 | ) 63 | 64 | logger.info("Loading VQ-GAN model...") 65 | decoder_model = load_decoder_model( 66 | config_name=args.decoder_config_name, 67 | checkpoint_path=args.decoder_checkpoint_path, 68 | device=args.device, 69 | ) 70 | 71 | logger.info("Decoder model loaded, warming up...") 72 | 73 | # Create the inference engine 74 | inference_engine = TTSInferenceEngine( 75 | llama_queue=llama_queue, 76 | decoder_model=decoder_model, 77 | compile=args.compile, 78 | precision=args.precision, 79 | ) 80 | 81 | # Dry run to check if the model is loaded correctly and avoid the first-time latency 82 | list( 83 | inference_engine.inference( 84 | ServeTTSRequest( 85 | text="Hello world.", 86 | references=[], 87 | reference_id=None, 88 | max_new_tokens=1024, 89 | chunk_length=200, 90 | top_p=0.7, 91 | repetition_penalty=1.5, 92 | temperature=0.7, 93 | format="wav", 94 | ) 95 | ) 96 | ) 97 | 98 | logger.info("Warming up done, launching the web UI...") 99 | 100 | # Get the inference function with the immutable arguments 101 | inference_fct = get_inference_wrapper(inference_engine) 102 | 103 | app = build_app(inference_fct, args.theme) 104 | app.launch(show_api=True) 105 | -------------------------------------------------------------------------------- /tools/server/agent/__init__.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from functools import partial 3 | 4 | import ormsgpack 5 | 6 | from tools.server.agent.generate import generate_responses 7 | from tools.server.agent.pre_generation_utils import prepare_messages 8 | 9 | 10 | def execute_request(input_queue, tokenizer, config, request, device): 11 | """ 12 | This function prepares the conversation, encodes the request, 13 | sends the generation request, and handles decoding/streaming. 14 | It returns a response generator (ServeResponse or ServeStreamResponse). 15 | """ 16 | prompt, im_end_id = prepare_messages(request, tokenizer, config) 17 | yield from generate_responses( 18 | input_queue, tokenizer, config, request, prompt, im_end_id, device 19 | ) 20 | 21 | 22 | def response_generator(req, llama_queue, tokenizer, config, device): 23 | """ 24 | Non-streaming response wrapper for the chat endpoint. 25 | Only returns the final result. 26 | """ 27 | generator = execute_request(llama_queue, tokenizer, config, req, device) 28 | return next(generator) 29 | 30 | 31 | async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode): 32 | """ 33 | Streaming response wrapper for the chat endpoint. 34 | Returns the response in chunks. 35 | """ 36 | generator = execute_request(llama_queue, tokenizer, config, req, device) 37 | for i in generator: 38 | if json_mode: 39 | body = i.model_dump_json().encode("utf-8") 40 | yield b"data: " + body + b"\n\n" 41 | else: 42 | body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) 43 | yield struct.pack("I", len(body)) + body 44 | 45 | 46 | def get_response_generator( 47 | llama_queue, tokenizer, config, req, device, json_mode 48 | ) -> partial: 49 | """ 50 | Get the correct response generator based on the request. 51 | """ 52 | if not req.streaming: 53 | return partial(response_generator, req, llama_queue, tokenizer, config, device) 54 | else: 55 | return partial( 56 | streaming_generator, req, llama_queue, tokenizer, config, device, json_mode 57 | ) 58 | -------------------------------------------------------------------------------- /tools/server/agent/generate.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from fish_speech.utils.schema import ServeMessage, ServeResponse, ServeStreamResponse 4 | from tools.server.agent.generation_utils import ( 5 | initialize_decode_buffers, 6 | process_response_tokens, 7 | send_reset_buffer, 8 | ) 9 | from tools.server.agent.pre_generation_utils import ( 10 | create_generation_request, 11 | send_generation_request, 12 | ) 13 | 14 | 15 | def generate_responses( 16 | input_queue, tokenizer, config, request, prompt, im_end_id, device 17 | ): 18 | """ 19 | Main generation function that handles the conversation, encodes the request, 20 | sends the generation request, and handles decoding/streaming. 21 | It returns a response generator (ServeResponse or ServeStreamResponse). 22 | """ 23 | stats = {} 24 | start = time.time() 25 | stats["start_time"] = start 26 | stats["tokens_count"] = 0 27 | 28 | # Prepare and send the generation request 29 | req = create_generation_request(prompt, request, im_end_id, device) 30 | response_queue = send_generation_request(input_queue, req) 31 | decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples) 32 | 33 | while True: 34 | response = response_queue.get() 35 | 36 | # Handle abnormal finish or error 37 | if response in ["stop", "error"]: 38 | finish_reason = response 39 | break 40 | 41 | # Process the response tokens 42 | is_first_token = stats["tokens_count"] == 0 43 | responses = process_response_tokens( 44 | response, 45 | tokenizer, 46 | config, 47 | request, 48 | decode_buffer, 49 | parts, 50 | finished, 51 | im_end_id, 52 | stats, 53 | start, 54 | is_first_token, 55 | ) 56 | 57 | # Yield the responses if streaming 58 | if request.streaming and responses: 59 | for r in responses: 60 | yield r 61 | 62 | stats["tokens_count"] += 1 63 | 64 | # Check if all samples are finished 65 | if all(finished): 66 | finish_reason = "stop" 67 | break 68 | 69 | # Finalize the response 70 | final_responses = finalize_response( 71 | request, finished, decode_buffer, tokenizer, parts, stats, finish_reason 72 | ) 73 | for fr in final_responses: 74 | yield fr 75 | 76 | 77 | def finalize_response( 78 | request, finished, decode_buffer, tokenizer, parts, stats, finish_reason 79 | ): 80 | """ 81 | Finalize the response by sending the remaining text buffers. 82 | """ 83 | responses = [] 84 | 85 | # Send the remaining text buffers 86 | for sample_id in range(request.num_samples): 87 | responses.extend( 88 | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) 89 | ) 90 | 91 | # Calculate the final stats 92 | stats["total_time"] = (time.time() - stats["start_time"]) * 1000 93 | stats["total_tokens"] = stats["tokens_count"] 94 | 95 | # If streaming, send the final chunks for each sample 96 | if request.streaming: 97 | for sample_id in range(request.num_samples): 98 | if finished[sample_id]: 99 | continue 100 | responses.append( 101 | ServeStreamResponse( 102 | finish_reason=finish_reason, stats=stats, sample_id=sample_id 103 | ) 104 | ) 105 | else: 106 | # If not streaming, send the full messages for each sample 107 | full_messages = [ 108 | ServeMessage(role="assistant", parts=parts[i]) 109 | for i in range(request.num_samples) 110 | ] 111 | responses.append( 112 | ServeResponse( 113 | messages=full_messages, 114 | finish_reason=finish_reason, 115 | stats=stats, 116 | ) 117 | ) 118 | 119 | return responses 120 | -------------------------------------------------------------------------------- /tools/server/agent/generation_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from fish_speech.utils.schema import ( 4 | ServeStreamDelta, 5 | ServeStreamResponse, 6 | ServeTextPart, 7 | ServeVQPart, 8 | ) 9 | 10 | 11 | def initialize_decode_buffers(num_samples): 12 | """Initialise the decode buffers for each sample.""" 13 | decode_buffer = [[] for _ in range(num_samples)] 14 | parts = [[] for _ in range(num_samples)] 15 | finished = [False for _ in range(num_samples)] 16 | return decode_buffer, parts, finished 17 | 18 | 19 | def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request): 20 | """Send the remaining text buffer for a sample.""" 21 | if len(decode_buffer[sample_id]) == 0: 22 | return [] 23 | 24 | decoded = tokenizer.decode(decode_buffer[sample_id]) 25 | part = ServeTextPart(text=decoded) 26 | 27 | responses = [] 28 | if request.streaming: 29 | responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part))) 30 | else: 31 | parts[sample_id].append(part) 32 | 33 | decode_buffer[sample_id] = [] 34 | return responses 35 | 36 | 37 | def handle_semantic_tokens(tokens, config, sample_id, parts, request): 38 | """Handle the semantic tokens returned by the model.""" 39 | responses = [] 40 | _tokens = tokens[1:].clone() 41 | 42 | if not config.share_codebook_embeddings: 43 | for i in range(len(_tokens)): 44 | _tokens[i] -= config.codebook_size * i 45 | 46 | # If streaming, send the VQ parts directly 47 | if request.streaming: 48 | responses.append( 49 | ServeStreamResponse( 50 | sample_id=sample_id, 51 | delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), 52 | ) 53 | ) 54 | else: 55 | # If not streaming, accumulate the VQ parts 56 | if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart): 57 | parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) 58 | else: 59 | # Accumulate the codes 60 | for codebook_id, value in enumerate(_tokens): 61 | parts[sample_id][-1].codes[codebook_id].append(value.item()) 62 | 63 | return responses 64 | 65 | 66 | def process_response_tokens( 67 | response, 68 | tokenizer, 69 | config, 70 | request, 71 | decode_buffer, 72 | parts, 73 | finished, 74 | im_end_id, 75 | stats, 76 | start, 77 | is_first_token, 78 | ): 79 | """Process the response tokens returned by the model.""" 80 | responses = [] 81 | for sample_id, tokens in enumerate(response): 82 | if finished[sample_id]: 83 | continue 84 | 85 | # End of the conversation 86 | if tokens[0] == im_end_id: 87 | finished[sample_id] = True 88 | # Send the remaining text buffer 89 | responses.extend( 90 | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) 91 | ) 92 | if request.streaming: 93 | responses.append( 94 | ServeStreamResponse( 95 | sample_id=sample_id, 96 | finish_reason="stop", 97 | stats=stats, 98 | ) 99 | ) 100 | continue 101 | 102 | # Check if the token is semantic 103 | is_semantic = ( 104 | tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id 105 | ) 106 | 107 | if is_semantic: 108 | # Before the semantic tokens, send the remaining text buffer 109 | responses.extend( 110 | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) 111 | ) 112 | responses.extend( 113 | handle_semantic_tokens(tokens, config, sample_id, parts, request) 114 | ) 115 | else: 116 | # Accumulate the text tokens (not implemented?) 117 | decode_buffer[sample_id].append(tokens[0, 0]) 118 | 119 | if is_first_token: 120 | stats["time_to_first_token"] = (time.time() - start) * 1000 121 | 122 | return responses 123 | -------------------------------------------------------------------------------- /tools/server/agent/pre_generation_utils.py: -------------------------------------------------------------------------------- 1 | import queue 2 | 3 | from fish_speech.conversation import Conversation, Message 4 | from fish_speech.models.text2semantic.inference import GenerateRequest 5 | from fish_speech.tokenizer import IM_END_TOKEN 6 | 7 | 8 | def prepare_messages(request, tokenizer, config): 9 | """ 10 | Reorganise the provided list of messages into a conversation. 11 | Encode the conversation for inference. 12 | """ 13 | # Convert the messages to ConversationMessage objects 14 | messages = [msg.to_conversation_message() for msg in request.messages] 15 | 16 | if len(messages) < 1: 17 | raise ValueError("At least one message is required") 18 | 19 | # Check the last message to determine the next step 20 | last_role = messages[-1].role 21 | match last_role: 22 | case "user": 23 | # The last message is from the user, ask the assistant to respond with a new message 24 | messages.append( 25 | Message(role="assistant", parts=[], add_im_end=False, modality="voice") 26 | ) 27 | case "raw": 28 | # The last message is raw text, ask the assistant to complete it 29 | messages[-1].add_im_start = False 30 | messages[-1].add_im_end = False 31 | messages[-1].modality = "voice" 32 | case "assistant": 33 | # The last message is from the assistant, ask the assistant to continue 34 | messages[-1].add_im_end = False 35 | case _: 36 | # We expect it to be assistant if not user or raw 37 | raise ValueError("The last message must be from the assistant, user or raw") 38 | 39 | # Create a conversation object and encode it for inference 40 | conv = Conversation(messages=messages) 41 | prompt = conv.encode_for_inference( 42 | tokenizer=tokenizer, num_codebooks=config.num_codebooks 43 | ) 44 | im_end_id = tokenizer.get_token_id(IM_END_TOKEN) 45 | 46 | return prompt, im_end_id 47 | 48 | 49 | def create_generation_request(prompt, request, im_end_id, device): 50 | """ 51 | Convert the request into a dictionary that can be sent to the model for generation. 52 | """ 53 | req = { 54 | "prompt": prompt.to(device), 55 | "max_new_tokens": request.max_new_tokens, 56 | "im_end_id": im_end_id, 57 | "temperature": request.temperature, 58 | "top_p": request.top_p, 59 | "repetition_penalty": request.repetition_penalty, 60 | "num_samples": request.num_samples, 61 | "early_stop_threshold": request.early_stop_threshold, 62 | } 63 | return req 64 | 65 | 66 | def send_generation_request(input_queue, req): 67 | """ 68 | Send the generation request to the model and return a queue to get the response. 69 | """ 70 | response_queue = queue.Queue() 71 | input_queue.put(GenerateRequest(req, response_queue)) 72 | return response_queue 73 | -------------------------------------------------------------------------------- /tools/server/api_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from http import HTTPStatus 3 | from typing import Annotated, Any 4 | 5 | import ormsgpack 6 | from baize.datastructures import ContentType 7 | from kui.asgi import HTTPException, HttpRequest 8 | 9 | from fish_speech.inference_engine import TTSInferenceEngine 10 | from fish_speech.utils.schema import ServeTTSRequest 11 | from tools.server.inference import inference_wrapper as inference 12 | 13 | 14 | def parse_args(): 15 | parser = ArgumentParser() 16 | parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") 17 | parser.add_argument("--load-asr-model", action="store_true") 18 | parser.add_argument( 19 | "--llama-checkpoint-path", 20 | type=str, 21 | default="checkpoints/fish-speech-1.5", 22 | ) 23 | parser.add_argument( 24 | "--decoder-checkpoint-path", 25 | type=str, 26 | default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", 27 | ) 28 | parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") 29 | parser.add_argument("--device", type=str, default="cuda") 30 | parser.add_argument("--half", action="store_true") 31 | parser.add_argument("--compile", action="store_true") 32 | parser.add_argument("--max-text-length", type=int, default=0) 33 | parser.add_argument("--listen", type=str, default="127.0.0.1:8080") 34 | parser.add_argument("--workers", type=int, default=1) 35 | parser.add_argument("--api-key", type=str, default=None) 36 | 37 | return parser.parse_args() 38 | 39 | 40 | class MsgPackRequest(HttpRequest): 41 | async def data( 42 | self, 43 | ) -> Annotated[ 44 | Any, ContentType("application/msgpack"), ContentType("application/json") 45 | ]: 46 | if self.content_type == "application/msgpack": 47 | return ormsgpack.unpackb(await self.body) 48 | 49 | elif self.content_type == "application/json": 50 | return await self.json 51 | 52 | raise HTTPException( 53 | HTTPStatus.UNSUPPORTED_MEDIA_TYPE, 54 | headers={"Accept": "application/msgpack, application/json"}, 55 | ) 56 | 57 | 58 | async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): 59 | for chunk in inference(req, engine): 60 | if isinstance(chunk, bytes): 61 | yield chunk 62 | 63 | 64 | async def buffer_to_async_generator(buffer): 65 | yield buffer 66 | 67 | 68 | def get_content_type(audio_format): 69 | if audio_format == "wav": 70 | return "audio/wav" 71 | elif audio_format == "flac": 72 | return "audio/flac" 73 | elif audio_format == "mp3": 74 | return "audio/mpeg" 75 | else: 76 | return "application/octet-stream" 77 | -------------------------------------------------------------------------------- /tools/server/exception_handler.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from http import HTTPStatus 3 | 4 | from kui.asgi import HTTPException, JSONResponse 5 | 6 | 7 | class ExceptionHandler: 8 | 9 | async def http_exception_handler(self, exc: HTTPException): 10 | return JSONResponse( 11 | dict( 12 | statusCode=exc.status_code, 13 | message=exc.content, 14 | error=HTTPStatus(exc.status_code).phrase, 15 | ), 16 | exc.status_code, 17 | exc.headers, 18 | ) 19 | 20 | async def other_exception_handler(self, exc: Exception): 21 | traceback.print_exc() 22 | 23 | status = HTTPStatus.INTERNAL_SERVER_ERROR 24 | return JSONResponse( 25 | dict(statusCode=status, message=str(exc), error=status.phrase), 26 | status, 27 | ) 28 | -------------------------------------------------------------------------------- /tools/server/inference.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | 3 | import numpy as np 4 | from kui.asgi import HTTPException 5 | 6 | from fish_speech.inference_engine import TTSInferenceEngine 7 | from fish_speech.utils.schema import ServeTTSRequest 8 | 9 | AMPLITUDE = 32768 # Needs an explaination 10 | 11 | 12 | def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine): 13 | """ 14 | Wrapper for the inference function. 15 | Used in the API server. 16 | """ 17 | count = 0 18 | for result in engine.inference(req): 19 | match result.code: 20 | case "header": 21 | if isinstance(result.audio, tuple): 22 | yield result.audio[1] 23 | 24 | case "error": 25 | raise HTTPException( 26 | HTTPStatus.INTERNAL_SERVER_ERROR, 27 | content=str(result.error), 28 | ) 29 | 30 | case "segment": 31 | count += 1 32 | if isinstance(result.audio, tuple): 33 | yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes() 34 | 35 | case "final": 36 | count += 1 37 | if isinstance(result.audio, tuple): 38 | yield result.audio[1] 39 | return None # Stop the generator 40 | 41 | if count == 0: 42 | raise HTTPException( 43 | HTTPStatus.INTERNAL_SERVER_ERROR, 44 | content="No audio generated, please check the input text.", 45 | ) 46 | -------------------------------------------------------------------------------- /tools/server/model_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from funasr import AutoModel 3 | from loguru import logger 4 | 5 | from fish_speech.inference_engine import TTSInferenceEngine 6 | from fish_speech.models.text2semantic.inference import ( 7 | launch_thread_safe_queue, 8 | launch_thread_safe_queue_agent, 9 | ) 10 | from fish_speech.models.vqgan.inference import load_model as load_decoder_model 11 | from fish_speech.utils.schema import ServeTTSRequest 12 | from tools.server.inference import inference_wrapper as inference 13 | 14 | ASR_MODEL_NAME = "iic/SenseVoiceSmall" 15 | 16 | 17 | class ModelManager: 18 | def __init__( 19 | self, 20 | mode: str, 21 | device: str, 22 | half: bool, 23 | compile: bool, 24 | asr_enabled: bool, 25 | llama_checkpoint_path: str, 26 | decoder_checkpoint_path: str, 27 | decoder_config_name: str, 28 | ) -> None: 29 | 30 | self.mode = mode 31 | self.device = device 32 | self.half = half 33 | self.compile = compile 34 | 35 | self.precision = torch.half if half else torch.bfloat16 36 | 37 | # Check if MPS or CUDA is available 38 | if torch.backends.mps.is_available(): 39 | self.device = "mps" 40 | logger.info("mps is available, running on mps.") 41 | elif not torch.cuda.is_available(): 42 | self.device = "cpu" 43 | logger.info("CUDA is not available, running on CPU.") 44 | 45 | # Load the ASR model if enabled 46 | if asr_enabled: 47 | self.load_asr_model(self.device) 48 | 49 | # Load the TTS models 50 | self.load_llama_model( 51 | llama_checkpoint_path, self.device, self.precision, self.compile, self.mode 52 | ) 53 | self.load_decoder_model( 54 | decoder_config_name, decoder_checkpoint_path, self.device 55 | ) 56 | self.tts_inference_engine = TTSInferenceEngine( 57 | llama_queue=self.llama_queue, 58 | decoder_model=self.decoder_model, 59 | precision=self.precision, 60 | compile=self.compile, 61 | ) 62 | 63 | # Warm up the models 64 | if self.mode == "tts": 65 | self.warm_up(self.tts_inference_engine) 66 | 67 | def load_asr_model(self, device, hub="ms") -> None: 68 | self.asr_model = AutoModel( 69 | model=ASR_MODEL_NAME, 70 | device=device, 71 | disable_pbar=True, 72 | hub=hub, 73 | ) 74 | logger.info("ASR model loaded.") 75 | 76 | def load_llama_model( 77 | self, checkpoint_path, device, precision, compile, mode 78 | ) -> None: 79 | 80 | if mode == "tts": 81 | self.llama_queue = launch_thread_safe_queue( 82 | checkpoint_path=checkpoint_path, 83 | device=device, 84 | precision=precision, 85 | compile=compile, 86 | ) 87 | elif mode == "agent": 88 | self.llama_queue, self.tokenizer, self.config = ( 89 | launch_thread_safe_queue_agent( 90 | checkpoint_path=checkpoint_path, 91 | device=device, 92 | precision=precision, 93 | compile=compile, 94 | ) 95 | ) 96 | else: 97 | raise ValueError(f"Invalid mode: {mode}") 98 | 99 | logger.info("LLAMA model loaded.") 100 | 101 | def load_decoder_model(self, config_name, checkpoint_path, device) -> None: 102 | self.decoder_model = load_decoder_model( 103 | config_name=config_name, 104 | checkpoint_path=checkpoint_path, 105 | device=device, 106 | ) 107 | logger.info("Decoder model loaded.") 108 | 109 | def warm_up(self, tts_inference_engine) -> None: 110 | request = ServeTTSRequest( 111 | text="Hello world.", 112 | references=[], 113 | reference_id=None, 114 | max_new_tokens=1024, 115 | chunk_length=200, 116 | top_p=0.7, 117 | repetition_penalty=1.2, 118 | temperature=0.7, 119 | format="wav", 120 | ) 121 | list(inference(request, tts_inference_engine)) 122 | logger.info("Models warmed up.") 123 | -------------------------------------------------------------------------------- /tools/server/model_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import re 3 | 4 | import librosa 5 | import torch 6 | import torchaudio 7 | from cachetools import LRUCache, cached 8 | 9 | CACHE_MAXSIZE = 10000 10 | MICRO_BATCH_SIZE = 8 11 | ASR_SAMPLE_RATE = 16000 12 | HUGE_GAP_THRESHOLD = 4000 13 | 14 | 15 | @torch.no_grad() 16 | @torch.autocast(device_type="cuda", dtype=torch.half) 17 | def batch_encode(model, audios_list: list[bytes]): 18 | audios: list[torch.Tensor] = [ 19 | ( 20 | torch.from_numpy( 21 | librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] 22 | )[None] 23 | if isinstance(audio, bytes) 24 | else audio 25 | ) 26 | for audio in audios_list 27 | ] 28 | 29 | lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) 30 | max_length = lengths.max().item() 31 | 32 | print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") 33 | 34 | padded = torch.stack( 35 | [ 36 | torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1]))) 37 | for audio in audios 38 | ] 39 | ).to(model.device) 40 | 41 | features, feature_lengths = model.encode(padded, audio_lengths=lengths) 42 | features, feature_lengths = features.cpu(), feature_lengths.cpu() 43 | 44 | return [feature[..., :length] for feature, length in zip(features, feature_lengths)] 45 | 46 | 47 | @cached( 48 | cache=LRUCache(maxsize=CACHE_MAXSIZE), 49 | key=lambda model, audios: (model.device, tuple(audios)), 50 | ) 51 | def cached_vqgan_batch_encode(model, audios: list[bytes]): 52 | return batch_encode(model, audios) 53 | 54 | 55 | @torch.no_grad() 56 | @torch.autocast(device_type="cuda", dtype=torch.half) 57 | def batch_vqgan_decode(model, features): 58 | lengths = torch.tensor( 59 | [feature.shape[-1] for feature in features], device=model.device 60 | ) 61 | max_length = lengths.max().item() 62 | padded = torch.stack( 63 | [ 64 | torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) 65 | for feature in features 66 | ] 67 | ).to(model.device) 68 | 69 | # If bs too large, we do micro batch decode 70 | audios, audio_lengths = [], [] 71 | for i in range(0, padded.shape[0], MICRO_BATCH_SIZE): 72 | audio, audio_length = model.decode( 73 | padded[i : i + MICRO_BATCH_SIZE], 74 | feature_lengths=lengths[i : i + MICRO_BATCH_SIZE], 75 | ) 76 | audios.append(audio) 77 | audio_lengths.append(audio_length) 78 | audios = torch.cat(audios, dim=0) 79 | audio_lengths = torch.cat(audio_lengths, dim=0) 80 | audios, audio_lengths = audios.cpu(), audio_lengths.cpu() 81 | 82 | return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] 83 | 84 | 85 | @torch.no_grad() 86 | def batch_asr(model, lock, audios, sr, language="auto"): 87 | resampled_audios = [] 88 | for audio in audios: 89 | audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE) 90 | assert audio.ndim == 1 91 | resampled_audios.append(audio) 92 | 93 | with lock: 94 | res = model.generate( 95 | input=resampled_audios, 96 | batch_size=len(resampled_audios), 97 | language=language, 98 | use_itn=True, 99 | ) 100 | 101 | results = [] 102 | for r, audio in zip(res, audios): 103 | text = r["text"] 104 | text = re.sub(r"<\|.*?\|>", "", text) 105 | duration = len(audio) / sr * 1000 106 | huge_gap = False 107 | 108 | if "timestamp" in r and len(r["timestamp"]) > 2: 109 | for timestamp_a, timestamp_b in zip( 110 | r["timestamp"][:-1], r["timestamp"][1:] 111 | ): 112 | # If there is a gap of more than 4 seconds, we consider it as a huge gap 113 | if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD: 114 | huge_gap = True 115 | break 116 | 117 | # Doesn't make sense to have a huge gap at the end 118 | if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD: 119 | huge_gap = True 120 | 121 | results.append( 122 | { 123 | "text": text, 124 | "duration": duration, 125 | "huge_gap": huge_gap, 126 | } 127 | ) 128 | 129 | return results 130 | -------------------------------------------------------------------------------- /tools/smart_pad.py: -------------------------------------------------------------------------------- 1 | import random 2 | from multiprocessing import Pool 3 | from pathlib import Path 4 | 5 | import click 6 | import librosa 7 | import torch.nn.functional as F 8 | import torchaudio 9 | from tqdm import tqdm 10 | 11 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files 12 | 13 | threshold = 10 ** (-50 / 20.0) 14 | 15 | 16 | def process(file): 17 | waveform, sample_rate = torchaudio.load(str(file), backend="sox") 18 | if waveform.size(0) > 1: 19 | waveform = waveform.mean(dim=0, keepdim=True) 20 | 21 | loudness = librosa.feature.rms( 22 | y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True 23 | )[0] 24 | 25 | for i in range(len(loudness) - 1, 0, -1): 26 | if loudness[i] > threshold: 27 | break 28 | 29 | end_silent_time = (len(loudness) - i) * 512 / sample_rate 30 | 31 | if end_silent_time <= 0.3: 32 | random_time = random.uniform(0.3, 0.7) - end_silent_time 33 | waveform = F.pad( 34 | waveform, (0, int(random_time * sample_rate)), mode="constant", value=0 35 | ) 36 | 37 | for i in range(len(loudness)): 38 | if loudness[i] > threshold: 39 | break 40 | 41 | start_silent_time = i * 512 / sample_rate 42 | 43 | if start_silent_time > 0.02: 44 | waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :] 45 | 46 | torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate) 47 | 48 | 49 | @click.command() 50 | @click.argument("source", type=Path) 51 | @click.option("--num-workers", type=int, default=12) 52 | def main(source, num_workers): 53 | files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True)) 54 | 55 | with Pool(num_workers) as p: 56 | list(tqdm(p.imap_unordered(process, files), total=len(files))) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /tools/vqgan/create_train_split.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | from random import Random 4 | 5 | import click 6 | from loguru import logger 7 | from pydub import AudioSegment 8 | from tqdm import tqdm 9 | 10 | from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist 11 | 12 | 13 | @click.command() 14 | @click.argument("root", type=click.Path(exists=True, path_type=Path)) 15 | @click.option("--val-ratio", type=float, default=None) 16 | @click.option("--val-count", type=int, default=None) 17 | @click.option("--filelist", default=None, type=Path) 18 | @click.option("--min-duration", default=None, type=float) 19 | @click.option("--max-duration", default=None, type=float) 20 | def main(root, val_ratio, val_count, filelist, min_duration, max_duration): 21 | if filelist: 22 | files = [i[0] for i in load_filelist(filelist)] 23 | else: 24 | files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) 25 | 26 | if min_duration is None and max_duration is None: 27 | filtered_files = list(map(str, [file.relative_to(root) for file in files])) 28 | else: 29 | filtered_files = [] 30 | for file in tqdm(files): 31 | try: 32 | audio = AudioSegment.from_file(str(file)) 33 | duration = len(audio) / 1000.0 34 | 35 | if min_duration is not None and duration < min_duration: 36 | logger.info( 37 | f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}" 38 | ) 39 | continue 40 | 41 | if max_duration is not None and duration > max_duration: 42 | logger.info( 43 | f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}" 44 | ) 45 | continue 46 | 47 | filtered_files.append(str(file.relative_to(root))) 48 | except Exception as e: 49 | logger.info(f"Error processing {file}: {e}") 50 | 51 | logger.info( 52 | f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering" 53 | ) 54 | 55 | Random(42).shuffle(filtered_files) 56 | 57 | if val_count is None and val_ratio is None: 58 | logger.info("Validation ratio and count not specified, using min(20%, 100)") 59 | val_size = min(100, math.ceil(len(filtered_files) * 0.2)) 60 | elif val_count is not None and val_ratio is not None: 61 | logger.error("Cannot specify both val_count and val_ratio") 62 | return 63 | elif val_count is not None: 64 | if val_count < 1 or val_count > len(filtered_files): 65 | logger.error("val_count must be between 1 and number of files") 66 | return 67 | val_size = val_count 68 | else: 69 | val_size = math.ceil(len(filtered_files) * val_ratio) 70 | 71 | logger.info(f"Using {val_size} files for validation") 72 | 73 | with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: 74 | f.write("\n".join(filtered_files[val_size:])) 75 | 76 | with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: 77 | f.write("\n".join(filtered_files[:val_size])) 78 | 79 | logger.info("Done") 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /tools/webui/inference.py: -------------------------------------------------------------------------------- 1 | import html 2 | from functools import partial 3 | from typing import Any, Callable 4 | 5 | from fish_speech.i18n import i18n 6 | from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest 7 | 8 | 9 | def inference_wrapper( 10 | text, 11 | reference_id, 12 | reference_audio, 13 | reference_text, 14 | max_new_tokens, 15 | chunk_length, 16 | top_p, 17 | repetition_penalty, 18 | temperature, 19 | seed, 20 | use_memory_cache, 21 | engine, 22 | ): 23 | """ 24 | Wrapper for the inference function. 25 | Used in the Gradio interface. 26 | """ 27 | 28 | if reference_audio: 29 | references = get_reference_audio(reference_audio, reference_text) 30 | else: 31 | references = [] 32 | 33 | req = ServeTTSRequest( 34 | text=text, 35 | reference_id=reference_id if reference_id else None, 36 | references=references, 37 | max_new_tokens=max_new_tokens, 38 | chunk_length=chunk_length, 39 | top_p=top_p, 40 | repetition_penalty=repetition_penalty, 41 | temperature=temperature, 42 | seed=int(seed) if seed else None, 43 | use_memory_cache=use_memory_cache, 44 | ) 45 | 46 | for result in engine.inference(req): 47 | match result.code: 48 | case "final": 49 | return result.audio, None 50 | case "error": 51 | return None, build_html_error_message(i18n(result.error)) 52 | case _: 53 | pass 54 | 55 | return None, i18n("No audio generated") 56 | 57 | 58 | def get_reference_audio(reference_audio: str, reference_text: str) -> list: 59 | """ 60 | Get the reference audio bytes. 61 | """ 62 | 63 | with open(reference_audio, "rb") as audio_file: 64 | audio_bytes = audio_file.read() 65 | 66 | return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)] 67 | 68 | 69 | def build_html_error_message(error: Any) -> str: 70 | 71 | error = error if isinstance(error, Exception) else Exception("Unknown error") 72 | 73 | return f""" 74 |
76 | {html.escape(str(error))} 77 |
78 | """ 79 | 80 | 81 | def get_inference_wrapper(engine) -> Callable: 82 | """ 83 | Get the inference function with the immutable arguments. 84 | """ 85 | 86 | return partial( 87 | inference_wrapper, 88 | engine=engine, 89 | ) 90 | -------------------------------------------------------------------------------- /tools/webui/variables.py: -------------------------------------------------------------------------------- 1 | from fish_speech.i18n import i18n 2 | 3 | HEADER_MD = f"""# Fish Speech 4 | 5 | {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} 6 | 7 | {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")} 8 | 9 | {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} 10 | 11 | {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} 12 | """ 13 | 14 | TEXTBOX_PLACEHOLDER = i18n("Put your text here.") 15 | --------------------------------------------------------------------------------