├── .gitignore
├── LICENSE
├── README.md
├── assets
├── aws.mov
├── aws.wav
├── bria.mp3
├── google.mov
├── google.wav
├── kansai_demo.mov
├── kotoba-speech_demo.mov
├── kotoba.mov
├── kotoba.wav
├── kotoba_cloning.mov
├── kotoba_cloning.wav
└── logo.png
├── fam
├── __init__.py
├── llm
│ ├── __init__.py
│ ├── adapters
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── flattened_encodec.py
│ │ └── tilted_encodec.py
│ ├── decoders.py
│ ├── enhancers.py
│ ├── fast_inference.py
│ ├── fast_inference_utils.py
│ ├── fast_model.py
│ ├── inference.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── attn.py
│ │ ├── combined.py
│ │ └── layers.py
│ ├── mixins
│ │ ├── __init__.py
│ │ ├── causal.py
│ │ └── non_causal.py
│ ├── model.py
│ ├── sample.py
│ ├── serving.py
│ ├── train.py
│ ├── training
│ │ ├── __init__.py
│ │ ├── args.py
│ │ ├── datamodule.py
│ │ ├── dataset.py
│ │ ├── dist.py
│ │ ├── evaluator.py
│ │ ├── optimizer.py
│ │ ├── utils.py
│ │ └── wandb_utils.py
│ └── utils.py
├── py.typed
├── quantiser
│ ├── __init__.py
│ ├── audio
│ │ ├── __init__.py
│ │ └── speaker_encoder
│ │ │ ├── __init__.py
│ │ │ ├── audio.py
│ │ │ ├── ckpt
│ │ │ └── ckpt.pt
│ │ │ └── model.py
│ └── text
│ │ └── tokenise.py
└── ui
│ ├── app.py
│ └── assets
│ └── favicon.ico
├── preprocess
├── audio_tokenize.py
├── download_reazon.py
├── spk_embed.py
├── split.py
├── text_tokenize.py
└── utils.py
├── pyproject.toml
├── requirements.txt
├── scripts
└── abci
│ ├── submit_ddp_1gpu.sh
│ ├── submit_ddp_1node.sh
│ ├── submit_fsdp_1node.sh
│ └── submit_fsdp_2node.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pkl
2 | *.flac
3 | *.npz
4 | *.wav
5 | *.m4a
6 | *.opus
7 | *.npy
8 | *wandb
9 | *.parquet
10 | *.wav
11 | *.pt
12 | *.bin
13 | *.png
14 | *.DS_Store
15 | *.idea
16 | *.ipynb_checkpoints/
17 | *__pycache__/
18 | *.pyc
19 | *.tsv
20 | *.bak
21 | *.tar
22 | *.db
23 | *.dat
24 | *.json
25 |
26 | # Byte-compiled / optimized / DLL files
27 | __pycache__/
28 | *.py[cod]
29 | *$py.class
30 |
31 | # C extensions
32 | *.so
33 |
34 | # Distribution / packaging
35 | .Python
36 | build/
37 | develop-eggs/
38 | dist/
39 | downloads/
40 | eggs/
41 | .eggs/
42 | lib/
43 | lib64/
44 | parts/
45 | sdist/
46 | var/
47 | wheels/
48 | share/python-wheels/
49 | *.egg-info/
50 | .installed.cfg
51 | *.egg
52 | MANIFEST
53 |
54 | # PyInstaller
55 | # Usually these files are written by a python script from a template
56 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
57 | *.manifest
58 | *.spec
59 |
60 | # Installer logs
61 | pip-log.txt
62 | pip-delete-this-directory.txt
63 |
64 | # Unit test / coverage reports
65 | htmlcov/
66 | .tox/
67 | .nox/
68 | .coverage
69 | .coverage.*
70 | .cache
71 | nosetests.xml
72 | coverage.xml
73 | *.cover
74 | *.py,cover
75 | .hypothesis/
76 | .pytest_cache/
77 | cover/
78 |
79 | # Translations
80 | *.mo
81 | *.pot
82 |
83 | # Django stuff:
84 | *.log
85 | local_settings.py
86 | db.sqlite3
87 | db.sqlite3-journal
88 |
89 | # Flask stuff:
90 | instance/
91 | .webassets-cache
92 |
93 | # Scrapy stuff:
94 | .scrapy
95 |
96 | # Sphinx documentation
97 | docs/_build/
98 |
99 | # PyBuilder
100 | .pybuilder/
101 | target/
102 |
103 | # Jupyter Notebook
104 | .ipynb_checkpoints
105 |
106 | # IPython
107 | profile_default/
108 | ipython_config.py
109 |
110 | # pyenv
111 | # For a library or package, you might want to ignore these files since the code is
112 | # intended to run in multiple environments; otherwise, check them in:
113 | # .python-version
114 |
115 | # pipenv
116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
119 | # install all needed dependencies.
120 | #Pipfile.lock
121 |
122 | # poetry
123 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
124 | # This is especially recommended for binary packages to ensure reproducibility, and is more
125 | # commonly ignored for libraries.
126 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
127 | #poetry.lock
128 |
129 | # pdm
130 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
131 | #pdm.lock
132 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
133 | # in version control.
134 | # https://pdm.fming.dev/#use-with-ide
135 | .pdm.toml
136 |
137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
138 | __pypackages__/
139 |
140 | # Celery stuff
141 | celerybeat-schedule
142 | celerybeat.pid
143 |
144 | # SageMath parsed files
145 | *.sage.py
146 |
147 | # Environments
148 | .env
149 | .venv
150 | env/
151 | venv/
152 | ENV/
153 | env.bak/
154 | venv.bak/
155 |
156 | # Spyder project settings
157 | .spyderproject
158 | .spyproject
159 |
160 | # Rope project settings
161 | .ropeproject
162 |
163 | # mkdocs documentation
164 | /site
165 |
166 | # mypy
167 | .mypy_cache/
168 | .dmypy.json
169 | dmypy.json
170 |
171 | # Pyre type checker
172 | .pyre/
173 |
174 | # pytype static type analyzer
175 | .pytype/
176 |
177 | # Cython debug symbols
178 | cython_debug/
179 |
180 | # PyCharm
181 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
182 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
183 | # and can be added to the global gitignore or merged into this file. For a more nuclear
184 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
185 | #.idea/
186 | **/.tmp
187 | !fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt
188 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Kotoba-Speech Version. 0.1
2 | Welcome to the code repository for Kotoba-Speech v0.1, a 1.2B Transformer-based speech generative model designed for generating fluent Japanese speech. This model represents one of the most advanced open-source options available in the field.
3 |
4 | Questions, feature requests, or bug reports? Join [our Discord community](https://discord.com/invite/qPVFqhGN7Z)!
5 |
6 |
7 |
8 | ## About
9 | Kotoba-Speech Version 0.1 distinguishes itself as an open-source solution for generating high-quality Japanese speech from text prompts, while also offering the capability for voice cloning through speech prompts.
10 |
11 | - **_Demo:_** Experience Kotoba-Speech in action [here](https://huggingface.co/spaces/kotoba-tech/Kotoba-Speech).
12 | - **_Model Checkpoint:_** Access our commercially usable pre-trained model [here](https://huggingface.co/kotoba-tech/kotoba-speech-v0.1).
13 | - **_Open-sourced Code:_** This repository opensources the training and inference code, along with the Gradio demo code. We borrow code from [MetaVoice](https://github.com/metavoiceio/metavoice-src) as a starting point.
14 |
15 | ### Our Model vs. Leading TTS Providers for Japanese
16 | https://github.com/kotoba-tech/kotoba-speech-release/assets/18011504/516f56f4-db92-45cb-b2e5-92863b36f8cd
17 |
18 | ### Fine-tuning Our Pre-trained Model for 関西弁
19 | https://github.com/kotoba-tech/kotoba-speech-release/assets/18011504/0204938e-7bb2-4c9f-9c6b-cb1e5c2dcff4
20 |
21 | ## Table of Contents
22 |
23 | 1. **Installation**
24 | 2. **Preparing Datasets**
25 | 3. **Training**
26 | 4. **Inference**
27 | 5. **Other Notes**
28 |
29 | ## 1. Installation
30 | ```bash
31 | # Installing ffmpeg
32 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz
33 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5
34 | md5sum -c ffmpeg-git-amd64-static.tar.xz.md5
35 | tar xvf ffmpeg-git-amd64-static.tar.xz
36 | sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/
37 | rm -rf ffmpeg-git-*
38 |
39 | # Setting-up Python virtual environment
40 | python -m venv myenv
41 | source myenv/bin/activate
42 | pip install -U --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
43 | pip install -r requirements.txt
44 | pip install flash-attn==2.5.3
45 | pip install -e .
46 | ```
47 |
48 | ## 2. Preparing Datasets
49 | We provide an example of preparing datasets to train our model. We use Reazon Speech, the largest open-sourced Japanese speech dataset, as an example. (Note that our model is not necessary trained on Reazon Speech solely.)
50 | ```bash
51 | # Download & Format Data
52 | python preprocess/download_reazon.py
53 |
54 | # Pre-calculate Speaker Embeddings
55 | python preprocess/spk_embed.py
56 |
57 | # Tokenize Audio
58 | python preprocess/audio_tokenize.py
59 |
60 | # Tokenize Text Captions
61 | python preprocess/text_tokenize.py
62 |
63 | # Split data into (training/validation/test)
64 | python preprocess/split.py
65 | ```
66 |
67 | ## 3. Training
68 | ```bash
69 | # Fine-tuning from our pre-trained checkpoint
70 | # Replace YOUR_WANDB_ENTITY and YOUR_WANDB_PROJECT
71 | python fam/llm/train.py --num_gpus 1 --batch_size 32 --per_gpu_batchsize 2 --max_epoch 5 --learning_rate 0.00005 --data_dir data --exp_name reazon_small_exp_finetuning --spkemb_dropout 0.1 --check_val_every_n_epoch 1 --wandb_entity YOUR_WANDB_ENTITY --wandb_project YOUR_WANDB_PROJECT --use_wandb
72 |
73 | # Multi-GPU Fine-tuning (e.g., using 2 GPUs)
74 | # Replace YOUR_WANDB_ENTITY and YOUR_WANDB_PROJECT
75 | python fam/llm/train.py --num_gpus 2 --batch_size 32 --per_gpu_batchsize 2 --max_epoch 5 --learning_rate 0.00005 --data_dir data --exp_name reazon_small_exp_finetuning --spkemb_dropout 0.1 --check_val_every_n_epoch 1 --wandb_entity YOUR_WANDB_ENTITY --wandb_project YOUR_WANDB_PROJECT --use_wandb
76 |
77 | # Fine-tuning (without WandB logging)
78 | python fam/llm/train.py --num_gpus 1 --batch_size 32 --per_gpu_batchsize 2 --max_epoch 5 --learning_rate 0.00005 --data_dir data --exp_name reazon_small_exp_finetuning --spkemb_dropout 0.1 --check_val_every_n_epoch 1
79 |
80 | # Training from scratch
81 | # Replace YOUR_WANDB_ENTITY and YOUR_WANDB_PROJECT
82 | python fam/llm/train.py --num_gpus 1 --batch_size 64 --per_gpu_batchsize 2 --max_epoch 20 --learning_rate 0.0001 --data_dir data --exp_name reazon_small_exp --spkemb_dropout 0.1 --check_val_every_n_epoch 1 --wandb_entity YOUR_WANDB_ENTITY --wandb_project YOUR_WANDB_PROJECT --use_wandb --train_from_scratch
83 | ```
84 |
85 | ## 4. Inference
86 | ```bash
87 | # Our Pre-trained Checkpoint
88 | python -i fam/llm/fast_inference.py --model_name kotoba-tech/kotoba-speech-v0.1
89 | tts.synthesise(text="コトバテクノロジーズのミッションは音声基盤モデルを作る事です。", spk_ref_path="assets/bria.mp3")
90 |
91 | # Inference from Our Pre-trained Checkpoint (関西弁)
92 | python -i fam/llm/fast_inference.py --model_name kotoba-tech/kotoba-speech-v0.1-kansai
93 | tts.synthesise(text="コトバテクノロジーズのミッションは音声基盤モデルを作る事です。", spk_ref_path="assets/bria.mp3")
94 |
95 | # Inference from Your Own Pre-trained Checkpoint
96 | # YOUR_CHECKPOINT_PATH is something like /home/checkpoints/epoch=0-step=1810.ckpt
97 | python -i fam/llm/fast_inference.py --first_model_path YOUR_CHECKPOINT_PATH
98 | tts.synthesise(text="コトバテクノロジーズのミッションは音声基盤モデルを作る事です。", spk_ref_path="assets/bria.mp3")
99 | ```
100 |
101 | ## 5. Other notes
102 | ## 5.1 Contribute
103 | - See all [active issues](https://github.com/kotoba-tech/kotoba-speech-release/issues)!
104 |
105 | ## 5.2 Enhancements
106 | - [ ] Write an explanation about multi-node training
107 | - [ ] Integrade a gradio demo
108 |
109 | ## 5.3 Acknowledgements
110 | We thank [MetaVoice](https://github.com/metavoiceio/metavoice-src) for releasing their code and their English pre-trained model.
111 |
--------------------------------------------------------------------------------
/assets/aws.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/aws.mov
--------------------------------------------------------------------------------
/assets/aws.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/aws.wav
--------------------------------------------------------------------------------
/assets/bria.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/bria.mp3
--------------------------------------------------------------------------------
/assets/google.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/google.mov
--------------------------------------------------------------------------------
/assets/google.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/google.wav
--------------------------------------------------------------------------------
/assets/kansai_demo.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kansai_demo.mov
--------------------------------------------------------------------------------
/assets/kotoba-speech_demo.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba-speech_demo.mov
--------------------------------------------------------------------------------
/assets/kotoba.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba.mov
--------------------------------------------------------------------------------
/assets/kotoba.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba.wav
--------------------------------------------------------------------------------
/assets/kotoba_cloning.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba_cloning.mov
--------------------------------------------------------------------------------
/assets/kotoba_cloning.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba_cloning.wav
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/logo.png
--------------------------------------------------------------------------------
/fam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/__init__.py
--------------------------------------------------------------------------------
/fam/llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/llm/__init__.py
--------------------------------------------------------------------------------
/fam/llm/adapters/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
2 | from fam.llm.adapters.tilted_encodec import TiltedEncodec
3 |
--------------------------------------------------------------------------------
/fam/llm/adapters/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 |
4 | class BaseDataAdapter(ABC):
5 | pass
6 |
--------------------------------------------------------------------------------
/fam/llm/adapters/flattened_encodec.py:
--------------------------------------------------------------------------------
1 | from fam.llm.adapters.base import BaseDataAdapter
2 |
3 |
4 | class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter):
5 | def __init__(self, end_of_audio_token):
6 | self._end_of_audio_token = end_of_audio_token
7 |
8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
9 | assert len(tokens) == 1
10 | tokens = tokens[0]
11 |
12 | text_ids = []
13 | extracted_audio_ids = [[], []]
14 |
15 | for t in tokens:
16 | if t < self._end_of_audio_token:
17 | extracted_audio_ids[0].append(t)
18 | elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token:
19 | extracted_audio_ids[1].append(t - self._end_of_audio_token)
20 | # We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token
21 | elif t > 2 * self._end_of_audio_token:
22 | text_ids.append(t)
23 |
24 | if len(set([len(x) for x in extracted_audio_ids])) != 1:
25 | min_len = min([len(x) for x in extracted_audio_ids])
26 | max_len = max([len(x) for x in extracted_audio_ids])
27 | print("WARNING: Number of tokens at each hierarchy must be of the same length!")
28 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
29 | print([len(x) for x in extracted_audio_ids])
30 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
31 |
32 | return text_ids[:-1], extracted_audio_ids
33 |
34 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
35 | """
36 | Performs the required combination and padding as needed.
37 | """
38 | raise NotImplementedError
39 |
--------------------------------------------------------------------------------
/fam/llm/adapters/tilted_encodec.py:
--------------------------------------------------------------------------------
1 | from fam.llm.adapters.base import BaseDataAdapter
2 |
3 |
4 | class TiltedEncodec(BaseDataAdapter):
5 | def __init__(self, end_of_audio_token):
6 | self._end_of_audio_token = end_of_audio_token
7 |
8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
9 | assert len(tokens) > 1
10 |
11 | text_ids = []
12 | extracted_audio_ids = []
13 |
14 | extracted_audio_ids.append([])
15 | # Handle first hierarchy as special case as it contains text tokens as well
16 | # TODO: maybe it doesn't need special case, and can be handled on it's own :)
17 | for t in tokens[0]:
18 | if t > self._end_of_audio_token:
19 | text_ids.append(t)
20 | elif t < self._end_of_audio_token:
21 | extracted_audio_ids[0].append(t)
22 |
23 | # Handle the rest of the hierarchies
24 | for i in range(1, len(tokens)):
25 | token_hierarchy_ids = tokens[i]
26 | extracted_audio_ids.append([])
27 | for t in token_hierarchy_ids:
28 | if t < self._end_of_audio_token:
29 | extracted_audio_ids[i].append(t)
30 |
31 | if len(set([len(x) for x in extracted_audio_ids])) != 1:
32 | min_len = min([len(x) for x in extracted_audio_ids])
33 | max_len = max([len(x) for x in extracted_audio_ids])
34 | print("WARNING: Number of tokens at each hierarchy must be of the same length!")
35 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
36 | print([len(x) for x in extracted_audio_ids])
37 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
38 |
39 | return text_ids[:-1], extracted_audio_ids
40 |
41 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
42 | """
43 | Performs the required combination and padding as needed.
44 | """
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/fam/llm/decoders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import uuid
4 | from abc import ABC, abstractmethod
5 | from typing import Callable, Optional, Union
6 |
7 | import julius
8 | import torch
9 | from audiocraft.data.audio import audio_read, audio_write
10 | from audiocraft.models import MultiBandDiffusion # type: ignore
11 |
12 | from IPython import embed
13 |
14 | class Decoder(ABC):
15 | @abstractmethod
16 | def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None):
17 | raise NotImplementedError
18 |
19 |
20 | class EncodecDecoder(Decoder):
21 | def __init__(
22 | self,
23 | tokeniser_decode_fn: Callable[[list[int]], str],
24 | data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
25 | output_dir: str,
26 | ):
27 | self._mbd_bandwidth = 6 # 1.5
28 | self._mbd_sample_rate = 24_000
29 | self._end_of_audio_token = 1024
30 | self._num_codebooks = 8
31 | self.mbd = MultiBandDiffusion.get_mbd_24khz(bw=self._mbd_bandwidth)
32 |
33 | self.tokeniser_decode_fn = tokeniser_decode_fn
34 | self._data_adapter_fn = data_adapter_fn
35 |
36 | self.output_dir = pathlib.Path(output_dir).resolve()
37 | os.makedirs(self.output_dir, exist_ok=True)
38 |
39 | def _save_audio(self, name: str, wav: torch.Tensor):
40 | audio_write(
41 | name,
42 | wav.squeeze(0).cpu(),
43 | self._mbd_sample_rate,
44 | strategy="loudness",
45 | loudness_compressor=True,
46 | )
47 |
48 | def get_tokens(self, audio_path: str) -> list[list[int]]:
49 | """
50 | Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g.
51 | limited codebook reconstruction or sampling from second stage model only).
52 | """
53 | pass
54 | wav, sr = audio_read(audio_path)
55 | if sr != self._mbd_sample_rate:
56 | wav = julius.resample_frac(wav, sr, self._mbd_sample_rate)
57 | if wav.ndim == 2:
58 | wav = wav.unsqueeze(1)
59 | wav = wav.to("cuda")
60 | tokens = self.mbd.codec_model.encode(wav)
61 | tokens = tokens[0][0]
62 | return tokens.tolist()
63 |
64 | def decode(
65 | self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None
66 | ) -> Union[str, torch.Tensor]:
67 | # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
68 | text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
69 | text = self.tokeniser_decode_fn(text_ids)
70 | print(f"Text: {text}")
71 |
72 | tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)
73 |
74 | if tokens.shape[1] < self._num_codebooks:
75 | tokens = torch.cat(
76 | [tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1
77 | )
78 |
79 | if causal:
80 | return tokens
81 | else:
82 | with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
83 | # embed()
84 | wav = self.mbd.tokens_to_wav(tokens)
85 | # NOTE: we couldn't just return wav here as it goes through loudness compression etc :)
86 |
87 | if wav.shape[-1] < 9600:
88 | # this causes problem for the code below, and is also odd :)
89 | # first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!)
90 | raise Exception("wav predicted is shorter than 400ms!")
91 |
92 | try:
93 | wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
94 | self._save_audio(wav_file_name, wav)
95 | print(f"\nSaved audio to {wav_file_name}.wav")
96 | return wav_file_name
97 | except Exception as e:
98 | print(f"Failed to save audio! Reason: {e}")
99 | wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
100 | self._save_audio(wav_file_name, wav)
101 | print(f"\nSaved audio to {wav_file_name}.wav")
102 | return wav_file_name
103 |
--------------------------------------------------------------------------------
/fam/llm/enhancers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC
3 | from typing import Literal, Optional
4 |
5 | from df.enhance import enhance, init_df, load_audio, save_audio
6 | from pydub import AudioSegment
7 |
8 |
9 | def convert_to_wav(input_file: str, output_file: str):
10 | """Convert an audio file to WAV format
11 |
12 | Args:
13 | input_file (str): path to input audio file
14 | output_file (str): path to output WAV file
15 |
16 | """
17 | # Detect the format of the input file
18 | format = input_file.split(".")[-1].lower()
19 |
20 | # Read the audio file
21 | audio = AudioSegment.from_file(input_file, format=format)
22 |
23 | # Export as WAV
24 | audio.export(output_file, format="wav")
25 |
26 |
27 | def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
28 | """Generate the output file path
29 |
30 | Args:
31 | audio_file (str): path to input audio file
32 | tag (str): tag to append to the output file name
33 | ext (str, optional): extension of the output file. Defaults to None.
34 |
35 | Returns:
36 | str: path to output file
37 | """
38 |
39 | directory = "./enhanced"
40 | # Get the name of the input file
41 | filename = os.path.basename(audio_file)
42 |
43 | # Get the name of the input file without the extension
44 | filename_without_extension = os.path.splitext(filename)[0]
45 |
46 | # Get the extension of the input file
47 | extension = ext or os.path.splitext(filename)[1]
48 |
49 | # Generate the output file path
50 | output_file = os.path.join(directory, filename_without_extension + tag + extension)
51 |
52 | return output_file
53 |
54 |
55 | class BaseEnhancer(ABC):
56 | """Base class for audio enhancers"""
57 |
58 | def __init__(self, *args, **kwargs):
59 | raise NotImplementedError
60 |
61 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
62 | raise NotImplementedError
63 |
64 | def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
65 | output_file = make_output_file_path(audio_file, tag, ext=ext)
66 | os.makedirs(os.path.dirname(output_file), exist_ok=True)
67 | return output_file
68 |
69 |
70 | class DFEnhancer(BaseEnhancer):
71 | def __init__(self, *args, **kwargs):
72 | self.model, self.df_state, _ = init_df()
73 |
74 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
75 | output_file = output_file or self.get_output_file(audio_file, "_df")
76 |
77 | audio, _ = load_audio(audio_file, sr=self.df_state.sr())
78 |
79 | enhanced = enhance(self.model, self.df_state, audio)
80 |
81 | save_audio(output_file, enhanced, self.df_state.sr())
82 |
83 | return output_file
84 |
85 |
86 | def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
87 | """Get an audio enhancer
88 |
89 | Args:
90 | enhancer_name (Literal["df"]): name of the audio enhancer
91 |
92 | Raises:
93 | ValueError: if the enhancer name is not recognised
94 |
95 | Returns:
96 | BaseEnhancer: audio enhancer
97 | """
98 |
99 | if enhancer_name == "df":
100 | return DFEnhancer()
101 | else:
102 | raise ValueError(f"Unknown enhancer name: {enhancer_name}")
103 |
--------------------------------------------------------------------------------
/fam/llm/fast_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tyro
4 | import tempfile
5 | import time
6 | from pathlib import Path
7 |
8 | import librosa
9 | import torch
10 | from huggingface_hub import snapshot_download
11 |
12 | from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
13 | from fam.llm.decoders import EncodecDecoder
14 | from fam.llm.fast_inference_utils import build_model, main
15 | from fam.llm.inference import (
16 | EncodecDecoder,
17 | InferenceConfig,
18 | Model,
19 | TiltedEncodec,
20 | TrainedBPETokeniser,
21 | get_cached_embedding,
22 | get_cached_file,
23 | get_enhancer,
24 | )
25 | from fam.llm.utils import (
26 | check_audio_file,
27 | get_default_dtype,
28 | get_device,
29 | normalize_text,
30 | )
31 | import argparse
32 |
33 |
34 | class TTS:
35 | def __init__(
36 | self, model_name: str = "kotoba-tech/kotoba-speech-v0.1", *, seed: int = 1337, output_dir: str = "outputs", first_model_path: str = None,
37 | ):
38 | """
39 | model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/kotoba-tech/)
40 | """
41 |
42 | # NOTE: this needs to come first so that we don't change global state when we want to use
43 | # the torch.compiled-model.
44 | self._dtype = get_default_dtype()
45 | self._device = get_device()
46 | self._first_model_dir = snapshot_download(repo_id=model_name)
47 | if model_name != "kotoba-tech/kotoba-speech-v0.1":
48 | self._other_model_dir = snapshot_download(repo_id="kotoba-tech/kotoba-speech-v0.1")
49 | else:
50 | self._other_model_dir = self._first_model_dir
51 |
52 | self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
53 | self.output_dir = output_dir
54 | os.makedirs(self.output_dir, exist_ok=True)
55 |
56 | second_stage_ckpt_path = f"{self._other_model_dir}/second_stage.pt"
57 | config_second_stage = InferenceConfig(
58 | ckpt_path=second_stage_ckpt_path,
59 | num_samples=1,
60 | seed=seed,
61 | device=self._device,
62 | dtype=self._dtype,
63 | compile=False,
64 | init_from="resume",
65 | output_dir=self.output_dir,
66 | )
67 | data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
68 | self.llm_second_stage = Model(
69 | config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
70 | )
71 | self.enhancer = get_enhancer("df")
72 |
73 | self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
74 | self.model, self.tokenizer, self.smodel, self.model_size = build_model(
75 | precision=self.precision,
76 | checkpoint_path=Path(f"{self._other_model_dir}/first_stage.pt"),
77 | spk_emb_ckpt_path=Path(f"{self._other_model_dir}/speaker_encoder.pt"),
78 | device=self._device,
79 | compile=True,
80 | compile_prefill=True,
81 | first_model_path= first_model_path if (first_model_path is not None) and (os.path.exists(first_model_path)) else Path(f"{self._first_model_dir}/first_stage.pt"),
82 | )
83 |
84 |
85 | def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
86 | """
87 | text: Text to speak
88 | spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
89 | top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
90 | guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
91 | temperature: Temperature for sampling applied to both LLMs (first & second stage)
92 |
93 | returns: path to speech .wav file
94 | """
95 | text = normalize_text(text)
96 | spk_ref_path = get_cached_file(spk_ref_path)
97 | check_audio_file(spk_ref_path)
98 | spk_emb = get_cached_embedding(
99 | spk_ref_path,
100 | self.smodel,
101 | ).to(device=self._device, dtype=self.precision)
102 |
103 | start = time.time()
104 | # first stage LLM
105 | tokens = main(
106 | model=self.model,
107 | tokenizer=self.tokenizer,
108 | model_size=self.model_size,
109 | prompt=text,
110 | spk_emb=spk_emb,
111 | top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
112 | guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
113 | temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
114 | )
115 | text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
116 |
117 | b_speaker_embs = spk_emb.unsqueeze(0)
118 |
119 | # second stage LLM + multi-band diffusion model
120 | wav_files = self.llm_second_stage(
121 | texts=[text],
122 | encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
123 | speaker_embs=b_speaker_embs,
124 | batch_size=1,
125 | guidance_scale=None,
126 | top_p=None,
127 | top_k=200,
128 | temperature=1.0,
129 | max_new_tokens=None,
130 | )
131 |
132 | # enhance using deepfilternet
133 | wav_file = wav_files[0]
134 | with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
135 | self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
136 | shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
137 | print(f"\nSaved audio to {wav_file}.wav")
138 |
139 | # calculating real-time factor (RTF)
140 | time_to_synth_s = time.time() - start
141 | audio, sr = librosa.load(str(wav_file) + ".wav")
142 | duration_s = librosa.get_duration(y=audio, sr=sr)
143 | print(f"\nTotal time to synth (s): {time_to_synth_s}")
144 | print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
145 |
146 | return str(wav_file) + ".wav"
147 |
148 |
149 | if __name__ == "__main__":
150 | tts = tyro.cli(TTS)
--------------------------------------------------------------------------------
/fam/llm/fast_inference_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kotoba Technologies, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without modification, are permitted
5 | # provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this list of
8 | # conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this
11 | # list of conditions and the following disclaimer in the documentation and/or other
12 | # materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its contributors
15 | # may be used to endorse or promote products derived from this software without
16 | # specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 | import itertools
27 | import gc
28 | import time
29 | from pathlib import Path
30 | from typing import Optional, Tuple
31 |
32 | import torch
33 | import torch._dynamo.config
34 | import torch._inductor.config
35 | import tqdm
36 |
37 |
38 | def device_sync(device):
39 | if "cuda" in device:
40 | torch.cuda.synchronize()
41 | elif "cpu" in device:
42 | pass
43 | else:
44 | print(f"device={device} is not yet suppported")
45 |
46 |
47 | torch._inductor.config.coordinate_descent_tuning = True
48 | torch._inductor.config.triton.unique_kernel_names = True
49 | torch._inductor.config.fx_graph_cache = (
50 | True # Experimental feature to reduce compilation times, will be on by default in future
51 | )
52 |
53 | # imports need to happen after setting above flags
54 | from fam.llm.fast_model import Transformer
55 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
56 | from fam.quantiser.text.tokenise import TrainedBPETokeniser
57 |
58 |
59 | def multinomial_sample_one_no_sync(
60 | probs_sort,
61 | ): # Does multinomial sampling without a cuda synchronization
62 | q = torch.empty_like(probs_sort).exponential_(1)
63 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
64 |
65 |
66 | def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor):
67 | # ref: huggingface/transformers
68 |
69 | sorted_logits, sorted_indices = torch.sort(logits, descending=False)
70 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
71 |
72 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
73 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
74 | # Keep at least min_tokens_to_keep
75 | sorted_indices_to_remove[-1:] = 0
76 |
77 | # scatter sorted tensors to original indexing
78 | indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
79 | scores = logits.masked_fill(indices_to_remove, -float("Inf"))
80 | return scores
81 |
82 |
83 | def logits_to_probs(
84 | logits,
85 | *,
86 | temperature: torch.Tensor,
87 | top_p: Optional[torch.Tensor] = None,
88 | top_k: Optional[torch.Tensor] = None,
89 | ):
90 | logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature))
91 |
92 | if top_k is not None:
93 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
94 | pivot = v.select(-1, -1).unsqueeze(-1)
95 | logits = torch.where(logits < pivot, -float("Inf"), logits)
96 |
97 | if top_p is not None:
98 | logits = top_p_sample(logits, top_p)
99 |
100 | probs = torch.nn.functional.softmax(logits, dim=-1)
101 |
102 | return probs
103 |
104 |
105 | def sample(
106 | logits,
107 | guidance_scale: torch.Tensor,
108 | temperature: torch.Tensor,
109 | top_p: Optional[torch.Tensor] = None,
110 | top_k: Optional[torch.Tensor] = None,
111 | ):
112 | # (b, t, vocab_size)
113 | logits = logits[:, -1]
114 | logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0)
115 | logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb
116 | probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k)
117 | idx_next = multinomial_sample_one_no_sync(probs)
118 | return idx_next, probs
119 |
120 |
121 | def prefill(
122 | model: Transformer,
123 | x: torch.Tensor,
124 | spk_emb: torch.Tensor,
125 | input_pos: torch.Tensor,
126 | **sampling_kwargs,
127 | ) -> torch.Tensor:
128 | # input_pos: [B, S]
129 | logits = model(x, spk_emb, input_pos)
130 | return sample(logits, **sampling_kwargs)[0]
131 |
132 |
133 | def decode_one_token(
134 | model: Transformer,
135 | x: torch.Tensor,
136 | spk_emb: torch.Tensor,
137 | input_pos: torch.Tensor,
138 | **sampling_kwargs,
139 | ) -> Tuple[torch.Tensor, torch.Tensor]:
140 | # input_pos: [B, 1]
141 | assert input_pos.shape[-1] == 1
142 | logits = model(x, spk_emb, input_pos)
143 | return sample(logits, **sampling_kwargs)
144 |
145 |
146 | def decode_n_tokens(
147 | model: Transformer,
148 | cur_token: torch.Tensor,
149 | spk_emb: torch.Tensor,
150 | input_pos: torch.Tensor,
151 | num_new_tokens: int,
152 | callback=lambda _: _,
153 | return_probs: bool = False,
154 | end_of_audio_token: int = 2048,
155 | **sampling_kwargs,
156 | ):
157 | new_tokens, new_probs = [], []
158 | for i in tqdm.tqdm(range(num_new_tokens)):
159 | if (cur_token == end_of_audio_token).any():
160 | break
161 | with torch.backends.cuda.sdp_kernel(
162 | enable_flash=False, enable_mem_efficient=False, enable_math=True
163 | ): # Actually better for Inductor to codegen attention here
164 | next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs)
165 | input_pos += 1
166 | new_tokens.append(next_token.clone())
167 | callback(new_tokens[-1])
168 | if return_probs:
169 | new_probs.append(next_prob.clone())
170 | cur_token = next_token.view(1, -1).repeat(2, 1)
171 |
172 | return new_tokens, new_probs
173 |
174 |
175 | def model_forward(model, x, spk_emb, input_pos):
176 | return model(x, spk_emb, input_pos)
177 |
178 |
179 | @torch.no_grad()
180 | def generate(
181 | model: Transformer,
182 | prompt: torch.Tensor,
183 | spk_emb: torch.Tensor,
184 | *,
185 | max_new_tokens: Optional[int] = None,
186 | callback=lambda x: x,
187 | end_of_audio_token: int = 2048,
188 | **sampling_kwargs,
189 | ) -> torch.Tensor:
190 | """
191 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
192 | """
193 | # create an empty tensor of the expected final shape and fill in the current tokens
194 | T = prompt.size(0)
195 | if max_new_tokens is None:
196 | max_seq_length = model.config.block_size
197 | else:
198 | max_seq_length = T + max_new_tokens
199 | max_seq_length = min(max_seq_length, model.config.block_size)
200 | max_new_tokens = max_seq_length - T
201 | if max_new_tokens <= 0:
202 | raise ValueError("Prompt is too long to generate more tokens")
203 |
204 | device, dtype = prompt.device, prompt.dtype
205 |
206 | seq = torch.clone(prompt)
207 | input_pos = torch.arange(0, T, device=device)
208 |
209 | next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
210 | seq = torch.cat([seq, next_token.view(1)])
211 |
212 | input_pos = torch.tensor([T], device=device, dtype=torch.int)
213 |
214 | generated_tokens, _ = decode_n_tokens(
215 | model,
216 | next_token.view(1, -1).repeat(2, 1),
217 | spk_emb,
218 | input_pos,
219 | max_new_tokens - 1,
220 | callback=callback,
221 | end_of_audio_token=end_of_audio_token,
222 | **sampling_kwargs,
223 | )
224 | seq = torch.cat([seq, torch.cat(generated_tokens)])
225 |
226 | return seq
227 |
228 |
229 | def encode_tokens(tokenizer, string, device="cuda"):
230 | tokens = tokenizer.encode(string)
231 | return torch.tensor(tokens, dtype=torch.int, device=device)
232 |
233 |
234 | def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path=None, unwanted_prefix="_orig_mod."):
235 | ##### MODEL
236 | with torch.device("meta"):
237 | model = Transformer.from_name("kotoba-speech-v0.1")
238 |
239 | # TODO(quantization): enable
240 | # if "int8" in str(checkpoint_path):
241 | # print("Using int8 weight-only quantization!")
242 | # from quantize import WeightOnlyInt8QuantHandler
243 | # simple_quantizer = WeightOnlyInt8QuantHandler(model)
244 | # model = simple_quantizer.convert_for_runtime()
245 | # from quantize import WeightOnlyInt8QuantHandler
246 |
247 | # if "int4" in str(checkpoint_path):
248 | # print("Using int4 quantization!")
249 | # path_comps = checkpoint_path.name.split(".")
250 | # assert path_comps[-2].startswith("g")
251 | # groupsize = int(path_comps[-2][1:])
252 | # from quantize import WeightOnlyInt4QuantHandler
253 | # simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
254 | # model = simple_quantizer.convert_for_runtime()
255 |
256 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
257 |
258 | ###### TOKENIZER
259 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
260 | tokenizer = TrainedBPETokeniser(**tokenizer_info)
261 |
262 | if first_model_path is not None:
263 | trained_ckpt = torch.load(str(first_model_path), mmap=True, weights_only=False)
264 | #trained_ckpt1 = torch.load(str(first_model_path), mmap=True, weights_only=False, map_location="cpu")
265 | #trained_ckpt2 = torch.load(str(checkpoint_path), mmap=True, weights_only=False, map_location="cpu")
266 | if "state_dict" in trained_ckpt.keys():
267 | state_dict = trained_ckpt["state_dict"]
268 | else:
269 | state_dict = trained_ckpt["model"]
270 | del checkpoint
271 | gc.collect()
272 | torch.cuda.empty_cache()
273 | else:
274 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
275 | if "state_dict" in checkpoint.keys():
276 | state_dict = checkpoint["state_dict"]
277 | else:
278 | state_dict = checkpoint["model"]
279 |
280 | # convert Kotoba-Speech model weights naming to gptfast naming
281 | for k, v in list(state_dict.items()):
282 | if k.startswith(unwanted_prefix):
283 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
284 | state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight")
285 | state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight")
286 | state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight")
287 | state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight")
288 | for k, v in list(state_dict.items()):
289 | if k.startswith("transformer.h."):
290 | state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k)
291 | k = k.replace("transformer.h.", "layers.")
292 | if ".attn.c_attn." in k:
293 | state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k)
294 | k = k.replace(".attn.c_attn.", ".attention.wqkv.")
295 | if ".attn.c_proj." in k:
296 | state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k)
297 | k = k.replace(".attn.c_proj.", ".attention.wo.")
298 | if ".mlp.swiglu.w1." in k:
299 | state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k)
300 | k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")
301 | if ".mlp.swiglu.w3." in k:
302 | state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k)
303 | k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")
304 | if ".ln_1." in k:
305 | state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k)
306 | k = k.replace(".ln_1.", ".attention_norm.")
307 | if ".ln_2." in k:
308 | state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k)
309 | k = k.replace(".ln_2.", ".ffn_norm.")
310 | if ".mlp.c_proj." in k:
311 | state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k)
312 | k = k.replace(".mlp.c_proj.", ".feed_forward.w2.")
313 |
314 | model.load_state_dict(state_dict, assign=True)
315 | # simple_quantizer = WeightOnlyInt8QuantHandler(model)
316 | # quantized_state_dict = simple_quantizer.create_quantized_state_dict()
317 | # model = simple_quantizer.convert_for_runtime()
318 | # model.load_state_dict(quantized_state_dict, assign=True)
319 | model = model.to(device=device, dtype=precision)
320 |
321 | ###### SPEAKER EMBEDDER
322 | # TODO: fix!
323 | smodel = SpeakerEncoder(
324 | weights_fpath=spk_emb_ckpt_path,
325 | device=device,
326 | eval=True,
327 | verbose=False,
328 | )
329 | return model.eval(), tokenizer, smodel
330 |
331 |
332 | def build_model(
333 | *,
334 | precision: torch.dtype,
335 | checkpoint_path: Path = Path(""),
336 | spk_emb_ckpt_path: Path = Path(""),
337 | compile_prefill: bool = False,
338 | compile: bool = True,
339 | device: str = "cuda",
340 | first_model_path: str = None,
341 | ):
342 | assert checkpoint_path.is_file(), checkpoint_path
343 |
344 | print(f"Using device={device}")
345 |
346 | print("Loading model ...")
347 | t0 = time.time()
348 | if first_model_path is None:
349 | # model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
350 | model, tokenizer, smodel = _load_model(
351 | checkpoint_path, spk_emb_ckpt_path, device, precision, unwanted_prefix="first_stage_model_transformer."
352 | )
353 |
354 | else:
355 | model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path, unwanted_prefix="first_stage_model_transformer.")
356 |
357 |
358 | device_sync(device=device) # MKG
359 | print(f"Time to load model: {time.time() - t0:.02f} seconds")
360 |
361 | torch.manual_seed(1234)
362 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
363 |
364 | with torch.device(device):
365 | model.setup_spk_cond_mask()
366 | model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size)
367 |
368 | if compile:
369 | print("Compiling...Can take up to 2 mins.")
370 | global decode_one_token, prefill
371 | decode_one_token = torch.compile(
372 | decode_one_token,
373 | mode="max-autotune",
374 | fullgraph=True,
375 | )
376 |
377 | if compile_prefill:
378 | prefill = torch.compile(
379 | prefill,
380 | fullgraph=True,
381 | dynamic=True,
382 | )
383 |
384 | encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
385 | spk_emb = torch.randn((1, 256), device=device, dtype=precision)
386 |
387 | device_sync(device=device) # MKG
388 | t0 = time.perf_counter()
389 | y = generate(
390 | model,
391 | encoded,
392 | spk_emb,
393 | max_new_tokens=200,
394 | callback=lambda x: x,
395 | temperature=torch.tensor(1.0, device=device, dtype=precision),
396 | top_k=None,
397 | top_p=torch.tensor(0.95, device=device, dtype=precision),
398 | guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
399 | end_of_audio_token=9999, # don't end early for compilation stage.
400 | )
401 |
402 | device_sync(device=device) # MKG
403 |
404 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
405 |
406 | return model, tokenizer, smodel, model_size
407 |
408 |
409 | def main(
410 | *,
411 | model,
412 | tokenizer,
413 | model_size,
414 | prompt: str,
415 | guidance_scale: torch.Tensor,
416 | temperature: torch.Tensor,
417 | spk_emb: torch.Tensor,
418 | top_k: Optional[torch.Tensor] = None,
419 | top_p: Optional[torch.Tensor] = None,
420 | device: str = "cuda",
421 | ) -> list:
422 | """Generates text samples based on a pre-trained Transformer model and tokenizer."""
423 |
424 | encoded = encode_tokens(tokenizer, prompt, device=device)
425 | prompt_length = encoded.size(0)
426 |
427 | aggregate_metrics: dict = {
428 | "tokens_per_sec": [],
429 | }
430 |
431 | device_sync(device=device) # MKG
432 |
433 | if True:
434 | callback = lambda x: x
435 | t0 = time.perf_counter()
436 |
437 | y = generate(
438 | model,
439 | encoded,
440 | spk_emb,
441 | callback=callback,
442 | temperature=temperature,
443 | top_k=top_k,
444 | top_p=top_p,
445 | guidance_scale=guidance_scale,
446 | )
447 |
448 | device_sync(device=device) # MKG
449 | t = time.perf_counter() - t0
450 |
451 | tokens_generated = y.size(0) - prompt_length
452 | tokens_sec = tokens_generated / t
453 | aggregate_metrics["tokens_per_sec"].append(tokens_sec)
454 | print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
455 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
456 | # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
457 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n")
458 |
459 | return y.tolist()
460 |
--------------------------------------------------------------------------------
/fam/llm/fast_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kotoba Technologies, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without modification, are permitted
5 | # provided that the following conditions are met:
6 | #
7 | # 1. Redistributions of source code must retain the above copyright notice, this list of
8 | # conditions and the following disclaimer.
9 | #
10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this
11 | # list of conditions and the following disclaimer in the documentation and/or other
12 | # materials provided with the distribution.
13 | #
14 | # 3. Neither the name of the copyright holder nor the names of its contributors
15 | # may be used to endorse or promote products derived from this software without
16 | # specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 | from dataclasses import dataclass
27 | from functools import reduce
28 | from math import gcd
29 | from typing import Optional, Tuple
30 |
31 | import torch
32 | import torch.nn as nn
33 | from torch import Tensor
34 | from torch.nn import functional as F
35 |
36 | from fam.llm.utils import get_default_dtype
37 |
38 | import logging
39 |
40 | # Adjust the logging level
41 | logger = logging.getLogger("torch")
42 | logger.setLevel(logging.ERROR)
43 |
44 |
45 | def find_multiple(n: int, *args: Tuple[int]) -> int:
46 | k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
47 | if n % k == 0:
48 | return n
49 | return n + k - (n % k)
50 |
51 |
52 | @dataclass
53 | class ModelArgs:
54 | block_size: int = 2048
55 | vocab_size: int = 32000
56 | n_layer: int = 32
57 | n_head: int = 32
58 | dim: int = 4096
59 | speaker_emb_dim: int = 256
60 | intermediate_size: int = None
61 | n_local_heads: int = -1
62 | head_dim: int = 64
63 | norm_eps: float = 1e-5
64 | dtype: torch.dtype = torch.bfloat16
65 |
66 | def __post_init__(self):
67 | if self.n_local_heads == -1:
68 | self.n_local_heads = self.n_head
69 | if self.intermediate_size is None:
70 | hidden_dim = 4 * self.dim
71 | n_hidden = int(2 * hidden_dim / 3)
72 | self.intermediate_size = find_multiple(n_hidden, 256)
73 | self.head_dim = self.dim // self.n_head
74 |
75 | self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()]
76 |
77 | @classmethod
78 | def from_name(cls, name: str):
79 | if name in transformer_configs:
80 | return cls(**transformer_configs[name])
81 | # fuzzy search
82 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
83 | assert len(config) == 1, name
84 | return cls(**transformer_configs[config[0]])
85 |
86 |
87 | transformer_configs = {
88 | "kotoba-speech-v0.1": dict(
89 | n_layer=24,
90 | n_head=16,
91 | dim=2048,
92 | vocab_size=2562,
93 | ),
94 | }
95 |
96 |
97 | class KVCache(nn.Module):
98 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
99 | super().__init__()
100 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
101 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
102 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
103 |
104 | def update(self, input_pos, k_val, v_val):
105 | # input_pos: [S], k_val: [B, H, S, D]
106 | assert input_pos.shape[0] == k_val.shape[2]
107 |
108 | k_out = self.k_cache
109 | v_out = self.v_cache
110 | k_out[:, :, input_pos] = k_val
111 | v_out[:, :, input_pos] = v_val
112 |
113 | return k_out, v_out
114 |
115 |
116 | class Transformer(nn.Module):
117 | def __init__(self, config: ModelArgs) -> None:
118 | super().__init__()
119 | self.config = config
120 |
121 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
122 | self.pos_embeddings = nn.Embedding(config.block_size, config.dim)
123 | self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False)
124 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
125 | self.norm = RMSNorm(config.dim, eps=config.norm_eps)
126 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
127 |
128 | self.mask_cache: Optional[Tensor] = None
129 | self.max_batch_size = -1
130 | self.max_seq_length = -1
131 |
132 | def setup_spk_cond_mask(self):
133 | self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool)
134 | self.spk_cond_mask[0] = 1
135 |
136 | def setup_caches(self, max_batch_size, max_seq_length):
137 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
138 | return
139 | head_dim = self.config.dim // self.config.n_head
140 | max_seq_length = find_multiple(max_seq_length, 8)
141 | self.max_seq_length = max_seq_length
142 | self.max_batch_size = max_batch_size
143 | for b in self.layers:
144 | b.attention.kv_cache = KVCache(
145 | max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype
146 | )
147 |
148 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
149 |
150 | def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor:
151 | mask = self.causal_mask[None, None, input_pos]
152 | x = (
153 | self.tok_embeddings(idx)
154 | + self.pos_embeddings(input_pos)
155 | # masking for speaker condition free guidance
156 | + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask
157 | )
158 |
159 | for i, layer in enumerate(self.layers):
160 | x = layer(x, input_pos, mask)
161 | x = self.norm(x)
162 | logits = self.output(x)
163 | return logits
164 |
165 | @classmethod
166 | def from_name(cls, name: str):
167 | return cls(ModelArgs.from_name(name))
168 |
169 |
170 | class TransformerBlock(nn.Module):
171 | def __init__(self, config: ModelArgs) -> None:
172 | super().__init__()
173 | self.attention = Attention(config)
174 | self.feed_forward = FeedForward(config)
175 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
176 | self.attention_norm = RMSNorm(config.dim, config.norm_eps)
177 |
178 | def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
179 | h = x + self.attention(self.attention_norm(x), mask, input_pos)
180 | out = h + self.feed_forward(self.ffn_norm(h))
181 | return out
182 |
183 |
184 | class Attention(nn.Module):
185 | def __init__(self, config: ModelArgs):
186 | super().__init__()
187 | assert config.dim % config.n_head == 0
188 |
189 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
190 | # key, query, value projections for all heads, but in a batch
191 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
192 | self.wo = nn.Linear(config.dim, config.dim, bias=False)
193 | self.kv_cache = None
194 |
195 | self.n_head = config.n_head
196 | self.head_dim = config.head_dim
197 | self.n_local_heads = config.n_local_heads
198 | self.dim = config.dim
199 |
200 | def forward(
201 | self,
202 | x: Tensor,
203 | mask: Tensor,
204 | input_pos: Optional[Tensor] = None,
205 | ) -> Tensor:
206 | bsz, seqlen, _ = x.shape
207 |
208 | kv_size = self.n_local_heads * self.head_dim
209 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
210 |
211 | q = q.view(bsz, seqlen, self.n_head, self.head_dim)
212 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
213 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
214 |
215 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
216 |
217 | if self.kv_cache is not None:
218 | k, v = self.kv_cache.update(input_pos, k, v)
219 |
220 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
221 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
222 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
223 |
224 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
225 |
226 | y = self.wo(y)
227 | return y
228 |
229 |
230 | class SwiGLU(nn.Module):
231 | def __init__(self, config: ModelArgs) -> None:
232 | super().__init__()
233 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
234 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
235 |
236 | def forward(self, x: Tensor) -> Tensor:
237 | return F.silu(self.w1(x)) * self.w3(x)
238 |
239 |
240 | class FeedForward(nn.Module):
241 | def __init__(self, config: ModelArgs) -> None:
242 | super().__init__()
243 | self.swiglu = SwiGLU(config)
244 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
245 |
246 | def forward(self, x: Tensor) -> Tensor:
247 | return self.w2(self.swiglu(x))
248 |
249 |
250 | class RMSNorm(nn.Module):
251 | def __init__(self, dim: int, eps: float = 1e-5):
252 | super().__init__()
253 | self.eps = eps
254 | self.weight = nn.Parameter(torch.ones(dim))
255 |
256 | def _norm(self, x):
257 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
258 |
259 | def forward(self, x: Tensor) -> Tensor:
260 | output = self._norm(x.float()).type_as(x)
261 | return output * self.weight
262 |
--------------------------------------------------------------------------------
/fam/llm/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.layers.attn import SelfAttention
2 | from fam.llm.layers.combined import Block
3 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU
4 |
--------------------------------------------------------------------------------
/fam/llm/layers/attn.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class SelfAttention(nn.Module):
9 | def __init__(self, config):
10 | """
11 | Initializes the SelfAttention module.
12 |
13 | Args:
14 | config: An object containing the configuration parameters for the SelfAttention module.
15 | """
16 | super().__init__()
17 | self._validate_config(config)
18 | self._initialize_parameters(config)
19 |
20 | def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
21 | """
22 | Empties the key-value cache.
23 |
24 | Args:
25 | batch_size: The batch size.
26 | kv_cache_maxlen: The maximum length of the key-value cache.
27 | dtype: The data type of the cache.
28 |
29 | Raises:
30 | Exception: If trying to empty the KV cache when it is disabled.
31 | """
32 | if self.kv_cache_enabled is False:
33 | raise Exception("Trying to empty KV cache when it is disabled")
34 |
35 | # register so that the cache moves devices along with the module
36 | # TODO: get rid of re-allocation.
37 | self.register_buffer(
38 | "kv_cache",
39 | torch.zeros(
40 | 2,
41 | batch_size,
42 | kv_cache_maxlen,
43 | self.n_head,
44 | self.n_embd // self.n_head,
45 | dtype=dtype,
46 | device=self.c_attn.weight.device,
47 | ),
48 | persistent=False,
49 | )
50 |
51 | self.kv_cache_first_empty_index = 0
52 |
53 | def _initialize_parameters(self, config):
54 | """
55 | Initializes the parameters of the SelfAttention module.
56 |
57 | Args:
58 | config: An object containing the configuration parameters for the SelfAttention module.
59 | """
60 | # key, query, value projections for all heads, but in a batch
61 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
62 |
63 | # output projection
64 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
65 |
66 | # regularization
67 | self.resid_dropout = nn.Dropout(config.dropout)
68 | self.n_head = config.n_head
69 | self.n_embd = config.n_embd
70 | self.dropout = config.dropout
71 | self.causal = config.causal
72 | self.attn_kernel_type = config.attn_kernel_type
73 | self.attn_dropout = nn.Dropout(config.dropout)
74 |
75 | self.kv_cache_enabled = False
76 |
77 | def _validate_config(self, config):
78 | """
79 | Validates the configuration parameters.
80 |
81 | Args:
82 | config: An object containing the configuration parameters for the SelfAttention module.
83 |
84 | Raises:
85 | AssertionError: If the embedding dimension is not divisible by the number of heads.
86 | """
87 | assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads"
88 |
89 | def _update_kv_cache(self, q, k, v):
90 | """
91 | Updates the key-value cache.
92 |
93 | Args:
94 | q: The query tensor.
95 | k: The key tensor.
96 | v: The value tensor.
97 |
98 | Returns:
99 | The updated key and value tensors.
100 |
101 | Raises:
102 | AssertionError: If the dimensions of the query, key, and value tensors are not compatible.
103 | """
104 | q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1]
105 |
106 | if self.kv_cache_first_empty_index == 0:
107 | assert q_time == k_time and q_time == v_time
108 | else:
109 | assert (
110 | q_time == 1
111 | ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}"
112 |
113 | self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k
114 | self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v
115 | self.kv_cache_first_empty_index += q_time
116 |
117 | k = self.kv_cache[0, :, : self.kv_cache_first_empty_index]
118 | v = self.kv_cache[1, :, : self.kv_cache_first_empty_index]
119 |
120 | return k, v
121 |
122 | def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
123 | """
124 | Performs attention using the torch.nn.functional.scaled_dot_product_attention function.
125 |
126 | Args:
127 | c_x: The input tensor.
128 |
129 | Returns:
130 | The output tensor.
131 | """
132 | q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs)
133 | q = q.squeeze(2) # (B, T, nh, hs)
134 | k = k.squeeze(2) # (B, T, nh, hs)
135 | v = v.squeeze(2) # (B, T, nh, hs)
136 |
137 | # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
138 | # use no mask for the "one time step" parts.
139 | # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
140 | is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)
141 |
142 | if self.kv_cache_enabled:
143 | k, v = self._update_kv_cache(q, k, v)
144 |
145 | q = q.transpose(1, 2) # (B, nh, T, hs)
146 | k = k.transpose(1, 2) # (B, nh, T, hs)
147 | v = v.transpose(1, 2) # (B, nh, T, hs)
148 | y = torch.nn.functional.scaled_dot_product_attention(
149 | q,
150 | k,
151 | v,
152 | attn_mask=None,
153 | dropout_p=self.dropout if self.training else 0,
154 | is_causal=is_causal_attn_mask,
155 | ).transpose(
156 | 1, 2
157 | ) # (B, nh, T, hs) -> (B, T, nh, hs)
158 |
159 | return y
160 |
161 | def forward(self, x):
162 | """
163 | Performs the forward pass of the SelfAttention module.
164 |
165 | Args:
166 | x: The input tensor.
167 |
168 | Returns:
169 | The output tensor.
170 | """
171 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
172 |
173 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
174 | c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)
175 |
176 | # causal self-attention;
177 | if self.attn_kernel_type == "torch_attn":
178 | y = self._torch_attn(c_x)
179 | else:
180 | raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")
181 |
182 | y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh)
183 | # output projection
184 | y = self.resid_dropout(self.c_proj(y))
185 | return y
186 |
--------------------------------------------------------------------------------
/fam/llm/layers/combined.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from fam.llm.layers.attn import SelfAttention
4 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
5 |
6 |
7 | class Block(nn.Module):
8 | """
9 | Block class represents a single block in the model.
10 |
11 | Args:
12 | config (object): Configuration object containing parameters for the block.
13 |
14 | Attributes:
15 | ln_1 (object): Layer normalization for the attention layer.
16 | ln_2 (object): Layer normalization for the feed-forward layer.
17 | attn (object): Self-attention layer.
18 | mlp (object): Multi-layer perceptron layer.
19 |
20 | Methods:
21 | forward(x): Performs forward pass through the block.
22 | """
23 |
24 | def __init__(self, config):
25 | super().__init__()
26 | if config.norm_type == "rmsnorm":
27 | if config.rmsnorm_eps is None:
28 | raise Exception("RMSNorm requires rmsnorm_eps to be set")
29 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
30 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
31 | elif config.norm_type == "layernorm":
32 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
33 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
34 | else:
35 | raise Exception(f"Unknown norm type: {config.norm_type}")
36 | self.attn = SelfAttention(config)
37 |
38 | self.mlp = MLP(config)
39 |
40 | def forward(self, x):
41 | """
42 | Performs forward pass through the block.
43 |
44 | Args:
45 | x (tensor): Input tensor.
46 |
47 | Returns:
48 | tensor: Output tensor after passing through the block.
49 | """
50 | x = x + self.attn(self.ln_1(x))
51 | x = x + self.mlp(self.ln_2(x))
52 | return x
53 |
--------------------------------------------------------------------------------
/fam/llm/layers/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class LayerNorm(nn.Module):
9 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
10 |
11 | def __init__(self, ndim, bias):
12 | super().__init__()
13 | self.weight = nn.Parameter(torch.ones(ndim))
14 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
15 |
16 | def forward(self, input):
17 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
18 |
19 |
20 | class RMSNorm(torch.nn.Module):
21 | def __init__(self, ndim: int, eps: float):
22 | super().__init__()
23 | self.eps = eps
24 | self.weight = nn.Parameter(torch.ones(ndim))
25 |
26 | def _norm(self, x):
27 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28 |
29 | def forward(self, x):
30 | return self._norm(x) * self.weight
31 |
32 |
33 | class SwiGLU(nn.Module):
34 | def __init__(self, in_dim, out_dim, bias) -> None:
35 | super().__init__()
36 | self.w1 = nn.Linear(in_dim, out_dim, bias=bias)
37 | self.w3 = nn.Linear(in_dim, out_dim, bias=bias)
38 |
39 | def forward(self, x):
40 | return F.silu(self.w1(x)) * self.w3(x)
41 |
42 |
43 | class MLP(nn.Module):
44 | def __init__(self, config):
45 | super().__init__()
46 | self.non_linearity = config.nonlinearity_type
47 | hidden_dim = 4 * config.n_embd
48 | if config.nonlinearity_type == "gelu":
49 | self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
50 | self.gelu = nn.GELU()
51 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
52 | elif config.nonlinearity_type == "swiglu":
53 | if config.swiglu_multiple_of is None:
54 | raise Exception("SwiGLU requires swiglu_multiple_of to be set")
55 | hidden_dim = int(2 * hidden_dim / 3)
56 | hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of)
57 | # set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__()
58 | self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias)
59 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
60 | else:
61 | raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}")
62 | self.dropout = nn.Dropout(config.dropout)
63 |
64 | def forward(self, x):
65 | if self.non_linearity == "gelu":
66 | x = self.c_fc(x)
67 | x = self.gelu(x)
68 | elif self.non_linearity == "swiglu":
69 | x = self.swiglu(x)
70 | x = self.c_proj(x)
71 | x = self.dropout(x)
72 | return x
73 |
--------------------------------------------------------------------------------
/fam/llm/mixins/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.mixins.causal import CausalInferenceMixin
2 | from fam.llm.mixins.non_causal import NonCausalInferenceMixin
3 |
--------------------------------------------------------------------------------
/fam/llm/mixins/causal.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | import tqdm
6 | from torch.nn import functional as F
7 | from IPython import embed
8 |
9 | def top_p_sample(prob_dist: torch.Tensor, top_p: float):
10 | sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True, dim=-1)
11 | cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) # (b, vocab_size)
12 |
13 | sorted_indices_to_remove = cum_sum_probs > top_p
14 |
15 | # Shift the indices to the right to keep also the first token above the threshold
16 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
17 | sorted_indices_to_remove[:, 0] = 0
18 | sorted_indices_to_remove = sorted_indices_to_remove.bool()
19 |
20 | # replace probs to be removed with 0 in the sorted_probs
21 | sorted_probs[sorted_indices_to_remove] = 0
22 |
23 | # reverse the sorting process
24 | reversed_indices = torch.argsort(sorted_indices)
25 | prob_dist = torch.gather(sorted_probs, -1, reversed_indices)
26 |
27 | # normalize
28 | prob_dist = prob_dist / prob_dist.sum(dim=-1, keepdim=True)
29 |
30 | return prob_dist
31 |
32 |
33 | class CausalInferenceMixin:
34 | """
35 | Mixin class for performing inference in a causal language model.
36 |
37 | This mixin provides methods for predicting the next token in a sequence, sampling from the model,
38 | and applying token prediction masks.
39 |
40 | Attributes:
41 | None
42 |
43 | Methods:
44 | _sample_next_token: Predicts the next token in the sequence.
45 | _create_token_pred_mask: Creates a token prediction mask based on sequence lengths.
46 | _apply_token_pred_mask: Applies a token prediction mask to the next token predictions.
47 | _sample_batch: Samples a batch of tokens from the model.
48 | _sort_for_batching: Sorts the input sequences for efficient batching.
49 | _causal_sample: Generates a sequence of tokens using causal sampling.
50 |
51 | """
52 |
53 | @torch.no_grad()
54 | def _sample_next_token(
55 | self,
56 | *,
57 | idx: torch.Tensor,
58 | speaker_embs: Optional[torch.Tensor],
59 | temperature: float,
60 | top_k: Optional[int],
61 | top_p: Optional[float],
62 | guidance_scale: Optional[float],
63 | ) -> torch.Tensor:
64 | """
65 | Predict the next token in the sequence.
66 |
67 | Args:
68 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
69 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
70 | temperature (float): Sampling temperature.
71 | top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
72 | top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
73 | guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
74 |
75 | Returns:
76 | torch.Tensor: Next index in the sequence after sampling. Shape: (batch, num_hierarchies).
77 | """
78 | if top_k is not None and top_p is not None:
79 | raise ValueError("Only one of top_k and top_p can be set")
80 |
81 | # if the sequence context is growing too long we must crop it at block_size
82 | idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, :, -self.config.block_size :]
83 |
84 | # forward the model to get the logits for the index in the sequence
85 | list_logits, _ = self(
86 | idx_cond, speaker_embs=speaker_embs
87 | ) # list with len num_hierarchies of (b,1,vocab_size) tensors
88 | # print(f'{list_logits[0].shape=}, {len(list_logits)=}')
89 | # print(f'{list_logits[0][:,:,:10]}')
90 |
91 | if guidance_scale is not None:
92 | assert idx_cond.shape[0] % 2 == 0
93 | assert list_logits[0].shape[0] % 2 == 0
94 |
95 | for i, logits in enumerate(list_logits):
96 | logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0)
97 | list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond
98 |
99 | assert list_logits[0].shape[0] == idx_cond.shape[0] // 2
100 |
101 | # pluck the logits at the final step and scale by desired temperature
102 | list_logits = [
103 | logits[:, -1, :] / temperature for logits in list_logits
104 | ] # list with len num_hierarchies of (b,vocab_size) tensors
105 |
106 | # optionally crop the logits to only the top k options
107 | if top_k is not None:
108 | for i in range(len(list_logits)):
109 | logits = list_logits[i]
110 | v, _ = torch.topk(
111 | logits, min(top_k, logits.size(-1))
112 | ) # returns a descending sorted list of values and indices of top_k values
113 | logits[logits < v[:, [-1]]] = -float("Inf") # set all logits below the smallest top_k value to -Inf
114 | list_logits[i] = logits
115 |
116 | # apply softmax to convert logits to (normalized) probabilities
117 | # embed()
118 | probs = [
119 | F.softmax(logits, dim=-1) for logits in list_logits
120 | ] # list of len num_hierarchies of (b,vocab_size) tensors
121 | # print(f'{probs[0].shape=}')
122 | # print(f'{probs[0][:,:,:10]}')
123 | if top_p is not None:
124 | for i in range(len(probs)):
125 | probs[i] = top_p_sample(probs[i], top_p)
126 |
127 | # sample from the distribution
128 | idx_next = [
129 | torch.multinomial(prob, num_samples=1) for prob in probs
130 | ] # list of len num_hierarchies of (b,1) tensors
131 | idx_next = torch.cat(idx_next, dim=-1) # (b, num_hierarchies) tensor
132 |
133 | return idx_next # (b, num_hierarchies) tensor
134 |
135 | @torch.no_grad()
136 | def _create_token_pred_mask(self, idx: torch.Tensor, seq_lens: list[int]) -> torch.Tensor:
137 | """
138 | Creates a token prediction mask based on sequence lengths.
139 |
140 | Args:
141 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
142 | seq_lens (list[int]): List of sequence lengths for each sequence in idx.
143 |
144 | Returns:
145 | torch.Tensor: Token prediction mask of shape (batch, time).
146 | """
147 | token_pred_mask = torch.zeros((idx.shape[0], idx.shape[-1]), dtype=torch.bool, device=idx.device)
148 | for i in range(len(seq_lens)):
149 | token_pred_mask[i, : seq_lens[i]] = True
150 |
151 | assert (token_pred_mask[:, : min(seq_lens)] == 1).all()
152 |
153 | return token_pred_mask
154 |
155 | @torch.no_grad()
156 | def _apply_token_pred_mask(
157 | self, *, idx_next: torch.Tensor, orig_input_at_t: torch.Tensor, token_pred_mask_at_t: torch.Tensor
158 | ) -> torch.Tensor:
159 | """
160 | Applies a token prediction mask to the next token predictions.
161 |
162 | Args:
163 | idx_next (torch.Tensor): Next token predictions of shape (batch, num_hierarchies).
164 | orig_input_at_t (torch.Tensor): Original input at time step t of shape (batch, num_hierarchies).
165 | token_pred_mask_at_t (torch.Tensor): Token prediction mask at time step t of shape (batch, 1).
166 |
167 | Returns:
168 | torch.Tensor: Updated next token predictions after applying the token prediction mask.
169 | """
170 | idx_next = idx_next * (~token_pred_mask_at_t) + orig_input_at_t * token_pred_mask_at_t
171 |
172 | return idx_next
173 |
174 | @torch.no_grad()
175 | def _sample_batch(
176 | self,
177 | *,
178 | idx: torch.Tensor,
179 | max_new_tokens: int,
180 | seq_lens: list[int],
181 | temperature: float,
182 | top_k: Optional[int],
183 | top_p: Optional[float],
184 | speaker_embs: Optional[torch.Tensor],
185 | guidance_scale: Optional[float],
186 | ):
187 | """
188 | Samples a batch of tokens from the model.
189 |
190 | Args:
191 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
192 | max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
193 | seq_lens (list[int]): List of sequence lengths for each sequence in idx.
194 | temperature (float): Sampling temperature.
195 | top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
196 | top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
197 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
198 | guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
199 |
200 | Returns:
201 | torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
202 | """
203 | assert max(seq_lens) <= idx.shape[-1]
204 | token_pred_mask = self._create_token_pred_mask(idx, seq_lens)
205 | input = torch.clone(idx)
206 |
207 | min_seq_lens = min(seq_lens)
208 | idx = idx[:, :, :min_seq_lens]
209 |
210 | if guidance_scale is not None:
211 | if speaker_embs is None:
212 | raise Exception("Guidance is only supported for conditional models")
213 |
214 | # create speaker embeddings equivalent to the batch size, filling with None
215 | # for second half to do unconditional generation.
216 | speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0])
217 |
218 | for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "):
219 | if (self.kv_cache_enabled is True) and (timestep > min_seq_lens):
220 | idx_input = idx[:, :, -1:]
221 | else:
222 | idx_input = idx
223 |
224 | if guidance_scale is not None:
225 | # TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size.
226 | if timestep == min_seq_lens:
227 | print("[hack!!!!] Guidance is on, so we're doubling batch size!")
228 |
229 | # replicate idx in the batch dimension
230 | idx_input = (
231 | idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2])
232 | )
233 |
234 | # sanity checks
235 | assert idx_input.shape[0] % 2 == 0
236 |
237 | idx_next = self._sample_next_token(
238 | idx=idx_input,
239 | speaker_embs=speaker_embs,
240 | temperature=temperature,
241 | top_k=top_k,
242 | top_p=top_p,
243 | guidance_scale=guidance_scale,
244 | ) # (b, num_hierarchies)
245 |
246 | assert idx_next.shape[0] == idx.shape[0]
247 |
248 | if timestep < token_pred_mask.shape[-1]:
249 | idx_next = self._apply_token_pred_mask(
250 | idx_next=idx_next,
251 | orig_input_at_t=input[:, :, timestep],
252 | token_pred_mask_at_t=token_pred_mask[:, [timestep]],
253 | )
254 |
255 | idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor
256 | # append sampled index to the running sequence and continue
257 | idx = torch.cat((idx, idx_next), dim=2)
258 |
259 | return idx
260 |
261 | @torch.no_grad()
262 | def _sort_for_batching(
263 | self,
264 | *,
265 | idx: torch.Tensor,
266 | seq_lens: list[int],
267 | speaker_embs: Optional[torch.Tensor],
268 | batch_size: int,
269 | max_new_tokens: int,
270 | ) -> Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
271 | """
272 | Sorts the input sequences for efficient batching.
273 |
274 | Args:
275 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
276 | seq_lens (list[int]): List of sequence lengths for each sequence in idx.
277 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
278 | batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
279 | max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
280 |
281 | Returns:
282 | Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
283 | - sorted_indices (list[int]): List of indices of the input sequences that transform it into sorted order.
284 | - invert_sorted_indices (list[int]): List of indices to invert the sorted sequences back to the original order.
285 | - idx (torch.Tensor): Input sequence indices in sorted order.
286 | - seq_lens (list[int]): Sequence lengths in sorted order.
287 | - speaker_embs (Optional[torch.Tensor]): speaker embeddings in sorted order.
288 | - max_token_len (int): Effective maximum number of tokens to generate.
289 | """
290 | assert len(seq_lens) == idx.shape[0]
291 | assert max(seq_lens) <= idx.shape[-1]
292 |
293 | sorted_indices = np.argsort(seq_lens)
294 | inverted_sorted_indices = np.zeros(len(seq_lens), dtype=np.int32)
295 | inverted_sorted_indices[sorted_indices] = np.arange(len(seq_lens), dtype=np.int32)
296 |
297 | idx = idx[sorted_indices]
298 | seq_lens = [seq_lens[i] for i in sorted_indices]
299 | speaker_embs = speaker_embs[sorted_indices] if speaker_embs is not None else None
300 | max_token_len = 0
301 |
302 | # figure out effective max_tokens to generate
303 | for start_index in range(0, len(seq_lens), batch_size):
304 | end_index = min(start_index + batch_size, len(seq_lens))
305 | batch_seq_lens = seq_lens[start_index:end_index]
306 | # random heuristic...
307 | # # TODO: fix!
308 | max_token_len = max(max_token_len, min(batch_seq_lens) + max_new_tokens)
309 |
310 | return sorted_indices, inverted_sorted_indices, idx, seq_lens, speaker_embs, max_token_len
311 |
312 | @torch.no_grad()
313 | def _causal_sample(
314 | self,
315 | *,
316 | idx: torch.Tensor,
317 | max_new_tokens: int,
318 | seq_lens: list[int],
319 | temperature: float,
320 | top_k: Optional[int],
321 | top_p: Optional[float],
322 | speaker_embs: Optional[torch.Tensor],
323 | batch_size: int,
324 | guidance_scale: Optional[float] = None,
325 | ) -> torch.Tensor:
326 | """
327 | Generates a sequence of tokens using causal sampling.
328 |
329 | Args:
330 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
331 | max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
332 | seq_lens (list[int]): List of sequence lengths for each sequence in idx.
333 | temperature (float): Sampling temperature.
334 | top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
335 | top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
336 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
337 | batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
338 | guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
339 |
340 | Returns:
341 | torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
342 | """
343 | (
344 | _,
345 | invert_sorted_indices,
346 | idx,
347 | seq_lens,
348 | speaker_embs,
349 | max_token_len,
350 | ) = self._sort_for_batching(
351 | idx=idx, seq_lens=seq_lens, speaker_embs=speaker_embs, batch_size=batch_size, max_new_tokens=max_new_tokens
352 | )
353 |
354 | return_idx = torch.zeros((len(seq_lens), idx.size(1), max_token_len), dtype=torch.long, device=idx.device)
355 |
356 | for start_index in tqdm.tqdm(range(0, len(seq_lens), batch_size), desc="batch: "):
357 | end_index = min(start_index + batch_size, len(seq_lens))
358 |
359 | kv_batch_size = end_index - start_index
360 | if guidance_scale is not None:
361 | kv_batch_size = 2 * kv_batch_size
362 |
363 | if self.kv_cache_enabled:
364 | print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
365 | self.empty_kv_cache(
366 | batch_size=kv_batch_size,
367 | kv_cache_maxlen=self.config.block_size,
368 | dtype=torch.bfloat16,
369 | )
370 |
371 | batch_seq_lens = seq_lens[start_index:end_index]
372 | batch_max_new_tokens = max_token_len - min(batch_seq_lens)
373 |
374 | batch_idx = idx[start_index:end_index]
375 | batch_speaker_embs = speaker_embs[start_index:end_index] if speaker_embs is not None else None
376 |
377 | batch_idx = self._sample_batch(
378 | idx=batch_idx,
379 | max_new_tokens=batch_max_new_tokens,
380 | seq_lens=batch_seq_lens,
381 | temperature=temperature,
382 | top_k=top_k,
383 | top_p=top_p,
384 | speaker_embs=batch_speaker_embs,
385 | guidance_scale=guidance_scale,
386 | )
387 | return_idx[start_index:end_index] = batch_idx
388 |
389 | return return_idx[invert_sorted_indices]
390 |
391 | def empty_kv_cache(self, *, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
392 | """
393 | Empties key-value (KV) cache for causal attention.
394 |
395 | Args:
396 | batch_size (int): The batch size.
397 | kv_cache_maxlen (int): The maximum length of the KV cache.
398 | dtype (torch.dtype): The data type of the KV cache.
399 |
400 | Raises:
401 | Exception: If KV cache is enabled for non-causal attention.
402 |
403 | """
404 | if self.kv_cache_enabled is False:
405 | raise Exception("KV cache is not enabled")
406 | if self.config.causal is False:
407 | raise Exception("KV cache is not supported for non-causal attention")
408 |
409 | self.kv_pos = 0
410 | for block in self.transformer.h:
411 | block.attn.empty_kv_cache(batch_size=batch_size, kv_cache_maxlen=kv_cache_maxlen, dtype=dtype)
412 |
413 | def enable_kv_cache(self):
414 | """
415 | Enables key-value (KV) cache for causal attention.
416 |
417 | Raises:
418 | Exception: If KV cache is enabled for non-causal attention.
419 |
420 | """
421 | if self.config.causal is False:
422 | raise Exception("KV cache is not supported for non-causal attention")
423 |
424 | self.kv_cache_enabled = True
425 | for block in self.transformer.h:
426 | block.attn.kv_cache_enabled = True
427 |
428 | def disable_kv_cache(self):
429 | """
430 | Disables the key-value cache for the transformer and all its blocks.
431 | """
432 | self.kv_cache_enabled = False
433 | for block in self.transformer.h:
434 | block.attn.kv_cache_enabled = False
435 | block.attn.kv_cache = None
436 | block.attn.kv_cache_first_empty_index = 0
437 |
438 | @torch.no_grad()
439 | def _slow_causal_sampling_loop(
440 | self,
441 | idx: torch.Tensor,
442 | max_new_tokens: int,
443 | temperature: float = 1.0,
444 | top_k: Optional[int] = None,
445 | top_p: Optional[float] = None,
446 | speaker_embs: Optional[torch.Tensor] = None,
447 | guidance_scale: Optional[float] = None,
448 | ):
449 | """
450 | Old non-batched version of causal sampling. Kept for testing / reference.
451 |
452 | Take a conditioning sequence of indices idx (LongTensor of shape (b,n_head,t)) and complete
453 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
454 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
455 | """
456 | assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
457 | assert idx.size(0) == 1, "can only do one sequence at a time for now"
458 | assert top_p is None, "nucleus sampling not supported yet with _slow_causal_sampling_loop"
459 |
460 | if self.config.causal is not True:
461 | raise Exception("Causal sampling is only supported for causal models")
462 |
463 | if self.kv_cache_enabled:
464 | print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
465 | self.empty_kv_cache(
466 | batch_size=1,
467 | kv_cache_maxlen=self.config.block_size,
468 | dtype=torch.bfloat16,
469 | )
470 |
471 | for i in range(max_new_tokens):
472 | # if the sequence context is growing too long we must crop it at block_size
473 | idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, -self.config.block_size :]
474 |
475 | if self.kv_cache_enabled:
476 | if i > 0:
477 | idx_cond = idx_cond[:, :, -1:]
478 |
479 | # forward the model to get the logits for the index in the sequence
480 | list_logits, _ = self(idx_cond, speaker_embs=speaker_embs)
481 |
482 | if guidance_scale is not None:
483 | # we've already checked that kv-caching is not switched on
484 | # so this should be ok.
485 | list_logits_uncond, _ = self(idx_cond, speaker_embs=None)
486 | list_logits = [
487 | (guidance_scale) * logits + (1 - guidance_scale) * logits_uncond
488 | for logits, logits_uncond in zip(list_logits, list_logits_uncond)
489 | ]
490 |
491 | # pluck the logits at the final step and scale by desired temperature
492 | list_logits = [logits[:, -1, :] / temperature for logits in list_logits]
493 |
494 | # optionally crop the logits to only the top k options
495 | if top_k is not None:
496 | for i in range(len(list_logits)):
497 | logits = list_logits[i]
498 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
499 | logits[logits < v[:, [-1]]] = -float("Inf")
500 | list_logits[i] = logits
501 |
502 | # apply softmax to convert logits to (normalized) probabilities
503 | probs = [F.softmax(logits, dim=-1) for logits in list_logits]
504 | # sample from the distribution
505 | idx_next = torch.tensor(
506 | [torch.multinomial(prob, num_samples=1) for prob in probs], device=idx.device
507 | ) # (c, 1)
508 | # append sampled index to the running sequence and continue
509 | idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(-1)), dim=2)
510 |
511 | return idx
512 |
--------------------------------------------------------------------------------
/fam/llm/mixins/non_causal.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | class NonCausalInferenceMixin:
8 | """
9 | Mixin class for non-causal inference in a language model.
10 |
11 | This class provides methods for performing non-causal sampling using a language model.
12 | """
13 |
14 | @torch.no_grad()
15 | def _non_causal_sample(
16 | self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
17 | ):
18 | """
19 | Perform non-causal sampling.
20 |
21 | Args:
22 | idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
23 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
24 | temperature (float): Temperature parameter for scaling the logits.
25 | top_k (int): Number of top options to consider.
26 |
27 | Returns:
28 | torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
29 | """
30 | b, c, t = idx.size()
31 | assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
32 | # forward the model to get the logits for the index in the sequence
33 | list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size)
34 |
35 | # scale by desired temperature
36 | list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size)
37 |
38 | # optionally crop the logits to only the top k options
39 | if top_k is not None:
40 | for i in range(len(list_logits)):
41 | logits = list_logits[i] # (b, t, vocab_size)
42 |
43 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k)
44 | logits[logits < v[:, :, [-1]]] = -float("Inf")
45 | list_logits[i] = logits # (b, t, vocab_size)
46 | assert logits.shape[0] == b and logits.shape[1] == t
47 |
48 | # apply softmax to convert logits to (normalized) probabilities
49 | # TODO: check shapes here!
50 | probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k)
51 | assert probs[0].shape[0] == b and probs[0].shape[1] == t
52 |
53 | # TODO: output shape is as expected
54 | outs = []
55 | for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k)
56 | out = [
57 | torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
58 | ] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
59 | assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
60 | out = torch.cat(out, dim=0) # (b, 1, t)
61 | assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
62 | outs.append(out)
63 |
64 | out = torch.cat(outs, dim=1) # (b, c, t)
65 | assert out.shape[0] == b and out.shape[2] == t
66 |
67 | return out
68 |
--------------------------------------------------------------------------------
/fam/llm/model.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import math
3 | from dataclasses import dataclass, field
4 | from typing import Literal, Optional, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | import tqdm
9 | from einops import rearrange
10 | from torch.nn import functional as F
11 |
12 | from fam.llm.layers import Block, LayerNorm, RMSNorm
13 | from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin
14 |
15 | from IPython import embed
16 | END_OF_TEXT_TOKEN = 1537
17 |
18 |
19 | def _select_spkemb(spkemb, mask):
20 | _, examples, _ = spkemb.shape
21 | mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples)
22 | spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex
23 | mask = mask.transpose(1, 2) # b t ex -> b ex t
24 | return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c
25 |
26 |
27 | @dataclass
28 | class GPTConfig:
29 | block_size: int = 1024
30 | vocab_sizes: list = field(default_factory=list)
31 | target_vocab_sizes: Optional[list] = None
32 | n_layer: int = 12
33 | n_head: int = 12
34 | n_embd: int = 768
35 | dropout: float = 0.0
36 | spkemb_dropout: float = 0.0
37 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
38 | causal: bool = (
39 | True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens
40 | )
41 | spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not
42 | norm_type: str = "layernorm" # "rmsnorm" or "layernorm
43 | rmsnorm_eps: Optional[float] = None # only used for rmsnorm
44 | nonlinearity_type: str = "gelu" # "gelu" or "swiglu"
45 | swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this
46 | attn_kernel_type: Literal["torch_attn"] = "torch_attn"
47 | #Literal["fa2", "torch_attn", "hand"] = "fa2"
48 | kv_cache_enabled: bool = False # whether to use key-value cache for attention
49 |
50 |
51 | def _check_speaker_emb_dims(
52 | speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int
53 | ) -> Union[torch.Tensor, list]:
54 | """
55 | Checks that the speaker embedding dimensions are correct, and reshapes them if necessary.
56 | """
57 | if type(speaker_embs) == list:
58 | b_se = len(speaker_embs)
59 | for i, s in enumerate(speaker_embs):
60 | if s is not None:
61 | emb_dim = s.shape[-1]
62 | if s.ndim == 1:
63 | speaker_embs[i] = speaker_embs[i].unsqueeze(0)
64 | else:
65 | if speaker_embs.ndim == 2:
66 | # if we have a single speaker embedding for the whole sequence,
67 | # add a dummy dimension for backwards compatibility
68 | speaker_embs = speaker_embs[:, None, :]
69 |
70 | # num_examples is the number of utterances packed into this sequence
71 | b_se, num_examples, emb_dim = speaker_embs.size()
72 |
73 | assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}"
74 | assert (
75 | emb_dim == expected_speaker_emb_dim
76 | ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}"
77 |
78 | return speaker_embs
79 |
80 |
81 | class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin):
82 | def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None):
83 | """
84 | Initialize the GPT model.
85 |
86 | Args:
87 | config (GPTConfig): Configuration object for the model.
88 | speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None.
89 | """
90 | super().__init__()
91 | assert config.vocab_sizes is not None
92 | assert config.block_size is not None
93 | self.config = config
94 |
95 | self.kv_cache_enabled = False # disabled by default
96 | self.kv_pos = 0
97 |
98 | self.speaker_emb_dim = speaker_emb_dim
99 | self.spk_emb_on_text = config.spk_emb_on_text
100 | if self.config.causal is True and self.spk_emb_on_text is False:
101 | print("!!!!!!!!!!!!!!!!!!")
102 | print(
103 | f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this."
104 | )
105 | print("!!!!!!!!!!!!!!!!!!")
106 | if self.config.causal is False and self.spk_emb_on_text is False:
107 | raise Exception(
108 | "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding."
109 | )
110 |
111 | if config.norm_type == "rmsnorm":
112 | if config.rmsnorm_eps is None:
113 | raise Exception("RMSNorm requires rmsnorm_eps to be set")
114 | ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
115 | elif config.norm_type == "layernorm":
116 | ln_f = LayerNorm(config.n_embd, bias=config.bias)
117 | else:
118 | raise Exception(f"Unknown norm type: {config.norm_type}")
119 |
120 | self.transformer = nn.ModuleDict(
121 | dict(
122 | wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd,) for vsize in config.vocab_sizes]),
123 | wpe=nn.Embedding(config.block_size, config.n_embd),
124 | drop=nn.Dropout(config.dropout),
125 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
126 | ln_f=ln_f,
127 | )
128 | )
129 | if speaker_emb_dim is not None:
130 | self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False) # ここで256->2048
131 |
132 | self.lm_heads = nn.ModuleList()
133 | if config.target_vocab_sizes is not None:
134 | assert config.causal is False
135 | else:
136 | assert config.causal is True
137 |
138 | for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes:
139 | self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False))
140 |
141 | if config.target_vocab_sizes is None:
142 | for i in range(len(config.vocab_sizes)):
143 | # TODO: do we not need to take the transpose here?
144 | # https://paperswithcode.com/method/weight-tying
145 | self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore
146 | assert len(self.lm_heads) == len(
147 | self.transformer.wtes # type: ignore
148 | ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore
149 | # - causal
150 | # GPT(
151 | # (transformer): ModuleDict(
152 | # (wtes): ModuleList(
153 | # (0): Embedding(2562, 2048)
154 | # )
155 | # (wpe): Embedding(2048, 2048)
156 | # (drop): Dropout(p=0.0, inplace=False)
157 | # (h): ModuleList(
158 | # (0-23): 24 x Block(
159 | # (ln_1): RMSNorm()
160 | # (ln_2): RMSNorm()
161 | # (attn): SelfAttention(
162 | # (c_attn): Linear(in_features=2048, out_features=6144, bias=False)
163 | # (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
164 | # (resid_dropout): Dropout(p=0.0, inplace=False)
165 | # )
166 | # (mlp): MLP(
167 | # (swiglu): SwiGLU(
168 | # (w1): Linear(in_features=2048, out_features=5632, bias=False)
169 | # (w3): Linear(in_features=2048, out_features=5632, bias=False)
170 | # )
171 | # (c_proj): Linear(in_features=5632, out_features=2048, bias=False)
172 | # (dropout): Dropout(p=0.0, inplace=False)
173 | # )
174 | # )
175 | # )
176 | # (ln_f): RMSNorm()
177 | # )
178 | # (speaker_cond_pos): Linear(in_features=256, out_features=2048, bias=False)
179 | # (lm_heads): ModuleList(
180 | # (0): Linear(in_features=2048, out_features=2562, bias=False)
181 | # )
182 | # )
183 | # GPTConfig(block_size=2048, vocab_sizes=[2562], target_vocab_sizes=None, n_layer=24, n_head=16, n_embd=2048, dropout=0.0, spkemb_dropout=0.1, bias=False, causal=True, spk_emb_on_text=True, norm_type='rmsnorm', rmsnorm_eps=1e-05, nonlinearity_type='swiglu', swiglu_multiple_of=256, attn_kernel_type='torch_attn', kv_cache_enabled=False)
184 | #
185 | # - non causal
186 | # GPT(
187 | # (transformer): ModuleDict(
188 | # (wtes): ModuleList(
189 | # (0): Embedding(1538, 384)
190 | # (1): Embedding(1025, 384)
191 | # )
192 | # (wpe): Embedding(1024, 384)
193 | # (drop): Dropout(p=0.0, inplace=False)
194 | # (h): ModuleList(
195 | # (0-5): 6 x Block(
196 | # (ln_1): LayerNorm()
197 | # (ln_2): LayerNorm()
198 | # (attn): SelfAttention(
199 | # (c_attn): Linear(in_features=384, out_features=1152, bias=False)
200 | # (c_proj): Linear(in_features=384, out_features=384, bias=False)
201 | # (resid_dropout): Dropout(p=0.0, inplace=False)
202 | # )
203 | # (mlp): MLP(
204 | # (c_fc): Linear(in_features=384, out_features=1536, bias=False)
205 | # (gelu): GELU(approximate='none')
206 | # (c_proj): Linear(in_features=1536, out_features=384, bias=False)
207 | # (dropout): Dropout(p=0.0, inplace=False)
208 | # )
209 | # )
210 | # )
211 | # (ln_f): LayerNorm()
212 | # )
213 | # (speaker_cond_pos): Linear(in_features=256, out_features=384, bias=False)
214 | # (lm_heads): ModuleList(
215 | # (0-5): 6 x Linear(in_features=384, out_features=1025, bias=False)
216 | # )
217 | # )
218 | # GPTConfig(block_size=1024, vocab_sizes=[1538, 1025], target_vocab_sizes=[1025, 1025, 1025, 1025, 1025, 1025], n_layer=6, n_head=6, n_embd=384, dropout=0.0, spkemb_dropout=0.0, bias=False, causal=False, spk_emb_on_text=True, norm_type='layernorm', rmsnorm_eps=None, nonlinearity_type='gelu', swiglu_multiple_of=None, attn_kernel_type='fa2', kv_cache_enabled=False)
219 | # if config.causal is False:
220 | # embed()
221 | # init all weights
222 | self.apply(self._init_weights)
223 | # apply special scaled init to the residual projections, per GPT-2 paper
224 | for pn, p in self.named_parameters():
225 | if pn.endswith("c_proj.weight"):
226 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
227 |
228 | # report number of parameters
229 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
230 |
231 | def get_num_params(self, non_embedding=True):
232 | """
233 | Return the number of parameters in the model.
234 | For non-embedding count (default), the position embeddings get subtracted.
235 | The token embeddings would too, except due to the parameter sharing these
236 | params are actually used as weights in the final layer, so we include them.
237 | """
238 | n_params = sum(p.numel() for p in self.parameters())
239 | if non_embedding:
240 | n_params -= self.transformer.wpe.weight.numel()
241 | return n_params
242 |
243 | def _init_weights(self, module):
244 | if isinstance(module, nn.Linear):
245 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
246 | if module.bias is not None:
247 | torch.nn.init.zeros_(module.bias)
248 | elif isinstance(module, nn.Embedding):
249 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
250 |
251 | def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor:
252 | """
253 | This is in a separate function so we can test it easily.
254 | """
255 | # find index of end of text token in each sequence, then generate a binary mask
256 | # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token.
257 | # Note: this does NOT mask the token. This is important so that the first audio token predicted
258 | # has speaker information to use.
259 |
260 | # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens.
261 | is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN
262 | # use > 0, in case end_of_text_token is repeated for any reason.
263 | mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float()
264 | spk_emb = spk_emb * mask[:, :, None]
265 |
266 | return spk_emb
267 |
268 | def forward(
269 | self,
270 | idx,
271 | targets=None,
272 | speaker_embs=None,
273 | embedding=None,
274 | speaker_emb_mask=None,
275 | loss_reduce: Literal["mean", "none"] = "mean",
276 | ):
277 | device = idx.device
278 | b, num_hierarchies, t = idx.size()
279 |
280 | if speaker_embs is not None:
281 | speaker_embs = _check_speaker_emb_dims(
282 | speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b
283 | )
284 |
285 | assert (
286 | t <= self.config.block_size
287 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
288 |
289 | if self.kv_cache_enabled:
290 | if self.kv_pos == 0:
291 | pos = torch.arange(0, t, dtype=torch.long, device=device)
292 | self.kv_pos += t
293 | else:
294 | assert t == 1, "KV cache is only supported for single token inputs"
295 | pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1)
296 | self.kv_pos += 1
297 | else:
298 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
299 | # embed the tokens, positional encoding, and speaker embedding
300 | tok_emb = torch.zeros((b, t, self.config.n_embd), device=device)
301 | # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings.
302 | wte = self.transformer.wtes[0]
303 | for i, wte in enumerate(self.transformer.wtes):
304 | mask_pad = idx[:, i, :] == -1
305 | masked_idx = idx[:, i, :].clone()
306 | masked_idx[mask_pad] = 0
307 | embedded_idx = wte(masked_idx)
308 | embedded_idx[mask_pad] = 0
309 | tok_emb += embedded_idx
310 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
311 |
312 | spk_emb = 0.0
313 | if speaker_embs is not None:
314 | if type(speaker_embs) == list:
315 | assert speaker_emb_mask is None
316 | assert self.training is False
317 | assert self.spk_emb_on_text is True
318 | spk_emb = []
319 | for speaker_emb_row in speaker_embs:
320 | if speaker_emb_row is not None:
321 | spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0)))
322 | assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}"
323 | else:
324 | spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype))
325 | spk_emb = torch.cat(spk_emb, dim=0)
326 |
327 | assert (
328 | spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b
329 | ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}"
330 | else:
331 | speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c)
332 |
333 | if speaker_emb_mask is not None:
334 | spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask)
335 | assert spk_emb.shape == (b, t, self.config.n_embd)
336 | else:
337 | spk_emb = speakers_embedded
338 | # if we don't have a mask, we assume that the speaker embedding is the same for all tokens
339 | # then num_examples dimension just becomes the time dimension
340 | assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1
341 |
342 | if self.training and self.config.spkemb_dropout > 0.0:
343 | # Remove speaker conditioning at random.
344 | dropout = torch.ones_like(speakers_embedded) * (
345 | torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
346 | )
347 | spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)
348 |
349 | if self.spk_emb_on_text is False:
350 | assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False"
351 | spk_emb = self._mask_spk_emb_on_text(idx, spk_emb)
352 | elif embedding is not None:
353 | speakers_embedded = self.speaker_cond_pos(embedding)
354 |
355 | if self.training and self.config.spkemb_dropout > 0.0:
356 | # Remove speaker conditioning at random.
357 | dropout = torch.ones_like(speakers_embedded) * (
358 | torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
359 | )
360 | spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)
361 | else:
362 | spk_emb = speakers_embedded
363 |
364 | x = self.transformer.drop(tok_emb + pos_emb + spk_emb)
365 | for block in self.transformer.h:
366 | x = block(x)
367 | x = self.transformer.ln_f(x)
368 |
369 | if targets is not None:
370 | # if we are given some desired targets also calculate the loss
371 | list_logits = [lm_head(x) for lm_head in self.lm_heads]
372 |
373 | losses = [
374 | F.cross_entropy(
375 | logits.view(-1, logits.size(-1)),
376 | targets[:, i, :].contiguous().view(-1),
377 | ignore_index=-1,
378 | reduction=loss_reduce,
379 | )
380 | for i, logits in enumerate(list_logits)
381 | ]
382 | losses = torch.stack(losses)
383 | if loss_reduce == "mean":
384 | losses = losses.mean()
385 | else:
386 | losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t)
387 | else:
388 | # inference-time mini-optimization: only forward the lm_head on the very last position
389 | if self.config.causal:
390 | list_logits = [
391 | lm_head(x[:, [-1], :]) for lm_head in self.lm_heads
392 | ] # note: using list [-1] to preserve the time dim
393 | # print(f'{len(list_logits)=}, {list_logits[0].shape=}')
394 | else:
395 | list_logits = [lm_head(x) for lm_head in self.lm_heads]
396 | losses = None
397 |
398 | return list_logits, losses
399 |
400 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
401 | # start with all of the candidate parameters
402 | param_dict = {pn: p for pn, p in self.named_parameters()}
403 | # filter out those that do not require grad
404 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
405 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
406 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
407 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
408 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
409 | optim_groups = [
410 | {"params": decay_params, "weight_decay": weight_decay},
411 | {"params": nodecay_params, "weight_decay": 0.0},
412 | ]
413 | num_decay_params = sum(p.numel() for p in decay_params)
414 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
415 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
416 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
417 | # Create AdamW optimizer and use the fused version if it is available
418 | fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
419 | use_fused = fused_available and device_type == "cuda"
420 | extra_args = dict(fused=True) if use_fused else dict()
421 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
422 | print(f"using fused AdamW: {use_fused}")
423 |
424 | return optimizer
425 |
426 | @torch.no_grad()
427 | def generate(
428 | self,
429 | idx: torch.Tensor,
430 | max_new_tokens: int,
431 | seq_lens: Optional[list] = None,
432 | temperature: float = 1.0,
433 | top_k: Optional[int] = None,
434 | top_p: Optional[float] = None,
435 | speaker_embs: Optional[torch.Tensor] = None,
436 | batch_size: Optional[int] = None,
437 | guidance_scale: Optional[float] = None,
438 | ):
439 | """
440 | Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
441 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
442 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
443 | """
444 | assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
445 |
446 | if self.config.causal:
447 | if seq_lens is None or batch_size is None:
448 | raise Exception("seq_lens and batch_size must be provided for causal sampling")
449 |
450 | return self._causal_sample(
451 | idx=idx,
452 | max_new_tokens=max_new_tokens,
453 | seq_lens=seq_lens,
454 | temperature=temperature,
455 | top_k=top_k,
456 | top_p=top_p,
457 | speaker_embs=speaker_embs,
458 | batch_size=batch_size,
459 | guidance_scale=guidance_scale,
460 | )
461 |
462 | else:
463 | if seq_lens is not None:
464 | raise Exception("seq_lens is not supported yet for non-causal sampling")
465 |
466 | if batch_size is None:
467 | raise Exception("batch_size must be provided for non-causal sampling")
468 |
469 | if guidance_scale is not None:
470 | raise Exception("guidance_scale is not supported for non-causal sampling")
471 |
472 | if top_p is not None:
473 | raise Exception("top_p is not supported for non-causal sampling")
474 |
475 | out = []
476 | for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="non-causal batching"):
477 | end_index = min(start_index + batch_size, idx.shape[0])
478 | out.append(
479 | self._non_causal_sample(
480 | idx=idx[start_index:end_index],
481 | speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None,
482 | temperature=temperature,
483 | top_k=top_k,
484 | )
485 | )
486 | return torch.cat(out, dim=0)
487 | return torch.cat(out, dim=0)
488 |
--------------------------------------------------------------------------------
/fam/llm/serving.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import shlex
5 | import subprocess
6 | import tempfile
7 | from pathlib import Path
8 | from typing import Literal, Optional
9 |
10 | import fastapi
11 | import fastapi.middleware.cors
12 | import torch
13 | import tyro
14 | import uvicorn
15 | from attr import dataclass
16 | from fastapi import Request
17 | from fastapi.responses import Response
18 | from huggingface_hub import snapshot_download
19 |
20 | from fam.llm.sample import (
21 | InferenceConfig,
22 | Model,
23 | build_models,
24 | get_first_stage_path,
25 | get_second_stage_path,
26 | # sample_utterance,
27 | )
28 | from fam.llm.fast_inference import TTS
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | ## Setup FastAPI server.
34 | app = fastapi.FastAPI()
35 |
36 |
37 | @dataclass
38 | class ServingConfig:
39 | huggingface_repo_id: str
40 | """Absolute path to the model directory."""
41 |
42 | max_new_tokens: int = 864 * 2
43 | """Maximum number of new tokens to generate from the first stage model."""
44 |
45 | temperature: float = 1.0
46 | """Temperature for sampling applied to both models."""
47 |
48 | top_k: int = 200
49 | """Top k for sampling applied to both models."""
50 |
51 | seed: int = 1337
52 | """Random seed for sampling."""
53 |
54 | dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
55 | """Data type to use for sampling."""
56 |
57 | enhancer: Optional[Literal["df"]] = "df"
58 | """Enhancer to use for post-processing."""
59 |
60 | port: int = 58003
61 |
62 |
63 | # Singleton
64 | class _GlobalState:
65 | config: ServingConfig
66 | tts: TTS
67 |
68 |
69 | GlobalState = _GlobalState()
70 |
71 | @dataclass(frozen=True)
72 | class TTSRequest:
73 | text: str
74 | guidance: Optional[float] = 3.0
75 | top_p: Optional[float] = 0.95
76 | speaker_ref_path: Optional[str] = None
77 | top_k: Optional[int] = None
78 |
79 |
80 | def sample_utterance(
81 | text: str,
82 | spk_cond_path: str | None,
83 | guidance_scale,
84 | max_new_tokens,
85 | top_k,
86 | top_p,
87 | temperature,
88 | ) -> str:
89 | return GlobalState.tts.synthesise(
90 | text,
91 | spk_cond_path,
92 | top_p=top_p,
93 | guidance_scale=guidance_scale,
94 | temperature=temperature,
95 | )
96 |
97 |
98 | @app.post("/tts", response_class=Response)
99 | async def text_to_speech(req: Request):
100 | audiodata = await req.body()
101 | payload = None
102 | wav_out_path = None
103 |
104 | try:
105 | headers = req.headers
106 | payload = headers["X-Payload"]
107 | payload = json.loads(payload)
108 | tts_req = TTSRequest(**payload)
109 | with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
110 | if tts_req.speaker_ref_path is None:
111 | wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
112 | else:
113 | wav_path = tts_req.speaker_ref_path
114 | wav_out_path = sample_utterance(
115 | tts_req.text,
116 | wav_path,
117 | guidance_scale=tts_req.guidance,
118 | max_new_tokens=GlobalState.config.max_new_tokens,
119 | temperature=GlobalState.config.temperature,
120 | top_k=tts_req.top_k,
121 | top_p=tts_req.top_p,
122 | )
123 | with open(wav_out_path, "rb") as f:
124 | return Response(content=f.read(), media_type="audio/wav")
125 | except Exception as e:
126 | # traceback_str = "".join(traceback.format_tb(e.__traceback__))
127 | logger.exception(f"Error processing request {payload}")
128 | return Response(
129 | content="Something went wrong. Please try again in a few mins or contact us on Discord",
130 | status_code=500,
131 | )
132 | finally:
133 | if wav_out_path is not None:
134 | Path(wav_out_path).unlink(missing_ok=True)
135 |
136 |
137 | def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
138 | with tempfile.NamedTemporaryFile() as unknown_format_tmp:
139 | assert unknown_format_tmp.write(audiodata) > 0
140 | unknown_format_tmp.flush()
141 |
142 | subprocess.check_output(
143 | # arbitrary 2 minute cutoff
144 | shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}")
145 | )
146 |
147 | return wav_tmp.name
148 |
149 |
150 | if __name__ == "__main__":
151 | # This has to be here to avoid some weird audiocraft shenaningans messing up matplotlib
152 | from fam.llm.enhancers import get_enhancer
153 |
154 | for name in logging.root.manager.loggerDict:
155 | logger = logging.getLogger(name)
156 | logger.setLevel(logging.INFO)
157 | logging.root.setLevel(logging.INFO)
158 |
159 | GlobalState.config = tyro.cli(ServingConfig)
160 | app.add_middleware(
161 | fastapi.middleware.cors.CORSMiddleware,
162 | allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"],
163 | allow_credentials=True,
164 | allow_methods=["*"],
165 | allow_headers=["*"],
166 | )
167 |
168 | device = "cuda" if torch.cuda.is_available() else "cpu"
169 | common_config = dict(
170 | num_samples=1,
171 | seed=1337,
172 | device=device,
173 | dtype=GlobalState.config.dtype,
174 | compile=False,
175 | init_from="resume",
176 | output_dir=tempfile.mkdtemp(),
177 | )
178 | model_dir = snapshot_download(repo_id=GlobalState.config.huggingface_repo_id)
179 | config1 = InferenceConfig(
180 | ckpt_path=get_first_stage_path(model_dir),
181 | **common_config,
182 | )
183 |
184 | config2 = InferenceConfig(
185 | ckpt_path=get_second_stage_path(model_dir),
186 | **common_config,
187 | )
188 |
189 | GlobalState.tts = TTS()
190 |
191 | # start server
192 | uvicorn.run(
193 | app,
194 | host="127.0.0.1",
195 | port=GlobalState.config.port,
196 | log_level="info",
197 | )
198 |
--------------------------------------------------------------------------------
/fam/llm/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import wandb
5 | from torch import nn
6 | from torch.nn import functional as F
7 | import pytorch_lightning as pl
8 | from pytorch_lightning import Trainer
9 | from pytorch_lightning.strategies import FSDPStrategy, DDPStrategy
10 | from pytorch_lightning.plugins.environments import ClusterEnvironment
11 | from huggingface_hub import snapshot_download
12 |
13 | # Specific to your project's structure
14 | from fam.llm.training import parse_args, get_first_stage_path, get_second_stage_path, TrainingConfig, VoiceDataModule, WandbLogger, dist_utils, optimizer_utils, Evaluator
15 | from fam.llm.decoders import Decoder, EncodecDecoder
16 | from fam.quantiser.text.tokenise import TrainedBPETokeniser
17 | from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
18 | from fam.llm.sample import Model
19 |
20 | class MyClusterEnvironment(ClusterEnvironment):
21 | @property
22 | def creates_processes_externally(self) -> bool:
23 | """Return True if the cluster is managed (you don't launch processes yourself)"""
24 | return True
25 |
26 | def world_size(self) -> int:
27 | return int(os.environ["OMPI_COMM_WORLD_SIZE"])
28 |
29 | def global_rank(self) -> int:
30 | return int(os.environ["OMPI_COMM_WORLD_RANK"])
31 |
32 | def local_rank(self) -> int:
33 | return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
34 |
35 | def node_rank(self) -> int:
36 | return int(os.environ["OMPI_COMM_WORLD_NODE_RANK"])
37 |
38 | main_address = os.getenv("MASTER_ADDR","")
39 |
40 | main_port = int(os.getenv("MASTER_PORT", "0"))
41 |
42 | def set_global_rank(self, rank):
43 | self.global_rank_ = rank
44 |
45 | def set_world_size(self, size):
46 | self.world_size_ = size
47 |
48 | def detect(self):
49 | return True
50 |
51 | class KotobaSpeechModelFirstStage(pl.LightningModule):
52 | """
53 | A PyTorch Lightning module for the first stage of KotobaVoice model training.
54 | """
55 |
56 | def __init__(self, config_first_stage, config_second_stage, device, use_kv_cache, logger, is_debug=False, configs=None):
57 | super().__init__()
58 | self.configs = vars(configs)
59 | self.config_first_stage = config_first_stage
60 | self.use_kv_cache = use_kv_cache
61 | self.prev_step = -1
62 | self.evaluator = Evaluator()
63 | if dist_utils.is_main_process():
64 | self.wandb_logger = logger
65 |
66 | def configure_model(self):
67 | """
68 | Configures the model and its components.
69 | """
70 | self.data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
71 | self.first_stage_model_cls = Model(
72 | self.config_first_stage,
73 | TrainedBPETokeniser,
74 | EncodecDecoder,
75 | data_adapter_fn=self.data_adapter.decode,
76 | use_kv_cache=self.use_kv_cache,
77 | )
78 | self.first_stage_model_transformer = self.first_stage_model_cls.model
79 |
80 | def forward(self, text_tokens, embedding, inputs, targets):
81 | truncated_inputs = inputs[:,:2048]
82 | truncated_targets = targets[:,:2048]
83 | truncated_inputs = truncated_inputs.unsqueeze(1)
84 | truncated_targets = truncated_targets.unsqueeze(1)
85 | list_logits, first_stage_loss = self.first_stage_model_transformer(idx=truncated_inputs, embedding=embedding, targets=truncated_targets, loss_reduce="mean")
86 | return first_stage_loss
87 |
88 | def training_step(self, batch, batch_idx):
89 | tokens, embedding, audio_tokens, first_stage_input, first_stage_output = batch
90 | embedding = embedding.unsqueeze(1)
91 | loss = self(text_tokens=tokens, embedding=embedding, inputs=first_stage_input, targets=first_stage_output)
92 |
93 | if dist_utils.is_main_process():
94 | if self.prev_step == self.global_step:
95 | pass
96 | else:
97 | stats = {
98 | "loss": loss.item(),
99 | }
100 | if self.wandb_logger is not None:
101 | self.wandb_logger.log(results=stats, split="training", step=self.global_step, commit=False)
102 | lr_stats = {f"lr_group{i}": list(self.optimizers().param_groups)[i]["lr"] for i in range(len(self.optimizers().param_groups))}
103 | self.wandb_logger.log(results=lr_stats, split="lr", step=self.global_step, commit=True)
104 | self.prev_step = self.global_step
105 |
106 | return loss
107 |
108 | def on_train_start(self):
109 | if dist_utils.is_main_process():
110 | if self.wandb_logger is not None:
111 | wandb.watch(self.first_stage_model_transformer, log='parameters', log_freq=1000)
112 |
113 | def validation_step(self, batch, batch_idx):
114 | tokens, embedding, audio_tokens, first_stage_input, first_stage_output = batch
115 | tokens = tokens.unsqueeze(1)
116 | embedding = embedding.unsqueeze(1)
117 | loss = self(text_tokens=tokens, embedding=embedding, inputs=first_stage_input, targets=first_stage_output)
118 | stats = {
119 | "loss": [loss.item()],
120 | }
121 | self.evaluator.update(stats)
122 | return loss
123 |
124 | def on_validation_epoch_end(self):
125 | self.evaluator.synchronize_between_processes()
126 | if dist_utils.is_main_process():
127 | summary = self.evaluator.summarize()
128 | self.evaluator.reset()
129 | if dist_utils.is_main_process():
130 | if self.wandb_logger is not None:
131 | self.wandb_logger.log(results=summary, split="validation", step=self.global_step, commit=False)
132 |
133 | def configure_optimizers(self):
134 | return optimizer_utils.set_schedule(self)
135 |
136 | # Train the model
137 | if __name__ == "__main__":
138 | training_args = parse_args()
139 | data_dir = training_args.debug_data_dir if training_args.debug else training_args.data_dir
140 | data_module = VoiceDataModule(training_args.per_gpu_batchsize, data_dir)
141 |
142 | model_dir = snapshot_download(repo_id=training_args.huggingface_repo_id)
143 | first_stage_ckpt_path = get_first_stage_path(model_dir)
144 | second_stage_ckpt_path = get_second_stage_path(model_dir)
145 | config_first_stage = TrainingConfig(
146 | ckpt_path=first_stage_ckpt_path,
147 | num_samples=training_args.num_samples,
148 | seed=training_args.seed,
149 | device=training_args.device,
150 | dtype=training_args.dtype,
151 | compile=training_args.compile,
152 | init_from=training_args.init_from,
153 | train_from_scratch=training_args.train_from_scratch,
154 | output_dir=training_args.output_dir,
155 | spkemb_dropout=training_args.spkemb_dropout
156 | )
157 |
158 | config_second_stage = TrainingConfig(
159 | ckpt_path=second_stage_ckpt_path,
160 | num_samples=training_args.num_samples,
161 | seed=training_args.seed,
162 | device=training_args.device,
163 | dtype=training_args.dtype,
164 | compile=training_args.compile,
165 | init_from=training_args.init_from,
166 | train_from_scratch=training_args.train_from_scratch,
167 | output_dir=training_args.output_dir,
168 | spkemb_dropout=training_args.spkemb_dropout
169 | )
170 |
171 | is_debug = training_args.debug
172 | logger = WandbLogger(training_args) if training_args.use_wandb else None
173 | model = KotobaSpeechModelFirstStage(config_first_stage, config_second_stage, device=training_args.device, use_kv_cache=training_args.use_kv_cache, logger=logger, is_debug=is_debug, configs=training_args)
174 |
175 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
176 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
177 | save_top_k=-1,
178 | verbose=True,
179 | monitor=None,
180 | )
181 | callbacks = [checkpoint_callback]
182 |
183 | num_gpus = (
184 | training_args.num_gpus
185 | if isinstance(training_args.num_gpus, int)
186 | else len(training_args.num_gpus)
187 | )
188 |
189 | grad_steps = training_args.batch_size // (
190 | training_args.per_gpu_batchsize * num_gpus * training_args.num_nodes
191 | )
192 |
193 | max_steps = training_args.max_steps if training_args.max_steps is not None else None
194 |
195 | if training_args.dtype == "bfloat16":
196 | precision="bf16-mixed"
197 | else:
198 | raise "Precision needs to be studied well."
199 |
200 | if training_args.fsdp_strategy is not None:
201 | strategy = FSDPStrategy(
202 | sharding_strategy=training_args.fsdp_strategy
203 | )
204 | elif training_args.use_ddp_strategy:
205 | strategy = DDPStrategy(
206 | static_graph=True,
207 | )
208 | else:
209 | strategy = 'ddp_find_unused_parameters_true'
210 |
211 | if training_args.num_nodes > 1:
212 | plugins = [MyClusterEnvironment()]
213 | else:
214 | plugins = None
215 |
216 | if training_args.val_check_interval is not None:
217 | if training_args.val_check_interval >= 1:
218 | training_args.val_check_interval=int(training_args.val_check_interval)
219 |
220 |
221 | trainer = Trainer(
222 | callbacks=callbacks,
223 | devices=training_args.num_gpus,
224 | strategy=strategy,
225 | num_nodes=training_args.num_nodes,
226 | precision=precision,
227 | accelerator="cuda",
228 | benchmark=True,
229 | deterministic=True,
230 | max_epochs=training_args.max_epoch if max_steps is None else 1000,
231 | accumulate_grad_batches=grad_steps,
232 | val_check_interval=training_args.val_check_interval,
233 | check_val_every_n_epoch=training_args.check_val_every_n_epoch,
234 | log_every_n_steps=10,
235 | fast_dev_run=training_args.fast_dev_run,
236 | plugins=plugins,
237 | gradient_clip_val=training_args.gradient_clip_val,
238 | )
239 |
240 | trainer.fit(
241 | model,
242 | ckpt_path=training_args.ckpt_path,
243 | datamodule=data_module
244 | )
245 | trainer.print(torch.cuda.memory_summary())
246 |
247 |
248 |
--------------------------------------------------------------------------------
/fam/llm/training/__init__.py:
--------------------------------------------------------------------------------
1 | from fam.llm.training.args import parse_args
2 | from fam.llm.training.utils import get_first_stage_path, get_second_stage_path, TrainingConfig
3 | from fam.llm.training.dataset import VoiceDataset
4 | from fam.llm.training.datamodule import VoiceDataModule
5 | from fam.llm.training.wandb_utils import WandbLogger
6 | from fam.llm.training.evaluator import Evaluator
7 | import fam.llm.training.optimizer as optimizer_utils
8 | import fam.llm.training.dist as dist_utils
--------------------------------------------------------------------------------
/fam/llm/training/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Optional, Literal
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser(description="Sample from a trained model.")
6 |
7 | parser.add_argument("--batch_size", type=int, default=128,
8 | help="Batch size.")
9 |
10 | parser.add_argument("--compile", action='store_true',
11 | help="Whether to compile the model using PyTorch 2.0.")
12 |
13 | parser.add_argument("--ckpt_path", type=str, default=None,
14 | help="Path to a checkpoint file to resume training from.")
15 |
16 | parser.add_argument("--data_dir", type=str, default="",
17 | help="A path to the dataset dir.")
18 |
19 | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
20 | help="Device to use for sampling.")
21 |
22 | parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32", "tfloat32"],
23 | help="Data type to use for sampling.")
24 |
25 | parser.add_argument("--exp_name", type=str, default="kotoba_voice_1.3B",
26 | help="A path to the dataset dir.")
27 |
28 | parser.add_argument("--fast_dev_run", action='store_true', default=False,
29 | help="Run a quick development check, usually used for debugging and testing purposes.")
30 |
31 | parser.add_argument("--huggingface_repo_id", type=str, required=False, default="kotoba-tech/kotoba-speech-v0.1",
32 | help="Absolute path to the model directory.")
33 |
34 | parser.add_argument("--train_from_scratch", action='store_true', default=False,
35 | help="Run a quick development check, usually used for debugging and testing purposes.")
36 |
37 | parser.add_argument("--init_from", type=str, default="resume",
38 | help="Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl').")
39 |
40 | parser.add_argument("--max_epoch", type=int, default=5,
41 | help="Number of nodes")
42 |
43 | parser.add_argument("--max_steps", type=int, default=None,
44 | help="Max steps to train")
45 |
46 | parser.add_argument("--num_gpus", type=int, default=1,
47 | help="Number of GPUs per node")
48 |
49 | parser.add_argument("--num_nodes", type=int, default=1,
50 | help="Number of nodes")
51 |
52 | parser.add_argument("--num_samples", type=int, default=1,
53 | help="Number of samples to generate from each model.")
54 |
55 | parser.add_argument("--output_dir", type=str, default="samples/",
56 | help="Relative path to output directory")
57 |
58 | parser.add_argument("--per_gpu_batchsize", type=int, default=16,
59 | help="Batch size per GPU.")
60 |
61 | parser.add_argument("--seed", type=int, default=1337,
62 | help="Random seed for sampling.")
63 |
64 | parser.add_argument("--use_kv_cache", type=str, default=None, choices=[None, "flash_decoding", "vanilla"],
65 | help="Type of kv caching to use for inference.")
66 |
67 | parser.add_argument("--val_check_interval", type=float, default=None,
68 | help="This overwrites check_val_every_n_epoch. If this value is less than 1, for example, 0.25, it means the validation set will be checked 4 times during a training epoch. If this value is greater than 1, for example, 1000, the validation set will be checked every 1000 training batches, either across complete epochs or during iteration-based training.")
69 |
70 | parser.add_argument("--check_val_every_n_epoch", type=int, default=1,
71 | help="Validate every n epoch.")
72 |
73 | parser.add_argument('--fsdp_strategy', type=str, default=None, choices=[None, 'FULL_SHARD', 'SHARD_GRAD_OP', 'HYBRID_SHARD', 'NO_SHARD'],
74 | help='Use fully sharded data parallel type,'
75 | 'FULL_SHARD: Shard weights, gradients, optimizer state (1 + 2 + 3)'
76 | 'SHARD_GRAD_OP: Shard gradients, optimizer state (2 + 3)'
77 | 'HYBRID_SHARD: Full-shard within a machine, replicate across machines'
78 | "NO_SHARD: Don't shard anything (similar to DDP, slower than DDP)"
79 | 'None: Use DDP'
80 | )
81 |
82 | parser.add_argument('--use_ddp_strategy', action='store_true', default=False,
83 | help='use DDPStrategy()')
84 |
85 | # WandB settings
86 | parser.add_argument('--use_wandb', action='store_true', default=False,
87 | help='Enable integration with Weights & Biases for experiment tracking.')
88 |
89 | parser.add_argument("--wandb_entity", type=str, default=None,
90 | help="Weights & Biases entity (team or user) under which the project is located (optional).")
91 |
92 | parser.add_argument("--wandb_project", type=str, default=None,
93 | help="Weights & Biases project name to which the run will be logged (optional).")
94 |
95 | parser.add_argument("--wandb_run_id", type=str, default=None,
96 | help="Unique identifier for the Weights & Biases run, allowing for run resumption or other operations (optional).")
97 |
98 | # Debug Setting
99 | parser.add_argument("--debug", action='store_true', default=False,
100 | help="Debug mode")
101 |
102 | parser.add_argument("--debug_data_dir", type=str, default="",
103 | help="A path to the dataset dir.")
104 |
105 | # Optimizer Setting
106 | parser.add_argument("--optimizer_type", type=str, default="AdamW",
107 | choices=["Adam", "AdamW"], help="Type of optimizer to use.")
108 |
109 | parser.add_argument("--beta_one", type=float, default=0.9,
110 | help="Coefficient for computing running averages of gradient.")
111 |
112 | parser.add_argument("--beta_two", type=float, default=0.95,
113 | help="Coefficient for computing running averages of the square of the gradient.")
114 |
115 | parser.add_argument("--epsilon", type=float, default=1e-5,
116 | help="Term added to the denominator to improve numerical stability.")
117 |
118 | parser.add_argument("--weight_decay", type=float, default=0.1,
119 | help="Coefficient for computing running averages of the square of the gradient.")
120 |
121 | parser.add_argument("--decay_power", type=str, default="cosine",
122 | help="Type of learning rate decay to use.")
123 |
124 | parser.add_argument("--learning_rate", type=float, default=2e-4,
125 | help="Initial learning rate.")
126 |
127 | parser.add_argument("--warmup_steps", type=float, default=0.01,
128 | help="Fraction of total training steps to use for learning rate warmup. If warmup_steps is greater than 1, then the value specified in warmup_steps will represent the exact number of steps to be used for the warmup phase.")
129 |
130 | parser.add_argument("--end_lr", type=float, default=0,
131 | help="Final learning rate after decay.")
132 |
133 | parser.add_argument("--gradient_clip_val", type=float, default=0,
134 | help="Clip gradients' maximum magnitude.")
135 |
136 | # Model config setting
137 | parser.add_argument("--spkemb_dropout", type=float, default=0.1,
138 | help="Fraction of total training steps to use for learning rate warmup. If warmup_steps is greater than 1, then the value specified in warmup_steps will represent the exact number of steps to be used for the warmup phase.")
139 |
140 | return parser.parse_args()
141 |
142 | if __name__ == "__main__":
143 | args = parse_args()
144 | # Now you can use args to access your configuration, for example:
145 | # print(args.huggingface_repo_id)
146 |
--------------------------------------------------------------------------------
/fam/llm/training/datamodule.py:
--------------------------------------------------------------------------------
1 | # python fam/llm/training/datamodule.py
2 | import pytorch_lightning as pl
3 | from torch.utils.data import DataLoader
4 | from fam.llm.training.dataset import VoiceDataset
5 |
6 | class VoiceDataModule(pl.LightningDataModule):
7 | def __init__(self, batch_size=64, data_dir="/root/data/reazon_small/"):
8 | super().__init__()
9 | self.batch_size = batch_size
10 | self.data_dir = data_dir
11 |
12 | def prepare_data(self):
13 | pass
14 |
15 | def setup(self, stage):
16 | # Load and split data
17 | if stage == "train" or stage == "fit":
18 | self.voice_train = VoiceDataset(split="train", data_dir=self.data_dir)
19 | self.voice_val = VoiceDataset(split="val", data_dir=self.data_dir)
20 | self.voice_test = None
21 | elif stage == "test":
22 | self.voice_train = None
23 | self.voice_val = None
24 | self.voice_test = VoiceDataset(split="test", data_dir=self.data_dir)
25 | else:
26 | assert False
27 |
28 | def train_dataloader(self):
29 | return DataLoader(self.voice_train, batch_size=self.batch_size, collate_fn=self.voice_train.collate, shuffle=True)
30 |
31 | def val_dataloader(self):
32 | return DataLoader(self.voice_val, batch_size=self.batch_size, collate_fn=self.voice_val.collate)
33 |
34 | def test_dataloader(self):
35 | return DataLoader(self.voice_test, batch_size=self.batch_size, collate_fn=self.voice_test.collate)
36 |
37 | if __name__ == "__main__":
38 | voice_datamodule = VoiceDataModule(batch_size=32, data_dir="/root/data/reazon_small/")
39 | voice_datamodule.setup("train")
40 | train_dataloader = voice_datamodule.train_dataloader()
41 | for batch in train_dataloader:
42 | print(batch)
43 | val_dataloader = voice_datamodule.val_dataloader()
44 | test_dataloader = voice_datamodule.test_dataloader()
45 |
--------------------------------------------------------------------------------
/fam/llm/training/dataset.py:
--------------------------------------------------------------------------------
1 | # python fam/llm/training/dataset.py
2 | import os, json
3 | import torch
4 | from tqdm import tqdm
5 | from torch.utils.data import Dataset
6 |
7 | class VoiceDataset(Dataset):
8 | def __init__(self, split="train", data_dir=""):
9 | self.data_dir = data_dir
10 | self.data = self._load_data(os.path.join(data_dir, f'{split}.jsonl'))
11 |
12 | def _load_data(self, file_path):
13 | data = []
14 | with open(file_path, 'r') as f:
15 | for line in f:
16 | data.append(json.loads(line))
17 | return data
18 |
19 | def __len__(self):
20 | return len(self.data)
21 |
22 | def __getitem__(self, idx):
23 | train_item = self.data[idx]
24 | key = train_item["key"]
25 | tokens = torch.load(os.path.join(self.data_dir, 'txt', f"{key}.pt"))
26 | embedding = torch.load(os.path.join(self.data_dir, 'emb', f"{key}.emb"))
27 | audio_tokens = torch.load(os.path.join(self.data_dir, 'tok', f"{key}.tok"))
28 | return tokens, embedding, audio_tokens
29 |
30 | def collate(self, batch):
31 | audio_pad = 2048
32 | num_bands = 2
33 | eoa = 2048
34 | loss_pad = -1
35 | audio_vocab_size = 1024
36 | # end of audio
37 | # batch is a list of samples, where each sample is a dictionary
38 | tokens = [sample[0] for sample in batch]
39 | padded_tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)
40 | embeddings = [sample[1] for sample in batch]
41 | padded_embeddings = torch.nn.utils.rnn.pad_sequence(embeddings, batch_first=True, padding_value=0)
42 | audios = [sample[2] for sample in batch]
43 | # Determine the maximum value of h
44 | max_l = max(tensor.size(1) for tensor in audios)
45 | max_t2l = max(tensor_text.size(0) + tensor_audio.size(1)*2 for tensor_text, tensor_audio in zip(tokens, audios))
46 | # Pad tensors to have the same width
47 | padded_audios = [torch.nn.functional.pad(tensor, (0, max_l - tensor.shape[1]), mode='constant', value=audio_vocab_size) for tensor in audios]
48 | # Concatenate tensors along dimension 0
49 | padded_audios = torch.stack(padded_audios, dim=0)
50 | # first_stage_output [B, 2, T+T+L (padded)]
51 | first_stage_input = torch.full([len(batch), max_t2l], audio_pad, dtype=torch.int64)
52 | first_stage_output = torch.full([len(batch), max_t2l], loss_pad, dtype=torch.int64)
53 | ## first_stage_output: [text, text, audio, pad]
54 | ## first_stage_input: [-1, audio, eoa, pad]
55 | for idx, tensor_text, tensor_audio in zip(range(len(batch)), tokens, audios):
56 | text_size = tensor_text.size(0)
57 | audio_size = tensor_audio.size(1)
58 | bands = tensor_audio[:num_bands]
59 | bands += torch.arange(num_bands).view([num_bands, 1])*audio_vocab_size
60 | bands = bands.transpose(0, 1).contiguous().view(-1)
61 | first_stage_input[idx, :text_size] = tensor_text
62 | first_stage_input[idx, text_size:text_size+audio_size*num_bands] = bands
63 | first_stage_output[idx, :text_size-1] = torch.full([text_size-1], loss_pad, dtype=torch.int64)
64 | eoa_tensor = torch.full([1], eoa, dtype=torch.int64)
65 | first_stage_output[idx, text_size-1:text_size+audio_size*2] = torch.cat([bands, eoa_tensor], dim=0)
66 | first_stage_output[idx, text_size+audio_size*2:] = loss_pad
67 | return padded_tokens, padded_embeddings, padded_audios, first_stage_input, first_stage_output
68 |
69 |
70 | if __name__ == "__main__":
71 | for sp in ["train" , "val"]:
72 | voice_train = VoiceDataset(split=sp, data_dir="/debug_data")
73 | for i in tqdm(range(len(voice_train))):
74 | tokens, embedding, audio_tokens = voice_train.__getitem__(i)
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/fam/llm/training/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | This file contains primitives for multi-gpu communication.
4 | This is useful when doing distributed training.
5 | """
6 | import functools
7 | import logging
8 | import pickle
9 | import torch
10 | import torch.distributed as dist
11 |
12 |
13 | _LOCAL_PROCESS_GROUP = None
14 | """
15 | A torch process group which only includes processes that on the same machine as the current process.
16 | This variable is set when processes are spawned by `launch()` in "engine/launch.py".
17 | """
18 |
19 |
20 | def get_world_size() -> int:
21 | if not dist.is_available():
22 | return 1
23 | if not dist.is_initialized():
24 | return 1
25 | return dist.get_world_size()
26 |
27 |
28 | def get_rank() -> int:
29 | if not dist.is_available():
30 | return 0
31 | if not dist.is_initialized():
32 | return 0
33 | return dist.get_rank()
34 |
35 |
36 | def get_local_rank() -> int:
37 | """
38 | Returns:
39 | The rank of the current process within the local (per-machine) process group.
40 | """
41 | if not dist.is_available():
42 | return 0
43 | if not dist.is_initialized():
44 | return 0
45 | assert _LOCAL_PROCESS_GROUP is not None
46 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
47 |
48 |
49 | def get_local_size() -> int:
50 | """
51 | Returns:
52 | The size of the per-machine process group,
53 | i.e. the number of processes per machine.
54 | """
55 | if not dist.is_available():
56 | return 1
57 | if not dist.is_initialized():
58 | return 1
59 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
60 |
61 |
62 | def is_initialized() -> bool:
63 | return dist.is_initialized()
64 |
65 | def is_available() -> bool:
66 | return dist.is_available()
67 |
68 | def is_main_process() -> bool:
69 | return get_rank() == 0
70 |
71 | def is_dist_avail_and_initialized():
72 | """
73 | Returns:
74 | True if distributed training is enabled
75 | """
76 | if not dist.is_available():
77 | return False
78 | if not dist.is_initialized():
79 | return False
80 | return True
81 |
82 |
83 | def synchronize():
84 | """
85 | Helper function to synchronize (barrier) among all processes when
86 | using distributed training
87 | """
88 | if not dist.is_available():
89 | return
90 | if not dist.is_initialized():
91 | return
92 | world_size = dist.get_world_size()
93 | if world_size == 1:
94 | return
95 | dist.barrier()
96 |
97 |
98 | @functools.lru_cache()
99 | def _get_global_gloo_group():
100 | """
101 | Return a process group based on gloo backend, containing all the ranks
102 | The result is cached.
103 | """
104 | if dist.get_backend() == "nccl":
105 | return dist.new_group(backend="gloo")
106 | else:
107 | return dist.group.WORLD
108 |
109 |
110 | def _serialize_to_tensor(data, group):
111 | backend = dist.get_backend(group)
112 | assert backend in ["gloo", "nccl"]
113 | device = torch.device("cpu" if backend == "gloo" else "cuda")
114 |
115 | buffer = pickle.dumps(data)
116 | if len(buffer) > 1024 ** 3:
117 | logger = logging.getLogger(__name__)
118 | logger.warning(
119 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
120 | get_rank(), len(buffer) / (1024 ** 3), device
121 | )
122 | )
123 | storage = torch.ByteStorage.from_buffer(buffer)
124 | tensor = torch.ByteTensor(storage).to(device=device)
125 | return tensor
126 |
127 |
128 | def _pad_to_largest_tensor(tensor, group):
129 | """
130 | Returns:
131 | list[int]: size of the tensor, on each rank
132 | Tensor: padded tensor that has the max size
133 | """
134 | world_size = dist.get_world_size(group=group)
135 | assert (
136 | world_size >= 1
137 | ), "comm.gather/all_gather must be called from ranks within the given group!"
138 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
139 | size_list = [
140 | torch.zeros([1], dtype=torch.int64, device=tensor.device)
141 | for _ in range(world_size)
142 | ]
143 | dist.all_gather(size_list, local_size, group=group)
144 | size_list = [int(size.item()) for size in size_list]
145 |
146 | max_size = max(size_list)
147 |
148 | # we pad the tensor because torch all_gather does not support
149 | # gathering tensors of different shapes
150 | if local_size != max_size:
151 | padding = torch.zeros(
152 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device
153 | )
154 | tensor = torch.cat((tensor, padding), dim=0)
155 | return size_list, tensor
156 |
157 |
158 | def all_gather(data, group=None):
159 | """
160 | Run all_gather on arbitrary picklable data (not necessarily tensors).
161 |
162 | Args:
163 | data: any picklable object
164 | group: a torch process group. By default, will use a group which
165 | contains all ranks on gloo backend.
166 |
167 | Returns:
168 | list[data]: list of data gathered from each rank
169 | """
170 | if get_world_size() == 1:
171 | return [data]
172 | if group is None:
173 | group = _get_global_gloo_group()
174 | if dist.get_world_size(group) == 1:
175 | return [data]
176 |
177 | tensor = _serialize_to_tensor(data, group)
178 |
179 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
180 | max_size = max(size_list)
181 |
182 | # receiving Tensor from all ranks
183 | tensor_list = [
184 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
185 | for _ in size_list
186 | ]
187 | dist.all_gather(tensor_list, tensor, group=group)
188 |
189 | data_list = []
190 | for size, tensor in zip(size_list, tensor_list):
191 | buffer = tensor.cpu().numpy().tobytes()[:size]
192 | data_list.append(pickle.loads(buffer))
193 |
194 | return data_list
195 |
196 |
197 | def gather(data, dst=0, group=None):
198 | """
199 | Run gather on arbitrary picklable data (not necessarily tensors).
200 |
201 | Args:
202 | data: any picklable object
203 | dst (int): destination rank
204 | group: a torch process group. By default, will use a group which
205 | contains all ranks on gloo backend.
206 |
207 | Returns:
208 | list[data]: on dst, a list of data gathered from each rank. Otherwise,
209 | an empty list.
210 | """
211 | if get_world_size() == 1:
212 | return [data]
213 | if group is None:
214 | group = _get_global_gloo_group()
215 | if dist.get_world_size(group=group) == 1:
216 | return [data]
217 | rank = dist.get_rank(group=group)
218 |
219 | tensor = _serialize_to_tensor(data, group)
220 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
221 |
222 | # receiving Tensor from all ranks
223 | if rank == dst:
224 | max_size = max(size_list)
225 | tensor_list = [
226 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
227 | for _ in size_list
228 | ]
229 | dist.gather(tensor, tensor_list, dst=dst, group=group)
230 |
231 | data_list = []
232 | for size, tensor in zip(size_list, tensor_list):
233 | buffer = tensor.cpu().numpy().tobytes()[:size]
234 | data_list.append(pickle.loads(buffer))
235 | return data_list
236 | else:
237 | dist.gather(tensor, [], dst=dst, group=group)
238 | return []
239 |
240 |
241 | def shared_random_seed():
242 | """
243 | Returns:
244 | int: a random number that is the same across all workers.
245 | If workers need a shared RNG, they can use this shared seed to
246 | create one.
247 |
248 | All workers must call this function, otherwise it will deadlock.
249 | """
250 | ints = np.random.randint(2 ** 31)
251 | all_ints = all_gather(ints)
252 | return all_ints[0]
253 |
254 |
255 | def reduce_dict(input_dict, average=True):
256 | """
257 | Reduce the values in the dictionary from all processes so that process with rank
258 | 0 has the reduced results.
259 |
260 | Args:
261 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
262 | average (bool): whether to do average or sum
263 |
264 | Returns:
265 | a dict with the same keys as input_dict, after reduction.
266 | """
267 | world_size = get_world_size()
268 | if world_size < 2:
269 | return input_dict
270 | with torch.no_grad():
271 | names = []
272 | values = []
273 | # sort the keys so that they are consistent across processes
274 | for k in sorted(input_dict.keys()):
275 | names.append(k)
276 | values.append(input_dict[k])
277 | values = torch.stack(values, dim=0)
278 | dist.reduce(values, dst=0)
279 | if dist.get_rank() == 0 and average:
280 | # only main process gets accumulated, so only divide by
281 | # world_size in this case
282 | values /= world_size
283 | reduced_dict = {k: v for k, v in zip(names, values)}
284 | return reduced_dict
--------------------------------------------------------------------------------
/fam/llm/training/evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import fam.llm.training.dist as dist_utils
3 |
4 | class Evaluator(object):
5 | def __init__(self):
6 | self.reset()
7 |
8 | def update(self, results):
9 | for k in self.fields:
10 | self.results[k] += results[k]
11 |
12 | def synchronize_between_processes(self):
13 | all_results = dist_utils.all_gather(self.results)
14 | merged_results = {}
15 | for r in all_results:
16 | for k in r.keys():
17 | merged_results.setdefault(k, [])
18 | merged_results[k] += r[k]
19 | for k in merged_results.keys():
20 | merged_results[k] = np.array(merged_results[k])
21 | self.results = merged_results
22 |
23 | def summarize(self):
24 | #! With multi-gpu, the dataloader duplicates examples if the number of examples is not divisible by the batch size.
25 | if dist_utils.is_main_process():
26 | return {k: np.mean(self.results[k]) for k in self.fields}
27 |
28 | else:
29 | assert False, "This if function should not be called."
30 |
31 | def reset(self):
32 | self.results = {}
33 | self.fields = [
34 | "loss",
35 | ]
36 |
37 | for f in self.fields:
38 | self.results.setdefault(f, [])
39 |
--------------------------------------------------------------------------------
/fam/llm/training/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import inspect
3 | from IPython import embed
4 | from transformers import (
5 | get_polynomial_decay_schedule_with_warmup,
6 | get_cosine_schedule_with_warmup,
7 | )
8 |
9 |
10 | def set_schedule(pl_module):
11 | optimizer_type = pl_module.configs["optimizer_type"]
12 | lr = pl_module.configs["learning_rate"]
13 | beta_one = pl_module.configs["beta_one"]
14 | beta_two = pl_module.configs["beta_two"]
15 | eps = pl_module.configs["epsilon"]
16 | wd = pl_module.configs["weight_decay"]
17 | #fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
18 | #use_fused = fused_available and pl_module.device.type == "cuda"
19 | use_fused=False
20 | extra_args = dict(fused=True) if use_fused else dict()
21 |
22 | if optimizer_type == "Adam":
23 | optimizer = torch.optim.Adam(pl_module.first_stage_model_cls.model.parameters(), lr=lr, betas=(beta_one, beta_two), eps=eps, **extra_args)
24 | elif optimizer_type == "AdamW":
25 | model = pl_module.first_stage_model_cls.model
26 | # start with all of the candidate parameters
27 | param_dict = {pn: p for pn, p in model.named_parameters()}
28 | # filter out those that do not require grad
29 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
30 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
31 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
32 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
33 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
34 | optim_groups = [
35 | {"params": decay_params, "weight_decay": wd},
36 | {"params": nodecay_params, "weight_decay": 0.0},
37 | ]
38 | num_decay_params = sum(p.numel() for p in decay_params)
39 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
40 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
41 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
42 |
43 | optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(beta_one, beta_two), eps=eps, **extra_args)
44 |
45 | if pl_module.trainer.max_steps == -1:
46 | # TODO: this is note tested in multi-node set-up.
47 | max_steps = (
48 | len(pl_module.trainer.datamodule.train_dataloader())
49 | // len(pl_module.trainer.device_ids)
50 | * pl_module.trainer.max_epochs
51 | // pl_module.trainer.accumulate_grad_batches
52 | )
53 | else:
54 | max_steps = pl_module.trainer.max_steps
55 |
56 | end_lr = pl_module.configs["end_lr"]
57 | decay_power = pl_module.configs["decay_power"]
58 | warmup_steps = pl_module.configs["warmup_steps"]
59 | if warmup_steps <= 1:
60 | warmup_steps = int(max_steps * warmup_steps)
61 |
62 | if decay_power == "cosine":
63 | scheduler = get_cosine_schedule_with_warmup(
64 | optimizer,
65 | num_warmup_steps=warmup_steps,
66 | num_training_steps=max_steps,
67 | )
68 | else:
69 | scheduler = get_polynomial_decay_schedule_with_warmup(
70 | optimizer,
71 | num_warmup_steps=warmup_steps,
72 | num_training_steps=max_steps,
73 | lr_end=end_lr,
74 | power=decay_power,
75 | )
76 |
77 | sched = {"scheduler": scheduler, "interval": "step"}
78 |
79 | return (
80 | [optimizer],
81 | [sched],
82 | )
--------------------------------------------------------------------------------
/fam/llm/training/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 | def get_first_stage_path(model_dir: str):
5 | """Absolute path to checkpoint for the first stage model."""
6 | return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
7 |
8 |
9 | def get_second_stage_path(model_dir: str):
10 | """Absolute path to checkpoint for the second stage model."""
11 | return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
12 |
13 |
14 | @dataclass
15 | class TrainingConfig:
16 | ckpt_path: str # path to checkpoint
17 | output_dir: str
18 | num_samples: int = 10 # number of samples to draw
19 | seed: int = 1337 # random seed
20 | device: str = "cuda"
21 | dtype: str = "bfloat16"
22 | compile: bool = False
23 | init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
24 | train_from_scratch: bool = False # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
25 | spkemb_dropout: float = 0.0
26 |
27 | def __str__(self):
28 | field_strs = []
29 | for field in dataclasses.fields(self):
30 | value = getattr(self, field.name)
31 | field_strs.append(f" {field.name}: {value}")
32 |
33 | return "TrainingConfig:\n" + "\n".join(field_strs)
34 |
--------------------------------------------------------------------------------
/fam/llm/training/wandb_utils.py:
--------------------------------------------------------------------------------
1 | # python fam/llm/training/wandb_utils.py
2 | import os
3 | import wandb
4 | from datetime import datetime
5 | from argparse import Namespace
6 | from IPython import embed
7 |
8 | class WandbLogger(object):
9 | def __init__(self, args: Namespace):
10 | super().__init__()
11 | self._step = 0
12 | self._debug = args.debug
13 | now = datetime.now()
14 | exp_name = args.exp_name + "-" + str(now).replace(" ", "-")
15 | wandb_input = {
16 | "entity": args.wandb_entity,
17 | "name": exp_name,
18 | "config": args
19 | }
20 | wandb_input["project"] = args.wandb_project
21 | if args.wandb_run_id is not None:
22 | wandb_input["id"] = args.wandb_run_id
23 | wandb_input["resume"] = "must"
24 | wandb.init(**wandb_input)
25 |
26 | def log(self, results, split: str, step: int = None, commit=False):
27 | formated_results = {}
28 | if step is not None:
29 | self._step = step
30 |
31 | for k, v in results.items():
32 | formated_results["{}/{}".format(split, k)] = v
33 | wandb.log(formated_results, step=self._step, commit=commit)
34 |
35 | def get_step(self):
36 | return self._step
37 |
38 |
39 | if __name__ == "__main__":
40 | from fam.llm.training import parse_args
41 | training_args = parse_args()
42 | logger = WandbLogger(training_args)
--------------------------------------------------------------------------------
/fam/llm/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import subprocess
4 | import tempfile
5 |
6 | import librosa
7 | import torch
8 |
9 |
10 | def normalize_text(text: str) -> str:
11 | unicode_conversion = {
12 | 8175: "'",
13 | 8189: "'",
14 | 8190: "'",
15 | 8208: "-",
16 | 8209: "-",
17 | 8210: "-",
18 | 8211: "-",
19 | 8212: "-",
20 | 8213: "-",
21 | 8214: "||",
22 | 8216: "'",
23 | 8217: "'",
24 | 8218: ",",
25 | 8219: "`",
26 | 8220: '"',
27 | 8221: '"',
28 | 8222: ",,",
29 | 8223: '"',
30 | 8228: ".",
31 | 8229: "..",
32 | 8230: "...",
33 | 8242: "'",
34 | 8243: '"',
35 | 8245: "'",
36 | 8246: '"',
37 | 180: "'",
38 | 2122: "TM", # Trademark
39 | }
40 |
41 | text = text.translate(unicode_conversion)
42 |
43 | non_bpe_chars = set([c for c in list(text) if ord(c) >= 256])
44 | #if len(non_bpe_chars) > 0:
45 | # non_bpe_points = [(c, ord(c)) for c in non_bpe_chars]
46 | # raise ValueError(f"Non-BPE single token characters found: {non_bpe_points}")
47 |
48 | text = text.replace("\t", " ")
49 | text = text.replace("\n", " ")
50 | text = text.replace("*", " ")
51 | text = text.strip()
52 | text = re.sub("\s\s+", " ", text) # remove multiple spaces
53 | return text
54 |
55 | def check_audio_file(path_or_uri, threshold_s=30):
56 | if "http" in path_or_uri:
57 | temp_fd, filepath = tempfile.mkstemp()
58 | os.close(temp_fd) # Close the file descriptor, curl will create a new connection
59 | curl_command = ["curl", "-L", path_or_uri, "-o", filepath]
60 | subprocess.run(curl_command, check=True)
61 |
62 | else:
63 | filepath = path_or_uri
64 |
65 | audio, sr = librosa.load(filepath)
66 | duration_s = librosa.get_duration(y=audio, sr=sr)
67 | #if duration_s < threshold_s:
68 | # raise Exception(
69 | # f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed."
70 | # )
71 |
72 | # Clean up the temporary file if it was created
73 | if "http" in path_or_uri:
74 | os.remove(filepath)
75 |
76 |
77 | def get_default_dtype() -> str:
78 | """Compute default 'dtype' based on GPU architecture"""
79 | if torch.cuda.is_available():
80 | for i in range(torch.cuda.device_count()):
81 | device_properties = torch.cuda.get_device_properties(i)
82 | dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures
83 | else:
84 | dtype = "float16"
85 |
86 | print(f"using dtype={dtype}")
87 | return dtype
88 |
89 |
90 | def get_device() -> str:
91 | return "cuda" if torch.cuda.is_available() else "cpu"
92 |
--------------------------------------------------------------------------------
/fam/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/py.typed
--------------------------------------------------------------------------------
/fam/quantiser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/__init__.py
--------------------------------------------------------------------------------
/fam/quantiser/audio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/audio/__init__.py
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/audio/speaker_encoder/__init__.py
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 |
4 | mel_window_length = 25
5 | mel_window_step = 10
6 | mel_n_channels = 40
7 | sampling_rate = 16000
8 |
9 |
10 | def wav_to_mel_spectrogram(wav):
11 | """
12 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
13 | Note: this not a log-mel spectrogram.
14 | """
15 | frames = librosa.feature.melspectrogram(
16 | y=wav,
17 | sr=sampling_rate,
18 | n_fft=int(sampling_rate * mel_window_length / 1000),
19 | hop_length=int(sampling_rate * mel_window_step / 1000),
20 | n_mels=mel_n_channels,
21 | )
22 | return frames.astype(np.float32).T
23 |
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt
--------------------------------------------------------------------------------
/fam/quantiser/audio/speaker_encoder/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from time import perf_counter as timer
3 | from typing import List, Optional, Union
4 |
5 | import librosa
6 | import numpy as np
7 | import torch
8 | from torch import nn
9 |
10 | from fam.quantiser.audio.speaker_encoder import audio
11 |
12 | DEFAULT_SPKENC_CKPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/ckpt.pt")
13 |
14 | mel_window_step = 10
15 | mel_n_channels = 40
16 | sampling_rate = 16000
17 | partials_n_frames = 160
18 | model_hidden_size = 256
19 | model_embedding_size = 256
20 | model_num_layers = 3
21 |
22 |
23 | class SpeakerEncoder(nn.Module):
24 | def __init__(
25 | self,
26 | weights_fpath: Optional[str] = None,
27 | device: Optional[Union[str, torch.device]] = None,
28 | verbose: bool = True,
29 | eval: bool = False,
30 | ):
31 | super().__init__()
32 |
33 | # Define the network
34 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
35 | self.linear = nn.Linear(model_hidden_size, model_embedding_size)
36 | self.relu = nn.ReLU()
37 |
38 | # Get the target device
39 | if device is None:
40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41 | elif isinstance(device, str):
42 | device = torch.device(device)
43 | self.device = device
44 |
45 | start = timer()
46 | if eval and weights_fpath is None:
47 | weights_fpath = DEFAULT_SPKENC_CKPT_PATH
48 |
49 | if weights_fpath is not None:
50 | checkpoint = torch.load(weights_fpath, map_location="cpu")
51 |
52 | self.load_state_dict(checkpoint["model_state"], strict=False)
53 | self.to(device)
54 |
55 | if eval:
56 | self.eval()
57 |
58 | if verbose:
59 | print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start))
60 |
61 | def forward(self, mels: torch.FloatTensor):
62 | _, (hidden, _) = self.lstm(mels)
63 | embeds_raw = self.relu(self.linear(hidden[-1]))
64 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
65 |
66 | @staticmethod
67 | def compute_partial_slices(n_samples: int, rate, min_coverage):
68 | # Compute how many frames separate two partial utterances
69 | samples_per_frame = int((sampling_rate * mel_window_step / 1000))
70 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
71 | frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
72 |
73 | # Compute the slices
74 | wav_slices, mel_slices = [], []
75 | steps = max(1, n_frames - partials_n_frames + frame_step + 1)
76 | for i in range(0, steps, frame_step):
77 | mel_range = np.array([i, i + partials_n_frames])
78 | wav_range = mel_range * samples_per_frame
79 | mel_slices.append(slice(*mel_range))
80 | wav_slices.append(slice(*wav_range))
81 |
82 | # Evaluate whether extra padding is warranted or not
83 | last_wav_range = wav_slices[-1]
84 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
85 | if coverage < min_coverage and len(mel_slices) > 1:
86 | mel_slices = mel_slices[:-1]
87 | wav_slices = wav_slices[:-1]
88 |
89 | return wav_slices, mel_slices
90 |
91 | def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True):
92 | wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
93 | max_wave_length = wav_slices[-1].stop
94 | if max_wave_length >= len(wav):
95 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
96 |
97 | mel = audio.wav_to_mel_spectrogram(wav)
98 | mels = np.array([mel[s] for s in mel_slices])
99 | with torch.no_grad():
100 | mels = torch.from_numpy(mels).to(self.device) # type: ignore
101 | partial_embeds = self(mels)
102 |
103 | if numpy:
104 | partial_embeds = partial_embeds.cpu().numpy()
105 | raw_embed = np.mean(partial_embeds, axis=0)
106 | embed = raw_embed / np.linalg.norm(raw_embed, 2)
107 | else:
108 | raw_embed = partial_embeds.mean(dim=0)
109 | embed = raw_embed / torch.linalg.norm(raw_embed, 2)
110 |
111 | if return_partials:
112 | return embed, partial_embeds, wav_slices
113 | return embed
114 |
115 | def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
116 | raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
117 | return raw_embed / np.linalg.norm(raw_embed, 2)
118 |
119 | def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor:
120 | wav_tgt, _ = librosa.load(fpath, sr=16000)
121 | wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
122 | embedding = self.embed_utterance(wav_tgt, numpy=numpy)
123 | return embedding
124 |
--------------------------------------------------------------------------------
/fam/quantiser/text/tokenise.py:
--------------------------------------------------------------------------------
1 | import tiktoken
2 |
3 |
4 | class TrainedBPETokeniser:
5 | def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None:
6 | self.tokenizer = tiktoken.Encoding(
7 | name=name,
8 | pat_str=pat_str,
9 | mergeable_ranks=mergeable_ranks,
10 | special_tokens=special_tokens,
11 | )
12 | self.offset = offset
13 |
14 | def encode(self, text: str) -> list[int]:
15 | # note: we add a end of text token!
16 | tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token]
17 | if self.offset is not None:
18 | tokens = [x + self.offset for x in tokens]
19 |
20 | return tokens
21 |
22 | def decode(self, tokens: list[int]):
23 | if self.offset is not None:
24 | tokens = [x - self.offset for x in tokens]
25 | return self.tokenizer.decode(tokens)
26 |
27 | @property
28 | def eot_token(self):
29 | if self.offset is not None:
30 | return self.tokenizer.eot_token + self.offset
31 | else:
32 | return self.tokenizer.eot_token
33 |
--------------------------------------------------------------------------------
/fam/ui/app.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | import os
4 |
5 | import gradio as gr
6 | import requests
7 | import soundfile as sf
8 |
9 | API_SERVER_URL = "http://127.0.0.1:58003/tts"
10 | RADIO_CHOICES = ["Preset voices", "Upload target voice", "Record your voice"]
11 | MAX_CHARS = 220
12 | PRESET_VOICES = {
13 | # female
14 | "Ava": "https://cdn.themetavoice.xyz/speakers/ava.flac",
15 | "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
16 | # male
17 | "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
18 | "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
19 | }
20 |
21 |
22 | def denormalise_top_p(top_p):
23 | # returns top_p in the range [0.9, 1.0]
24 | return round(0.9 + top_p / 100, 2)
25 |
26 |
27 | def denormalise_guidance(guidance):
28 | # returns guidance in the range [1.0, 3.0]
29 | return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)
30 |
31 |
32 | def _handle_edge_cases(to_say, upload_target):
33 | if not to_say:
34 | raise gr.Error("Please provide text to synthesise")
35 |
36 | def _check_file_size(path):
37 | if not path:
38 | return
39 | filesize = os.path.getsize(path)
40 | filesize_mb = filesize / 1024 / 1024
41 | if filesize_mb >= 50:
42 | raise gr.Error(
43 | f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB"
44 | )
45 |
46 | _check_file_size(upload_target)
47 |
48 |
49 | def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target):
50 | d_top_p = denormalise_top_p(top_p)
51 | d_guidance = denormalise_guidance(guidance)
52 |
53 | _handle_edge_cases(to_say, upload_target)
54 |
55 | to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
56 |
57 | custom_target_path = None
58 | if toggle == RADIO_CHOICES[1]:
59 | custom_target_path = upload_target
60 | elif toggle == RADIO_CHOICES[2]:
61 | custom_target_path = record_target
62 |
63 | config = {
64 | "text": to_say,
65 | "guidance": d_guidance,
66 | "top_p": d_top_p,
67 | "speaker_ref_path": PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else None,
68 | }
69 | headers = {"Content-Type": "audio/wav", "X-Payload": json.dumps(config)}
70 | if not custom_target_path:
71 | response = requests.post(API_SERVER_URL, headers=headers, data=None)
72 | else:
73 | with open(custom_target_path, "rb") as f:
74 | data = f.read()
75 | response = requests.post(API_SERVER_URL, headers=headers, data=data)
76 |
77 | wav, sr = None, None
78 | if response.status_code == 200:
79 | audio_buffer = io.BytesIO(response.content)
80 | audio_buffer.seek(0)
81 | wav, sr = sf.read(audio_buffer, dtype="float32")
82 | else:
83 | print(f"Something went wrong. response status code: {response.status_code}")
84 |
85 | return sr, wav
86 |
87 |
88 | def change_voice_selection_layout(choice):
89 | index = RADIO_CHOICES.index(choice)
90 | return [
91 | gr.update(visible=True)
92 | if i == index else gr.update(visible=False)
93 | for i in range(len(RADIO_CHOICES))
94 | ]
95 |
96 |
97 | title = "# TTS by Kotoba-Speech"
98 |
99 | description = """
100 | Kotoba-Speech v0.1は、1.2Bのトランスフォーマーに基づく音声生成モデルです。
101 | 以下の機能をサポートしています:
102 | \n
103 | * 日本語における滑らかなテキスト読み上げ生成
104 | * スピーチプロンプトを通じたOne-shot音声クローニング
105 |
106 | Kotoba Technologiesは、公開されたモデルを商用可能なApache 2.0ライセンスで公開します。
107 | 推論およびモデルコードは、Meta-Voiceをベースに作られており、学習コードは弊社のGitHubで近日中に公開する予定です。
108 | Kotoba Technologiesは、音声基盤モデルの開発に取り組んでおり、今後もモデルの公開を行なっていきます。是非、[Discord Community](https://discord.gg/qPVFqhGN7Z)に参加してご意見ください!
109 |
110 | Kotoba-Speech v0.1 is a 1.2B Transformer-based speech generative model. It supports the following properties:
111 | \n
112 | * Fluent text-to-speech generation in Japanese
113 | * One-shot voice cloning through speech prompt
114 |
115 | We are releasing our model under the Apache 2.0 license. Our inference and model code is adapted from Meta-Voice, and we will our training code on our GitHub repository shortly.
116 | Kotoba Technologies is committing on developing speech foundation models, and we’ll continue releasing our models. Please join [our discord](https://discord.gg/qPVFqhGN7Z) to contribute to out community.
117 | """
118 |
119 | with gr.Blocks(title="TTS by Kotoba-Speech") as demo:
120 | gr.Markdown(title)
121 |
122 | with gr.Row():
123 | gr.Markdown(description)
124 |
125 | with gr.Row():
126 | with gr.Column():
127 | to_say = gr.TextArea(
128 | label="What should I say!?",
129 | lines=4,
130 | value="コトバテクノロジーズのミッションは、音声基盤モデルを作ることです。",
131 | )
132 |
133 | with gr.Row(), gr.Column():
134 | # voice settings
135 | top_p = gr.Slider(
136 | value=5.0,
137 | minimum=0.0,
138 | maximum=10.0,
139 | step=1.0,
140 | label="Speech Stability - improves text following for a challenging speaker",
141 | )
142 | guidance = gr.Slider(
143 | value=5.0,
144 | minimum=1.0,
145 | maximum=5.0,
146 | step=1.0,
147 | label="Speaker similarity - How closely to match speaker identity and speech style.",
148 | )
149 |
150 | # voice select
151 | toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0])
152 |
153 | with gr.Row(visible=True) as row_1:
154 | preset_dropdown = gr.Dropdown(
155 | PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0]
156 | )
157 | with gr.Accordion("Preview: Preset voices", open=False):
158 | for label, path in PRESET_VOICES.items():
159 | gr.Audio(value=path, label=label)
160 |
161 | with gr.Row(visible=False) as row_2:
162 | upload_target = gr.Audio(
163 | sources=["upload"],
164 | type="filepath",
165 | label="Upload a clean sample to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.",
166 | min_length=10,
167 | max_length=90,
168 | )
169 |
170 | with gr.Row(visible=False) as row_3:
171 | record_target = gr.Audio(
172 | sources=["microphone"],
173 | type="filepath",
174 | label="Record your voice with a microphone to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.",
175 | min_length=10,
176 | max_length=90,
177 | )
178 |
179 | toggle.change(
180 | change_voice_selection_layout,
181 | inputs=toggle,
182 | outputs=[row_1, row_2, row_3],
183 | )
184 |
185 | with gr.Column():
186 | speech = gr.Audio(
187 | type="numpy",
188 | label="Kotoba-Speech says...",
189 | )
190 |
191 | submit = gr.Button("Generate Speech")
192 | submit.click(
193 | fn=tts,
194 | inputs=[to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target],
195 | outputs=speech,
196 | )
197 |
198 |
199 | demo.queue(default_concurrency_limit=2)
200 | demo.launch()
201 | # demo.launch(server_name="0.0.0.0", server_port=3000, share=True) # dev
202 |
--------------------------------------------------------------------------------
/fam/ui/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/ui/assets/favicon.ico
--------------------------------------------------------------------------------
/preprocess/audio_tokenize.py:
--------------------------------------------------------------------------------
1 | import os, torch, json
2 | from tqdm import tqdm
3 | import argparse
4 | import julius
5 | from audiocraft.data.audio import audio_read
6 | from audiocraft.models import MultiBandDiffusion # type: ignore
7 | from utils import get_sorted_keys, get_ids, create_directory_recursive
8 |
9 | mbd_sample_rate = 24_000
10 | mbd_bandwidth = 6
11 | num_codebooks = 8
12 | mbd = MultiBandDiffusion.get_mbd_24khz(bw=mbd_bandwidth)
13 |
14 | def tokenize(audio_path, output_file):
15 | wav, sr = audio_read(audio_path)
16 | if sr != mbd_sample_rate:
17 | wav = julius.resample_frac(wav, sr, mbd_sample_rate)
18 | if wav.ndim == 2:
19 | wav = wav.unsqueeze(1)
20 | wav = wav.to("cuda")
21 | with torch.no_grad():
22 | tokens = mbd.codec_model.encode(wav)
23 | tokens = tokens[0][0].cpu()
24 | torch.save(tokens, output_file)
25 |
26 | def main(file_name, base_dir, num_shards, shard_id):
27 | files = get_sorted_keys(file_name)
28 | start, end = get_ids(len(files), num_shards, shard_id)
29 | print(start, end)
30 | for file_name in tqdm(files[start:end]):
31 | in_file = os.path.join(base_dir, "wav", file_name + ".wav")
32 | out_file = os.path.join(base_dir, "tok", file_name + ".tok")
33 | if os.path.exists(out_file):
34 | continue
35 | create_directory_recursive(out_file)
36 | tokenize(in_file, out_file)
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser(description="Audito Tokenize")
40 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file")
41 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir")
42 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards")
43 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID")
44 | args = parser.parse_args()
45 | main(args.in_file, args.base_dir, args.num_shards, args.shard_id)
46 |
--------------------------------------------------------------------------------
/preprocess/download_reazon.py:
--------------------------------------------------------------------------------
1 | import os, subprocess, json
2 | import pandas as pd
3 | import argparse
4 | from utils import get_dirs, get_sorted_keys, get_ids, create_directory_recursive
5 |
6 | def download_tsv(split, base_dir):
7 | command = f"wget https://reazonspeech.s3.abci.ai/{split}.tsv -P {base_dir}"
8 | subprocess.run(command, shell=True)
9 | file_path = os.path.join(base_dir, f"{split}.tsv")
10 | return file_path
11 |
12 | def download_file(file_name, base_dir):
13 | base_url = "https://reazonspeech.s3.abci.ai/data/"
14 | path = os.path.join(base_url, file_name + ".tar")
15 | command = f"wget {path} -P {base_dir}"
16 | subprocess.run(command, shell=True)
17 | path = os.path.join(base_dir, file_name + ".tar")
18 | command = f"tar xvf {path} -C {base_dir}"
19 | subprocess.run(command, shell=True)
20 | command = f"rm {path}"
21 | subprocess.run(command, shell=True)
22 |
23 | def flac2wav(file_name, base_dir):
24 | flac_name = os.path.join(base_dir, file_name + ".flac")
25 | wav_name = os.path.join(base_dir, 'wav', file_name + '.wav')
26 | create_directory_recursive(wav_name)
27 | subprocess.run(f"ffmpeg -y -i {flac_name} {wav_name}", shell=True)
28 |
29 | def download_files(in_file, base_dir, num_shards, shard_id):
30 | files = get_dirs(in_file)
31 | start, end = get_ids(len(files), num_shards, shard_id)
32 | for file_name in files[start:end]:
33 | path = os.path.join(base_dir, file_name)
34 | if not os.path.exists(path):
35 | download_file(file_name, base_dir)
36 |
37 | def flac2wav_files(in_file, base_dir, num_shards, shard_id):
38 | files = get_sorted_keys(in_file)
39 | start, end = get_ids(len(files), num_shards, shard_id)
40 | for file_name in files[start:end]:
41 | path = os.path.join(base_dir, file_name)
42 | path = os.path.join(base_dir, 'wav', file_name + '.wav')
43 | if not os.path.exists(path):
44 | flac2wav(file_name, base_dir)
45 |
46 | def tsv2jsonl(in_file, out_file):
47 | data = pd.read_csv(in_file, header=None, sep='\t')
48 | with open(out_file, "w") as f:
49 | for index, row in data.iterrows():
50 | json.dump({"key": row[0].replace(".flac", ""), "text": row[1]}, f, ensure_ascii=False)
51 | f.write('\n')
52 |
53 | def main(split, base_dir, num_shards, shard_id):
54 | create_directory_recursive(base_dir)
55 | in_file = download_tsv(split, base_dir)
56 | out_file = in_file.replace(".tsv", ".jsonl")
57 | tsv2jsonl(in_file, out_file)
58 | download_files(out_file, base_dir, num_shards, shard_id)
59 | flac2wav_files(out_file, base_dir, num_shards, shard_id)
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser(description="Download Reazon Dataset")
63 | parser.add_argument("--split", default="small", help="Reazon split")
64 | parser.add_argument("--base_dir", default="data/", help="path to data directory")
65 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards")
66 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID")
67 | args = parser.parse_args()
68 | main(args.split, args.base_dir, args.num_shards, args.shard_id)
69 |
--------------------------------------------------------------------------------
/preprocess/spk_embed.py:
--------------------------------------------------------------------------------
1 | import os, torch, json
2 | from tqdm import tqdm
3 | import argparse
4 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
5 | from utils import get_sorted_keys, get_ids, create_directory_recursive
6 |
7 | smodel = SpeakerEncoder(device="cuda", eval=True, verbose=False)
8 |
9 | def speaker_embed(audio_path, output_file):
10 | with torch.no_grad():
11 | embedding = smodel.embed_utterance_from_file(audio_path, numpy=False)
12 | embedding = embedding.cpu().detach()
13 | torch.save(embedding, output_file)
14 | return embedding
15 |
16 |
17 | def main(file_name, base_dir, num_shards, shard_id):
18 | files = get_sorted_keys(file_name)
19 | start, end = get_ids(len(files), num_shards, shard_id)
20 | print(start, end)
21 | for file_name in tqdm(files[start:end]):
22 | in_file = os.path.join(base_dir, "wav", file_name + ".wav")
23 | out_file = os.path.join(base_dir, "emb", file_name + ".emb")
24 | if os.path.exists(out_file):
25 | continue
26 | create_directory_recursive(out_file)
27 | speaker_embed(in_file, out_file)
28 |
29 | if __name__ == "__main__":
30 | parser = argparse.ArgumentParser(description="Create Speaker Embeddings")
31 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file")
32 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir")
33 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards")
34 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID")
35 | args = parser.parse_args()
36 | main(args.in_file, args.base_dir, args.num_shards, args.shard_id)
37 |
--------------------------------------------------------------------------------
/preprocess/split.py:
--------------------------------------------------------------------------------
1 | import os, subprocess
2 | import json
3 | import argparse
4 |
5 |
6 | def split_json_file(json_file_path, output_dir, val_size=500):
7 | with open(json_file_path, 'r') as f:
8 | # Read and modify lines
9 | lines = f.readlines()
10 |
11 | N = len(lines)
12 | train_data = lines[:N-val_size]
13 | val_data = lines[N-val_size*2:N-val_size]
14 | test_data = lines[N-val_size:]
15 |
16 | # Write to train.jsonl
17 | with open(os.path.join(output_dir, 'train.jsonl'), 'w', encoding='utf-8') as f:
18 | f.writelines(train_data)
19 |
20 | # Write to val.jsonl
21 | with open(os.path.join(output_dir, 'val.jsonl'), 'w', encoding='utf-8') as f:
22 | f.writelines(val_data)
23 |
24 | # Write to test.jsonl
25 | with open(os.path.join(output_dir, 'test.jsonl'), 'w', encoding='utf-8') as f:
26 | f.writelines(test_data)
27 |
28 |
29 |
30 | # Example usage:
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description="Split list into chunks")
33 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file")
34 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir")
35 | args = parser.parse_args()
36 | split_json_file(args.in_file, args.base_dir)
37 |
--------------------------------------------------------------------------------
/preprocess/text_tokenize.py:
--------------------------------------------------------------------------------
1 | import os, torch, argparse, torch
2 | from tqdm import tqdm
3 | from fam.quantiser.text.tokenise import TrainedBPETokeniser
4 | from utils import get_keys_texts, get_ids, create_directory_recursive
5 | from huggingface_hub import snapshot_download
6 | from fam.llm.training import get_first_stage_path
7 |
8 | model_dir = snapshot_download(repo_id="metavoiceio/metavoice-1B-v0.1")
9 | checkpoint = torch.load(get_first_stage_path(model_dir))
10 | tokenizer = TrainedBPETokeniser(**checkpoint["meta"]["tokenizer"])
11 |
12 | def text_tokenize(text, output_file):
13 | tokens = tokenizer.encode(text)
14 | tokens = torch.tensor(tokens, dtype=torch.int64)
15 | torch.save(tokens, output_file)
16 |
17 | def main(file_name, num_shards, shard_id, base_dir):
18 | keys_text = get_keys_texts(file_name)
19 | start, end = get_ids(len(keys_text), num_shards, shard_id)
20 | print(start, end)
21 | for keys_texts in tqdm(keys_text[start:end]):
22 | file_name = keys_texts["key"]
23 | text = keys_texts["text"]
24 | out_file = os.path.join(base_dir, "txt", file_name + ".pt")
25 | create_directory_recursive(out_file)
26 | if os.path.exists(out_file):
27 | continue
28 | text_tokenize(text, out_file)
29 |
30 | if __name__ == "__main__":
31 | parser = argparse.ArgumentParser(description="Tokenize Text")
32 | args = parser.parse_args()
33 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file")
34 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir")
35 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards")
36 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID")
37 | args = parser.parse_args()
38 | main(args.in_file, args.num_shards, args.shard_id, args.base_dir)
39 |
--------------------------------------------------------------------------------
/preprocess/utils.py:
--------------------------------------------------------------------------------
1 | import os, json
2 |
3 |
4 | def create_directory(dir_path):
5 | """
6 | Recursively create directories if they do not exist.
7 |
8 | Args:
9 | directory_path (str): The directory path to create.
10 | """
11 | if not os.path.exists(dir_path):
12 | os.makedirs(dir_path)
13 |
14 | def list_parent_directories(path):
15 | """
16 | List all parent directories of a given path.
17 |
18 | Args:
19 | path (str): The directory path to find parent directories for.
20 |
21 | Returns:
22 | list: A list of parent directories.
23 | """
24 | parent_directories = []
25 | while True:
26 | path = os.path.dirname(path)
27 | if path == '/' or not path:
28 | break
29 | parent_directories.append(path)
30 | return parent_directories
31 |
32 | def create_directory_recursive(path):
33 | dirs = list_parent_directories(path)
34 | dirs.reverse()
35 | for dir in dirs:
36 | create_directory(dir)
37 |
38 | def get_sorted_keys(in_file, key="key"):
39 | files = []
40 | with open(in_file) as fin:
41 | for line in fin:
42 | line = json.loads(line)
43 | files.append(line[key])
44 | return files
45 |
46 | def get_keys_texts(in_file, key="key"):
47 | data = []
48 | with open(in_file) as fin:
49 | for line in fin:
50 | line = json.loads(line)
51 | data.append(line)
52 | return data
53 |
54 | def get_ids(num_keys, num_shards, shard_id):
55 | shard_size = num_keys//num_shards
56 | remainder = num_keys - shard_size*num_shards
57 | start_id = shard_size*shard_id + min([shard_id, remainder])
58 | end_id = shard_size*(shard_id+1) + min([shard_id+1, remainder])
59 | return start_id, end_id
60 |
61 | def get_dirs(in_file):
62 | files = set()
63 | with open(in_file) as fin:
64 | for line in fin:
65 | line = json.loads(line)
66 | files.add(line["key"].split("/")[0])
67 | files = sorted(files)
68 | return files
69 |
70 | def split_json_file(json_file_path, output_dir):
71 | with open(json_file_path, 'r') as f:
72 | # Read and modify lines
73 | lines = f.readlines()
74 |
75 | N = len(lines)
76 | train_data = lines[:N-1000]
77 | val_data = lines[N-1000:N-500]
78 | test_data = lines[N-500:]
79 |
80 | # Write to train.jsonl
81 | with open(os.path.join(output_dir, 'train.jsonl'), 'w', encoding='utf-8') as f:
82 | f.writelines(train_data)
83 |
84 | # Write to val.jsonl
85 | with open(os.path.join(output_dir, 'val.jsonl'), 'w') as f:
86 | f.writelines(val_data)
87 |
88 | # Write to test.jsonl
89 | with open(os.path.join(output_dir, 'test.jsonl'), 'w') as f:
90 | f.writelines(test_data)
91 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "kotoba-speech"
3 | version = "0.1.0"
4 | description = "Foundational model for text to speech"
5 | requires-python = ">=3.10,<3.12"
6 |
7 | [tool.black]
8 | line-length = 120
9 | exclude = '''
10 | /(
11 | \.git
12 | | \.mypy_cache
13 | | \.tox
14 | | _build
15 | | build
16 | | dist
17 | )/
18 | '''
19 |
20 | [tool.isort]
21 | profile = "black"
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | packaging==23.1
2 | wheel==0.42.0
3 | transformers
4 | librosa
5 | tqdm
6 | tiktoken==0.5.1
7 | audiocraft
8 | numpy
9 | ninja
10 | fastapi
11 | uvicorn
12 | tyro
13 | deepfilternet
14 | pydub
15 | soundfile
16 | gradio
17 | huggingface_hub
18 | wandb
19 | pytorch_lightning
20 | IPython
21 |
--------------------------------------------------------------------------------
/scripts/abci/submit_ddp_1gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #$-l rt_AG.small=1
4 | #$-l h_rt=1:00:00
5 | #$-j y
6 | #$ -o outputs/a-node/
7 | #$-cwd
8 |
9 | #submit: qsub -g gcd50698 scripts/abci/submit_ddp_1node.sh
10 |
11 | source /etc/profile.d/modules.sh
12 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7
13 | source myenv/bin/activate
14 |
15 | python fam/llm/train.py --num_gpus 1 --batch_size 1 --per_gpu_batchsize 1 --check_val_every_n_epoch 1 --exp_name kotoba_voice_1.3B_debug_abci_reazon_small --max_epoch 1 --debug_data_dir /groups/gcd50698/reazon_data/reazon_small --debug --use_ddp_strategy
16 |
--------------------------------------------------------------------------------
/scripts/abci/submit_ddp_1node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #$-l rt_AF=1
4 | #$-l h_rt=0:10:00
5 | #$-j y
6 | #$ -o outputs/a-node/
7 | #$-cwd
8 |
9 | #submit: qsub -g gcd50698 scripts/abci/submit_ddp_1node.sh
10 |
11 | source /etc/profile.d/modules.sh
12 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7
13 | source myenv/bin/activate
14 |
15 | python fam/llm/train.py --num_gpus 8 --batch_size 16 --per_gpu_batchsize 1 --check_val_every_n_epoch 1 --exp_name kotoba_voice_1.3B_debug_abci_reazon_small --max_epoch 1 --debug_data_dir /groups/gcd50698/reazon_data/reazon_small --debug --use_ddp_strategy
16 |
--------------------------------------------------------------------------------
/scripts/abci/submit_fsdp_1node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #$-l rt_AF=1
4 | #$-l h_rt=0:10:00
5 | #$-j y
6 | #$ -o outputs/a-node/
7 | #$-cwd
8 |
9 | #submit: qsub -g gcd50698 scripts/abci/submit_fsdp_1node.sh
10 |
11 | source /etc/profile.d/modules.sh
12 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7
13 | source myenv/bin/activate
14 |
15 | python fam/llm/train.py --num_gpus 8 --batch_size 16 --per_gpu_batchsize 1 --check_val_every_n_epoch 1 --exp_name kotoba_voice_1.3B_debug_abci_reazon_small --max_epoch 1 --debug_data_dir /groups/gcd50698/reazon_data/reazon_small --debug --fsdp_strategy SHARD_GRAD_OP
16 |
--------------------------------------------------------------------------------
/scripts/abci/submit_fsdp_2node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #$-l rt_AF=2
4 | #$-l h_rt=0:10:00
5 | #$-j y
6 | #$ -o outputs/a-node/
7 | #$-cwd
8 |
9 | #submit: qsub -g gcd50698 scripts/abci/submit_fsdp_2node.sh
10 |
11 | # module
12 | source /etc/profile.d/modules.sh
13 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7 hpcx/2.12
14 | source myenv/bin/activate
15 |
16 | # distributed settings
17 | export MASTER_ADDR=$(/usr/sbin/ip a show dev bond0 | grep 'inet ' | awk '{ print $2 }' | cut -d "/" -f 1)
18 | export MASTER_PORT=$((10000 + ($JOB_ID % 50000)))
19 | export NUM_GPU_PER_NODE=8
20 | NUM_NODES=$NHOSTS
21 | NUM_GPUS=$((${NUM_NODES} * ${NUM_GPU_PER_NODE}))
22 |
23 | mkdir -p ./hostfile
24 | HOSTFILE_NAME=./hostfile/hostfile_${JOB_ID}
25 | while read -r line; do
26 | echo "${line} slots=${NUM_GPU_PER_NODE}"
27 | done <"$SGE_JOB_HOSTLIST" >"$HOSTFILE_NAME"
28 |
29 | mpirun -np $NUM_GPUS \
30 | -hostfile $HOSTFILE_NAME \
31 | -x MASTER_ADDR=$MASTER_ADDR \
32 | -x MASTER_PORT=$MASTER_PORT \
33 | -bind-to none -map-by slot \
34 | -x PATH \
35 | python fam/llm/train.py \
36 | --num_nodes 2 \
37 | --num_gpus 8 \
38 | --batch_size 16 \
39 | --per_gpu_batchsize 1 \
40 | --check_val_every_n_epoch 1 \
41 | --exp_name kotoba_voice_1.3B_debug_abci_reazon_small \
42 | --max_epoch 1 \
43 | --debug_data_dir /groups/gcd50698/reazon_data/reazon_small \
44 | --debug \
45 | --fsdp_strategy SHARD_GRAD_OP
46 |
47 | # python -m torch.distributed.run \
48 | # --nnodes 2 \
49 | # --master_addr $MASTER_ADDR \
50 | # --master_port $MASTER_PORT \
51 | # --nproc_per_node 8 \
52 | # fam/llm/train.py \
53 | # --num_nodes 2 \
54 | # --num_gpus 8 \
55 | # --batch_size 16 \
56 | # --per_gpu_batchsize 1 \
57 | # --check_val_every_n_epoch 1 \
58 | # --exp_name kotoba_voice_1.3B_debug_abci_reazon_small \
59 | # --max_epoch 1 \
60 | # --debug_data_dir /groups/gcd50698/reazon_data/reazon_small \
61 | # --debug \
62 | # --fsdp_strategy SHARD_GRAD_OP
63 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup # type: ignore
2 |
3 | setup(
4 | name="fam",
5 | packages=find_packages(".", exclude=["tests"]),
6 | )
7 |
--------------------------------------------------------------------------------