├── .github
└── workflows
│ └── build.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── app.py
├── assets
├── bria.mp3
├── favicon.ico
└── logo.png
├── colab_demo.ipynb
├── data
├── audio.wav
└── caption.txt
├── datasets
├── sample_dataset.csv
└── sample_val_dataset.csv
├── docker-compose.yml
├── fam
├── __init__.py
├── llm
│ ├── __init__.py
│ ├── adapters
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── flattened_encodec.py
│ │ └── tilted_encodec.py
│ ├── config
│ │ ├── __init__.py
│ │ └── finetune_params.py
│ ├── decoders.py
│ ├── enhancers.py
│ ├── fast_inference.py
│ ├── fast_inference_utils.py
│ ├── fast_model.py
│ ├── fast_quantize.py
│ ├── finetune.py
│ ├── inference.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── attn.py
│ │ ├── combined.py
│ │ └── layers.py
│ ├── loaders
│ │ ├── __init__.py
│ │ └── training_data.py
│ ├── mixins
│ │ ├── __init__.py
│ │ ├── causal.py
│ │ └── non_causal.py
│ ├── model.py
│ ├── preprocessing
│ │ ├── __init__.py
│ │ ├── audio_token_mode.py
│ │ └── data_pipeline.py
│ └── utils.py
├── py.typed
├── quantiser
│ ├── __init__.py
│ ├── audio
│ │ ├── __init__.py
│ │ └── speaker_encoder
│ │ │ ├── __init__.py
│ │ │ ├── audio.py
│ │ │ └── model.py
│ └── text
│ │ ├── __init__.py
│ │ └── tokenise.py
└── telemetry
│ ├── README.md
│ ├── __init__.py
│ └── posthog.py
├── poetry.lock
├── pyproject.toml
├── requirements.txt
├── serving.py
├── setup.py
└── tests
├── llm
└── loaders
│ └── test_dataloader.py
└── resources
├── data
└── caption.txt
└── datasets
└── sample_dataset.csv
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Poetry Install Matrix
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.10", "3.11"]
15 |
16 | steps:
17 | - name: Checkout
18 | uses: actions/checkout@v3
19 |
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v4
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 |
25 | - name: Install Poetry
26 | uses: Gr1N/setup-poetry@v8
27 |
28 | - name: Install dependencies
29 | run: |
30 | poetry --version
31 | poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | *.pkl
3 | *.flac
4 | *.npz
5 | *.wav
6 | *.m4a
7 | *.opus
8 | *.npy
9 | *wandb
10 | *.parquet
11 | *.wav
12 | *.pt
13 | *.bin
14 | *.png
15 | *.DS_Store
16 | *.idea
17 | *.ipynb_checkpoints/
18 | *__pycache__/
19 | *.pyc
20 | *.tsv
21 | *.bak
22 | *.tar
23 | *.db
24 | *.dat
25 | *.json
26 |
27 | # Byte-compiled / optimized / DLL files
28 | __pycache__/
29 | *.py[cod]
30 | *$py.class
31 |
32 | # C extensions
33 | *.so
34 |
35 | # Distribution / packaging
36 | .Python
37 | build/
38 | develop-eggs/
39 | dist/
40 | downloads/
41 | eggs/
42 | .eggs/
43 | lib/
44 | lib64/
45 | parts/
46 | sdist/
47 | var/
48 | wheels/
49 | share/python-wheels/
50 | *.egg-info/
51 | .installed.cfg
52 | *.egg
53 | MANIFEST
54 |
55 | # PyInstaller
56 | # Usually these files are written by a python script from a template
57 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
58 | *.manifest
59 | *.spec
60 |
61 | # Installer logs
62 | pip-log.txt
63 | pip-delete-this-directory.txt
64 |
65 | # Unit test / coverage reports
66 | htmlcov/
67 | .tox/
68 | .nox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *.cover
75 | *.py,cover
76 | .hypothesis/
77 | .pytest_cache/
78 | cover/
79 |
80 | # Translations
81 | *.mo
82 | *.pot
83 |
84 | # Django stuff:
85 | *.log
86 | local_settings.py
87 | db.sqlite3
88 | db.sqlite3-journal
89 |
90 | # Flask stuff:
91 | instance/
92 | .webassets-cache
93 |
94 | # Scrapy stuff:
95 | .scrapy
96 |
97 | # Sphinx documentation
98 | docs/_build/
99 |
100 | # PyBuilder
101 | .pybuilder/
102 | target/
103 |
104 | # Jupyter Notebook
105 | .ipynb_checkpoints
106 |
107 | # IPython
108 | profile_default/
109 | ipython_config.py
110 |
111 | # pyenv
112 | # For a library or package, you might want to ignore these files since the code is
113 | # intended to run in multiple environments; otherwise, check them in:
114 | # .python-version
115 |
116 | # pipenv
117 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
119 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
120 | # install all needed dependencies.
121 | #Pipfile.lock
122 |
123 | # poetry
124 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
125 | # This is especially recommended for binary packages to ensure reproducibility, and is more
126 | # commonly ignored for libraries.
127 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
128 | #poetry.lock
129 |
130 | # pdm
131 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
132 | #pdm.lock
133 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
134 | # in version control.
135 | # https://pdm.fming.dev/#use-with-ide
136 | .pdm.toml
137 |
138 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
139 | __pypackages__/
140 |
141 | # Celery stuff
142 | celerybeat-schedule
143 | celerybeat.pid
144 |
145 | # SageMath parsed files
146 | *.sage.py
147 |
148 | # Environments
149 | .env
150 | .venv
151 | env/
152 | venv/
153 | ENV/
154 | env.bak/
155 | venv.bak/
156 |
157 | # Spyder project settings
158 | .spyderproject
159 | .spyproject
160 |
161 | # Rope project settings
162 | .ropeproject
163 |
164 | # mkdocs documentation
165 | /site
166 |
167 | # mypy
168 | .mypy_cache/
169 | .dmypy.json
170 | dmypy.json
171 |
172 | # Pyre type checker
173 | .pyre/
174 |
175 | # pytype static type analyzer
176 | .pytype/
177 |
178 | # Cython debug symbols
179 | cython_debug/
180 |
181 | # PyCharm
182 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
183 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
184 | # and can be added to the global gitignore or merged into this file. For a more nuclear
185 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
186 | #.idea/
187 | **/.tmp
188 | !fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt
189 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/python-poetry/poetry
3 | rev: "1.8"
4 | hooks:
5 | - id: poetry-lock
6 | - id: poetry-export
7 | args: ["--dev", "-f", "requirements.txt", "-o", "requirements.txt"]
8 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as base
2 |
3 | ENV POETRY_NO_INTERACTION=1 \
4 | POETRY_VIRTUALENVS_IN_PROJECT=1 \
5 | POETRY_VIRTUALENVS_CREATE=1 \
6 | POETRY_CACHE_DIR=/tmp/poetry_cache \
7 | DEBIAN_FRONTEND=noninteractive
8 |
9 | # Install system dependencies in a single RUN command to reduce layers
10 | # Combine apt-get update, upgrade, and installation of packages. Clean up in the same layer to reduce image size.
11 | RUN apt-get update && \
12 | apt-get upgrade -y && \
13 | apt-get install -y python3.10 python3-pip git wget curl build-essential pipx && \
14 | apt-get autoremove -y && \
15 | apt-get clean && \
16 | rm -rf /var/lib/apt/lists/*
17 |
18 | # install via pip given ubuntu 22.04 as per docs https://pipx.pypa.io/stable/installation/
19 | RUN python3 -m pip install --user pipx && \
20 | python3 -m pipx ensurepath && \
21 | python3 -m pipx install poetry==1.8.2
22 |
23 | # make pipx installs (i.e poetry) available
24 | ENV PATH="/root/.local/bin:${PATH}"
25 |
26 | # install ffmpeg
27 | RUN wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz &&\
28 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5 &&\
29 | md5sum -c ffmpeg-git-amd64-static.tar.xz.md5 &&\
30 | tar xvf ffmpeg-git-amd64-static.tar.xz &&\
31 | mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ &&\
32 | rm -rf ffmpeg-git-*
33 |
34 | WORKDIR /app
35 |
36 | COPY pyproject.toml poetry.lock ./
37 | RUN touch README.md # poetry will complain otherwise
38 |
39 | RUN poetry install --without dev --no-root
40 | RUN poetry run python -m pip install torch==2.2.1 torchaudio==2.2.1 && \
41 | rm -rf $POETRY_CACHE_DIR
42 |
43 | COPY fam ./fam
44 | COPY serving.py ./
45 | COPY app.py ./
46 |
47 | RUN poetry install --only-root
48 |
49 | ENTRYPOINT ["poetry", "run", "python", "serving.py"]
50 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MetaVoice-1B
2 |
3 |
4 |
5 | [](https://ttsdemo.themetavoice.xyz/)
6 |
7 |
8 |
9 | [](https://discord.gg/tbTbkGEgJM)
10 | [](https://twitter.com/metavoiceio)
11 |
12 |
13 |
14 | MetaVoice-1B is a 1.2B parameter base model trained on 100K hours of speech for TTS (text-to-speech). It has been built with the following priorities:
15 | * **Emotional speech rhythm and tone** in English.
16 | * **Zero-shot cloning for American & British voices**, with 30s reference audio.
17 | * Support for (cross-lingual) **voice cloning with finetuning**.
18 | * We have had success with as little as 1 minute training data for Indian speakers.
19 | * Synthesis of **arbitrary length text**
20 |
21 | We’re releasing MetaVoice-1B under the Apache 2.0 license, *it can be used without restrictions*.
22 |
23 |
24 | ## Quickstart - tl;dr
25 |
26 | Web UI
27 | ```bash
28 | docker-compose up -d ui && docker-compose ps && docker-compose logs -f
29 | ```
30 |
31 | Server
32 | ```bash
33 | # navigate to /docs for API definitions
34 | docker-compose up -d server && docker-compose ps && docker-compose logs -f
35 | ```
36 |
37 | ## Installation
38 |
39 | **Pre-requisites:**
40 | - GPU VRAM >=12GB
41 | - Python >=3.10,<3.12
42 | - pipx ([installation instructions](https://pipx.pypa.io/stable/installation/))
43 |
44 | **Environment setup**
45 | ```bash
46 | # install ffmpeg
47 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz
48 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5
49 | md5sum -c ffmpeg-git-amd64-static.tar.xz.md5
50 | tar xvf ffmpeg-git-amd64-static.tar.xz
51 | sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/
52 | rm -rf ffmpeg-git-*
53 |
54 | # install rust if not installed (ensure you've restarted your terminal after installation)
55 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
56 | ```
57 |
58 | ### Project dependencies installation
59 | 1. [Using poetry](#using-poetry-recommended)
60 | 2. [Using pip/conda](#using-pipconda)
61 |
62 | #### Using poetry (recommended)
63 | ```bash
64 | # install poetry if not installed (ensure you've restarted your terminal after installation)
65 | pipx install poetry
66 |
67 | # disable any conda envs that might interfere with poetry's venv
68 | conda deactivate
69 |
70 | # if running from Linux, keyring backend can hang on `poetry install`. This prevents that.
71 | export PYTHON_KEYRING_BACKEND=keyring.backends.fail.Keyring
72 |
73 | # pip's dependency resolver will complain, this is temporary expected behaviour
74 | # full inference & finetuning functionality will still be available
75 | poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1
76 | ```
77 |
78 | #### Using pip/conda
79 | NOTE 1: When raising issues, we'll ask you to try with poetry first.
80 | NOTE 2: All commands in this README use `poetry` by default, so you can just remove any `poetry run`.
81 |
82 | ```bash
83 | pip install -r requirements.txt
84 | pip install torch==2.2.1 torchaudio==2.2.1
85 | pip install -e .
86 | ```
87 |
88 | ## Usage
89 | 1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py)
90 | ```bash
91 | # You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
92 | # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
93 | poetry run python -i fam/llm/fast_inference.py
94 |
95 | # Run e.g. of API usage within the interactive python session
96 | tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.", spk_ref_path="assets/bria.mp3")
97 | ```
98 | > Note: The script takes 30-90s to startup (depending on hardware). This is because we torch.compile the model for fast inference.
99 |
100 | > On Ampere, Ada-Lovelace, and Hopper architecture GPUs, once compiled, the synthesise() API runs faster than real-time, with a Real-Time Factor (RTF) < 1.0.
101 |
102 | 2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py)
103 | ```bash
104 | # You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
105 | # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
106 |
107 | # navigate to /docs for API definitions
108 | poetry run python serving.py
109 |
110 | poetry run python app.py
111 | ```
112 |
113 | 3. Use it via [Hugging Face](https://huggingface.co/metavoiceio)
114 | 4. [Google Colab Demo](https://colab.research.google.com/github/metavoiceio/metavoice-src/blob/main/colab_demo.ipynb)
115 |
116 | ## Finetuning
117 | We support finetuning the first stage LLM (see [Architecture section](#Architecture)).
118 |
119 | In order to finetune, we expect a "|"-delimited CSV dataset of the following format:
120 |
121 | ```csv
122 | audio_files|captions
123 | ./data/audio.wav|./data/caption.txt
124 | ```
125 |
126 | Note that we don't perform any dataset overlap checks, so ensure that your train and val datasets are disjoint.
127 |
128 | Try it out using our sample datasets via:
129 | ```bash
130 | poetry run finetune --train ./datasets/sample_dataset.csv --val ./datasets/sample_val_dataset.csv
131 | ```
132 |
133 | Once you've trained your model, you can use it for inference via:
134 | ```bash
135 | poetry run python -i fam/llm/fast_inference.py --first_stage_path ./my-finetuned_model.pt
136 | ```
137 |
138 | ### Configuration
139 |
140 | In order to set hyperparameters such as learning rate, what to freeze, etc, you
141 | can edit the [finetune_params.py](./fam/llm/config/finetune_params.py) file.
142 |
143 | We've got a light & optional integration with W&B that can be enabled via setting
144 | `wandb_log = True` & by installing the appropriate dependencies.
145 |
146 | ```bash
147 | poetry install -E observable
148 | ```
149 |
150 | ## Upcoming
151 | - [x] Faster inference ⚡
152 | - [x] Fine-tuning code 📐
153 | - [ ] Synthesis of arbitrary length text
154 |
155 |
156 | ## Architecture
157 | We predict EnCodec tokens from text, and speaker information. This is then diffused up to the waveform level, with post-processing applied to clean up the audio.
158 |
159 | * We use a causal GPT to predict the first two hierarchies of EnCodec tokens. Text and audio are part of the LLM context. Speaker information is passed via conditioning at the token embedding layer. This speaker conditioning is obtained from a separately trained speaker verification network.
160 | - The two hierarchies are predicted in a "flattened interleaved" manner, we predict the first token of the first hierarchy, then the first token of the second hierarchy, then the second token of the first hierarchy, and so on.
161 | - We use condition-free sampling to boost the cloning capability of the model.
162 | - The text is tokenised using a custom trained BPE tokeniser with 512 tokens.
163 | - Note that we've skipped predicting semantic tokens as done in other works, as we found that this isn't strictly necessary.
164 | * We use a non-causal (encoder-style) transformer to predict the rest of the 6 hierarchies from the first two hierarchies. This is a super small model (~10Mn parameters), and has extensive zero-shot generalisation to most speakers we've tried. Since it's non-causal, we're also able to predict all the timesteps in parallel.
165 | * We use multi-band diffusion to generate waveforms from the EnCodec tokens. We noticed that the speech is clearer than using the original RVQ decoder or VOCOS. However, the diffusion at waveform level leaves some background artifacts which are quite unpleasant to the ear. We clean this up in the next step.
166 | * We use DeepFilterNet to clear up the artifacts introduced by the multi-band diffusion.
167 |
168 | ## Optimizations
169 | The model supports:
170 | 1. KV-caching via Flash Decoding
171 | 2. Batching (including texts of different lengths)
172 |
173 | ## Contribute
174 | - See all [active issues](https://github.com/metavoiceio/metavoice-src/issues)!
175 |
176 | ## Acknowledgements
177 | We are grateful to Together.ai for their 24/7 help in marshalling our cluster. We thank the teams of AWS, GCP & Hugging Face for support with their cloud platforms.
178 |
179 | - [A Défossez et. al.](https://arxiv.org/abs/2210.13438) for Encodec.
180 | - [RS Roman et. al.](https://arxiv.org/abs/2308.02560) for Multiband Diffusion.
181 | - [@liusongxiang](https://github.com/liusongxiang/ppg-vc/blob/main/speaker_encoder/inference.py) for speaker encoder implementation.
182 | - [@karpathy](https://github.com/karpathy/nanoGPT) for NanoGPT which our inference implementation is based on.
183 | - [@Rikorose](https://github.com/Rikorose) for DeepFilterNet.
184 |
185 | Apologies in advance if we've missed anyone out. Please let us know if we have.
186 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | project_root = os.path.dirname(os.path.abspath(__file__))
5 | if project_root not in sys.path:
6 | sys.path.insert(0, project_root)
7 |
8 |
9 | import gradio as gr
10 | import tyro
11 |
12 | from fam.llm.fast_inference import TTS
13 | from fam.llm.utils import check_audio_file
14 |
15 | #### setup model
16 | TTS_MODEL = tyro.cli(TTS, args=["--telemetry_origin", "webapp"])
17 |
18 | #### setup interface
19 | RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
20 | MAX_CHARS = 220
21 | PRESET_VOICES = {
22 | # female
23 | "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
24 | # male
25 | "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
26 | "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
27 | }
28 |
29 |
30 | def denormalise_top_p(top_p):
31 | # returns top_p in the range [0.9, 1.0]
32 | return round(0.9 + top_p / 100, 2)
33 |
34 |
35 | def denormalise_guidance(guidance):
36 | # returns guidance in the range [1.0, 3.0]
37 | return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)
38 |
39 |
40 | def _check_file_size(path):
41 | if not path:
42 | return
43 | filesize = os.path.getsize(path)
44 | filesize_mb = filesize / 1024 / 1024
45 | if filesize_mb >= 50:
46 | raise gr.Error(f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB")
47 |
48 |
49 | def _handle_edge_cases(to_say, upload_target):
50 | if not to_say:
51 | raise gr.Error("Please provide text to synthesise")
52 |
53 | if len(to_say) > MAX_CHARS:
54 | gr.Warning(
55 | f"Max {MAX_CHARS} characters allowed. Provided: {len(to_say)} characters. Truncating and generating speech...Result at the end can be unstable as a result."
56 | )
57 |
58 | if not upload_target:
59 | return
60 |
61 | check_audio_file(upload_target) # check file duration to be atleast 30s
62 | _check_file_size(upload_target)
63 |
64 |
65 | def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target):
66 | try:
67 | d_top_p = denormalise_top_p(top_p)
68 | d_guidance = denormalise_guidance(guidance)
69 |
70 | _handle_edge_cases(to_say, upload_target)
71 |
72 | to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
73 |
74 | return TTS_MODEL.synthesise(
75 | text=to_say,
76 | spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
77 | top_p=d_top_p,
78 | guidance_scale=d_guidance,
79 | )
80 | except Exception as e:
81 | raise gr.Error(f"Something went wrong. Reason: {str(e)}")
82 |
83 |
84 | def change_voice_selection_layout(choice):
85 | if choice == RADIO_CHOICES[0]:
86 | return [gr.update(visible=True), gr.update(visible=False)]
87 |
88 | return [gr.update(visible=False), gr.update(visible=True)]
89 |
90 |
91 | title = """
92 |
93 |
94 |
95 |
96 |
97 | \n# TTS by MetaVoice-1B
98 | """
99 |
100 | description = """
101 | MetaVoice-1B is a 1.2B parameter base model for TTS (text-to-speech). It has been built with the following priorities:
102 | \n
103 | * Emotional speech rhythm and tone in English.
104 | * Zero-shot cloning for American & British voices, with 30s reference audio.
105 | * Support for voice cloning with finetuning.
106 | * We have had success with as little as 1 minute training data for Indian speakers.
107 | * Support for long-form synthesis.
108 |
109 | We are releasing the model under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). See [Github](https://github.com/metavoiceio/metavoice-src) for details and to contribute.
110 | """
111 |
112 | with gr.Blocks(title="TTS by MetaVoice") as demo:
113 | gr.Markdown(title)
114 |
115 | with gr.Row():
116 | gr.Markdown(description)
117 |
118 | with gr.Row():
119 | with gr.Column():
120 | to_say = gr.TextArea(
121 | label=f"What should I say!? (max {MAX_CHARS} characters).",
122 | lines=4,
123 | value="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice.",
124 | )
125 | with gr.Row(), gr.Column():
126 | # voice settings
127 | top_p = gr.Slider(
128 | value=5.0,
129 | minimum=0.0,
130 | maximum=10.0,
131 | step=1.0,
132 | label="Speech Stability - improves text following for a challenging speaker",
133 | )
134 | guidance = gr.Slider(
135 | value=5.0,
136 | minimum=1.0,
137 | maximum=5.0,
138 | step=1.0,
139 | label="Speaker similarity - How closely to match speaker identity and speech style.",
140 | )
141 |
142 | # voice select
143 | toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0])
144 |
145 | with gr.Row(visible=True) as row_1:
146 | preset_dropdown = gr.Dropdown(
147 | PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0]
148 | )
149 | with gr.Accordion("Preview: Preset voices", open=False):
150 | for label, path in PRESET_VOICES.items():
151 | gr.Audio(value=path, label=label)
152 |
153 | with gr.Row(visible=False) as row_2:
154 | upload_target = gr.Audio(
155 | sources=["upload"],
156 | type="filepath",
157 | label="Upload a clean sample to clone. Sample should contain 1 speaker, be between 30-90 seconds and not contain background noise.",
158 | )
159 |
160 | toggle.change(
161 | change_voice_selection_layout,
162 | inputs=toggle,
163 | outputs=[row_1, row_2],
164 | )
165 |
166 | with gr.Column():
167 | speech = gr.Audio(
168 | type="filepath",
169 | label="MetaVoice-1B says...",
170 | )
171 |
172 | submit = gr.Button("Generate Speech")
173 | submit.click(
174 | fn=tts,
175 | inputs=[to_say, top_p, guidance, toggle, preset_dropdown, upload_target],
176 | outputs=speech,
177 | )
178 |
179 |
180 | demo.queue()
181 | demo.launch(
182 | favicon_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/favicon.ico"),
183 | server_name="0.0.0.0",
184 | server_port=7861,
185 | )
186 |
--------------------------------------------------------------------------------
/assets/bria.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/assets/bria.mp3
--------------------------------------------------------------------------------
/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/assets/favicon.ico
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/assets/logo.png
--------------------------------------------------------------------------------
/colab_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Installation"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "### Clone the repository"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "vscode": {
22 | "languageId": "plaintext"
23 | }
24 | },
25 | "outputs": [],
26 | "source": [
27 | "!git clone https://github.com/metavoiceio/metavoice-src.git\n",
28 | "%cd metavoice-src"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "### Install dependencies"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {
42 | "vscode": {
43 | "languageId": "plaintext"
44 | }
45 | },
46 | "outputs": [],
47 | "source": [
48 | "!sudo apt install pipx\n",
49 | "!pipx install poetry\n",
50 | "!pipx run poetry install && pipx run poetry run pip install torch==2.2.1 torchaudio==2.2.1\n",
51 | "!pipx run poetry env list | sed 's/ (Activated)//' > poetry_env.txt\n",
52 | "# NOTE: pip's dependency resolver will error & complain, ignore it!\n",
53 | "# its due to a temporary dependency issue, `tts.synthesise` will still work as intended!"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "vscode": {
61 | "languageId": "plaintext"
62 | }
63 | },
64 | "outputs": [],
65 | "source": [
66 | "import sys, pathlib\n",
67 | "venv = pathlib.Path(\"poetry_env.txt\").read_text().strip(\"\\n\")\n",
68 | "sys.path.append(f\"/root/.cache/pypoetry/virtualenvs/{venv}/lib/python3.10/site-packages\")"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Inference"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {
82 | "vscode": {
83 | "languageId": "plaintext"
84 | }
85 | },
86 | "outputs": [],
87 | "source": [
88 | "from IPython.display import Audio, display\n",
89 | "from fam.llm.fast_inference import TTS\n",
90 | "\n",
91 | "tts = TTS()"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {
98 | "vscode": {
99 | "languageId": "plaintext"
100 | }
101 | },
102 | "outputs": [],
103 | "source": [
104 | "wav_file = tts.synthesise(\n",
105 | " text=\"This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.\",\n",
106 | " spk_ref_path=\"assets/bria.mp3\" # you can use any speaker reference file (WAV, OGG, MP3, FLAC, etc.)\n",
107 | ")\n",
108 | "display(Audio(wav_file, autoplay=True))"
109 | ]
110 | }
111 | ],
112 | "metadata": {
113 | "accelerator": "GPU",
114 | "colab": {
115 | "gpuType": "T4",
116 | "provenance": []
117 | },
118 | "kernelspec": {
119 | "display_name": "Python 3",
120 | "name": "python3"
121 | },
122 | "language_info": {
123 | "name": "python"
124 | }
125 | },
126 | "nbformat": 4,
127 | "nbformat_minor": 2
128 | }
129 |
--------------------------------------------------------------------------------
/data/audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/data/audio.wav
--------------------------------------------------------------------------------
/data/caption.txt:
--------------------------------------------------------------------------------
1 | Please call Stella.
--------------------------------------------------------------------------------
/datasets/sample_dataset.csv:
--------------------------------------------------------------------------------
1 | audio_files|captions
2 | ./data/audio.wav|./data/caption.txt
3 | ./data/audio.wav|./data/caption.txt
4 | ./data/audio.wav|./data/caption.txt
5 | ./data/audio.wav|./data/caption.txt
6 | ./data/audio.wav|./data/caption.txt
7 | ./data/audio.wav|./data/caption.txt
8 | ./data/audio.wav|./data/caption.txt
9 | ./data/audio.wav|./data/caption.txt
10 | ./data/audio.wav|./data/caption.txt
11 | ./data/audio.wav|./data/caption.txt
12 | ./data/audio.wav|./data/caption.txt
13 | ./data/audio.wav|./data/caption.txt
14 | ./data/audio.wav|./data/caption.txt
15 | ./data/audio.wav|./data/caption.txt
16 | ./data/audio.wav|./data/caption.txt
17 | ./data/audio.wav|./data/caption.txt
18 | ./data/audio.wav|./data/caption.txt
19 | ./data/audio.wav|./data/caption.txt
20 | ./data/audio.wav|./data/caption.txt
21 | ./data/audio.wav|./data/caption.txt
22 | ./data/audio.wav|./data/caption.txt
23 | ./data/audio.wav|./data/caption.txt
24 | ./data/audio.wav|./data/caption.txt
25 | ./data/audio.wav|./data/caption.txt
26 | ./data/audio.wav|./data/caption.txt
27 | ./data/audio.wav|./data/caption.txt
28 | ./data/audio.wav|./data/caption.txt
29 | ./data/audio.wav|./data/caption.txt
30 | ./data/audio.wav|./data/caption.txt
31 | ./data/audio.wav|./data/caption.txt
32 | ./data/audio.wav|./data/caption.txt
33 | ./data/audio.wav|./data/caption.txt
34 | ./data/audio.wav|./data/caption.txt
35 | ./data/audio.wav|./data/caption.txt
36 | ./data/audio.wav|./data/caption.txt
37 | ./data/audio.wav|./data/caption.txt
38 | ./data/audio.wav|./data/caption.txt
39 | ./data/audio.wav|./data/caption.txt
40 | ./data/audio.wav|./data/caption.txt
41 | ./data/audio.wav|./data/caption.txt
42 | ./data/audio.wav|./data/caption.txt
43 | ./data/audio.wav|./data/caption.txt
44 | ./data/audio.wav|./data/caption.txt
45 | ./data/audio.wav|./data/caption.txt
46 | ./data/audio.wav|./data/caption.txt
47 | ./data/audio.wav|./data/caption.txt
48 | ./data/audio.wav|./data/caption.txt
49 | ./data/audio.wav|./data/caption.txt
50 | ./data/audio.wav|./data/caption.txt
51 | ./data/audio.wav|./data/caption.txt
52 | ./data/audio.wav|./data/caption.txt
53 | ./data/audio.wav|./data/caption.txt
54 | ./data/audio.wav|./data/caption.txt
55 | ./data/audio.wav|./data/caption.txt
56 | ./data/audio.wav|./data/caption.txt
57 | ./data/audio.wav|./data/caption.txt
58 | ./data/audio.wav|./data/caption.txt
59 | ./data/audio.wav|./data/caption.txt
60 | ./data/audio.wav|./data/caption.txt
61 | ./data/audio.wav|./data/caption.txt
62 | ./data/audio.wav|./data/caption.txt
63 | ./data/audio.wav|./data/caption.txt
64 | ./data/audio.wav|./data/caption.txt
65 | ./data/audio.wav|./data/caption.txt
66 | ./data/audio.wav|./data/caption.txt
67 | ./data/audio.wav|./data/caption.txt
68 | ./data/audio.wav|./data/caption.txt
69 | ./data/audio.wav|./data/caption.txt
70 | ./data/audio.wav|./data/caption.txt
71 | ./data/audio.wav|./data/caption.txt
72 | ./data/audio.wav|./data/caption.txt
73 | ./data/audio.wav|./data/caption.txt
74 | ./data/audio.wav|./data/caption.txt
75 | ./data/audio.wav|./data/caption.txt
76 | ./data/audio.wav|./data/caption.txt
77 | ./data/audio.wav|./data/caption.txt
78 | ./data/audio.wav|./data/caption.txt
79 | ./data/audio.wav|./data/caption.txt
80 | ./data/audio.wav|./data/caption.txt
81 | ./data/audio.wav|./data/caption.txt
82 | ./data/audio.wav|./data/caption.txt
83 | ./data/audio.wav|./data/caption.txt
84 | ./data/audio.wav|./data/caption.txt
85 | ./data/audio.wav|./data/caption.txt
86 | ./data/audio.wav|./data/caption.txt
87 | ./data/audio.wav|./data/caption.txt
88 | ./data/audio.wav|./data/caption.txt
89 | ./data/audio.wav|./data/caption.txt
90 | ./data/audio.wav|./data/caption.txt
91 | ./data/audio.wav|./data/caption.txt
92 | ./data/audio.wav|./data/caption.txt
93 | ./data/audio.wav|./data/caption.txt
94 | ./data/audio.wav|./data/caption.txt
95 | ./data/audio.wav|./data/caption.txt
96 | ./data/audio.wav|./data/caption.txt
97 | ./data/audio.wav|./data/caption.txt
98 | ./data/audio.wav|./data/caption.txt
99 | ./data/audio.wav|./data/caption.txt
100 | ./data/audio.wav|./data/caption.txt
101 | ./data/audio.wav|./data/caption.txt
102 | ./data/audio.wav|./data/caption.txt
103 | ./data/audio.wav|./data/caption.txt
104 | ./data/audio.wav|./data/caption.txt
105 | ./data/audio.wav|./data/caption.txt
106 | ./data/audio.wav|./data/caption.txt
107 | ./data/audio.wav|./data/caption.txt
108 | ./data/audio.wav|./data/caption.txt
109 | ./data/audio.wav|./data/caption.txt
110 | ./data/audio.wav|./data/caption.txt
111 | ./data/audio.wav|./data/caption.txt
112 | ./data/audio.wav|./data/caption.txt
113 | ./data/audio.wav|./data/caption.txt
114 | ./data/audio.wav|./data/caption.txt
115 | ./data/audio.wav|./data/caption.txt
116 | ./data/audio.wav|./data/caption.txt
117 | ./data/audio.wav|./data/caption.txt
118 | ./data/audio.wav|./data/caption.txt
119 | ./data/audio.wav|./data/caption.txt
120 | ./data/audio.wav|./data/caption.txt
121 | ./data/audio.wav|./data/caption.txt
122 | ./data/audio.wav|./data/caption.txt
123 | ./data/audio.wav|./data/caption.txt
124 | ./data/audio.wav|./data/caption.txt
125 | ./data/audio.wav|./data/caption.txt
126 | ./data/audio.wav|./data/caption.txt
127 | ./data/audio.wav|./data/caption.txt
128 | ./data/audio.wav|./data/caption.txt
129 | ./data/audio.wav|./data/caption.txt
130 | ./data/audio.wav|./data/caption.txt
131 | ./data/audio.wav|./data/caption.txt
132 | ./data/audio.wav|./data/caption.txt
133 | ./data/audio.wav|./data/caption.txt
134 | ./data/audio.wav|./data/caption.txt
135 | ./data/audio.wav|./data/caption.txt
136 | ./data/audio.wav|./data/caption.txt
137 | ./data/audio.wav|./data/caption.txt
138 | ./data/audio.wav|./data/caption.txt
139 | ./data/audio.wav|./data/caption.txt
140 | ./data/audio.wav|./data/caption.txt
141 | ./data/audio.wav|./data/caption.txt
142 | ./data/audio.wav|./data/caption.txt
143 | ./data/audio.wav|./data/caption.txt
144 | ./data/audio.wav|./data/caption.txt
145 | ./data/audio.wav|./data/caption.txt
146 | ./data/audio.wav|./data/caption.txt
147 | ./data/audio.wav|./data/caption.txt
148 | ./data/audio.wav|./data/caption.txt
149 | ./data/audio.wav|./data/caption.txt
150 | ./data/audio.wav|./data/caption.txt
151 | ./data/audio.wav|./data/caption.txt
152 | ./data/audio.wav|./data/caption.txt
153 | ./data/audio.wav|./data/caption.txt
154 | ./data/audio.wav|./data/caption.txt
155 | ./data/audio.wav|./data/caption.txt
156 | ./data/audio.wav|./data/caption.txt
157 | ./data/audio.wav|./data/caption.txt
158 | ./data/audio.wav|./data/caption.txt
159 | ./data/audio.wav|./data/caption.txt
160 | ./data/audio.wav|./data/caption.txt
161 | ./data/audio.wav|./data/caption.txt
162 | ./data/audio.wav|./data/caption.txt
163 | ./data/audio.wav|./data/caption.txt
164 | ./data/audio.wav|./data/caption.txt
165 | ./data/audio.wav|./data/caption.txt
166 | ./data/audio.wav|./data/caption.txt
167 | ./data/audio.wav|./data/caption.txt
168 | ./data/audio.wav|./data/caption.txt
169 | ./data/audio.wav|./data/caption.txt
170 | ./data/audio.wav|./data/caption.txt
171 | ./data/audio.wav|./data/caption.txt
172 | ./data/audio.wav|./data/caption.txt
173 | ./data/audio.wav|./data/caption.txt
174 | ./data/audio.wav|./data/caption.txt
175 | ./data/audio.wav|./data/caption.txt
176 | ./data/audio.wav|./data/caption.txt
177 | ./data/audio.wav|./data/caption.txt
178 | ./data/audio.wav|./data/caption.txt
179 | ./data/audio.wav|./data/caption.txt
180 | ./data/audio.wav|./data/caption.txt
181 | ./data/audio.wav|./data/caption.txt
182 | ./data/audio.wav|./data/caption.txt
183 | ./data/audio.wav|./data/caption.txt
184 | ./data/audio.wav|./data/caption.txt
185 | ./data/audio.wav|./data/caption.txt
186 | ./data/audio.wav|./data/caption.txt
187 | ./data/audio.wav|./data/caption.txt
188 | ./data/audio.wav|./data/caption.txt
189 | ./data/audio.wav|./data/caption.txt
190 | ./data/audio.wav|./data/caption.txt
191 | ./data/audio.wav|./data/caption.txt
192 | ./data/audio.wav|./data/caption.txt
193 | ./data/audio.wav|./data/caption.txt
194 | ./data/audio.wav|./data/caption.txt
195 | ./data/audio.wav|./data/caption.txt
196 | ./data/audio.wav|./data/caption.txt
197 | ./data/audio.wav|./data/caption.txt
198 | ./data/audio.wav|./data/caption.txt
199 | ./data/audio.wav|./data/caption.txt
200 | ./data/audio.wav|./data/caption.txt
201 | ./data/audio.wav|./data/caption.txt
202 | ./data/audio.wav|./data/caption.txt
203 | ./data/audio.wav|./data/caption.txt
204 | ./data/audio.wav|./data/caption.txt
205 | ./data/audio.wav|./data/caption.txt
206 | ./data/audio.wav|./data/caption.txt
207 | ./data/audio.wav|./data/caption.txt
208 | ./data/audio.wav|./data/caption.txt
209 | ./data/audio.wav|./data/caption.txt
210 | ./data/audio.wav|./data/caption.txt
211 | ./data/audio.wav|./data/caption.txt
212 | ./data/audio.wav|./data/caption.txt
213 | ./data/audio.wav|./data/caption.txt
214 | ./data/audio.wav|./data/caption.txt
215 | ./data/audio.wav|./data/caption.txt
216 | ./data/audio.wav|./data/caption.txt
217 | ./data/audio.wav|./data/caption.txt
218 | ./data/audio.wav|./data/caption.txt
219 | ./data/audio.wav|./data/caption.txt
220 | ./data/audio.wav|./data/caption.txt
221 | ./data/audio.wav|./data/caption.txt
222 | ./data/audio.wav|./data/caption.txt
223 | ./data/audio.wav|./data/caption.txt
224 | ./data/audio.wav|./data/caption.txt
225 | ./data/audio.wav|./data/caption.txt
226 | ./data/audio.wav|./data/caption.txt
227 | ./data/audio.wav|./data/caption.txt
228 | ./data/audio.wav|./data/caption.txt
229 | ./data/audio.wav|./data/caption.txt
230 | ./data/audio.wav|./data/caption.txt
231 | ./data/audio.wav|./data/caption.txt
232 | ./data/audio.wav|./data/caption.txt
233 | ./data/audio.wav|./data/caption.txt
234 | ./data/audio.wav|./data/caption.txt
235 | ./data/audio.wav|./data/caption.txt
236 | ./data/audio.wav|./data/caption.txt
237 | ./data/audio.wav|./data/caption.txt
238 | ./data/audio.wav|./data/caption.txt
239 | ./data/audio.wav|./data/caption.txt
240 | ./data/audio.wav|./data/caption.txt
241 | ./data/audio.wav|./data/caption.txt
242 | ./data/audio.wav|./data/caption.txt
243 | ./data/audio.wav|./data/caption.txt
244 | ./data/audio.wav|./data/caption.txt
245 | ./data/audio.wav|./data/caption.txt
246 | ./data/audio.wav|./data/caption.txt
247 | ./data/audio.wav|./data/caption.txt
248 | ./data/audio.wav|./data/caption.txt
249 | ./data/audio.wav|./data/caption.txt
250 | ./data/audio.wav|./data/caption.txt
251 | ./data/audio.wav|./data/caption.txt
252 | ./data/audio.wav|./data/caption.txt
253 | ./data/audio.wav|./data/caption.txt
254 | ./data/audio.wav|./data/caption.txt
255 | ./data/audio.wav|./data/caption.txt
256 | ./data/audio.wav|./data/caption.txt
257 | ./data/audio.wav|./data/caption.txt
258 | ./data/audio.wav|./data/caption.txt
259 | ./data/audio.wav|./data/caption.txt
260 | ./data/audio.wav|./data/caption.txt
261 | ./data/audio.wav|./data/caption.txt
262 | ./data/audio.wav|./data/caption.txt
263 | ./data/audio.wav|./data/caption.txt
264 | ./data/audio.wav|./data/caption.txt
265 | ./data/audio.wav|./data/caption.txt
266 | ./data/audio.wav|./data/caption.txt
267 | ./data/audio.wav|./data/caption.txt
268 | ./data/audio.wav|./data/caption.txt
269 | ./data/audio.wav|./data/caption.txt
270 | ./data/audio.wav|./data/caption.txt
271 | ./data/audio.wav|./data/caption.txt
272 | ./data/audio.wav|./data/caption.txt
273 | ./data/audio.wav|./data/caption.txt
274 | ./data/audio.wav|./data/caption.txt
275 | ./data/audio.wav|./data/caption.txt
276 | ./data/audio.wav|./data/caption.txt
277 | ./data/audio.wav|./data/caption.txt
278 | ./data/audio.wav|./data/caption.txt
279 | ./data/audio.wav|./data/caption.txt
280 | ./data/audio.wav|./data/caption.txt
281 | ./data/audio.wav|./data/caption.txt
282 | ./data/audio.wav|./data/caption.txt
283 | ./data/audio.wav|./data/caption.txt
284 | ./data/audio.wav|./data/caption.txt
285 | ./data/audio.wav|./data/caption.txt
286 | ./data/audio.wav|./data/caption.txt
287 | ./data/audio.wav|./data/caption.txt
288 | ./data/audio.wav|./data/caption.txt
289 | ./data/audio.wav|./data/caption.txt
290 | ./data/audio.wav|./data/caption.txt
291 | ./data/audio.wav|./data/caption.txt
292 | ./data/audio.wav|./data/caption.txt
293 | ./data/audio.wav|./data/caption.txt
294 | ./data/audio.wav|./data/caption.txt
295 | ./data/audio.wav|./data/caption.txt
296 | ./data/audio.wav|./data/caption.txt
297 | ./data/audio.wav|./data/caption.txt
298 | ./data/audio.wav|./data/caption.txt
299 | ./data/audio.wav|./data/caption.txt
300 | ./data/audio.wav|./data/caption.txt
301 | ./data/audio.wav|./data/caption.txt
302 | ./data/audio.wav|./data/caption.txt
303 | ./data/audio.wav|./data/caption.txt
304 | ./data/audio.wav|./data/caption.txt
305 | ./data/audio.wav|./data/caption.txt
306 | ./data/audio.wav|./data/caption.txt
307 | ./data/audio.wav|./data/caption.txt
308 | ./data/audio.wav|./data/caption.txt
309 | ./data/audio.wav|./data/caption.txt
310 | ./data/audio.wav|./data/caption.txt
311 | ./data/audio.wav|./data/caption.txt
312 | ./data/audio.wav|./data/caption.txt
313 | ./data/audio.wav|./data/caption.txt
314 | ./data/audio.wav|./data/caption.txt
315 | ./data/audio.wav|./data/caption.txt
316 | ./data/audio.wav|./data/caption.txt
317 | ./data/audio.wav|./data/caption.txt
318 | ./data/audio.wav|./data/caption.txt
319 | ./data/audio.wav|./data/caption.txt
320 | ./data/audio.wav|./data/caption.txt
321 | ./data/audio.wav|./data/caption.txt
322 |
--------------------------------------------------------------------------------
/datasets/sample_val_dataset.csv:
--------------------------------------------------------------------------------
1 | audio_files|captions
2 | ./data/audio.wav|./data/caption.txt
3 | ./data/audio.wav|./data/caption.txt
4 | ./data/audio.wav|./data/caption.txt
5 | ./data/audio.wav|./data/caption.txt
6 | ./data/audio.wav|./data/caption.txt
7 | ./data/audio.wav|./data/caption.txt
8 | ./data/audio.wav|./data/caption.txt
9 | ./data/audio.wav|./data/caption.txt
10 | ./data/audio.wav|./data/caption.txt
11 | ./data/audio.wav|./data/caption.txt
12 | ./data/audio.wav|./data/caption.txt
13 | ./data/audio.wav|./data/caption.txt
14 | ./data/audio.wav|./data/caption.txt
15 | ./data/audio.wav|./data/caption.txt
16 | ./data/audio.wav|./data/caption.txt
17 | ./data/audio.wav|./data/caption.txt
18 | ./data/audio.wav|./data/caption.txt
19 | ./data/audio.wav|./data/caption.txt
20 | ./data/audio.wav|./data/caption.txt
21 | ./data/audio.wav|./data/caption.txt
22 | ./data/audio.wav|./data/caption.txt
23 | ./data/audio.wav|./data/caption.txt
24 | ./data/audio.wav|./data/caption.txt
25 | ./data/audio.wav|./data/caption.txt
26 | ./data/audio.wav|./data/caption.txt
27 | ./data/audio.wav|./data/caption.txt
28 | ./data/audio.wav|./data/caption.txt
29 | ./data/audio.wav|./data/caption.txt
30 | ./data/audio.wav|./data/caption.txt
31 | ./data/audio.wav|./data/caption.txt
32 | ./data/audio.wav|./data/caption.txt
33 | ./data/audio.wav|./data/caption.txt
34 | ./data/audio.wav|./data/caption.txt
35 | ./data/audio.wav|./data/caption.txt
36 | ./data/audio.wav|./data/caption.txt
37 | ./data/audio.wav|./data/caption.txt
38 | ./data/audio.wav|./data/caption.txt
39 | ./data/audio.wav|./data/caption.txt
40 | ./data/audio.wav|./data/caption.txt
41 | ./data/audio.wav|./data/caption.txt
42 | ./data/audio.wav|./data/caption.txt
43 | ./data/audio.wav|./data/caption.txt
44 | ./data/audio.wav|./data/caption.txt
45 | ./data/audio.wav|./data/caption.txt
46 | ./data/audio.wav|./data/caption.txt
47 | ./data/audio.wav|./data/caption.txt
48 | ./data/audio.wav|./data/caption.txt
49 | ./data/audio.wav|./data/caption.txt
50 | ./data/audio.wav|./data/caption.txt
51 | ./data/audio.wav|./data/caption.txt
52 | ./data/audio.wav|./data/caption.txt
53 | ./data/audio.wav|./data/caption.txt
54 | ./data/audio.wav|./data/caption.txt
55 | ./data/audio.wav|./data/caption.txt
56 | ./data/audio.wav|./data/caption.txt
57 | ./data/audio.wav|./data/caption.txt
58 | ./data/audio.wav|./data/caption.txt
59 | ./data/audio.wav|./data/caption.txt
60 | ./data/audio.wav|./data/caption.txt
61 | ./data/audio.wav|./data/caption.txt
62 | ./data/audio.wav|./data/caption.txt
63 | ./data/audio.wav|./data/caption.txt
64 | ./data/audio.wav|./data/caption.txt
65 | ./data/audio.wav|./data/caption.txt
66 | ./data/audio.wav|./data/caption.txt
67 | ./data/audio.wav|./data/caption.txt
68 | ./data/audio.wav|./data/caption.txt
69 | ./data/audio.wav|./data/caption.txt
70 | ./data/audio.wav|./data/caption.txt
71 | ./data/audio.wav|./data/caption.txt
72 | ./data/audio.wav|./data/caption.txt
73 | ./data/audio.wav|./data/caption.txt
74 | ./data/audio.wav|./data/caption.txt
75 | ./data/audio.wav|./data/caption.txt
76 | ./data/audio.wav|./data/caption.txt
77 | ./data/audio.wav|./data/caption.txt
78 | ./data/audio.wav|./data/caption.txt
79 | ./data/audio.wav|./data/caption.txt
80 | ./data/audio.wav|./data/caption.txt
81 | ./data/audio.wav|./data/caption.txt
82 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.5"
2 |
3 | networks:
4 | metavoice-net:
5 | driver: bridge
6 |
7 | volumes:
8 | hf-cache:
9 | driver: local
10 |
11 | x-common-settings: &common-settings
12 | volumes:
13 | - hf-cache:/.hf-cache
14 | - ./assets:/app/assets
15 | deploy:
16 | replicas: 1
17 | resources:
18 | reservations:
19 | devices:
20 | - driver: nvidia
21 | count: 1
22 | capabilities: [ gpu ]
23 | runtime: nvidia
24 | ipc: host
25 | tty: true # enable colorized logs
26 | build:
27 | context: .
28 | image: metavoice-server:latest
29 | networks:
30 | - metavoice-net
31 | environment:
32 | - NVIDIA_VISIBLE_DEVICES=all
33 | - HF_HOME=/.hf-cache
34 | logging:
35 | options:
36 | max-size: "100m"
37 | max-file: "10"
38 |
39 | services:
40 | server:
41 | <<: *common-settings
42 | container_name: metavoice-server
43 | command: [ "--port=58004" ]
44 | ports:
45 | - 58004:58004
46 | healthcheck:
47 | test: [ "CMD", "curl", "http://metavoice-server:58004/health" ]
48 | interval: 1m
49 | timeout: 10s
50 | retries: 20
51 | ui:
52 | <<: *common-settings
53 | container_name: metavoice-ui
54 | entrypoint: [ "poetry", "run", "python", "app.py" ]
55 | ports:
56 | - 7861:7861
57 | healthcheck:
58 | test: [ "CMD", "curl", "http://localhost:7861" ]
59 | interval: 1m
60 | timeout: 10s
61 | retries: 1
62 |
--------------------------------------------------------------------------------
/fam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/__init__.py
--------------------------------------------------------------------------------
/fam/llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/__init__.py
--------------------------------------------------------------------------------
/fam/llm/adapters/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
2 | from fam.llm.adapters.tilted_encodec import TiltedEncodec
3 |
--------------------------------------------------------------------------------
/fam/llm/adapters/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 |
4 | class BaseDataAdapter(ABC):
5 | pass
6 |
--------------------------------------------------------------------------------
/fam/llm/adapters/flattened_encodec.py:
--------------------------------------------------------------------------------
1 | from fam.llm.adapters.base import BaseDataAdapter
2 |
3 |
4 | class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter):
5 | def __init__(self, end_of_audio_token):
6 | self._end_of_audio_token = end_of_audio_token
7 |
8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
9 | assert len(tokens) == 1
10 | tokens = tokens[0]
11 |
12 | text_ids = []
13 | extracted_audio_ids = [[], []]
14 |
15 | for t in tokens:
16 | if t < self._end_of_audio_token:
17 | extracted_audio_ids[0].append(t)
18 | elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token:
19 | extracted_audio_ids[1].append(t - self._end_of_audio_token)
20 | # We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token
21 | elif t > 2 * self._end_of_audio_token:
22 | text_ids.append(t)
23 |
24 | if len(set([len(x) for x in extracted_audio_ids])) != 1:
25 | min_len = min([len(x) for x in extracted_audio_ids])
26 | max_len = max([len(x) for x in extracted_audio_ids])
27 | print("WARNING: Number of tokens at each hierarchy must be of the same length!")
28 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
29 | print([len(x) for x in extracted_audio_ids])
30 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
31 |
32 | return text_ids[:-1], extracted_audio_ids
33 |
34 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
35 | """
36 | Performs the required combination and padding as needed.
37 | """
38 | raise NotImplementedError
39 |
--------------------------------------------------------------------------------
/fam/llm/adapters/tilted_encodec.py:
--------------------------------------------------------------------------------
1 | from fam.llm.adapters.base import BaseDataAdapter
2 |
3 |
4 | class TiltedEncodec(BaseDataAdapter):
5 | def __init__(self, end_of_audio_token):
6 | self._end_of_audio_token = end_of_audio_token
7 |
8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
9 | assert len(tokens) > 1
10 |
11 | text_ids = []
12 | extracted_audio_ids = []
13 |
14 | extracted_audio_ids.append([])
15 | # Handle first hierarchy as special case as it contains text tokens as well
16 | # TODO: maybe it doesn't need special case, and can be handled on it's own :)
17 | for t in tokens[0]:
18 | if t > self._end_of_audio_token:
19 | text_ids.append(t)
20 | elif t < self._end_of_audio_token:
21 | extracted_audio_ids[0].append(t)
22 |
23 | # Handle the rest of the hierarchies
24 | for i in range(1, len(tokens)):
25 | token_hierarchy_ids = tokens[i]
26 | extracted_audio_ids.append([])
27 | for t in token_hierarchy_ids:
28 | if t < self._end_of_audio_token:
29 | extracted_audio_ids[i].append(t)
30 |
31 | if len(set([len(x) for x in extracted_audio_ids])) != 1:
32 | min_len = min([len(x) for x in extracted_audio_ids])
33 | max_len = max([len(x) for x in extracted_audio_ids])
34 | print("WARNING: Number of tokens at each hierarchy must be of the same length!")
35 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
36 | print([len(x) for x in extracted_audio_ids])
37 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
38 |
39 | return text_ids[:-1], extracted_audio_ids
40 |
41 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
42 | """
43 | Performs the required combination and padding as needed.
44 | """
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/fam/llm/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/config/__init__.py
--------------------------------------------------------------------------------
/fam/llm/config/finetune_params.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | import os
3 | import uuid
4 | import pathlib
5 | from typing import Literal, Optional
6 | import torch
7 |
8 | batch_size = 2
9 | dataset_size: int = 400
10 | batched_ds_size = dataset_size // batch_size
11 | val_train_ratio = 0.2
12 |
13 | epochs: int = 2
14 | max_iters = batched_ds_size * epochs
15 | learning_rate = 3e-5
16 | last_n_blocks_to_finetune = 1
17 | decay_lr = False
18 | lr_decay_iters = 0 # decay learning rate after this many iterations
19 | min_lr = 3e-6
20 |
21 | eval_interval = batched_ds_size
22 | eval_iters = int(batched_ds_size*val_train_ratio)
23 | eval_only: bool = False # if True, script exits right after the first eval
24 | log_interval = batched_ds_size # don't print too too often
25 | save_interval: int = batched_ds_size * (epochs//2) # save a checkpoint every this many iterations
26 | assert save_interval % eval_interval == 0, "save_interval must be divisible by eval_interval."
27 | seed = 1337
28 | grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0
29 |
30 | wandb_log = False
31 | wandb_project = "project-name"
32 | wandb_run_name = "run-name"
33 | wandb_tags = ["tag1", "tag2"]
34 |
35 | gradient_accumulation_steps = 1
36 | block_size = 2_048
37 | audio_token_mode = "flattened_interleaved"
38 | num_max_audio_tokens_timesteps = 1_024
39 |
40 | n_layer = 24
41 | n_head = 16
42 | n_embd = 2048
43 | dropout = 0.1
44 |
45 | weight_decay = 1e-1
46 | beta1 = 0.9
47 | beta2 = 0.95
48 |
49 | warmup_iters: int = 0 # how many steps to warm up for
50 | out_dir = f"finetune-{epochs=}-{learning_rate=}-{batch_size=}-{last_n_blocks_to_finetune=}-{dropout=}-{uuid.uuid4()}"
51 |
52 | compile = True
53 | num_codebooks = None
54 | norm_type = "rmsnorm"
55 | rmsnorm_eps = 1e-5
56 | nonlinearity_type = "swiglu"
57 | swiglu_multiple_of = 256
58 | attn_kernel_type = "torch_attn"
59 | meta_target_vocab_sizes: Optional[list[int]] = None
60 | speaker_emb_size: int = 256
61 | speaker_cond = True
62 |
63 | # always running finetuning on a single GPU
64 | master_process = True
65 | device: str = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
66 | ddp = False
67 | ddp_world_size = 1
68 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
69 |
70 | causal = True
71 | bias: bool = False # do we use bias inside LayerNorm and Linear layers?
72 | spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not
73 |
--------------------------------------------------------------------------------
/fam/llm/decoders.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os
3 | import pathlib
4 | import uuid
5 | from abc import ABC, abstractmethod
6 | from typing import Callable, Optional, Union
7 |
8 | import julius
9 | import torch
10 | from audiocraft.data.audio import audio_read, audio_write
11 | from audiocraft.models import MultiBandDiffusion # type: ignore
12 |
13 | mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5
14 |
15 |
16 | class Decoder(ABC):
17 | @abstractmethod
18 | def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None):
19 | raise NotImplementedError
20 |
21 |
22 | class EncodecDecoder(Decoder):
23 | def __init__(
24 | self,
25 | tokeniser_decode_fn: Callable[[list[int]], str],
26 | data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
27 | output_dir: str,
28 | ):
29 | self._mbd_sample_rate = 24_000
30 | self._end_of_audio_token = 1024
31 | self._num_codebooks = 8
32 | self.mbd = mbd
33 |
34 | self.tokeniser_decode_fn = tokeniser_decode_fn
35 | self._data_adapter_fn = data_adapter_fn
36 |
37 | self.output_dir = pathlib.Path(output_dir).resolve()
38 | os.makedirs(self.output_dir, exist_ok=True)
39 |
40 | def _save_audio(self, name: str, wav: torch.Tensor):
41 | audio_write(
42 | name,
43 | wav.squeeze(0).cpu(),
44 | self._mbd_sample_rate,
45 | strategy="loudness",
46 | loudness_compressor=True,
47 | )
48 |
49 | def get_tokens(self, audio_path: str) -> list[list[int]]:
50 | """
51 | Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g.
52 | limited codebook reconstruction or sampling from second stage model only).
53 | """
54 | pass
55 | wav, sr = audio_read(audio_path)
56 | if sr != self._mbd_sample_rate:
57 | wav = julius.resample_frac(wav, sr, self._mbd_sample_rate)
58 | if wav.ndim == 2:
59 | wav = wav.unsqueeze(1)
60 | wav = wav.to("cuda")
61 | tokens = self.mbd.codec_model.encode(wav)
62 | tokens = tokens[0][0]
63 |
64 | return tokens.tolist()
65 |
66 | def decode(
67 | self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None
68 | ) -> Union[str, torch.Tensor]:
69 | # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
70 | text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
71 | text = self.tokeniser_decode_fn(text_ids)
72 | # print(f"Text: {text}")
73 |
74 | tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)
75 |
76 | if tokens.shape[1] < self._num_codebooks:
77 | tokens = torch.cat(
78 | [tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1
79 | )
80 |
81 | if causal:
82 | return tokens
83 | else:
84 | with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
85 | wav = self.mbd.tokens_to_wav(tokens)
86 | # NOTE: we couldn't just return wav here as it goes through loudness compression etc :)
87 |
88 | if wav.shape[-1] < 9600:
89 | # this causes problem for the code below, and is also odd :)
90 | # first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!)
91 | raise Exception("wav predicted is shorter than 400ms!")
92 |
93 | try:
94 | wav_file_name = self.output_dir / f"synth_{datetime.now().strftime('%y-%m-%d--%H-%M-%S')}_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
95 | self._save_audio(wav_file_name, wav)
96 | return wav_file_name
97 | except Exception as e:
98 | print(f"Failed to save audio! Reason: {e}")
99 |
100 | wav_file_name = self.output_dir / f"synth_{datetime.now().strftime('%y-%m-%d--%H-%M-%S')}_{uuid.uuid4()}"
101 | self._save_audio(wav_file_name, wav)
102 | return wav_file_name
103 |
--------------------------------------------------------------------------------
/fam/llm/enhancers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC
3 | from typing import Literal, Optional
4 |
5 | from df.enhance import enhance, init_df, load_audio, save_audio
6 | from pydub import AudioSegment
7 |
8 |
9 | def convert_to_wav(input_file: str, output_file: str):
10 | """Convert an audio file to WAV format
11 |
12 | Args:
13 | input_file (str): path to input audio file
14 | output_file (str): path to output WAV file
15 |
16 | """
17 | # Detect the format of the input file
18 | format = input_file.split(".")[-1].lower()
19 |
20 | # Read the audio file
21 | audio = AudioSegment.from_file(input_file, format=format)
22 |
23 | # Export as WAV
24 | audio.export(output_file, format="wav")
25 |
26 |
27 | def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
28 | """Generate the output file path
29 |
30 | Args:
31 | audio_file (str): path to input audio file
32 | tag (str): tag to append to the output file name
33 | ext (str, optional): extension of the output file. Defaults to None.
34 |
35 | Returns:
36 | str: path to output file
37 | """
38 |
39 | directory = "./enhanced"
40 | # Get the name of the input file
41 | filename = os.path.basename(audio_file)
42 |
43 | # Get the name of the input file without the extension
44 | filename_without_extension = os.path.splitext(filename)[0]
45 |
46 | # Get the extension of the input file
47 | extension = ext or os.path.splitext(filename)[1]
48 |
49 | # Generate the output file path
50 | output_file = os.path.join(directory, filename_without_extension + tag + extension)
51 |
52 | return output_file
53 |
54 |
55 | class BaseEnhancer(ABC):
56 | """Base class for audio enhancers"""
57 |
58 | def __init__(self, *args, **kwargs):
59 | raise NotImplementedError
60 |
61 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
62 | raise NotImplementedError
63 |
64 | def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
65 | output_file = make_output_file_path(audio_file, tag, ext=ext)
66 | os.makedirs(os.path.dirname(output_file), exist_ok=True)
67 | return output_file
68 |
69 |
70 | class DFEnhancer(BaseEnhancer):
71 | def __init__(self, *args, **kwargs):
72 | self.model, self.df_state, _ = init_df()
73 |
74 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
75 | output_file = output_file or self.get_output_file(audio_file, "_df")
76 |
77 | audio, _ = load_audio(audio_file, sr=self.df_state.sr())
78 |
79 | enhanced = enhance(self.model, self.df_state, audio)
80 |
81 | save_audio(output_file, enhanced, self.df_state.sr())
82 |
83 | return output_file
84 |
85 |
86 | def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
87 | """Get an audio enhancer
88 |
89 | Args:
90 | enhancer_name (Literal["df"]): name of the audio enhancer
91 |
92 | Raises:
93 | ValueError: if the enhancer name is not recognised
94 |
95 | Returns:
96 | BaseEnhancer: audio enhancer
97 | """
98 |
99 | if enhancer_name == "df":
100 | import warnings
101 |
102 | warnings.filterwarnings(
103 | "ignore",
104 | message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
105 | )
106 | return DFEnhancer()
107 | else:
108 | raise ValueError(f"Unknown enhancer name: {enhancer_name}")
109 |
--------------------------------------------------------------------------------
/fam/llm/fast_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tempfile
4 | import time
5 | from pathlib import Path
6 | from typing import Literal, Optional
7 |
8 | import librosa
9 | import torch
10 | import tyro
11 | from huggingface_hub import snapshot_download
12 |
13 | from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
14 | from fam.llm.decoders import EncodecDecoder
15 | from fam.llm.fast_inference_utils import build_model, main
16 | from fam.llm.inference import (
17 | EncodecDecoder,
18 | InferenceConfig,
19 | Model,
20 | TiltedEncodec,
21 | TrainedBPETokeniser,
22 | get_cached_embedding,
23 | get_cached_file,
24 | get_enhancer,
25 | )
26 | from fam.llm.utils import (
27 | check_audio_file,
28 | get_default_dtype,
29 | get_device,
30 | normalize_text,
31 | )
32 | from fam.telemetry import TelemetryEvent
33 | from fam.telemetry.posthog import PosthogClient
34 |
35 | posthog = PosthogClient() # see fam/telemetry/README.md for more information
36 |
37 |
38 | class TTS:
39 | END_OF_AUDIO_TOKEN = 1024
40 |
41 | def __init__(
42 | self,
43 | model_name: str = "metavoiceio/metavoice-1B-v0.1",
44 | *,
45 | seed: int = 1337,
46 | output_dir: str = "outputs",
47 | quantisation_mode: Optional[Literal["int4", "int8"]] = None,
48 | first_stage_path: Optional[str] = None,
49 | telemetry_origin: Optional[str] = None,
50 | ):
51 | """
52 | Initialise the TTS model.
53 |
54 | Args:
55 | model_name: refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
56 | seed: random seed for reproducibility
57 | output_dir: directory to save output files
58 | quantisation_mode: quantisation mode for first-stage LLM.
59 | Options:
60 | - None for no quantisation (bf16 or fp16 based on device),
61 | - int4 for int4 weight-only quantisation,
62 | - int8 for int8 weight-only quantisation.
63 | first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`.
64 | telemetry_origin: A string identifier that specifies the origin of the telemetry data sent to PostHog.
65 | """
66 |
67 | # NOTE: this needs to come first so that we don't change global state when we want to use
68 | # the torch.compiled-model.
69 | self._dtype = get_default_dtype()
70 | self._device = get_device()
71 | self._model_dir = snapshot_download(repo_id=model_name)
72 | self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
73 | self.output_dir = output_dir
74 | os.makedirs(self.output_dir, exist_ok=True)
75 | if first_stage_path:
76 | print(f"Overriding first stage checkpoint via provided model: {first_stage_path}")
77 | self._first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"
78 |
79 | second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
80 | config_second_stage = InferenceConfig(
81 | ckpt_path=second_stage_ckpt_path,
82 | num_samples=1,
83 | seed=seed,
84 | device=self._device,
85 | dtype=self._dtype,
86 | compile=False,
87 | init_from="resume",
88 | output_dir=self.output_dir,
89 | )
90 | data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
91 | self.llm_second_stage = Model(
92 | config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
93 | )
94 | self.enhancer = get_enhancer("df")
95 |
96 | self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
97 | self.model, self.tokenizer, self.smodel, self.model_size = build_model(
98 | precision=self.precision,
99 | checkpoint_path=Path(self._first_stage_ckpt),
100 | spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
101 | device=self._device,
102 | compile=True,
103 | compile_prefill=True,
104 | quantisation_mode=quantisation_mode,
105 | )
106 | self._seed = seed
107 | self._quantisation_mode = quantisation_mode
108 | self._model_name = model_name
109 | self._telemetry_origin = telemetry_origin
110 |
111 | def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
112 | """
113 | text: Text to speak
114 | spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
115 | top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
116 | guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
117 | temperature: Temperature for sampling applied to both LLMs (first & second stage)
118 |
119 | returns: path to speech .wav file
120 | """
121 | text = normalize_text(text)
122 | spk_ref_path = get_cached_file(spk_ref_path)
123 | check_audio_file(spk_ref_path)
124 | spk_emb = get_cached_embedding(
125 | spk_ref_path,
126 | self.smodel,
127 | ).to(device=self._device, dtype=self.precision)
128 |
129 | start = time.time()
130 | # first stage LLM
131 | tokens = main(
132 | model=self.model,
133 | tokenizer=self.tokenizer,
134 | model_size=self.model_size,
135 | prompt=text,
136 | spk_emb=spk_emb,
137 | top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
138 | guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
139 | temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
140 | )
141 | _, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
142 |
143 | b_speaker_embs = spk_emb.unsqueeze(0)
144 |
145 | # second stage LLM + multi-band diffusion model
146 | wav_files = self.llm_second_stage(
147 | texts=[text],
148 | encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
149 | speaker_embs=b_speaker_embs,
150 | batch_size=1,
151 | guidance_scale=None,
152 | top_p=None,
153 | top_k=200,
154 | temperature=1.0,
155 | max_new_tokens=None,
156 | )
157 |
158 | # enhance using deepfilternet
159 | wav_file = wav_files[0]
160 | with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
161 | self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
162 | shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
163 | print(f"\nSaved audio to {wav_file}.wav")
164 |
165 | # calculating real-time factor (RTF)
166 | time_to_synth_s = time.time() - start
167 | audio, sr = librosa.load(str(wav_file) + ".wav")
168 | duration_s = librosa.get_duration(y=audio, sr=sr)
169 | real_time_factor = time_to_synth_s / duration_s
170 | print(f"\nTotal time to synth (s): {time_to_synth_s}")
171 | print(f"Real-time factor: {real_time_factor:.2f}")
172 |
173 | posthog.capture(
174 | TelemetryEvent(
175 | name="user_ran_tts",
176 | properties={
177 | "model_name": self._model_name,
178 | "text": text,
179 | "temperature": temperature,
180 | "guidance_scale": guidance_scale,
181 | "top_p": top_p,
182 | "spk_ref_path": spk_ref_path,
183 | "speech_duration_s": duration_s,
184 | "time_to_synth_s": time_to_synth_s,
185 | "real_time_factor": round(real_time_factor, 2),
186 | "quantisation_mode": self._quantisation_mode,
187 | "seed": self._seed,
188 | "first_stage_ckpt": self._first_stage_ckpt,
189 | "gpu": torch.cuda.get_device_name(0),
190 | "telemetry_origin": self._telemetry_origin,
191 | },
192 | )
193 | )
194 |
195 | return str(wav_file) + ".wav"
196 |
197 |
198 | if __name__ == "__main__":
199 | tts = tyro.cli(TTS)
200 |
--------------------------------------------------------------------------------
/fam/llm/fast_inference_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without modification, are permitted
5 | # provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this list of
8 | # conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this
11 | # list of conditions and the following disclaimer in the documentation and/or other
12 | # materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its contributors
15 | # may be used to endorse or promote products derived from this software without
16 | # specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 | import itertools
27 | import time
28 | import warnings
29 | from pathlib import Path
30 | from typing import Literal, Optional, Tuple
31 |
32 | import torch
33 | import torch._dynamo.config
34 | import torch._inductor.config
35 | import tqdm
36 |
37 | from fam.llm.fast_quantize import WeightOnlyInt4QuantHandler, WeightOnlyInt8QuantHandler
38 |
39 |
40 | def device_sync(device):
41 | if "cuda" in device:
42 | torch.cuda.synchronize()
43 | elif "cpu" in device:
44 | pass
45 | else:
46 | print(f"device={device} is not yet suppported")
47 |
48 |
49 | torch._inductor.config.coordinate_descent_tuning = True
50 | torch._inductor.config.triton.unique_kernel_names = True
51 | torch._inductor.config.fx_graph_cache = (
52 | True # Experimental feature to reduce compilation times, will be on by default in future
53 | )
54 |
55 | # imports need to happen after setting above flags
56 | from fam.llm.fast_model import Transformer
57 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
58 | from fam.quantiser.text.tokenise import TrainedBPETokeniser
59 |
60 |
61 | def multinomial_sample_one_no_sync(
62 | probs_sort,
63 | ): # Does multinomial sampling without a cuda synchronization
64 | q = torch.empty_like(probs_sort).exponential_(1)
65 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
66 |
67 |
68 | def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor):
69 | # ref: huggingface/transformers
70 |
71 | sorted_logits, sorted_indices = torch.sort(logits, descending=False)
72 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
73 |
74 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
75 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
76 | # Keep at least min_tokens_to_keep
77 | sorted_indices_to_remove[-1:] = 0
78 |
79 | # scatter sorted tensors to original indexing
80 | indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
81 | scores = logits.masked_fill(indices_to_remove, -float("Inf"))
82 | return scores
83 |
84 |
85 | def logits_to_probs(
86 | logits,
87 | *,
88 | temperature: torch.Tensor,
89 | top_p: Optional[torch.Tensor] = None,
90 | top_k: Optional[torch.Tensor] = None,
91 | ):
92 | logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature))
93 |
94 | if top_k is not None:
95 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
96 | pivot = v.select(-1, -1).unsqueeze(-1)
97 | logits = torch.where(logits < pivot, -float("Inf"), logits)
98 |
99 | if top_p is not None:
100 | logits = top_p_sample(logits, top_p)
101 |
102 | probs = torch.nn.functional.softmax(logits, dim=-1)
103 |
104 | return probs
105 |
106 |
107 | def sample(
108 | logits,
109 | guidance_scale: torch.Tensor,
110 | temperature: torch.Tensor,
111 | top_p: Optional[torch.Tensor] = None,
112 | top_k: Optional[torch.Tensor] = None,
113 | ):
114 | # (b, t, vocab_size)
115 | logits = logits[:, -1]
116 | logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0)
117 | logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb
118 | probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k)
119 | idx_next = multinomial_sample_one_no_sync(probs)
120 | return idx_next, probs
121 |
122 |
123 | def prefill(
124 | model: Transformer,
125 | x: torch.Tensor,
126 | spk_emb: torch.Tensor,
127 | input_pos: torch.Tensor,
128 | **sampling_kwargs,
129 | ) -> torch.Tensor:
130 | # input_pos: [B, S]
131 | logits = model(x, spk_emb, input_pos)
132 | return sample(logits, **sampling_kwargs)[0]
133 |
134 |
135 | def decode_one_token(
136 | model: Transformer,
137 | x: torch.Tensor,
138 | spk_emb: torch.Tensor,
139 | input_pos: torch.Tensor,
140 | **sampling_kwargs,
141 | ) -> Tuple[torch.Tensor, torch.Tensor]:
142 | # input_pos: [B, 1]
143 | assert input_pos.shape[-1] == 1
144 | logits = model(x, spk_emb, input_pos)
145 | return sample(logits, **sampling_kwargs)
146 |
147 |
148 | def decode_n_tokens(
149 | model: Transformer,
150 | cur_token: torch.Tensor,
151 | spk_emb: torch.Tensor,
152 | input_pos: torch.Tensor,
153 | num_new_tokens: int,
154 | callback=lambda _: _,
155 | return_probs: bool = False,
156 | end_of_audio_token: int = 2048,
157 | **sampling_kwargs,
158 | ):
159 | new_tokens, new_probs = [], []
160 | for i in tqdm.tqdm(range(num_new_tokens)):
161 | if (cur_token == end_of_audio_token).any():
162 | break
163 | with torch.backends.cuda.sdp_kernel(
164 | enable_flash=False, enable_mem_efficient=False, enable_math=True
165 | ): # Actually better for Inductor to codegen attention here
166 | next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs)
167 | input_pos += 1
168 | new_tokens.append(next_token.clone())
169 | callback(new_tokens[-1])
170 | if return_probs:
171 | new_probs.append(next_prob.clone())
172 | cur_token = next_token.view(1, -1).repeat(2, 1)
173 |
174 | return new_tokens, new_probs
175 |
176 |
177 | def model_forward(model, x, spk_emb, input_pos):
178 | return model(x, spk_emb, input_pos)
179 |
180 |
181 | @torch.no_grad()
182 | def generate(
183 | model: Transformer,
184 | prompt: torch.Tensor,
185 | spk_emb: torch.Tensor,
186 | *,
187 | max_new_tokens: Optional[int] = None,
188 | callback=lambda x: x,
189 | end_of_audio_token: int = 2048,
190 | **sampling_kwargs,
191 | ) -> torch.Tensor:
192 | """
193 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
194 | """
195 | # create an empty tensor of the expected final shape and fill in the current tokens
196 | T = prompt.size(0)
197 | if max_new_tokens is None:
198 | max_seq_length = model.config.block_size
199 | else:
200 | max_seq_length = T + max_new_tokens
201 | max_seq_length = min(max_seq_length, model.config.block_size)
202 | max_new_tokens = max_seq_length - T
203 | if max_new_tokens <= 0:
204 | raise ValueError("Prompt is too long to generate more tokens")
205 |
206 | device, dtype = prompt.device, prompt.dtype
207 |
208 | seq = torch.clone(prompt)
209 | input_pos = torch.arange(0, T, device=device)
210 |
211 | next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
212 | seq = torch.cat([seq, next_token.view(1)])
213 |
214 | input_pos = torch.tensor([T], device=device, dtype=torch.int)
215 |
216 | generated_tokens, _ = decode_n_tokens(
217 | model,
218 | next_token.view(1, -1).repeat(2, 1),
219 | spk_emb,
220 | input_pos,
221 | max_new_tokens - 1,
222 | callback=callback,
223 | end_of_audio_token=end_of_audio_token,
224 | **sampling_kwargs,
225 | )
226 | seq = torch.cat([seq, torch.cat(generated_tokens)])
227 |
228 | return seq
229 |
230 |
231 | def encode_tokens(tokenizer: TrainedBPETokeniser, text: str, device="cuda") -> torch.Tensor:
232 | tokens = tokenizer.encode(text)
233 | return torch.tensor(tokens, dtype=torch.int, device=device)
234 |
235 |
236 | def _load_model(
237 | checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode: Optional[Literal["int4", "int8"]] = None
238 | ):
239 | ##### MODEL
240 | with torch.device("meta"):
241 | model = Transformer.from_name("metavoice-1B")
242 |
243 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
244 | state_dict = checkpoint["model"]
245 | # convert MetaVoice-1B model weights naming to gptfast naming
246 | unwanted_prefix = "_orig_mod."
247 | for k, v in list(state_dict.items()):
248 | if k.startswith(unwanted_prefix):
249 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
250 | state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight")
251 | state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight")
252 | state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight")
253 | state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight")
254 | for k, v in list(state_dict.items()):
255 | if k.startswith("transformer.h."):
256 | state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k)
257 | k = k.replace("transformer.h.", "layers.")
258 | if ".attn.c_attn." in k:
259 | state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k)
260 | k = k.replace(".attn.c_attn.", ".attention.wqkv.")
261 | if ".attn.c_proj." in k:
262 | state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k)
263 | k = k.replace(".attn.c_proj.", ".attention.wo.")
264 | if ".mlp.swiglu.w1." in k:
265 | state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k)
266 | k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")
267 | if ".mlp.swiglu.w3." in k:
268 | state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k)
269 | k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")
270 | if ".ln_1." in k:
271 | state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k)
272 | k = k.replace(".ln_1.", ".attention_norm.")
273 | if ".ln_2." in k:
274 | state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k)
275 | k = k.replace(".ln_2.", ".ffn_norm.")
276 | if ".mlp.c_proj." in k:
277 | state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k)
278 | k = k.replace(".mlp.c_proj.", ".feed_forward.w2.")
279 |
280 | model.load_state_dict(state_dict, assign=True)
281 | model = model.to(device=device, dtype=torch.bfloat16)
282 |
283 | if quantisation_mode == "int8":
284 | warnings.warn(
285 | "int8 quantisation is slower than bf16/fp16 for undebugged reasons! Please set optimisation_mode to `None` or to `int4`."
286 | )
287 | warnings.warn(
288 | "quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality."
289 | )
290 | simple_quantizer = WeightOnlyInt8QuantHandler(model)
291 | quantized_state_dict = simple_quantizer.create_quantized_state_dict()
292 | model = simple_quantizer.convert_for_runtime()
293 | model.load_state_dict(quantized_state_dict, assign=True)
294 | model = model.to(device=device, dtype=torch.bfloat16)
295 | # TODO: int8/int4 doesn't decrease VRAM usage substantially... fix that (might be linked to kv-cache)
296 | torch.cuda.empty_cache()
297 | elif quantisation_mode == "int4":
298 | warnings.warn(
299 | "quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality."
300 | )
301 | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize=128)
302 | quantized_state_dict = simple_quantizer.create_quantized_state_dict()
303 | model = simple_quantizer.convert_for_runtime(use_cuda=True)
304 | model.load_state_dict(quantized_state_dict, assign=True)
305 | model = model.to(device=device, dtype=torch.bfloat16)
306 | torch.cuda.empty_cache()
307 | elif quantisation_mode is not None:
308 | raise Exception(f"Invalid quantisation mode {quantisation_mode}! Must be either 'int4' or 'int8'!")
309 |
310 | ###### TOKENIZER
311 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
312 | tokenizer = TrainedBPETokeniser(**tokenizer_info)
313 |
314 | ###### SPEAKER EMBEDDER
315 | smodel = SpeakerEncoder(
316 | weights_fpath=spk_emb_ckpt_path,
317 | device=device,
318 | eval=True,
319 | verbose=False,
320 | )
321 | return model.eval(), tokenizer, smodel
322 |
323 |
324 | def build_model(
325 | *,
326 | precision: torch.dtype,
327 | checkpoint_path: Path = Path(""),
328 | spk_emb_ckpt_path: Path = Path(""),
329 | compile_prefill: bool = False,
330 | compile: bool = True,
331 | device: str = "cuda",
332 | quantisation_mode: Optional[Literal["int4", "int8"]] = None,
333 | ):
334 | assert checkpoint_path.is_file(), checkpoint_path
335 |
336 | print(f"Using device={device}")
337 |
338 | print("Loading model ...")
339 | t0 = time.time()
340 | model, tokenizer, smodel = _load_model(
341 | checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode=quantisation_mode
342 | )
343 |
344 | device_sync(device=device) # MKG
345 | print(f"Time to load model: {time.time() - t0:.02f} seconds")
346 |
347 | torch.manual_seed(1234)
348 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
349 |
350 | with torch.device(device):
351 | model.setup_spk_cond_mask()
352 | model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size)
353 |
354 | if compile:
355 | print("Compiling...Can take up to 2 mins.")
356 | global decode_one_token, prefill
357 | decode_one_token = torch.compile(
358 | decode_one_token,
359 | mode="max-autotune",
360 | fullgraph=True,
361 | )
362 |
363 | if compile_prefill:
364 | prefill = torch.compile(
365 | prefill,
366 | fullgraph=True,
367 | dynamic=True,
368 | )
369 |
370 | encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
371 | spk_emb = torch.randn((1, 256), device=device, dtype=precision)
372 |
373 | device_sync(device=device) # MKG
374 | t0 = time.perf_counter()
375 | y = generate(
376 | model,
377 | encoded,
378 | spk_emb,
379 | max_new_tokens=200,
380 | callback=lambda x: x,
381 | temperature=torch.tensor(1.0, device=device, dtype=precision),
382 | top_k=None,
383 | top_p=torch.tensor(0.95, device=device, dtype=precision),
384 | guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
385 | end_of_audio_token=9999, # don't end early for compilation stage.
386 | )
387 |
388 | device_sync(device=device) # MKG
389 |
390 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
391 |
392 | return model, tokenizer, smodel, model_size
393 |
394 |
395 | def main(
396 | *,
397 | model,
398 | tokenizer,
399 | model_size,
400 | prompt: str,
401 | guidance_scale: torch.Tensor,
402 | temperature: torch.Tensor,
403 | spk_emb: torch.Tensor,
404 | top_k: Optional[torch.Tensor] = None,
405 | top_p: Optional[torch.Tensor] = None,
406 | device: str = "cuda",
407 | ) -> list:
408 | """Generates text samples based on a pre-trained Transformer model and tokenizer."""
409 |
410 | encoded = encode_tokens(tokenizer, prompt, device=device)
411 | prompt_length = encoded.size(0)
412 |
413 | aggregate_metrics: dict = {
414 | "tokens_per_sec": [],
415 | }
416 |
417 | device_sync(device=device) # MKG
418 |
419 | if True:
420 | callback = lambda x: x
421 | t0 = time.perf_counter()
422 |
423 | y = generate(
424 | model,
425 | encoded,
426 | spk_emb,
427 | callback=callback,
428 | temperature=temperature,
429 | top_k=top_k,
430 | top_p=top_p,
431 | guidance_scale=guidance_scale,
432 | )
433 |
434 | device_sync(device=device) # MKG
435 | t = time.perf_counter() - t0
436 |
437 | tokens_generated = y.size(0) - prompt_length
438 | tokens_sec = tokens_generated / t
439 | aggregate_metrics["tokens_per_sec"].append(tokens_sec)
440 | print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
441 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
442 | # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
443 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n")
444 |
445 | return y.tolist()
446 |
--------------------------------------------------------------------------------
/fam/llm/fast_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without modification, are permitted
5 | # provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this list of
8 | # conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this
11 | # list of conditions and the following disclaimer in the documentation and/or other
12 | # materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its contributors
15 | # may be used to endorse or promote products derived from this software without
16 | # specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 | from dataclasses import dataclass
27 | from functools import reduce
28 | from math import gcd
29 | from typing import Optional, Tuple
30 |
31 | import torch
32 | import torch.nn as nn
33 | from torch import Tensor
34 | from torch.nn import functional as F
35 |
36 | from fam.llm.utils import get_default_dtype
37 |
38 | import logging
39 |
40 | # Adjust the logging level
41 | logger = logging.getLogger("torch")
42 | logger.setLevel(logging.ERROR)
43 |
44 |
45 | def find_multiple(n: int, *args: Tuple[int]) -> int:
46 | k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
47 | if n % k == 0:
48 | return n
49 | return n + k - (n % k)
50 |
51 |
52 | @dataclass
53 | class ModelArgs:
54 | block_size: int = 2048
55 | vocab_size: int = 32000
56 | n_layer: int = 32
57 | n_head: int = 32
58 | dim: int = 4096
59 | speaker_emb_dim: int = 256
60 | intermediate_size: int = None
61 | n_local_heads: int = -1
62 | head_dim: int = 64
63 | norm_eps: float = 1e-5
64 | dtype: torch.dtype = torch.bfloat16
65 |
66 | def __post_init__(self):
67 | if self.n_local_heads == -1:
68 | self.n_local_heads = self.n_head
69 | if self.intermediate_size is None:
70 | hidden_dim = 4 * self.dim
71 | n_hidden = int(2 * hidden_dim / 3)
72 | self.intermediate_size = find_multiple(n_hidden, 256)
73 | self.head_dim = self.dim // self.n_head
74 |
75 | self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()]
76 |
77 | @classmethod
78 | def from_name(cls, name: str):
79 | if name in transformer_configs:
80 | return cls(**transformer_configs[name])
81 | # fuzzy search
82 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
83 | assert len(config) == 1, name
84 | return cls(**transformer_configs[config[0]])
85 |
86 |
87 | transformer_configs = {
88 | "metavoice-1B": dict(
89 | n_layer=24,
90 | n_head=16,
91 | dim=2048,
92 | vocab_size=2562,
93 | ),
94 | }
95 |
96 |
97 | class KVCache(nn.Module):
98 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
99 | super().__init__()
100 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
101 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
102 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
103 |
104 | def update(self, input_pos, k_val, v_val):
105 | # input_pos: [S], k_val: [B, H, S, D]
106 | assert input_pos.shape[0] == k_val.shape[2]
107 |
108 | k_out = self.k_cache
109 | v_out = self.v_cache
110 | k_out[:, :, input_pos] = k_val
111 | v_out[:, :, input_pos] = v_val
112 |
113 | return k_out, v_out
114 |
115 |
116 | class Transformer(nn.Module):
117 | def __init__(self, config: ModelArgs) -> None:
118 | super().__init__()
119 | self.config = config
120 |
121 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
122 | self.pos_embeddings = nn.Embedding(config.block_size, config.dim)
123 | self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False)
124 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
125 | self.norm = RMSNorm(config.dim, eps=config.norm_eps)
126 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
127 |
128 | self.mask_cache: Optional[Tensor] = None
129 | self.max_batch_size = -1
130 | self.max_seq_length = -1
131 |
132 | def setup_spk_cond_mask(self):
133 | self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool)
134 | self.spk_cond_mask[0] = 1
135 |
136 | def setup_caches(self, max_batch_size, max_seq_length):
137 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
138 | return
139 | head_dim = self.config.dim // self.config.n_head
140 | max_seq_length = find_multiple(max_seq_length, 8)
141 | self.max_seq_length = max_seq_length
142 | self.max_batch_size = max_batch_size
143 | for b in self.layers:
144 | b.attention.kv_cache = KVCache(
145 | max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype
146 | )
147 |
148 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
149 |
150 | def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor:
151 | mask = self.causal_mask[None, None, input_pos]
152 | x = (
153 | self.tok_embeddings(idx)
154 | + self.pos_embeddings(input_pos)
155 | # masking for speaker condition free guidance
156 | + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask
157 | )
158 |
159 | for i, layer in enumerate(self.layers):
160 | x = layer(x, input_pos, mask)
161 | x = self.norm(x)
162 | logits = self.output(x)
163 | return logits
164 |
165 | @classmethod
166 | def from_name(cls, name: str):
167 | return cls(ModelArgs.from_name(name))
168 |
169 |
170 | class TransformerBlock(nn.Module):
171 | def __init__(self, config: ModelArgs) -> None:
172 | super().__init__()
173 | self.attention = Attention(config)
174 | self.feed_forward = FeedForward(config)
175 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
176 | self.attention_norm = RMSNorm(config.dim, config.norm_eps)
177 |
178 | def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
179 | h = x + self.attention(self.attention_norm(x), mask, input_pos)
180 | out = h + self.feed_forward(self.ffn_norm(h))
181 | return out
182 |
183 |
184 | class Attention(nn.Module):
185 | def __init__(self, config: ModelArgs):
186 | super().__init__()
187 | assert config.dim % config.n_head == 0
188 |
189 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
190 | # key, query, value projections for all heads, but in a batch
191 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
192 | self.wo = nn.Linear(config.dim, config.dim, bias=False)
193 | self.kv_cache = None
194 |
195 | self.n_head = config.n_head
196 | self.head_dim = config.head_dim
197 | self.n_local_heads = config.n_local_heads
198 | self.dim = config.dim
199 |
200 | def forward(
201 | self,
202 | x: Tensor,
203 | mask: Tensor,
204 | input_pos: Optional[Tensor] = None,
205 | ) -> Tensor:
206 | bsz, seqlen, _ = x.shape
207 |
208 | kv_size = self.n_local_heads * self.head_dim
209 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
210 |
211 | q = q.view(bsz, seqlen, self.n_head, self.head_dim)
212 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
213 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
214 |
215 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
216 |
217 | if self.kv_cache is not None:
218 | k, v = self.kv_cache.update(input_pos, k, v)
219 |
220 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
221 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
222 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
223 |
224 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
225 |
226 | y = self.wo(y)
227 | return y
228 |
229 |
230 | class SwiGLU(nn.Module):
231 | def __init__(self, config: ModelArgs) -> None:
232 | super().__init__()
233 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
234 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
235 |
236 | def forward(self, x: Tensor) -> Tensor:
237 | return F.silu(self.w1(x)) * self.w3(x)
238 |
239 |
240 | class FeedForward(nn.Module):
241 | def __init__(self, config: ModelArgs) -> None:
242 | super().__init__()
243 | self.swiglu = SwiGLU(config)
244 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
245 |
246 | def forward(self, x: Tensor) -> Tensor:
247 | return self.w2(self.swiglu(x))
248 |
249 |
250 | class RMSNorm(nn.Module):
251 | def __init__(self, dim: int, eps: float = 1e-5):
252 | super().__init__()
253 | self.eps = eps
254 | self.weight = nn.Parameter(torch.ones(dim))
255 |
256 | def _norm(self, x):
257 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
258 |
259 | def forward(self, x: Tensor) -> Tensor:
260 | output = self._norm(x.float()).type_as(x)
261 | return output * self.weight
262 |
--------------------------------------------------------------------------------
/fam/llm/fast_quantize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without modification, are permitted
5 | # provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this list of
8 | # conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this
11 | # list of conditions and the following disclaimer in the documentation and/or other
12 | # materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its contributors
15 | # may be used to endorse or promote products derived from this software without
16 | # specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 | import time
27 | from pathlib import Path
28 |
29 | import torch
30 | import torch.nn as nn
31 | import torch.nn.functional as F
32 |
33 | default_device = "cuda" if torch.cuda.is_available() else "cpu"
34 |
35 | ##### Quantization Primitives ######
36 |
37 |
38 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
39 | # assumes symmetric quantization
40 | # assumes axis == 0
41 | # assumes dense memory format
42 | # TODO(future): relax ^ as needed
43 |
44 | # default setup for affine quantization of activations
45 | eps = torch.finfo(torch.float32).eps
46 |
47 | # get min and max
48 | min_val, max_val = torch.aminmax(x, dim=1)
49 |
50 | # calculate scales and zero_points based on min and max
51 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
52 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
53 | device = min_val_neg.device
54 |
55 | max_val_pos = torch.max(-min_val_neg, max_val_pos)
56 | scales = max_val_pos / (float(quant_max - quant_min) / 2)
57 | # ensure scales is the same dtype as the original tensor
58 | scales = torch.clamp(scales, min=eps).to(x.dtype)
59 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
60 |
61 | # quantize based on qmin/qmax/scales/zp
62 | x_div = x / scales.unsqueeze(-1)
63 | x_round = torch.round(x_div)
64 | x_zp = x_round + zero_points.unsqueeze(-1)
65 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
66 |
67 | return quant, scales, zero_points
68 |
69 |
70 | def get_group_qparams(w, n_bit=4, groupsize=128):
71 | # needed for GPTQ with padding
72 | if groupsize > w.shape[-1]:
73 | groupsize = w.shape[-1]
74 | assert groupsize > 1
75 | assert w.shape[-1] % groupsize == 0
76 | assert w.dim() == 2
77 |
78 | to_quant = w.reshape(-1, groupsize)
79 | assert torch.isnan(to_quant).sum() == 0
80 |
81 | max_val = to_quant.amax(dim=1, keepdim=True)
82 | min_val = to_quant.amin(dim=1, keepdim=True)
83 | max_int = 2**n_bit - 1
84 | scales = (max_val - min_val).clamp(min=1e-6) / max_int
85 | zeros = min_val + scales * (2 ** (n_bit - 1))
86 | return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(torch.bfloat16).reshape(w.shape[0], -1)
87 |
88 |
89 | def pack_scales_and_zeros(scales, zeros):
90 | assert scales.shape == zeros.shape
91 | assert scales.dtype == torch.bfloat16
92 | assert zeros.dtype == torch.bfloat16
93 | return (
94 | torch.cat(
95 | [
96 | scales.reshape(scales.size(0), scales.size(1), 1),
97 | zeros.reshape(zeros.size(0), zeros.size(1), 1),
98 | ],
99 | 2,
100 | )
101 | .transpose(0, 1)
102 | .contiguous()
103 | )
104 |
105 |
106 | def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
107 | assert groupsize > 1
108 | # needed for GPTQ single column quantize
109 | if groupsize > w.shape[-1] and scales.shape[-1] == 1:
110 | groupsize = w.shape[-1]
111 |
112 | assert w.shape[-1] % groupsize == 0
113 | assert w.dim() == 2
114 |
115 | to_quant = w.reshape(-1, groupsize)
116 | assert torch.isnan(to_quant).sum() == 0
117 |
118 | scales = scales.reshape(-1, 1)
119 | zeros = zeros.reshape(-1, 1)
120 | min_val = zeros - scales * (2 ** (n_bit - 1))
121 | max_int = 2**n_bit - 1
122 | min_int = 0
123 | w_int32 = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int).to(torch.int32).reshape_as(w)
124 |
125 | return w_int32
126 |
127 |
128 | def group_quantize_tensor(w, n_bit=4, groupsize=128):
129 | scales, zeros = get_group_qparams(w, n_bit, groupsize)
130 | w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
131 | scales_and_zeros = pack_scales_and_zeros(scales, zeros)
132 | return w_int32, scales_and_zeros
133 |
134 |
135 | def group_dequantize_tensor_from_qparams(w_int32, scales, zeros, n_bit=4, groupsize=128):
136 | assert groupsize > 1
137 | # needed for GPTQ single column dequantize
138 | if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
139 | groupsize = w_int32.shape[-1]
140 | assert w_int32.shape[-1] % groupsize == 0
141 | assert w_int32.dim() == 2
142 |
143 | w_int32_grouped = w_int32.reshape(-1, groupsize)
144 | scales = scales.reshape(-1, 1)
145 | zeros = zeros.reshape(-1, 1)
146 |
147 | w_dq = w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
148 | return w_dq
149 |
150 |
151 | ##### Weight-only int8 per-channel quantized code ######
152 |
153 |
154 | def replace_linear_weight_only_int8_per_channel(module):
155 | for name, child in module.named_children():
156 | if isinstance(child, nn.Linear):
157 | setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
158 | else:
159 | replace_linear_weight_only_int8_per_channel(child)
160 |
161 |
162 | class WeightOnlyInt8QuantHandler:
163 | def __init__(self, mod):
164 | self.mod = mod
165 |
166 | @torch.no_grad()
167 | def create_quantized_state_dict(self):
168 | cur_state_dict = self.mod.state_dict()
169 | for fqn, mod in self.mod.named_modules():
170 | # TODO: quantise RMSNorm as well.
171 | if isinstance(mod, torch.nn.Linear):
172 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
173 | cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu")
174 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu")
175 |
176 | return cur_state_dict
177 |
178 | def convert_for_runtime(self):
179 | replace_linear_weight_only_int8_per_channel(self.mod)
180 | return self.mod
181 |
182 |
183 | class WeightOnlyInt8Linear(torch.nn.Module):
184 | __constants__ = ["in_features", "out_features"]
185 | in_features: int
186 | out_features: int
187 | weight: torch.Tensor
188 |
189 | def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
190 | factory_kwargs = {"device": device, "dtype": dtype}
191 | super().__init__()
192 | self.in_features = in_features
193 | self.out_features = out_features
194 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
195 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
196 |
197 | def forward(self, input: torch.Tensor) -> torch.Tensor:
198 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
199 |
200 |
201 | ##### weight only int4 per channel groupwise quantized code ######
202 |
203 |
204 | def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
205 | weight_int32, scales_and_zeros = group_quantize_tensor(weight_bf16, n_bit=4, groupsize=groupsize)
206 | weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
207 | return weight_int4pack, scales_and_zeros
208 |
209 |
210 | def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
211 | origin_x_size = x.size()
212 | x = x.reshape(-1, origin_x_size[-1])
213 | c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
214 | new_shape = origin_x_size[:-1] + (out_features,)
215 | c = c.reshape(new_shape)
216 | return c
217 |
218 |
219 | def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
220 | return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
221 |
222 |
223 | def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda):
224 | for name, child in module.named_children():
225 | if isinstance(child, nn.Linear) and child.out_features % 8 == 0:
226 | if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
227 | setattr(
228 | module,
229 | name,
230 | WeightOnlyInt4Linear(
231 | child.in_features,
232 | child.out_features,
233 | bias=False,
234 | groupsize=groupsize,
235 | inner_k_tiles=inner_k_tiles,
236 | padding=False,
237 | use_cuda=use_cuda,
238 | ),
239 | )
240 | elif padding:
241 | setattr(
242 | module,
243 | name,
244 | WeightOnlyInt4Linear(
245 | child.in_features,
246 | child.out_features,
247 | bias=False,
248 | groupsize=groupsize,
249 | inner_k_tiles=inner_k_tiles,
250 | padding=True,
251 | use_cuda=use_cuda,
252 | ),
253 | )
254 | else:
255 | replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda)
256 |
257 |
258 | class WeightOnlyInt4QuantHandler:
259 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
260 | self.mod = mod
261 | self.groupsize = groupsize
262 | self.inner_k_tiles = inner_k_tiles
263 | self.padding = padding
264 | assert groupsize in [32, 64, 128, 256]
265 | assert inner_k_tiles in [2, 4, 8]
266 |
267 | @torch.no_grad()
268 | def create_quantized_state_dict(self):
269 | cur_state_dict = self.mod.state_dict()
270 | for fqn, mod in self.mod.named_modules():
271 | if isinstance(mod, torch.nn.Linear):
272 | assert not mod.bias
273 | out_features = mod.out_features
274 | in_features = mod.in_features
275 | if out_features % 8 != 0:
276 | continue
277 | assert out_features % 8 == 0, "require out_features % 8 == 0"
278 | print(f"linear: {fqn}, in={in_features}, out={out_features}")
279 |
280 | weight = mod.weight.data
281 | if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
282 | if self.padding:
283 | import torch.nn.functional as F
284 | from model import find_multiple
285 |
286 | print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
287 | padded_in_features = find_multiple(in_features, 1024)
288 | weight = F.pad(weight, pad=(0, padded_in_features - in_features))
289 | else:
290 | print(
291 | f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
292 | + "and that groupsize and inner_k_tiles*16 evenly divide into it"
293 | )
294 | continue
295 | weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
296 | weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles
297 | )
298 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
299 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
300 |
301 | return cur_state_dict
302 |
303 | def convert_for_runtime(self, use_cuda):
304 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
305 | return self.mod
306 |
307 |
308 | class WeightOnlyInt4Linear(torch.nn.Module):
309 | __constants__ = ["in_features", "out_features"]
310 | in_features: int
311 | out_features: int
312 | weight: torch.Tensor
313 |
314 | def __init__(
315 | self,
316 | in_features: int,
317 | out_features: int,
318 | bias=True,
319 | device=None,
320 | dtype=None,
321 | groupsize: int = 128,
322 | inner_k_tiles: int = 8,
323 | padding: bool = True,
324 | use_cuda=True,
325 | ) -> None:
326 | super().__init__()
327 | self.padding = padding
328 | if padding:
329 | from model import find_multiple
330 |
331 | self.origin_in_features = in_features
332 | in_features = find_multiple(in_features, 1024)
333 |
334 | self.in_features = in_features
335 | self.out_features = out_features
336 | assert not bias, "require bias=False"
337 | self.groupsize = groupsize
338 | self.inner_k_tiles = inner_k_tiles
339 |
340 | assert out_features % 8 == 0, "require out_features % 8 == 0"
341 | assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
342 | if use_cuda:
343 | self.register_buffer(
344 | "weight",
345 | torch.empty(
346 | (out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32
347 | ),
348 | )
349 | else:
350 | self.register_buffer("weight", torch.empty((out_features, in_features // 2), dtype=torch.uint8))
351 | self.register_buffer(
352 | "scales_and_zeros", torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
353 | )
354 |
355 | def forward(self, input: torch.Tensor) -> torch.Tensor:
356 | input = input.to(torch.bfloat16)
357 | if self.padding:
358 | import torch.nn.functional as F
359 |
360 | input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
361 | return linear_forward_int4(input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize)
362 |
--------------------------------------------------------------------------------
/fam/llm/finetune.py:
--------------------------------------------------------------------------------
1 | """
2 | Module responsible for finetuning the first stage LLM.
3 | """
4 |
5 | import itertools
6 | import math
7 | import time
8 | from pathlib import Path
9 | from typing import Any, Dict, Optional
10 |
11 | import click
12 | import torch
13 | from huggingface_hub import snapshot_download
14 | from torch.utils.data import DataLoader
15 | from tqdm import tqdm
16 |
17 | from fam.llm.config.finetune_params import *
18 | from fam.llm.loaders.training_data import DynamicComputeDataset
19 | from fam.llm.model import GPT, GPTConfig
20 | from fam.llm.preprocessing.audio_token_mode import get_params_for_mode
21 | from fam.llm.preprocessing.data_pipeline import get_training_tuple
22 | from fam.llm.utils import hash_dictionary
23 | from fam.telemetry import TelemetryEvent
24 | from fam.telemetry.posthog import PosthogClient
25 |
26 | # see fam/telemetry/README.md for more information
27 | posthog = PosthogClient()
28 |
29 | dtype: Literal["bfloat16", "float16", "tfloat32", "float32"] = (
30 | "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16"
31 | ) # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
32 | seed_offset = 0
33 |
34 | torch.manual_seed(seed + seed_offset)
35 | torch.backends.cuda.matmul.allow_tf32 = True if dtype != "float32" else False
36 | torch.backends.cudnn.allow_tf32 = True if dtype != "float32" else False
37 | device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
38 | # note: float16 data type will automatically use a GradScaler
39 | ptdtype = {"float32": torch.float32, "tfloat32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[
40 | dtype
41 | ]
42 | ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
43 |
44 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
45 |
46 | ckpts_base_dir = pathlib.Path(__file__).resolve().parent / "ckpts"
47 | if not os.path.exists(ckpts_base_dir) and master_process:
48 | print("Checkpoints directory didn't exist, creating...")
49 | ckpts_base_dir.mkdir(parents=True)
50 |
51 | if master_process:
52 | if "/" in out_dir:
53 | raise Exception("out_dir should be just a name, not a path with slashes")
54 |
55 | ckpts_save_dir = ckpts_base_dir / out_dir
56 | os.makedirs(ckpts_save_dir, exist_ok=True)
57 |
58 |
59 | def get_globals_state():
60 | """Return entirety of configuration global state which can be used for logging."""
61 | config_keys = [k for k, v in globals().items() if not k.startswith("_") and isinstance(v, (int, float, bool, str))]
62 | return {k: globals()[k] for k in config_keys} # will be useful for logging
63 |
64 |
65 | model_args: dict = dict(
66 | n_layer=n_layer,
67 | n_head=n_head,
68 | n_embd=n_embd,
69 | block_size=block_size,
70 | bias=bias,
71 | vocab_sizes=None,
72 | dropout=dropout,
73 | causal=causal,
74 | norm_type=norm_type,
75 | rmsnorm_eps=rmsnorm_eps,
76 | nonlinearity_type=nonlinearity_type,
77 | spk_emb_on_text=spk_emb_on_text,
78 | attn_kernel_type=attn_kernel_type,
79 | swiglu_multiple_of=swiglu_multiple_of,
80 | ) # start with model_args from command line
81 |
82 |
83 | def strip_prefix(state_dict: Dict[str, Any], unwanted_prefix: str):
84 | # TODO: this also appears in fast_inference_utils._load_model, it should be moved to a common place.
85 | for k, v in list(state_dict.items()):
86 | if k.startswith(unwanted_prefix):
87 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
88 | return state_dict
89 |
90 |
91 | def force_ckpt_args(model_args, checkpoint_model_args) -> None:
92 | # force these config attributes to be equal otherwise we can't even resume training
93 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
94 | for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_sizes", "causal"]:
95 | model_args[k] = checkpoint_model_args[k]
96 | # this enables backward compatability with previously saved checkpoints.
97 | for k in [
98 | "target_vocab_sizes",
99 | "norm_type",
100 | "rmsnorm_eps",
101 | "nonlinearity_type",
102 | "attn_kernel_type",
103 | "spk_emb_on_text",
104 | "swiglu_multiple_of",
105 | ]:
106 | if k in checkpoint_model_args:
107 | model_args[k] = checkpoint_model_args[k]
108 | if attn_kernel_type != model_args["attn_kernel_type"]:
109 | print(
110 | f'Found {model_args["attn_kernel_type"]} kernel type inside model,',
111 | f"but expected {attn_kernel_type}. Manually replacing it.",
112 | )
113 | model_args["attn_kernel_type"] = attn_kernel_type
114 |
115 |
116 | @click.command()
117 | @click.option("--train", type=click.Path(exists=True, path_type=Path), required=True)
118 | @click.option("--val", type=click.Path(exists=True, path_type=Path), required=True)
119 | @click.option("--model-id", type=str, required=False, default="metavoiceio/metavoice-1B-v0.1")
120 | @click.option("--ckpt", type=click.Path(exists=True, path_type=Path))
121 | @click.option("--spk-emb-ckpt", type=click.Path(exists=True, path_type=Path))
122 | def main(train: Path, val: Path, model_id: str, ckpt: Optional[Path], spk_emb_ckpt: Optional[Path]):
123 | if ckpt and spk_emb_ckpt:
124 | checkpoint_path, spk_emb_ckpt_path = ckpt, spk_emb_ckpt
125 | else:
126 | _model_dir = snapshot_download(repo_id=model_id)
127 | checkpoint_path = Path(f"{_model_dir}/first_stage.pt")
128 | spk_emb_ckpt_path = Path(f"{_model_dir}/speaker_encoder.pt")
129 |
130 | mode_params = get_params_for_mode(audio_token_mode, num_max_audio_tokens_timesteps=num_max_audio_tokens_timesteps)
131 | config = get_globals_state()
132 |
133 | checkpoint = torch.load(str(checkpoint_path), mmap=True, map_location=device)
134 | iter_num = checkpoint.get("iter_num", 0)
135 | best_val_loss = checkpoint.get("best_val_loss", 1e9)
136 | checkpoint_model_args = checkpoint["model_args"]
137 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
138 | force_ckpt_args(model_args, checkpoint_model_args)
139 | gptconf = GPTConfig(**model_args) # type: ignore
140 | model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if speaker_cond else None)
141 |
142 | # removing torch.compile module prefixes for pre-compile loading
143 | state_dict = strip_prefix(checkpoint["model"], "_orig_mod.")
144 | model.load_state_dict(state_dict)
145 | model.to(device)
146 | # initialize a GradScaler. If enabled=False scaler is a no-op
147 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
148 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
149 | if compile:
150 | print("Compiling the model... (takes a ~minute)")
151 | # requires PyTorch 2.0
152 | from einops._torch_specific import allow_ops_in_compiled_graph
153 |
154 | allow_ops_in_compiled_graph()
155 | model = torch.compile(model) # type: ignore
156 |
157 | def estimate_loss(dataset, iters: int = eval_iters):
158 | """Estimate loss on a dataset by running on `iters` batches."""
159 | if dataset is None:
160 | return torch.nan
161 | losses = []
162 | for _, batch in zip(tqdm(range(iters)), dataset):
163 | X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device)
164 | with ctx:
165 | _, loss = model(X, Y, speaker_embs=SE, speaker_emb_mask=None)
166 | losses.append(loss.item())
167 | return torch.tensor(losses).mean()
168 |
169 | # learning rate decay scheduler (cosine with warmup)
170 | def get_lr(it):
171 | # 1) linear warmup for warmup_iters steps
172 | if it < warmup_iters:
173 | return learning_rate * it / warmup_iters
174 | # 2) if it > lr_decay_iters, return min learning rate
175 | if it > lr_decay_iters:
176 | return min_lr
177 | # 3) in between, use cosine decay down to min learning rate
178 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
179 | assert 0 <= decay_ratio <= 1
180 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
181 | return min_lr + coeff * (learning_rate - min_lr)
182 |
183 | if wandb_log and master_process:
184 | import wandb
185 |
186 | if os.environ.get("WANDB_RUN_ID", None) is not None:
187 | resume = "must"
188 | else:
189 | resume = None
190 |
191 | wandb.init(project=wandb_project, name=wandb_run_name, tags=wandb_tags, config=config, resume=resume)
192 |
193 | train_dataset = DynamicComputeDataset.from_meta(
194 | tokenizer_info,
195 | mode_params["combine_func"],
196 | spk_emb_ckpt_path,
197 | train,
198 | mode_params["pad_token"],
199 | mode_params["ctx_window"],
200 | device,
201 | )
202 | val_dataset = DynamicComputeDataset.from_meta(
203 | tokenizer_info,
204 | mode_params["combine_func"],
205 | spk_emb_ckpt_path,
206 | val,
207 | mode_params["pad_token"],
208 | mode_params["ctx_window"],
209 | device,
210 | )
211 | train_dataloader = itertools.cycle(DataLoader(train_dataset, batch_size, shuffle=True))
212 | train_data = iter(train_dataloader)
213 | # we do not perform any explicit checks for dataset overlap & leave it to the user
214 | # to handle this
215 | eval_val_data = DataLoader(val_dataset, batch_size, shuffle=True)
216 | # we can use the same Dataset object given it is a mapped dataset & not an iterable
217 | # one that can be exhausted. This implies we will be needlessly recomputing, fine
218 | # for now.
219 | eval_train_data = DataLoader(train_dataset, batch_size, shuffle=True)
220 |
221 | batch = next(train_data)
222 | X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device)
223 |
224 | t0 = time.time()
225 | local_iter_num = 0 # number of iterations in the lifetime of this process
226 | raw_model = model.module if ddp else model # unwrap DDP container if needed
227 | running_mfu = -1.0
228 | total_norm = 0.0
229 | save_checkpoint = False
230 | if master_process:
231 | progress = tqdm(total=max_iters, desc="Training", initial=iter_num)
232 | else:
233 | progress = None
234 |
235 | # finetune last X transformer blocks and the ln_f layer
236 | trainable_count = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)
237 | print(f"Before layer freezing {trainable_count(model)=}...")
238 | for param in model.parameters():
239 | param.requires_grad = False
240 | for param in itertools.chain(
241 | model.transformer.ln_f.parameters(), model.transformer.h[last_n_blocks_to_finetune * -1 :].parameters()
242 | ):
243 | param.requires_grad = True
244 | print(f"After freezing excl. last {last_n_blocks_to_finetune} transformer blocks: {trainable_count(model)=}...")
245 |
246 | # log start of finetuning event
247 | properties = {
248 | **config,
249 | **model_args,
250 | "train": str(train),
251 | "val": str(val),
252 | "model_id": model_id,
253 | "ckpt": ckpt,
254 | "spk_emb_ckpt": spk_emb_ckpt,
255 | }
256 | finetune_jobid = hash_dictionary(properties)
257 | posthog.capture(
258 | TelemetryEvent(
259 | name="user_started_finetuning",
260 | properties={"finetune_jobid": finetune_jobid, **properties},
261 | )
262 | )
263 |
264 | while True:
265 | lr = get_lr(iter_num) if decay_lr else learning_rate
266 | for param_group in optimizer.param_groups:
267 | param_group["lr"] = lr
268 | if master_process:
269 | if iter_num % eval_interval == 0 and master_process:
270 | ckpt_save_name = f"ckpt_{iter_num:07d}.pt"
271 | with torch.no_grad():
272 | model.eval()
273 | losses = {
274 | "train": estimate_loss(eval_train_data),
275 | "val": estimate_loss(eval_val_data),
276 | }
277 | model.train()
278 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
279 | if wandb_log:
280 | wandb.log(
281 | {
282 | "iter": iter_num,
283 | "train/loss": losses["train"],
284 | "val/loss": losses["val"],
285 | "lr": lr,
286 | "mfu": running_mfu * 100, # convert to percentage
287 | "stats/total_norm": total_norm,
288 | }
289 | )
290 | if losses["val"] < best_val_loss:
291 | best_val_loss = losses["val"]
292 | if iter_num > 0:
293 | ckpt_save_name = ckpt_save_name.replace(
294 | ".pt", f"_bestval_{best_val_loss}".replace(".", "_") + ".pt"
295 | )
296 | save_checkpoint = True
297 |
298 | save_checkpoint = save_checkpoint or iter_num % save_interval == 0
299 | if save_checkpoint and iter_num > 0:
300 | checkpoint = {
301 | "model": raw_model.state_dict(), # type: ignore
302 | "optimizer": optimizer.state_dict(),
303 | "model_args": model_args,
304 | "iter_num": iter_num,
305 | "best_val_loss": best_val_loss,
306 | "config": config,
307 | "meta": {
308 | "speaker_cond": speaker_cond,
309 | "speaker_emb_size": speaker_emb_size,
310 | "tokenizer": tokenizer_info,
311 | },
312 | }
313 | torch.save(checkpoint, os.path.join(ckpts_save_dir, ckpt_save_name))
314 | print(f"saving checkpoint to {ckpts_save_dir}")
315 | save_checkpoint = False
316 | if iter_num == 0 and eval_only:
317 | break
318 | # forward backward update, with optional gradient accumulation to simulate larger batch size
319 | # and using the GradScaler if data type is float16
320 | for micro_step in range(gradient_accumulation_steps):
321 | if ddp:
322 | # in DDP training we only need to sync gradients at the last micro step.
323 | # the official way to do this is with model.no_sync() context manager, but
324 | # I really dislike that this bloats the code and forces us to repeat code
325 | # looking at the source of that context manager, it just toggles this variable
326 | model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 # type: ignore
327 | with ctx: # type: ignore
328 | logits, loss = model(X, Y, speaker_embs=SE, speaker_emb_mask=None)
329 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
330 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
331 | batch = next(train_data)
332 | X, Y, SE = get_training_tuple(
333 | batch,
334 | causal,
335 | num_codebooks,
336 | speaker_cond,
337 | device,
338 | )
339 | # backward pass, with gradient scaling if training in fp16
340 | scaler.scale(loss).backward()
341 | # clip the gradient
342 | if grad_clip != 0.0:
343 | scaler.unscale_(optimizer)
344 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
345 | # step the optimizer and scaler if training in fp16
346 | scaler.step(optimizer)
347 | scaler.update()
348 | # flush the gradients as soon as we can, no need for this memory anymore
349 | optimizer.zero_grad(set_to_none=True)
350 |
351 | # timing and logging
352 | t1 = time.time()
353 | dt = t1 - t0
354 | t0 = t1
355 | if master_process:
356 | # get loss as float. note: this is a CPU-GPU sync point
357 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
358 | lossf = loss.item() * gradient_accumulation_steps
359 | progress.update(1)
360 | progress.set_description(f"Training: loss {lossf:.4f}, time {dt*1000:.2f}ms")
361 | if iter_num % log_interval == 0:
362 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
363 |
364 | iter_num += 1
365 | local_iter_num += 1
366 |
367 | # termination conditions
368 | if iter_num > max_iters:
369 | # log end of finetuning event
370 | posthog.capture(
371 | TelemetryEvent(
372 | name="user_completed_finetuning",
373 | properties={"finetune_jobid": finetune_jobid, "loss": round(lossf, 4)},
374 | )
375 | )
376 | break
377 |
378 |
379 | if __name__ == "__main__":
380 | main()
381 |
--------------------------------------------------------------------------------
/fam/llm/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.layers.attn import SelfAttention
2 | from fam.llm.layers.combined import Block
3 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU
4 |
--------------------------------------------------------------------------------
/fam/llm/layers/attn.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class SelfAttention(nn.Module):
9 | def __init__(self, config):
10 | """
11 | Initializes the SelfAttention module.
12 |
13 | Args:
14 | config: An object containing the configuration parameters for the SelfAttention module.
15 | """
16 | super().__init__()
17 | self._validate_config(config)
18 | self._initialize_parameters(config)
19 |
20 | def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
21 | """
22 | Empties the key-value cache.
23 |
24 | Args:
25 | batch_size: The batch size.
26 | kv_cache_maxlen: The maximum length of the key-value cache.
27 | dtype: The data type of the cache.
28 |
29 | Raises:
30 | Exception: If trying to empty the KV cache when it is disabled.
31 | """
32 | if self.kv_cache_enabled is False:
33 | raise Exception("Trying to empty KV cache when it is disabled")
34 |
35 | # register so that the cache moves devices along with the module
36 | # TODO: get rid of re-allocation.
37 | self.register_buffer(
38 | "kv_cache",
39 | torch.zeros(
40 | 2,
41 | batch_size,
42 | kv_cache_maxlen,
43 | self.n_head,
44 | self.n_embd // self.n_head,
45 | dtype=dtype,
46 | device=self.c_attn.weight.device,
47 | ),
48 | persistent=False,
49 | )
50 |
51 | self.kv_cache_first_empty_index = 0
52 |
53 | def _initialize_parameters(self, config):
54 | """
55 | Initializes the parameters of the SelfAttention module.
56 |
57 | Args:
58 | config: An object containing the configuration parameters for the SelfAttention module.
59 | """
60 | # key, query, value projections for all heads, but in a batch
61 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
62 |
63 | # output projection
64 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
65 |
66 | # regularization
67 | self.resid_dropout = nn.Dropout(config.dropout)
68 | self.n_head = config.n_head
69 | self.n_embd = config.n_embd
70 | self.dropout = config.dropout
71 | self.causal = config.causal
72 | self.attn_kernel_type = config.attn_kernel_type
73 | self.attn_dropout = nn.Dropout(config.dropout)
74 |
75 | self.kv_cache_enabled = False
76 |
77 | def _validate_config(self, config):
78 | """
79 | Validates the configuration parameters.
80 |
81 | Args:
82 | config: An object containing the configuration parameters for the SelfAttention module.
83 |
84 | Raises:
85 | AssertionError: If the embedding dimension is not divisible by the number of heads.
86 | """
87 | assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads"
88 |
89 | def _update_kv_cache(self, q, k, v):
90 | """
91 | Updates the key-value cache.
92 |
93 | Args:
94 | q: The query tensor.
95 | k: The key tensor.
96 | v: The value tensor.
97 |
98 | Returns:
99 | The updated key and value tensors.
100 |
101 | Raises:
102 | AssertionError: If the dimensions of the query, key, and value tensors are not compatible.
103 | """
104 | q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1]
105 |
106 | if self.kv_cache_first_empty_index == 0:
107 | assert q_time == k_time and q_time == v_time
108 | else:
109 | assert (
110 | q_time == 1
111 | ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}"
112 |
113 | self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k
114 | self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v
115 | self.kv_cache_first_empty_index += q_time
116 |
117 | k = self.kv_cache[0, :, : self.kv_cache_first_empty_index]
118 | v = self.kv_cache[1, :, : self.kv_cache_first_empty_index]
119 |
120 | return k, v
121 |
122 | def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
123 | """
124 | Performs attention using the torch.nn.functional.scaled_dot_product_attention function.
125 |
126 | Args:
127 | c_x: The input tensor.
128 |
129 | Returns:
130 | The output tensor.
131 | """
132 | q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs)
133 | q = q.squeeze(2) # (B, T, nh, hs)
134 | k = k.squeeze(2) # (B, T, nh, hs)
135 | v = v.squeeze(2) # (B, T, nh, hs)
136 |
137 | # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
138 | # use no mask for the "one time step" parts.
139 | # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
140 | is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)
141 |
142 | if self.kv_cache_enabled:
143 | k, v = self._update_kv_cache(q, k, v)
144 |
145 | q = q.transpose(1, 2) # (B, nh, T, hs)
146 | k = k.transpose(1, 2) # (B, nh, T, hs)
147 | v = v.transpose(1, 2) # (B, nh, T, hs)
148 | y = torch.nn.functional.scaled_dot_product_attention(
149 | q,
150 | k,
151 | v,
152 | attn_mask=None,
153 | dropout_p=self.dropout if self.training else 0,
154 | is_causal=is_causal_attn_mask,
155 | ).transpose(
156 | 1, 2
157 | ) # (B, nh, T, hs) -> (B, T, nh, hs)
158 |
159 | return y
160 |
161 | def forward(self, x):
162 | """
163 | Performs the forward pass of the SelfAttention module.
164 |
165 | Args:
166 | x: The input tensor.
167 |
168 | Returns:
169 | The output tensor.
170 | """
171 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
172 |
173 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
174 | c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)
175 |
176 | # causal self-attention;
177 | if self.attn_kernel_type == "torch_attn":
178 | y = self._torch_attn(c_x)
179 | else:
180 | raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")
181 |
182 | y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh)
183 | # output projection
184 | y = self.resid_dropout(self.c_proj(y))
185 | return y
186 |
--------------------------------------------------------------------------------
/fam/llm/layers/combined.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from fam.llm.layers.attn import SelfAttention
4 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
5 |
6 |
7 | class Block(nn.Module):
8 | """
9 | Block class represents a single block in the model.
10 |
11 | Args:
12 | config (object): Configuration object containing parameters for the block.
13 |
14 | Attributes:
15 | ln_1 (object): Layer normalization for the attention layer.
16 | ln_2 (object): Layer normalization for the feed-forward layer.
17 | attn (object): Self-attention layer.
18 | mlp (object): Multi-layer perceptron layer.
19 |
20 | Methods:
21 | forward(x): Performs forward pass through the block.
22 | """
23 |
24 | def __init__(self, config):
25 | super().__init__()
26 | if config.norm_type == "rmsnorm":
27 | if config.rmsnorm_eps is None:
28 | raise Exception("RMSNorm requires rmsnorm_eps to be set")
29 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
30 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
31 | elif config.norm_type == "layernorm":
32 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
33 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
34 | else:
35 | raise Exception(f"Unknown norm type: {config.norm_type}")
36 | self.attn = SelfAttention(config)
37 |
38 | self.mlp = MLP(config)
39 |
40 | def forward(self, x):
41 | """
42 | Performs forward pass through the block.
43 |
44 | Args:
45 | x (tensor): Input tensor.
46 |
47 | Returns:
48 | tensor: Output tensor after passing through the block.
49 | """
50 | x = x + self.attn(self.ln_1(x))
51 | x = x + self.mlp(self.ln_2(x))
52 | return x
53 |
--------------------------------------------------------------------------------
/fam/llm/layers/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class LayerNorm(nn.Module):
9 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
10 |
11 | def __init__(self, ndim, bias):
12 | super().__init__()
13 | self.weight = nn.Parameter(torch.ones(ndim))
14 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
15 |
16 | def forward(self, input):
17 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
18 |
19 |
20 | class RMSNorm(nn.Module):
21 | def __init__(self, ndim: int, eps: float):
22 | super().__init__()
23 | self.eps = eps
24 | self.weight = nn.Parameter(torch.ones(ndim))
25 |
26 | def _norm(self, x):
27 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28 |
29 | def forward(self, x):
30 | return self._norm(x) * self.weight
31 |
32 |
33 | class SwiGLU(nn.Module):
34 | def __init__(self, in_dim, out_dim, bias) -> None:
35 | super().__init__()
36 | self.w1 = nn.Linear(in_dim, out_dim, bias=bias)
37 | self.w3 = nn.Linear(in_dim, out_dim, bias=bias)
38 |
39 | def forward(self, x):
40 | return F.silu(self.w1(x)) * self.w3(x)
41 |
42 |
43 | class MLP(nn.Module):
44 | def __init__(self, config):
45 | super().__init__()
46 | self.non_linearity = config.nonlinearity_type
47 | hidden_dim = 4 * config.n_embd
48 | if config.nonlinearity_type == "gelu":
49 | self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
50 | self.gelu = nn.GELU()
51 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
52 | elif config.nonlinearity_type == "swiglu":
53 | if config.swiglu_multiple_of is None:
54 | raise Exception("SwiGLU requires swiglu_multiple_of to be set")
55 | hidden_dim = int(2 * hidden_dim / 3)
56 | hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of)
57 | # set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__()
58 | self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias)
59 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
60 | else:
61 | raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}")
62 | self.dropout = nn.Dropout(config.dropout)
63 |
64 | def forward(self, x):
65 | if self.non_linearity == "gelu":
66 | x = self.c_fc(x)
67 | x = self.gelu(x)
68 | elif self.non_linearity == "swiglu":
69 | x = self.swiglu(x)
70 | x = self.c_proj(x)
71 | x = self.dropout(x)
72 | return x
73 |
--------------------------------------------------------------------------------
/fam/llm/loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/loaders/__init__.py
--------------------------------------------------------------------------------
/fam/llm/loaders/training_data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Mapping
3 |
4 | import julius
5 | import torch
6 | import math
7 | import numpy as np
8 | import pandas as pd
9 | from audiocraft.data.audio import audio_read
10 | from encodec import EncodecModel
11 | from torch.utils.data import DataLoader, Dataset
12 |
13 | from fam.llm.fast_inference_utils import encode_tokens
14 | from fam.llm.preprocessing.audio_token_mode import CombinerFuncT, CombinerFuncT
15 | from fam.llm.preprocessing.data_pipeline import pad_tokens
16 | from fam.llm.utils import normalize_text
17 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
18 | from fam.quantiser.text.tokenise import TrainedBPETokeniser
19 |
20 | MBD_SAMPLE_RATE = 24000
21 | ENCODEC_BANDWIDTH = 6
22 |
23 |
24 | class DynamicComputeDataset(Dataset):
25 | def __init__(
26 | self,
27 | dataset_dir: Path | str,
28 | encodec_model: EncodecModel,
29 | tokenizer: TrainedBPETokeniser,
30 | spkemb_model: SpeakerEncoder,
31 | combiner: CombinerFuncT,
32 | pad_token: int,
33 | ctx_window: int,
34 | device: str,
35 | ):
36 | self.dataset_dir = dataset_dir
37 | self.encodec_model = encodec_model
38 | self.tokenizer = tokenizer
39 | self.spkemb_model = spkemb_model
40 | self.device = device
41 | self.combiner = combiner
42 | self.pad_token = pad_token
43 | self.ctx_window = ctx_window
44 | self.df = pd.read_csv(dataset_dir, delimiter="|", index_col=False)
45 |
46 | @classmethod
47 | def from_meta(
48 | cls,
49 | tokenizer_info: Mapping[str, Any],
50 | combiner: CombinerFuncT,
51 | speaker_embedding_ckpt_path: Path | str,
52 | dataset_dir: Path | str,
53 | pad_token: int,
54 | ctx_window: int,
55 | device: str
56 | ):
57 | encodec = EncodecModel.encodec_model_24khz().to(device)
58 | encodec.set_target_bandwidth(ENCODEC_BANDWIDTH)
59 | smodel = SpeakerEncoder(
60 | weights_fpath=str(speaker_embedding_ckpt_path),
61 | eval=True,
62 | device=device,
63 | verbose=False,
64 | )
65 | tokeniser = TrainedBPETokeniser(**tokenizer_info)
66 |
67 | return cls(
68 | dataset_dir,
69 | encodec,
70 | tokeniser,
71 | smodel,
72 | combiner,
73 | pad_token,
74 | ctx_window,
75 | device
76 | )
77 |
78 | def __len__(self):
79 | return len(self.df)
80 |
81 | def __getitem__(self, idx):
82 | audio_path, text = self.df.iloc[idx].values.tolist()
83 | with torch.no_grad():
84 | text_tokens = self._extract_text_tokens(text)
85 | encodec_tokens = self._extract_encodec_tokens(audio_path)
86 | speaker_embedding = self._extract_speaker_embedding(audio_path)
87 | combined = self.combiner(encodec_tokens, text_tokens)
88 | padded_combined_tokens = pad_tokens(combined, self.ctx_window, self.pad_token)
89 |
90 | return {"tokens": padded_combined_tokens, "spkemb": speaker_embedding}
91 |
92 | def _extract_text_tokens(self, text: str):
93 | _text = normalize_text(text)
94 | _tokens = encode_tokens(self.tokenizer, _text, self.device)
95 |
96 | return _tokens.detach().cpu().numpy()
97 |
98 | def _extract_encodec_tokens(self, audio_path: str):
99 | wav, sr = audio_read(audio_path)
100 | if sr != MBD_SAMPLE_RATE:
101 | wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)
102 |
103 | # Convert to mono and fix dimensionality
104 | if wav.ndim == 2:
105 | wav = wav.mean(axis=0, keepdims=True)
106 | wav = wav.unsqueeze(0) # Add batch dimension
107 |
108 | wav = wav.to(self.device)
109 | tokens = self.encodec_model.encode(wav)
110 | _tokens = tokens[0][0][0].detach().cpu().numpy() # shape = [8, T]
111 |
112 | return _tokens
113 |
114 | def _extract_speaker_embedding(self, audio_path: str):
115 | emb = self.spkemb_model.embed_utterance_from_file(audio_path, numpy=False) # shape = [256,]
116 | return emb.unsqueeze(0).detach()
117 |
--------------------------------------------------------------------------------
/fam/llm/mixins/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.mixins.causal import CausalInferenceMixin
2 | from fam.llm.mixins.non_causal import NonCausalInferenceMixin
3 |
--------------------------------------------------------------------------------
/fam/llm/mixins/non_causal.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class NonCausalInferenceMixin:
8 | """
9 | Mixin class for non-causal inference in a language model.
10 |
11 | This class provides methods for performing non-causal sampling using a language model.
12 | """
13 |
14 | @torch.no_grad()
15 | def _non_causal_sample(
16 | self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
17 | ):
18 | """
19 | Perform non-causal sampling.
20 |
21 | Args:
22 | idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
23 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
24 | temperature (float): Temperature parameter for scaling the logits.
25 | top_k (int): Number of top options to consider.
26 |
27 | Returns:
28 | torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
29 | """
30 | b, c, t = idx.size()
31 | assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
32 | # forward the model to get the logits for the index in the sequence
33 | list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size)
34 |
35 | # scale by desired temperature
36 | list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size)
37 |
38 | # optionally crop the logits to only the top k options
39 | if top_k is not None:
40 | for i in range(len(list_logits)):
41 | logits = list_logits[i] # (b, t, vocab_size)
42 |
43 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k)
44 | logits[logits < v[:, :, [-1]]] = -float("Inf")
45 | list_logits[i] = logits # (b, t, vocab_size)
46 | assert logits.shape[0] == b and logits.shape[1] == t
47 |
48 | # apply softmax to convert logits to (normalized) probabilities
49 | # TODO: check shapes here!
50 | probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k)
51 | assert probs[0].shape[0] == b and probs[0].shape[1] == t
52 |
53 | # TODO: output shape is as expected
54 | outs = []
55 | for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k)
56 | out = [
57 | torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
58 | ] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
59 | assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
60 | out = torch.cat(out, dim=0) # (b, 1, t)
61 | assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
62 | outs.append(out)
63 |
64 | out = torch.cat(outs, dim=1) # (b, c, t)
65 | assert out.shape[0] == b and out.shape[2] == t
66 |
67 | return out
68 |
--------------------------------------------------------------------------------
/fam/llm/model.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import math
3 | from dataclasses import dataclass, field
4 | from typing import Literal, Optional, Tuple, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | import tqdm
9 | from einops import rearrange
10 | from torch.nn import functional as F
11 |
12 | from fam.llm.layers import Block, LayerNorm, RMSNorm
13 | from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin
14 |
15 | END_OF_TEXT_TOKEN = 1537
16 |
17 |
18 | def _select_spkemb(spkemb, mask):
19 | _, examples, _ = spkemb.shape
20 | mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples)
21 | spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex
22 | mask = mask.transpose(1, 2) # b t ex -> b ex t
23 | return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c
24 |
25 |
26 | @dataclass
27 | class GPTConfig:
28 | block_size: int = 1024
29 | vocab_sizes: list = field(default_factory=list)
30 | target_vocab_sizes: Optional[list] = None
31 | n_layer: int = 12
32 | n_head: int = 12
33 | n_embd: int = 768
34 | dropout: float = 0.0
35 | spkemb_dropout: float = 0.0
36 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
37 | causal: bool = (
38 | True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens
39 | )
40 | spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not
41 | norm_type: str = "layernorm" # "rmsnorm" or "layernorm
42 | rmsnorm_eps: Optional[float] = None # only used for rmsnorm
43 | nonlinearity_type: str = "gelu" # "gelu" or "swiglu"
44 | swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this
45 | attn_kernel_type: Literal["torch_attn"] = "torch_attn"
46 | kv_cache_enabled: bool = False # whether to use key-value cache for attention
47 |
48 |
49 | def _check_speaker_emb_dims(
50 | speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int
51 | ) -> Union[torch.Tensor, list]:
52 | """
53 | Checks that the speaker embedding dimensions are correct, and reshapes them if necessary.
54 | """
55 | if type(speaker_embs) == list:
56 | b_se = len(speaker_embs)
57 | for i, s in enumerate(speaker_embs):
58 | if s is not None:
59 | emb_dim = s.shape[-1]
60 | if s.ndim == 1:
61 | speaker_embs[i] = speaker_embs[i].unsqueeze(0)
62 | else:
63 | if speaker_embs.ndim == 2:
64 | # if we have a single speaker embedding for the whole sequence,
65 | # add a dummy dimension for backwards compatibility
66 | speaker_embs = speaker_embs[:, None, :]
67 |
68 | # num_examples is the number of utterances packed into this sequence
69 | b_se, num_examples, emb_dim = speaker_embs.size()
70 |
71 | assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}"
72 | assert (
73 | emb_dim == expected_speaker_emb_dim
74 | ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}"
75 |
76 | return speaker_embs
77 |
78 |
79 | class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin):
80 | def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None):
81 | """
82 | Initialize the GPT model.
83 |
84 | Args:
85 | config (GPTConfig): Configuration object for the model.
86 | speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None.
87 | """
88 | super().__init__()
89 | assert config.vocab_sizes is not None
90 | assert config.block_size is not None
91 | self.config = config
92 |
93 | self.kv_cache_enabled = False # disabled by default
94 | self.kv_pos = 0
95 |
96 | self.speaker_emb_dim = speaker_emb_dim
97 | self.spk_emb_on_text = config.spk_emb_on_text
98 | if self.config.causal is True and self.spk_emb_on_text is False:
99 | print("!!!!!!!!!!!!!!!!!!")
100 | print(
101 | f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this."
102 | )
103 | print("!!!!!!!!!!!!!!!!!!")
104 | if self.config.causal is False and self.spk_emb_on_text is False:
105 | raise Exception(
106 | "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding."
107 | )
108 |
109 | if config.norm_type == "rmsnorm":
110 | if config.rmsnorm_eps is None:
111 | raise Exception("RMSNorm requires rmsnorm_eps to be set")
112 | ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
113 | elif config.norm_type == "layernorm":
114 | ln_f = LayerNorm(config.n_embd, bias=config.bias)
115 | else:
116 | raise Exception(f"Unknown norm type: {config.norm_type}")
117 |
118 | self.transformer = nn.ModuleDict(
119 | dict(
120 | wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd) for vsize in config.vocab_sizes]),
121 | wpe=nn.Embedding(config.block_size, config.n_embd),
122 | drop=nn.Dropout(config.dropout),
123 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
124 | ln_f=ln_f,
125 | )
126 | )
127 | if speaker_emb_dim is not None:
128 | self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False)
129 |
130 | self.lm_heads = nn.ModuleList()
131 | if config.target_vocab_sizes is not None:
132 | assert config.causal is False
133 | else:
134 | assert config.causal is True
135 |
136 | for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes:
137 | self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False))
138 |
139 | if config.target_vocab_sizes is None:
140 | for i in range(len(config.vocab_sizes)):
141 | # TODO: do we not need to take the transpose here?
142 | # https://paperswithcode.com/method/weight-tying
143 | self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore
144 | assert len(self.lm_heads) == len(
145 | self.transformer.wtes # type: ignore
146 | ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore
147 |
148 | # init all weights
149 | self.apply(self._init_weights)
150 | # apply special scaled init to the residual projections, per GPT-2 paper
151 | for pn, p in self.named_parameters():
152 | if pn.endswith("c_proj.weight"):
153 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
154 |
155 | # report number of parameters
156 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
157 |
158 | def get_num_params(self, non_embedding=True):
159 | """
160 | Return the number of parameters in the model.
161 | For non-embedding count (default), the position embeddings get subtracted.
162 | The token embeddings would too, except due to the parameter sharing these
163 | params are actually used as weights in the final layer, so we include them.
164 | """
165 | n_params = sum(p.numel() for p in self.parameters())
166 | if non_embedding:
167 | n_params -= self.transformer.wpe.weight.numel()
168 | return n_params
169 |
170 | def _init_weights(self, module):
171 | if isinstance(module, nn.Linear):
172 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
173 | if module.bias is not None:
174 | torch.nn.init.zeros_(module.bias)
175 | elif isinstance(module, nn.Embedding):
176 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
177 |
178 | def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor:
179 | """
180 | This is in a separate function so we can test it easily.
181 | """
182 | # find index of end of text token in each sequence, then generate a binary mask
183 | # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token.
184 | # Note: this does NOT mask the token. This is important so that the first audio token predicted
185 | # has speaker information to use.
186 |
187 | # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens.
188 | is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN
189 | # use > 0, in case end_of_text_token is repeated for any reason.
190 | mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float()
191 | spk_emb = spk_emb * mask[:, :, None]
192 |
193 | return spk_emb
194 |
195 | def forward(
196 | self,
197 | idx,
198 | targets=None,
199 | speaker_embs=None,
200 | speaker_emb_mask=None,
201 | loss_reduce: Literal["mean", "none"] = "mean",
202 | ):
203 | device = idx.device
204 | b, num_hierarchies, t = idx.size()
205 |
206 | if speaker_embs is not None:
207 | speaker_embs = _check_speaker_emb_dims(
208 | speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b
209 | )
210 |
211 | assert (
212 | t <= self.config.block_size
213 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
214 |
215 | if self.kv_cache_enabled:
216 | if self.kv_pos == 0:
217 | pos = torch.arange(0, t, dtype=torch.long, device=device)
218 | self.kv_pos += t
219 | else:
220 | assert t == 1, "KV cache is only supported for single token inputs"
221 | pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1)
222 | self.kv_pos += 1
223 | else:
224 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
225 |
226 | # forward the GPT model itself
227 | assert num_hierarchies == len(
228 | self.transformer.wtes
229 | ), f"Input tensor has {num_hierarchies} hierarchies, but model has {len(self.transformer.wtes)} set of input embeddings."
230 |
231 | # embed the tokens, positional encoding, and speaker embedding
232 | tok_emb = torch.zeros((b, t, self.config.n_embd), device=device)
233 | # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings.
234 | for i, wte in enumerate(self.transformer.wtes):
235 | tok_emb += wte(idx[:, i, :])
236 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
237 |
238 | spk_emb = 0.0
239 | if speaker_embs is not None:
240 | if type(speaker_embs) == list:
241 | assert speaker_emb_mask is None
242 | assert self.training is False
243 | assert self.spk_emb_on_text is True
244 |
245 | spk_emb = []
246 | for speaker_emb_row in speaker_embs:
247 | if speaker_emb_row is not None:
248 | spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0)))
249 | assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}"
250 | else:
251 | spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype))
252 | spk_emb = torch.cat(spk_emb, dim=0)
253 |
254 | assert (
255 | spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b
256 | ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}"
257 | else:
258 | speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c)
259 |
260 | if speaker_emb_mask is not None:
261 | spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask)
262 | assert spk_emb.shape == (b, t, self.config.n_embd)
263 | else:
264 | spk_emb = speakers_embedded
265 | # if we don't have a mask, we assume that the speaker embedding is the same for all tokens
266 | # then num_examples dimension just becomes the time dimension
267 | assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1
268 |
269 | if self.training and self.config.spkemb_dropout > 0.0:
270 | # Remove speaker conditioning at random.
271 | dropout = torch.ones_like(speakers_embedded) * (
272 | torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
273 | )
274 | spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)
275 |
276 | if self.spk_emb_on_text is False:
277 | assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False"
278 | spk_emb = self._mask_spk_emb_on_text(idx, spk_emb)
279 |
280 | x = self.transformer.drop(tok_emb + pos_emb + spk_emb)
281 | for block in self.transformer.h:
282 | x = block(x)
283 | x = self.transformer.ln_f(x)
284 |
285 | if targets is not None:
286 | # if we are given some desired targets also calculate the loss
287 | list_logits = [lm_head(x) for lm_head in self.lm_heads]
288 |
289 | losses = [
290 | F.cross_entropy(
291 | logits.view(-1, logits.size(-1)),
292 | targets[:, i, :].contiguous().view(-1),
293 | ignore_index=-1,
294 | reduction=loss_reduce,
295 | )
296 | for i, logits in enumerate(list_logits)
297 | ]
298 | # TODO: should we do this better without stack somehow?
299 | losses = torch.stack(losses)
300 | if loss_reduce == "mean":
301 | losses = losses.mean()
302 | else:
303 | losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t)
304 | else:
305 | # inference-time mini-optimization: only forward the lm_head on the very last position
306 | if self.config.causal:
307 | list_logits = [
308 | lm_head(x[:, [-1], :]) for lm_head in self.lm_heads
309 | ] # note: using list [-1] to preserve the time dim
310 | else:
311 | list_logits = [lm_head(x) for lm_head in self.lm_heads]
312 | losses = None
313 |
314 | return list_logits, losses
315 |
316 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
317 | # start with all of the candidate parameters
318 | param_dict = {pn: p for pn, p in self.named_parameters()}
319 | # filter out those that do not require grad
320 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
321 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
322 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
323 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
324 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
325 | optim_groups = [
326 | {"params": decay_params, "weight_decay": weight_decay},
327 | {"params": nodecay_params, "weight_decay": 0.0},
328 | ]
329 | num_decay_params = sum(p.numel() for p in decay_params)
330 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
331 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
332 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
333 | # Create AdamW optimizer and use the fused version if it is available
334 | fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
335 | use_fused = fused_available and device_type == "cuda"
336 | extra_args = dict(fused=True) if use_fused else dict()
337 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
338 | print(f"using fused AdamW: {use_fused}")
339 |
340 | return optimizer
341 |
342 | @torch.no_grad()
343 | def generate(
344 | self,
345 | idx: torch.Tensor,
346 | max_new_tokens: int,
347 | seq_lens: Optional[list] = None,
348 | temperature: float = 1.0,
349 | top_k: Optional[int] = None,
350 | top_p: Optional[float] = None,
351 | speaker_embs: Optional[torch.Tensor] = None,
352 | batch_size: Optional[int] = None,
353 | guidance_scale: Optional[Tuple[float, float]] = None,
354 | dtype: torch.dtype = torch.bfloat16,
355 | end_of_audio_token: int = 99999, # Dummy values will disable early termination / guidance features.
356 | end_of_text_token: int = 99999,
357 | ):
358 | """
359 | Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
360 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
361 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
362 | """
363 | assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
364 |
365 | if self.config.causal:
366 | if seq_lens is None or batch_size is None:
367 | raise Exception("seq_lens and batch_size must be provided for causal sampling")
368 |
369 | return self._causal_sample(
370 | idx=idx,
371 | max_new_tokens=max_new_tokens,
372 | seq_lens=seq_lens,
373 | temperature=temperature,
374 | top_k=top_k,
375 | top_p=top_p,
376 | speaker_embs=speaker_embs,
377 | batch_size=batch_size,
378 | guidance_scale=guidance_scale,
379 | dtype=dtype,
380 | end_of_audio_token=end_of_audio_token,
381 | end_of_text_token=end_of_text_token,
382 | )
383 |
384 | else:
385 | if seq_lens is not None:
386 | raise Exception("seq_lens is not supported yet for non-causal sampling")
387 |
388 | if batch_size is None:
389 | raise Exception("batch_size must be provided for non-causal sampling")
390 |
391 | if guidance_scale is not None:
392 | raise Exception("guidance_scale is not supported for non-causal sampling")
393 |
394 | if top_p is not None:
395 | raise Exception("top_p is not supported for non-causal sampling")
396 |
397 | out = []
398 | for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="Non-causal batching"):
399 | end_index = min(start_index + batch_size, idx.shape[0])
400 | out.append(
401 | self._non_causal_sample(
402 | idx=idx[start_index:end_index],
403 | speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None,
404 | temperature=temperature,
405 | top_k=top_k,
406 | )
407 | )
408 | return torch.cat(out, dim=0)
409 |
--------------------------------------------------------------------------------
/fam/llm/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/llm/preprocessing/__init__.py
--------------------------------------------------------------------------------
/fam/llm/preprocessing/audio_token_mode.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any, Callable, Literal, Optional
3 |
4 | import numpy as np
5 |
6 | AudioTokenModeT = Literal["flattened_interleaved"]
7 | CombinerWithOffsetFuncT = Callable[[np.ndarray, np.ndarray, int], np.ndarray]
8 | CombinerFuncT = Callable[[np.ndarray, np.ndarray], np.ndarray]
9 |
10 |
11 | def combine_tokens_flattened_interleaved(
12 | audio_tokens: np.ndarray, text_tokens: np.ndarray, second_hierarchy_flattening_offset: int
13 | ) -> np.ndarray:
14 | """
15 | Flattens & interleaves first 2 of the audio token hierarchies. Note that the tokens for the second hierarchy
16 | are also offset by second_hierarchy_flattening_offset as part of this transform to avoid conflict with values for the
17 | first hierarchy.
18 | """
19 | assert np.issubdtype(audio_tokens.dtype, np.integer)
20 | assert np.issubdtype(text_tokens.dtype, np.integer)
21 |
22 | num_hierarchies = audio_tokens.shape[0]
23 | assert num_hierarchies >= 2, f"Unexpected number of hierarchies: {num_hierarchies}. Expected at least 2."
24 |
25 | # choosing -5 so that we can't get error!
26 | interleaved_audio_tokens = np.full((len(audio_tokens[0]) + len(audio_tokens[1]),), -5)
27 | interleaved_audio_tokens[::2] = audio_tokens[0]
28 | interleaved_audio_tokens[1::2] = audio_tokens[1] + second_hierarchy_flattening_offset
29 |
30 | tokens = np.concatenate([text_tokens, interleaved_audio_tokens])
31 |
32 | return np.expand_dims(tokens, axis=0)
33 |
34 |
35 | def get_params_for_mode(
36 | audio_token_mode: AudioTokenModeT, num_max_audio_tokens_timesteps: Optional[int] = None
37 | ) -> dict[str, Any]:
38 | if audio_token_mode == "flattened_interleaved":
39 | return {
40 | "text_tokenisation_offset": 1024 * 2 + 1,
41 | "pad_token": 1024 * 2,
42 | "ctx_window": num_max_audio_tokens_timesteps * 2 if num_max_audio_tokens_timesteps else None,
43 | "second_hierarchy_flattening_offset": 1024,
44 | # TODO: fix the repeat of `second_hierarchy_flattening_offset`
45 | "combine_func": partial(
46 | combine_tokens_flattened_interleaved,
47 | second_hierarchy_flattening_offset=1024,
48 | ),
49 | }
50 | else:
51 | raise Exception(f"Unknown mode {audio_token_mode}")
52 |
--------------------------------------------------------------------------------
/fam/llm/preprocessing/data_pipeline.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def pad_tokens(tokens: np.ndarray, context_window: int, pad_token: int) -> np.ndarray:
8 | """Pads or truncates a single example to the context_window + 1 size.
9 |
10 | tokens: (..., example_length)
11 | """
12 | example_length = tokens.shape[-1]
13 | if example_length > context_window + 1:
14 | # Truncate
15 | tokens = tokens[..., : context_window + 1]
16 | elif example_length < context_window + 1:
17 | # Pad
18 | padding = np.full(tokens.shape[:-1] + (context_window + 1 - example_length,), pad_token)
19 | tokens = np.concatenate([tokens, padding], axis=-1)
20 | assert tokens.shape[-1] == context_window + 1
21 | return tokens
22 |
23 |
24 | def get_training_tuple(
25 | batch: Dict[str, Any],
26 | causal: bool,
27 | num_codebooks: Optional[int],
28 | speaker_cond: bool,
29 | device: torch.device,
30 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
31 | # batch contains combined tokens as specified by audio_token_mode
32 | if causal:
33 | num_codebooks = batch["tokens"].shape[1] if num_codebooks is None else num_codebooks
34 | x = batch["tokens"][:, :num_codebooks, :-1]
35 | y = batch["tokens"][:, :num_codebooks, 1:]
36 |
37 | se = batch["spkemb"]
38 |
39 | x = x.to(device, non_blocking=True)
40 | y = y.to(device, non_blocking=True)
41 | se = se.to(device, non_blocking=True) if speaker_cond else None
42 |
43 | return x, y, se
44 |
45 |
46 | def pad_with_values(tensor, batch_size, value):
47 | """Pads the tensor up to batch_size with values."""
48 | if tensor.shape[0] < batch_size:
49 | return torch.cat(
50 | [
51 | tensor,
52 | torch.full(
53 | (batch_size - tensor.shape[0], *tensor.shape[1:]), value, dtype=tensor.dtype, device=tensor.device
54 | ),
55 | ]
56 | )
57 | else:
58 | return tensor
59 |
--------------------------------------------------------------------------------
/fam/llm/utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | import re
5 | import subprocess
6 | import tempfile
7 |
8 | import librosa
9 | import torch
10 |
11 |
12 | def normalize_text(text: str) -> str:
13 | unicode_conversion = {
14 | 8175: "'",
15 | 8189: "'",
16 | 8190: "'",
17 | 8208: "-",
18 | 8209: "-",
19 | 8210: "-",
20 | 8211: "-",
21 | 8212: "-",
22 | 8213: "-",
23 | 8214: "||",
24 | 8216: "'",
25 | 8217: "'",
26 | 8218: ",",
27 | 8219: "`",
28 | 8220: '"',
29 | 8221: '"',
30 | 8222: ",,",
31 | 8223: '"',
32 | 8228: ".",
33 | 8229: "..",
34 | 8230: "...",
35 | 8242: "'",
36 | 8243: '"',
37 | 8245: "'",
38 | 8246: '"',
39 | 180: "'",
40 | 2122: "TM", # Trademark
41 | }
42 |
43 | text = text.translate(unicode_conversion)
44 |
45 | non_bpe_chars = set([c for c in list(text) if ord(c) >= 256])
46 | if len(non_bpe_chars) > 0:
47 | non_bpe_points = [(c, ord(c)) for c in non_bpe_chars]
48 | raise ValueError(f"Non-supported character found: {non_bpe_points}")
49 |
50 | text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace("*", " ").strip()
51 | text = re.sub("\s\s+", " ", text) # remove multiple spaces
52 | return text
53 |
54 |
55 | def check_audio_file(path_or_uri, threshold_s=30):
56 | if "http" in path_or_uri:
57 | temp_fd, filepath = tempfile.mkstemp()
58 | os.close(temp_fd) # Close the file descriptor, curl will create a new connection
59 | curl_command = ["curl", "-L", path_or_uri, "-o", filepath]
60 | subprocess.run(curl_command, check=True)
61 |
62 | else:
63 | filepath = path_or_uri
64 |
65 | audio, sr = librosa.load(filepath)
66 | duration_s = librosa.get_duration(y=audio, sr=sr)
67 | if duration_s < threshold_s:
68 | raise Exception(
69 | f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed."
70 | )
71 |
72 | # Clean up the temporary file if it was created
73 | if "http" in path_or_uri:
74 | os.remove(filepath)
75 |
76 |
77 | def get_default_dtype() -> str:
78 | """Compute default 'dtype' based on GPU architecture"""
79 | if torch.cuda.is_available():
80 | for i in range(torch.cuda.device_count()):
81 | device_properties = torch.cuda.get_device_properties(i)
82 | dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures
83 | else:
84 | dtype = "float16"
85 |
86 | print(f"using dtype={dtype}")
87 | return dtype
88 |
89 |
90 | def get_device() -> str:
91 | return "cuda" if torch.cuda.is_available() else "cpu"
92 |
93 |
94 | def hash_dictionary(d: dict):
95 | # Serialize the dictionary into JSON with sorted keys to ensure consistency
96 | serialized = json.dumps(d, sort_keys=True)
97 | # Encode the serialized string to bytes
98 | encoded = serialized.encode()
99 | # Create a hash object (you can also use sha1, sha512, etc.)
100 | hash_object = hashlib.sha256(encoded)
101 | # Get the hexadecimal digest of the hash
102 | hash_digest = hash_object.hexdigest()
103 | return hash_digest
104 |
--------------------------------------------------------------------------------
/fam/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/py.typed
--------------------------------------------------------------------------------
/fam/quantiser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/quantiser/__init__.py
--------------------------------------------------------------------------------
/fam/quantiser/audio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/quantiser/audio/__init__.py
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/metavoiceio/metavoice-src/de3fa211ac4621e03a5f990651aeecc64da418f5/fam/quantiser/audio/speaker_encoder/__init__.py
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 |
4 | mel_window_length = 25
5 | mel_window_step = 10
6 | mel_n_channels = 40
7 | sampling_rate = 16000
8 |
9 |
10 | def wav_to_mel_spectrogram(wav):
11 | """
12 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
13 | Note: this not a log-mel spectrogram.
14 | """
15 | frames = librosa.feature.melspectrogram(
16 | y=wav,
17 | sr=sampling_rate,
18 | n_fft=int(sampling_rate * mel_window_length / 1000),
19 | hop_length=int(sampling_rate * mel_window_step / 1000),
20 | n_mels=mel_n_channels,
21 | )
22 | return frames.astype(np.float32).T
23 |
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from time import perf_counter as timer
3 | from typing import List, Optional, Union
4 |
5 | import librosa
6 | import numpy as np
7 | import torch
8 | from torch import nn
9 |
10 | from fam.quantiser.audio.speaker_encoder import audio
11 |
12 | mel_window_step = 10
13 | mel_n_channels = 40
14 | sampling_rate = 16000
15 | partials_n_frames = 160
16 | model_hidden_size = 256
17 | model_embedding_size = 256
18 | model_num_layers = 3
19 |
20 |
21 | class SpeakerEncoder(nn.Module):
22 | def __init__(
23 | self,
24 | weights_fpath: Optional[str] = None,
25 | device: Optional[Union[str, torch.device]] = None,
26 | verbose: bool = True,
27 | eval: bool = False,
28 | ):
29 | super().__init__()
30 |
31 | # Define the network
32 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
33 | self.linear = nn.Linear(model_hidden_size, model_embedding_size)
34 | self.relu = nn.ReLU()
35 |
36 | # Get the target device
37 | if device is None:
38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39 | elif isinstance(device, str):
40 | device = torch.device(device)
41 | self.device = device
42 |
43 | start = timer()
44 |
45 | checkpoint = torch.load(weights_fpath, map_location="cpu")
46 | self.load_state_dict(checkpoint["model_state"], strict=False)
47 | self.to(device)
48 |
49 | if eval:
50 | self.eval()
51 |
52 | if verbose:
53 | print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start))
54 |
55 | def forward(self, mels: torch.FloatTensor):
56 | _, (hidden, _) = self.lstm(mels)
57 | embeds_raw = self.relu(self.linear(hidden[-1]))
58 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
59 |
60 | @staticmethod
61 | def compute_partial_slices(n_samples: int, rate, min_coverage):
62 | # Compute how many frames separate two partial utterances
63 | samples_per_frame = int((sampling_rate * mel_window_step / 1000))
64 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
65 | frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
66 |
67 | # Compute the slices
68 | wav_slices, mel_slices = [], []
69 | steps = max(1, n_frames - partials_n_frames + frame_step + 1)
70 | for i in range(0, steps, frame_step):
71 | mel_range = np.array([i, i + partials_n_frames])
72 | wav_range = mel_range * samples_per_frame
73 | mel_slices.append(slice(*mel_range))
74 | wav_slices.append(slice(*wav_range))
75 |
76 | # Evaluate whether extra padding is warranted or not
77 | last_wav_range = wav_slices[-1]
78 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
79 | if coverage < min_coverage and len(mel_slices) > 1:
80 | mel_slices = mel_slices[:-1]
81 | wav_slices = wav_slices[:-1]
82 |
83 | return wav_slices, mel_slices
84 |
85 | def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True):
86 | wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
87 | max_wave_length = wav_slices[-1].stop
88 | if max_wave_length >= len(wav):
89 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
90 |
91 | mel = audio.wav_to_mel_spectrogram(wav)
92 | mels = np.array([mel[s] for s in mel_slices])
93 | mels = torch.from_numpy(mels).to(self.device) # type: ignore
94 | with torch.no_grad():
95 | partial_embeds = self(mels)
96 |
97 | if numpy:
98 | raw_embed = np.mean(partial_embeds.cpu().numpy(), axis=0)
99 | embed = raw_embed / np.linalg.norm(raw_embed, 2)
100 | else:
101 | raw_embed = partial_embeds.mean(dim=0)
102 | embed = raw_embed / torch.linalg.norm(raw_embed, 2)
103 |
104 | if return_partials:
105 | return embed, partial_embeds, wav_slices
106 | return embed
107 |
108 | def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
109 | raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
110 | return raw_embed / np.linalg.norm(raw_embed, 2)
111 |
112 | def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor:
113 | wav_tgt, _ = librosa.load(fpath, sr=sampling_rate)
114 | wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
115 |
116 | embedding = self.embed_utterance(wav_tgt, numpy=numpy)
117 | return embedding
118 |
--------------------------------------------------------------------------------
/fam/quantiser/text/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/fam/quantiser/text/tokenise.py:
--------------------------------------------------------------------------------
1 | import tiktoken
2 |
3 |
4 | class TrainedBPETokeniser:
5 | def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None:
6 | self.tokenizer = tiktoken.Encoding(
7 | name=name,
8 | pat_str=pat_str,
9 | mergeable_ranks=mergeable_ranks,
10 | special_tokens=special_tokens,
11 | )
12 | self.offset = offset
13 |
14 | def encode(self, text: str) -> list[int]:
15 | # note: we add a end of text token!
16 | tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token]
17 | if self.offset is not None:
18 | tokens = [x + self.offset for x in tokens]
19 |
20 | return tokens
21 |
22 | def decode(self, tokens: list[int]):
23 | if self.offset is not None:
24 | tokens = [x - self.offset for x in tokens]
25 | return self.tokenizer.decode(tokens)
26 |
27 | @property
28 | def eot_token(self):
29 | if self.offset is not None:
30 | return self.tokenizer.eot_token + self.offset
31 | else:
32 | return self.tokenizer.eot_token
33 |
--------------------------------------------------------------------------------
/fam/telemetry/README.md:
--------------------------------------------------------------------------------
1 | # Telemetry
2 |
3 | This directory holds all the telemetry for MetaVoice. We, MetaVoice, capture anonymized telemetry to understand usage patterns.
4 |
5 | If you prefer to opt out of telemetry, set `ANONYMIZED_TELEMETRY=False` in an .env file at the root level of this repo.
6 |
--------------------------------------------------------------------------------
/fam/telemetry/__init__.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import os
3 | import uuid
4 | from abc import abstractmethod
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 |
8 |
9 | @dataclass(frozen=True)
10 | class TelemetryEvent:
11 | name: str
12 | properties: dict
13 |
14 |
15 | class TelemetryClient(abc.ABC):
16 | USER_ID_PATH = str(Path.home() / ".cache" / "metavoice" / "telemetry_user_id")
17 | UNKNOWN_USER_ID = "UNKNOWN"
18 | _curr_user_id = None
19 |
20 | @abstractmethod
21 | def capture(self, event: TelemetryEvent) -> None:
22 | pass
23 |
24 | @property
25 | def user_id(self) -> str:
26 | if self._curr_user_id:
27 | return self._curr_user_id
28 |
29 | # File access may fail due to permissions or other reasons. We don't want to
30 | # crash so we catch all exceptions.
31 | try:
32 | if not os.path.exists(self.USER_ID_PATH):
33 | os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
34 | with open(self.USER_ID_PATH, "w") as f:
35 | new_user_id = str(uuid.uuid4())
36 | f.write(new_user_id)
37 | self._curr_user_id = new_user_id
38 | else:
39 | with open(self.USER_ID_PATH, "r") as f:
40 | self._curr_user_id = f.read()
41 | except Exception:
42 | self._curr_user_id = self.UNKNOWN_USER_ID
43 | return self._curr_user_id
44 |
--------------------------------------------------------------------------------
/fam/telemetry/posthog.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | from dotenv import load_dotenv
6 | from posthog import Posthog
7 |
8 | from fam.telemetry import TelemetryClient, TelemetryEvent
9 |
10 | load_dotenv()
11 | logger = logging.getLogger(__name__)
12 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout), logging.StreamHandler(sys.stderr)])
13 |
14 |
15 | class PosthogClient(TelemetryClient):
16 | def __init__(self):
17 | self._posthog = Posthog(
18 | project_api_key="phc_tk7IUlV7Q7lEa9LNbXxyC1sMWlCqiW6DkHyhJrbWMCS", host="https://eu.posthog.com"
19 | )
20 |
21 | if not bool(os.getenv("ANONYMIZED_TELEMETRY", True)) or "pytest" in sys.modules:
22 | self._posthog.disabled = True
23 | logger.info("Anonymized telemetry disabled. See fam/telemetry/README.md for more information.")
24 | else:
25 | logger.info("Anonymized telemetry enabled. See fam/telemetry/README.md for more information.")
26 |
27 | posthog_logger = logging.getLogger("posthog")
28 | posthog_logger.disabled = True # Silence posthog's logging
29 |
30 | super().__init__()
31 |
32 | def capture(self, event: TelemetryEvent) -> None:
33 | try:
34 | self._posthog.capture(
35 | self.user_id,
36 | event.name,
37 | {**event.properties},
38 | )
39 | except Exception as e:
40 | logger.error(f"Failed to send telemetry event {event.name}: {e}")
41 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "fam"
3 | version = "0.1.0"
4 | description = "Foundational model for text to speech"
5 | authors = []
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "^3.10"
10 | torch = "^2.1.0"
11 | torchaudio = "^2.1.0"
12 | librosa = "^0.10.1"
13 | tqdm = "^4.66.2"
14 | tiktoken = "==0.5.1"
15 | audiocraft = "^1.2.0"
16 | numpy = "^1.26.4"
17 | ninja = "^1.11.1"
18 | fastapi = "^0.110.0"
19 | uvicorn = "^0.27.1"
20 | tyro = "^0.7.3"
21 | deepfilternet = "^0.5.6"
22 | pydub = "^0.25.1"
23 | gradio = "^4.20.1"
24 | huggingface_hub = "^0.21.4"
25 | click = "^8.1.7"
26 | wandb = { version = "^0.16.4", optional = true }
27 | posthog = "^3.5.0"
28 | python-dotenv = "^1.0.1"
29 |
30 | [tool.poetry.dev-dependencies]
31 | pre-commit = "^3.7.0"
32 | pytest = "^8.0.2"
33 | ipdb = "^0.13.13"
34 |
35 | [tool.poetry.extras]
36 | observable = ["wandb"]
37 |
38 | [tool.poetry.scripts]
39 | finetune = "fam.llm.finetune:main"
40 |
41 | [build-system]
42 | requires = ["poetry-core"]
43 | build-backend = "poetry.core.masonry.api"
44 |
45 | [tool.black]
46 | line-length = 120
47 | exclude = '''
48 | /(
49 | \.git
50 | | \.mypy_cache
51 | | \.tox
52 | | _build
53 | | build
54 | | dist
55 | )/
56 | '''
57 |
58 | [tool.isort]
59 | profile = "black"
60 |
61 |
--------------------------------------------------------------------------------
/serving.py:
--------------------------------------------------------------------------------
1 | # curl -X POST http://127.0.0.1:58003/tts -F "text=Testing this inference server." -F "speaker_ref_path=https://cdn.themetavoice.xyz/speakers/bria.mp3" -F "guidance=3.0" -F "top_p=0.95" --output out.wav
2 |
3 | import logging
4 | import shlex
5 | import subprocess
6 | import tempfile
7 | import warnings
8 | from pathlib import Path
9 | from typing import Literal, Optional
10 |
11 | import fastapi
12 | import fastapi.middleware.cors
13 | import tyro
14 | import uvicorn
15 | from attr import dataclass
16 | from fastapi import File, Form, HTTPException, UploadFile, status
17 | from fastapi.responses import Response
18 |
19 | from fam.llm.fast_inference import TTS
20 | from fam.llm.utils import check_audio_file
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | ## Setup FastAPI server.
26 | app = fastapi.FastAPI()
27 |
28 |
29 | @dataclass
30 | class ServingConfig:
31 | huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1"
32 | """Absolute path to the model directory."""
33 |
34 | temperature: float = 1.0
35 | """Temperature for sampling applied to both models."""
36 |
37 | seed: int = 1337
38 | """Random seed for sampling."""
39 |
40 | port: int = 58003
41 |
42 | quantisation_mode: Optional[Literal["int4", "int8"]] = None
43 |
44 |
45 | # Singleton
46 | class _GlobalState:
47 | config: ServingConfig
48 | tts: TTS
49 |
50 |
51 | GlobalState = _GlobalState()
52 |
53 |
54 | @app.get("/health")
55 | async def health_check():
56 | return {"status": "ok"}
57 |
58 |
59 | @app.post("/tts", response_class=Response)
60 | async def text_to_speech(
61 | text: str = Form(..., description="Text to convert to speech."),
62 | speaker_ref_path: Optional[str] = Form(None, description="Optional URL to an audio file of a reference speaker. Provide either this URL or audio data through 'audiodata'."),
63 | audiodata: Optional[UploadFile] = File(None, description="Optional audio data of a reference speaker. Provide either this file or a URL through 'speaker_ref_path'."),
64 | guidance: float = Form(3.0, description="Control speaker similarity - how closely to match speaker identity and speech style, range: 0.0 to 5.0.", ge=0.0, le=5.0),
65 | top_p: float = Form(0.95, description="Controls speech stability - improves text following for a challenging speaker, range: 0.0 to 1.0.", ge=0.0, le=1.0),
66 | ):
67 | # Ensure at least one of speaker_ref_path or audiodata is provided
68 | if not audiodata and not speaker_ref_path:
69 | raise HTTPException(
70 | status_code=status.HTTP_400_BAD_REQUEST,
71 | detail="Either an audio file or a speaker reference path must be provided.",
72 | )
73 |
74 | wav_out_path = None
75 |
76 | try:
77 | with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
78 | if speaker_ref_path is None:
79 | wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
80 | check_audio_file(wav_path)
81 | else:
82 | # TODO: fix
83 | wav_path = speaker_ref_path
84 |
85 | if wav_path is None:
86 | warnings.warn("Running without speaker reference")
87 | assert guidance is None
88 |
89 | wav_out_path = GlobalState.tts.synthesise(
90 | text=text,
91 | spk_ref_path=wav_path,
92 | top_p=top_p,
93 | guidance_scale=guidance,
94 | )
95 |
96 | with open(wav_out_path, "rb") as f:
97 | return Response(content=f.read(), media_type="audio/wav")
98 | except Exception as e:
99 | # traceback_str = "".join(traceback.format_tb(e.__traceback__))
100 | logger.exception(
101 | f"Error processing request. text: {text}, speaker_ref_path: {speaker_ref_path}, guidance: {guidance}, top_p: {top_p}"
102 | )
103 | return Response(
104 | content="Something went wrong. Please try again in a few mins or contact us on Discord",
105 | status_code=500,
106 | )
107 | finally:
108 | if wav_out_path is not None:
109 | Path(wav_out_path).unlink(missing_ok=True)
110 |
111 |
112 | def _convert_audiodata_to_wav_path(audiodata: UploadFile, wav_tmp):
113 | with tempfile.NamedTemporaryFile() as unknown_format_tmp:
114 | if unknown_format_tmp.write(audiodata.read()) == 0:
115 | return None
116 | unknown_format_tmp.flush()
117 |
118 | subprocess.check_output(
119 | # arbitrary 2 minute cutoff
120 | shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}")
121 | )
122 |
123 | return wav_tmp.name
124 |
125 |
126 | if __name__ == "__main__":
127 | for name in logging.root.manager.loggerDict:
128 | logger = logging.getLogger(name)
129 | logger.setLevel(logging.INFO)
130 | logging.root.setLevel(logging.INFO)
131 |
132 | GlobalState.config = tyro.cli(ServingConfig)
133 | GlobalState.tts = TTS(
134 | seed=GlobalState.config.seed,
135 | quantisation_mode=GlobalState.config.quantisation_mode,
136 | telemetry_origin="api_server",
137 | )
138 |
139 | app.add_middleware(
140 | fastapi.middleware.cors.CORSMiddleware,
141 | allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"],
142 | allow_credentials=True,
143 | allow_methods=["*"],
144 | allow_headers=["*"],
145 | )
146 | uvicorn.run(
147 | app,
148 | host="0.0.0.0",
149 | port=GlobalState.config.port,
150 | log_level="info",
151 | )
152 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup # type: ignore
2 |
3 | setup(
4 | name="fam",
5 | packages=find_packages(".", exclude=["tests"]),
6 | )
7 |
--------------------------------------------------------------------------------
/tests/llm/loaders/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from pathlib import Path
3 |
4 | import pytest
5 | import torch
6 | from huggingface_hub import snapshot_download
7 | from torch.utils.data import DataLoader
8 |
9 | from fam.llm.config.finetune_params import audio_token_mode as atm
10 | from fam.llm.config.finetune_params import num_max_audio_tokens_timesteps
11 | from fam.llm.loaders.training_data import DynamicComputeDataset
12 | from fam.llm.preprocessing.audio_token_mode import get_params_for_mode
13 |
14 |
15 | @pytest.mark.parametrize("dataset", ["tests/resources/datasets/sample_dataset.csv"])
16 | @pytest.mark.skip(reason="Requires ckpt download, not feasible as test suite")
17 | def test_dataset_preprocess_e2e(dataset):
18 | model_name = "metavoiceio/metavoice-1B-v0.1"
19 | device = "cuda"
20 | mode_params = get_params_for_mode(atm, num_max_audio_tokens_timesteps=num_max_audio_tokens_timesteps)
21 |
22 | _model_dir = snapshot_download(repo_id=model_name)
23 | checkpoint_path = Path(f"{_model_dir}/first_stage.pt")
24 | spk_emb_ckpt_path = Path(f"{_model_dir}/speaker_encoder.pt")
25 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
26 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
27 |
28 | dataset = DynamicComputeDataset.from_meta(
29 | tokenizer_info,
30 | mode_params["combine_func"],
31 | spk_emb_ckpt_path,
32 | dataset,
33 | mode_params["pad_token"],
34 | mode_params["ctx_window"],
35 | device
36 | )
37 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
38 | result = next(iter(dataloader))
39 |
40 | # TODO: better assertions based on sample input dims
41 | assert len(result) == 2
42 |
--------------------------------------------------------------------------------
/tests/resources/data/caption.txt:
--------------------------------------------------------------------------------
1 | Please call Stella.
--------------------------------------------------------------------------------
/tests/resources/datasets/sample_dataset.csv:
--------------------------------------------------------------------------------
1 | audio_files,captions
2 | ./data/audio.wav,./data/caption.txt
3 | ./data/audio.wav,./data/caption.txt
4 | ./data/audio.wav,./data/caption.txt
5 | ./data/audio.wav,./data/caption.txt
6 | ./data/audio.wav,./data/caption.txt
7 | ./data/audio.wav,./data/caption.txt
8 | ./data/audio.wav,./data/caption.txt
9 | ./data/audio.wav,./data/caption.txt
10 | ./data/audio.wav,./data/caption.txt
11 | ./data/audio.wav,./data/caption.txt
12 | ./data/audio.wav,./data/caption.txt
13 | ./data/audio.wav,./data/caption.txt
14 | ./data/audio.wav,./data/caption.txt
15 | ./data/audio.wav,./data/caption.txt
16 | ./data/audio.wav,./data/caption.txt
17 | ./data/audio.wav,./data/caption.txt
18 | ./data/audio.wav,./data/caption.txt
19 | ./data/audio.wav,./data/caption.txt
20 | ./data/audio.wav,./data/caption.txt
21 | ./data/audio.wav,./data/caption.txt
22 | ./data/audio.wav,./data/caption.txt
23 | ./data/audio.wav,./data/caption.txt
24 | ./data/audio.wav,./data/caption.txt
25 | ./data/audio.wav,./data/caption.txt
26 | ./data/audio.wav,./data/caption.txt
27 | ./data/audio.wav,./data/caption.txt
28 | ./data/audio.wav,./data/caption.txt
29 | ./data/audio.wav,./data/caption.txt
30 | ./data/audio.wav,./data/caption.txt
31 | ./data/audio.wav,./data/caption.txt
32 | ./data/audio.wav,./data/caption.txt
33 | ./data/audio.wav,./data/caption.txt
34 | ./data/audio.wav,./data/caption.txt
35 | ./data/audio.wav,./data/caption.txt
36 | ./data/audio.wav,./data/caption.txt
37 | ./data/audio.wav,./data/caption.txt
38 | ./data/audio.wav,./data/caption.txt
39 | ./data/audio.wav,./data/caption.txt
40 | ./data/audio.wav,./data/caption.txt
41 | ./data/audio.wav,./data/caption.txt
42 | ./data/audio.wav,./data/caption.txt
43 | ./data/audio.wav,./data/caption.txt
44 | ./data/audio.wav,./data/caption.txt
45 | ./data/audio.wav,./data/caption.txt
46 | ./data/audio.wav,./data/caption.txt
47 | ./data/audio.wav,./data/caption.txt
48 | ./data/audio.wav,./data/caption.txt
49 | ./data/audio.wav,./data/caption.txt
50 | ./data/audio.wav,./data/caption.txt
51 | ./data/audio.wav,./data/caption.txt
52 | ./data/audio.wav,./data/caption.txt
53 | ./data/audio.wav,./data/caption.txt
54 | ./data/audio.wav,./data/caption.txt
55 | ./data/audio.wav,./data/caption.txt
56 | ./data/audio.wav,./data/caption.txt
57 | ./data/audio.wav,./data/caption.txt
58 | ./data/audio.wav,./data/caption.txt
59 | ./data/audio.wav,./data/caption.txt
60 | ./data/audio.wav,./data/caption.txt
61 | ./data/audio.wav,./data/caption.txt
62 | ./data/audio.wav,./data/caption.txt
63 | ./data/audio.wav,./data/caption.txt
64 | ./data/audio.wav,./data/caption.txt
65 | ./data/audio.wav,./data/caption.txt
66 | ./data/audio.wav,./data/caption.txt
67 | ./data/audio.wav,./data/caption.txt
68 | ./data/audio.wav,./data/caption.txt
69 | ./data/audio.wav,./data/caption.txt
70 | ./data/audio.wav,./data/caption.txt
71 | ./data/audio.wav,./data/caption.txt
72 | ./data/audio.wav,./data/caption.txt
73 | ./data/audio.wav,./data/caption.txt
74 | ./data/audio.wav,./data/caption.txt
75 | ./data/audio.wav,./data/caption.txt
76 | ./data/audio.wav,./data/caption.txt
77 | ./data/audio.wav,./data/caption.txt
78 | ./data/audio.wav,./data/caption.txt
79 | ./data/audio.wav,./data/caption.txt
80 | ./data/audio.wav,./data/caption.txt
81 | ./data/audio.wav,./data/caption.txt
82 | ./data/audio.wav,./data/caption.txt
83 | ./data/audio.wav,./data/caption.txt
84 | ./data/audio.wav,./data/caption.txt
85 | ./data/audio.wav,./data/caption.txt
86 | ./data/audio.wav,./data/caption.txt
87 | ./data/audio.wav,./data/caption.txt
88 | ./data/audio.wav,./data/caption.txt
89 | ./data/audio.wav,./data/caption.txt
90 | ./data/audio.wav,./data/caption.txt
91 | ./data/audio.wav,./data/caption.txt
92 | ./data/audio.wav,./data/caption.txt
93 | ./data/audio.wav,./data/caption.txt
94 | ./data/audio.wav,./data/caption.txt
95 | ./data/audio.wav,./data/caption.txt
96 | ./data/audio.wav,./data/caption.txt
97 | ./data/audio.wav,./data/caption.txt
98 | ./data/audio.wav,./data/caption.txt
99 | ./data/audio.wav,./data/caption.txt
100 | ./data/audio.wav,./data/caption.txt
101 | ./data/audio.wav,./data/caption.txt
102 | ./data/audio.wav,./data/caption.txt
103 | ./data/audio.wav,./data/caption.txt
104 | ./data/audio.wav,./data/caption.txt
105 | ./data/audio.wav,./data/caption.txt
106 | ./data/audio.wav,./data/caption.txt
107 | ./data/audio.wav,./data/caption.txt
108 | ./data/audio.wav,./data/caption.txt
109 | ./data/audio.wav,./data/caption.txt
110 | ./data/audio.wav,./data/caption.txt
111 | ./data/audio.wav,./data/caption.txt
112 | ./data/audio.wav,./data/caption.txt
113 | ./data/audio.wav,./data/caption.txt
114 | ./data/audio.wav,./data/caption.txt
115 | ./data/audio.wav,./data/caption.txt
116 | ./data/audio.wav,./data/caption.txt
117 | ./data/audio.wav,./data/caption.txt
118 | ./data/audio.wav,./data/caption.txt
119 | ./data/audio.wav,./data/caption.txt
120 | ./data/audio.wav,./data/caption.txt
121 | ./data/audio.wav,./data/caption.txt
122 | ./data/audio.wav,./data/caption.txt
123 | ./data/audio.wav,./data/caption.txt
124 | ./data/audio.wav,./data/caption.txt
125 | ./data/audio.wav,./data/caption.txt
126 | ./data/audio.wav,./data/caption.txt
127 | ./data/audio.wav,./data/caption.txt
128 | ./data/audio.wav,./data/caption.txt
129 | ./data/audio.wav,./data/caption.txt
130 | ./data/audio.wav,./data/caption.txt
131 | ./data/audio.wav,./data/caption.txt
132 | ./data/audio.wav,./data/caption.txt
133 | ./data/audio.wav,./data/caption.txt
134 | ./data/audio.wav,./data/caption.txt
135 | ./data/audio.wav,./data/caption.txt
136 | ./data/audio.wav,./data/caption.txt
137 | ./data/audio.wav,./data/caption.txt
138 | ./data/audio.wav,./data/caption.txt
139 | ./data/audio.wav,./data/caption.txt
140 | ./data/audio.wav,./data/caption.txt
141 | ./data/audio.wav,./data/caption.txt
142 | ./data/audio.wav,./data/caption.txt
143 | ./data/audio.wav,./data/caption.txt
144 | ./data/audio.wav,./data/caption.txt
145 | ./data/audio.wav,./data/caption.txt
146 | ./data/audio.wav,./data/caption.txt
147 | ./data/audio.wav,./data/caption.txt
148 | ./data/audio.wav,./data/caption.txt
149 | ./data/audio.wav,./data/caption.txt
150 | ./data/audio.wav,./data/caption.txt
151 | ./data/audio.wav,./data/caption.txt
152 | ./data/audio.wav,./data/caption.txt
153 | ./data/audio.wav,./data/caption.txt
154 | ./data/audio.wav,./data/caption.txt
155 | ./data/audio.wav,./data/caption.txt
156 | ./data/audio.wav,./data/caption.txt
157 | ./data/audio.wav,./data/caption.txt
158 | ./data/audio.wav,./data/caption.txt
159 | ./data/audio.wav,./data/caption.txt
160 | ./data/audio.wav,./data/caption.txt
161 | ./data/audio.wav,./data/caption.txt
162 | ./data/audio.wav,./data/caption.txt
163 | ./data/audio.wav,./data/caption.txt
164 | ./data/audio.wav,./data/caption.txt
165 | ./data/audio.wav,./data/caption.txt
166 | ./data/audio.wav,./data/caption.txt
167 | ./data/audio.wav,./data/caption.txt
168 | ./data/audio.wav,./data/caption.txt
169 | ./data/audio.wav,./data/caption.txt
170 | ./data/audio.wav,./data/caption.txt
171 | ./data/audio.wav,./data/caption.txt
172 | ./data/audio.wav,./data/caption.txt
173 | ./data/audio.wav,./data/caption.txt
174 | ./data/audio.wav,./data/caption.txt
175 | ./data/audio.wav,./data/caption.txt
176 | ./data/audio.wav,./data/caption.txt
177 | ./data/audio.wav,./data/caption.txt
178 | ./data/audio.wav,./data/caption.txt
179 | ./data/audio.wav,./data/caption.txt
180 | ./data/audio.wav,./data/caption.txt
181 | ./data/audio.wav,./data/caption.txt
182 | ./data/audio.wav,./data/caption.txt
183 | ./data/audio.wav,./data/caption.txt
184 | ./data/audio.wav,./data/caption.txt
185 | ./data/audio.wav,./data/caption.txt
186 | ./data/audio.wav,./data/caption.txt
187 | ./data/audio.wav,./data/caption.txt
188 | ./data/audio.wav,./data/caption.txt
189 | ./data/audio.wav,./data/caption.txt
190 | ./data/audio.wav,./data/caption.txt
191 | ./data/audio.wav,./data/caption.txt
192 | ./data/audio.wav,./data/caption.txt
193 | ./data/audio.wav,./data/caption.txt
194 | ./data/audio.wav,./data/caption.txt
195 | ./data/audio.wav,./data/caption.txt
196 | ./data/audio.wav,./data/caption.txt
197 | ./data/audio.wav,./data/caption.txt
198 | ./data/audio.wav,./data/caption.txt
199 | ./data/audio.wav,./data/caption.txt
200 | ./data/audio.wav,./data/caption.txt
201 | ./data/audio.wav,./data/caption.txt
202 | ./data/audio.wav,./data/caption.txt
203 | ./data/audio.wav,./data/caption.txt
204 | ./data/audio.wav,./data/caption.txt
205 | ./data/audio.wav,./data/caption.txt
206 | ./data/audio.wav,./data/caption.txt
207 | ./data/audio.wav,./data/caption.txt
208 | ./data/audio.wav,./data/caption.txt
209 | ./data/audio.wav,./data/caption.txt
210 | ./data/audio.wav,./data/caption.txt
211 | ./data/audio.wav,./data/caption.txt
212 | ./data/audio.wav,./data/caption.txt
213 | ./data/audio.wav,./data/caption.txt
214 | ./data/audio.wav,./data/caption.txt
215 | ./data/audio.wav,./data/caption.txt
216 | ./data/audio.wav,./data/caption.txt
217 | ./data/audio.wav,./data/caption.txt
218 | ./data/audio.wav,./data/caption.txt
219 | ./data/audio.wav,./data/caption.txt
220 | ./data/audio.wav,./data/caption.txt
221 | ./data/audio.wav,./data/caption.txt
222 | ./data/audio.wav,./data/caption.txt
223 | ./data/audio.wav,./data/caption.txt
224 | ./data/audio.wav,./data/caption.txt
225 | ./data/audio.wav,./data/caption.txt
226 | ./data/audio.wav,./data/caption.txt
227 | ./data/audio.wav,./data/caption.txt
228 | ./data/audio.wav,./data/caption.txt
229 | ./data/audio.wav,./data/caption.txt
230 | ./data/audio.wav,./data/caption.txt
231 | ./data/audio.wav,./data/caption.txt
232 | ./data/audio.wav,./data/caption.txt
233 | ./data/audio.wav,./data/caption.txt
234 | ./data/audio.wav,./data/caption.txt
235 | ./data/audio.wav,./data/caption.txt
236 | ./data/audio.wav,./data/caption.txt
237 | ./data/audio.wav,./data/caption.txt
238 | ./data/audio.wav,./data/caption.txt
239 | ./data/audio.wav,./data/caption.txt
240 | ./data/audio.wav,./data/caption.txt
241 | ./data/audio.wav,./data/caption.txt
242 | ./data/audio.wav,./data/caption.txt
243 | ./data/audio.wav,./data/caption.txt
244 | ./data/audio.wav,./data/caption.txt
245 | ./data/audio.wav,./data/caption.txt
246 | ./data/audio.wav,./data/caption.txt
247 | ./data/audio.wav,./data/caption.txt
248 | ./data/audio.wav,./data/caption.txt
249 | ./data/audio.wav,./data/caption.txt
250 | ./data/audio.wav,./data/caption.txt
251 | ./data/audio.wav,./data/caption.txt
252 | ./data/audio.wav,./data/caption.txt
253 | ./data/audio.wav,./data/caption.txt
254 | ./data/audio.wav,./data/caption.txt
255 | ./data/audio.wav,./data/caption.txt
256 | ./data/audio.wav,./data/caption.txt
257 | ./data/audio.wav,./data/caption.txt
258 | ./data/audio.wav,./data/caption.txt
259 | ./data/audio.wav,./data/caption.txt
260 | ./data/audio.wav,./data/caption.txt
261 | ./data/audio.wav,./data/caption.txt
262 | ./data/audio.wav,./data/caption.txt
263 | ./data/audio.wav,./data/caption.txt
264 | ./data/audio.wav,./data/caption.txt
265 | ./data/audio.wav,./data/caption.txt
266 | ./data/audio.wav,./data/caption.txt
267 | ./data/audio.wav,./data/caption.txt
268 | ./data/audio.wav,./data/caption.txt
269 | ./data/audio.wav,./data/caption.txt
270 | ./data/audio.wav,./data/caption.txt
271 | ./data/audio.wav,./data/caption.txt
272 | ./data/audio.wav,./data/caption.txt
273 | ./data/audio.wav,./data/caption.txt
274 | ./data/audio.wav,./data/caption.txt
275 | ./data/audio.wav,./data/caption.txt
276 | ./data/audio.wav,./data/caption.txt
277 | ./data/audio.wav,./data/caption.txt
278 | ./data/audio.wav,./data/caption.txt
279 | ./data/audio.wav,./data/caption.txt
280 | ./data/audio.wav,./data/caption.txt
281 | ./data/audio.wav,./data/caption.txt
282 | ./data/audio.wav,./data/caption.txt
283 | ./data/audio.wav,./data/caption.txt
284 | ./data/audio.wav,./data/caption.txt
285 | ./data/audio.wav,./data/caption.txt
286 | ./data/audio.wav,./data/caption.txt
287 | ./data/audio.wav,./data/caption.txt
288 | ./data/audio.wav,./data/caption.txt
289 | ./data/audio.wav,./data/caption.txt
290 | ./data/audio.wav,./data/caption.txt
291 | ./data/audio.wav,./data/caption.txt
292 | ./data/audio.wav,./data/caption.txt
293 | ./data/audio.wav,./data/caption.txt
294 | ./data/audio.wav,./data/caption.txt
295 | ./data/audio.wav,./data/caption.txt
296 | ./data/audio.wav,./data/caption.txt
297 | ./data/audio.wav,./data/caption.txt
298 | ./data/audio.wav,./data/caption.txt
299 | ./data/audio.wav,./data/caption.txt
300 | ./data/audio.wav,./data/caption.txt
301 | ./data/audio.wav,./data/caption.txt
302 | ./data/audio.wav,./data/caption.txt
303 | ./data/audio.wav,./data/caption.txt
304 | ./data/audio.wav,./data/caption.txt
305 | ./data/audio.wav,./data/caption.txt
306 | ./data/audio.wav,./data/caption.txt
307 | ./data/audio.wav,./data/caption.txt
308 | ./data/audio.wav,./data/caption.txt
309 | ./data/audio.wav,./data/caption.txt
310 | ./data/audio.wav,./data/caption.txt
311 | ./data/audio.wav,./data/caption.txt
312 | ./data/audio.wav,./data/caption.txt
313 | ./data/audio.wav,./data/caption.txt
314 | ./data/audio.wav,./data/caption.txt
315 | ./data/audio.wav,./data/caption.txt
316 | ./data/audio.wav,./data/caption.txt
317 | ./data/audio.wav,./data/caption.txt
318 | ./data/audio.wav,./data/caption.txt
319 | ./data/audio.wav,./data/caption.txt
320 | ./data/audio.wav,./data/caption.txt
321 | ./data/audio.wav,./data/caption.txt
322 | ./data/audio.wav,./data/caption.txt
323 | ./data/audio.wav,./data/caption.txt
324 | ./data/audio.wav,./data/caption.txt
325 | ./data/audio.wav,./data/caption.txt
326 | ./data/audio.wav,./data/caption.txt
327 | ./data/audio.wav,./data/caption.txt
328 | ./data/audio.wav,./data/caption.txt
329 | ./data/audio.wav,./data/caption.txt
330 | ./data/audio.wav,./data/caption.txt
331 | ./data/audio.wav,./data/caption.txt
332 | ./data/audio.wav,./data/caption.txt
333 | ./data/audio.wav,./data/caption.txt
334 | ./data/audio.wav,./data/caption.txt
335 | ./data/audio.wav,./data/caption.txt
336 | ./data/audio.wav,./data/caption.txt
337 | ./data/audio.wav,./data/caption.txt
338 | ./data/audio.wav,./data/caption.txt
339 | ./data/audio.wav,./data/caption.txt
340 | ./data/audio.wav,./data/caption.txt
341 | ./data/audio.wav,./data/caption.txt
342 | ./data/audio.wav,./data/caption.txt
343 | ./data/audio.wav,./data/caption.txt
344 | ./data/audio.wav,./data/caption.txt
345 | ./data/audio.wav,./data/caption.txt
346 | ./data/audio.wav,./data/caption.txt
347 | ./data/audio.wav,./data/caption.txt
348 | ./data/audio.wav,./data/caption.txt
349 | ./data/audio.wav,./data/caption.txt
350 | ./data/audio.wav,./data/caption.txt
351 | ./data/audio.wav,./data/caption.txt
352 | ./data/audio.wav,./data/caption.txt
353 | ./data/audio.wav,./data/caption.txt
354 | ./data/audio.wav,./data/caption.txt
355 | ./data/audio.wav,./data/caption.txt
356 | ./data/audio.wav,./data/caption.txt
357 | ./data/audio.wav,./data/caption.txt
358 | ./data/audio.wav,./data/caption.txt
359 | ./data/audio.wav,./data/caption.txt
360 | ./data/audio.wav,./data/caption.txt
361 | ./data/audio.wav,./data/caption.txt
362 | ./data/audio.wav,./data/caption.txt
363 | ./data/audio.wav,./data/caption.txt
364 | ./data/audio.wav,./data/caption.txt
365 | ./data/audio.wav,./data/caption.txt
366 | ./data/audio.wav,./data/caption.txt
367 | ./data/audio.wav,./data/caption.txt
368 | ./data/audio.wav,./data/caption.txt
369 | ./data/audio.wav,./data/caption.txt
370 | ./data/audio.wav,./data/caption.txt
371 | ./data/audio.wav,./data/caption.txt
372 | ./data/audio.wav,./data/caption.txt
373 | ./data/audio.wav,./data/caption.txt
374 | ./data/audio.wav,./data/caption.txt
375 | ./data/audio.wav,./data/caption.txt
376 | ./data/audio.wav,./data/caption.txt
377 | ./data/audio.wav,./data/caption.txt
378 | ./data/audio.wav,./data/caption.txt
379 | ./data/audio.wav,./data/caption.txt
380 | ./data/audio.wav,./data/caption.txt
381 | ./data/audio.wav,./data/caption.txt
382 | ./data/audio.wav,./data/caption.txt
383 | ./data/audio.wav,./data/caption.txt
384 | ./data/audio.wav,./data/caption.txt
385 | ./data/audio.wav,./data/caption.txt
386 | ./data/audio.wav,./data/caption.txt
387 | ./data/audio.wav,./data/caption.txt
388 | ./data/audio.wav,./data/caption.txt
389 | ./data/audio.wav,./data/caption.txt
390 | ./data/audio.wav,./data/caption.txt
391 | ./data/audio.wav,./data/caption.txt
392 | ./data/audio.wav,./data/caption.txt
393 | ./data/audio.wav,./data/caption.txt
394 | ./data/audio.wav,./data/caption.txt
395 | ./data/audio.wav,./data/caption.txt
396 | ./data/audio.wav,./data/caption.txt
397 | ./data/audio.wav,./data/caption.txt
398 | ./data/audio.wav,./data/caption.txt
399 | ./data/audio.wav,./data/caption.txt
400 | ./data/audio.wav,./data/caption.txt
401 | ./data/audio.wav,./data/caption.txt
402 |
--------------------------------------------------------------------------------