├── .github └── workflows │ └── build.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── app.py ├── assets ├── bria.mp3 ├── favicon.ico └── logo.png ├── colab_demo.ipynb ├── data ├── audio.wav └── caption.txt ├── datasets ├── sample_dataset.csv └── sample_val_dataset.csv ├── docker-compose.yml ├── fam ├── __init__.py ├── llm │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── base.py │ │ ├── flattened_encodec.py │ │ └── tilted_encodec.py │ ├── config │ │ ├── __init__.py │ │ └── finetune_params.py │ ├── decoders.py │ ├── enhancers.py │ ├── fast_inference.py │ ├── fast_inference_utils.py │ ├── fast_model.py │ ├── fast_quantize.py │ ├── finetune.py │ ├── inference.py │ ├── layers │ │ ├── __init__.py │ │ ├── attn.py │ │ ├── combined.py │ │ └── layers.py │ ├── loaders │ │ ├── __init__.py │ │ └── training_data.py │ ├── mixins │ │ ├── __init__.py │ │ ├── causal.py │ │ └── non_causal.py │ ├── model.py │ ├── preprocessing │ │ ├── __init__.py │ │ ├── audio_token_mode.py │ │ └── data_pipeline.py │ └── utils.py ├── py.typed ├── quantiser │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ └── speaker_encoder │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ └── model.py │ └── text │ │ ├── __init__.py │ │ └── tokenise.py └── telemetry │ ├── README.md │ ├── __init__.py │ └── posthog.py ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── serving.py ├── setup.py └── tests ├── llm └── loaders │ └── test_dataloader.py └── resources ├── data └── caption.txt └── datasets └── sample_dataset.csv /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Poetry Install Matrix 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11"] 15 | 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v3 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install Poetry 26 | uses: Gr1N/setup-poetry@v8 27 | 28 | - name: Install dependencies 29 | run: | 30 | poetry --version 31 | poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pkl 3 | *.flac 4 | *.npz 5 | *.wav 6 | *.m4a 7 | *.opus 8 | *.npy 9 | *wandb 10 | *.parquet 11 | *.wav 12 | *.pt 13 | *.bin 14 | *.png 15 | *.DS_Store 16 | *.idea 17 | *.ipynb_checkpoints/ 18 | *__pycache__/ 19 | *.pyc 20 | *.tsv 21 | *.bak 22 | *.tar 23 | *.db 24 | *.dat 25 | *.json 26 | 27 | # Byte-compiled / optimized / DLL files 28 | __pycache__/ 29 | *.py[cod] 30 | *$py.class 31 | 32 | # C extensions 33 | *.so 34 | 35 | # Distribution / packaging 36 | .Python 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | .eggs/ 43 | lib/ 44 | lib64/ 45 | parts/ 46 | sdist/ 47 | var/ 48 | wheels/ 49 | share/python-wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .nox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | *.py,cover 76 | .hypothesis/ 77 | .pytest_cache/ 78 | cover/ 79 | 80 | # Translations 81 | *.mo 82 | *.pot 83 | 84 | # Django stuff: 85 | *.log 86 | local_settings.py 87 | db.sqlite3 88 | db.sqlite3-journal 89 | 90 | # Flask stuff: 91 | instance/ 92 | .webassets-cache 93 | 94 | # Scrapy stuff: 95 | .scrapy 96 | 97 | # Sphinx documentation 98 | docs/_build/ 99 | 100 | # PyBuilder 101 | .pybuilder/ 102 | target/ 103 | 104 | # Jupyter Notebook 105 | .ipynb_checkpoints 106 | 107 | # IPython 108 | profile_default/ 109 | ipython_config.py 110 | 111 | # pyenv 112 | # For a library or package, you might want to ignore these files since the code is 113 | # intended to run in multiple environments; otherwise, check them in: 114 | # .python-version 115 | 116 | # pipenv 117 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 118 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 119 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 120 | # install all needed dependencies. 121 | #Pipfile.lock 122 | 123 | # poetry 124 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 125 | # This is especially recommended for binary packages to ensure reproducibility, and is more 126 | # commonly ignored for libraries. 127 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 128 | #poetry.lock 129 | 130 | # pdm 131 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 132 | #pdm.lock 133 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 134 | # in version control. 135 | # https://pdm.fming.dev/#use-with-ide 136 | .pdm.toml 137 | 138 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 139 | __pypackages__/ 140 | 141 | # Celery stuff 142 | celerybeat-schedule 143 | celerybeat.pid 144 | 145 | # SageMath parsed files 146 | *.sage.py 147 | 148 | # Environments 149 | .env 150 | .venv 151 | env/ 152 | venv/ 153 | ENV/ 154 | env.bak/ 155 | venv.bak/ 156 | 157 | # Spyder project settings 158 | .spyderproject 159 | .spyproject 160 | 161 | # Rope project settings 162 | .ropeproject 163 | 164 | # mkdocs documentation 165 | /site 166 | 167 | # mypy 168 | .mypy_cache/ 169 | .dmypy.json 170 | dmypy.json 171 | 172 | # Pyre type checker 173 | .pyre/ 174 | 175 | # pytype static type analyzer 176 | .pytype/ 177 | 178 | # Cython debug symbols 179 | cython_debug/ 180 | 181 | # PyCharm 182 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 183 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 184 | # and can be added to the global gitignore or merged into this file. For a more nuclear 185 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 186 | #.idea/ 187 | **/.tmp 188 | !fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt 189 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python-poetry/poetry 3 | rev: "1.8" 4 | hooks: 5 | - id: poetry-lock 6 | - id: poetry-export 7 | args: ["--dev", "-f", "requirements.txt", "-o", "requirements.txt"] 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as base 2 | 3 | ENV POETRY_NO_INTERACTION=1 \ 4 | POETRY_VIRTUALENVS_IN_PROJECT=1 \ 5 | POETRY_VIRTUALENVS_CREATE=1 \ 6 | POETRY_CACHE_DIR=/tmp/poetry_cache \ 7 | DEBIAN_FRONTEND=noninteractive 8 | 9 | # Install system dependencies in a single RUN command to reduce layers 10 | # Combine apt-get update, upgrade, and installation of packages. Clean up in the same layer to reduce image size. 11 | RUN apt-get update && \ 12 | apt-get upgrade -y && \ 13 | apt-get install -y python3.10 python3-pip git wget curl build-essential pipx && \ 14 | apt-get autoremove -y && \ 15 | apt-get clean && \ 16 | rm -rf /var/lib/apt/lists/* 17 | 18 | # install via pip given ubuntu 22.04 as per docs https://pipx.pypa.io/stable/installation/ 19 | RUN python3 -m pip install --user pipx && \ 20 | python3 -m pipx ensurepath && \ 21 | python3 -m pipx install poetry==1.8.2 22 | 23 | # make pipx installs (i.e poetry) available 24 | ENV PATH="/root/.local/bin:${PATH}" 25 | 26 | # install ffmpeg 27 | RUN wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz &&\ 28 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5 &&\ 29 | md5sum -c ffmpeg-git-amd64-static.tar.xz.md5 &&\ 30 | tar xvf ffmpeg-git-amd64-static.tar.xz &&\ 31 | mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ &&\ 32 | rm -rf ffmpeg-git-* 33 | 34 | WORKDIR /app 35 | 36 | COPY pyproject.toml poetry.lock ./ 37 | RUN touch README.md # poetry will complain otherwise 38 | 39 | RUN poetry install --without dev --no-root 40 | RUN poetry run python -m pip install torch==2.2.1 torchaudio==2.2.1 && \ 41 | rm -rf $POETRY_CACHE_DIR 42 | 43 | COPY fam ./fam 44 | COPY serving.py ./ 45 | COPY app.py ./ 46 | 47 | RUN poetry install --only-root 48 | 49 | ENTRYPOINT ["poetry", "run", "python", "serving.py"] 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaVoice-1B 2 | 3 | 4 | 5 | [![Playground](https://img.shields.io/static/v1?label=Try&message=Playground&color=fc4982&url=https://ttsdemo.themetavoice.xyz/)](https://ttsdemo.themetavoice.xyz/) 6 | 7 | Open In Colab 8 | 9 | [![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM) 10 | [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/OnusFM.svg?style=social&label=@metavoiceio)](https://twitter.com/metavoiceio) 11 | 12 | 13 | 14 | MetaVoice-1B is a 1.2B parameter base model trained on 100K hours of speech for TTS (text-to-speech). It has been built with the following priorities: 15 | * **Emotional speech rhythm and tone** in English. 16 | * **Zero-shot cloning for American & British voices**, with 30s reference audio. 17 | * Support for (cross-lingual) **voice cloning with finetuning**. 18 | * We have had success with as little as 1 minute training data for Indian speakers. 19 | * Synthesis of **arbitrary length text** 20 | 21 | We’re releasing MetaVoice-1B under the Apache 2.0 license, *it can be used without restrictions*. 22 | 23 | 24 | ## Quickstart - tl;dr 25 | 26 | Web UI 27 | ```bash 28 | docker-compose up -d ui && docker-compose ps && docker-compose logs -f 29 | ``` 30 | 31 | Server 32 | ```bash 33 | # navigate to /docs for API definitions 34 | docker-compose up -d server && docker-compose ps && docker-compose logs -f 35 | ``` 36 | 37 | ## Installation 38 | 39 | **Pre-requisites:** 40 | - GPU VRAM >=12GB 41 | - Python >=3.10,<3.12 42 | - pipx ([installation instructions](https://pipx.pypa.io/stable/installation/)) 43 | 44 | **Environment setup** 45 | ```bash 46 | # install ffmpeg 47 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz 48 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5 49 | md5sum -c ffmpeg-git-amd64-static.tar.xz.md5 50 | tar xvf ffmpeg-git-amd64-static.tar.xz 51 | sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ 52 | rm -rf ffmpeg-git-* 53 | 54 | # install rust if not installed (ensure you've restarted your terminal after installation) 55 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 56 | ``` 57 | 58 | ### Project dependencies installation 59 | 1. [Using poetry](#using-poetry-recommended) 60 | 2. [Using pip/conda](#using-pipconda) 61 | 62 | #### Using poetry (recommended) 63 | ```bash 64 | # install poetry if not installed (ensure you've restarted your terminal after installation) 65 | pipx install poetry 66 | 67 | # disable any conda envs that might interfere with poetry's venv 68 | conda deactivate 69 | 70 | # if running from Linux, keyring backend can hang on `poetry install`. This prevents that. 71 | export PYTHON_KEYRING_BACKEND=keyring.backends.fail.Keyring 72 | 73 | # pip's dependency resolver will complain, this is temporary expected behaviour 74 | # full inference & finetuning functionality will still be available 75 | poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1 76 | ``` 77 | 78 | #### Using pip/conda 79 | NOTE 1: When raising issues, we'll ask you to try with poetry first. 80 | NOTE 2: All commands in this README use `poetry` by default, so you can just remove any `poetry run`. 81 | 82 | ```bash 83 | pip install -r requirements.txt 84 | pip install torch==2.2.1 torchaudio==2.2.1 85 | pip install -e . 86 | ``` 87 | 88 | ## Usage 89 | 1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py) 90 | ```bash 91 | # You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. 92 | # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. 93 | poetry run python -i fam/llm/fast_inference.py 94 | 95 | # Run e.g. of API usage within the interactive python session 96 | tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.", spk_ref_path="assets/bria.mp3") 97 | ``` 98 | > Note: The script takes 30-90s to startup (depending on hardware). This is because we torch.compile the model for fast inference. 99 | 100 | > On Ampere, Ada-Lovelace, and Hopper architecture GPUs, once compiled, the synthesise() API runs faster than real-time, with a Real-Time Factor (RTF) < 1.0. 101 | 102 | 2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py) 103 | ```bash 104 | # You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. 105 | # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. 106 | 107 | # navigate to /docs for API definitions 108 | poetry run python serving.py 109 | 110 | poetry run python app.py 111 | ``` 112 | 113 | 3. Use it via [Hugging Face](https://huggingface.co/metavoiceio) 114 | 4. [Google Colab Demo](https://colab.research.google.com/github/metavoiceio/metavoice-src/blob/main/colab_demo.ipynb) 115 | 116 | ## Finetuning 117 | We support finetuning the first stage LLM (see [Architecture section](#Architecture)). 118 | 119 | In order to finetune, we expect a "|"-delimited CSV dataset of the following format: 120 | 121 | ```csv 122 | audio_files|captions 123 | ./data/audio.wav|./data/caption.txt 124 | ``` 125 | 126 | Note that we don't perform any dataset overlap checks, so ensure that your train and val datasets are disjoint. 127 | 128 | Try it out using our sample datasets via: 129 | ```bash 130 | poetry run finetune --train ./datasets/sample_dataset.csv --val ./datasets/sample_val_dataset.csv 131 | ``` 132 | 133 | Once you've trained your model, you can use it for inference via: 134 | ```bash 135 | poetry run python -i fam/llm/fast_inference.py --first_stage_path ./my-finetuned_model.pt 136 | ``` 137 | 138 | ### Configuration 139 | 140 | In order to set hyperparameters such as learning rate, what to freeze, etc, you 141 | can edit the [finetune_params.py](./fam/llm/config/finetune_params.py) file. 142 | 143 | We've got a light & optional integration with W&B that can be enabled via setting 144 | `wandb_log = True` & by installing the appropriate dependencies. 145 | 146 | ```bash 147 | poetry install -E observable 148 | ``` 149 | 150 | ## Upcoming 151 | - [x] Faster inference ⚡ 152 | - [x] Fine-tuning code 📐 153 | - [ ] Synthesis of arbitrary length text 154 | 155 | 156 | ## Architecture 157 | We predict EnCodec tokens from text, and speaker information. This is then diffused up to the waveform level, with post-processing applied to clean up the audio. 158 | 159 | * We use a causal GPT to predict the first two hierarchies of EnCodec tokens. Text and audio are part of the LLM context. Speaker information is passed via conditioning at the token embedding layer. This speaker conditioning is obtained from a separately trained speaker verification network. 160 | - The two hierarchies are predicted in a "flattened interleaved" manner, we predict the first token of the first hierarchy, then the first token of the second hierarchy, then the second token of the first hierarchy, and so on. 161 | - We use condition-free sampling to boost the cloning capability of the model. 162 | - The text is tokenised using a custom trained BPE tokeniser with 512 tokens. 163 | - Note that we've skipped predicting semantic tokens as done in other works, as we found that this isn't strictly necessary. 164 | * We use a non-causal (encoder-style) transformer to predict the rest of the 6 hierarchies from the first two hierarchies. This is a super small model (~10Mn parameters), and has extensive zero-shot generalisation to most speakers we've tried. Since it's non-causal, we're also able to predict all the timesteps in parallel. 165 | * We use multi-band diffusion to generate waveforms from the EnCodec tokens. We noticed that the speech is clearer than using the original RVQ decoder or VOCOS. However, the diffusion at waveform level leaves some background artifacts which are quite unpleasant to the ear. We clean this up in the next step. 166 | * We use DeepFilterNet to clear up the artifacts introduced by the multi-band diffusion. 167 | 168 | ## Optimizations 169 | The model supports: 170 | 1. KV-caching via Flash Decoding 171 | 2. Batching (including texts of different lengths) 172 | 173 | ## Contribute 174 | - See all [active issues](https://github.com/metavoiceio/metavoice-src/issues)! 175 | 176 | ## Acknowledgements 177 | We are grateful to Together.ai for their 24/7 help in marshalling our cluster. We thank the teams of AWS, GCP & Hugging Face for support with their cloud platforms. 178 | 179 | - [A Défossez et. al.](https://arxiv.org/abs/2210.13438) for Encodec. 180 | - [RS Roman et. al.](https://arxiv.org/abs/2308.02560) for Multiband Diffusion. 181 | - [@liusongxiang](https://github.com/liusongxiang/ppg-vc/blob/main/speaker_encoder/inference.py) for speaker encoder implementation. 182 | - [@karpathy](https://github.com/karpathy/nanoGPT) for NanoGPT which our inference implementation is based on. 183 | - [@Rikorose](https://github.com/Rikorose) for DeepFilterNet. 184 | 185 | Apologies in advance if we've missed anyone out. Please let us know if we have. 186 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | project_root = os.path.dirname(os.path.abspath(__file__)) 5 | if project_root not in sys.path: 6 | sys.path.insert(0, project_root) 7 | 8 | 9 | import gradio as gr 10 | import tyro 11 | 12 | from fam.llm.fast_inference import TTS 13 | from fam.llm.utils import check_audio_file 14 | 15 | #### setup model 16 | TTS_MODEL = tyro.cli(TTS, args=["--telemetry_origin", "webapp"]) 17 | 18 | #### setup interface 19 | RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"] 20 | MAX_CHARS = 220 21 | PRESET_VOICES = { 22 | # female 23 | "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3", 24 | # male 25 | "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3", 26 | "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav", 27 | } 28 | 29 | 30 | def denormalise_top_p(top_p): 31 | # returns top_p in the range [0.9, 1.0] 32 | return round(0.9 + top_p / 100, 2) 33 | 34 | 35 | def denormalise_guidance(guidance): 36 | # returns guidance in the range [1.0, 3.0] 37 | return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1) 38 | 39 | 40 | def _check_file_size(path): 41 | if not path: 42 | return 43 | filesize = os.path.getsize(path) 44 | filesize_mb = filesize / 1024 / 1024 45 | if filesize_mb >= 50: 46 | raise gr.Error(f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB") 47 | 48 | 49 | def _handle_edge_cases(to_say, upload_target): 50 | if not to_say: 51 | raise gr.Error("Please provide text to synthesise") 52 | 53 | if len(to_say) > MAX_CHARS: 54 | gr.Warning( 55 | f"Max {MAX_CHARS} characters allowed. Provided: {len(to_say)} characters. Truncating and generating speech...Result at the end can be unstable as a result." 56 | ) 57 | 58 | if not upload_target: 59 | return 60 | 61 | check_audio_file(upload_target) # check file duration to be atleast 30s 62 | _check_file_size(upload_target) 63 | 64 | 65 | def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target): 66 | try: 67 | d_top_p = denormalise_top_p(top_p) 68 | d_guidance = denormalise_guidance(guidance) 69 | 70 | _handle_edge_cases(to_say, upload_target) 71 | 72 | to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS] 73 | 74 | return TTS_MODEL.synthesise( 75 | text=to_say, 76 | spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target, 77 | top_p=d_top_p, 78 | guidance_scale=d_guidance, 79 | ) 80 | except Exception as e: 81 | raise gr.Error(f"Something went wrong. Reason: {str(e)}") 82 | 83 | 84 | def change_voice_selection_layout(choice): 85 | if choice == RADIO_CHOICES[0]: 86 | return [gr.update(visible=True), gr.update(visible=False)] 87 | 88 | return [gr.update(visible=False), gr.update(visible=True)] 89 | 90 | 91 | title = """ 92 | 93 | 94 | MetaVoice logo 95 | 96 | 97 | \n# TTS by MetaVoice-1B 98 | """ 99 | 100 | description = """ 101 | MetaVoice-1B is a 1.2B parameter base model for TTS (text-to-speech). It has been built with the following priorities: 102 | \n 103 | * Emotional speech rhythm and tone in English. 104 | * Zero-shot cloning for American & British voices, with 30s reference audio. 105 | * Support for voice cloning with finetuning. 106 | * We have had success with as little as 1 minute training data for Indian speakers. 107 | * Support for long-form synthesis. 108 | 109 | We are releasing the model under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). See [Github](https://github.com/metavoiceio/metavoice-src) for details and to contribute. 110 | """ 111 | 112 | with gr.Blocks(title="TTS by MetaVoice") as demo: 113 | gr.Markdown(title) 114 | 115 | with gr.Row(): 116 | gr.Markdown(description) 117 | 118 | with gr.Row(): 119 | with gr.Column(): 120 | to_say = gr.TextArea( 121 | label=f"What should I say!? (max {MAX_CHARS} characters).", 122 | lines=4, 123 | value="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice.", 124 | ) 125 | with gr.Row(), gr.Column(): 126 | # voice settings 127 | top_p = gr.Slider( 128 | value=5.0, 129 | minimum=0.0, 130 | maximum=10.0, 131 | step=1.0, 132 | label="Speech Stability - improves text following for a challenging speaker", 133 | ) 134 | guidance = gr.Slider( 135 | value=5.0, 136 | minimum=1.0, 137 | maximum=5.0, 138 | step=1.0, 139 | label="Speaker similarity - How closely to match speaker identity and speech style.", 140 | ) 141 | 142 | # voice select 143 | toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0]) 144 | 145 | with gr.Row(visible=True) as row_1: 146 | preset_dropdown = gr.Dropdown( 147 | PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0] 148 | ) 149 | with gr.Accordion("Preview: Preset voices", open=False): 150 | for label, path in PRESET_VOICES.items(): 151 | gr.Audio(value=path, label=label) 152 | 153 | with gr.Row(visible=False) as row_2: 154 | upload_target = gr.Audio( 155 | sources=["upload"], 156 | type="filepath", 157 | label="Upload a clean sample to clone. Sample should contain 1 speaker, be between 30-90 seconds and not contain background noise.", 158 | ) 159 | 160 | toggle.change( 161 | change_voice_selection_layout, 162 | inputs=toggle, 163 | outputs=[row_1, row_2], 164 | ) 165 | 166 | with gr.Column(): 167 | speech = gr.Audio( 168 | type="filepath", 169 | label="MetaVoice-1B says...", 170 | ) 171 | 172 | submit = gr.Button("Generate Speech") 173 | submit.click( 174 | fn=tts, 175 | inputs=[to_say, top_p, guidance, toggle, preset_dropdown, upload_target], 176 | outputs=speech, 177 | ) 178 | 179 | 180 | demo.queue() 181 | demo.launch( 182 | favicon_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/favicon.ico"), 183 | server_name="0.0.0.0", 184 | server_port=7861, 185 | ) 186 | -------------------------------------------------------------------------------- /assets/bria.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/assets/bria.mp3 -------------------------------------------------------------------------------- /assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/assets/favicon.ico -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/assets/logo.png -------------------------------------------------------------------------------- /colab_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Installation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Clone the repository" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "vscode": { 22 | "languageId": "plaintext" 23 | } 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "!git clone https://github.com/metavoiceio/metavoice-src.git\n", 28 | "%cd metavoice-src" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "### Install dependencies" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "vscode": { 43 | "languageId": "plaintext" 44 | } 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "!sudo apt install pipx\n", 49 | "!pipx install poetry\n", 50 | "!pipx run poetry install && pipx run poetry run pip install torch==2.2.1 torchaudio==2.2.1\n", 51 | "!pipx run poetry env list | sed 's/ (Activated)//' > poetry_env.txt\n", 52 | "# NOTE: pip's dependency resolver will error & complain, ignore it!\n", 53 | "# its due to a temporary dependency issue, `tts.synthesise` will still work as intended!" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "vscode": { 61 | "languageId": "plaintext" 62 | } 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "import sys, pathlib\n", 67 | "venv = pathlib.Path(\"poetry_env.txt\").read_text().strip(\"\\n\")\n", 68 | "sys.path.append(f\"/root/.cache/pypoetry/virtualenvs/{venv}/lib/python3.10/site-packages\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Inference" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "vscode": { 83 | "languageId": "plaintext" 84 | } 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "from IPython.display import Audio, display\n", 89 | "from fam.llm.fast_inference import TTS\n", 90 | "\n", 91 | "tts = TTS()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "vscode": { 99 | "languageId": "plaintext" 100 | } 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "wav_file = tts.synthesise(\n", 105 | " text=\"This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.\",\n", 106 | " spk_ref_path=\"assets/bria.mp3\" # you can use any speaker reference file (WAV, OGG, MP3, FLAC, etc.)\n", 107 | ")\n", 108 | "display(Audio(wav_file, autoplay=True))" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "accelerator": "GPU", 114 | "colab": { 115 | "gpuType": "T4", 116 | "provenance": [] 117 | }, 118 | "kernelspec": { 119 | "display_name": "Python 3", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "name": "python" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 2 128 | } 129 | -------------------------------------------------------------------------------- /data/audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/data/audio.wav -------------------------------------------------------------------------------- /data/caption.txt: -------------------------------------------------------------------------------- 1 | Please call Stella. -------------------------------------------------------------------------------- /datasets/sample_dataset.csv: -------------------------------------------------------------------------------- 1 | audio_files|captions 2 | ./data/audio.wav|./data/caption.txt 3 | ./data/audio.wav|./data/caption.txt 4 | ./data/audio.wav|./data/caption.txt 5 | ./data/audio.wav|./data/caption.txt 6 | ./data/audio.wav|./data/caption.txt 7 | ./data/audio.wav|./data/caption.txt 8 | ./data/audio.wav|./data/caption.txt 9 | ./data/audio.wav|./data/caption.txt 10 | ./data/audio.wav|./data/caption.txt 11 | ./data/audio.wav|./data/caption.txt 12 | ./data/audio.wav|./data/caption.txt 13 | ./data/audio.wav|./data/caption.txt 14 | ./data/audio.wav|./data/caption.txt 15 | ./data/audio.wav|./data/caption.txt 16 | ./data/audio.wav|./data/caption.txt 17 | ./data/audio.wav|./data/caption.txt 18 | ./data/audio.wav|./data/caption.txt 19 | ./data/audio.wav|./data/caption.txt 20 | ./data/audio.wav|./data/caption.txt 21 | ./data/audio.wav|./data/caption.txt 22 | ./data/audio.wav|./data/caption.txt 23 | ./data/audio.wav|./data/caption.txt 24 | ./data/audio.wav|./data/caption.txt 25 | ./data/audio.wav|./data/caption.txt 26 | ./data/audio.wav|./data/caption.txt 27 | ./data/audio.wav|./data/caption.txt 28 | ./data/audio.wav|./data/caption.txt 29 | ./data/audio.wav|./data/caption.txt 30 | ./data/audio.wav|./data/caption.txt 31 | ./data/audio.wav|./data/caption.txt 32 | ./data/audio.wav|./data/caption.txt 33 | ./data/audio.wav|./data/caption.txt 34 | ./data/audio.wav|./data/caption.txt 35 | ./data/audio.wav|./data/caption.txt 36 | ./data/audio.wav|./data/caption.txt 37 | ./data/audio.wav|./data/caption.txt 38 | ./data/audio.wav|./data/caption.txt 39 | ./data/audio.wav|./data/caption.txt 40 | ./data/audio.wav|./data/caption.txt 41 | ./data/audio.wav|./data/caption.txt 42 | ./data/audio.wav|./data/caption.txt 43 | ./data/audio.wav|./data/caption.txt 44 | ./data/audio.wav|./data/caption.txt 45 | ./data/audio.wav|./data/caption.txt 46 | ./data/audio.wav|./data/caption.txt 47 | ./data/audio.wav|./data/caption.txt 48 | ./data/audio.wav|./data/caption.txt 49 | ./data/audio.wav|./data/caption.txt 50 | ./data/audio.wav|./data/caption.txt 51 | ./data/audio.wav|./data/caption.txt 52 | ./data/audio.wav|./data/caption.txt 53 | ./data/audio.wav|./data/caption.txt 54 | ./data/audio.wav|./data/caption.txt 55 | ./data/audio.wav|./data/caption.txt 56 | ./data/audio.wav|./data/caption.txt 57 | ./data/audio.wav|./data/caption.txt 58 | ./data/audio.wav|./data/caption.txt 59 | ./data/audio.wav|./data/caption.txt 60 | ./data/audio.wav|./data/caption.txt 61 | ./data/audio.wav|./data/caption.txt 62 | ./data/audio.wav|./data/caption.txt 63 | ./data/audio.wav|./data/caption.txt 64 | ./data/audio.wav|./data/caption.txt 65 | ./data/audio.wav|./data/caption.txt 66 | ./data/audio.wav|./data/caption.txt 67 | ./data/audio.wav|./data/caption.txt 68 | ./data/audio.wav|./data/caption.txt 69 | ./data/audio.wav|./data/caption.txt 70 | ./data/audio.wav|./data/caption.txt 71 | ./data/audio.wav|./data/caption.txt 72 | ./data/audio.wav|./data/caption.txt 73 | ./data/audio.wav|./data/caption.txt 74 | ./data/audio.wav|./data/caption.txt 75 | ./data/audio.wav|./data/caption.txt 76 | ./data/audio.wav|./data/caption.txt 77 | ./data/audio.wav|./data/caption.txt 78 | ./data/audio.wav|./data/caption.txt 79 | ./data/audio.wav|./data/caption.txt 80 | ./data/audio.wav|./data/caption.txt 81 | ./data/audio.wav|./data/caption.txt 82 | ./data/audio.wav|./data/caption.txt 83 | ./data/audio.wav|./data/caption.txt 84 | ./data/audio.wav|./data/caption.txt 85 | ./data/audio.wav|./data/caption.txt 86 | ./data/audio.wav|./data/caption.txt 87 | ./data/audio.wav|./data/caption.txt 88 | ./data/audio.wav|./data/caption.txt 89 | ./data/audio.wav|./data/caption.txt 90 | ./data/audio.wav|./data/caption.txt 91 | ./data/audio.wav|./data/caption.txt 92 | ./data/audio.wav|./data/caption.txt 93 | ./data/audio.wav|./data/caption.txt 94 | ./data/audio.wav|./data/caption.txt 95 | ./data/audio.wav|./data/caption.txt 96 | ./data/audio.wav|./data/caption.txt 97 | ./data/audio.wav|./data/caption.txt 98 | ./data/audio.wav|./data/caption.txt 99 | ./data/audio.wav|./data/caption.txt 100 | ./data/audio.wav|./data/caption.txt 101 | ./data/audio.wav|./data/caption.txt 102 | ./data/audio.wav|./data/caption.txt 103 | ./data/audio.wav|./data/caption.txt 104 | ./data/audio.wav|./data/caption.txt 105 | ./data/audio.wav|./data/caption.txt 106 | ./data/audio.wav|./data/caption.txt 107 | ./data/audio.wav|./data/caption.txt 108 | ./data/audio.wav|./data/caption.txt 109 | ./data/audio.wav|./data/caption.txt 110 | ./data/audio.wav|./data/caption.txt 111 | ./data/audio.wav|./data/caption.txt 112 | ./data/audio.wav|./data/caption.txt 113 | ./data/audio.wav|./data/caption.txt 114 | ./data/audio.wav|./data/caption.txt 115 | ./data/audio.wav|./data/caption.txt 116 | ./data/audio.wav|./data/caption.txt 117 | ./data/audio.wav|./data/caption.txt 118 | ./data/audio.wav|./data/caption.txt 119 | ./data/audio.wav|./data/caption.txt 120 | ./data/audio.wav|./data/caption.txt 121 | ./data/audio.wav|./data/caption.txt 122 | ./data/audio.wav|./data/caption.txt 123 | ./data/audio.wav|./data/caption.txt 124 | ./data/audio.wav|./data/caption.txt 125 | ./data/audio.wav|./data/caption.txt 126 | ./data/audio.wav|./data/caption.txt 127 | ./data/audio.wav|./data/caption.txt 128 | ./data/audio.wav|./data/caption.txt 129 | ./data/audio.wav|./data/caption.txt 130 | ./data/audio.wav|./data/caption.txt 131 | ./data/audio.wav|./data/caption.txt 132 | ./data/audio.wav|./data/caption.txt 133 | ./data/audio.wav|./data/caption.txt 134 | ./data/audio.wav|./data/caption.txt 135 | ./data/audio.wav|./data/caption.txt 136 | ./data/audio.wav|./data/caption.txt 137 | ./data/audio.wav|./data/caption.txt 138 | ./data/audio.wav|./data/caption.txt 139 | ./data/audio.wav|./data/caption.txt 140 | ./data/audio.wav|./data/caption.txt 141 | ./data/audio.wav|./data/caption.txt 142 | ./data/audio.wav|./data/caption.txt 143 | ./data/audio.wav|./data/caption.txt 144 | ./data/audio.wav|./data/caption.txt 145 | ./data/audio.wav|./data/caption.txt 146 | ./data/audio.wav|./data/caption.txt 147 | ./data/audio.wav|./data/caption.txt 148 | ./data/audio.wav|./data/caption.txt 149 | ./data/audio.wav|./data/caption.txt 150 | ./data/audio.wav|./data/caption.txt 151 | ./data/audio.wav|./data/caption.txt 152 | ./data/audio.wav|./data/caption.txt 153 | ./data/audio.wav|./data/caption.txt 154 | ./data/audio.wav|./data/caption.txt 155 | ./data/audio.wav|./data/caption.txt 156 | ./data/audio.wav|./data/caption.txt 157 | ./data/audio.wav|./data/caption.txt 158 | ./data/audio.wav|./data/caption.txt 159 | ./data/audio.wav|./data/caption.txt 160 | ./data/audio.wav|./data/caption.txt 161 | ./data/audio.wav|./data/caption.txt 162 | ./data/audio.wav|./data/caption.txt 163 | ./data/audio.wav|./data/caption.txt 164 | ./data/audio.wav|./data/caption.txt 165 | ./data/audio.wav|./data/caption.txt 166 | ./data/audio.wav|./data/caption.txt 167 | ./data/audio.wav|./data/caption.txt 168 | ./data/audio.wav|./data/caption.txt 169 | ./data/audio.wav|./data/caption.txt 170 | ./data/audio.wav|./data/caption.txt 171 | ./data/audio.wav|./data/caption.txt 172 | ./data/audio.wav|./data/caption.txt 173 | ./data/audio.wav|./data/caption.txt 174 | ./data/audio.wav|./data/caption.txt 175 | ./data/audio.wav|./data/caption.txt 176 | ./data/audio.wav|./data/caption.txt 177 | ./data/audio.wav|./data/caption.txt 178 | ./data/audio.wav|./data/caption.txt 179 | ./data/audio.wav|./data/caption.txt 180 | ./data/audio.wav|./data/caption.txt 181 | ./data/audio.wav|./data/caption.txt 182 | ./data/audio.wav|./data/caption.txt 183 | ./data/audio.wav|./data/caption.txt 184 | ./data/audio.wav|./data/caption.txt 185 | ./data/audio.wav|./data/caption.txt 186 | ./data/audio.wav|./data/caption.txt 187 | ./data/audio.wav|./data/caption.txt 188 | ./data/audio.wav|./data/caption.txt 189 | ./data/audio.wav|./data/caption.txt 190 | ./data/audio.wav|./data/caption.txt 191 | ./data/audio.wav|./data/caption.txt 192 | ./data/audio.wav|./data/caption.txt 193 | ./data/audio.wav|./data/caption.txt 194 | ./data/audio.wav|./data/caption.txt 195 | ./data/audio.wav|./data/caption.txt 196 | ./data/audio.wav|./data/caption.txt 197 | ./data/audio.wav|./data/caption.txt 198 | ./data/audio.wav|./data/caption.txt 199 | ./data/audio.wav|./data/caption.txt 200 | ./data/audio.wav|./data/caption.txt 201 | ./data/audio.wav|./data/caption.txt 202 | ./data/audio.wav|./data/caption.txt 203 | ./data/audio.wav|./data/caption.txt 204 | ./data/audio.wav|./data/caption.txt 205 | ./data/audio.wav|./data/caption.txt 206 | ./data/audio.wav|./data/caption.txt 207 | ./data/audio.wav|./data/caption.txt 208 | ./data/audio.wav|./data/caption.txt 209 | ./data/audio.wav|./data/caption.txt 210 | ./data/audio.wav|./data/caption.txt 211 | ./data/audio.wav|./data/caption.txt 212 | ./data/audio.wav|./data/caption.txt 213 | ./data/audio.wav|./data/caption.txt 214 | ./data/audio.wav|./data/caption.txt 215 | ./data/audio.wav|./data/caption.txt 216 | ./data/audio.wav|./data/caption.txt 217 | ./data/audio.wav|./data/caption.txt 218 | ./data/audio.wav|./data/caption.txt 219 | ./data/audio.wav|./data/caption.txt 220 | ./data/audio.wav|./data/caption.txt 221 | ./data/audio.wav|./data/caption.txt 222 | ./data/audio.wav|./data/caption.txt 223 | ./data/audio.wav|./data/caption.txt 224 | ./data/audio.wav|./data/caption.txt 225 | ./data/audio.wav|./data/caption.txt 226 | ./data/audio.wav|./data/caption.txt 227 | ./data/audio.wav|./data/caption.txt 228 | ./data/audio.wav|./data/caption.txt 229 | ./data/audio.wav|./data/caption.txt 230 | ./data/audio.wav|./data/caption.txt 231 | ./data/audio.wav|./data/caption.txt 232 | ./data/audio.wav|./data/caption.txt 233 | ./data/audio.wav|./data/caption.txt 234 | ./data/audio.wav|./data/caption.txt 235 | ./data/audio.wav|./data/caption.txt 236 | ./data/audio.wav|./data/caption.txt 237 | ./data/audio.wav|./data/caption.txt 238 | ./data/audio.wav|./data/caption.txt 239 | ./data/audio.wav|./data/caption.txt 240 | ./data/audio.wav|./data/caption.txt 241 | ./data/audio.wav|./data/caption.txt 242 | ./data/audio.wav|./data/caption.txt 243 | ./data/audio.wav|./data/caption.txt 244 | ./data/audio.wav|./data/caption.txt 245 | ./data/audio.wav|./data/caption.txt 246 | ./data/audio.wav|./data/caption.txt 247 | ./data/audio.wav|./data/caption.txt 248 | ./data/audio.wav|./data/caption.txt 249 | ./data/audio.wav|./data/caption.txt 250 | ./data/audio.wav|./data/caption.txt 251 | ./data/audio.wav|./data/caption.txt 252 | ./data/audio.wav|./data/caption.txt 253 | ./data/audio.wav|./data/caption.txt 254 | ./data/audio.wav|./data/caption.txt 255 | ./data/audio.wav|./data/caption.txt 256 | ./data/audio.wav|./data/caption.txt 257 | ./data/audio.wav|./data/caption.txt 258 | ./data/audio.wav|./data/caption.txt 259 | ./data/audio.wav|./data/caption.txt 260 | ./data/audio.wav|./data/caption.txt 261 | ./data/audio.wav|./data/caption.txt 262 | ./data/audio.wav|./data/caption.txt 263 | ./data/audio.wav|./data/caption.txt 264 | ./data/audio.wav|./data/caption.txt 265 | ./data/audio.wav|./data/caption.txt 266 | ./data/audio.wav|./data/caption.txt 267 | ./data/audio.wav|./data/caption.txt 268 | ./data/audio.wav|./data/caption.txt 269 | ./data/audio.wav|./data/caption.txt 270 | ./data/audio.wav|./data/caption.txt 271 | ./data/audio.wav|./data/caption.txt 272 | ./data/audio.wav|./data/caption.txt 273 | ./data/audio.wav|./data/caption.txt 274 | ./data/audio.wav|./data/caption.txt 275 | ./data/audio.wav|./data/caption.txt 276 | ./data/audio.wav|./data/caption.txt 277 | ./data/audio.wav|./data/caption.txt 278 | ./data/audio.wav|./data/caption.txt 279 | ./data/audio.wav|./data/caption.txt 280 | ./data/audio.wav|./data/caption.txt 281 | ./data/audio.wav|./data/caption.txt 282 | ./data/audio.wav|./data/caption.txt 283 | ./data/audio.wav|./data/caption.txt 284 | ./data/audio.wav|./data/caption.txt 285 | ./data/audio.wav|./data/caption.txt 286 | ./data/audio.wav|./data/caption.txt 287 | ./data/audio.wav|./data/caption.txt 288 | ./data/audio.wav|./data/caption.txt 289 | ./data/audio.wav|./data/caption.txt 290 | ./data/audio.wav|./data/caption.txt 291 | ./data/audio.wav|./data/caption.txt 292 | ./data/audio.wav|./data/caption.txt 293 | ./data/audio.wav|./data/caption.txt 294 | ./data/audio.wav|./data/caption.txt 295 | ./data/audio.wav|./data/caption.txt 296 | ./data/audio.wav|./data/caption.txt 297 | ./data/audio.wav|./data/caption.txt 298 | ./data/audio.wav|./data/caption.txt 299 | ./data/audio.wav|./data/caption.txt 300 | ./data/audio.wav|./data/caption.txt 301 | ./data/audio.wav|./data/caption.txt 302 | ./data/audio.wav|./data/caption.txt 303 | ./data/audio.wav|./data/caption.txt 304 | ./data/audio.wav|./data/caption.txt 305 | ./data/audio.wav|./data/caption.txt 306 | ./data/audio.wav|./data/caption.txt 307 | ./data/audio.wav|./data/caption.txt 308 | ./data/audio.wav|./data/caption.txt 309 | ./data/audio.wav|./data/caption.txt 310 | ./data/audio.wav|./data/caption.txt 311 | ./data/audio.wav|./data/caption.txt 312 | ./data/audio.wav|./data/caption.txt 313 | ./data/audio.wav|./data/caption.txt 314 | ./data/audio.wav|./data/caption.txt 315 | ./data/audio.wav|./data/caption.txt 316 | ./data/audio.wav|./data/caption.txt 317 | ./data/audio.wav|./data/caption.txt 318 | ./data/audio.wav|./data/caption.txt 319 | ./data/audio.wav|./data/caption.txt 320 | ./data/audio.wav|./data/caption.txt 321 | ./data/audio.wav|./data/caption.txt 322 | -------------------------------------------------------------------------------- /datasets/sample_val_dataset.csv: -------------------------------------------------------------------------------- 1 | audio_files|captions 2 | ./data/audio.wav|./data/caption.txt 3 | ./data/audio.wav|./data/caption.txt 4 | ./data/audio.wav|./data/caption.txt 5 | ./data/audio.wav|./data/caption.txt 6 | ./data/audio.wav|./data/caption.txt 7 | ./data/audio.wav|./data/caption.txt 8 | ./data/audio.wav|./data/caption.txt 9 | ./data/audio.wav|./data/caption.txt 10 | ./data/audio.wav|./data/caption.txt 11 | ./data/audio.wav|./data/caption.txt 12 | ./data/audio.wav|./data/caption.txt 13 | ./data/audio.wav|./data/caption.txt 14 | ./data/audio.wav|./data/caption.txt 15 | ./data/audio.wav|./data/caption.txt 16 | ./data/audio.wav|./data/caption.txt 17 | ./data/audio.wav|./data/caption.txt 18 | ./data/audio.wav|./data/caption.txt 19 | ./data/audio.wav|./data/caption.txt 20 | ./data/audio.wav|./data/caption.txt 21 | ./data/audio.wav|./data/caption.txt 22 | ./data/audio.wav|./data/caption.txt 23 | ./data/audio.wav|./data/caption.txt 24 | ./data/audio.wav|./data/caption.txt 25 | ./data/audio.wav|./data/caption.txt 26 | ./data/audio.wav|./data/caption.txt 27 | ./data/audio.wav|./data/caption.txt 28 | ./data/audio.wav|./data/caption.txt 29 | ./data/audio.wav|./data/caption.txt 30 | ./data/audio.wav|./data/caption.txt 31 | ./data/audio.wav|./data/caption.txt 32 | ./data/audio.wav|./data/caption.txt 33 | ./data/audio.wav|./data/caption.txt 34 | ./data/audio.wav|./data/caption.txt 35 | ./data/audio.wav|./data/caption.txt 36 | ./data/audio.wav|./data/caption.txt 37 | ./data/audio.wav|./data/caption.txt 38 | ./data/audio.wav|./data/caption.txt 39 | ./data/audio.wav|./data/caption.txt 40 | ./data/audio.wav|./data/caption.txt 41 | ./data/audio.wav|./data/caption.txt 42 | ./data/audio.wav|./data/caption.txt 43 | ./data/audio.wav|./data/caption.txt 44 | ./data/audio.wav|./data/caption.txt 45 | ./data/audio.wav|./data/caption.txt 46 | ./data/audio.wav|./data/caption.txt 47 | ./data/audio.wav|./data/caption.txt 48 | ./data/audio.wav|./data/caption.txt 49 | ./data/audio.wav|./data/caption.txt 50 | ./data/audio.wav|./data/caption.txt 51 | ./data/audio.wav|./data/caption.txt 52 | ./data/audio.wav|./data/caption.txt 53 | ./data/audio.wav|./data/caption.txt 54 | ./data/audio.wav|./data/caption.txt 55 | ./data/audio.wav|./data/caption.txt 56 | ./data/audio.wav|./data/caption.txt 57 | ./data/audio.wav|./data/caption.txt 58 | ./data/audio.wav|./data/caption.txt 59 | ./data/audio.wav|./data/caption.txt 60 | ./data/audio.wav|./data/caption.txt 61 | ./data/audio.wav|./data/caption.txt 62 | ./data/audio.wav|./data/caption.txt 63 | ./data/audio.wav|./data/caption.txt 64 | ./data/audio.wav|./data/caption.txt 65 | ./data/audio.wav|./data/caption.txt 66 | ./data/audio.wav|./data/caption.txt 67 | ./data/audio.wav|./data/caption.txt 68 | ./data/audio.wav|./data/caption.txt 69 | ./data/audio.wav|./data/caption.txt 70 | ./data/audio.wav|./data/caption.txt 71 | ./data/audio.wav|./data/caption.txt 72 | ./data/audio.wav|./data/caption.txt 73 | ./data/audio.wav|./data/caption.txt 74 | ./data/audio.wav|./data/caption.txt 75 | ./data/audio.wav|./data/caption.txt 76 | ./data/audio.wav|./data/caption.txt 77 | ./data/audio.wav|./data/caption.txt 78 | ./data/audio.wav|./data/caption.txt 79 | ./data/audio.wav|./data/caption.txt 80 | ./data/audio.wav|./data/caption.txt 81 | ./data/audio.wav|./data/caption.txt 82 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.5" 2 | 3 | networks: 4 | metavoice-net: 5 | driver: bridge 6 | 7 | volumes: 8 | hf-cache: 9 | driver: local 10 | 11 | x-common-settings: &common-settings 12 | volumes: 13 | - hf-cache:/.hf-cache 14 | - ./assets:/app/assets 15 | deploy: 16 | replicas: 1 17 | resources: 18 | reservations: 19 | devices: 20 | - driver: nvidia 21 | count: 1 22 | capabilities: [ gpu ] 23 | runtime: nvidia 24 | ipc: host 25 | tty: true # enable colorized logs 26 | build: 27 | context: . 28 | image: metavoice-server:latest 29 | networks: 30 | - metavoice-net 31 | environment: 32 | - NVIDIA_VISIBLE_DEVICES=all 33 | - HF_HOME=/.hf-cache 34 | logging: 35 | options: 36 | max-size: "100m" 37 | max-file: "10" 38 | 39 | services: 40 | server: 41 | <<: *common-settings 42 | container_name: metavoice-server 43 | command: [ "--port=58004" ] 44 | ports: 45 | - 58004:58004 46 | healthcheck: 47 | test: [ "CMD", "curl", "http://metavoice-server:58004/health" ] 48 | interval: 1m 49 | timeout: 10s 50 | retries: 20 51 | ui: 52 | <<: *common-settings 53 | container_name: metavoice-ui 54 | entrypoint: [ "poetry", "run", "python", "app.py" ] 55 | ports: 56 | - 7861:7861 57 | healthcheck: 58 | test: [ "CMD", "curl", "http://localhost:7861" ] 59 | interval: 1m 60 | timeout: 10s 61 | retries: 1 62 | -------------------------------------------------------------------------------- /fam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/__init__.py -------------------------------------------------------------------------------- /fam/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/__init__.py -------------------------------------------------------------------------------- /fam/llm/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook 2 | from fam.llm.adapters.tilted_encodec import TiltedEncodec 3 | -------------------------------------------------------------------------------- /fam/llm/adapters/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseDataAdapter(ABC): 5 | pass 6 | -------------------------------------------------------------------------------- /fam/llm/adapters/flattened_encodec.py: -------------------------------------------------------------------------------- 1 | from fam.llm.adapters.base import BaseDataAdapter 2 | 3 | 4 | class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter): 5 | def __init__(self, end_of_audio_token): 6 | self._end_of_audio_token = end_of_audio_token 7 | 8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]: 9 | assert len(tokens) == 1 10 | tokens = tokens[0] 11 | 12 | text_ids = [] 13 | extracted_audio_ids = [[], []] 14 | 15 | for t in tokens: 16 | if t < self._end_of_audio_token: 17 | extracted_audio_ids[0].append(t) 18 | elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token: 19 | extracted_audio_ids[1].append(t - self._end_of_audio_token) 20 | # We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token 21 | elif t > 2 * self._end_of_audio_token: 22 | text_ids.append(t) 23 | 24 | if len(set([len(x) for x in extracted_audio_ids])) != 1: 25 | min_len = min([len(x) for x in extracted_audio_ids]) 26 | max_len = max([len(x) for x in extracted_audio_ids]) 27 | print("WARNING: Number of tokens at each hierarchy must be of the same length!") 28 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.") 29 | print([len(x) for x in extracted_audio_ids]) 30 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids] 31 | 32 | return text_ids[:-1], extracted_audio_ids 33 | 34 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]): 35 | """ 36 | Performs the required combination and padding as needed. 37 | """ 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /fam/llm/adapters/tilted_encodec.py: -------------------------------------------------------------------------------- 1 | from fam.llm.adapters.base import BaseDataAdapter 2 | 3 | 4 | class TiltedEncodec(BaseDataAdapter): 5 | def __init__(self, end_of_audio_token): 6 | self._end_of_audio_token = end_of_audio_token 7 | 8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]: 9 | assert len(tokens) > 1 10 | 11 | text_ids = [] 12 | extracted_audio_ids = [] 13 | 14 | extracted_audio_ids.append([]) 15 | # Handle first hierarchy as special case as it contains text tokens as well 16 | # TODO: maybe it doesn't need special case, and can be handled on it's own :) 17 | for t in tokens[0]: 18 | if t > self._end_of_audio_token: 19 | text_ids.append(t) 20 | elif t < self._end_of_audio_token: 21 | extracted_audio_ids[0].append(t) 22 | 23 | # Handle the rest of the hierarchies 24 | for i in range(1, len(tokens)): 25 | token_hierarchy_ids = tokens[i] 26 | extracted_audio_ids.append([]) 27 | for t in token_hierarchy_ids: 28 | if t < self._end_of_audio_token: 29 | extracted_audio_ids[i].append(t) 30 | 31 | if len(set([len(x) for x in extracted_audio_ids])) != 1: 32 | min_len = min([len(x) for x in extracted_audio_ids]) 33 | max_len = max([len(x) for x in extracted_audio_ids]) 34 | print("WARNING: Number of tokens at each hierarchy must be of the same length!") 35 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.") 36 | print([len(x) for x in extracted_audio_ids]) 37 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids] 38 | 39 | return text_ids[:-1], extracted_audio_ids 40 | 41 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]): 42 | """ 43 | Performs the required combination and padding as needed. 44 | """ 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /fam/llm/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/config/__init__.py -------------------------------------------------------------------------------- /fam/llm/config/finetune_params.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | import os 3 | import uuid 4 | import pathlib 5 | from typing import Literal, Optional 6 | import torch 7 | 8 | batch_size = 2 9 | dataset_size: int = 400 10 | batched_ds_size = dataset_size // batch_size 11 | val_train_ratio = 0.2 12 | 13 | epochs: int = 2 14 | max_iters = batched_ds_size * epochs 15 | learning_rate = 3e-5 16 | last_n_blocks_to_finetune = 1 17 | decay_lr = False 18 | lr_decay_iters = 0 # decay learning rate after this many iterations 19 | min_lr = 3e-6 20 | 21 | eval_interval = batched_ds_size 22 | eval_iters = int(batched_ds_size*val_train_ratio) 23 | eval_only: bool = False # if True, script exits right after the first eval 24 | log_interval = batched_ds_size # don't print too too often 25 | save_interval: int = batched_ds_size * (epochs//2) # save a checkpoint every this many iterations 26 | assert save_interval % eval_interval == 0, "save_interval must be divisible by eval_interval." 27 | seed = 1337 28 | grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0 29 | 30 | wandb_log = False 31 | wandb_project = "project-name" 32 | wandb_run_name = "run-name" 33 | wandb_tags = ["tag1", "tag2"] 34 | 35 | gradient_accumulation_steps = 1 36 | block_size = 2_048 37 | audio_token_mode = "flattened_interleaved" 38 | num_max_audio_tokens_timesteps = 1_024 39 | 40 | n_layer = 24 41 | n_head = 16 42 | n_embd = 2048 43 | dropout = 0.1 44 | 45 | weight_decay = 1e-1 46 | beta1 = 0.9 47 | beta2 = 0.95 48 | 49 | warmup_iters: int = 0 # how many steps to warm up for 50 | out_dir = f"finetune-{epochs=}-{learning_rate=}-{batch_size=}-{last_n_blocks_to_finetune=}-{dropout=}-{uuid.uuid4()}" 51 | 52 | compile = True 53 | num_codebooks = None 54 | norm_type = "rmsnorm" 55 | rmsnorm_eps = 1e-5 56 | nonlinearity_type = "swiglu" 57 | swiglu_multiple_of = 256 58 | attn_kernel_type = "torch_attn" 59 | meta_target_vocab_sizes: Optional[list[int]] = None 60 | speaker_emb_size: int = 256 61 | speaker_cond = True 62 | 63 | # always running finetuning on a single GPU 64 | master_process = True 65 | device: str = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 66 | ddp = False 67 | ddp_world_size = 1 68 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 69 | 70 | causal = True 71 | bias: bool = False # do we use bias inside LayerNorm and Linear layers? 72 | spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not 73 | -------------------------------------------------------------------------------- /fam/llm/decoders.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import pathlib 4 | import uuid 5 | from abc import ABC, abstractmethod 6 | from typing import Callable, Optional, Union 7 | 8 | import julius 9 | import torch 10 | from audiocraft.data.audio import audio_read, audio_write 11 | from audiocraft.models import MultiBandDiffusion # type: ignore 12 | 13 | mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5 14 | 15 | 16 | class Decoder(ABC): 17 | @abstractmethod 18 | def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None): 19 | raise NotImplementedError 20 | 21 | 22 | class EncodecDecoder(Decoder): 23 | def __init__( 24 | self, 25 | tokeniser_decode_fn: Callable[[list[int]], str], 26 | data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]], 27 | output_dir: str, 28 | ): 29 | self._mbd_sample_rate = 24_000 30 | self._end_of_audio_token = 1024 31 | self._num_codebooks = 8 32 | self.mbd = mbd 33 | 34 | self.tokeniser_decode_fn = tokeniser_decode_fn 35 | self._data_adapter_fn = data_adapter_fn 36 | 37 | self.output_dir = pathlib.Path(output_dir).resolve() 38 | os.makedirs(self.output_dir, exist_ok=True) 39 | 40 | def _save_audio(self, name: str, wav: torch.Tensor): 41 | audio_write( 42 | name, 43 | wav.squeeze(0).cpu(), 44 | self._mbd_sample_rate, 45 | strategy="loudness", 46 | loudness_compressor=True, 47 | ) 48 | 49 | def get_tokens(self, audio_path: str) -> list[list[int]]: 50 | """ 51 | Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g. 52 | limited codebook reconstruction or sampling from second stage model only). 53 | """ 54 | pass 55 | wav, sr = audio_read(audio_path) 56 | if sr != self._mbd_sample_rate: 57 | wav = julius.resample_frac(wav, sr, self._mbd_sample_rate) 58 | if wav.ndim == 2: 59 | wav = wav.unsqueeze(1) 60 | wav = wav.to("cuda") 61 | tokens = self.mbd.codec_model.encode(wav) 62 | tokens = tokens[0][0] 63 | 64 | return tokens.tolist() 65 | 66 | def decode( 67 | self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None 68 | ) -> Union[str, torch.Tensor]: 69 | # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file. 70 | text_ids, extracted_audio_ids = self._data_adapter_fn(tokens) 71 | text = self.tokeniser_decode_fn(text_ids) 72 | # print(f"Text: {text}") 73 | 74 | tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0) 75 | 76 | if tokens.shape[1] < self._num_codebooks: 77 | tokens = torch.cat( 78 | [tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1 79 | ) 80 | 81 | if causal: 82 | return tokens 83 | else: 84 | with torch.amp.autocast(device_type="cuda", dtype=torch.float32): 85 | wav = self.mbd.tokens_to_wav(tokens) 86 | # NOTE: we couldn't just return wav here as it goes through loudness compression etc :) 87 | 88 | if wav.shape[-1] < 9600: 89 | # this causes problem for the code below, and is also odd :) 90 | # first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!) 91 | raise Exception("wav predicted is shorter than 400ms!") 92 | 93 | try: 94 | wav_file_name = self.output_dir / f"synth_{datetime.now().strftime('%y-%m-%d--%H-%M-%S')}_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}" 95 | self._save_audio(wav_file_name, wav) 96 | return wav_file_name 97 | except Exception as e: 98 | print(f"Failed to save audio! Reason: {e}") 99 | 100 | wav_file_name = self.output_dir / f"synth_{datetime.now().strftime('%y-%m-%d--%H-%M-%S')}_{uuid.uuid4()}" 101 | self._save_audio(wav_file_name, wav) 102 | return wav_file_name 103 | -------------------------------------------------------------------------------- /fam/llm/enhancers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC 3 | from typing import Literal, Optional 4 | 5 | from df.enhance import enhance, init_df, load_audio, save_audio 6 | from pydub import AudioSegment 7 | 8 | 9 | def convert_to_wav(input_file: str, output_file: str): 10 | """Convert an audio file to WAV format 11 | 12 | Args: 13 | input_file (str): path to input audio file 14 | output_file (str): path to output WAV file 15 | 16 | """ 17 | # Detect the format of the input file 18 | format = input_file.split(".")[-1].lower() 19 | 20 | # Read the audio file 21 | audio = AudioSegment.from_file(input_file, format=format) 22 | 23 | # Export as WAV 24 | audio.export(output_file, format="wav") 25 | 26 | 27 | def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str: 28 | """Generate the output file path 29 | 30 | Args: 31 | audio_file (str): path to input audio file 32 | tag (str): tag to append to the output file name 33 | ext (str, optional): extension of the output file. Defaults to None. 34 | 35 | Returns: 36 | str: path to output file 37 | """ 38 | 39 | directory = "./enhanced" 40 | # Get the name of the input file 41 | filename = os.path.basename(audio_file) 42 | 43 | # Get the name of the input file without the extension 44 | filename_without_extension = os.path.splitext(filename)[0] 45 | 46 | # Get the extension of the input file 47 | extension = ext or os.path.splitext(filename)[1] 48 | 49 | # Generate the output file path 50 | output_file = os.path.join(directory, filename_without_extension + tag + extension) 51 | 52 | return output_file 53 | 54 | 55 | class BaseEnhancer(ABC): 56 | """Base class for audio enhancers""" 57 | 58 | def __init__(self, *args, **kwargs): 59 | raise NotImplementedError 60 | 61 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str: 62 | raise NotImplementedError 63 | 64 | def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str: 65 | output_file = make_output_file_path(audio_file, tag, ext=ext) 66 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 67 | return output_file 68 | 69 | 70 | class DFEnhancer(BaseEnhancer): 71 | def __init__(self, *args, **kwargs): 72 | self.model, self.df_state, _ = init_df() 73 | 74 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str: 75 | output_file = output_file or self.get_output_file(audio_file, "_df") 76 | 77 | audio, _ = load_audio(audio_file, sr=self.df_state.sr()) 78 | 79 | enhanced = enhance(self.model, self.df_state, audio) 80 | 81 | save_audio(output_file, enhanced, self.df_state.sr()) 82 | 83 | return output_file 84 | 85 | 86 | def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer: 87 | """Get an audio enhancer 88 | 89 | Args: 90 | enhancer_name (Literal["df"]): name of the audio enhancer 91 | 92 | Raises: 93 | ValueError: if the enhancer name is not recognised 94 | 95 | Returns: 96 | BaseEnhancer: audio enhancer 97 | """ 98 | 99 | if enhancer_name == "df": 100 | import warnings 101 | 102 | warnings.filterwarnings( 103 | "ignore", 104 | message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.', 105 | ) 106 | return DFEnhancer() 107 | else: 108 | raise ValueError(f"Unknown enhancer name: {enhancer_name}") 109 | -------------------------------------------------------------------------------- /fam/llm/fast_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import time 5 | from pathlib import Path 6 | from typing import Literal, Optional 7 | 8 | import librosa 9 | import torch 10 | import tyro 11 | from huggingface_hub import snapshot_download 12 | 13 | from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook 14 | from fam.llm.decoders import EncodecDecoder 15 | from fam.llm.fast_inference_utils import build_model, main 16 | from fam.llm.inference import ( 17 | EncodecDecoder, 18 | InferenceConfig, 19 | Model, 20 | TiltedEncodec, 21 | TrainedBPETokeniser, 22 | get_cached_embedding, 23 | get_cached_file, 24 | get_enhancer, 25 | ) 26 | from fam.llm.utils import ( 27 | check_audio_file, 28 | get_default_dtype, 29 | get_device, 30 | normalize_text, 31 | ) 32 | from fam.telemetry import TelemetryEvent 33 | from fam.telemetry.posthog import PosthogClient 34 | 35 | posthog = PosthogClient() # see fam/telemetry/README.md for more information 36 | 37 | 38 | class TTS: 39 | END_OF_AUDIO_TOKEN = 1024 40 | 41 | def __init__( 42 | self, 43 | model_name: str = "metavoiceio/metavoice-1B-v0.1", 44 | *, 45 | seed: int = 1337, 46 | output_dir: str = "outputs", 47 | quantisation_mode: Optional[Literal["int4", "int8"]] = None, 48 | first_stage_path: Optional[str] = None, 49 | telemetry_origin: Optional[str] = None, 50 | ): 51 | """ 52 | Initialise the TTS model. 53 | 54 | Args: 55 | model_name: refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio) 56 | seed: random seed for reproducibility 57 | output_dir: directory to save output files 58 | quantisation_mode: quantisation mode for first-stage LLM. 59 | Options: 60 | - None for no quantisation (bf16 or fp16 based on device), 61 | - int4 for int4 weight-only quantisation, 62 | - int8 for int8 weight-only quantisation. 63 | first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`. 64 | telemetry_origin: A string identifier that specifies the origin of the telemetry data sent to PostHog. 65 | """ 66 | 67 | # NOTE: this needs to come first so that we don't change global state when we want to use 68 | # the torch.compiled-model. 69 | self._dtype = get_default_dtype() 70 | self._device = get_device() 71 | self._model_dir = snapshot_download(repo_id=model_name) 72 | self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN) 73 | self.output_dir = output_dir 74 | os.makedirs(self.output_dir, exist_ok=True) 75 | if first_stage_path: 76 | print(f"Overriding first stage checkpoint via provided model: {first_stage_path}") 77 | self._first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt" 78 | 79 | second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt" 80 | config_second_stage = InferenceConfig( 81 | ckpt_path=second_stage_ckpt_path, 82 | num_samples=1, 83 | seed=seed, 84 | device=self._device, 85 | dtype=self._dtype, 86 | compile=False, 87 | init_from="resume", 88 | output_dir=self.output_dir, 89 | ) 90 | data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN) 91 | self.llm_second_stage = Model( 92 | config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode 93 | ) 94 | self.enhancer = get_enhancer("df") 95 | 96 | self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype] 97 | self.model, self.tokenizer, self.smodel, self.model_size = build_model( 98 | precision=self.precision, 99 | checkpoint_path=Path(self._first_stage_ckpt), 100 | spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"), 101 | device=self._device, 102 | compile=True, 103 | compile_prefill=True, 104 | quantisation_mode=quantisation_mode, 105 | ) 106 | self._seed = seed 107 | self._quantisation_mode = quantisation_mode 108 | self._model_name = model_name 109 | self._telemetry_origin = telemetry_origin 110 | 111 | def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str: 112 | """ 113 | text: Text to speak 114 | spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3 115 | top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker 116 | guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style. 117 | temperature: Temperature for sampling applied to both LLMs (first & second stage) 118 | 119 | returns: path to speech .wav file 120 | """ 121 | text = normalize_text(text) 122 | spk_ref_path = get_cached_file(spk_ref_path) 123 | check_audio_file(spk_ref_path) 124 | spk_emb = get_cached_embedding( 125 | spk_ref_path, 126 | self.smodel, 127 | ).to(device=self._device, dtype=self.precision) 128 | 129 | start = time.time() 130 | # first stage LLM 131 | tokens = main( 132 | model=self.model, 133 | tokenizer=self.tokenizer, 134 | model_size=self.model_size, 135 | prompt=text, 136 | spk_emb=spk_emb, 137 | top_p=torch.tensor(top_p, device=self._device, dtype=self.precision), 138 | guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision), 139 | temperature=torch.tensor(temperature, device=self._device, dtype=self.precision), 140 | ) 141 | _, extracted_audio_ids = self.first_stage_adapter.decode([tokens]) 142 | 143 | b_speaker_embs = spk_emb.unsqueeze(0) 144 | 145 | # second stage LLM + multi-band diffusion model 146 | wav_files = self.llm_second_stage( 147 | texts=[text], 148 | encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)], 149 | speaker_embs=b_speaker_embs, 150 | batch_size=1, 151 | guidance_scale=None, 152 | top_p=None, 153 | top_k=200, 154 | temperature=1.0, 155 | max_new_tokens=None, 156 | ) 157 | 158 | # enhance using deepfilternet 159 | wav_file = wav_files[0] 160 | with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp: 161 | self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name) 162 | shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav") 163 | print(f"\nSaved audio to {wav_file}.wav") 164 | 165 | # calculating real-time factor (RTF) 166 | time_to_synth_s = time.time() - start 167 | audio, sr = librosa.load(str(wav_file) + ".wav") 168 | duration_s = librosa.get_duration(y=audio, sr=sr) 169 | real_time_factor = time_to_synth_s / duration_s 170 | print(f"\nTotal time to synth (s): {time_to_synth_s}") 171 | print(f"Real-time factor: {real_time_factor:.2f}") 172 | 173 | posthog.capture( 174 | TelemetryEvent( 175 | name="user_ran_tts", 176 | properties={ 177 | "model_name": self._model_name, 178 | "text": text, 179 | "temperature": temperature, 180 | "guidance_scale": guidance_scale, 181 | "top_p": top_p, 182 | "spk_ref_path": spk_ref_path, 183 | "speech_duration_s": duration_s, 184 | "time_to_synth_s": time_to_synth_s, 185 | "real_time_factor": round(real_time_factor, 2), 186 | "quantisation_mode": self._quantisation_mode, 187 | "seed": self._seed, 188 | "first_stage_ckpt": self._first_stage_ckpt, 189 | "gpu": torch.cuda.get_device_name(0), 190 | "telemetry_origin": self._telemetry_origin, 191 | }, 192 | ) 193 | ) 194 | 195 | return str(wav_file) + ".wav" 196 | 197 | 198 | if __name__ == "__main__": 199 | tts = tyro.cli(TTS) 200 | -------------------------------------------------------------------------------- /fam/llm/fast_inference_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted 5 | # provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this list of 8 | # conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this 11 | # list of conditions and the following disclaimer in the documentation and/or other 12 | # materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its contributors 15 | # may be used to endorse or promote products derived from this software without 16 | # specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR 19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import itertools 27 | import time 28 | import warnings 29 | from pathlib import Path 30 | from typing import Literal, Optional, Tuple 31 | 32 | import torch 33 | import torch._dynamo.config 34 | import torch._inductor.config 35 | import tqdm 36 | 37 | from fam.llm.fast_quantize import WeightOnlyInt4QuantHandler, WeightOnlyInt8QuantHandler 38 | 39 | 40 | def device_sync(device): 41 | if "cuda" in device: 42 | torch.cuda.synchronize() 43 | elif "cpu" in device: 44 | pass 45 | else: 46 | print(f"device={device} is not yet suppported") 47 | 48 | 49 | torch._inductor.config.coordinate_descent_tuning = True 50 | torch._inductor.config.triton.unique_kernel_names = True 51 | torch._inductor.config.fx_graph_cache = ( 52 | True # Experimental feature to reduce compilation times, will be on by default in future 53 | ) 54 | 55 | # imports need to happen after setting above flags 56 | from fam.llm.fast_model import Transformer 57 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder 58 | from fam.quantiser.text.tokenise import TrainedBPETokeniser 59 | 60 | 61 | def multinomial_sample_one_no_sync( 62 | probs_sort, 63 | ): # Does multinomial sampling without a cuda synchronization 64 | q = torch.empty_like(probs_sort).exponential_(1) 65 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 66 | 67 | 68 | def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor): 69 | # ref: huggingface/transformers 70 | 71 | sorted_logits, sorted_indices = torch.sort(logits, descending=False) 72 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 73 | 74 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 75 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p) 76 | # Keep at least min_tokens_to_keep 77 | sorted_indices_to_remove[-1:] = 0 78 | 79 | # scatter sorted tensors to original indexing 80 | indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) 81 | scores = logits.masked_fill(indices_to_remove, -float("Inf")) 82 | return scores 83 | 84 | 85 | def logits_to_probs( 86 | logits, 87 | *, 88 | temperature: torch.Tensor, 89 | top_p: Optional[torch.Tensor] = None, 90 | top_k: Optional[torch.Tensor] = None, 91 | ): 92 | logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature)) 93 | 94 | if top_k is not None: 95 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 96 | pivot = v.select(-1, -1).unsqueeze(-1) 97 | logits = torch.where(logits < pivot, -float("Inf"), logits) 98 | 99 | if top_p is not None: 100 | logits = top_p_sample(logits, top_p) 101 | 102 | probs = torch.nn.functional.softmax(logits, dim=-1) 103 | 104 | return probs 105 | 106 | 107 | def sample( 108 | logits, 109 | guidance_scale: torch.Tensor, 110 | temperature: torch.Tensor, 111 | top_p: Optional[torch.Tensor] = None, 112 | top_k: Optional[torch.Tensor] = None, 113 | ): 114 | # (b, t, vocab_size) 115 | logits = logits[:, -1] 116 | logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0) 117 | logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb 118 | probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k) 119 | idx_next = multinomial_sample_one_no_sync(probs) 120 | return idx_next, probs 121 | 122 | 123 | def prefill( 124 | model: Transformer, 125 | x: torch.Tensor, 126 | spk_emb: torch.Tensor, 127 | input_pos: torch.Tensor, 128 | **sampling_kwargs, 129 | ) -> torch.Tensor: 130 | # input_pos: [B, S] 131 | logits = model(x, spk_emb, input_pos) 132 | return sample(logits, **sampling_kwargs)[0] 133 | 134 | 135 | def decode_one_token( 136 | model: Transformer, 137 | x: torch.Tensor, 138 | spk_emb: torch.Tensor, 139 | input_pos: torch.Tensor, 140 | **sampling_kwargs, 141 | ) -> Tuple[torch.Tensor, torch.Tensor]: 142 | # input_pos: [B, 1] 143 | assert input_pos.shape[-1] == 1 144 | logits = model(x, spk_emb, input_pos) 145 | return sample(logits, **sampling_kwargs) 146 | 147 | 148 | def decode_n_tokens( 149 | model: Transformer, 150 | cur_token: torch.Tensor, 151 | spk_emb: torch.Tensor, 152 | input_pos: torch.Tensor, 153 | num_new_tokens: int, 154 | callback=lambda _: _, 155 | return_probs: bool = False, 156 | end_of_audio_token: int = 2048, 157 | **sampling_kwargs, 158 | ): 159 | new_tokens, new_probs = [], [] 160 | for i in tqdm.tqdm(range(num_new_tokens)): 161 | if (cur_token == end_of_audio_token).any(): 162 | break 163 | with torch.backends.cuda.sdp_kernel( 164 | enable_flash=False, enable_mem_efficient=False, enable_math=True 165 | ): # Actually better for Inductor to codegen attention here 166 | next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs) 167 | input_pos += 1 168 | new_tokens.append(next_token.clone()) 169 | callback(new_tokens[-1]) 170 | if return_probs: 171 | new_probs.append(next_prob.clone()) 172 | cur_token = next_token.view(1, -1).repeat(2, 1) 173 | 174 | return new_tokens, new_probs 175 | 176 | 177 | def model_forward(model, x, spk_emb, input_pos): 178 | return model(x, spk_emb, input_pos) 179 | 180 | 181 | @torch.no_grad() 182 | def generate( 183 | model: Transformer, 184 | prompt: torch.Tensor, 185 | spk_emb: torch.Tensor, 186 | *, 187 | max_new_tokens: Optional[int] = None, 188 | callback=lambda x: x, 189 | end_of_audio_token: int = 2048, 190 | **sampling_kwargs, 191 | ) -> torch.Tensor: 192 | """ 193 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 194 | """ 195 | # create an empty tensor of the expected final shape and fill in the current tokens 196 | T = prompt.size(0) 197 | if max_new_tokens is None: 198 | max_seq_length = model.config.block_size 199 | else: 200 | max_seq_length = T + max_new_tokens 201 | max_seq_length = min(max_seq_length, model.config.block_size) 202 | max_new_tokens = max_seq_length - T 203 | if max_new_tokens <= 0: 204 | raise ValueError("Prompt is too long to generate more tokens") 205 | 206 | device, dtype = prompt.device, prompt.dtype 207 | 208 | seq = torch.clone(prompt) 209 | input_pos = torch.arange(0, T, device=device) 210 | 211 | next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs) 212 | seq = torch.cat([seq, next_token.view(1)]) 213 | 214 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 215 | 216 | generated_tokens, _ = decode_n_tokens( 217 | model, 218 | next_token.view(1, -1).repeat(2, 1), 219 | spk_emb, 220 | input_pos, 221 | max_new_tokens - 1, 222 | callback=callback, 223 | end_of_audio_token=end_of_audio_token, 224 | **sampling_kwargs, 225 | ) 226 | seq = torch.cat([seq, torch.cat(generated_tokens)]) 227 | 228 | return seq 229 | 230 | 231 | def encode_tokens(tokenizer: TrainedBPETokeniser, text: str, device="cuda") -> torch.Tensor: 232 | tokens = tokenizer.encode(text) 233 | return torch.tensor(tokens, dtype=torch.int, device=device) 234 | 235 | 236 | def _load_model( 237 | checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode: Optional[Literal["int4", "int8"]] = None 238 | ): 239 | ##### MODEL 240 | with torch.device("meta"): 241 | model = Transformer.from_name("metavoice-1B") 242 | 243 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) 244 | state_dict = checkpoint["model"] 245 | # convert MetaVoice-1B model weights naming to gptfast naming 246 | unwanted_prefix = "_orig_mod." 247 | for k, v in list(state_dict.items()): 248 | if k.startswith(unwanted_prefix): 249 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 250 | state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight") 251 | state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight") 252 | state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight") 253 | state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight") 254 | for k, v in list(state_dict.items()): 255 | if k.startswith("transformer.h."): 256 | state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k) 257 | k = k.replace("transformer.h.", "layers.") 258 | if ".attn.c_attn." in k: 259 | state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k) 260 | k = k.replace(".attn.c_attn.", ".attention.wqkv.") 261 | if ".attn.c_proj." in k: 262 | state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k) 263 | k = k.replace(".attn.c_proj.", ".attention.wo.") 264 | if ".mlp.swiglu.w1." in k: 265 | state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k) 266 | k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.") 267 | if ".mlp.swiglu.w3." in k: 268 | state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k) 269 | k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.") 270 | if ".ln_1." in k: 271 | state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k) 272 | k = k.replace(".ln_1.", ".attention_norm.") 273 | if ".ln_2." in k: 274 | state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k) 275 | k = k.replace(".ln_2.", ".ffn_norm.") 276 | if ".mlp.c_proj." in k: 277 | state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k) 278 | k = k.replace(".mlp.c_proj.", ".feed_forward.w2.") 279 | 280 | model.load_state_dict(state_dict, assign=True) 281 | model = model.to(device=device, dtype=torch.bfloat16) 282 | 283 | if quantisation_mode == "int8": 284 | warnings.warn( 285 | "int8 quantisation is slower than bf16/fp16 for undebugged reasons! Please set optimisation_mode to `None` or to `int4`." 286 | ) 287 | warnings.warn( 288 | "quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality." 289 | ) 290 | simple_quantizer = WeightOnlyInt8QuantHandler(model) 291 | quantized_state_dict = simple_quantizer.create_quantized_state_dict() 292 | model = simple_quantizer.convert_for_runtime() 293 | model.load_state_dict(quantized_state_dict, assign=True) 294 | model = model.to(device=device, dtype=torch.bfloat16) 295 | # TODO: int8/int4 doesn't decrease VRAM usage substantially... fix that (might be linked to kv-cache) 296 | torch.cuda.empty_cache() 297 | elif quantisation_mode == "int4": 298 | warnings.warn( 299 | "quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality." 300 | ) 301 | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize=128) 302 | quantized_state_dict = simple_quantizer.create_quantized_state_dict() 303 | model = simple_quantizer.convert_for_runtime(use_cuda=True) 304 | model.load_state_dict(quantized_state_dict, assign=True) 305 | model = model.to(device=device, dtype=torch.bfloat16) 306 | torch.cuda.empty_cache() 307 | elif quantisation_mode is not None: 308 | raise Exception(f"Invalid quantisation mode {quantisation_mode}! Must be either 'int4' or 'int8'!") 309 | 310 | ###### TOKENIZER 311 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) 312 | tokenizer = TrainedBPETokeniser(**tokenizer_info) 313 | 314 | ###### SPEAKER EMBEDDER 315 | smodel = SpeakerEncoder( 316 | weights_fpath=spk_emb_ckpt_path, 317 | device=device, 318 | eval=True, 319 | verbose=False, 320 | ) 321 | return model.eval(), tokenizer, smodel 322 | 323 | 324 | def build_model( 325 | *, 326 | precision: torch.dtype, 327 | checkpoint_path: Path = Path(""), 328 | spk_emb_ckpt_path: Path = Path(""), 329 | compile_prefill: bool = False, 330 | compile: bool = True, 331 | device: str = "cuda", 332 | quantisation_mode: Optional[Literal["int4", "int8"]] = None, 333 | ): 334 | assert checkpoint_path.is_file(), checkpoint_path 335 | 336 | print(f"Using device={device}") 337 | 338 | print("Loading model ...") 339 | t0 = time.time() 340 | model, tokenizer, smodel = _load_model( 341 | checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode=quantisation_mode 342 | ) 343 | 344 | device_sync(device=device) # MKG 345 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 346 | 347 | torch.manual_seed(1234) 348 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) 349 | 350 | with torch.device(device): 351 | model.setup_spk_cond_mask() 352 | model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size) 353 | 354 | if compile: 355 | print("Compiling...Can take up to 2 mins.") 356 | global decode_one_token, prefill 357 | decode_one_token = torch.compile( 358 | decode_one_token, 359 | mode="max-autotune", 360 | fullgraph=True, 361 | ) 362 | 363 | if compile_prefill: 364 | prefill = torch.compile( 365 | prefill, 366 | fullgraph=True, 367 | dynamic=True, 368 | ) 369 | 370 | encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device) 371 | spk_emb = torch.randn((1, 256), device=device, dtype=precision) 372 | 373 | device_sync(device=device) # MKG 374 | t0 = time.perf_counter() 375 | y = generate( 376 | model, 377 | encoded, 378 | spk_emb, 379 | max_new_tokens=200, 380 | callback=lambda x: x, 381 | temperature=torch.tensor(1.0, device=device, dtype=precision), 382 | top_k=None, 383 | top_p=torch.tensor(0.95, device=device, dtype=precision), 384 | guidance_scale=torch.tensor(3.0, device=device, dtype=precision), 385 | end_of_audio_token=9999, # don't end early for compilation stage. 386 | ) 387 | 388 | device_sync(device=device) # MKG 389 | 390 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 391 | 392 | return model, tokenizer, smodel, model_size 393 | 394 | 395 | def main( 396 | *, 397 | model, 398 | tokenizer, 399 | model_size, 400 | prompt: str, 401 | guidance_scale: torch.Tensor, 402 | temperature: torch.Tensor, 403 | spk_emb: torch.Tensor, 404 | top_k: Optional[torch.Tensor] = None, 405 | top_p: Optional[torch.Tensor] = None, 406 | device: str = "cuda", 407 | ) -> list: 408 | """Generates text samples based on a pre-trained Transformer model and tokenizer.""" 409 | 410 | encoded = encode_tokens(tokenizer, prompt, device=device) 411 | prompt_length = encoded.size(0) 412 | 413 | aggregate_metrics: dict = { 414 | "tokens_per_sec": [], 415 | } 416 | 417 | device_sync(device=device) # MKG 418 | 419 | if True: 420 | callback = lambda x: x 421 | t0 = time.perf_counter() 422 | 423 | y = generate( 424 | model, 425 | encoded, 426 | spk_emb, 427 | callback=callback, 428 | temperature=temperature, 429 | top_k=top_k, 430 | top_p=top_p, 431 | guidance_scale=guidance_scale, 432 | ) 433 | 434 | device_sync(device=device) # MKG 435 | t = time.perf_counter() - t0 436 | 437 | tokens_generated = y.size(0) - prompt_length 438 | tokens_sec = tokens_generated / t 439 | aggregate_metrics["tokens_per_sec"].append(tokens_sec) 440 | print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") 441 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") 442 | # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") 443 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n") 444 | 445 | return y.tolist() 446 | -------------------------------------------------------------------------------- /fam/llm/fast_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted 5 | # provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this list of 8 | # conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this 11 | # list of conditions and the following disclaimer in the documentation and/or other 12 | # materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its contributors 15 | # may be used to endorse or promote products derived from this software without 16 | # specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR 19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | from dataclasses import dataclass 27 | from functools import reduce 28 | from math import gcd 29 | from typing import Optional, Tuple 30 | 31 | import torch 32 | import torch.nn as nn 33 | from torch import Tensor 34 | from torch.nn import functional as F 35 | 36 | from fam.llm.utils import get_default_dtype 37 | 38 | import logging 39 | 40 | # Adjust the logging level 41 | logger = logging.getLogger("torch") 42 | logger.setLevel(logging.ERROR) 43 | 44 | 45 | def find_multiple(n: int, *args: Tuple[int]) -> int: 46 | k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) 47 | if n % k == 0: 48 | return n 49 | return n + k - (n % k) 50 | 51 | 52 | @dataclass 53 | class ModelArgs: 54 | block_size: int = 2048 55 | vocab_size: int = 32000 56 | n_layer: int = 32 57 | n_head: int = 32 58 | dim: int = 4096 59 | speaker_emb_dim: int = 256 60 | intermediate_size: int = None 61 | n_local_heads: int = -1 62 | head_dim: int = 64 63 | norm_eps: float = 1e-5 64 | dtype: torch.dtype = torch.bfloat16 65 | 66 | def __post_init__(self): 67 | if self.n_local_heads == -1: 68 | self.n_local_heads = self.n_head 69 | if self.intermediate_size is None: 70 | hidden_dim = 4 * self.dim 71 | n_hidden = int(2 * hidden_dim / 3) 72 | self.intermediate_size = find_multiple(n_hidden, 256) 73 | self.head_dim = self.dim // self.n_head 74 | 75 | self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()] 76 | 77 | @classmethod 78 | def from_name(cls, name: str): 79 | if name in transformer_configs: 80 | return cls(**transformer_configs[name]) 81 | # fuzzy search 82 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] 83 | assert len(config) == 1, name 84 | return cls(**transformer_configs[config[0]]) 85 | 86 | 87 | transformer_configs = { 88 | "metavoice-1B": dict( 89 | n_layer=24, 90 | n_head=16, 91 | dim=2048, 92 | vocab_size=2562, 93 | ), 94 | } 95 | 96 | 97 | class KVCache(nn.Module): 98 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype): 99 | super().__init__() 100 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 101 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) 102 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) 103 | 104 | def update(self, input_pos, k_val, v_val): 105 | # input_pos: [S], k_val: [B, H, S, D] 106 | assert input_pos.shape[0] == k_val.shape[2] 107 | 108 | k_out = self.k_cache 109 | v_out = self.v_cache 110 | k_out[:, :, input_pos] = k_val 111 | v_out[:, :, input_pos] = v_val 112 | 113 | return k_out, v_out 114 | 115 | 116 | class Transformer(nn.Module): 117 | def __init__(self, config: ModelArgs) -> None: 118 | super().__init__() 119 | self.config = config 120 | 121 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 122 | self.pos_embeddings = nn.Embedding(config.block_size, config.dim) 123 | self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False) 124 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) 125 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 126 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 127 | 128 | self.mask_cache: Optional[Tensor] = None 129 | self.max_batch_size = -1 130 | self.max_seq_length = -1 131 | 132 | def setup_spk_cond_mask(self): 133 | self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool) 134 | self.spk_cond_mask[0] = 1 135 | 136 | def setup_caches(self, max_batch_size, max_seq_length): 137 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 138 | return 139 | head_dim = self.config.dim // self.config.n_head 140 | max_seq_length = find_multiple(max_seq_length, 8) 141 | self.max_seq_length = max_seq_length 142 | self.max_batch_size = max_batch_size 143 | for b in self.layers: 144 | b.attention.kv_cache = KVCache( 145 | max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype 146 | ) 147 | 148 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 149 | 150 | def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor: 151 | mask = self.causal_mask[None, None, input_pos] 152 | x = ( 153 | self.tok_embeddings(idx) 154 | + self.pos_embeddings(input_pos) 155 | # masking for speaker condition free guidance 156 | + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask 157 | ) 158 | 159 | for i, layer in enumerate(self.layers): 160 | x = layer(x, input_pos, mask) 161 | x = self.norm(x) 162 | logits = self.output(x) 163 | return logits 164 | 165 | @classmethod 166 | def from_name(cls, name: str): 167 | return cls(ModelArgs.from_name(name)) 168 | 169 | 170 | class TransformerBlock(nn.Module): 171 | def __init__(self, config: ModelArgs) -> None: 172 | super().__init__() 173 | self.attention = Attention(config) 174 | self.feed_forward = FeedForward(config) 175 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 176 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 177 | 178 | def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor: 179 | h = x + self.attention(self.attention_norm(x), mask, input_pos) 180 | out = h + self.feed_forward(self.ffn_norm(h)) 181 | return out 182 | 183 | 184 | class Attention(nn.Module): 185 | def __init__(self, config: ModelArgs): 186 | super().__init__() 187 | assert config.dim % config.n_head == 0 188 | 189 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 190 | # key, query, value projections for all heads, but in a batch 191 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 192 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 193 | self.kv_cache = None 194 | 195 | self.n_head = config.n_head 196 | self.head_dim = config.head_dim 197 | self.n_local_heads = config.n_local_heads 198 | self.dim = config.dim 199 | 200 | def forward( 201 | self, 202 | x: Tensor, 203 | mask: Tensor, 204 | input_pos: Optional[Tensor] = None, 205 | ) -> Tensor: 206 | bsz, seqlen, _ = x.shape 207 | 208 | kv_size = self.n_local_heads * self.head_dim 209 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 210 | 211 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 212 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 213 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 214 | 215 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 216 | 217 | if self.kv_cache is not None: 218 | k, v = self.kv_cache.update(input_pos, k, v) 219 | 220 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 221 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 222 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 223 | 224 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 225 | 226 | y = self.wo(y) 227 | return y 228 | 229 | 230 | class SwiGLU(nn.Module): 231 | def __init__(self, config: ModelArgs) -> None: 232 | super().__init__() 233 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 234 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 235 | 236 | def forward(self, x: Tensor) -> Tensor: 237 | return F.silu(self.w1(x)) * self.w3(x) 238 | 239 | 240 | class FeedForward(nn.Module): 241 | def __init__(self, config: ModelArgs) -> None: 242 | super().__init__() 243 | self.swiglu = SwiGLU(config) 244 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 245 | 246 | def forward(self, x: Tensor) -> Tensor: 247 | return self.w2(self.swiglu(x)) 248 | 249 | 250 | class RMSNorm(nn.Module): 251 | def __init__(self, dim: int, eps: float = 1e-5): 252 | super().__init__() 253 | self.eps = eps 254 | self.weight = nn.Parameter(torch.ones(dim)) 255 | 256 | def _norm(self, x): 257 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 258 | 259 | def forward(self, x: Tensor) -> Tensor: 260 | output = self._norm(x.float()).type_as(x) 261 | return output * self.weight 262 | -------------------------------------------------------------------------------- /fam/llm/fast_quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted 5 | # provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this list of 8 | # conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this 11 | # list of conditions and the following disclaimer in the documentation and/or other 12 | # materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its contributors 15 | # may be used to endorse or promote products derived from this software without 16 | # specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR 19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import time 27 | from pathlib import Path 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 34 | 35 | ##### Quantization Primitives ###### 36 | 37 | 38 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 39 | # assumes symmetric quantization 40 | # assumes axis == 0 41 | # assumes dense memory format 42 | # TODO(future): relax ^ as needed 43 | 44 | # default setup for affine quantization of activations 45 | eps = torch.finfo(torch.float32).eps 46 | 47 | # get min and max 48 | min_val, max_val = torch.aminmax(x, dim=1) 49 | 50 | # calculate scales and zero_points based on min and max 51 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 52 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 53 | device = min_val_neg.device 54 | 55 | max_val_pos = torch.max(-min_val_neg, max_val_pos) 56 | scales = max_val_pos / (float(quant_max - quant_min) / 2) 57 | # ensure scales is the same dtype as the original tensor 58 | scales = torch.clamp(scales, min=eps).to(x.dtype) 59 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 60 | 61 | # quantize based on qmin/qmax/scales/zp 62 | x_div = x / scales.unsqueeze(-1) 63 | x_round = torch.round(x_div) 64 | x_zp = x_round + zero_points.unsqueeze(-1) 65 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 66 | 67 | return quant, scales, zero_points 68 | 69 | 70 | def get_group_qparams(w, n_bit=4, groupsize=128): 71 | # needed for GPTQ with padding 72 | if groupsize > w.shape[-1]: 73 | groupsize = w.shape[-1] 74 | assert groupsize > 1 75 | assert w.shape[-1] % groupsize == 0 76 | assert w.dim() == 2 77 | 78 | to_quant = w.reshape(-1, groupsize) 79 | assert torch.isnan(to_quant).sum() == 0 80 | 81 | max_val = to_quant.amax(dim=1, keepdim=True) 82 | min_val = to_quant.amin(dim=1, keepdim=True) 83 | max_int = 2**n_bit - 1 84 | scales = (max_val - min_val).clamp(min=1e-6) / max_int 85 | zeros = min_val + scales * (2 ** (n_bit - 1)) 86 | return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(torch.bfloat16).reshape(w.shape[0], -1) 87 | 88 | 89 | def pack_scales_and_zeros(scales, zeros): 90 | assert scales.shape == zeros.shape 91 | assert scales.dtype == torch.bfloat16 92 | assert zeros.dtype == torch.bfloat16 93 | return ( 94 | torch.cat( 95 | [ 96 | scales.reshape(scales.size(0), scales.size(1), 1), 97 | zeros.reshape(zeros.size(0), zeros.size(1), 1), 98 | ], 99 | 2, 100 | ) 101 | .transpose(0, 1) 102 | .contiguous() 103 | ) 104 | 105 | 106 | def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): 107 | assert groupsize > 1 108 | # needed for GPTQ single column quantize 109 | if groupsize > w.shape[-1] and scales.shape[-1] == 1: 110 | groupsize = w.shape[-1] 111 | 112 | assert w.shape[-1] % groupsize == 0 113 | assert w.dim() == 2 114 | 115 | to_quant = w.reshape(-1, groupsize) 116 | assert torch.isnan(to_quant).sum() == 0 117 | 118 | scales = scales.reshape(-1, 1) 119 | zeros = zeros.reshape(-1, 1) 120 | min_val = zeros - scales * (2 ** (n_bit - 1)) 121 | max_int = 2**n_bit - 1 122 | min_int = 0 123 | w_int32 = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int).to(torch.int32).reshape_as(w) 124 | 125 | return w_int32 126 | 127 | 128 | def group_quantize_tensor(w, n_bit=4, groupsize=128): 129 | scales, zeros = get_group_qparams(w, n_bit, groupsize) 130 | w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) 131 | scales_and_zeros = pack_scales_and_zeros(scales, zeros) 132 | return w_int32, scales_and_zeros 133 | 134 | 135 | def group_dequantize_tensor_from_qparams(w_int32, scales, zeros, n_bit=4, groupsize=128): 136 | assert groupsize > 1 137 | # needed for GPTQ single column dequantize 138 | if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: 139 | groupsize = w_int32.shape[-1] 140 | assert w_int32.shape[-1] % groupsize == 0 141 | assert w_int32.dim() == 2 142 | 143 | w_int32_grouped = w_int32.reshape(-1, groupsize) 144 | scales = scales.reshape(-1, 1) 145 | zeros = zeros.reshape(-1, 1) 146 | 147 | w_dq = w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) 148 | return w_dq 149 | 150 | 151 | ##### Weight-only int8 per-channel quantized code ###### 152 | 153 | 154 | def replace_linear_weight_only_int8_per_channel(module): 155 | for name, child in module.named_children(): 156 | if isinstance(child, nn.Linear): 157 | setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) 158 | else: 159 | replace_linear_weight_only_int8_per_channel(child) 160 | 161 | 162 | class WeightOnlyInt8QuantHandler: 163 | def __init__(self, mod): 164 | self.mod = mod 165 | 166 | @torch.no_grad() 167 | def create_quantized_state_dict(self): 168 | cur_state_dict = self.mod.state_dict() 169 | for fqn, mod in self.mod.named_modules(): 170 | # TODO: quantise RMSNorm as well. 171 | if isinstance(mod, torch.nn.Linear): 172 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) 173 | cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu") 174 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu") 175 | 176 | return cur_state_dict 177 | 178 | def convert_for_runtime(self): 179 | replace_linear_weight_only_int8_per_channel(self.mod) 180 | return self.mod 181 | 182 | 183 | class WeightOnlyInt8Linear(torch.nn.Module): 184 | __constants__ = ["in_features", "out_features"] 185 | in_features: int 186 | out_features: int 187 | weight: torch.Tensor 188 | 189 | def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: 190 | factory_kwargs = {"device": device, "dtype": dtype} 191 | super().__init__() 192 | self.in_features = in_features 193 | self.out_features = out_features 194 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) 195 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 196 | 197 | def forward(self, input: torch.Tensor) -> torch.Tensor: 198 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 199 | 200 | 201 | ##### weight only int4 per channel groupwise quantized code ###### 202 | 203 | 204 | def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): 205 | weight_int32, scales_and_zeros = group_quantize_tensor(weight_bf16, n_bit=4, groupsize=groupsize) 206 | weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) 207 | return weight_int4pack, scales_and_zeros 208 | 209 | 210 | def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): 211 | origin_x_size = x.size() 212 | x = x.reshape(-1, origin_x_size[-1]) 213 | c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) 214 | new_shape = origin_x_size[:-1] + (out_features,) 215 | c = c.reshape(new_shape) 216 | return c 217 | 218 | 219 | def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): 220 | return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 221 | 222 | 223 | def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda): 224 | for name, child in module.named_children(): 225 | if isinstance(child, nn.Linear) and child.out_features % 8 == 0: 226 | if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): 227 | setattr( 228 | module, 229 | name, 230 | WeightOnlyInt4Linear( 231 | child.in_features, 232 | child.out_features, 233 | bias=False, 234 | groupsize=groupsize, 235 | inner_k_tiles=inner_k_tiles, 236 | padding=False, 237 | use_cuda=use_cuda, 238 | ), 239 | ) 240 | elif padding: 241 | setattr( 242 | module, 243 | name, 244 | WeightOnlyInt4Linear( 245 | child.in_features, 246 | child.out_features, 247 | bias=False, 248 | groupsize=groupsize, 249 | inner_k_tiles=inner_k_tiles, 250 | padding=True, 251 | use_cuda=use_cuda, 252 | ), 253 | ) 254 | else: 255 | replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda) 256 | 257 | 258 | class WeightOnlyInt4QuantHandler: 259 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): 260 | self.mod = mod 261 | self.groupsize = groupsize 262 | self.inner_k_tiles = inner_k_tiles 263 | self.padding = padding 264 | assert groupsize in [32, 64, 128, 256] 265 | assert inner_k_tiles in [2, 4, 8] 266 | 267 | @torch.no_grad() 268 | def create_quantized_state_dict(self): 269 | cur_state_dict = self.mod.state_dict() 270 | for fqn, mod in self.mod.named_modules(): 271 | if isinstance(mod, torch.nn.Linear): 272 | assert not mod.bias 273 | out_features = mod.out_features 274 | in_features = mod.in_features 275 | if out_features % 8 != 0: 276 | continue 277 | assert out_features % 8 == 0, "require out_features % 8 == 0" 278 | print(f"linear: {fqn}, in={in_features}, out={out_features}") 279 | 280 | weight = mod.weight.data 281 | if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): 282 | if self.padding: 283 | import torch.nn.functional as F 284 | from model import find_multiple 285 | 286 | print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") 287 | padded_in_features = find_multiple(in_features, 1024) 288 | weight = F.pad(weight, pad=(0, padded_in_features - in_features)) 289 | else: 290 | print( 291 | f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " 292 | + "and that groupsize and inner_k_tiles*16 evenly divide into it" 293 | ) 294 | continue 295 | weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( 296 | weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles 297 | ) 298 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") 299 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") 300 | 301 | return cur_state_dict 302 | 303 | def convert_for_runtime(self, use_cuda): 304 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) 305 | return self.mod 306 | 307 | 308 | class WeightOnlyInt4Linear(torch.nn.Module): 309 | __constants__ = ["in_features", "out_features"] 310 | in_features: int 311 | out_features: int 312 | weight: torch.Tensor 313 | 314 | def __init__( 315 | self, 316 | in_features: int, 317 | out_features: int, 318 | bias=True, 319 | device=None, 320 | dtype=None, 321 | groupsize: int = 128, 322 | inner_k_tiles: int = 8, 323 | padding: bool = True, 324 | use_cuda=True, 325 | ) -> None: 326 | super().__init__() 327 | self.padding = padding 328 | if padding: 329 | from model import find_multiple 330 | 331 | self.origin_in_features = in_features 332 | in_features = find_multiple(in_features, 1024) 333 | 334 | self.in_features = in_features 335 | self.out_features = out_features 336 | assert not bias, "require bias=False" 337 | self.groupsize = groupsize 338 | self.inner_k_tiles = inner_k_tiles 339 | 340 | assert out_features % 8 == 0, "require out_features % 8 == 0" 341 | assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" 342 | if use_cuda: 343 | self.register_buffer( 344 | "weight", 345 | torch.empty( 346 | (out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32 347 | ), 348 | ) 349 | else: 350 | self.register_buffer("weight", torch.empty((out_features, in_features // 2), dtype=torch.uint8)) 351 | self.register_buffer( 352 | "scales_and_zeros", torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) 353 | ) 354 | 355 | def forward(self, input: torch.Tensor) -> torch.Tensor: 356 | input = input.to(torch.bfloat16) 357 | if self.padding: 358 | import torch.nn.functional as F 359 | 360 | input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) 361 | return linear_forward_int4(input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize) 362 | -------------------------------------------------------------------------------- /fam/llm/finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module responsible for finetuning the first stage LLM. 3 | """ 4 | 5 | import itertools 6 | import math 7 | import time 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional 10 | 11 | import click 12 | import torch 13 | from huggingface_hub import snapshot_download 14 | from torch.utils.data import DataLoader 15 | from tqdm import tqdm 16 | 17 | from fam.llm.config.finetune_params import * 18 | from fam.llm.loaders.training_data import DynamicComputeDataset 19 | from fam.llm.model import GPT, GPTConfig 20 | from fam.llm.preprocessing.audio_token_mode import get_params_for_mode 21 | from fam.llm.preprocessing.data_pipeline import get_training_tuple 22 | from fam.llm.utils import hash_dictionary 23 | from fam.telemetry import TelemetryEvent 24 | from fam.telemetry.posthog import PosthogClient 25 | 26 | # see fam/telemetry/README.md for more information 27 | posthog = PosthogClient() 28 | 29 | dtype: Literal["bfloat16", "float16", "tfloat32", "float32"] = ( 30 | "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16" 31 | ) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 32 | seed_offset = 0 33 | 34 | torch.manual_seed(seed + seed_offset) 35 | torch.backends.cuda.matmul.allow_tf32 = True if dtype != "float32" else False 36 | torch.backends.cudnn.allow_tf32 = True if dtype != "float32" else False 37 | device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast 38 | # note: float16 data type will automatically use a GradScaler 39 | ptdtype = {"float32": torch.float32, "tfloat32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ 40 | dtype 41 | ] 42 | ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 43 | 44 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 45 | 46 | ckpts_base_dir = pathlib.Path(__file__).resolve().parent / "ckpts" 47 | if not os.path.exists(ckpts_base_dir) and master_process: 48 | print("Checkpoints directory didn't exist, creating...") 49 | ckpts_base_dir.mkdir(parents=True) 50 | 51 | if master_process: 52 | if "/" in out_dir: 53 | raise Exception("out_dir should be just a name, not a path with slashes") 54 | 55 | ckpts_save_dir = ckpts_base_dir / out_dir 56 | os.makedirs(ckpts_save_dir, exist_ok=True) 57 | 58 | 59 | def get_globals_state(): 60 | """Return entirety of configuration global state which can be used for logging.""" 61 | config_keys = [k for k, v in globals().items() if not k.startswith("_") and isinstance(v, (int, float, bool, str))] 62 | return {k: globals()[k] for k in config_keys} # will be useful for logging 63 | 64 | 65 | model_args: dict = dict( 66 | n_layer=n_layer, 67 | n_head=n_head, 68 | n_embd=n_embd, 69 | block_size=block_size, 70 | bias=bias, 71 | vocab_sizes=None, 72 | dropout=dropout, 73 | causal=causal, 74 | norm_type=norm_type, 75 | rmsnorm_eps=rmsnorm_eps, 76 | nonlinearity_type=nonlinearity_type, 77 | spk_emb_on_text=spk_emb_on_text, 78 | attn_kernel_type=attn_kernel_type, 79 | swiglu_multiple_of=swiglu_multiple_of, 80 | ) # start with model_args from command line 81 | 82 | 83 | def strip_prefix(state_dict: Dict[str, Any], unwanted_prefix: str): 84 | # TODO: this also appears in fast_inference_utils._load_model, it should be moved to a common place. 85 | for k, v in list(state_dict.items()): 86 | if k.startswith(unwanted_prefix): 87 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 88 | return state_dict 89 | 90 | 91 | def force_ckpt_args(model_args, checkpoint_model_args) -> None: 92 | # force these config attributes to be equal otherwise we can't even resume training 93 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 94 | for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_sizes", "causal"]: 95 | model_args[k] = checkpoint_model_args[k] 96 | # this enables backward compatability with previously saved checkpoints. 97 | for k in [ 98 | "target_vocab_sizes", 99 | "norm_type", 100 | "rmsnorm_eps", 101 | "nonlinearity_type", 102 | "attn_kernel_type", 103 | "spk_emb_on_text", 104 | "swiglu_multiple_of", 105 | ]: 106 | if k in checkpoint_model_args: 107 | model_args[k] = checkpoint_model_args[k] 108 | if attn_kernel_type != model_args["attn_kernel_type"]: 109 | print( 110 | f'Found {model_args["attn_kernel_type"]} kernel type inside model,', 111 | f"but expected {attn_kernel_type}. Manually replacing it.", 112 | ) 113 | model_args["attn_kernel_type"] = attn_kernel_type 114 | 115 | 116 | @click.command() 117 | @click.option("--train", type=click.Path(exists=True, path_type=Path), required=True) 118 | @click.option("--val", type=click.Path(exists=True, path_type=Path), required=True) 119 | @click.option("--model-id", type=str, required=False, default="metavoiceio/metavoice-1B-v0.1") 120 | @click.option("--ckpt", type=click.Path(exists=True, path_type=Path)) 121 | @click.option("--spk-emb-ckpt", type=click.Path(exists=True, path_type=Path)) 122 | def main(train: Path, val: Path, model_id: str, ckpt: Optional[Path], spk_emb_ckpt: Optional[Path]): 123 | if ckpt and spk_emb_ckpt: 124 | checkpoint_path, spk_emb_ckpt_path = ckpt, spk_emb_ckpt 125 | else: 126 | _model_dir = snapshot_download(repo_id=model_id) 127 | checkpoint_path = Path(f"{_model_dir}/first_stage.pt") 128 | spk_emb_ckpt_path = Path(f"{_model_dir}/speaker_encoder.pt") 129 | 130 | mode_params = get_params_for_mode(audio_token_mode, num_max_audio_tokens_timesteps=num_max_audio_tokens_timesteps) 131 | config = get_globals_state() 132 | 133 | checkpoint = torch.load(str(checkpoint_path), mmap=True, map_location=device) 134 | iter_num = checkpoint.get("iter_num", 0) 135 | best_val_loss = checkpoint.get("best_val_loss", 1e9) 136 | checkpoint_model_args = checkpoint["model_args"] 137 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) 138 | force_ckpt_args(model_args, checkpoint_model_args) 139 | gptconf = GPTConfig(**model_args) # type: ignore 140 | model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if speaker_cond else None) 141 | 142 | # removing torch.compile module prefixes for pre-compile loading 143 | state_dict = strip_prefix(checkpoint["model"], "_orig_mod.") 144 | model.load_state_dict(state_dict) 145 | model.to(device) 146 | # initialize a GradScaler. If enabled=False scaler is a no-op 147 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) 148 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 149 | if compile: 150 | print("Compiling the model... (takes a ~minute)") 151 | # requires PyTorch 2.0 152 | from einops._torch_specific import allow_ops_in_compiled_graph 153 | 154 | allow_ops_in_compiled_graph() 155 | model = torch.compile(model) # type: ignore 156 | 157 | def estimate_loss(dataset, iters: int = eval_iters): 158 | """Estimate loss on a dataset by running on `iters` batches.""" 159 | if dataset is None: 160 | return torch.nan 161 | losses = [] 162 | for _, batch in zip(tqdm(range(iters)), dataset): 163 | X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device) 164 | with ctx: 165 | _, loss = model(X, Y, speaker_embs=SE, speaker_emb_mask=None) 166 | losses.append(loss.item()) 167 | return torch.tensor(losses).mean() 168 | 169 | # learning rate decay scheduler (cosine with warmup) 170 | def get_lr(it): 171 | # 1) linear warmup for warmup_iters steps 172 | if it < warmup_iters: 173 | return learning_rate * it / warmup_iters 174 | # 2) if it > lr_decay_iters, return min learning rate 175 | if it > lr_decay_iters: 176 | return min_lr 177 | # 3) in between, use cosine decay down to min learning rate 178 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 179 | assert 0 <= decay_ratio <= 1 180 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 181 | return min_lr + coeff * (learning_rate - min_lr) 182 | 183 | if wandb_log and master_process: 184 | import wandb 185 | 186 | if os.environ.get("WANDB_RUN_ID", None) is not None: 187 | resume = "must" 188 | else: 189 | resume = None 190 | 191 | wandb.init(project=wandb_project, name=wandb_run_name, tags=wandb_tags, config=config, resume=resume) 192 | 193 | train_dataset = DynamicComputeDataset.from_meta( 194 | tokenizer_info, 195 | mode_params["combine_func"], 196 | spk_emb_ckpt_path, 197 | train, 198 | mode_params["pad_token"], 199 | mode_params["ctx_window"], 200 | device, 201 | ) 202 | val_dataset = DynamicComputeDataset.from_meta( 203 | tokenizer_info, 204 | mode_params["combine_func"], 205 | spk_emb_ckpt_path, 206 | val, 207 | mode_params["pad_token"], 208 | mode_params["ctx_window"], 209 | device, 210 | ) 211 | train_dataloader = itertools.cycle(DataLoader(train_dataset, batch_size, shuffle=True)) 212 | train_data = iter(train_dataloader) 213 | # we do not perform any explicit checks for dataset overlap & leave it to the user 214 | # to handle this 215 | eval_val_data = DataLoader(val_dataset, batch_size, shuffle=True) 216 | # we can use the same Dataset object given it is a mapped dataset & not an iterable 217 | # one that can be exhausted. This implies we will be needlessly recomputing, fine 218 | # for now. 219 | eval_train_data = DataLoader(train_dataset, batch_size, shuffle=True) 220 | 221 | batch = next(train_data) 222 | X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device) 223 | 224 | t0 = time.time() 225 | local_iter_num = 0 # number of iterations in the lifetime of this process 226 | raw_model = model.module if ddp else model # unwrap DDP container if needed 227 | running_mfu = -1.0 228 | total_norm = 0.0 229 | save_checkpoint = False 230 | if master_process: 231 | progress = tqdm(total=max_iters, desc="Training", initial=iter_num) 232 | else: 233 | progress = None 234 | 235 | # finetune last X transformer blocks and the ln_f layer 236 | trainable_count = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad) 237 | print(f"Before layer freezing {trainable_count(model)=}...") 238 | for param in model.parameters(): 239 | param.requires_grad = False 240 | for param in itertools.chain( 241 | model.transformer.ln_f.parameters(), model.transformer.h[last_n_blocks_to_finetune * -1 :].parameters() 242 | ): 243 | param.requires_grad = True 244 | print(f"After freezing excl. last {last_n_blocks_to_finetune} transformer blocks: {trainable_count(model)=}...") 245 | 246 | # log start of finetuning event 247 | properties = { 248 | **config, 249 | **model_args, 250 | "train": str(train), 251 | "val": str(val), 252 | "model_id": model_id, 253 | "ckpt": ckpt, 254 | "spk_emb_ckpt": spk_emb_ckpt, 255 | } 256 | finetune_jobid = hash_dictionary(properties) 257 | posthog.capture( 258 | TelemetryEvent( 259 | name="user_started_finetuning", 260 | properties={"finetune_jobid": finetune_jobid, **properties}, 261 | ) 262 | ) 263 | 264 | while True: 265 | lr = get_lr(iter_num) if decay_lr else learning_rate 266 | for param_group in optimizer.param_groups: 267 | param_group["lr"] = lr 268 | if master_process: 269 | if iter_num % eval_interval == 0 and master_process: 270 | ckpt_save_name = f"ckpt_{iter_num:07d}.pt" 271 | with torch.no_grad(): 272 | model.eval() 273 | losses = { 274 | "train": estimate_loss(eval_train_data), 275 | "val": estimate_loss(eval_val_data), 276 | } 277 | model.train() 278 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 279 | if wandb_log: 280 | wandb.log( 281 | { 282 | "iter": iter_num, 283 | "train/loss": losses["train"], 284 | "val/loss": losses["val"], 285 | "lr": lr, 286 | "mfu": running_mfu * 100, # convert to percentage 287 | "stats/total_norm": total_norm, 288 | } 289 | ) 290 | if losses["val"] < best_val_loss: 291 | best_val_loss = losses["val"] 292 | if iter_num > 0: 293 | ckpt_save_name = ckpt_save_name.replace( 294 | ".pt", f"_bestval_{best_val_loss}".replace(".", "_") + ".pt" 295 | ) 296 | save_checkpoint = True 297 | 298 | save_checkpoint = save_checkpoint or iter_num % save_interval == 0 299 | if save_checkpoint and iter_num > 0: 300 | checkpoint = { 301 | "model": raw_model.state_dict(), # type: ignore 302 | "optimizer": optimizer.state_dict(), 303 | "model_args": model_args, 304 | "iter_num": iter_num, 305 | "best_val_loss": best_val_loss, 306 | "config": config, 307 | "meta": { 308 | "speaker_cond": speaker_cond, 309 | "speaker_emb_size": speaker_emb_size, 310 | "tokenizer": tokenizer_info, 311 | }, 312 | } 313 | torch.save(checkpoint, os.path.join(ckpts_save_dir, ckpt_save_name)) 314 | print(f"saving checkpoint to {ckpts_save_dir}") 315 | save_checkpoint = False 316 | if iter_num == 0 and eval_only: 317 | break 318 | # forward backward update, with optional gradient accumulation to simulate larger batch size 319 | # and using the GradScaler if data type is float16 320 | for micro_step in range(gradient_accumulation_steps): 321 | if ddp: 322 | # in DDP training we only need to sync gradients at the last micro step. 323 | # the official way to do this is with model.no_sync() context manager, but 324 | # I really dislike that this bloats the code and forces us to repeat code 325 | # looking at the source of that context manager, it just toggles this variable 326 | model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 # type: ignore 327 | with ctx: # type: ignore 328 | logits, loss = model(X, Y, speaker_embs=SE, speaker_emb_mask=None) 329 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 330 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 331 | batch = next(train_data) 332 | X, Y, SE = get_training_tuple( 333 | batch, 334 | causal, 335 | num_codebooks, 336 | speaker_cond, 337 | device, 338 | ) 339 | # backward pass, with gradient scaling if training in fp16 340 | scaler.scale(loss).backward() 341 | # clip the gradient 342 | if grad_clip != 0.0: 343 | scaler.unscale_(optimizer) 344 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 345 | # step the optimizer and scaler if training in fp16 346 | scaler.step(optimizer) 347 | scaler.update() 348 | # flush the gradients as soon as we can, no need for this memory anymore 349 | optimizer.zero_grad(set_to_none=True) 350 | 351 | # timing and logging 352 | t1 = time.time() 353 | dt = t1 - t0 354 | t0 = t1 355 | if master_process: 356 | # get loss as float. note: this is a CPU-GPU sync point 357 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 358 | lossf = loss.item() * gradient_accumulation_steps 359 | progress.update(1) 360 | progress.set_description(f"Training: loss {lossf:.4f}, time {dt*1000:.2f}ms") 361 | if iter_num % log_interval == 0: 362 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms") 363 | 364 | iter_num += 1 365 | local_iter_num += 1 366 | 367 | # termination conditions 368 | if iter_num > max_iters: 369 | # log end of finetuning event 370 | posthog.capture( 371 | TelemetryEvent( 372 | name="user_completed_finetuning", 373 | properties={"finetune_jobid": finetune_jobid, "loss": round(lossf, 4)}, 374 | ) 375 | ) 376 | break 377 | 378 | 379 | if __name__ == "__main__": 380 | main() 381 | -------------------------------------------------------------------------------- /fam/llm/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.layers.attn import SelfAttention 2 | from fam.llm.layers.combined import Block 3 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU 4 | -------------------------------------------------------------------------------- /fam/llm/layers/attn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class SelfAttention(nn.Module): 9 | def __init__(self, config): 10 | """ 11 | Initializes the SelfAttention module. 12 | 13 | Args: 14 | config: An object containing the configuration parameters for the SelfAttention module. 15 | """ 16 | super().__init__() 17 | self._validate_config(config) 18 | self._initialize_parameters(config) 19 | 20 | def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype): 21 | """ 22 | Empties the key-value cache. 23 | 24 | Args: 25 | batch_size: The batch size. 26 | kv_cache_maxlen: The maximum length of the key-value cache. 27 | dtype: The data type of the cache. 28 | 29 | Raises: 30 | Exception: If trying to empty the KV cache when it is disabled. 31 | """ 32 | if self.kv_cache_enabled is False: 33 | raise Exception("Trying to empty KV cache when it is disabled") 34 | 35 | # register so that the cache moves devices along with the module 36 | # TODO: get rid of re-allocation. 37 | self.register_buffer( 38 | "kv_cache", 39 | torch.zeros( 40 | 2, 41 | batch_size, 42 | kv_cache_maxlen, 43 | self.n_head, 44 | self.n_embd // self.n_head, 45 | dtype=dtype, 46 | device=self.c_attn.weight.device, 47 | ), 48 | persistent=False, 49 | ) 50 | 51 | self.kv_cache_first_empty_index = 0 52 | 53 | def _initialize_parameters(self, config): 54 | """ 55 | Initializes the parameters of the SelfAttention module. 56 | 57 | Args: 58 | config: An object containing the configuration parameters for the SelfAttention module. 59 | """ 60 | # key, query, value projections for all heads, but in a batch 61 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 62 | 63 | # output projection 64 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 65 | 66 | # regularization 67 | self.resid_dropout = nn.Dropout(config.dropout) 68 | self.n_head = config.n_head 69 | self.n_embd = config.n_embd 70 | self.dropout = config.dropout 71 | self.causal = config.causal 72 | self.attn_kernel_type = config.attn_kernel_type 73 | self.attn_dropout = nn.Dropout(config.dropout) 74 | 75 | self.kv_cache_enabled = False 76 | 77 | def _validate_config(self, config): 78 | """ 79 | Validates the configuration parameters. 80 | 81 | Args: 82 | config: An object containing the configuration parameters for the SelfAttention module. 83 | 84 | Raises: 85 | AssertionError: If the embedding dimension is not divisible by the number of heads. 86 | """ 87 | assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads" 88 | 89 | def _update_kv_cache(self, q, k, v): 90 | """ 91 | Updates the key-value cache. 92 | 93 | Args: 94 | q: The query tensor. 95 | k: The key tensor. 96 | v: The value tensor. 97 | 98 | Returns: 99 | The updated key and value tensors. 100 | 101 | Raises: 102 | AssertionError: If the dimensions of the query, key, and value tensors are not compatible. 103 | """ 104 | q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1] 105 | 106 | if self.kv_cache_first_empty_index == 0: 107 | assert q_time == k_time and q_time == v_time 108 | else: 109 | assert ( 110 | q_time == 1 111 | ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}" 112 | 113 | self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k 114 | self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v 115 | self.kv_cache_first_empty_index += q_time 116 | 117 | k = self.kv_cache[0, :, : self.kv_cache_first_empty_index] 118 | v = self.kv_cache[1, :, : self.kv_cache_first_empty_index] 119 | 120 | return k, v 121 | 122 | def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: 123 | """ 124 | Performs attention using the torch.nn.functional.scaled_dot_product_attention function. 125 | 126 | Args: 127 | c_x: The input tensor. 128 | 129 | Returns: 130 | The output tensor. 131 | """ 132 | q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs) 133 | q = q.squeeze(2) # (B, T, nh, hs) 134 | k = k.squeeze(2) # (B, T, nh, hs) 135 | v = v.squeeze(2) # (B, T, nh, hs) 136 | 137 | # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and 138 | # use no mask for the "one time step" parts. 139 | # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index 140 | is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0) 141 | 142 | if self.kv_cache_enabled: 143 | k, v = self._update_kv_cache(q, k, v) 144 | 145 | q = q.transpose(1, 2) # (B, nh, T, hs) 146 | k = k.transpose(1, 2) # (B, nh, T, hs) 147 | v = v.transpose(1, 2) # (B, nh, T, hs) 148 | y = torch.nn.functional.scaled_dot_product_attention( 149 | q, 150 | k, 151 | v, 152 | attn_mask=None, 153 | dropout_p=self.dropout if self.training else 0, 154 | is_causal=is_causal_attn_mask, 155 | ).transpose( 156 | 1, 2 157 | ) # (B, nh, T, hs) -> (B, T, nh, hs) 158 | 159 | return y 160 | 161 | def forward(self, x): 162 | """ 163 | Performs the forward pass of the SelfAttention module. 164 | 165 | Args: 166 | x: The input tensor. 167 | 168 | Returns: 169 | The output tensor. 170 | """ 171 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 172 | 173 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 174 | c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs) 175 | 176 | # causal self-attention; 177 | if self.attn_kernel_type == "torch_attn": 178 | y = self._torch_attn(c_x) 179 | else: 180 | raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") 181 | 182 | y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh) 183 | # output projection 184 | y = self.resid_dropout(self.c_proj(y)) 185 | return y 186 | -------------------------------------------------------------------------------- /fam/llm/layers/combined.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fam.llm.layers.attn import SelfAttention 4 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm 5 | 6 | 7 | class Block(nn.Module): 8 | """ 9 | Block class represents a single block in the model. 10 | 11 | Args: 12 | config (object): Configuration object containing parameters for the block. 13 | 14 | Attributes: 15 | ln_1 (object): Layer normalization for the attention layer. 16 | ln_2 (object): Layer normalization for the feed-forward layer. 17 | attn (object): Self-attention layer. 18 | mlp (object): Multi-layer perceptron layer. 19 | 20 | Methods: 21 | forward(x): Performs forward pass through the block. 22 | """ 23 | 24 | def __init__(self, config): 25 | super().__init__() 26 | if config.norm_type == "rmsnorm": 27 | if config.rmsnorm_eps is None: 28 | raise Exception("RMSNorm requires rmsnorm_eps to be set") 29 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm 30 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm 31 | elif config.norm_type == "layernorm": 32 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm 33 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm 34 | else: 35 | raise Exception(f"Unknown norm type: {config.norm_type}") 36 | self.attn = SelfAttention(config) 37 | 38 | self.mlp = MLP(config) 39 | 40 | def forward(self, x): 41 | """ 42 | Performs forward pass through the block. 43 | 44 | Args: 45 | x (tensor): Input tensor. 46 | 47 | Returns: 48 | tensor: Output tensor after passing through the block. 49 | """ 50 | x = x + self.attn(self.ln_1(x)) 51 | x = x + self.mlp(self.ln_2(x)) 52 | return x 53 | -------------------------------------------------------------------------------- /fam/llm/layers/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class LayerNorm(nn.Module): 9 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" 10 | 11 | def __init__(self, ndim, bias): 12 | super().__init__() 13 | self.weight = nn.Parameter(torch.ones(ndim)) 14 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 15 | 16 | def forward(self, input): 17 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 18 | 19 | 20 | class RMSNorm(nn.Module): 21 | def __init__(self, ndim: int, eps: float): 22 | super().__init__() 23 | self.eps = eps 24 | self.weight = nn.Parameter(torch.ones(ndim)) 25 | 26 | def _norm(self, x): 27 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 28 | 29 | def forward(self, x): 30 | return self._norm(x) * self.weight 31 | 32 | 33 | class SwiGLU(nn.Module): 34 | def __init__(self, in_dim, out_dim, bias) -> None: 35 | super().__init__() 36 | self.w1 = nn.Linear(in_dim, out_dim, bias=bias) 37 | self.w3 = nn.Linear(in_dim, out_dim, bias=bias) 38 | 39 | def forward(self, x): 40 | return F.silu(self.w1(x)) * self.w3(x) 41 | 42 | 43 | class MLP(nn.Module): 44 | def __init__(self, config): 45 | super().__init__() 46 | self.non_linearity = config.nonlinearity_type 47 | hidden_dim = 4 * config.n_embd 48 | if config.nonlinearity_type == "gelu": 49 | self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) 50 | self.gelu = nn.GELU() 51 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) 52 | elif config.nonlinearity_type == "swiglu": 53 | if config.swiglu_multiple_of is None: 54 | raise Exception("SwiGLU requires swiglu_multiple_of to be set") 55 | hidden_dim = int(2 * hidden_dim / 3) 56 | hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of) 57 | # set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__() 58 | self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias) 59 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) 60 | else: 61 | raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}") 62 | self.dropout = nn.Dropout(config.dropout) 63 | 64 | def forward(self, x): 65 | if self.non_linearity == "gelu": 66 | x = self.c_fc(x) 67 | x = self.gelu(x) 68 | elif self.non_linearity == "swiglu": 69 | x = self.swiglu(x) 70 | x = self.c_proj(x) 71 | x = self.dropout(x) 72 | return x 73 | -------------------------------------------------------------------------------- /fam/llm/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/loaders/__init__.py -------------------------------------------------------------------------------- /fam/llm/loaders/training_data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Mapping 3 | 4 | import julius 5 | import torch 6 | import math 7 | import numpy as np 8 | import pandas as pd 9 | from audiocraft.data.audio import audio_read 10 | from encodec import EncodecModel 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from fam.llm.fast_inference_utils import encode_tokens 14 | from fam.llm.preprocessing.audio_token_mode import CombinerFuncT, CombinerFuncT 15 | from fam.llm.preprocessing.data_pipeline import pad_tokens 16 | from fam.llm.utils import normalize_text 17 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder 18 | from fam.quantiser.text.tokenise import TrainedBPETokeniser 19 | 20 | MBD_SAMPLE_RATE = 24000 21 | ENCODEC_BANDWIDTH = 6 22 | 23 | 24 | class DynamicComputeDataset(Dataset): 25 | def __init__( 26 | self, 27 | dataset_dir: Path | str, 28 | encodec_model: EncodecModel, 29 | tokenizer: TrainedBPETokeniser, 30 | spkemb_model: SpeakerEncoder, 31 | combiner: CombinerFuncT, 32 | pad_token: int, 33 | ctx_window: int, 34 | device: str, 35 | ): 36 | self.dataset_dir = dataset_dir 37 | self.encodec_model = encodec_model 38 | self.tokenizer = tokenizer 39 | self.spkemb_model = spkemb_model 40 | self.device = device 41 | self.combiner = combiner 42 | self.pad_token = pad_token 43 | self.ctx_window = ctx_window 44 | self.df = pd.read_csv(dataset_dir, delimiter="|", index_col=False) 45 | 46 | @classmethod 47 | def from_meta( 48 | cls, 49 | tokenizer_info: Mapping[str, Any], 50 | combiner: CombinerFuncT, 51 | speaker_embedding_ckpt_path: Path | str, 52 | dataset_dir: Path | str, 53 | pad_token: int, 54 | ctx_window: int, 55 | device: str 56 | ): 57 | encodec = EncodecModel.encodec_model_24khz().to(device) 58 | encodec.set_target_bandwidth(ENCODEC_BANDWIDTH) 59 | smodel = SpeakerEncoder( 60 | weights_fpath=str(speaker_embedding_ckpt_path), 61 | eval=True, 62 | device=device, 63 | verbose=False, 64 | ) 65 | tokeniser = TrainedBPETokeniser(**tokenizer_info) 66 | 67 | return cls( 68 | dataset_dir, 69 | encodec, 70 | tokeniser, 71 | smodel, 72 | combiner, 73 | pad_token, 74 | ctx_window, 75 | device 76 | ) 77 | 78 | def __len__(self): 79 | return len(self.df) 80 | 81 | def __getitem__(self, idx): 82 | audio_path, text = self.df.iloc[idx].values.tolist() 83 | with torch.no_grad(): 84 | text_tokens = self._extract_text_tokens(text) 85 | encodec_tokens = self._extract_encodec_tokens(audio_path) 86 | speaker_embedding = self._extract_speaker_embedding(audio_path) 87 | combined = self.combiner(encodec_tokens, text_tokens) 88 | padded_combined_tokens = pad_tokens(combined, self.ctx_window, self.pad_token) 89 | 90 | return {"tokens": padded_combined_tokens, "spkemb": speaker_embedding} 91 | 92 | def _extract_text_tokens(self, text: str): 93 | _text = normalize_text(text) 94 | _tokens = encode_tokens(self.tokenizer, _text, self.device) 95 | 96 | return _tokens.detach().cpu().numpy() 97 | 98 | def _extract_encodec_tokens(self, audio_path: str): 99 | wav, sr = audio_read(audio_path) 100 | if sr != MBD_SAMPLE_RATE: 101 | wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE) 102 | 103 | # Convert to mono and fix dimensionality 104 | if wav.ndim == 2: 105 | wav = wav.mean(axis=0, keepdims=True) 106 | wav = wav.unsqueeze(0) # Add batch dimension 107 | 108 | wav = wav.to(self.device) 109 | tokens = self.encodec_model.encode(wav) 110 | _tokens = tokens[0][0][0].detach().cpu().numpy() # shape = [8, T] 111 | 112 | return _tokens 113 | 114 | def _extract_speaker_embedding(self, audio_path: str): 115 | emb = self.spkemb_model.embed_utterance_from_file(audio_path, numpy=False) # shape = [256,] 116 | return emb.unsqueeze(0).detach() 117 | -------------------------------------------------------------------------------- /fam/llm/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.mixins.causal import CausalInferenceMixin 2 | from fam.llm.mixins.non_causal import NonCausalInferenceMixin 3 | -------------------------------------------------------------------------------- /fam/llm/mixins/non_causal.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class NonCausalInferenceMixin: 8 | """ 9 | Mixin class for non-causal inference in a language model. 10 | 11 | This class provides methods for performing non-causal sampling using a language model. 12 | """ 13 | 14 | @torch.no_grad() 15 | def _non_causal_sample( 16 | self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int 17 | ): 18 | """ 19 | Perform non-causal sampling. 20 | 21 | Args: 22 | idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length). 23 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size). 24 | temperature (float): Temperature parameter for scaling the logits. 25 | top_k (int): Number of top options to consider. 26 | 27 | Returns: 28 | torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length). 29 | """ 30 | b, c, t = idx.size() 31 | assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}" 32 | # forward the model to get the logits for the index in the sequence 33 | list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size) 34 | 35 | # scale by desired temperature 36 | list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size) 37 | 38 | # optionally crop the logits to only the top k options 39 | if top_k is not None: 40 | for i in range(len(list_logits)): 41 | logits = list_logits[i] # (b, t, vocab_size) 42 | 43 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k) 44 | logits[logits < v[:, :, [-1]]] = -float("Inf") 45 | list_logits[i] = logits # (b, t, vocab_size) 46 | assert logits.shape[0] == b and logits.shape[1] == t 47 | 48 | # apply softmax to convert logits to (normalized) probabilities 49 | # TODO: check shapes here! 50 | probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k) 51 | assert probs[0].shape[0] == b and probs[0].shape[1] == t 52 | 53 | # TODO: output shape is as expected 54 | outs = [] 55 | for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k) 56 | out = [ 57 | torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob 58 | ] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t) 59 | assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t 60 | out = torch.cat(out, dim=0) # (b, 1, t) 61 | assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t 62 | outs.append(out) 63 | 64 | out = torch.cat(outs, dim=1) # (b, c, t) 65 | assert out.shape[0] == b and out.shape[2] == t 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /fam/llm/model.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from dataclasses import dataclass, field 4 | from typing import Literal, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import tqdm 9 | from einops import rearrange 10 | from torch.nn import functional as F 11 | 12 | from fam.llm.layers import Block, LayerNorm, RMSNorm 13 | from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin 14 | 15 | END_OF_TEXT_TOKEN = 1537 16 | 17 | 18 | def _select_spkemb(spkemb, mask): 19 | _, examples, _ = spkemb.shape 20 | mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples) 21 | spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex 22 | mask = mask.transpose(1, 2) # b t ex -> b ex t 23 | return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c 24 | 25 | 26 | @dataclass 27 | class GPTConfig: 28 | block_size: int = 1024 29 | vocab_sizes: list = field(default_factory=list) 30 | target_vocab_sizes: Optional[list] = None 31 | n_layer: int = 12 32 | n_head: int = 12 33 | n_embd: int = 768 34 | dropout: float = 0.0 35 | spkemb_dropout: float = 0.0 36 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 37 | causal: bool = ( 38 | True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens 39 | ) 40 | spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not 41 | norm_type: str = "layernorm" # "rmsnorm" or "layernorm 42 | rmsnorm_eps: Optional[float] = None # only used for rmsnorm 43 | nonlinearity_type: str = "gelu" # "gelu" or "swiglu" 44 | swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this 45 | attn_kernel_type: Literal["torch_attn"] = "torch_attn" 46 | kv_cache_enabled: bool = False # whether to use key-value cache for attention 47 | 48 | 49 | def _check_speaker_emb_dims( 50 | speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int 51 | ) -> Union[torch.Tensor, list]: 52 | """ 53 | Checks that the speaker embedding dimensions are correct, and reshapes them if necessary. 54 | """ 55 | if type(speaker_embs) == list: 56 | b_se = len(speaker_embs) 57 | for i, s in enumerate(speaker_embs): 58 | if s is not None: 59 | emb_dim = s.shape[-1] 60 | if s.ndim == 1: 61 | speaker_embs[i] = speaker_embs[i].unsqueeze(0) 62 | else: 63 | if speaker_embs.ndim == 2: 64 | # if we have a single speaker embedding for the whole sequence, 65 | # add a dummy dimension for backwards compatibility 66 | speaker_embs = speaker_embs[:, None, :] 67 | 68 | # num_examples is the number of utterances packed into this sequence 69 | b_se, num_examples, emb_dim = speaker_embs.size() 70 | 71 | assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}" 72 | assert ( 73 | emb_dim == expected_speaker_emb_dim 74 | ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}" 75 | 76 | return speaker_embs 77 | 78 | 79 | class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin): 80 | def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None): 81 | """ 82 | Initialize the GPT model. 83 | 84 | Args: 85 | config (GPTConfig): Configuration object for the model. 86 | speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None. 87 | """ 88 | super().__init__() 89 | assert config.vocab_sizes is not None 90 | assert config.block_size is not None 91 | self.config = config 92 | 93 | self.kv_cache_enabled = False # disabled by default 94 | self.kv_pos = 0 95 | 96 | self.speaker_emb_dim = speaker_emb_dim 97 | self.spk_emb_on_text = config.spk_emb_on_text 98 | if self.config.causal is True and self.spk_emb_on_text is False: 99 | print("!!!!!!!!!!!!!!!!!!") 100 | print( 101 | f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this." 102 | ) 103 | print("!!!!!!!!!!!!!!!!!!") 104 | if self.config.causal is False and self.spk_emb_on_text is False: 105 | raise Exception( 106 | "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding." 107 | ) 108 | 109 | if config.norm_type == "rmsnorm": 110 | if config.rmsnorm_eps is None: 111 | raise Exception("RMSNorm requires rmsnorm_eps to be set") 112 | ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) 113 | elif config.norm_type == "layernorm": 114 | ln_f = LayerNorm(config.n_embd, bias=config.bias) 115 | else: 116 | raise Exception(f"Unknown norm type: {config.norm_type}") 117 | 118 | self.transformer = nn.ModuleDict( 119 | dict( 120 | wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd) for vsize in config.vocab_sizes]), 121 | wpe=nn.Embedding(config.block_size, config.n_embd), 122 | drop=nn.Dropout(config.dropout), 123 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 124 | ln_f=ln_f, 125 | ) 126 | ) 127 | if speaker_emb_dim is not None: 128 | self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False) 129 | 130 | self.lm_heads = nn.ModuleList() 131 | if config.target_vocab_sizes is not None: 132 | assert config.causal is False 133 | else: 134 | assert config.causal is True 135 | 136 | for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes: 137 | self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False)) 138 | 139 | if config.target_vocab_sizes is None: 140 | for i in range(len(config.vocab_sizes)): 141 | # TODO: do we not need to take the transpose here? 142 | # https://paperswithcode.com/method/weight-tying 143 | self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore 144 | assert len(self.lm_heads) == len( 145 | self.transformer.wtes # type: ignore 146 | ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore 147 | 148 | # init all weights 149 | self.apply(self._init_weights) 150 | # apply special scaled init to the residual projections, per GPT-2 paper 151 | for pn, p in self.named_parameters(): 152 | if pn.endswith("c_proj.weight"): 153 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 154 | 155 | # report number of parameters 156 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) 157 | 158 | def get_num_params(self, non_embedding=True): 159 | """ 160 | Return the number of parameters in the model. 161 | For non-embedding count (default), the position embeddings get subtracted. 162 | The token embeddings would too, except due to the parameter sharing these 163 | params are actually used as weights in the final layer, so we include them. 164 | """ 165 | n_params = sum(p.numel() for p in self.parameters()) 166 | if non_embedding: 167 | n_params -= self.transformer.wpe.weight.numel() 168 | return n_params 169 | 170 | def _init_weights(self, module): 171 | if isinstance(module, nn.Linear): 172 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 173 | if module.bias is not None: 174 | torch.nn.init.zeros_(module.bias) 175 | elif isinstance(module, nn.Embedding): 176 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 177 | 178 | def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor: 179 | """ 180 | This is in a separate function so we can test it easily. 181 | """ 182 | # find index of end of text token in each sequence, then generate a binary mask 183 | # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token. 184 | # Note: this does NOT mask the token. This is important so that the first audio token predicted 185 | # has speaker information to use. 186 | 187 | # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens. 188 | is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN 189 | # use > 0, in case end_of_text_token is repeated for any reason. 190 | mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float() 191 | spk_emb = spk_emb * mask[:, :, None] 192 | 193 | return spk_emb 194 | 195 | def forward( 196 | self, 197 | idx, 198 | targets=None, 199 | speaker_embs=None, 200 | speaker_emb_mask=None, 201 | loss_reduce: Literal["mean", "none"] = "mean", 202 | ): 203 | device = idx.device 204 | b, num_hierarchies, t = idx.size() 205 | 206 | if speaker_embs is not None: 207 | speaker_embs = _check_speaker_emb_dims( 208 | speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b 209 | ) 210 | 211 | assert ( 212 | t <= self.config.block_size 213 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 214 | 215 | if self.kv_cache_enabled: 216 | if self.kv_pos == 0: 217 | pos = torch.arange(0, t, dtype=torch.long, device=device) 218 | self.kv_pos += t 219 | else: 220 | assert t == 1, "KV cache is only supported for single token inputs" 221 | pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1) 222 | self.kv_pos += 1 223 | else: 224 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 225 | 226 | # forward the GPT model itself 227 | assert num_hierarchies == len( 228 | self.transformer.wtes 229 | ), f"Input tensor has {num_hierarchies} hierarchies, but model has {len(self.transformer.wtes)} set of input embeddings." 230 | 231 | # embed the tokens, positional encoding, and speaker embedding 232 | tok_emb = torch.zeros((b, t, self.config.n_embd), device=device) 233 | # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings. 234 | for i, wte in enumerate(self.transformer.wtes): 235 | tok_emb += wte(idx[:, i, :]) 236 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 237 | 238 | spk_emb = 0.0 239 | if speaker_embs is not None: 240 | if type(speaker_embs) == list: 241 | assert speaker_emb_mask is None 242 | assert self.training is False 243 | assert self.spk_emb_on_text is True 244 | 245 | spk_emb = [] 246 | for speaker_emb_row in speaker_embs: 247 | if speaker_emb_row is not None: 248 | spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0))) 249 | assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}" 250 | else: 251 | spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype)) 252 | spk_emb = torch.cat(spk_emb, dim=0) 253 | 254 | assert ( 255 | spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b 256 | ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}" 257 | else: 258 | speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c) 259 | 260 | if speaker_emb_mask is not None: 261 | spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask) 262 | assert spk_emb.shape == (b, t, self.config.n_embd) 263 | else: 264 | spk_emb = speakers_embedded 265 | # if we don't have a mask, we assume that the speaker embedding is the same for all tokens 266 | # then num_examples dimension just becomes the time dimension 267 | assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1 268 | 269 | if self.training and self.config.spkemb_dropout > 0.0: 270 | # Remove speaker conditioning at random. 271 | dropout = torch.ones_like(speakers_embedded) * ( 272 | torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout 273 | ) 274 | spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded) 275 | 276 | if self.spk_emb_on_text is False: 277 | assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False" 278 | spk_emb = self._mask_spk_emb_on_text(idx, spk_emb) 279 | 280 | x = self.transformer.drop(tok_emb + pos_emb + spk_emb) 281 | for block in self.transformer.h: 282 | x = block(x) 283 | x = self.transformer.ln_f(x) 284 | 285 | if targets is not None: 286 | # if we are given some desired targets also calculate the loss 287 | list_logits = [lm_head(x) for lm_head in self.lm_heads] 288 | 289 | losses = [ 290 | F.cross_entropy( 291 | logits.view(-1, logits.size(-1)), 292 | targets[:, i, :].contiguous().view(-1), 293 | ignore_index=-1, 294 | reduction=loss_reduce, 295 | ) 296 | for i, logits in enumerate(list_logits) 297 | ] 298 | # TODO: should we do this better without stack somehow? 299 | losses = torch.stack(losses) 300 | if loss_reduce == "mean": 301 | losses = losses.mean() 302 | else: 303 | losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t) 304 | else: 305 | # inference-time mini-optimization: only forward the lm_head on the very last position 306 | if self.config.causal: 307 | list_logits = [ 308 | lm_head(x[:, [-1], :]) for lm_head in self.lm_heads 309 | ] # note: using list [-1] to preserve the time dim 310 | else: 311 | list_logits = [lm_head(x) for lm_head in self.lm_heads] 312 | losses = None 313 | 314 | return list_logits, losses 315 | 316 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 317 | # start with all of the candidate parameters 318 | param_dict = {pn: p for pn, p in self.named_parameters()} 319 | # filter out those that do not require grad 320 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 321 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 322 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 323 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 324 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 325 | optim_groups = [ 326 | {"params": decay_params, "weight_decay": weight_decay}, 327 | {"params": nodecay_params, "weight_decay": 0.0}, 328 | ] 329 | num_decay_params = sum(p.numel() for p in decay_params) 330 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 331 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 332 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 333 | # Create AdamW optimizer and use the fused version if it is available 334 | fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters 335 | use_fused = fused_available and device_type == "cuda" 336 | extra_args = dict(fused=True) if use_fused else dict() 337 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 338 | print(f"using fused AdamW: {use_fused}") 339 | 340 | return optimizer 341 | 342 | @torch.no_grad() 343 | def generate( 344 | self, 345 | idx: torch.Tensor, 346 | max_new_tokens: int, 347 | seq_lens: Optional[list] = None, 348 | temperature: float = 1.0, 349 | top_k: Optional[int] = None, 350 | top_p: Optional[float] = None, 351 | speaker_embs: Optional[torch.Tensor] = None, 352 | batch_size: Optional[int] = None, 353 | guidance_scale: Optional[Tuple[float, float]] = None, 354 | dtype: torch.dtype = torch.bfloat16, 355 | end_of_audio_token: int = 99999, # Dummy values will disable early termination / guidance features. 356 | end_of_text_token: int = 99999, 357 | ): 358 | """ 359 | Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete 360 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 361 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 362 | """ 363 | assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens" 364 | 365 | if self.config.causal: 366 | if seq_lens is None or batch_size is None: 367 | raise Exception("seq_lens and batch_size must be provided for causal sampling") 368 | 369 | return self._causal_sample( 370 | idx=idx, 371 | max_new_tokens=max_new_tokens, 372 | seq_lens=seq_lens, 373 | temperature=temperature, 374 | top_k=top_k, 375 | top_p=top_p, 376 | speaker_embs=speaker_embs, 377 | batch_size=batch_size, 378 | guidance_scale=guidance_scale, 379 | dtype=dtype, 380 | end_of_audio_token=end_of_audio_token, 381 | end_of_text_token=end_of_text_token, 382 | ) 383 | 384 | else: 385 | if seq_lens is not None: 386 | raise Exception("seq_lens is not supported yet for non-causal sampling") 387 | 388 | if batch_size is None: 389 | raise Exception("batch_size must be provided for non-causal sampling") 390 | 391 | if guidance_scale is not None: 392 | raise Exception("guidance_scale is not supported for non-causal sampling") 393 | 394 | if top_p is not None: 395 | raise Exception("top_p is not supported for non-causal sampling") 396 | 397 | out = [] 398 | for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="Non-causal batching"): 399 | end_index = min(start_index + batch_size, idx.shape[0]) 400 | out.append( 401 | self._non_causal_sample( 402 | idx=idx[start_index:end_index], 403 | speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None, 404 | temperature=temperature, 405 | top_k=top_k, 406 | ) 407 | ) 408 | return torch.cat(out, dim=0) 409 | -------------------------------------------------------------------------------- /fam/llm/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/preprocessing/__init__.py -------------------------------------------------------------------------------- /fam/llm/preprocessing/audio_token_mode.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Literal, Optional 3 | 4 | import numpy as np 5 | 6 | AudioTokenModeT = Literal["flattened_interleaved"] 7 | CombinerWithOffsetFuncT = Callable[[np.ndarray, np.ndarray, int], np.ndarray] 8 | CombinerFuncT = Callable[[np.ndarray, np.ndarray], np.ndarray] 9 | 10 | 11 | def combine_tokens_flattened_interleaved( 12 | audio_tokens: np.ndarray, text_tokens: np.ndarray, second_hierarchy_flattening_offset: int 13 | ) -> np.ndarray: 14 | """ 15 | Flattens & interleaves first 2 of the audio token hierarchies. Note that the tokens for the second hierarchy 16 | are also offset by second_hierarchy_flattening_offset as part of this transform to avoid conflict with values for the 17 | first hierarchy. 18 | """ 19 | assert np.issubdtype(audio_tokens.dtype, np.integer) 20 | assert np.issubdtype(text_tokens.dtype, np.integer) 21 | 22 | num_hierarchies = audio_tokens.shape[0] 23 | assert num_hierarchies >= 2, f"Unexpected number of hierarchies: {num_hierarchies}. Expected at least 2." 24 | 25 | # choosing -5 so that we can't get error! 26 | interleaved_audio_tokens = np.full((len(audio_tokens[0]) + len(audio_tokens[1]),), -5) 27 | interleaved_audio_tokens[::2] = audio_tokens[0] 28 | interleaved_audio_tokens[1::2] = audio_tokens[1] + second_hierarchy_flattening_offset 29 | 30 | tokens = np.concatenate([text_tokens, interleaved_audio_tokens]) 31 | 32 | return np.expand_dims(tokens, axis=0) 33 | 34 | 35 | def get_params_for_mode( 36 | audio_token_mode: AudioTokenModeT, num_max_audio_tokens_timesteps: Optional[int] = None 37 | ) -> dict[str, Any]: 38 | if audio_token_mode == "flattened_interleaved": 39 | return { 40 | "text_tokenisation_offset": 1024 * 2 + 1, 41 | "pad_token": 1024 * 2, 42 | "ctx_window": num_max_audio_tokens_timesteps * 2 if num_max_audio_tokens_timesteps else None, 43 | "second_hierarchy_flattening_offset": 1024, 44 | # TODO: fix the repeat of `second_hierarchy_flattening_offset` 45 | "combine_func": partial( 46 | combine_tokens_flattened_interleaved, 47 | second_hierarchy_flattening_offset=1024, 48 | ), 49 | } 50 | else: 51 | raise Exception(f"Unknown mode {audio_token_mode}") 52 | -------------------------------------------------------------------------------- /fam/llm/preprocessing/data_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def pad_tokens(tokens: np.ndarray, context_window: int, pad_token: int) -> np.ndarray: 8 | """Pads or truncates a single example to the context_window + 1 size. 9 | 10 | tokens: (..., example_length) 11 | """ 12 | example_length = tokens.shape[-1] 13 | if example_length > context_window + 1: 14 | # Truncate 15 | tokens = tokens[..., : context_window + 1] 16 | elif example_length < context_window + 1: 17 | # Pad 18 | padding = np.full(tokens.shape[:-1] + (context_window + 1 - example_length,), pad_token) 19 | tokens = np.concatenate([tokens, padding], axis=-1) 20 | assert tokens.shape[-1] == context_window + 1 21 | return tokens 22 | 23 | 24 | def get_training_tuple( 25 | batch: Dict[str, Any], 26 | causal: bool, 27 | num_codebooks: Optional[int], 28 | speaker_cond: bool, 29 | device: torch.device, 30 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 31 | # batch contains combined tokens as specified by audio_token_mode 32 | if causal: 33 | num_codebooks = batch["tokens"].shape[1] if num_codebooks is None else num_codebooks 34 | x = batch["tokens"][:, :num_codebooks, :-1] 35 | y = batch["tokens"][:, :num_codebooks, 1:] 36 | 37 | se = batch["spkemb"] 38 | 39 | x = x.to(device, non_blocking=True) 40 | y = y.to(device, non_blocking=True) 41 | se = se.to(device, non_blocking=True) if speaker_cond else None 42 | 43 | return x, y, se 44 | 45 | 46 | def pad_with_values(tensor, batch_size, value): 47 | """Pads the tensor up to batch_size with values.""" 48 | if tensor.shape[0] < batch_size: 49 | return torch.cat( 50 | [ 51 | tensor, 52 | torch.full( 53 | (batch_size - tensor.shape[0], *tensor.shape[1:]), value, dtype=tensor.dtype, device=tensor.device 54 | ), 55 | ] 56 | ) 57 | else: 58 | return tensor 59 | -------------------------------------------------------------------------------- /fam/llm/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | import re 5 | import subprocess 6 | import tempfile 7 | 8 | import librosa 9 | import torch 10 | 11 | 12 | def normalize_text(text: str) -> str: 13 | unicode_conversion = { 14 | 8175: "'", 15 | 8189: "'", 16 | 8190: "'", 17 | 8208: "-", 18 | 8209: "-", 19 | 8210: "-", 20 | 8211: "-", 21 | 8212: "-", 22 | 8213: "-", 23 | 8214: "||", 24 | 8216: "'", 25 | 8217: "'", 26 | 8218: ",", 27 | 8219: "`", 28 | 8220: '"', 29 | 8221: '"', 30 | 8222: ",,", 31 | 8223: '"', 32 | 8228: ".", 33 | 8229: "..", 34 | 8230: "...", 35 | 8242: "'", 36 | 8243: '"', 37 | 8245: "'", 38 | 8246: '"', 39 | 180: "'", 40 | 2122: "TM", # Trademark 41 | } 42 | 43 | text = text.translate(unicode_conversion) 44 | 45 | non_bpe_chars = set([c for c in list(text) if ord(c) >= 256]) 46 | if len(non_bpe_chars) > 0: 47 | non_bpe_points = [(c, ord(c)) for c in non_bpe_chars] 48 | raise ValueError(f"Non-supported character found: {non_bpe_points}") 49 | 50 | text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace("*", " ").strip() 51 | text = re.sub("\s\s+", " ", text) # remove multiple spaces 52 | return text 53 | 54 | 55 | def check_audio_file(path_or_uri, threshold_s=30): 56 | if "http" in path_or_uri: 57 | temp_fd, filepath = tempfile.mkstemp() 58 | os.close(temp_fd) # Close the file descriptor, curl will create a new connection 59 | curl_command = ["curl", "-L", path_or_uri, "-o", filepath] 60 | subprocess.run(curl_command, check=True) 61 | 62 | else: 63 | filepath = path_or_uri 64 | 65 | audio, sr = librosa.load(filepath) 66 | duration_s = librosa.get_duration(y=audio, sr=sr) 67 | if duration_s < threshold_s: 68 | raise Exception( 69 | f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed." 70 | ) 71 | 72 | # Clean up the temporary file if it was created 73 | if "http" in path_or_uri: 74 | os.remove(filepath) 75 | 76 | 77 | def get_default_dtype() -> str: 78 | """Compute default 'dtype' based on GPU architecture""" 79 | if torch.cuda.is_available(): 80 | for i in range(torch.cuda.device_count()): 81 | device_properties = torch.cuda.get_device_properties(i) 82 | dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures 83 | else: 84 | dtype = "float16" 85 | 86 | print(f"using dtype={dtype}") 87 | return dtype 88 | 89 | 90 | def get_device() -> str: 91 | return "cuda" if torch.cuda.is_available() else "cpu" 92 | 93 | 94 | def hash_dictionary(d: dict): 95 | # Serialize the dictionary into JSON with sorted keys to ensure consistency 96 | serialized = json.dumps(d, sort_keys=True) 97 | # Encode the serialized string to bytes 98 | encoded = serialized.encode() 99 | # Create a hash object (you can also use sha1, sha512, etc.) 100 | hash_object = hashlib.sha256(encoded) 101 | # Get the hexadecimal digest of the hash 102 | hash_digest = hash_object.hexdigest() 103 | return hash_digest 104 | -------------------------------------------------------------------------------- /fam/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/py.typed -------------------------------------------------------------------------------- /fam/quantiser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/quantiser/__init__.py -------------------------------------------------------------------------------- /fam/quantiser/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/quantiser/audio/__init__.py -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/quantiser/audio/speaker_encoder/__init__.py -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | mel_window_length = 25 5 | mel_window_step = 10 6 | mel_n_channels = 40 7 | sampling_rate = 16000 8 | 9 | 10 | def wav_to_mel_spectrogram(wav): 11 | """ 12 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 13 | Note: this not a log-mel spectrogram. 14 | """ 15 | frames = librosa.feature.melspectrogram( 16 | y=wav, 17 | sr=sampling_rate, 18 | n_fft=int(sampling_rate * mel_window_length / 1000), 19 | hop_length=int(sampling_rate * mel_window_step / 1000), 20 | n_mels=mel_n_channels, 21 | ) 22 | return frames.astype(np.float32).T 23 | -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import perf_counter as timer 3 | from typing import List, Optional, Union 4 | 5 | import librosa 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | from fam.quantiser.audio.speaker_encoder import audio 11 | 12 | mel_window_step = 10 13 | mel_n_channels = 40 14 | sampling_rate = 16000 15 | partials_n_frames = 160 16 | model_hidden_size = 256 17 | model_embedding_size = 256 18 | model_num_layers = 3 19 | 20 | 21 | class SpeakerEncoder(nn.Module): 22 | def __init__( 23 | self, 24 | weights_fpath: Optional[str] = None, 25 | device: Optional[Union[str, torch.device]] = None, 26 | verbose: bool = True, 27 | eval: bool = False, 28 | ): 29 | super().__init__() 30 | 31 | # Define the network 32 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) 33 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 34 | self.relu = nn.ReLU() 35 | 36 | # Get the target device 37 | if device is None: 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | elif isinstance(device, str): 40 | device = torch.device(device) 41 | self.device = device 42 | 43 | start = timer() 44 | 45 | checkpoint = torch.load(weights_fpath, map_location="cpu") 46 | self.load_state_dict(checkpoint["model_state"], strict=False) 47 | self.to(device) 48 | 49 | if eval: 50 | self.eval() 51 | 52 | if verbose: 53 | print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start)) 54 | 55 | def forward(self, mels: torch.FloatTensor): 56 | _, (hidden, _) = self.lstm(mels) 57 | embeds_raw = self.relu(self.linear(hidden[-1])) 58 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 59 | 60 | @staticmethod 61 | def compute_partial_slices(n_samples: int, rate, min_coverage): 62 | # Compute how many frames separate two partial utterances 63 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 64 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 65 | frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) 66 | 67 | # Compute the slices 68 | wav_slices, mel_slices = [], [] 69 | steps = max(1, n_frames - partials_n_frames + frame_step + 1) 70 | for i in range(0, steps, frame_step): 71 | mel_range = np.array([i, i + partials_n_frames]) 72 | wav_range = mel_range * samples_per_frame 73 | mel_slices.append(slice(*mel_range)) 74 | wav_slices.append(slice(*wav_range)) 75 | 76 | # Evaluate whether extra padding is warranted or not 77 | last_wav_range = wav_slices[-1] 78 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 79 | if coverage < min_coverage and len(mel_slices) > 1: 80 | mel_slices = mel_slices[:-1] 81 | wav_slices = wav_slices[:-1] 82 | 83 | return wav_slices, mel_slices 84 | 85 | def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True): 86 | wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) 87 | max_wave_length = wav_slices[-1].stop 88 | if max_wave_length >= len(wav): 89 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 90 | 91 | mel = audio.wav_to_mel_spectrogram(wav) 92 | mels = np.array([mel[s] for s in mel_slices]) 93 | mels = torch.from_numpy(mels).to(self.device) # type: ignore 94 | with torch.no_grad(): 95 | partial_embeds = self(mels) 96 | 97 | if numpy: 98 | raw_embed = np.mean(partial_embeds.cpu().numpy(), axis=0) 99 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 100 | else: 101 | raw_embed = partial_embeds.mean(dim=0) 102 | embed = raw_embed / torch.linalg.norm(raw_embed, 2) 103 | 104 | if return_partials: 105 | return embed, partial_embeds, wav_slices 106 | return embed 107 | 108 | def embed_speaker(self, wavs: List[np.ndarray], **kwargs): 109 | raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0) 110 | return raw_embed / np.linalg.norm(raw_embed, 2) 111 | 112 | def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor: 113 | wav_tgt, _ = librosa.load(fpath, sr=sampling_rate) 114 | wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) 115 | 116 | embedding = self.embed_utterance(wav_tgt, numpy=numpy) 117 | return embedding 118 | -------------------------------------------------------------------------------- /fam/quantiser/text/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fam/quantiser/text/tokenise.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | 3 | 4 | class TrainedBPETokeniser: 5 | def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None: 6 | self.tokenizer = tiktoken.Encoding( 7 | name=name, 8 | pat_str=pat_str, 9 | mergeable_ranks=mergeable_ranks, 10 | special_tokens=special_tokens, 11 | ) 12 | self.offset = offset 13 | 14 | def encode(self, text: str) -> list[int]: 15 | # note: we add a end of text token! 16 | tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token] 17 | if self.offset is not None: 18 | tokens = [x + self.offset for x in tokens] 19 | 20 | return tokens 21 | 22 | def decode(self, tokens: list[int]): 23 | if self.offset is not None: 24 | tokens = [x - self.offset for x in tokens] 25 | return self.tokenizer.decode(tokens) 26 | 27 | @property 28 | def eot_token(self): 29 | if self.offset is not None: 30 | return self.tokenizer.eot_token + self.offset 31 | else: 32 | return self.tokenizer.eot_token 33 | -------------------------------------------------------------------------------- /fam/telemetry/README.md: -------------------------------------------------------------------------------- 1 | # Telemetry 2 | 3 | This directory holds all the telemetry for MetaVoice. We, MetaVoice, capture anonymized telemetry to understand usage patterns. 4 | 5 | If you prefer to opt out of telemetry, set `ANONYMIZED_TELEMETRY=False` in an .env file at the root level of this repo. 6 | -------------------------------------------------------------------------------- /fam/telemetry/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import uuid 4 | from abc import abstractmethod 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | 8 | 9 | @dataclass(frozen=True) 10 | class TelemetryEvent: 11 | name: str 12 | properties: dict 13 | 14 | 15 | class TelemetryClient(abc.ABC): 16 | USER_ID_PATH = str(Path.home() / ".cache" / "metavoice" / "telemetry_user_id") 17 | UNKNOWN_USER_ID = "UNKNOWN" 18 | _curr_user_id = None 19 | 20 | @abstractmethod 21 | def capture(self, event: TelemetryEvent) -> None: 22 | pass 23 | 24 | @property 25 | def user_id(self) -> str: 26 | if self._curr_user_id: 27 | return self._curr_user_id 28 | 29 | # File access may fail due to permissions or other reasons. We don't want to 30 | # crash so we catch all exceptions. 31 | try: 32 | if not os.path.exists(self.USER_ID_PATH): 33 | os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True) 34 | with open(self.USER_ID_PATH, "w") as f: 35 | new_user_id = str(uuid.uuid4()) 36 | f.write(new_user_id) 37 | self._curr_user_id = new_user_id 38 | else: 39 | with open(self.USER_ID_PATH, "r") as f: 40 | self._curr_user_id = f.read() 41 | except Exception: 42 | self._curr_user_id = self.UNKNOWN_USER_ID 43 | return self._curr_user_id 44 | -------------------------------------------------------------------------------- /fam/telemetry/posthog.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from dotenv import load_dotenv 6 | from posthog import Posthog 7 | 8 | from fam.telemetry import TelemetryClient, TelemetryEvent 9 | 10 | load_dotenv() 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout), logging.StreamHandler(sys.stderr)]) 13 | 14 | 15 | class PosthogClient(TelemetryClient): 16 | def __init__(self): 17 | self._posthog = Posthog( 18 | project_api_key="phc_tk7IUlV7Q7lEa9LNbXxyC1sMWlCqiW6DkHyhJrbWMCS", host="https://eu.posthog.com" 19 | ) 20 | 21 | if not bool(os.getenv("ANONYMIZED_TELEMETRY", True)) or "pytest" in sys.modules: 22 | self._posthog.disabled = True 23 | logger.info("Anonymized telemetry disabled. See fam/telemetry/README.md for more information.") 24 | else: 25 | logger.info("Anonymized telemetry enabled. See fam/telemetry/README.md for more information.") 26 | 27 | posthog_logger = logging.getLogger("posthog") 28 | posthog_logger.disabled = True # Silence posthog's logging 29 | 30 | super().__init__() 31 | 32 | def capture(self, event: TelemetryEvent) -> None: 33 | try: 34 | self._posthog.capture( 35 | self.user_id, 36 | event.name, 37 | {**event.properties}, 38 | ) 39 | except Exception as e: 40 | logger.error(f"Failed to send telemetry event {event.name}: {e}") 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fam" 3 | version = "0.1.0" 4 | description = "Foundational model for text to speech" 5 | authors = [] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | torch = "^2.1.0" 11 | torchaudio = "^2.1.0" 12 | librosa = "^0.10.1" 13 | tqdm = "^4.66.2" 14 | tiktoken = "==0.5.1" 15 | audiocraft = "^1.2.0" 16 | numpy = "^1.26.4" 17 | ninja = "^1.11.1" 18 | fastapi = "^0.110.0" 19 | uvicorn = "^0.27.1" 20 | tyro = "^0.7.3" 21 | deepfilternet = "^0.5.6" 22 | pydub = "^0.25.1" 23 | gradio = "^4.20.1" 24 | huggingface_hub = "^0.21.4" 25 | click = "^8.1.7" 26 | wandb = { version = "^0.16.4", optional = true } 27 | posthog = "^3.5.0" 28 | python-dotenv = "^1.0.1" 29 | 30 | [tool.poetry.dev-dependencies] 31 | pre-commit = "^3.7.0" 32 | pytest = "^8.0.2" 33 | ipdb = "^0.13.13" 34 | 35 | [tool.poetry.extras] 36 | observable = ["wandb"] 37 | 38 | [tool.poetry.scripts] 39 | finetune = "fam.llm.finetune:main" 40 | 41 | [build-system] 42 | requires = ["poetry-core"] 43 | build-backend = "poetry.core.masonry.api" 44 | 45 | [tool.black] 46 | line-length = 120 47 | exclude = ''' 48 | /( 49 | \.git 50 | | \.mypy_cache 51 | | \.tox 52 | | _build 53 | | build 54 | | dist 55 | )/ 56 | ''' 57 | 58 | [tool.isort] 59 | profile = "black" 60 | 61 | -------------------------------------------------------------------------------- /serving.py: -------------------------------------------------------------------------------- 1 | # curl -X POST http://127.0.0.1:58003/tts -F "text=Testing this inference server." -F "speaker_ref_path=https://cdn.themetavoice.xyz/speakers/bria.mp3" -F "guidance=3.0" -F "top_p=0.95" --output out.wav 2 | 3 | import logging 4 | import shlex 5 | import subprocess 6 | import tempfile 7 | import warnings 8 | from pathlib import Path 9 | from typing import Literal, Optional 10 | 11 | import fastapi 12 | import fastapi.middleware.cors 13 | import tyro 14 | import uvicorn 15 | from attr import dataclass 16 | from fastapi import File, Form, HTTPException, UploadFile, status 17 | from fastapi.responses import Response 18 | 19 | from fam.llm.fast_inference import TTS 20 | from fam.llm.utils import check_audio_file 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | ## Setup FastAPI server. 26 | app = fastapi.FastAPI() 27 | 28 | 29 | @dataclass 30 | class ServingConfig: 31 | huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1" 32 | """Absolute path to the model directory.""" 33 | 34 | temperature: float = 1.0 35 | """Temperature for sampling applied to both models.""" 36 | 37 | seed: int = 1337 38 | """Random seed for sampling.""" 39 | 40 | port: int = 58003 41 | 42 | quantisation_mode: Optional[Literal["int4", "int8"]] = None 43 | 44 | 45 | # Singleton 46 | class _GlobalState: 47 | config: ServingConfig 48 | tts: TTS 49 | 50 | 51 | GlobalState = _GlobalState() 52 | 53 | 54 | @app.get("/health") 55 | async def health_check(): 56 | return {"status": "ok"} 57 | 58 | 59 | @app.post("/tts", response_class=Response) 60 | async def text_to_speech( 61 | text: str = Form(..., description="Text to convert to speech."), 62 | speaker_ref_path: Optional[str] = Form(None, description="Optional URL to an audio file of a reference speaker. Provide either this URL or audio data through 'audiodata'."), 63 | audiodata: Optional[UploadFile] = File(None, description="Optional audio data of a reference speaker. Provide either this file or a URL through 'speaker_ref_path'."), 64 | guidance: float = Form(3.0, description="Control speaker similarity - how closely to match speaker identity and speech style, range: 0.0 to 5.0.", ge=0.0, le=5.0), 65 | top_p: float = Form(0.95, description="Controls speech stability - improves text following for a challenging speaker, range: 0.0 to 1.0.", ge=0.0, le=1.0), 66 | ): 67 | # Ensure at least one of speaker_ref_path or audiodata is provided 68 | if not audiodata and not speaker_ref_path: 69 | raise HTTPException( 70 | status_code=status.HTTP_400_BAD_REQUEST, 71 | detail="Either an audio file or a speaker reference path must be provided.", 72 | ) 73 | 74 | wav_out_path = None 75 | 76 | try: 77 | with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp: 78 | if speaker_ref_path is None: 79 | wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) 80 | check_audio_file(wav_path) 81 | else: 82 | # TODO: fix 83 | wav_path = speaker_ref_path 84 | 85 | if wav_path is None: 86 | warnings.warn("Running without speaker reference") 87 | assert guidance is None 88 | 89 | wav_out_path = GlobalState.tts.synthesise( 90 | text=text, 91 | spk_ref_path=wav_path, 92 | top_p=top_p, 93 | guidance_scale=guidance, 94 | ) 95 | 96 | with open(wav_out_path, "rb") as f: 97 | return Response(content=f.read(), media_type="audio/wav") 98 | except Exception as e: 99 | # traceback_str = "".join(traceback.format_tb(e.__traceback__)) 100 | logger.exception( 101 | f"Error processing request. text: {text}, speaker_ref_path: {speaker_ref_path}, guidance: {guidance}, top_p: {top_p}" 102 | ) 103 | return Response( 104 | content="Something went wrong. Please try again in a few mins or contact us on Discord", 105 | status_code=500, 106 | ) 107 | finally: 108 | if wav_out_path is not None: 109 | Path(wav_out_path).unlink(missing_ok=True) 110 | 111 | 112 | def _convert_audiodata_to_wav_path(audiodata: UploadFile, wav_tmp): 113 | with tempfile.NamedTemporaryFile() as unknown_format_tmp: 114 | if unknown_format_tmp.write(audiodata.read()) == 0: 115 | return None 116 | unknown_format_tmp.flush() 117 | 118 | subprocess.check_output( 119 | # arbitrary 2 minute cutoff 120 | shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}") 121 | ) 122 | 123 | return wav_tmp.name 124 | 125 | 126 | if __name__ == "__main__": 127 | for name in logging.root.manager.loggerDict: 128 | logger = logging.getLogger(name) 129 | logger.setLevel(logging.INFO) 130 | logging.root.setLevel(logging.INFO) 131 | 132 | GlobalState.config = tyro.cli(ServingConfig) 133 | GlobalState.tts = TTS( 134 | seed=GlobalState.config.seed, 135 | quantisation_mode=GlobalState.config.quantisation_mode, 136 | telemetry_origin="api_server", 137 | ) 138 | 139 | app.add_middleware( 140 | fastapi.middleware.cors.CORSMiddleware, 141 | allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"], 142 | allow_credentials=True, 143 | allow_methods=["*"], 144 | allow_headers=["*"], 145 | ) 146 | uvicorn.run( 147 | app, 148 | host="0.0.0.0", 149 | port=GlobalState.config.port, 150 | log_level="info", 151 | ) 152 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup # type: ignore 2 | 3 | setup( 4 | name="fam", 5 | packages=find_packages(".", exclude=["tests"]), 6 | ) 7 | -------------------------------------------------------------------------------- /tests/llm/loaders/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | 4 | import pytest 5 | import torch 6 | from huggingface_hub import snapshot_download 7 | from torch.utils.data import DataLoader 8 | 9 | from fam.llm.config.finetune_params import audio_token_mode as atm 10 | from fam.llm.config.finetune_params import num_max_audio_tokens_timesteps 11 | from fam.llm.loaders.training_data import DynamicComputeDataset 12 | from fam.llm.preprocessing.audio_token_mode import get_params_for_mode 13 | 14 | 15 | @pytest.mark.parametrize("dataset", ["tests/resources/datasets/sample_dataset.csv"]) 16 | @pytest.mark.skip(reason="Requires ckpt download, not feasible as test suite") 17 | def test_dataset_preprocess_e2e(dataset): 18 | model_name = "metavoiceio/metavoice-1B-v0.1" 19 | device = "cuda" 20 | mode_params = get_params_for_mode(atm, num_max_audio_tokens_timesteps=num_max_audio_tokens_timesteps) 21 | 22 | _model_dir = snapshot_download(repo_id=model_name) 23 | checkpoint_path = Path(f"{_model_dir}/first_stage.pt") 24 | spk_emb_ckpt_path = Path(f"{_model_dir}/speaker_encoder.pt") 25 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) 26 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) 27 | 28 | dataset = DynamicComputeDataset.from_meta( 29 | tokenizer_info, 30 | mode_params["combine_func"], 31 | spk_emb_ckpt_path, 32 | dataset, 33 | mode_params["pad_token"], 34 | mode_params["ctx_window"], 35 | device 36 | ) 37 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 38 | result = next(iter(dataloader)) 39 | 40 | # TODO: better assertions based on sample input dims 41 | assert len(result) == 2 42 | -------------------------------------------------------------------------------- /tests/resources/data/caption.txt: -------------------------------------------------------------------------------- 1 | Please call Stella. -------------------------------------------------------------------------------- /tests/resources/datasets/sample_dataset.csv: -------------------------------------------------------------------------------- 1 | audio_files,captions 2 | ./data/audio.wav,./data/caption.txt 3 | ./data/audio.wav,./data/caption.txt 4 | ./data/audio.wav,./data/caption.txt 5 | ./data/audio.wav,./data/caption.txt 6 | ./data/audio.wav,./data/caption.txt 7 | ./data/audio.wav,./data/caption.txt 8 | ./data/audio.wav,./data/caption.txt 9 | ./data/audio.wav,./data/caption.txt 10 | ./data/audio.wav,./data/caption.txt 11 | ./data/audio.wav,./data/caption.txt 12 | ./data/audio.wav,./data/caption.txt 13 | ./data/audio.wav,./data/caption.txt 14 | ./data/audio.wav,./data/caption.txt 15 | ./data/audio.wav,./data/caption.txt 16 | ./data/audio.wav,./data/caption.txt 17 | ./data/audio.wav,./data/caption.txt 18 | ./data/audio.wav,./data/caption.txt 19 | ./data/audio.wav,./data/caption.txt 20 | ./data/audio.wav,./data/caption.txt 21 | ./data/audio.wav,./data/caption.txt 22 | ./data/audio.wav,./data/caption.txt 23 | ./data/audio.wav,./data/caption.txt 24 | ./data/audio.wav,./data/caption.txt 25 | ./data/audio.wav,./data/caption.txt 26 | ./data/audio.wav,./data/caption.txt 27 | ./data/audio.wav,./data/caption.txt 28 | ./data/audio.wav,./data/caption.txt 29 | ./data/audio.wav,./data/caption.txt 30 | ./data/audio.wav,./data/caption.txt 31 | ./data/audio.wav,./data/caption.txt 32 | ./data/audio.wav,./data/caption.txt 33 | ./data/audio.wav,./data/caption.txt 34 | ./data/audio.wav,./data/caption.txt 35 | ./data/audio.wav,./data/caption.txt 36 | ./data/audio.wav,./data/caption.txt 37 | ./data/audio.wav,./data/caption.txt 38 | ./data/audio.wav,./data/caption.txt 39 | ./data/audio.wav,./data/caption.txt 40 | ./data/audio.wav,./data/caption.txt 41 | ./data/audio.wav,./data/caption.txt 42 | ./data/audio.wav,./data/caption.txt 43 | ./data/audio.wav,./data/caption.txt 44 | ./data/audio.wav,./data/caption.txt 45 | ./data/audio.wav,./data/caption.txt 46 | ./data/audio.wav,./data/caption.txt 47 | ./data/audio.wav,./data/caption.txt 48 | ./data/audio.wav,./data/caption.txt 49 | ./data/audio.wav,./data/caption.txt 50 | ./data/audio.wav,./data/caption.txt 51 | ./data/audio.wav,./data/caption.txt 52 | ./data/audio.wav,./data/caption.txt 53 | ./data/audio.wav,./data/caption.txt 54 | ./data/audio.wav,./data/caption.txt 55 | ./data/audio.wav,./data/caption.txt 56 | ./data/audio.wav,./data/caption.txt 57 | ./data/audio.wav,./data/caption.txt 58 | ./data/audio.wav,./data/caption.txt 59 | ./data/audio.wav,./data/caption.txt 60 | ./data/audio.wav,./data/caption.txt 61 | ./data/audio.wav,./data/caption.txt 62 | ./data/audio.wav,./data/caption.txt 63 | ./data/audio.wav,./data/caption.txt 64 | ./data/audio.wav,./data/caption.txt 65 | ./data/audio.wav,./data/caption.txt 66 | ./data/audio.wav,./data/caption.txt 67 | ./data/audio.wav,./data/caption.txt 68 | ./data/audio.wav,./data/caption.txt 69 | ./data/audio.wav,./data/caption.txt 70 | ./data/audio.wav,./data/caption.txt 71 | ./data/audio.wav,./data/caption.txt 72 | ./data/audio.wav,./data/caption.txt 73 | ./data/audio.wav,./data/caption.txt 74 | ./data/audio.wav,./data/caption.txt 75 | ./data/audio.wav,./data/caption.txt 76 | ./data/audio.wav,./data/caption.txt 77 | ./data/audio.wav,./data/caption.txt 78 | ./data/audio.wav,./data/caption.txt 79 | ./data/audio.wav,./data/caption.txt 80 | ./data/audio.wav,./data/caption.txt 81 | ./data/audio.wav,./data/caption.txt 82 | ./data/audio.wav,./data/caption.txt 83 | ./data/audio.wav,./data/caption.txt 84 | ./data/audio.wav,./data/caption.txt 85 | ./data/audio.wav,./data/caption.txt 86 | ./data/audio.wav,./data/caption.txt 87 | ./data/audio.wav,./data/caption.txt 88 | ./data/audio.wav,./data/caption.txt 89 | ./data/audio.wav,./data/caption.txt 90 | ./data/audio.wav,./data/caption.txt 91 | ./data/audio.wav,./data/caption.txt 92 | ./data/audio.wav,./data/caption.txt 93 | ./data/audio.wav,./data/caption.txt 94 | ./data/audio.wav,./data/caption.txt 95 | ./data/audio.wav,./data/caption.txt 96 | ./data/audio.wav,./data/caption.txt 97 | ./data/audio.wav,./data/caption.txt 98 | ./data/audio.wav,./data/caption.txt 99 | ./data/audio.wav,./data/caption.txt 100 | ./data/audio.wav,./data/caption.txt 101 | ./data/audio.wav,./data/caption.txt 102 | ./data/audio.wav,./data/caption.txt 103 | ./data/audio.wav,./data/caption.txt 104 | ./data/audio.wav,./data/caption.txt 105 | ./data/audio.wav,./data/caption.txt 106 | ./data/audio.wav,./data/caption.txt 107 | ./data/audio.wav,./data/caption.txt 108 | ./data/audio.wav,./data/caption.txt 109 | ./data/audio.wav,./data/caption.txt 110 | ./data/audio.wav,./data/caption.txt 111 | ./data/audio.wav,./data/caption.txt 112 | ./data/audio.wav,./data/caption.txt 113 | ./data/audio.wav,./data/caption.txt 114 | ./data/audio.wav,./data/caption.txt 115 | ./data/audio.wav,./data/caption.txt 116 | ./data/audio.wav,./data/caption.txt 117 | ./data/audio.wav,./data/caption.txt 118 | ./data/audio.wav,./data/caption.txt 119 | ./data/audio.wav,./data/caption.txt 120 | ./data/audio.wav,./data/caption.txt 121 | ./data/audio.wav,./data/caption.txt 122 | ./data/audio.wav,./data/caption.txt 123 | ./data/audio.wav,./data/caption.txt 124 | ./data/audio.wav,./data/caption.txt 125 | ./data/audio.wav,./data/caption.txt 126 | ./data/audio.wav,./data/caption.txt 127 | ./data/audio.wav,./data/caption.txt 128 | ./data/audio.wav,./data/caption.txt 129 | ./data/audio.wav,./data/caption.txt 130 | ./data/audio.wav,./data/caption.txt 131 | ./data/audio.wav,./data/caption.txt 132 | ./data/audio.wav,./data/caption.txt 133 | ./data/audio.wav,./data/caption.txt 134 | ./data/audio.wav,./data/caption.txt 135 | ./data/audio.wav,./data/caption.txt 136 | ./data/audio.wav,./data/caption.txt 137 | ./data/audio.wav,./data/caption.txt 138 | ./data/audio.wav,./data/caption.txt 139 | ./data/audio.wav,./data/caption.txt 140 | ./data/audio.wav,./data/caption.txt 141 | ./data/audio.wav,./data/caption.txt 142 | ./data/audio.wav,./data/caption.txt 143 | ./data/audio.wav,./data/caption.txt 144 | ./data/audio.wav,./data/caption.txt 145 | ./data/audio.wav,./data/caption.txt 146 | ./data/audio.wav,./data/caption.txt 147 | ./data/audio.wav,./data/caption.txt 148 | ./data/audio.wav,./data/caption.txt 149 | ./data/audio.wav,./data/caption.txt 150 | ./data/audio.wav,./data/caption.txt 151 | ./data/audio.wav,./data/caption.txt 152 | ./data/audio.wav,./data/caption.txt 153 | ./data/audio.wav,./data/caption.txt 154 | ./data/audio.wav,./data/caption.txt 155 | ./data/audio.wav,./data/caption.txt 156 | ./data/audio.wav,./data/caption.txt 157 | ./data/audio.wav,./data/caption.txt 158 | ./data/audio.wav,./data/caption.txt 159 | ./data/audio.wav,./data/caption.txt 160 | ./data/audio.wav,./data/caption.txt 161 | ./data/audio.wav,./data/caption.txt 162 | ./data/audio.wav,./data/caption.txt 163 | ./data/audio.wav,./data/caption.txt 164 | ./data/audio.wav,./data/caption.txt 165 | ./data/audio.wav,./data/caption.txt 166 | ./data/audio.wav,./data/caption.txt 167 | ./data/audio.wav,./data/caption.txt 168 | ./data/audio.wav,./data/caption.txt 169 | ./data/audio.wav,./data/caption.txt 170 | ./data/audio.wav,./data/caption.txt 171 | ./data/audio.wav,./data/caption.txt 172 | ./data/audio.wav,./data/caption.txt 173 | ./data/audio.wav,./data/caption.txt 174 | ./data/audio.wav,./data/caption.txt 175 | ./data/audio.wav,./data/caption.txt 176 | ./data/audio.wav,./data/caption.txt 177 | ./data/audio.wav,./data/caption.txt 178 | ./data/audio.wav,./data/caption.txt 179 | ./data/audio.wav,./data/caption.txt 180 | ./data/audio.wav,./data/caption.txt 181 | ./data/audio.wav,./data/caption.txt 182 | ./data/audio.wav,./data/caption.txt 183 | ./data/audio.wav,./data/caption.txt 184 | ./data/audio.wav,./data/caption.txt 185 | ./data/audio.wav,./data/caption.txt 186 | ./data/audio.wav,./data/caption.txt 187 | ./data/audio.wav,./data/caption.txt 188 | ./data/audio.wav,./data/caption.txt 189 | ./data/audio.wav,./data/caption.txt 190 | ./data/audio.wav,./data/caption.txt 191 | ./data/audio.wav,./data/caption.txt 192 | ./data/audio.wav,./data/caption.txt 193 | ./data/audio.wav,./data/caption.txt 194 | ./data/audio.wav,./data/caption.txt 195 | ./data/audio.wav,./data/caption.txt 196 | ./data/audio.wav,./data/caption.txt 197 | ./data/audio.wav,./data/caption.txt 198 | ./data/audio.wav,./data/caption.txt 199 | ./data/audio.wav,./data/caption.txt 200 | ./data/audio.wav,./data/caption.txt 201 | ./data/audio.wav,./data/caption.txt 202 | ./data/audio.wav,./data/caption.txt 203 | ./data/audio.wav,./data/caption.txt 204 | ./data/audio.wav,./data/caption.txt 205 | ./data/audio.wav,./data/caption.txt 206 | ./data/audio.wav,./data/caption.txt 207 | ./data/audio.wav,./data/caption.txt 208 | ./data/audio.wav,./data/caption.txt 209 | ./data/audio.wav,./data/caption.txt 210 | ./data/audio.wav,./data/caption.txt 211 | ./data/audio.wav,./data/caption.txt 212 | ./data/audio.wav,./data/caption.txt 213 | ./data/audio.wav,./data/caption.txt 214 | ./data/audio.wav,./data/caption.txt 215 | ./data/audio.wav,./data/caption.txt 216 | ./data/audio.wav,./data/caption.txt 217 | ./data/audio.wav,./data/caption.txt 218 | ./data/audio.wav,./data/caption.txt 219 | ./data/audio.wav,./data/caption.txt 220 | ./data/audio.wav,./data/caption.txt 221 | ./data/audio.wav,./data/caption.txt 222 | ./data/audio.wav,./data/caption.txt 223 | ./data/audio.wav,./data/caption.txt 224 | ./data/audio.wav,./data/caption.txt 225 | ./data/audio.wav,./data/caption.txt 226 | ./data/audio.wav,./data/caption.txt 227 | ./data/audio.wav,./data/caption.txt 228 | ./data/audio.wav,./data/caption.txt 229 | ./data/audio.wav,./data/caption.txt 230 | ./data/audio.wav,./data/caption.txt 231 | ./data/audio.wav,./data/caption.txt 232 | ./data/audio.wav,./data/caption.txt 233 | ./data/audio.wav,./data/caption.txt 234 | ./data/audio.wav,./data/caption.txt 235 | ./data/audio.wav,./data/caption.txt 236 | ./data/audio.wav,./data/caption.txt 237 | ./data/audio.wav,./data/caption.txt 238 | ./data/audio.wav,./data/caption.txt 239 | ./data/audio.wav,./data/caption.txt 240 | ./data/audio.wav,./data/caption.txt 241 | ./data/audio.wav,./data/caption.txt 242 | ./data/audio.wav,./data/caption.txt 243 | ./data/audio.wav,./data/caption.txt 244 | ./data/audio.wav,./data/caption.txt 245 | ./data/audio.wav,./data/caption.txt 246 | ./data/audio.wav,./data/caption.txt 247 | ./data/audio.wav,./data/caption.txt 248 | ./data/audio.wav,./data/caption.txt 249 | ./data/audio.wav,./data/caption.txt 250 | ./data/audio.wav,./data/caption.txt 251 | ./data/audio.wav,./data/caption.txt 252 | ./data/audio.wav,./data/caption.txt 253 | ./data/audio.wav,./data/caption.txt 254 | ./data/audio.wav,./data/caption.txt 255 | ./data/audio.wav,./data/caption.txt 256 | ./data/audio.wav,./data/caption.txt 257 | ./data/audio.wav,./data/caption.txt 258 | ./data/audio.wav,./data/caption.txt 259 | ./data/audio.wav,./data/caption.txt 260 | ./data/audio.wav,./data/caption.txt 261 | ./data/audio.wav,./data/caption.txt 262 | ./data/audio.wav,./data/caption.txt 263 | ./data/audio.wav,./data/caption.txt 264 | ./data/audio.wav,./data/caption.txt 265 | ./data/audio.wav,./data/caption.txt 266 | ./data/audio.wav,./data/caption.txt 267 | ./data/audio.wav,./data/caption.txt 268 | ./data/audio.wav,./data/caption.txt 269 | ./data/audio.wav,./data/caption.txt 270 | ./data/audio.wav,./data/caption.txt 271 | ./data/audio.wav,./data/caption.txt 272 | ./data/audio.wav,./data/caption.txt 273 | ./data/audio.wav,./data/caption.txt 274 | ./data/audio.wav,./data/caption.txt 275 | ./data/audio.wav,./data/caption.txt 276 | ./data/audio.wav,./data/caption.txt 277 | ./data/audio.wav,./data/caption.txt 278 | ./data/audio.wav,./data/caption.txt 279 | ./data/audio.wav,./data/caption.txt 280 | ./data/audio.wav,./data/caption.txt 281 | ./data/audio.wav,./data/caption.txt 282 | ./data/audio.wav,./data/caption.txt 283 | ./data/audio.wav,./data/caption.txt 284 | ./data/audio.wav,./data/caption.txt 285 | ./data/audio.wav,./data/caption.txt 286 | ./data/audio.wav,./data/caption.txt 287 | ./data/audio.wav,./data/caption.txt 288 | ./data/audio.wav,./data/caption.txt 289 | ./data/audio.wav,./data/caption.txt 290 | ./data/audio.wav,./data/caption.txt 291 | ./data/audio.wav,./data/caption.txt 292 | ./data/audio.wav,./data/caption.txt 293 | ./data/audio.wav,./data/caption.txt 294 | ./data/audio.wav,./data/caption.txt 295 | ./data/audio.wav,./data/caption.txt 296 | ./data/audio.wav,./data/caption.txt 297 | ./data/audio.wav,./data/caption.txt 298 | ./data/audio.wav,./data/caption.txt 299 | ./data/audio.wav,./data/caption.txt 300 | ./data/audio.wav,./data/caption.txt 301 | ./data/audio.wav,./data/caption.txt 302 | ./data/audio.wav,./data/caption.txt 303 | ./data/audio.wav,./data/caption.txt 304 | ./data/audio.wav,./data/caption.txt 305 | ./data/audio.wav,./data/caption.txt 306 | ./data/audio.wav,./data/caption.txt 307 | ./data/audio.wav,./data/caption.txt 308 | ./data/audio.wav,./data/caption.txt 309 | ./data/audio.wav,./data/caption.txt 310 | ./data/audio.wav,./data/caption.txt 311 | ./data/audio.wav,./data/caption.txt 312 | ./data/audio.wav,./data/caption.txt 313 | ./data/audio.wav,./data/caption.txt 314 | ./data/audio.wav,./data/caption.txt 315 | ./data/audio.wav,./data/caption.txt 316 | ./data/audio.wav,./data/caption.txt 317 | ./data/audio.wav,./data/caption.txt 318 | ./data/audio.wav,./data/caption.txt 319 | ./data/audio.wav,./data/caption.txt 320 | ./data/audio.wav,./data/caption.txt 321 | ./data/audio.wav,./data/caption.txt 322 | ./data/audio.wav,./data/caption.txt 323 | ./data/audio.wav,./data/caption.txt 324 | ./data/audio.wav,./data/caption.txt 325 | ./data/audio.wav,./data/caption.txt 326 | ./data/audio.wav,./data/caption.txt 327 | ./data/audio.wav,./data/caption.txt 328 | ./data/audio.wav,./data/caption.txt 329 | ./data/audio.wav,./data/caption.txt 330 | ./data/audio.wav,./data/caption.txt 331 | ./data/audio.wav,./data/caption.txt 332 | ./data/audio.wav,./data/caption.txt 333 | ./data/audio.wav,./data/caption.txt 334 | ./data/audio.wav,./data/caption.txt 335 | ./data/audio.wav,./data/caption.txt 336 | ./data/audio.wav,./data/caption.txt 337 | ./data/audio.wav,./data/caption.txt 338 | ./data/audio.wav,./data/caption.txt 339 | ./data/audio.wav,./data/caption.txt 340 | ./data/audio.wav,./data/caption.txt 341 | ./data/audio.wav,./data/caption.txt 342 | ./data/audio.wav,./data/caption.txt 343 | ./data/audio.wav,./data/caption.txt 344 | ./data/audio.wav,./data/caption.txt 345 | ./data/audio.wav,./data/caption.txt 346 | ./data/audio.wav,./data/caption.txt 347 | ./data/audio.wav,./data/caption.txt 348 | ./data/audio.wav,./data/caption.txt 349 | ./data/audio.wav,./data/caption.txt 350 | ./data/audio.wav,./data/caption.txt 351 | ./data/audio.wav,./data/caption.txt 352 | ./data/audio.wav,./data/caption.txt 353 | ./data/audio.wav,./data/caption.txt 354 | ./data/audio.wav,./data/caption.txt 355 | ./data/audio.wav,./data/caption.txt 356 | ./data/audio.wav,./data/caption.txt 357 | ./data/audio.wav,./data/caption.txt 358 | ./data/audio.wav,./data/caption.txt 359 | ./data/audio.wav,./data/caption.txt 360 | ./data/audio.wav,./data/caption.txt 361 | ./data/audio.wav,./data/caption.txt 362 | ./data/audio.wav,./data/caption.txt 363 | ./data/audio.wav,./data/caption.txt 364 | ./data/audio.wav,./data/caption.txt 365 | ./data/audio.wav,./data/caption.txt 366 | ./data/audio.wav,./data/caption.txt 367 | ./data/audio.wav,./data/caption.txt 368 | ./data/audio.wav,./data/caption.txt 369 | ./data/audio.wav,./data/caption.txt 370 | ./data/audio.wav,./data/caption.txt 371 | ./data/audio.wav,./data/caption.txt 372 | ./data/audio.wav,./data/caption.txt 373 | ./data/audio.wav,./data/caption.txt 374 | ./data/audio.wav,./data/caption.txt 375 | ./data/audio.wav,./data/caption.txt 376 | ./data/audio.wav,./data/caption.txt 377 | ./data/audio.wav,./data/caption.txt 378 | ./data/audio.wav,./data/caption.txt 379 | ./data/audio.wav,./data/caption.txt 380 | ./data/audio.wav,./data/caption.txt 381 | ./data/audio.wav,./data/caption.txt 382 | ./data/audio.wav,./data/caption.txt 383 | ./data/audio.wav,./data/caption.txt 384 | ./data/audio.wav,./data/caption.txt 385 | ./data/audio.wav,./data/caption.txt 386 | ./data/audio.wav,./data/caption.txt 387 | ./data/audio.wav,./data/caption.txt 388 | ./data/audio.wav,./data/caption.txt 389 | ./data/audio.wav,./data/caption.txt 390 | ./data/audio.wav,./data/caption.txt 391 | ./data/audio.wav,./data/caption.txt 392 | ./data/audio.wav,./data/caption.txt 393 | ./data/audio.wav,./data/caption.txt 394 | ./data/audio.wav,./data/caption.txt 395 | ./data/audio.wav,./data/caption.txt 396 | ./data/audio.wav,./data/caption.txt 397 | ./data/audio.wav,./data/caption.txt 398 | ./data/audio.wav,./data/caption.txt 399 | ./data/audio.wav,./data/caption.txt 400 | ./data/audio.wav,./data/caption.txt 401 | ./data/audio.wav,./data/caption.txt 402 | --------------------------------------------------------------------------------