├── .gitignore ├── LICENSE ├── README.md ├── assets ├── aws.mov ├── aws.wav ├── bria.mp3 ├── google.mov ├── google.wav ├── kansai_demo.mov ├── kotoba-speech_demo.mov ├── kotoba.mov ├── kotoba.wav ├── kotoba_cloning.mov ├── kotoba_cloning.wav └── logo.png ├── fam ├── __init__.py ├── llm │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── base.py │ │ ├── flattened_encodec.py │ │ └── tilted_encodec.py │ ├── decoders.py │ ├── enhancers.py │ ├── fast_inference.py │ ├── fast_inference_utils.py │ ├── fast_model.py │ ├── inference.py │ ├── layers │ │ ├── __init__.py │ │ ├── attn.py │ │ ├── combined.py │ │ └── layers.py │ ├── mixins │ │ ├── __init__.py │ │ ├── causal.py │ │ └── non_causal.py │ ├── model.py │ ├── sample.py │ ├── serving.py │ ├── train.py │ ├── training │ │ ├── __init__.py │ │ ├── args.py │ │ ├── datamodule.py │ │ ├── dataset.py │ │ ├── dist.py │ │ ├── evaluator.py │ │ ├── optimizer.py │ │ ├── utils.py │ │ └── wandb_utils.py │ └── utils.py ├── py.typed ├── quantiser │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ └── speaker_encoder │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── ckpt │ │ │ └── ckpt.pt │ │ │ └── model.py │ └── text │ │ └── tokenise.py └── ui │ ├── app.py │ └── assets │ └── favicon.ico ├── preprocess ├── audio_tokenize.py ├── download_reazon.py ├── spk_embed.py ├── split.py ├── text_tokenize.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── scripts └── abci │ ├── submit_ddp_1gpu.sh │ ├── submit_ddp_1node.sh │ ├── submit_fsdp_1node.sh │ └── submit_fsdp_2node.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.flac 3 | *.npz 4 | *.wav 5 | *.m4a 6 | *.opus 7 | *.npy 8 | *wandb 9 | *.parquet 10 | *.wav 11 | *.pt 12 | *.bin 13 | *.png 14 | *.DS_Store 15 | *.idea 16 | *.ipynb_checkpoints/ 17 | *__pycache__/ 18 | *.pyc 19 | *.tsv 20 | *.bak 21 | *.tar 22 | *.db 23 | *.dat 24 | *.json 25 | 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | *$py.class 30 | 31 | # C extensions 32 | *.so 33 | 34 | # Distribution / packaging 35 | .Python 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | share/python-wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | *.py,cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | cover/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | db.sqlite3-journal 88 | 89 | # Flask stuff: 90 | instance/ 91 | .webassets-cache 92 | 93 | # Scrapy stuff: 94 | .scrapy 95 | 96 | # Sphinx documentation 97 | docs/_build/ 98 | 99 | # PyBuilder 100 | .pybuilder/ 101 | target/ 102 | 103 | # Jupyter Notebook 104 | .ipynb_checkpoints 105 | 106 | # IPython 107 | profile_default/ 108 | ipython_config.py 109 | 110 | # pyenv 111 | # For a library or package, you might want to ignore these files since the code is 112 | # intended to run in multiple environments; otherwise, check them in: 113 | # .python-version 114 | 115 | # pipenv 116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 119 | # install all needed dependencies. 120 | #Pipfile.lock 121 | 122 | # poetry 123 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 124 | # This is especially recommended for binary packages to ensure reproducibility, and is more 125 | # commonly ignored for libraries. 126 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 127 | #poetry.lock 128 | 129 | # pdm 130 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 131 | #pdm.lock 132 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 133 | # in version control. 134 | # https://pdm.fming.dev/#use-with-ide 135 | .pdm.toml 136 | 137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 138 | __pypackages__/ 139 | 140 | # Celery stuff 141 | celerybeat-schedule 142 | celerybeat.pid 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | 174 | # pytype static type analyzer 175 | .pytype/ 176 | 177 | # Cython debug symbols 178 | cython_debug/ 179 | 180 | # PyCharm 181 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 182 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 183 | # and can be added to the global gitignore or merged into this file. For a more nuclear 184 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 185 | #.idea/ 186 | **/.tmp 187 | !fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt 188 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kotoba-Speech Version. 0.1 2 | Welcome to the code repository for Kotoba-Speech v0.1, a 1.2B Transformer-based speech generative model designed for generating fluent Japanese speech. This model represents one of the most advanced open-source options available in the field. 3 | 4 | Questions, feature requests, or bug reports? Join [our Discord community](https://discord.com/invite/qPVFqhGN7Z)! 5 | 6 | Kotoba-Speech Logo 7 | 8 | ## About 9 | Kotoba-Speech Version 0.1 distinguishes itself as an open-source solution for generating high-quality Japanese speech from text prompts, while also offering the capability for voice cloning through speech prompts. 10 | 11 | - **_Demo:_** Experience Kotoba-Speech in action [here](https://huggingface.co/spaces/kotoba-tech/Kotoba-Speech). 12 | - **_Model Checkpoint:_** Access our commercially usable pre-trained model [here](https://huggingface.co/kotoba-tech/kotoba-speech-v0.1). 13 | - **_Open-sourced Code:_** This repository opensources the training and inference code, along with the Gradio demo code. We borrow code from [MetaVoice](https://github.com/metavoiceio/metavoice-src) as a starting point. 14 | 15 | ### Our Model vs. Leading TTS Providers for Japanese 16 | https://github.com/kotoba-tech/kotoba-speech-release/assets/18011504/516f56f4-db92-45cb-b2e5-92863b36f8cd 17 | 18 | ### Fine-tuning Our Pre-trained Model for 関西弁 19 | https://github.com/kotoba-tech/kotoba-speech-release/assets/18011504/0204938e-7bb2-4c9f-9c6b-cb1e5c2dcff4 20 | 21 | ## Table of Contents 22 | 23 | 1. **Installation** 24 | 2. **Preparing Datasets** 25 | 3. **Training** 26 | 4. **Inference** 27 | 5. **Other Notes** 28 | 29 | ## 1. Installation 30 | ```bash 31 | # Installing ffmpeg 32 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz 33 | wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5 34 | md5sum -c ffmpeg-git-amd64-static.tar.xz.md5 35 | tar xvf ffmpeg-git-amd64-static.tar.xz 36 | sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ 37 | rm -rf ffmpeg-git-* 38 | 39 | # Setting-up Python virtual environment 40 | python -m venv myenv 41 | source myenv/bin/activate 42 | pip install -U --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 43 | pip install -r requirements.txt 44 | pip install flash-attn==2.5.3 45 | pip install -e . 46 | ``` 47 | 48 | ## 2. Preparing Datasets 49 | We provide an example of preparing datasets to train our model. We use Reazon Speech, the largest open-sourced Japanese speech dataset, as an example. (Note that our model is not necessary trained on Reazon Speech solely.) 50 | ```bash 51 | # Download & Format Data 52 | python preprocess/download_reazon.py 53 | 54 | # Pre-calculate Speaker Embeddings 55 | python preprocess/spk_embed.py 56 | 57 | # Tokenize Audio 58 | python preprocess/audio_tokenize.py 59 | 60 | # Tokenize Text Captions 61 | python preprocess/text_tokenize.py 62 | 63 | # Split data into (training/validation/test) 64 | python preprocess/split.py 65 | ``` 66 | 67 | ## 3. Training 68 | ```bash 69 | # Fine-tuning from our pre-trained checkpoint 70 | # Replace YOUR_WANDB_ENTITY and YOUR_WANDB_PROJECT 71 | python fam/llm/train.py --num_gpus 1 --batch_size 32 --per_gpu_batchsize 2 --max_epoch 5 --learning_rate 0.00005 --data_dir data --exp_name reazon_small_exp_finetuning --spkemb_dropout 0.1 --check_val_every_n_epoch 1 --wandb_entity YOUR_WANDB_ENTITY --wandb_project YOUR_WANDB_PROJECT --use_wandb 72 | 73 | # Multi-GPU Fine-tuning (e.g., using 2 GPUs) 74 | # Replace YOUR_WANDB_ENTITY and YOUR_WANDB_PROJECT 75 | python fam/llm/train.py --num_gpus 2 --batch_size 32 --per_gpu_batchsize 2 --max_epoch 5 --learning_rate 0.00005 --data_dir data --exp_name reazon_small_exp_finetuning --spkemb_dropout 0.1 --check_val_every_n_epoch 1 --wandb_entity YOUR_WANDB_ENTITY --wandb_project YOUR_WANDB_PROJECT --use_wandb 76 | 77 | # Fine-tuning (without WandB logging) 78 | python fam/llm/train.py --num_gpus 1 --batch_size 32 --per_gpu_batchsize 2 --max_epoch 5 --learning_rate 0.00005 --data_dir data --exp_name reazon_small_exp_finetuning --spkemb_dropout 0.1 --check_val_every_n_epoch 1 79 | 80 | # Training from scratch 81 | # Replace YOUR_WANDB_ENTITY and YOUR_WANDB_PROJECT 82 | python fam/llm/train.py --num_gpus 1 --batch_size 64 --per_gpu_batchsize 2 --max_epoch 20 --learning_rate 0.0001 --data_dir data --exp_name reazon_small_exp --spkemb_dropout 0.1 --check_val_every_n_epoch 1 --wandb_entity YOUR_WANDB_ENTITY --wandb_project YOUR_WANDB_PROJECT --use_wandb --train_from_scratch 83 | ``` 84 | 85 | ## 4. Inference 86 | ```bash 87 | # Our Pre-trained Checkpoint 88 | python -i fam/llm/fast_inference.py --model_name kotoba-tech/kotoba-speech-v0.1 89 | tts.synthesise(text="コトバテクノロジーズのミッションは音声基盤モデルを作る事です。", spk_ref_path="assets/bria.mp3") 90 | 91 | # Inference from Our Pre-trained Checkpoint (関西弁) 92 | python -i fam/llm/fast_inference.py --model_name kotoba-tech/kotoba-speech-v0.1-kansai 93 | tts.synthesise(text="コトバテクノロジーズのミッションは音声基盤モデルを作る事です。", spk_ref_path="assets/bria.mp3") 94 | 95 | # Inference from Your Own Pre-trained Checkpoint 96 | # YOUR_CHECKPOINT_PATH is something like /home/checkpoints/epoch=0-step=1810.ckpt 97 | python -i fam/llm/fast_inference.py --first_model_path YOUR_CHECKPOINT_PATH 98 | tts.synthesise(text="コトバテクノロジーズのミッションは音声基盤モデルを作る事です。", spk_ref_path="assets/bria.mp3") 99 | ``` 100 | 101 | ## 5. Other notes 102 | ## 5.1 Contribute 103 | - See all [active issues](https://github.com/kotoba-tech/kotoba-speech-release/issues)! 104 | 105 | ## 5.2 Enhancements 106 | - [ ] Write an explanation about multi-node training 107 | - [ ] Integrade a gradio demo 108 | 109 | ## 5.3 Acknowledgements 110 | We thank [MetaVoice](https://github.com/metavoiceio/metavoice-src) for releasing their code and their English pre-trained model. 111 | -------------------------------------------------------------------------------- /assets/aws.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/aws.mov -------------------------------------------------------------------------------- /assets/aws.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/aws.wav -------------------------------------------------------------------------------- /assets/bria.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/bria.mp3 -------------------------------------------------------------------------------- /assets/google.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/google.mov -------------------------------------------------------------------------------- /assets/google.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/google.wav -------------------------------------------------------------------------------- /assets/kansai_demo.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kansai_demo.mov -------------------------------------------------------------------------------- /assets/kotoba-speech_demo.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba-speech_demo.mov -------------------------------------------------------------------------------- /assets/kotoba.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba.mov -------------------------------------------------------------------------------- /assets/kotoba.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba.wav -------------------------------------------------------------------------------- /assets/kotoba_cloning.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba_cloning.mov -------------------------------------------------------------------------------- /assets/kotoba_cloning.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/kotoba_cloning.wav -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/assets/logo.png -------------------------------------------------------------------------------- /fam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/__init__.py -------------------------------------------------------------------------------- /fam/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/llm/__init__.py -------------------------------------------------------------------------------- /fam/llm/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook 2 | from fam.llm.adapters.tilted_encodec import TiltedEncodec 3 | -------------------------------------------------------------------------------- /fam/llm/adapters/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseDataAdapter(ABC): 5 | pass 6 | -------------------------------------------------------------------------------- /fam/llm/adapters/flattened_encodec.py: -------------------------------------------------------------------------------- 1 | from fam.llm.adapters.base import BaseDataAdapter 2 | 3 | 4 | class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter): 5 | def __init__(self, end_of_audio_token): 6 | self._end_of_audio_token = end_of_audio_token 7 | 8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]: 9 | assert len(tokens) == 1 10 | tokens = tokens[0] 11 | 12 | text_ids = [] 13 | extracted_audio_ids = [[], []] 14 | 15 | for t in tokens: 16 | if t < self._end_of_audio_token: 17 | extracted_audio_ids[0].append(t) 18 | elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token: 19 | extracted_audio_ids[1].append(t - self._end_of_audio_token) 20 | # We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token 21 | elif t > 2 * self._end_of_audio_token: 22 | text_ids.append(t) 23 | 24 | if len(set([len(x) for x in extracted_audio_ids])) != 1: 25 | min_len = min([len(x) for x in extracted_audio_ids]) 26 | max_len = max([len(x) for x in extracted_audio_ids]) 27 | print("WARNING: Number of tokens at each hierarchy must be of the same length!") 28 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.") 29 | print([len(x) for x in extracted_audio_ids]) 30 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids] 31 | 32 | return text_ids[:-1], extracted_audio_ids 33 | 34 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]): 35 | """ 36 | Performs the required combination and padding as needed. 37 | """ 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /fam/llm/adapters/tilted_encodec.py: -------------------------------------------------------------------------------- 1 | from fam.llm.adapters.base import BaseDataAdapter 2 | 3 | 4 | class TiltedEncodec(BaseDataAdapter): 5 | def __init__(self, end_of_audio_token): 6 | self._end_of_audio_token = end_of_audio_token 7 | 8 | def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]: 9 | assert len(tokens) > 1 10 | 11 | text_ids = [] 12 | extracted_audio_ids = [] 13 | 14 | extracted_audio_ids.append([]) 15 | # Handle first hierarchy as special case as it contains text tokens as well 16 | # TODO: maybe it doesn't need special case, and can be handled on it's own :) 17 | for t in tokens[0]: 18 | if t > self._end_of_audio_token: 19 | text_ids.append(t) 20 | elif t < self._end_of_audio_token: 21 | extracted_audio_ids[0].append(t) 22 | 23 | # Handle the rest of the hierarchies 24 | for i in range(1, len(tokens)): 25 | token_hierarchy_ids = tokens[i] 26 | extracted_audio_ids.append([]) 27 | for t in token_hierarchy_ids: 28 | if t < self._end_of_audio_token: 29 | extracted_audio_ids[i].append(t) 30 | 31 | if len(set([len(x) for x in extracted_audio_ids])) != 1: 32 | min_len = min([len(x) for x in extracted_audio_ids]) 33 | max_len = max([len(x) for x in extracted_audio_ids]) 34 | print("WARNING: Number of tokens at each hierarchy must be of the same length!") 35 | print(f"Truncating to min length of {min_len} tokens from {max_len} max.") 36 | print([len(x) for x in extracted_audio_ids]) 37 | extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids] 38 | 39 | return text_ids[:-1], extracted_audio_ids 40 | 41 | def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]): 42 | """ 43 | Performs the required combination and padding as needed. 44 | """ 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /fam/llm/decoders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import uuid 4 | from abc import ABC, abstractmethod 5 | from typing import Callable, Optional, Union 6 | 7 | import julius 8 | import torch 9 | from audiocraft.data.audio import audio_read, audio_write 10 | from audiocraft.models import MultiBandDiffusion # type: ignore 11 | 12 | from IPython import embed 13 | 14 | class Decoder(ABC): 15 | @abstractmethod 16 | def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None): 17 | raise NotImplementedError 18 | 19 | 20 | class EncodecDecoder(Decoder): 21 | def __init__( 22 | self, 23 | tokeniser_decode_fn: Callable[[list[int]], str], 24 | data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]], 25 | output_dir: str, 26 | ): 27 | self._mbd_bandwidth = 6 # 1.5 28 | self._mbd_sample_rate = 24_000 29 | self._end_of_audio_token = 1024 30 | self._num_codebooks = 8 31 | self.mbd = MultiBandDiffusion.get_mbd_24khz(bw=self._mbd_bandwidth) 32 | 33 | self.tokeniser_decode_fn = tokeniser_decode_fn 34 | self._data_adapter_fn = data_adapter_fn 35 | 36 | self.output_dir = pathlib.Path(output_dir).resolve() 37 | os.makedirs(self.output_dir, exist_ok=True) 38 | 39 | def _save_audio(self, name: str, wav: torch.Tensor): 40 | audio_write( 41 | name, 42 | wav.squeeze(0).cpu(), 43 | self._mbd_sample_rate, 44 | strategy="loudness", 45 | loudness_compressor=True, 46 | ) 47 | 48 | def get_tokens(self, audio_path: str) -> list[list[int]]: 49 | """ 50 | Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g. 51 | limited codebook reconstruction or sampling from second stage model only). 52 | """ 53 | pass 54 | wav, sr = audio_read(audio_path) 55 | if sr != self._mbd_sample_rate: 56 | wav = julius.resample_frac(wav, sr, self._mbd_sample_rate) 57 | if wav.ndim == 2: 58 | wav = wav.unsqueeze(1) 59 | wav = wav.to("cuda") 60 | tokens = self.mbd.codec_model.encode(wav) 61 | tokens = tokens[0][0] 62 | return tokens.tolist() 63 | 64 | def decode( 65 | self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None 66 | ) -> Union[str, torch.Tensor]: 67 | # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file. 68 | text_ids, extracted_audio_ids = self._data_adapter_fn(tokens) 69 | text = self.tokeniser_decode_fn(text_ids) 70 | print(f"Text: {text}") 71 | 72 | tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0) 73 | 74 | if tokens.shape[1] < self._num_codebooks: 75 | tokens = torch.cat( 76 | [tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1 77 | ) 78 | 79 | if causal: 80 | return tokens 81 | else: 82 | with torch.amp.autocast(device_type="cuda", dtype=torch.float32): 83 | # embed() 84 | wav = self.mbd.tokens_to_wav(tokens) 85 | # NOTE: we couldn't just return wav here as it goes through loudness compression etc :) 86 | 87 | if wav.shape[-1] < 9600: 88 | # this causes problem for the code below, and is also odd :) 89 | # first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!) 90 | raise Exception("wav predicted is shorter than 400ms!") 91 | 92 | try: 93 | wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}" 94 | self._save_audio(wav_file_name, wav) 95 | print(f"\nSaved audio to {wav_file_name}.wav") 96 | return wav_file_name 97 | except Exception as e: 98 | print(f"Failed to save audio! Reason: {e}") 99 | wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}" 100 | self._save_audio(wav_file_name, wav) 101 | print(f"\nSaved audio to {wav_file_name}.wav") 102 | return wav_file_name 103 | -------------------------------------------------------------------------------- /fam/llm/enhancers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC 3 | from typing import Literal, Optional 4 | 5 | from df.enhance import enhance, init_df, load_audio, save_audio 6 | from pydub import AudioSegment 7 | 8 | 9 | def convert_to_wav(input_file: str, output_file: str): 10 | """Convert an audio file to WAV format 11 | 12 | Args: 13 | input_file (str): path to input audio file 14 | output_file (str): path to output WAV file 15 | 16 | """ 17 | # Detect the format of the input file 18 | format = input_file.split(".")[-1].lower() 19 | 20 | # Read the audio file 21 | audio = AudioSegment.from_file(input_file, format=format) 22 | 23 | # Export as WAV 24 | audio.export(output_file, format="wav") 25 | 26 | 27 | def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str: 28 | """Generate the output file path 29 | 30 | Args: 31 | audio_file (str): path to input audio file 32 | tag (str): tag to append to the output file name 33 | ext (str, optional): extension of the output file. Defaults to None. 34 | 35 | Returns: 36 | str: path to output file 37 | """ 38 | 39 | directory = "./enhanced" 40 | # Get the name of the input file 41 | filename = os.path.basename(audio_file) 42 | 43 | # Get the name of the input file without the extension 44 | filename_without_extension = os.path.splitext(filename)[0] 45 | 46 | # Get the extension of the input file 47 | extension = ext or os.path.splitext(filename)[1] 48 | 49 | # Generate the output file path 50 | output_file = os.path.join(directory, filename_without_extension + tag + extension) 51 | 52 | return output_file 53 | 54 | 55 | class BaseEnhancer(ABC): 56 | """Base class for audio enhancers""" 57 | 58 | def __init__(self, *args, **kwargs): 59 | raise NotImplementedError 60 | 61 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str: 62 | raise NotImplementedError 63 | 64 | def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str: 65 | output_file = make_output_file_path(audio_file, tag, ext=ext) 66 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 67 | return output_file 68 | 69 | 70 | class DFEnhancer(BaseEnhancer): 71 | def __init__(self, *args, **kwargs): 72 | self.model, self.df_state, _ = init_df() 73 | 74 | def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str: 75 | output_file = output_file or self.get_output_file(audio_file, "_df") 76 | 77 | audio, _ = load_audio(audio_file, sr=self.df_state.sr()) 78 | 79 | enhanced = enhance(self.model, self.df_state, audio) 80 | 81 | save_audio(output_file, enhanced, self.df_state.sr()) 82 | 83 | return output_file 84 | 85 | 86 | def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer: 87 | """Get an audio enhancer 88 | 89 | Args: 90 | enhancer_name (Literal["df"]): name of the audio enhancer 91 | 92 | Raises: 93 | ValueError: if the enhancer name is not recognised 94 | 95 | Returns: 96 | BaseEnhancer: audio enhancer 97 | """ 98 | 99 | if enhancer_name == "df": 100 | return DFEnhancer() 101 | else: 102 | raise ValueError(f"Unknown enhancer name: {enhancer_name}") 103 | -------------------------------------------------------------------------------- /fam/llm/fast_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tyro 4 | import tempfile 5 | import time 6 | from pathlib import Path 7 | 8 | import librosa 9 | import torch 10 | from huggingface_hub import snapshot_download 11 | 12 | from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook 13 | from fam.llm.decoders import EncodecDecoder 14 | from fam.llm.fast_inference_utils import build_model, main 15 | from fam.llm.inference import ( 16 | EncodecDecoder, 17 | InferenceConfig, 18 | Model, 19 | TiltedEncodec, 20 | TrainedBPETokeniser, 21 | get_cached_embedding, 22 | get_cached_file, 23 | get_enhancer, 24 | ) 25 | from fam.llm.utils import ( 26 | check_audio_file, 27 | get_default_dtype, 28 | get_device, 29 | normalize_text, 30 | ) 31 | import argparse 32 | 33 | 34 | class TTS: 35 | def __init__( 36 | self, model_name: str = "kotoba-tech/kotoba-speech-v0.1", *, seed: int = 1337, output_dir: str = "outputs", first_model_path: str = None, 37 | ): 38 | """ 39 | model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/kotoba-tech/) 40 | """ 41 | 42 | # NOTE: this needs to come first so that we don't change global state when we want to use 43 | # the torch.compiled-model. 44 | self._dtype = get_default_dtype() 45 | self._device = get_device() 46 | self._first_model_dir = snapshot_download(repo_id=model_name) 47 | if model_name != "kotoba-tech/kotoba-speech-v0.1": 48 | self._other_model_dir = snapshot_download(repo_id="kotoba-tech/kotoba-speech-v0.1") 49 | else: 50 | self._other_model_dir = self._first_model_dir 51 | 52 | self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024) 53 | self.output_dir = output_dir 54 | os.makedirs(self.output_dir, exist_ok=True) 55 | 56 | second_stage_ckpt_path = f"{self._other_model_dir}/second_stage.pt" 57 | config_second_stage = InferenceConfig( 58 | ckpt_path=second_stage_ckpt_path, 59 | num_samples=1, 60 | seed=seed, 61 | device=self._device, 62 | dtype=self._dtype, 63 | compile=False, 64 | init_from="resume", 65 | output_dir=self.output_dir, 66 | ) 67 | data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024) 68 | self.llm_second_stage = Model( 69 | config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode 70 | ) 71 | self.enhancer = get_enhancer("df") 72 | 73 | self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype] 74 | self.model, self.tokenizer, self.smodel, self.model_size = build_model( 75 | precision=self.precision, 76 | checkpoint_path=Path(f"{self._other_model_dir}/first_stage.pt"), 77 | spk_emb_ckpt_path=Path(f"{self._other_model_dir}/speaker_encoder.pt"), 78 | device=self._device, 79 | compile=True, 80 | compile_prefill=True, 81 | first_model_path= first_model_path if (first_model_path is not None) and (os.path.exists(first_model_path)) else Path(f"{self._first_model_dir}/first_stage.pt"), 82 | ) 83 | 84 | 85 | def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str: 86 | """ 87 | text: Text to speak 88 | spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3 89 | top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker 90 | guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style. 91 | temperature: Temperature for sampling applied to both LLMs (first & second stage) 92 | 93 | returns: path to speech .wav file 94 | """ 95 | text = normalize_text(text) 96 | spk_ref_path = get_cached_file(spk_ref_path) 97 | check_audio_file(spk_ref_path) 98 | spk_emb = get_cached_embedding( 99 | spk_ref_path, 100 | self.smodel, 101 | ).to(device=self._device, dtype=self.precision) 102 | 103 | start = time.time() 104 | # first stage LLM 105 | tokens = main( 106 | model=self.model, 107 | tokenizer=self.tokenizer, 108 | model_size=self.model_size, 109 | prompt=text, 110 | spk_emb=spk_emb, 111 | top_p=torch.tensor(top_p, device=self._device, dtype=self.precision), 112 | guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision), 113 | temperature=torch.tensor(temperature, device=self._device, dtype=self.precision), 114 | ) 115 | text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens]) 116 | 117 | b_speaker_embs = spk_emb.unsqueeze(0) 118 | 119 | # second stage LLM + multi-band diffusion model 120 | wav_files = self.llm_second_stage( 121 | texts=[text], 122 | encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)], 123 | speaker_embs=b_speaker_embs, 124 | batch_size=1, 125 | guidance_scale=None, 126 | top_p=None, 127 | top_k=200, 128 | temperature=1.0, 129 | max_new_tokens=None, 130 | ) 131 | 132 | # enhance using deepfilternet 133 | wav_file = wav_files[0] 134 | with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp: 135 | self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name) 136 | shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav") 137 | print(f"\nSaved audio to {wav_file}.wav") 138 | 139 | # calculating real-time factor (RTF) 140 | time_to_synth_s = time.time() - start 141 | audio, sr = librosa.load(str(wav_file) + ".wav") 142 | duration_s = librosa.get_duration(y=audio, sr=sr) 143 | print(f"\nTotal time to synth (s): {time_to_synth_s}") 144 | print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}") 145 | 146 | return str(wav_file) + ".wav" 147 | 148 | 149 | if __name__ == "__main__": 150 | tts = tyro.cli(TTS) -------------------------------------------------------------------------------- /fam/llm/fast_inference_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kotoba Technologies, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted 5 | # provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this list of 8 | # conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this 11 | # list of conditions and the following disclaimer in the documentation and/or other 12 | # materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its contributors 15 | # may be used to endorse or promote products derived from this software without 16 | # specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR 19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import itertools 27 | import gc 28 | import time 29 | from pathlib import Path 30 | from typing import Optional, Tuple 31 | 32 | import torch 33 | import torch._dynamo.config 34 | import torch._inductor.config 35 | import tqdm 36 | 37 | 38 | def device_sync(device): 39 | if "cuda" in device: 40 | torch.cuda.synchronize() 41 | elif "cpu" in device: 42 | pass 43 | else: 44 | print(f"device={device} is not yet suppported") 45 | 46 | 47 | torch._inductor.config.coordinate_descent_tuning = True 48 | torch._inductor.config.triton.unique_kernel_names = True 49 | torch._inductor.config.fx_graph_cache = ( 50 | True # Experimental feature to reduce compilation times, will be on by default in future 51 | ) 52 | 53 | # imports need to happen after setting above flags 54 | from fam.llm.fast_model import Transformer 55 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder 56 | from fam.quantiser.text.tokenise import TrainedBPETokeniser 57 | 58 | 59 | def multinomial_sample_one_no_sync( 60 | probs_sort, 61 | ): # Does multinomial sampling without a cuda synchronization 62 | q = torch.empty_like(probs_sort).exponential_(1) 63 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 64 | 65 | 66 | def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor): 67 | # ref: huggingface/transformers 68 | 69 | sorted_logits, sorted_indices = torch.sort(logits, descending=False) 70 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 71 | 72 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 73 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p) 74 | # Keep at least min_tokens_to_keep 75 | sorted_indices_to_remove[-1:] = 0 76 | 77 | # scatter sorted tensors to original indexing 78 | indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) 79 | scores = logits.masked_fill(indices_to_remove, -float("Inf")) 80 | return scores 81 | 82 | 83 | def logits_to_probs( 84 | logits, 85 | *, 86 | temperature: torch.Tensor, 87 | top_p: Optional[torch.Tensor] = None, 88 | top_k: Optional[torch.Tensor] = None, 89 | ): 90 | logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature)) 91 | 92 | if top_k is not None: 93 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 94 | pivot = v.select(-1, -1).unsqueeze(-1) 95 | logits = torch.where(logits < pivot, -float("Inf"), logits) 96 | 97 | if top_p is not None: 98 | logits = top_p_sample(logits, top_p) 99 | 100 | probs = torch.nn.functional.softmax(logits, dim=-1) 101 | 102 | return probs 103 | 104 | 105 | def sample( 106 | logits, 107 | guidance_scale: torch.Tensor, 108 | temperature: torch.Tensor, 109 | top_p: Optional[torch.Tensor] = None, 110 | top_k: Optional[torch.Tensor] = None, 111 | ): 112 | # (b, t, vocab_size) 113 | logits = logits[:, -1] 114 | logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0) 115 | logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb 116 | probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k) 117 | idx_next = multinomial_sample_one_no_sync(probs) 118 | return idx_next, probs 119 | 120 | 121 | def prefill( 122 | model: Transformer, 123 | x: torch.Tensor, 124 | spk_emb: torch.Tensor, 125 | input_pos: torch.Tensor, 126 | **sampling_kwargs, 127 | ) -> torch.Tensor: 128 | # input_pos: [B, S] 129 | logits = model(x, spk_emb, input_pos) 130 | return sample(logits, **sampling_kwargs)[0] 131 | 132 | 133 | def decode_one_token( 134 | model: Transformer, 135 | x: torch.Tensor, 136 | spk_emb: torch.Tensor, 137 | input_pos: torch.Tensor, 138 | **sampling_kwargs, 139 | ) -> Tuple[torch.Tensor, torch.Tensor]: 140 | # input_pos: [B, 1] 141 | assert input_pos.shape[-1] == 1 142 | logits = model(x, spk_emb, input_pos) 143 | return sample(logits, **sampling_kwargs) 144 | 145 | 146 | def decode_n_tokens( 147 | model: Transformer, 148 | cur_token: torch.Tensor, 149 | spk_emb: torch.Tensor, 150 | input_pos: torch.Tensor, 151 | num_new_tokens: int, 152 | callback=lambda _: _, 153 | return_probs: bool = False, 154 | end_of_audio_token: int = 2048, 155 | **sampling_kwargs, 156 | ): 157 | new_tokens, new_probs = [], [] 158 | for i in tqdm.tqdm(range(num_new_tokens)): 159 | if (cur_token == end_of_audio_token).any(): 160 | break 161 | with torch.backends.cuda.sdp_kernel( 162 | enable_flash=False, enable_mem_efficient=False, enable_math=True 163 | ): # Actually better for Inductor to codegen attention here 164 | next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs) 165 | input_pos += 1 166 | new_tokens.append(next_token.clone()) 167 | callback(new_tokens[-1]) 168 | if return_probs: 169 | new_probs.append(next_prob.clone()) 170 | cur_token = next_token.view(1, -1).repeat(2, 1) 171 | 172 | return new_tokens, new_probs 173 | 174 | 175 | def model_forward(model, x, spk_emb, input_pos): 176 | return model(x, spk_emb, input_pos) 177 | 178 | 179 | @torch.no_grad() 180 | def generate( 181 | model: Transformer, 182 | prompt: torch.Tensor, 183 | spk_emb: torch.Tensor, 184 | *, 185 | max_new_tokens: Optional[int] = None, 186 | callback=lambda x: x, 187 | end_of_audio_token: int = 2048, 188 | **sampling_kwargs, 189 | ) -> torch.Tensor: 190 | """ 191 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 192 | """ 193 | # create an empty tensor of the expected final shape and fill in the current tokens 194 | T = prompt.size(0) 195 | if max_new_tokens is None: 196 | max_seq_length = model.config.block_size 197 | else: 198 | max_seq_length = T + max_new_tokens 199 | max_seq_length = min(max_seq_length, model.config.block_size) 200 | max_new_tokens = max_seq_length - T 201 | if max_new_tokens <= 0: 202 | raise ValueError("Prompt is too long to generate more tokens") 203 | 204 | device, dtype = prompt.device, prompt.dtype 205 | 206 | seq = torch.clone(prompt) 207 | input_pos = torch.arange(0, T, device=device) 208 | 209 | next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs) 210 | seq = torch.cat([seq, next_token.view(1)]) 211 | 212 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 213 | 214 | generated_tokens, _ = decode_n_tokens( 215 | model, 216 | next_token.view(1, -1).repeat(2, 1), 217 | spk_emb, 218 | input_pos, 219 | max_new_tokens - 1, 220 | callback=callback, 221 | end_of_audio_token=end_of_audio_token, 222 | **sampling_kwargs, 223 | ) 224 | seq = torch.cat([seq, torch.cat(generated_tokens)]) 225 | 226 | return seq 227 | 228 | 229 | def encode_tokens(tokenizer, string, device="cuda"): 230 | tokens = tokenizer.encode(string) 231 | return torch.tensor(tokens, dtype=torch.int, device=device) 232 | 233 | 234 | def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path=None, unwanted_prefix="_orig_mod."): 235 | ##### MODEL 236 | with torch.device("meta"): 237 | model = Transformer.from_name("kotoba-speech-v0.1") 238 | 239 | # TODO(quantization): enable 240 | # if "int8" in str(checkpoint_path): 241 | # print("Using int8 weight-only quantization!") 242 | # from quantize import WeightOnlyInt8QuantHandler 243 | # simple_quantizer = WeightOnlyInt8QuantHandler(model) 244 | # model = simple_quantizer.convert_for_runtime() 245 | # from quantize import WeightOnlyInt8QuantHandler 246 | 247 | # if "int4" in str(checkpoint_path): 248 | # print("Using int4 quantization!") 249 | # path_comps = checkpoint_path.name.split(".") 250 | # assert path_comps[-2].startswith("g") 251 | # groupsize = int(path_comps[-2][1:]) 252 | # from quantize import WeightOnlyInt4QuantHandler 253 | # simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) 254 | # model = simple_quantizer.convert_for_runtime() 255 | 256 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) 257 | 258 | ###### TOKENIZER 259 | tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) 260 | tokenizer = TrainedBPETokeniser(**tokenizer_info) 261 | 262 | if first_model_path is not None: 263 | trained_ckpt = torch.load(str(first_model_path), mmap=True, weights_only=False) 264 | #trained_ckpt1 = torch.load(str(first_model_path), mmap=True, weights_only=False, map_location="cpu") 265 | #trained_ckpt2 = torch.load(str(checkpoint_path), mmap=True, weights_only=False, map_location="cpu") 266 | if "state_dict" in trained_ckpt.keys(): 267 | state_dict = trained_ckpt["state_dict"] 268 | else: 269 | state_dict = trained_ckpt["model"] 270 | del checkpoint 271 | gc.collect() 272 | torch.cuda.empty_cache() 273 | else: 274 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) 275 | if "state_dict" in checkpoint.keys(): 276 | state_dict = checkpoint["state_dict"] 277 | else: 278 | state_dict = checkpoint["model"] 279 | 280 | # convert Kotoba-Speech model weights naming to gptfast naming 281 | for k, v in list(state_dict.items()): 282 | if k.startswith(unwanted_prefix): 283 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 284 | state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight") 285 | state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight") 286 | state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight") 287 | state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight") 288 | for k, v in list(state_dict.items()): 289 | if k.startswith("transformer.h."): 290 | state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k) 291 | k = k.replace("transformer.h.", "layers.") 292 | if ".attn.c_attn." in k: 293 | state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k) 294 | k = k.replace(".attn.c_attn.", ".attention.wqkv.") 295 | if ".attn.c_proj." in k: 296 | state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k) 297 | k = k.replace(".attn.c_proj.", ".attention.wo.") 298 | if ".mlp.swiglu.w1." in k: 299 | state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k) 300 | k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.") 301 | if ".mlp.swiglu.w3." in k: 302 | state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k) 303 | k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.") 304 | if ".ln_1." in k: 305 | state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k) 306 | k = k.replace(".ln_1.", ".attention_norm.") 307 | if ".ln_2." in k: 308 | state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k) 309 | k = k.replace(".ln_2.", ".ffn_norm.") 310 | if ".mlp.c_proj." in k: 311 | state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k) 312 | k = k.replace(".mlp.c_proj.", ".feed_forward.w2.") 313 | 314 | model.load_state_dict(state_dict, assign=True) 315 | # simple_quantizer = WeightOnlyInt8QuantHandler(model) 316 | # quantized_state_dict = simple_quantizer.create_quantized_state_dict() 317 | # model = simple_quantizer.convert_for_runtime() 318 | # model.load_state_dict(quantized_state_dict, assign=True) 319 | model = model.to(device=device, dtype=precision) 320 | 321 | ###### SPEAKER EMBEDDER 322 | # TODO: fix! 323 | smodel = SpeakerEncoder( 324 | weights_fpath=spk_emb_ckpt_path, 325 | device=device, 326 | eval=True, 327 | verbose=False, 328 | ) 329 | return model.eval(), tokenizer, smodel 330 | 331 | 332 | def build_model( 333 | *, 334 | precision: torch.dtype, 335 | checkpoint_path: Path = Path(""), 336 | spk_emb_ckpt_path: Path = Path(""), 337 | compile_prefill: bool = False, 338 | compile: bool = True, 339 | device: str = "cuda", 340 | first_model_path: str = None, 341 | ): 342 | assert checkpoint_path.is_file(), checkpoint_path 343 | 344 | print(f"Using device={device}") 345 | 346 | print("Loading model ...") 347 | t0 = time.time() 348 | if first_model_path is None: 349 | # model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision) 350 | model, tokenizer, smodel = _load_model( 351 | checkpoint_path, spk_emb_ckpt_path, device, precision, unwanted_prefix="first_stage_model_transformer." 352 | ) 353 | 354 | else: 355 | model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path, unwanted_prefix="first_stage_model_transformer.") 356 | 357 | 358 | device_sync(device=device) # MKG 359 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 360 | 361 | torch.manual_seed(1234) 362 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) 363 | 364 | with torch.device(device): 365 | model.setup_spk_cond_mask() 366 | model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size) 367 | 368 | if compile: 369 | print("Compiling...Can take up to 2 mins.") 370 | global decode_one_token, prefill 371 | decode_one_token = torch.compile( 372 | decode_one_token, 373 | mode="max-autotune", 374 | fullgraph=True, 375 | ) 376 | 377 | if compile_prefill: 378 | prefill = torch.compile( 379 | prefill, 380 | fullgraph=True, 381 | dynamic=True, 382 | ) 383 | 384 | encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device) 385 | spk_emb = torch.randn((1, 256), device=device, dtype=precision) 386 | 387 | device_sync(device=device) # MKG 388 | t0 = time.perf_counter() 389 | y = generate( 390 | model, 391 | encoded, 392 | spk_emb, 393 | max_new_tokens=200, 394 | callback=lambda x: x, 395 | temperature=torch.tensor(1.0, device=device, dtype=precision), 396 | top_k=None, 397 | top_p=torch.tensor(0.95, device=device, dtype=precision), 398 | guidance_scale=torch.tensor(3.0, device=device, dtype=precision), 399 | end_of_audio_token=9999, # don't end early for compilation stage. 400 | ) 401 | 402 | device_sync(device=device) # MKG 403 | 404 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 405 | 406 | return model, tokenizer, smodel, model_size 407 | 408 | 409 | def main( 410 | *, 411 | model, 412 | tokenizer, 413 | model_size, 414 | prompt: str, 415 | guidance_scale: torch.Tensor, 416 | temperature: torch.Tensor, 417 | spk_emb: torch.Tensor, 418 | top_k: Optional[torch.Tensor] = None, 419 | top_p: Optional[torch.Tensor] = None, 420 | device: str = "cuda", 421 | ) -> list: 422 | """Generates text samples based on a pre-trained Transformer model and tokenizer.""" 423 | 424 | encoded = encode_tokens(tokenizer, prompt, device=device) 425 | prompt_length = encoded.size(0) 426 | 427 | aggregate_metrics: dict = { 428 | "tokens_per_sec": [], 429 | } 430 | 431 | device_sync(device=device) # MKG 432 | 433 | if True: 434 | callback = lambda x: x 435 | t0 = time.perf_counter() 436 | 437 | y = generate( 438 | model, 439 | encoded, 440 | spk_emb, 441 | callback=callback, 442 | temperature=temperature, 443 | top_k=top_k, 444 | top_p=top_p, 445 | guidance_scale=guidance_scale, 446 | ) 447 | 448 | device_sync(device=device) # MKG 449 | t = time.perf_counter() - t0 450 | 451 | tokens_generated = y.size(0) - prompt_length 452 | tokens_sec = tokens_generated / t 453 | aggregate_metrics["tokens_per_sec"].append(tokens_sec) 454 | print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") 455 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") 456 | # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") 457 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n") 458 | 459 | return y.tolist() 460 | -------------------------------------------------------------------------------- /fam/llm/fast_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kotoba Technologies, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted 5 | # provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this list of 8 | # conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, this 11 | # list of conditions and the following disclaimer in the documentation and/or other 12 | # materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its contributors 15 | # may be used to endorse or promote products derived from this software without 16 | # specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR 19 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 20 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 21 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | from dataclasses import dataclass 27 | from functools import reduce 28 | from math import gcd 29 | from typing import Optional, Tuple 30 | 31 | import torch 32 | import torch.nn as nn 33 | from torch import Tensor 34 | from torch.nn import functional as F 35 | 36 | from fam.llm.utils import get_default_dtype 37 | 38 | import logging 39 | 40 | # Adjust the logging level 41 | logger = logging.getLogger("torch") 42 | logger.setLevel(logging.ERROR) 43 | 44 | 45 | def find_multiple(n: int, *args: Tuple[int]) -> int: 46 | k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) 47 | if n % k == 0: 48 | return n 49 | return n + k - (n % k) 50 | 51 | 52 | @dataclass 53 | class ModelArgs: 54 | block_size: int = 2048 55 | vocab_size: int = 32000 56 | n_layer: int = 32 57 | n_head: int = 32 58 | dim: int = 4096 59 | speaker_emb_dim: int = 256 60 | intermediate_size: int = None 61 | n_local_heads: int = -1 62 | head_dim: int = 64 63 | norm_eps: float = 1e-5 64 | dtype: torch.dtype = torch.bfloat16 65 | 66 | def __post_init__(self): 67 | if self.n_local_heads == -1: 68 | self.n_local_heads = self.n_head 69 | if self.intermediate_size is None: 70 | hidden_dim = 4 * self.dim 71 | n_hidden = int(2 * hidden_dim / 3) 72 | self.intermediate_size = find_multiple(n_hidden, 256) 73 | self.head_dim = self.dim // self.n_head 74 | 75 | self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()] 76 | 77 | @classmethod 78 | def from_name(cls, name: str): 79 | if name in transformer_configs: 80 | return cls(**transformer_configs[name]) 81 | # fuzzy search 82 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] 83 | assert len(config) == 1, name 84 | return cls(**transformer_configs[config[0]]) 85 | 86 | 87 | transformer_configs = { 88 | "kotoba-speech-v0.1": dict( 89 | n_layer=24, 90 | n_head=16, 91 | dim=2048, 92 | vocab_size=2562, 93 | ), 94 | } 95 | 96 | 97 | class KVCache(nn.Module): 98 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype): 99 | super().__init__() 100 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 101 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) 102 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) 103 | 104 | def update(self, input_pos, k_val, v_val): 105 | # input_pos: [S], k_val: [B, H, S, D] 106 | assert input_pos.shape[0] == k_val.shape[2] 107 | 108 | k_out = self.k_cache 109 | v_out = self.v_cache 110 | k_out[:, :, input_pos] = k_val 111 | v_out[:, :, input_pos] = v_val 112 | 113 | return k_out, v_out 114 | 115 | 116 | class Transformer(nn.Module): 117 | def __init__(self, config: ModelArgs) -> None: 118 | super().__init__() 119 | self.config = config 120 | 121 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 122 | self.pos_embeddings = nn.Embedding(config.block_size, config.dim) 123 | self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False) 124 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) 125 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 126 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 127 | 128 | self.mask_cache: Optional[Tensor] = None 129 | self.max_batch_size = -1 130 | self.max_seq_length = -1 131 | 132 | def setup_spk_cond_mask(self): 133 | self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool) 134 | self.spk_cond_mask[0] = 1 135 | 136 | def setup_caches(self, max_batch_size, max_seq_length): 137 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 138 | return 139 | head_dim = self.config.dim // self.config.n_head 140 | max_seq_length = find_multiple(max_seq_length, 8) 141 | self.max_seq_length = max_seq_length 142 | self.max_batch_size = max_batch_size 143 | for b in self.layers: 144 | b.attention.kv_cache = KVCache( 145 | max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype 146 | ) 147 | 148 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 149 | 150 | def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor: 151 | mask = self.causal_mask[None, None, input_pos] 152 | x = ( 153 | self.tok_embeddings(idx) 154 | + self.pos_embeddings(input_pos) 155 | # masking for speaker condition free guidance 156 | + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask 157 | ) 158 | 159 | for i, layer in enumerate(self.layers): 160 | x = layer(x, input_pos, mask) 161 | x = self.norm(x) 162 | logits = self.output(x) 163 | return logits 164 | 165 | @classmethod 166 | def from_name(cls, name: str): 167 | return cls(ModelArgs.from_name(name)) 168 | 169 | 170 | class TransformerBlock(nn.Module): 171 | def __init__(self, config: ModelArgs) -> None: 172 | super().__init__() 173 | self.attention = Attention(config) 174 | self.feed_forward = FeedForward(config) 175 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 176 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 177 | 178 | def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor: 179 | h = x + self.attention(self.attention_norm(x), mask, input_pos) 180 | out = h + self.feed_forward(self.ffn_norm(h)) 181 | return out 182 | 183 | 184 | class Attention(nn.Module): 185 | def __init__(self, config: ModelArgs): 186 | super().__init__() 187 | assert config.dim % config.n_head == 0 188 | 189 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 190 | # key, query, value projections for all heads, but in a batch 191 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 192 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 193 | self.kv_cache = None 194 | 195 | self.n_head = config.n_head 196 | self.head_dim = config.head_dim 197 | self.n_local_heads = config.n_local_heads 198 | self.dim = config.dim 199 | 200 | def forward( 201 | self, 202 | x: Tensor, 203 | mask: Tensor, 204 | input_pos: Optional[Tensor] = None, 205 | ) -> Tensor: 206 | bsz, seqlen, _ = x.shape 207 | 208 | kv_size = self.n_local_heads * self.head_dim 209 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 210 | 211 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 212 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 213 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 214 | 215 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 216 | 217 | if self.kv_cache is not None: 218 | k, v = self.kv_cache.update(input_pos, k, v) 219 | 220 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 221 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 222 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 223 | 224 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 225 | 226 | y = self.wo(y) 227 | return y 228 | 229 | 230 | class SwiGLU(nn.Module): 231 | def __init__(self, config: ModelArgs) -> None: 232 | super().__init__() 233 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 234 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 235 | 236 | def forward(self, x: Tensor) -> Tensor: 237 | return F.silu(self.w1(x)) * self.w3(x) 238 | 239 | 240 | class FeedForward(nn.Module): 241 | def __init__(self, config: ModelArgs) -> None: 242 | super().__init__() 243 | self.swiglu = SwiGLU(config) 244 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 245 | 246 | def forward(self, x: Tensor) -> Tensor: 247 | return self.w2(self.swiglu(x)) 248 | 249 | 250 | class RMSNorm(nn.Module): 251 | def __init__(self, dim: int, eps: float = 1e-5): 252 | super().__init__() 253 | self.eps = eps 254 | self.weight = nn.Parameter(torch.ones(dim)) 255 | 256 | def _norm(self, x): 257 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 258 | 259 | def forward(self, x: Tensor) -> Tensor: 260 | output = self._norm(x.float()).type_as(x) 261 | return output * self.weight 262 | -------------------------------------------------------------------------------- /fam/llm/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.layers.attn import SelfAttention 2 | from fam.llm.layers.combined import Block 3 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU 4 | -------------------------------------------------------------------------------- /fam/llm/layers/attn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class SelfAttention(nn.Module): 9 | def __init__(self, config): 10 | """ 11 | Initializes the SelfAttention module. 12 | 13 | Args: 14 | config: An object containing the configuration parameters for the SelfAttention module. 15 | """ 16 | super().__init__() 17 | self._validate_config(config) 18 | self._initialize_parameters(config) 19 | 20 | def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype): 21 | """ 22 | Empties the key-value cache. 23 | 24 | Args: 25 | batch_size: The batch size. 26 | kv_cache_maxlen: The maximum length of the key-value cache. 27 | dtype: The data type of the cache. 28 | 29 | Raises: 30 | Exception: If trying to empty the KV cache when it is disabled. 31 | """ 32 | if self.kv_cache_enabled is False: 33 | raise Exception("Trying to empty KV cache when it is disabled") 34 | 35 | # register so that the cache moves devices along with the module 36 | # TODO: get rid of re-allocation. 37 | self.register_buffer( 38 | "kv_cache", 39 | torch.zeros( 40 | 2, 41 | batch_size, 42 | kv_cache_maxlen, 43 | self.n_head, 44 | self.n_embd // self.n_head, 45 | dtype=dtype, 46 | device=self.c_attn.weight.device, 47 | ), 48 | persistent=False, 49 | ) 50 | 51 | self.kv_cache_first_empty_index = 0 52 | 53 | def _initialize_parameters(self, config): 54 | """ 55 | Initializes the parameters of the SelfAttention module. 56 | 57 | Args: 58 | config: An object containing the configuration parameters for the SelfAttention module. 59 | """ 60 | # key, query, value projections for all heads, but in a batch 61 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 62 | 63 | # output projection 64 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 65 | 66 | # regularization 67 | self.resid_dropout = nn.Dropout(config.dropout) 68 | self.n_head = config.n_head 69 | self.n_embd = config.n_embd 70 | self.dropout = config.dropout 71 | self.causal = config.causal 72 | self.attn_kernel_type = config.attn_kernel_type 73 | self.attn_dropout = nn.Dropout(config.dropout) 74 | 75 | self.kv_cache_enabled = False 76 | 77 | def _validate_config(self, config): 78 | """ 79 | Validates the configuration parameters. 80 | 81 | Args: 82 | config: An object containing the configuration parameters for the SelfAttention module. 83 | 84 | Raises: 85 | AssertionError: If the embedding dimension is not divisible by the number of heads. 86 | """ 87 | assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads" 88 | 89 | def _update_kv_cache(self, q, k, v): 90 | """ 91 | Updates the key-value cache. 92 | 93 | Args: 94 | q: The query tensor. 95 | k: The key tensor. 96 | v: The value tensor. 97 | 98 | Returns: 99 | The updated key and value tensors. 100 | 101 | Raises: 102 | AssertionError: If the dimensions of the query, key, and value tensors are not compatible. 103 | """ 104 | q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1] 105 | 106 | if self.kv_cache_first_empty_index == 0: 107 | assert q_time == k_time and q_time == v_time 108 | else: 109 | assert ( 110 | q_time == 1 111 | ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}" 112 | 113 | self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k 114 | self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v 115 | self.kv_cache_first_empty_index += q_time 116 | 117 | k = self.kv_cache[0, :, : self.kv_cache_first_empty_index] 118 | v = self.kv_cache[1, :, : self.kv_cache_first_empty_index] 119 | 120 | return k, v 121 | 122 | def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: 123 | """ 124 | Performs attention using the torch.nn.functional.scaled_dot_product_attention function. 125 | 126 | Args: 127 | c_x: The input tensor. 128 | 129 | Returns: 130 | The output tensor. 131 | """ 132 | q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs) 133 | q = q.squeeze(2) # (B, T, nh, hs) 134 | k = k.squeeze(2) # (B, T, nh, hs) 135 | v = v.squeeze(2) # (B, T, nh, hs) 136 | 137 | # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and 138 | # use no mask for the "one time step" parts. 139 | # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index 140 | is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0) 141 | 142 | if self.kv_cache_enabled: 143 | k, v = self._update_kv_cache(q, k, v) 144 | 145 | q = q.transpose(1, 2) # (B, nh, T, hs) 146 | k = k.transpose(1, 2) # (B, nh, T, hs) 147 | v = v.transpose(1, 2) # (B, nh, T, hs) 148 | y = torch.nn.functional.scaled_dot_product_attention( 149 | q, 150 | k, 151 | v, 152 | attn_mask=None, 153 | dropout_p=self.dropout if self.training else 0, 154 | is_causal=is_causal_attn_mask, 155 | ).transpose( 156 | 1, 2 157 | ) # (B, nh, T, hs) -> (B, T, nh, hs) 158 | 159 | return y 160 | 161 | def forward(self, x): 162 | """ 163 | Performs the forward pass of the SelfAttention module. 164 | 165 | Args: 166 | x: The input tensor. 167 | 168 | Returns: 169 | The output tensor. 170 | """ 171 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 172 | 173 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 174 | c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs) 175 | 176 | # causal self-attention; 177 | if self.attn_kernel_type == "torch_attn": 178 | y = self._torch_attn(c_x) 179 | else: 180 | raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") 181 | 182 | y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh) 183 | # output projection 184 | y = self.resid_dropout(self.c_proj(y)) 185 | return y 186 | -------------------------------------------------------------------------------- /fam/llm/layers/combined.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fam.llm.layers.attn import SelfAttention 4 | from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm 5 | 6 | 7 | class Block(nn.Module): 8 | """ 9 | Block class represents a single block in the model. 10 | 11 | Args: 12 | config (object): Configuration object containing parameters for the block. 13 | 14 | Attributes: 15 | ln_1 (object): Layer normalization for the attention layer. 16 | ln_2 (object): Layer normalization for the feed-forward layer. 17 | attn (object): Self-attention layer. 18 | mlp (object): Multi-layer perceptron layer. 19 | 20 | Methods: 21 | forward(x): Performs forward pass through the block. 22 | """ 23 | 24 | def __init__(self, config): 25 | super().__init__() 26 | if config.norm_type == "rmsnorm": 27 | if config.rmsnorm_eps is None: 28 | raise Exception("RMSNorm requires rmsnorm_eps to be set") 29 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm 30 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm 31 | elif config.norm_type == "layernorm": 32 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm 33 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm 34 | else: 35 | raise Exception(f"Unknown norm type: {config.norm_type}") 36 | self.attn = SelfAttention(config) 37 | 38 | self.mlp = MLP(config) 39 | 40 | def forward(self, x): 41 | """ 42 | Performs forward pass through the block. 43 | 44 | Args: 45 | x (tensor): Input tensor. 46 | 47 | Returns: 48 | tensor: Output tensor after passing through the block. 49 | """ 50 | x = x + self.attn(self.ln_1(x)) 51 | x = x + self.mlp(self.ln_2(x)) 52 | return x 53 | -------------------------------------------------------------------------------- /fam/llm/layers/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class LayerNorm(nn.Module): 9 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" 10 | 11 | def __init__(self, ndim, bias): 12 | super().__init__() 13 | self.weight = nn.Parameter(torch.ones(ndim)) 14 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 15 | 16 | def forward(self, input): 17 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 18 | 19 | 20 | class RMSNorm(torch.nn.Module): 21 | def __init__(self, ndim: int, eps: float): 22 | super().__init__() 23 | self.eps = eps 24 | self.weight = nn.Parameter(torch.ones(ndim)) 25 | 26 | def _norm(self, x): 27 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 28 | 29 | def forward(self, x): 30 | return self._norm(x) * self.weight 31 | 32 | 33 | class SwiGLU(nn.Module): 34 | def __init__(self, in_dim, out_dim, bias) -> None: 35 | super().__init__() 36 | self.w1 = nn.Linear(in_dim, out_dim, bias=bias) 37 | self.w3 = nn.Linear(in_dim, out_dim, bias=bias) 38 | 39 | def forward(self, x): 40 | return F.silu(self.w1(x)) * self.w3(x) 41 | 42 | 43 | class MLP(nn.Module): 44 | def __init__(self, config): 45 | super().__init__() 46 | self.non_linearity = config.nonlinearity_type 47 | hidden_dim = 4 * config.n_embd 48 | if config.nonlinearity_type == "gelu": 49 | self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) 50 | self.gelu = nn.GELU() 51 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) 52 | elif config.nonlinearity_type == "swiglu": 53 | if config.swiglu_multiple_of is None: 54 | raise Exception("SwiGLU requires swiglu_multiple_of to be set") 55 | hidden_dim = int(2 * hidden_dim / 3) 56 | hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of) 57 | # set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__() 58 | self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias) 59 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) 60 | else: 61 | raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}") 62 | self.dropout = nn.Dropout(config.dropout) 63 | 64 | def forward(self, x): 65 | if self.non_linearity == "gelu": 66 | x = self.c_fc(x) 67 | x = self.gelu(x) 68 | elif self.non_linearity == "swiglu": 69 | x = self.swiglu(x) 70 | x = self.c_proj(x) 71 | x = self.dropout(x) 72 | return x 73 | -------------------------------------------------------------------------------- /fam/llm/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.mixins.causal import CausalInferenceMixin 2 | from fam.llm.mixins.non_causal import NonCausalInferenceMixin 3 | -------------------------------------------------------------------------------- /fam/llm/mixins/causal.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | import tqdm 6 | from torch.nn import functional as F 7 | from IPython import embed 8 | 9 | def top_p_sample(prob_dist: torch.Tensor, top_p: float): 10 | sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True, dim=-1) 11 | cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) # (b, vocab_size) 12 | 13 | sorted_indices_to_remove = cum_sum_probs > top_p 14 | 15 | # Shift the indices to the right to keep also the first token above the threshold 16 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 17 | sorted_indices_to_remove[:, 0] = 0 18 | sorted_indices_to_remove = sorted_indices_to_remove.bool() 19 | 20 | # replace probs to be removed with 0 in the sorted_probs 21 | sorted_probs[sorted_indices_to_remove] = 0 22 | 23 | # reverse the sorting process 24 | reversed_indices = torch.argsort(sorted_indices) 25 | prob_dist = torch.gather(sorted_probs, -1, reversed_indices) 26 | 27 | # normalize 28 | prob_dist = prob_dist / prob_dist.sum(dim=-1, keepdim=True) 29 | 30 | return prob_dist 31 | 32 | 33 | class CausalInferenceMixin: 34 | """ 35 | Mixin class for performing inference in a causal language model. 36 | 37 | This mixin provides methods for predicting the next token in a sequence, sampling from the model, 38 | and applying token prediction masks. 39 | 40 | Attributes: 41 | None 42 | 43 | Methods: 44 | _sample_next_token: Predicts the next token in the sequence. 45 | _create_token_pred_mask: Creates a token prediction mask based on sequence lengths. 46 | _apply_token_pred_mask: Applies a token prediction mask to the next token predictions. 47 | _sample_batch: Samples a batch of tokens from the model. 48 | _sort_for_batching: Sorts the input sequences for efficient batching. 49 | _causal_sample: Generates a sequence of tokens using causal sampling. 50 | 51 | """ 52 | 53 | @torch.no_grad() 54 | def _sample_next_token( 55 | self, 56 | *, 57 | idx: torch.Tensor, 58 | speaker_embs: Optional[torch.Tensor], 59 | temperature: float, 60 | top_k: Optional[int], 61 | top_p: Optional[float], 62 | guidance_scale: Optional[float], 63 | ) -> torch.Tensor: 64 | """ 65 | Predict the next token in the sequence. 66 | 67 | Args: 68 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). 69 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. 70 | temperature (float): Sampling temperature. 71 | top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering. 72 | top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it. 73 | guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance. 74 | 75 | Returns: 76 | torch.Tensor: Next index in the sequence after sampling. Shape: (batch, num_hierarchies). 77 | """ 78 | if top_k is not None and top_p is not None: 79 | raise ValueError("Only one of top_k and top_p can be set") 80 | 81 | # if the sequence context is growing too long we must crop it at block_size 82 | idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, :, -self.config.block_size :] 83 | 84 | # forward the model to get the logits for the index in the sequence 85 | list_logits, _ = self( 86 | idx_cond, speaker_embs=speaker_embs 87 | ) # list with len num_hierarchies of (b,1,vocab_size) tensors 88 | # print(f'{list_logits[0].shape=}, {len(list_logits)=}') 89 | # print(f'{list_logits[0][:,:,:10]}') 90 | 91 | if guidance_scale is not None: 92 | assert idx_cond.shape[0] % 2 == 0 93 | assert list_logits[0].shape[0] % 2 == 0 94 | 95 | for i, logits in enumerate(list_logits): 96 | logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0) 97 | list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond 98 | 99 | assert list_logits[0].shape[0] == idx_cond.shape[0] // 2 100 | 101 | # pluck the logits at the final step and scale by desired temperature 102 | list_logits = [ 103 | logits[:, -1, :] / temperature for logits in list_logits 104 | ] # list with len num_hierarchies of (b,vocab_size) tensors 105 | 106 | # optionally crop the logits to only the top k options 107 | if top_k is not None: 108 | for i in range(len(list_logits)): 109 | logits = list_logits[i] 110 | v, _ = torch.topk( 111 | logits, min(top_k, logits.size(-1)) 112 | ) # returns a descending sorted list of values and indices of top_k values 113 | logits[logits < v[:, [-1]]] = -float("Inf") # set all logits below the smallest top_k value to -Inf 114 | list_logits[i] = logits 115 | 116 | # apply softmax to convert logits to (normalized) probabilities 117 | # embed() 118 | probs = [ 119 | F.softmax(logits, dim=-1) for logits in list_logits 120 | ] # list of len num_hierarchies of (b,vocab_size) tensors 121 | # print(f'{probs[0].shape=}') 122 | # print(f'{probs[0][:,:,:10]}') 123 | if top_p is not None: 124 | for i in range(len(probs)): 125 | probs[i] = top_p_sample(probs[i], top_p) 126 | 127 | # sample from the distribution 128 | idx_next = [ 129 | torch.multinomial(prob, num_samples=1) for prob in probs 130 | ] # list of len num_hierarchies of (b,1) tensors 131 | idx_next = torch.cat(idx_next, dim=-1) # (b, num_hierarchies) tensor 132 | 133 | return idx_next # (b, num_hierarchies) tensor 134 | 135 | @torch.no_grad() 136 | def _create_token_pred_mask(self, idx: torch.Tensor, seq_lens: list[int]) -> torch.Tensor: 137 | """ 138 | Creates a token prediction mask based on sequence lengths. 139 | 140 | Args: 141 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). 142 | seq_lens (list[int]): List of sequence lengths for each sequence in idx. 143 | 144 | Returns: 145 | torch.Tensor: Token prediction mask of shape (batch, time). 146 | """ 147 | token_pred_mask = torch.zeros((idx.shape[0], idx.shape[-1]), dtype=torch.bool, device=idx.device) 148 | for i in range(len(seq_lens)): 149 | token_pred_mask[i, : seq_lens[i]] = True 150 | 151 | assert (token_pred_mask[:, : min(seq_lens)] == 1).all() 152 | 153 | return token_pred_mask 154 | 155 | @torch.no_grad() 156 | def _apply_token_pred_mask( 157 | self, *, idx_next: torch.Tensor, orig_input_at_t: torch.Tensor, token_pred_mask_at_t: torch.Tensor 158 | ) -> torch.Tensor: 159 | """ 160 | Applies a token prediction mask to the next token predictions. 161 | 162 | Args: 163 | idx_next (torch.Tensor): Next token predictions of shape (batch, num_hierarchies). 164 | orig_input_at_t (torch.Tensor): Original input at time step t of shape (batch, num_hierarchies). 165 | token_pred_mask_at_t (torch.Tensor): Token prediction mask at time step t of shape (batch, 1). 166 | 167 | Returns: 168 | torch.Tensor: Updated next token predictions after applying the token prediction mask. 169 | """ 170 | idx_next = idx_next * (~token_pred_mask_at_t) + orig_input_at_t * token_pred_mask_at_t 171 | 172 | return idx_next 173 | 174 | @torch.no_grad() 175 | def _sample_batch( 176 | self, 177 | *, 178 | idx: torch.Tensor, 179 | max_new_tokens: int, 180 | seq_lens: list[int], 181 | temperature: float, 182 | top_k: Optional[int], 183 | top_p: Optional[float], 184 | speaker_embs: Optional[torch.Tensor], 185 | guidance_scale: Optional[float], 186 | ): 187 | """ 188 | Samples a batch of tokens from the model. 189 | 190 | Args: 191 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). 192 | max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx). 193 | seq_lens (list[int]): List of sequence lengths for each sequence in idx. 194 | temperature (float): Sampling temperature. 195 | top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering. 196 | top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it. 197 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. 198 | guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance. 199 | 200 | Returns: 201 | torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time). 202 | """ 203 | assert max(seq_lens) <= idx.shape[-1] 204 | token_pred_mask = self._create_token_pred_mask(idx, seq_lens) 205 | input = torch.clone(idx) 206 | 207 | min_seq_lens = min(seq_lens) 208 | idx = idx[:, :, :min_seq_lens] 209 | 210 | if guidance_scale is not None: 211 | if speaker_embs is None: 212 | raise Exception("Guidance is only supported for conditional models") 213 | 214 | # create speaker embeddings equivalent to the batch size, filling with None 215 | # for second half to do unconditional generation. 216 | speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0]) 217 | 218 | for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "): 219 | if (self.kv_cache_enabled is True) and (timestep > min_seq_lens): 220 | idx_input = idx[:, :, -1:] 221 | else: 222 | idx_input = idx 223 | 224 | if guidance_scale is not None: 225 | # TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size. 226 | if timestep == min_seq_lens: 227 | print("[hack!!!!] Guidance is on, so we're doubling batch size!") 228 | 229 | # replicate idx in the batch dimension 230 | idx_input = ( 231 | idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2]) 232 | ) 233 | 234 | # sanity checks 235 | assert idx_input.shape[0] % 2 == 0 236 | 237 | idx_next = self._sample_next_token( 238 | idx=idx_input, 239 | speaker_embs=speaker_embs, 240 | temperature=temperature, 241 | top_k=top_k, 242 | top_p=top_p, 243 | guidance_scale=guidance_scale, 244 | ) # (b, num_hierarchies) 245 | 246 | assert idx_next.shape[0] == idx.shape[0] 247 | 248 | if timestep < token_pred_mask.shape[-1]: 249 | idx_next = self._apply_token_pred_mask( 250 | idx_next=idx_next, 251 | orig_input_at_t=input[:, :, timestep], 252 | token_pred_mask_at_t=token_pred_mask[:, [timestep]], 253 | ) 254 | 255 | idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor 256 | # append sampled index to the running sequence and continue 257 | idx = torch.cat((idx, idx_next), dim=2) 258 | 259 | return idx 260 | 261 | @torch.no_grad() 262 | def _sort_for_batching( 263 | self, 264 | *, 265 | idx: torch.Tensor, 266 | seq_lens: list[int], 267 | speaker_embs: Optional[torch.Tensor], 268 | batch_size: int, 269 | max_new_tokens: int, 270 | ) -> Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]: 271 | """ 272 | Sorts the input sequences for efficient batching. 273 | 274 | Args: 275 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). 276 | seq_lens (list[int]): List of sequence lengths for each sequence in idx. 277 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. 278 | batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling. 279 | max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx). 280 | 281 | Returns: 282 | Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]: 283 | - sorted_indices (list[int]): List of indices of the input sequences that transform it into sorted order. 284 | - invert_sorted_indices (list[int]): List of indices to invert the sorted sequences back to the original order. 285 | - idx (torch.Tensor): Input sequence indices in sorted order. 286 | - seq_lens (list[int]): Sequence lengths in sorted order. 287 | - speaker_embs (Optional[torch.Tensor]): speaker embeddings in sorted order. 288 | - max_token_len (int): Effective maximum number of tokens to generate. 289 | """ 290 | assert len(seq_lens) == idx.shape[0] 291 | assert max(seq_lens) <= idx.shape[-1] 292 | 293 | sorted_indices = np.argsort(seq_lens) 294 | inverted_sorted_indices = np.zeros(len(seq_lens), dtype=np.int32) 295 | inverted_sorted_indices[sorted_indices] = np.arange(len(seq_lens), dtype=np.int32) 296 | 297 | idx = idx[sorted_indices] 298 | seq_lens = [seq_lens[i] for i in sorted_indices] 299 | speaker_embs = speaker_embs[sorted_indices] if speaker_embs is not None else None 300 | max_token_len = 0 301 | 302 | # figure out effective max_tokens to generate 303 | for start_index in range(0, len(seq_lens), batch_size): 304 | end_index = min(start_index + batch_size, len(seq_lens)) 305 | batch_seq_lens = seq_lens[start_index:end_index] 306 | # random heuristic... 307 | # # TODO: fix! 308 | max_token_len = max(max_token_len, min(batch_seq_lens) + max_new_tokens) 309 | 310 | return sorted_indices, inverted_sorted_indices, idx, seq_lens, speaker_embs, max_token_len 311 | 312 | @torch.no_grad() 313 | def _causal_sample( 314 | self, 315 | *, 316 | idx: torch.Tensor, 317 | max_new_tokens: int, 318 | seq_lens: list[int], 319 | temperature: float, 320 | top_k: Optional[int], 321 | top_p: Optional[float], 322 | speaker_embs: Optional[torch.Tensor], 323 | batch_size: int, 324 | guidance_scale: Optional[float] = None, 325 | ) -> torch.Tensor: 326 | """ 327 | Generates a sequence of tokens using causal sampling. 328 | 329 | Args: 330 | idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). 331 | max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx). 332 | seq_lens (list[int]): List of sequence lengths for each sequence in idx. 333 | temperature (float): Sampling temperature. 334 | top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering. 335 | top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it. 336 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. 337 | batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling. 338 | guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance. 339 | 340 | Returns: 341 | torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time). 342 | """ 343 | ( 344 | _, 345 | invert_sorted_indices, 346 | idx, 347 | seq_lens, 348 | speaker_embs, 349 | max_token_len, 350 | ) = self._sort_for_batching( 351 | idx=idx, seq_lens=seq_lens, speaker_embs=speaker_embs, batch_size=batch_size, max_new_tokens=max_new_tokens 352 | ) 353 | 354 | return_idx = torch.zeros((len(seq_lens), idx.size(1), max_token_len), dtype=torch.long, device=idx.device) 355 | 356 | for start_index in tqdm.tqdm(range(0, len(seq_lens), batch_size), desc="batch: "): 357 | end_index = min(start_index + batch_size, len(seq_lens)) 358 | 359 | kv_batch_size = end_index - start_index 360 | if guidance_scale is not None: 361 | kv_batch_size = 2 * kv_batch_size 362 | 363 | if self.kv_cache_enabled: 364 | print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16") 365 | self.empty_kv_cache( 366 | batch_size=kv_batch_size, 367 | kv_cache_maxlen=self.config.block_size, 368 | dtype=torch.bfloat16, 369 | ) 370 | 371 | batch_seq_lens = seq_lens[start_index:end_index] 372 | batch_max_new_tokens = max_token_len - min(batch_seq_lens) 373 | 374 | batch_idx = idx[start_index:end_index] 375 | batch_speaker_embs = speaker_embs[start_index:end_index] if speaker_embs is not None else None 376 | 377 | batch_idx = self._sample_batch( 378 | idx=batch_idx, 379 | max_new_tokens=batch_max_new_tokens, 380 | seq_lens=batch_seq_lens, 381 | temperature=temperature, 382 | top_k=top_k, 383 | top_p=top_p, 384 | speaker_embs=batch_speaker_embs, 385 | guidance_scale=guidance_scale, 386 | ) 387 | return_idx[start_index:end_index] = batch_idx 388 | 389 | return return_idx[invert_sorted_indices] 390 | 391 | def empty_kv_cache(self, *, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype): 392 | """ 393 | Empties key-value (KV) cache for causal attention. 394 | 395 | Args: 396 | batch_size (int): The batch size. 397 | kv_cache_maxlen (int): The maximum length of the KV cache. 398 | dtype (torch.dtype): The data type of the KV cache. 399 | 400 | Raises: 401 | Exception: If KV cache is enabled for non-causal attention. 402 | 403 | """ 404 | if self.kv_cache_enabled is False: 405 | raise Exception("KV cache is not enabled") 406 | if self.config.causal is False: 407 | raise Exception("KV cache is not supported for non-causal attention") 408 | 409 | self.kv_pos = 0 410 | for block in self.transformer.h: 411 | block.attn.empty_kv_cache(batch_size=batch_size, kv_cache_maxlen=kv_cache_maxlen, dtype=dtype) 412 | 413 | def enable_kv_cache(self): 414 | """ 415 | Enables key-value (KV) cache for causal attention. 416 | 417 | Raises: 418 | Exception: If KV cache is enabled for non-causal attention. 419 | 420 | """ 421 | if self.config.causal is False: 422 | raise Exception("KV cache is not supported for non-causal attention") 423 | 424 | self.kv_cache_enabled = True 425 | for block in self.transformer.h: 426 | block.attn.kv_cache_enabled = True 427 | 428 | def disable_kv_cache(self): 429 | """ 430 | Disables the key-value cache for the transformer and all its blocks. 431 | """ 432 | self.kv_cache_enabled = False 433 | for block in self.transformer.h: 434 | block.attn.kv_cache_enabled = False 435 | block.attn.kv_cache = None 436 | block.attn.kv_cache_first_empty_index = 0 437 | 438 | @torch.no_grad() 439 | def _slow_causal_sampling_loop( 440 | self, 441 | idx: torch.Tensor, 442 | max_new_tokens: int, 443 | temperature: float = 1.0, 444 | top_k: Optional[int] = None, 445 | top_p: Optional[float] = None, 446 | speaker_embs: Optional[torch.Tensor] = None, 447 | guidance_scale: Optional[float] = None, 448 | ): 449 | """ 450 | Old non-batched version of causal sampling. Kept for testing / reference. 451 | 452 | Take a conditioning sequence of indices idx (LongTensor of shape (b,n_head,t)) and complete 453 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 454 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 455 | """ 456 | assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens" 457 | assert idx.size(0) == 1, "can only do one sequence at a time for now" 458 | assert top_p is None, "nucleus sampling not supported yet with _slow_causal_sampling_loop" 459 | 460 | if self.config.causal is not True: 461 | raise Exception("Causal sampling is only supported for causal models") 462 | 463 | if self.kv_cache_enabled: 464 | print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16") 465 | self.empty_kv_cache( 466 | batch_size=1, 467 | kv_cache_maxlen=self.config.block_size, 468 | dtype=torch.bfloat16, 469 | ) 470 | 471 | for i in range(max_new_tokens): 472 | # if the sequence context is growing too long we must crop it at block_size 473 | idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, -self.config.block_size :] 474 | 475 | if self.kv_cache_enabled: 476 | if i > 0: 477 | idx_cond = idx_cond[:, :, -1:] 478 | 479 | # forward the model to get the logits for the index in the sequence 480 | list_logits, _ = self(idx_cond, speaker_embs=speaker_embs) 481 | 482 | if guidance_scale is not None: 483 | # we've already checked that kv-caching is not switched on 484 | # so this should be ok. 485 | list_logits_uncond, _ = self(idx_cond, speaker_embs=None) 486 | list_logits = [ 487 | (guidance_scale) * logits + (1 - guidance_scale) * logits_uncond 488 | for logits, logits_uncond in zip(list_logits, list_logits_uncond) 489 | ] 490 | 491 | # pluck the logits at the final step and scale by desired temperature 492 | list_logits = [logits[:, -1, :] / temperature for logits in list_logits] 493 | 494 | # optionally crop the logits to only the top k options 495 | if top_k is not None: 496 | for i in range(len(list_logits)): 497 | logits = list_logits[i] 498 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 499 | logits[logits < v[:, [-1]]] = -float("Inf") 500 | list_logits[i] = logits 501 | 502 | # apply softmax to convert logits to (normalized) probabilities 503 | probs = [F.softmax(logits, dim=-1) for logits in list_logits] 504 | # sample from the distribution 505 | idx_next = torch.tensor( 506 | [torch.multinomial(prob, num_samples=1) for prob in probs], device=idx.device 507 | ) # (c, 1) 508 | # append sampled index to the running sequence and continue 509 | idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(-1)), dim=2) 510 | 511 | return idx 512 | -------------------------------------------------------------------------------- /fam/llm/mixins/non_causal.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class NonCausalInferenceMixin: 8 | """ 9 | Mixin class for non-causal inference in a language model. 10 | 11 | This class provides methods for performing non-causal sampling using a language model. 12 | """ 13 | 14 | @torch.no_grad() 15 | def _non_causal_sample( 16 | self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int 17 | ): 18 | """ 19 | Perform non-causal sampling. 20 | 21 | Args: 22 | idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length). 23 | speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size). 24 | temperature (float): Temperature parameter for scaling the logits. 25 | top_k (int): Number of top options to consider. 26 | 27 | Returns: 28 | torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length). 29 | """ 30 | b, c, t = idx.size() 31 | assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}" 32 | # forward the model to get the logits for the index in the sequence 33 | list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size) 34 | 35 | # scale by desired temperature 36 | list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size) 37 | 38 | # optionally crop the logits to only the top k options 39 | if top_k is not None: 40 | for i in range(len(list_logits)): 41 | logits = list_logits[i] # (b, t, vocab_size) 42 | 43 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k) 44 | logits[logits < v[:, :, [-1]]] = -float("Inf") 45 | list_logits[i] = logits # (b, t, vocab_size) 46 | assert logits.shape[0] == b and logits.shape[1] == t 47 | 48 | # apply softmax to convert logits to (normalized) probabilities 49 | # TODO: check shapes here! 50 | probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k) 51 | assert probs[0].shape[0] == b and probs[0].shape[1] == t 52 | 53 | # TODO: output shape is as expected 54 | outs = [] 55 | for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k) 56 | out = [ 57 | torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob 58 | ] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t) 59 | assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t 60 | out = torch.cat(out, dim=0) # (b, 1, t) 61 | assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t 62 | outs.append(out) 63 | 64 | out = torch.cat(outs, dim=1) # (b, c, t) 65 | assert out.shape[0] == b and out.shape[2] == t 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /fam/llm/model.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from dataclasses import dataclass, field 4 | from typing import Literal, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import tqdm 9 | from einops import rearrange 10 | from torch.nn import functional as F 11 | 12 | from fam.llm.layers import Block, LayerNorm, RMSNorm 13 | from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin 14 | 15 | from IPython import embed 16 | END_OF_TEXT_TOKEN = 1537 17 | 18 | 19 | def _select_spkemb(spkemb, mask): 20 | _, examples, _ = spkemb.shape 21 | mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples) 22 | spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex 23 | mask = mask.transpose(1, 2) # b t ex -> b ex t 24 | return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c 25 | 26 | 27 | @dataclass 28 | class GPTConfig: 29 | block_size: int = 1024 30 | vocab_sizes: list = field(default_factory=list) 31 | target_vocab_sizes: Optional[list] = None 32 | n_layer: int = 12 33 | n_head: int = 12 34 | n_embd: int = 768 35 | dropout: float = 0.0 36 | spkemb_dropout: float = 0.0 37 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 38 | causal: bool = ( 39 | True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens 40 | ) 41 | spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not 42 | norm_type: str = "layernorm" # "rmsnorm" or "layernorm 43 | rmsnorm_eps: Optional[float] = None # only used for rmsnorm 44 | nonlinearity_type: str = "gelu" # "gelu" or "swiglu" 45 | swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this 46 | attn_kernel_type: Literal["torch_attn"] = "torch_attn" 47 | #Literal["fa2", "torch_attn", "hand"] = "fa2" 48 | kv_cache_enabled: bool = False # whether to use key-value cache for attention 49 | 50 | 51 | def _check_speaker_emb_dims( 52 | speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int 53 | ) -> Union[torch.Tensor, list]: 54 | """ 55 | Checks that the speaker embedding dimensions are correct, and reshapes them if necessary. 56 | """ 57 | if type(speaker_embs) == list: 58 | b_se = len(speaker_embs) 59 | for i, s in enumerate(speaker_embs): 60 | if s is not None: 61 | emb_dim = s.shape[-1] 62 | if s.ndim == 1: 63 | speaker_embs[i] = speaker_embs[i].unsqueeze(0) 64 | else: 65 | if speaker_embs.ndim == 2: 66 | # if we have a single speaker embedding for the whole sequence, 67 | # add a dummy dimension for backwards compatibility 68 | speaker_embs = speaker_embs[:, None, :] 69 | 70 | # num_examples is the number of utterances packed into this sequence 71 | b_se, num_examples, emb_dim = speaker_embs.size() 72 | 73 | assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}" 74 | assert ( 75 | emb_dim == expected_speaker_emb_dim 76 | ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}" 77 | 78 | return speaker_embs 79 | 80 | 81 | class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin): 82 | def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None): 83 | """ 84 | Initialize the GPT model. 85 | 86 | Args: 87 | config (GPTConfig): Configuration object for the model. 88 | speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None. 89 | """ 90 | super().__init__() 91 | assert config.vocab_sizes is not None 92 | assert config.block_size is not None 93 | self.config = config 94 | 95 | self.kv_cache_enabled = False # disabled by default 96 | self.kv_pos = 0 97 | 98 | self.speaker_emb_dim = speaker_emb_dim 99 | self.spk_emb_on_text = config.spk_emb_on_text 100 | if self.config.causal is True and self.spk_emb_on_text is False: 101 | print("!!!!!!!!!!!!!!!!!!") 102 | print( 103 | f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this." 104 | ) 105 | print("!!!!!!!!!!!!!!!!!!") 106 | if self.config.causal is False and self.spk_emb_on_text is False: 107 | raise Exception( 108 | "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding." 109 | ) 110 | 111 | if config.norm_type == "rmsnorm": 112 | if config.rmsnorm_eps is None: 113 | raise Exception("RMSNorm requires rmsnorm_eps to be set") 114 | ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) 115 | elif config.norm_type == "layernorm": 116 | ln_f = LayerNorm(config.n_embd, bias=config.bias) 117 | else: 118 | raise Exception(f"Unknown norm type: {config.norm_type}") 119 | 120 | self.transformer = nn.ModuleDict( 121 | dict( 122 | wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd,) for vsize in config.vocab_sizes]), 123 | wpe=nn.Embedding(config.block_size, config.n_embd), 124 | drop=nn.Dropout(config.dropout), 125 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 126 | ln_f=ln_f, 127 | ) 128 | ) 129 | if speaker_emb_dim is not None: 130 | self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False) # ここで256->2048 131 | 132 | self.lm_heads = nn.ModuleList() 133 | if config.target_vocab_sizes is not None: 134 | assert config.causal is False 135 | else: 136 | assert config.causal is True 137 | 138 | for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes: 139 | self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False)) 140 | 141 | if config.target_vocab_sizes is None: 142 | for i in range(len(config.vocab_sizes)): 143 | # TODO: do we not need to take the transpose here? 144 | # https://paperswithcode.com/method/weight-tying 145 | self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore 146 | assert len(self.lm_heads) == len( 147 | self.transformer.wtes # type: ignore 148 | ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore 149 | # - causal 150 | # GPT( 151 | # (transformer): ModuleDict( 152 | # (wtes): ModuleList( 153 | # (0): Embedding(2562, 2048) 154 | # ) 155 | # (wpe): Embedding(2048, 2048) 156 | # (drop): Dropout(p=0.0, inplace=False) 157 | # (h): ModuleList( 158 | # (0-23): 24 x Block( 159 | # (ln_1): RMSNorm() 160 | # (ln_2): RMSNorm() 161 | # (attn): SelfAttention( 162 | # (c_attn): Linear(in_features=2048, out_features=6144, bias=False) 163 | # (c_proj): Linear(in_features=2048, out_features=2048, bias=False) 164 | # (resid_dropout): Dropout(p=0.0, inplace=False) 165 | # ) 166 | # (mlp): MLP( 167 | # (swiglu): SwiGLU( 168 | # (w1): Linear(in_features=2048, out_features=5632, bias=False) 169 | # (w3): Linear(in_features=2048, out_features=5632, bias=False) 170 | # ) 171 | # (c_proj): Linear(in_features=5632, out_features=2048, bias=False) 172 | # (dropout): Dropout(p=0.0, inplace=False) 173 | # ) 174 | # ) 175 | # ) 176 | # (ln_f): RMSNorm() 177 | # ) 178 | # (speaker_cond_pos): Linear(in_features=256, out_features=2048, bias=False) 179 | # (lm_heads): ModuleList( 180 | # (0): Linear(in_features=2048, out_features=2562, bias=False) 181 | # ) 182 | # ) 183 | # GPTConfig(block_size=2048, vocab_sizes=[2562], target_vocab_sizes=None, n_layer=24, n_head=16, n_embd=2048, dropout=0.0, spkemb_dropout=0.1, bias=False, causal=True, spk_emb_on_text=True, norm_type='rmsnorm', rmsnorm_eps=1e-05, nonlinearity_type='swiglu', swiglu_multiple_of=256, attn_kernel_type='torch_attn', kv_cache_enabled=False) 184 | # 185 | # - non causal 186 | # GPT( 187 | # (transformer): ModuleDict( 188 | # (wtes): ModuleList( 189 | # (0): Embedding(1538, 384) 190 | # (1): Embedding(1025, 384) 191 | # ) 192 | # (wpe): Embedding(1024, 384) 193 | # (drop): Dropout(p=0.0, inplace=False) 194 | # (h): ModuleList( 195 | # (0-5): 6 x Block( 196 | # (ln_1): LayerNorm() 197 | # (ln_2): LayerNorm() 198 | # (attn): SelfAttention( 199 | # (c_attn): Linear(in_features=384, out_features=1152, bias=False) 200 | # (c_proj): Linear(in_features=384, out_features=384, bias=False) 201 | # (resid_dropout): Dropout(p=0.0, inplace=False) 202 | # ) 203 | # (mlp): MLP( 204 | # (c_fc): Linear(in_features=384, out_features=1536, bias=False) 205 | # (gelu): GELU(approximate='none') 206 | # (c_proj): Linear(in_features=1536, out_features=384, bias=False) 207 | # (dropout): Dropout(p=0.0, inplace=False) 208 | # ) 209 | # ) 210 | # ) 211 | # (ln_f): LayerNorm() 212 | # ) 213 | # (speaker_cond_pos): Linear(in_features=256, out_features=384, bias=False) 214 | # (lm_heads): ModuleList( 215 | # (0-5): 6 x Linear(in_features=384, out_features=1025, bias=False) 216 | # ) 217 | # ) 218 | # GPTConfig(block_size=1024, vocab_sizes=[1538, 1025], target_vocab_sizes=[1025, 1025, 1025, 1025, 1025, 1025], n_layer=6, n_head=6, n_embd=384, dropout=0.0, spkemb_dropout=0.0, bias=False, causal=False, spk_emb_on_text=True, norm_type='layernorm', rmsnorm_eps=None, nonlinearity_type='gelu', swiglu_multiple_of=None, attn_kernel_type='fa2', kv_cache_enabled=False) 219 | # if config.causal is False: 220 | # embed() 221 | # init all weights 222 | self.apply(self._init_weights) 223 | # apply special scaled init to the residual projections, per GPT-2 paper 224 | for pn, p in self.named_parameters(): 225 | if pn.endswith("c_proj.weight"): 226 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 227 | 228 | # report number of parameters 229 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) 230 | 231 | def get_num_params(self, non_embedding=True): 232 | """ 233 | Return the number of parameters in the model. 234 | For non-embedding count (default), the position embeddings get subtracted. 235 | The token embeddings would too, except due to the parameter sharing these 236 | params are actually used as weights in the final layer, so we include them. 237 | """ 238 | n_params = sum(p.numel() for p in self.parameters()) 239 | if non_embedding: 240 | n_params -= self.transformer.wpe.weight.numel() 241 | return n_params 242 | 243 | def _init_weights(self, module): 244 | if isinstance(module, nn.Linear): 245 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 246 | if module.bias is not None: 247 | torch.nn.init.zeros_(module.bias) 248 | elif isinstance(module, nn.Embedding): 249 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 250 | 251 | def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor: 252 | """ 253 | This is in a separate function so we can test it easily. 254 | """ 255 | # find index of end of text token in each sequence, then generate a binary mask 256 | # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token. 257 | # Note: this does NOT mask the token. This is important so that the first audio token predicted 258 | # has speaker information to use. 259 | 260 | # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens. 261 | is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN 262 | # use > 0, in case end_of_text_token is repeated for any reason. 263 | mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float() 264 | spk_emb = spk_emb * mask[:, :, None] 265 | 266 | return spk_emb 267 | 268 | def forward( 269 | self, 270 | idx, 271 | targets=None, 272 | speaker_embs=None, 273 | embedding=None, 274 | speaker_emb_mask=None, 275 | loss_reduce: Literal["mean", "none"] = "mean", 276 | ): 277 | device = idx.device 278 | b, num_hierarchies, t = idx.size() 279 | 280 | if speaker_embs is not None: 281 | speaker_embs = _check_speaker_emb_dims( 282 | speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b 283 | ) 284 | 285 | assert ( 286 | t <= self.config.block_size 287 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 288 | 289 | if self.kv_cache_enabled: 290 | if self.kv_pos == 0: 291 | pos = torch.arange(0, t, dtype=torch.long, device=device) 292 | self.kv_pos += t 293 | else: 294 | assert t == 1, "KV cache is only supported for single token inputs" 295 | pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1) 296 | self.kv_pos += 1 297 | else: 298 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 299 | # embed the tokens, positional encoding, and speaker embedding 300 | tok_emb = torch.zeros((b, t, self.config.n_embd), device=device) 301 | # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings. 302 | wte = self.transformer.wtes[0] 303 | for i, wte in enumerate(self.transformer.wtes): 304 | mask_pad = idx[:, i, :] == -1 305 | masked_idx = idx[:, i, :].clone() 306 | masked_idx[mask_pad] = 0 307 | embedded_idx = wte(masked_idx) 308 | embedded_idx[mask_pad] = 0 309 | tok_emb += embedded_idx 310 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 311 | 312 | spk_emb = 0.0 313 | if speaker_embs is not None: 314 | if type(speaker_embs) == list: 315 | assert speaker_emb_mask is None 316 | assert self.training is False 317 | assert self.spk_emb_on_text is True 318 | spk_emb = [] 319 | for speaker_emb_row in speaker_embs: 320 | if speaker_emb_row is not None: 321 | spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0))) 322 | assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}" 323 | else: 324 | spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype)) 325 | spk_emb = torch.cat(spk_emb, dim=0) 326 | 327 | assert ( 328 | spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b 329 | ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}" 330 | else: 331 | speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c) 332 | 333 | if speaker_emb_mask is not None: 334 | spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask) 335 | assert spk_emb.shape == (b, t, self.config.n_embd) 336 | else: 337 | spk_emb = speakers_embedded 338 | # if we don't have a mask, we assume that the speaker embedding is the same for all tokens 339 | # then num_examples dimension just becomes the time dimension 340 | assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1 341 | 342 | if self.training and self.config.spkemb_dropout > 0.0: 343 | # Remove speaker conditioning at random. 344 | dropout = torch.ones_like(speakers_embedded) * ( 345 | torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout 346 | ) 347 | spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded) 348 | 349 | if self.spk_emb_on_text is False: 350 | assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False" 351 | spk_emb = self._mask_spk_emb_on_text(idx, spk_emb) 352 | elif embedding is not None: 353 | speakers_embedded = self.speaker_cond_pos(embedding) 354 | 355 | if self.training and self.config.spkemb_dropout > 0.0: 356 | # Remove speaker conditioning at random. 357 | dropout = torch.ones_like(speakers_embedded) * ( 358 | torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout 359 | ) 360 | spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded) 361 | else: 362 | spk_emb = speakers_embedded 363 | 364 | x = self.transformer.drop(tok_emb + pos_emb + spk_emb) 365 | for block in self.transformer.h: 366 | x = block(x) 367 | x = self.transformer.ln_f(x) 368 | 369 | if targets is not None: 370 | # if we are given some desired targets also calculate the loss 371 | list_logits = [lm_head(x) for lm_head in self.lm_heads] 372 | 373 | losses = [ 374 | F.cross_entropy( 375 | logits.view(-1, logits.size(-1)), 376 | targets[:, i, :].contiguous().view(-1), 377 | ignore_index=-1, 378 | reduction=loss_reduce, 379 | ) 380 | for i, logits in enumerate(list_logits) 381 | ] 382 | losses = torch.stack(losses) 383 | if loss_reduce == "mean": 384 | losses = losses.mean() 385 | else: 386 | losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t) 387 | else: 388 | # inference-time mini-optimization: only forward the lm_head on the very last position 389 | if self.config.causal: 390 | list_logits = [ 391 | lm_head(x[:, [-1], :]) for lm_head in self.lm_heads 392 | ] # note: using list [-1] to preserve the time dim 393 | # print(f'{len(list_logits)=}, {list_logits[0].shape=}') 394 | else: 395 | list_logits = [lm_head(x) for lm_head in self.lm_heads] 396 | losses = None 397 | 398 | return list_logits, losses 399 | 400 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 401 | # start with all of the candidate parameters 402 | param_dict = {pn: p for pn, p in self.named_parameters()} 403 | # filter out those that do not require grad 404 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 405 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 406 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 407 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 408 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 409 | optim_groups = [ 410 | {"params": decay_params, "weight_decay": weight_decay}, 411 | {"params": nodecay_params, "weight_decay": 0.0}, 412 | ] 413 | num_decay_params = sum(p.numel() for p in decay_params) 414 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 415 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 416 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 417 | # Create AdamW optimizer and use the fused version if it is available 418 | fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters 419 | use_fused = fused_available and device_type == "cuda" 420 | extra_args = dict(fused=True) if use_fused else dict() 421 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 422 | print(f"using fused AdamW: {use_fused}") 423 | 424 | return optimizer 425 | 426 | @torch.no_grad() 427 | def generate( 428 | self, 429 | idx: torch.Tensor, 430 | max_new_tokens: int, 431 | seq_lens: Optional[list] = None, 432 | temperature: float = 1.0, 433 | top_k: Optional[int] = None, 434 | top_p: Optional[float] = None, 435 | speaker_embs: Optional[torch.Tensor] = None, 436 | batch_size: Optional[int] = None, 437 | guidance_scale: Optional[float] = None, 438 | ): 439 | """ 440 | Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete 441 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 442 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 443 | """ 444 | assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens" 445 | 446 | if self.config.causal: 447 | if seq_lens is None or batch_size is None: 448 | raise Exception("seq_lens and batch_size must be provided for causal sampling") 449 | 450 | return self._causal_sample( 451 | idx=idx, 452 | max_new_tokens=max_new_tokens, 453 | seq_lens=seq_lens, 454 | temperature=temperature, 455 | top_k=top_k, 456 | top_p=top_p, 457 | speaker_embs=speaker_embs, 458 | batch_size=batch_size, 459 | guidance_scale=guidance_scale, 460 | ) 461 | 462 | else: 463 | if seq_lens is not None: 464 | raise Exception("seq_lens is not supported yet for non-causal sampling") 465 | 466 | if batch_size is None: 467 | raise Exception("batch_size must be provided for non-causal sampling") 468 | 469 | if guidance_scale is not None: 470 | raise Exception("guidance_scale is not supported for non-causal sampling") 471 | 472 | if top_p is not None: 473 | raise Exception("top_p is not supported for non-causal sampling") 474 | 475 | out = [] 476 | for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="non-causal batching"): 477 | end_index = min(start_index + batch_size, idx.shape[0]) 478 | out.append( 479 | self._non_causal_sample( 480 | idx=idx[start_index:end_index], 481 | speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None, 482 | temperature=temperature, 483 | top_k=top_k, 484 | ) 485 | ) 486 | return torch.cat(out, dim=0) 487 | return torch.cat(out, dim=0) 488 | -------------------------------------------------------------------------------- /fam/llm/serving.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shlex 5 | import subprocess 6 | import tempfile 7 | from pathlib import Path 8 | from typing import Literal, Optional 9 | 10 | import fastapi 11 | import fastapi.middleware.cors 12 | import torch 13 | import tyro 14 | import uvicorn 15 | from attr import dataclass 16 | from fastapi import Request 17 | from fastapi.responses import Response 18 | from huggingface_hub import snapshot_download 19 | 20 | from fam.llm.sample import ( 21 | InferenceConfig, 22 | Model, 23 | build_models, 24 | get_first_stage_path, 25 | get_second_stage_path, 26 | # sample_utterance, 27 | ) 28 | from fam.llm.fast_inference import TTS 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | ## Setup FastAPI server. 34 | app = fastapi.FastAPI() 35 | 36 | 37 | @dataclass 38 | class ServingConfig: 39 | huggingface_repo_id: str 40 | """Absolute path to the model directory.""" 41 | 42 | max_new_tokens: int = 864 * 2 43 | """Maximum number of new tokens to generate from the first stage model.""" 44 | 45 | temperature: float = 1.0 46 | """Temperature for sampling applied to both models.""" 47 | 48 | top_k: int = 200 49 | """Top k for sampling applied to both models.""" 50 | 51 | seed: int = 1337 52 | """Random seed for sampling.""" 53 | 54 | dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16" 55 | """Data type to use for sampling.""" 56 | 57 | enhancer: Optional[Literal["df"]] = "df" 58 | """Enhancer to use for post-processing.""" 59 | 60 | port: int = 58003 61 | 62 | 63 | # Singleton 64 | class _GlobalState: 65 | config: ServingConfig 66 | tts: TTS 67 | 68 | 69 | GlobalState = _GlobalState() 70 | 71 | @dataclass(frozen=True) 72 | class TTSRequest: 73 | text: str 74 | guidance: Optional[float] = 3.0 75 | top_p: Optional[float] = 0.95 76 | speaker_ref_path: Optional[str] = None 77 | top_k: Optional[int] = None 78 | 79 | 80 | def sample_utterance( 81 | text: str, 82 | spk_cond_path: str | None, 83 | guidance_scale, 84 | max_new_tokens, 85 | top_k, 86 | top_p, 87 | temperature, 88 | ) -> str: 89 | return GlobalState.tts.synthesise( 90 | text, 91 | spk_cond_path, 92 | top_p=top_p, 93 | guidance_scale=guidance_scale, 94 | temperature=temperature, 95 | ) 96 | 97 | 98 | @app.post("/tts", response_class=Response) 99 | async def text_to_speech(req: Request): 100 | audiodata = await req.body() 101 | payload = None 102 | wav_out_path = None 103 | 104 | try: 105 | headers = req.headers 106 | payload = headers["X-Payload"] 107 | payload = json.loads(payload) 108 | tts_req = TTSRequest(**payload) 109 | with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp: 110 | if tts_req.speaker_ref_path is None: 111 | wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) 112 | else: 113 | wav_path = tts_req.speaker_ref_path 114 | wav_out_path = sample_utterance( 115 | tts_req.text, 116 | wav_path, 117 | guidance_scale=tts_req.guidance, 118 | max_new_tokens=GlobalState.config.max_new_tokens, 119 | temperature=GlobalState.config.temperature, 120 | top_k=tts_req.top_k, 121 | top_p=tts_req.top_p, 122 | ) 123 | with open(wav_out_path, "rb") as f: 124 | return Response(content=f.read(), media_type="audio/wav") 125 | except Exception as e: 126 | # traceback_str = "".join(traceback.format_tb(e.__traceback__)) 127 | logger.exception(f"Error processing request {payload}") 128 | return Response( 129 | content="Something went wrong. Please try again in a few mins or contact us on Discord", 130 | status_code=500, 131 | ) 132 | finally: 133 | if wav_out_path is not None: 134 | Path(wav_out_path).unlink(missing_ok=True) 135 | 136 | 137 | def _convert_audiodata_to_wav_path(audiodata, wav_tmp): 138 | with tempfile.NamedTemporaryFile() as unknown_format_tmp: 139 | assert unknown_format_tmp.write(audiodata) > 0 140 | unknown_format_tmp.flush() 141 | 142 | subprocess.check_output( 143 | # arbitrary 2 minute cutoff 144 | shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}") 145 | ) 146 | 147 | return wav_tmp.name 148 | 149 | 150 | if __name__ == "__main__": 151 | # This has to be here to avoid some weird audiocraft shenaningans messing up matplotlib 152 | from fam.llm.enhancers import get_enhancer 153 | 154 | for name in logging.root.manager.loggerDict: 155 | logger = logging.getLogger(name) 156 | logger.setLevel(logging.INFO) 157 | logging.root.setLevel(logging.INFO) 158 | 159 | GlobalState.config = tyro.cli(ServingConfig) 160 | app.add_middleware( 161 | fastapi.middleware.cors.CORSMiddleware, 162 | allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"], 163 | allow_credentials=True, 164 | allow_methods=["*"], 165 | allow_headers=["*"], 166 | ) 167 | 168 | device = "cuda" if torch.cuda.is_available() else "cpu" 169 | common_config = dict( 170 | num_samples=1, 171 | seed=1337, 172 | device=device, 173 | dtype=GlobalState.config.dtype, 174 | compile=False, 175 | init_from="resume", 176 | output_dir=tempfile.mkdtemp(), 177 | ) 178 | model_dir = snapshot_download(repo_id=GlobalState.config.huggingface_repo_id) 179 | config1 = InferenceConfig( 180 | ckpt_path=get_first_stage_path(model_dir), 181 | **common_config, 182 | ) 183 | 184 | config2 = InferenceConfig( 185 | ckpt_path=get_second_stage_path(model_dir), 186 | **common_config, 187 | ) 188 | 189 | GlobalState.tts = TTS() 190 | 191 | # start server 192 | uvicorn.run( 193 | app, 194 | host="127.0.0.1", 195 | port=GlobalState.config.port, 196 | log_level="info", 197 | ) 198 | -------------------------------------------------------------------------------- /fam/llm/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import wandb 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.strategies import FSDPStrategy, DDPStrategy 10 | from pytorch_lightning.plugins.environments import ClusterEnvironment 11 | from huggingface_hub import snapshot_download 12 | 13 | # Specific to your project's structure 14 | from fam.llm.training import parse_args, get_first_stage_path, get_second_stage_path, TrainingConfig, VoiceDataModule, WandbLogger, dist_utils, optimizer_utils, Evaluator 15 | from fam.llm.decoders import Decoder, EncodecDecoder 16 | from fam.quantiser.text.tokenise import TrainedBPETokeniser 17 | from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook 18 | from fam.llm.sample import Model 19 | 20 | class MyClusterEnvironment(ClusterEnvironment): 21 | @property 22 | def creates_processes_externally(self) -> bool: 23 | """Return True if the cluster is managed (you don't launch processes yourself)""" 24 | return True 25 | 26 | def world_size(self) -> int: 27 | return int(os.environ["OMPI_COMM_WORLD_SIZE"]) 28 | 29 | def global_rank(self) -> int: 30 | return int(os.environ["OMPI_COMM_WORLD_RANK"]) 31 | 32 | def local_rank(self) -> int: 33 | return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 34 | 35 | def node_rank(self) -> int: 36 | return int(os.environ["OMPI_COMM_WORLD_NODE_RANK"]) 37 | 38 | main_address = os.getenv("MASTER_ADDR","") 39 | 40 | main_port = int(os.getenv("MASTER_PORT", "0")) 41 | 42 | def set_global_rank(self, rank): 43 | self.global_rank_ = rank 44 | 45 | def set_world_size(self, size): 46 | self.world_size_ = size 47 | 48 | def detect(self): 49 | return True 50 | 51 | class KotobaSpeechModelFirstStage(pl.LightningModule): 52 | """ 53 | A PyTorch Lightning module for the first stage of KotobaVoice model training. 54 | """ 55 | 56 | def __init__(self, config_first_stage, config_second_stage, device, use_kv_cache, logger, is_debug=False, configs=None): 57 | super().__init__() 58 | self.configs = vars(configs) 59 | self.config_first_stage = config_first_stage 60 | self.use_kv_cache = use_kv_cache 61 | self.prev_step = -1 62 | self.evaluator = Evaluator() 63 | if dist_utils.is_main_process(): 64 | self.wandb_logger = logger 65 | 66 | def configure_model(self): 67 | """ 68 | Configures the model and its components. 69 | """ 70 | self.data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024) 71 | self.first_stage_model_cls = Model( 72 | self.config_first_stage, 73 | TrainedBPETokeniser, 74 | EncodecDecoder, 75 | data_adapter_fn=self.data_adapter.decode, 76 | use_kv_cache=self.use_kv_cache, 77 | ) 78 | self.first_stage_model_transformer = self.first_stage_model_cls.model 79 | 80 | def forward(self, text_tokens, embedding, inputs, targets): 81 | truncated_inputs = inputs[:,:2048] 82 | truncated_targets = targets[:,:2048] 83 | truncated_inputs = truncated_inputs.unsqueeze(1) 84 | truncated_targets = truncated_targets.unsqueeze(1) 85 | list_logits, first_stage_loss = self.first_stage_model_transformer(idx=truncated_inputs, embedding=embedding, targets=truncated_targets, loss_reduce="mean") 86 | return first_stage_loss 87 | 88 | def training_step(self, batch, batch_idx): 89 | tokens, embedding, audio_tokens, first_stage_input, first_stage_output = batch 90 | embedding = embedding.unsqueeze(1) 91 | loss = self(text_tokens=tokens, embedding=embedding, inputs=first_stage_input, targets=first_stage_output) 92 | 93 | if dist_utils.is_main_process(): 94 | if self.prev_step == self.global_step: 95 | pass 96 | else: 97 | stats = { 98 | "loss": loss.item(), 99 | } 100 | if self.wandb_logger is not None: 101 | self.wandb_logger.log(results=stats, split="training", step=self.global_step, commit=False) 102 | lr_stats = {f"lr_group{i}": list(self.optimizers().param_groups)[i]["lr"] for i in range(len(self.optimizers().param_groups))} 103 | self.wandb_logger.log(results=lr_stats, split="lr", step=self.global_step, commit=True) 104 | self.prev_step = self.global_step 105 | 106 | return loss 107 | 108 | def on_train_start(self): 109 | if dist_utils.is_main_process(): 110 | if self.wandb_logger is not None: 111 | wandb.watch(self.first_stage_model_transformer, log='parameters', log_freq=1000) 112 | 113 | def validation_step(self, batch, batch_idx): 114 | tokens, embedding, audio_tokens, first_stage_input, first_stage_output = batch 115 | tokens = tokens.unsqueeze(1) 116 | embedding = embedding.unsqueeze(1) 117 | loss = self(text_tokens=tokens, embedding=embedding, inputs=first_stage_input, targets=first_stage_output) 118 | stats = { 119 | "loss": [loss.item()], 120 | } 121 | self.evaluator.update(stats) 122 | return loss 123 | 124 | def on_validation_epoch_end(self): 125 | self.evaluator.synchronize_between_processes() 126 | if dist_utils.is_main_process(): 127 | summary = self.evaluator.summarize() 128 | self.evaluator.reset() 129 | if dist_utils.is_main_process(): 130 | if self.wandb_logger is not None: 131 | self.wandb_logger.log(results=summary, split="validation", step=self.global_step, commit=False) 132 | 133 | def configure_optimizers(self): 134 | return optimizer_utils.set_schedule(self) 135 | 136 | # Train the model 137 | if __name__ == "__main__": 138 | training_args = parse_args() 139 | data_dir = training_args.debug_data_dir if training_args.debug else training_args.data_dir 140 | data_module = VoiceDataModule(training_args.per_gpu_batchsize, data_dir) 141 | 142 | model_dir = snapshot_download(repo_id=training_args.huggingface_repo_id) 143 | first_stage_ckpt_path = get_first_stage_path(model_dir) 144 | second_stage_ckpt_path = get_second_stage_path(model_dir) 145 | config_first_stage = TrainingConfig( 146 | ckpt_path=first_stage_ckpt_path, 147 | num_samples=training_args.num_samples, 148 | seed=training_args.seed, 149 | device=training_args.device, 150 | dtype=training_args.dtype, 151 | compile=training_args.compile, 152 | init_from=training_args.init_from, 153 | train_from_scratch=training_args.train_from_scratch, 154 | output_dir=training_args.output_dir, 155 | spkemb_dropout=training_args.spkemb_dropout 156 | ) 157 | 158 | config_second_stage = TrainingConfig( 159 | ckpt_path=second_stage_ckpt_path, 160 | num_samples=training_args.num_samples, 161 | seed=training_args.seed, 162 | device=training_args.device, 163 | dtype=training_args.dtype, 164 | compile=training_args.compile, 165 | init_from=training_args.init_from, 166 | train_from_scratch=training_args.train_from_scratch, 167 | output_dir=training_args.output_dir, 168 | spkemb_dropout=training_args.spkemb_dropout 169 | ) 170 | 171 | is_debug = training_args.debug 172 | logger = WandbLogger(training_args) if training_args.use_wandb else None 173 | model = KotobaSpeechModelFirstStage(config_first_stage, config_second_stage, device=training_args.device, use_kv_cache=training_args.use_kv_cache, logger=logger, is_debug=is_debug, configs=training_args) 174 | 175 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 176 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 177 | save_top_k=-1, 178 | verbose=True, 179 | monitor=None, 180 | ) 181 | callbacks = [checkpoint_callback] 182 | 183 | num_gpus = ( 184 | training_args.num_gpus 185 | if isinstance(training_args.num_gpus, int) 186 | else len(training_args.num_gpus) 187 | ) 188 | 189 | grad_steps = training_args.batch_size // ( 190 | training_args.per_gpu_batchsize * num_gpus * training_args.num_nodes 191 | ) 192 | 193 | max_steps = training_args.max_steps if training_args.max_steps is not None else None 194 | 195 | if training_args.dtype == "bfloat16": 196 | precision="bf16-mixed" 197 | else: 198 | raise "Precision needs to be studied well." 199 | 200 | if training_args.fsdp_strategy is not None: 201 | strategy = FSDPStrategy( 202 | sharding_strategy=training_args.fsdp_strategy 203 | ) 204 | elif training_args.use_ddp_strategy: 205 | strategy = DDPStrategy( 206 | static_graph=True, 207 | ) 208 | else: 209 | strategy = 'ddp_find_unused_parameters_true' 210 | 211 | if training_args.num_nodes > 1: 212 | plugins = [MyClusterEnvironment()] 213 | else: 214 | plugins = None 215 | 216 | if training_args.val_check_interval is not None: 217 | if training_args.val_check_interval >= 1: 218 | training_args.val_check_interval=int(training_args.val_check_interval) 219 | 220 | 221 | trainer = Trainer( 222 | callbacks=callbacks, 223 | devices=training_args.num_gpus, 224 | strategy=strategy, 225 | num_nodes=training_args.num_nodes, 226 | precision=precision, 227 | accelerator="cuda", 228 | benchmark=True, 229 | deterministic=True, 230 | max_epochs=training_args.max_epoch if max_steps is None else 1000, 231 | accumulate_grad_batches=grad_steps, 232 | val_check_interval=training_args.val_check_interval, 233 | check_val_every_n_epoch=training_args.check_val_every_n_epoch, 234 | log_every_n_steps=10, 235 | fast_dev_run=training_args.fast_dev_run, 236 | plugins=plugins, 237 | gradient_clip_val=training_args.gradient_clip_val, 238 | ) 239 | 240 | trainer.fit( 241 | model, 242 | ckpt_path=training_args.ckpt_path, 243 | datamodule=data_module 244 | ) 245 | trainer.print(torch.cuda.memory_summary()) 246 | 247 | 248 | -------------------------------------------------------------------------------- /fam/llm/training/__init__.py: -------------------------------------------------------------------------------- 1 | from fam.llm.training.args import parse_args 2 | from fam.llm.training.utils import get_first_stage_path, get_second_stage_path, TrainingConfig 3 | from fam.llm.training.dataset import VoiceDataset 4 | from fam.llm.training.datamodule import VoiceDataModule 5 | from fam.llm.training.wandb_utils import WandbLogger 6 | from fam.llm.training.evaluator import Evaluator 7 | import fam.llm.training.optimizer as optimizer_utils 8 | import fam.llm.training.dist as dist_utils -------------------------------------------------------------------------------- /fam/llm/training/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Optional, Literal 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description="Sample from a trained model.") 6 | 7 | parser.add_argument("--batch_size", type=int, default=128, 8 | help="Batch size.") 9 | 10 | parser.add_argument("--compile", action='store_true', 11 | help="Whether to compile the model using PyTorch 2.0.") 12 | 13 | parser.add_argument("--ckpt_path", type=str, default=None, 14 | help="Path to a checkpoint file to resume training from.") 15 | 16 | parser.add_argument("--data_dir", type=str, default="", 17 | help="A path to the dataset dir.") 18 | 19 | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], 20 | help="Device to use for sampling.") 21 | 22 | parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32", "tfloat32"], 23 | help="Data type to use for sampling.") 24 | 25 | parser.add_argument("--exp_name", type=str, default="kotoba_voice_1.3B", 26 | help="A path to the dataset dir.") 27 | 28 | parser.add_argument("--fast_dev_run", action='store_true', default=False, 29 | help="Run a quick development check, usually used for debugging and testing purposes.") 30 | 31 | parser.add_argument("--huggingface_repo_id", type=str, required=False, default="kotoba-tech/kotoba-speech-v0.1", 32 | help="Absolute path to the model directory.") 33 | 34 | parser.add_argument("--train_from_scratch", action='store_true', default=False, 35 | help="Run a quick development check, usually used for debugging and testing purposes.") 36 | 37 | parser.add_argument("--init_from", type=str, default="resume", 38 | help="Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl').") 39 | 40 | parser.add_argument("--max_epoch", type=int, default=5, 41 | help="Number of nodes") 42 | 43 | parser.add_argument("--max_steps", type=int, default=None, 44 | help="Max steps to train") 45 | 46 | parser.add_argument("--num_gpus", type=int, default=1, 47 | help="Number of GPUs per node") 48 | 49 | parser.add_argument("--num_nodes", type=int, default=1, 50 | help="Number of nodes") 51 | 52 | parser.add_argument("--num_samples", type=int, default=1, 53 | help="Number of samples to generate from each model.") 54 | 55 | parser.add_argument("--output_dir", type=str, default="samples/", 56 | help="Relative path to output directory") 57 | 58 | parser.add_argument("--per_gpu_batchsize", type=int, default=16, 59 | help="Batch size per GPU.") 60 | 61 | parser.add_argument("--seed", type=int, default=1337, 62 | help="Random seed for sampling.") 63 | 64 | parser.add_argument("--use_kv_cache", type=str, default=None, choices=[None, "flash_decoding", "vanilla"], 65 | help="Type of kv caching to use for inference.") 66 | 67 | parser.add_argument("--val_check_interval", type=float, default=None, 68 | help="This overwrites check_val_every_n_epoch. If this value is less than 1, for example, 0.25, it means the validation set will be checked 4 times during a training epoch. If this value is greater than 1, for example, 1000, the validation set will be checked every 1000 training batches, either across complete epochs or during iteration-based training.") 69 | 70 | parser.add_argument("--check_val_every_n_epoch", type=int, default=1, 71 | help="Validate every n epoch.") 72 | 73 | parser.add_argument('--fsdp_strategy', type=str, default=None, choices=[None, 'FULL_SHARD', 'SHARD_GRAD_OP', 'HYBRID_SHARD', 'NO_SHARD'], 74 | help='Use fully sharded data parallel type,' 75 | 'FULL_SHARD: Shard weights, gradients, optimizer state (1 + 2 + 3)' 76 | 'SHARD_GRAD_OP: Shard gradients, optimizer state (2 + 3)' 77 | 'HYBRID_SHARD: Full-shard within a machine, replicate across machines' 78 | "NO_SHARD: Don't shard anything (similar to DDP, slower than DDP)" 79 | 'None: Use DDP' 80 | ) 81 | 82 | parser.add_argument('--use_ddp_strategy', action='store_true', default=False, 83 | help='use DDPStrategy()') 84 | 85 | # WandB settings 86 | parser.add_argument('--use_wandb', action='store_true', default=False, 87 | help='Enable integration with Weights & Biases for experiment tracking.') 88 | 89 | parser.add_argument("--wandb_entity", type=str, default=None, 90 | help="Weights & Biases entity (team or user) under which the project is located (optional).") 91 | 92 | parser.add_argument("--wandb_project", type=str, default=None, 93 | help="Weights & Biases project name to which the run will be logged (optional).") 94 | 95 | parser.add_argument("--wandb_run_id", type=str, default=None, 96 | help="Unique identifier for the Weights & Biases run, allowing for run resumption or other operations (optional).") 97 | 98 | # Debug Setting 99 | parser.add_argument("--debug", action='store_true', default=False, 100 | help="Debug mode") 101 | 102 | parser.add_argument("--debug_data_dir", type=str, default="", 103 | help="A path to the dataset dir.") 104 | 105 | # Optimizer Setting 106 | parser.add_argument("--optimizer_type", type=str, default="AdamW", 107 | choices=["Adam", "AdamW"], help="Type of optimizer to use.") 108 | 109 | parser.add_argument("--beta_one", type=float, default=0.9, 110 | help="Coefficient for computing running averages of gradient.") 111 | 112 | parser.add_argument("--beta_two", type=float, default=0.95, 113 | help="Coefficient for computing running averages of the square of the gradient.") 114 | 115 | parser.add_argument("--epsilon", type=float, default=1e-5, 116 | help="Term added to the denominator to improve numerical stability.") 117 | 118 | parser.add_argument("--weight_decay", type=float, default=0.1, 119 | help="Coefficient for computing running averages of the square of the gradient.") 120 | 121 | parser.add_argument("--decay_power", type=str, default="cosine", 122 | help="Type of learning rate decay to use.") 123 | 124 | parser.add_argument("--learning_rate", type=float, default=2e-4, 125 | help="Initial learning rate.") 126 | 127 | parser.add_argument("--warmup_steps", type=float, default=0.01, 128 | help="Fraction of total training steps to use for learning rate warmup. If warmup_steps is greater than 1, then the value specified in warmup_steps will represent the exact number of steps to be used for the warmup phase.") 129 | 130 | parser.add_argument("--end_lr", type=float, default=0, 131 | help="Final learning rate after decay.") 132 | 133 | parser.add_argument("--gradient_clip_val", type=float, default=0, 134 | help="Clip gradients' maximum magnitude.") 135 | 136 | # Model config setting 137 | parser.add_argument("--spkemb_dropout", type=float, default=0.1, 138 | help="Fraction of total training steps to use for learning rate warmup. If warmup_steps is greater than 1, then the value specified in warmup_steps will represent the exact number of steps to be used for the warmup phase.") 139 | 140 | return parser.parse_args() 141 | 142 | if __name__ == "__main__": 143 | args = parse_args() 144 | # Now you can use args to access your configuration, for example: 145 | # print(args.huggingface_repo_id) 146 | -------------------------------------------------------------------------------- /fam/llm/training/datamodule.py: -------------------------------------------------------------------------------- 1 | # python fam/llm/training/datamodule.py 2 | import pytorch_lightning as pl 3 | from torch.utils.data import DataLoader 4 | from fam.llm.training.dataset import VoiceDataset 5 | 6 | class VoiceDataModule(pl.LightningDataModule): 7 | def __init__(self, batch_size=64, data_dir="/root/data/reazon_small/"): 8 | super().__init__() 9 | self.batch_size = batch_size 10 | self.data_dir = data_dir 11 | 12 | def prepare_data(self): 13 | pass 14 | 15 | def setup(self, stage): 16 | # Load and split data 17 | if stage == "train" or stage == "fit": 18 | self.voice_train = VoiceDataset(split="train", data_dir=self.data_dir) 19 | self.voice_val = VoiceDataset(split="val", data_dir=self.data_dir) 20 | self.voice_test = None 21 | elif stage == "test": 22 | self.voice_train = None 23 | self.voice_val = None 24 | self.voice_test = VoiceDataset(split="test", data_dir=self.data_dir) 25 | else: 26 | assert False 27 | 28 | def train_dataloader(self): 29 | return DataLoader(self.voice_train, batch_size=self.batch_size, collate_fn=self.voice_train.collate, shuffle=True) 30 | 31 | def val_dataloader(self): 32 | return DataLoader(self.voice_val, batch_size=self.batch_size, collate_fn=self.voice_val.collate) 33 | 34 | def test_dataloader(self): 35 | return DataLoader(self.voice_test, batch_size=self.batch_size, collate_fn=self.voice_test.collate) 36 | 37 | if __name__ == "__main__": 38 | voice_datamodule = VoiceDataModule(batch_size=32, data_dir="/root/data/reazon_small/") 39 | voice_datamodule.setup("train") 40 | train_dataloader = voice_datamodule.train_dataloader() 41 | for batch in train_dataloader: 42 | print(batch) 43 | val_dataloader = voice_datamodule.val_dataloader() 44 | test_dataloader = voice_datamodule.test_dataloader() 45 | -------------------------------------------------------------------------------- /fam/llm/training/dataset.py: -------------------------------------------------------------------------------- 1 | # python fam/llm/training/dataset.py 2 | import os, json 3 | import torch 4 | from tqdm import tqdm 5 | from torch.utils.data import Dataset 6 | 7 | class VoiceDataset(Dataset): 8 | def __init__(self, split="train", data_dir=""): 9 | self.data_dir = data_dir 10 | self.data = self._load_data(os.path.join(data_dir, f'{split}.jsonl')) 11 | 12 | def _load_data(self, file_path): 13 | data = [] 14 | with open(file_path, 'r') as f: 15 | for line in f: 16 | data.append(json.loads(line)) 17 | return data 18 | 19 | def __len__(self): 20 | return len(self.data) 21 | 22 | def __getitem__(self, idx): 23 | train_item = self.data[idx] 24 | key = train_item["key"] 25 | tokens = torch.load(os.path.join(self.data_dir, 'txt', f"{key}.pt")) 26 | embedding = torch.load(os.path.join(self.data_dir, 'emb', f"{key}.emb")) 27 | audio_tokens = torch.load(os.path.join(self.data_dir, 'tok', f"{key}.tok")) 28 | return tokens, embedding, audio_tokens 29 | 30 | def collate(self, batch): 31 | audio_pad = 2048 32 | num_bands = 2 33 | eoa = 2048 34 | loss_pad = -1 35 | audio_vocab_size = 1024 36 | # end of audio 37 | # batch is a list of samples, where each sample is a dictionary 38 | tokens = [sample[0] for sample in batch] 39 | padded_tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0) 40 | embeddings = [sample[1] for sample in batch] 41 | padded_embeddings = torch.nn.utils.rnn.pad_sequence(embeddings, batch_first=True, padding_value=0) 42 | audios = [sample[2] for sample in batch] 43 | # Determine the maximum value of h 44 | max_l = max(tensor.size(1) for tensor in audios) 45 | max_t2l = max(tensor_text.size(0) + tensor_audio.size(1)*2 for tensor_text, tensor_audio in zip(tokens, audios)) 46 | # Pad tensors to have the same width 47 | padded_audios = [torch.nn.functional.pad(tensor, (0, max_l - tensor.shape[1]), mode='constant', value=audio_vocab_size) for tensor in audios] 48 | # Concatenate tensors along dimension 0 49 | padded_audios = torch.stack(padded_audios, dim=0) 50 | # first_stage_output [B, 2, T+T+L (padded)] 51 | first_stage_input = torch.full([len(batch), max_t2l], audio_pad, dtype=torch.int64) 52 | first_stage_output = torch.full([len(batch), max_t2l], loss_pad, dtype=torch.int64) 53 | ## first_stage_output: [text, text, audio, pad] 54 | ## first_stage_input: [-1, audio, eoa, pad] 55 | for idx, tensor_text, tensor_audio in zip(range(len(batch)), tokens, audios): 56 | text_size = tensor_text.size(0) 57 | audio_size = tensor_audio.size(1) 58 | bands = tensor_audio[:num_bands] 59 | bands += torch.arange(num_bands).view([num_bands, 1])*audio_vocab_size 60 | bands = bands.transpose(0, 1).contiguous().view(-1) 61 | first_stage_input[idx, :text_size] = tensor_text 62 | first_stage_input[idx, text_size:text_size+audio_size*num_bands] = bands 63 | first_stage_output[idx, :text_size-1] = torch.full([text_size-1], loss_pad, dtype=torch.int64) 64 | eoa_tensor = torch.full([1], eoa, dtype=torch.int64) 65 | first_stage_output[idx, text_size-1:text_size+audio_size*2] = torch.cat([bands, eoa_tensor], dim=0) 66 | first_stage_output[idx, text_size+audio_size*2:] = loss_pad 67 | return padded_tokens, padded_embeddings, padded_audios, first_stage_input, first_stage_output 68 | 69 | 70 | if __name__ == "__main__": 71 | for sp in ["train" , "val"]: 72 | voice_train = VoiceDataset(split=sp, data_dir="/debug_data") 73 | for i in tqdm(range(len(voice_train))): 74 | tokens, embedding, audio_tokens = voice_train.__getitem__(i) 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /fam/llm/training/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | import functools 7 | import logging 8 | import pickle 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | _LOCAL_PROCESS_GROUP = None 14 | """ 15 | A torch process group which only includes processes that on the same machine as the current process. 16 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 17 | """ 18 | 19 | 20 | def get_world_size() -> int: 21 | if not dist.is_available(): 22 | return 1 23 | if not dist.is_initialized(): 24 | return 1 25 | return dist.get_world_size() 26 | 27 | 28 | def get_rank() -> int: 29 | if not dist.is_available(): 30 | return 0 31 | if not dist.is_initialized(): 32 | return 0 33 | return dist.get_rank() 34 | 35 | 36 | def get_local_rank() -> int: 37 | """ 38 | Returns: 39 | The rank of the current process within the local (per-machine) process group. 40 | """ 41 | if not dist.is_available(): 42 | return 0 43 | if not dist.is_initialized(): 44 | return 0 45 | assert _LOCAL_PROCESS_GROUP is not None 46 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 47 | 48 | 49 | def get_local_size() -> int: 50 | """ 51 | Returns: 52 | The size of the per-machine process group, 53 | i.e. the number of processes per machine. 54 | """ 55 | if not dist.is_available(): 56 | return 1 57 | if not dist.is_initialized(): 58 | return 1 59 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 60 | 61 | 62 | def is_initialized() -> bool: 63 | return dist.is_initialized() 64 | 65 | def is_available() -> bool: 66 | return dist.is_available() 67 | 68 | def is_main_process() -> bool: 69 | return get_rank() == 0 70 | 71 | def is_dist_avail_and_initialized(): 72 | """ 73 | Returns: 74 | True if distributed training is enabled 75 | """ 76 | if not dist.is_available(): 77 | return False 78 | if not dist.is_initialized(): 79 | return False 80 | return True 81 | 82 | 83 | def synchronize(): 84 | """ 85 | Helper function to synchronize (barrier) among all processes when 86 | using distributed training 87 | """ 88 | if not dist.is_available(): 89 | return 90 | if not dist.is_initialized(): 91 | return 92 | world_size = dist.get_world_size() 93 | if world_size == 1: 94 | return 95 | dist.barrier() 96 | 97 | 98 | @functools.lru_cache() 99 | def _get_global_gloo_group(): 100 | """ 101 | Return a process group based on gloo backend, containing all the ranks 102 | The result is cached. 103 | """ 104 | if dist.get_backend() == "nccl": 105 | return dist.new_group(backend="gloo") 106 | else: 107 | return dist.group.WORLD 108 | 109 | 110 | def _serialize_to_tensor(data, group): 111 | backend = dist.get_backend(group) 112 | assert backend in ["gloo", "nccl"] 113 | device = torch.device("cpu" if backend == "gloo" else "cuda") 114 | 115 | buffer = pickle.dumps(data) 116 | if len(buffer) > 1024 ** 3: 117 | logger = logging.getLogger(__name__) 118 | logger.warning( 119 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 120 | get_rank(), len(buffer) / (1024 ** 3), device 121 | ) 122 | ) 123 | storage = torch.ByteStorage.from_buffer(buffer) 124 | tensor = torch.ByteTensor(storage).to(device=device) 125 | return tensor 126 | 127 | 128 | def _pad_to_largest_tensor(tensor, group): 129 | """ 130 | Returns: 131 | list[int]: size of the tensor, on each rank 132 | Tensor: padded tensor that has the max size 133 | """ 134 | world_size = dist.get_world_size(group=group) 135 | assert ( 136 | world_size >= 1 137 | ), "comm.gather/all_gather must be called from ranks within the given group!" 138 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 139 | size_list = [ 140 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 141 | for _ in range(world_size) 142 | ] 143 | dist.all_gather(size_list, local_size, group=group) 144 | size_list = [int(size.item()) for size in size_list] 145 | 146 | max_size = max(size_list) 147 | 148 | # we pad the tensor because torch all_gather does not support 149 | # gathering tensors of different shapes 150 | if local_size != max_size: 151 | padding = torch.zeros( 152 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 153 | ) 154 | tensor = torch.cat((tensor, padding), dim=0) 155 | return size_list, tensor 156 | 157 | 158 | def all_gather(data, group=None): 159 | """ 160 | Run all_gather on arbitrary picklable data (not necessarily tensors). 161 | 162 | Args: 163 | data: any picklable object 164 | group: a torch process group. By default, will use a group which 165 | contains all ranks on gloo backend. 166 | 167 | Returns: 168 | list[data]: list of data gathered from each rank 169 | """ 170 | if get_world_size() == 1: 171 | return [data] 172 | if group is None: 173 | group = _get_global_gloo_group() 174 | if dist.get_world_size(group) == 1: 175 | return [data] 176 | 177 | tensor = _serialize_to_tensor(data, group) 178 | 179 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 180 | max_size = max(size_list) 181 | 182 | # receiving Tensor from all ranks 183 | tensor_list = [ 184 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 185 | for _ in size_list 186 | ] 187 | dist.all_gather(tensor_list, tensor, group=group) 188 | 189 | data_list = [] 190 | for size, tensor in zip(size_list, tensor_list): 191 | buffer = tensor.cpu().numpy().tobytes()[:size] 192 | data_list.append(pickle.loads(buffer)) 193 | 194 | return data_list 195 | 196 | 197 | def gather(data, dst=0, group=None): 198 | """ 199 | Run gather on arbitrary picklable data (not necessarily tensors). 200 | 201 | Args: 202 | data: any picklable object 203 | dst (int): destination rank 204 | group: a torch process group. By default, will use a group which 205 | contains all ranks on gloo backend. 206 | 207 | Returns: 208 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 209 | an empty list. 210 | """ 211 | if get_world_size() == 1: 212 | return [data] 213 | if group is None: 214 | group = _get_global_gloo_group() 215 | if dist.get_world_size(group=group) == 1: 216 | return [data] 217 | rank = dist.get_rank(group=group) 218 | 219 | tensor = _serialize_to_tensor(data, group) 220 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 221 | 222 | # receiving Tensor from all ranks 223 | if rank == dst: 224 | max_size = max(size_list) 225 | tensor_list = [ 226 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 227 | for _ in size_list 228 | ] 229 | dist.gather(tensor, tensor_list, dst=dst, group=group) 230 | 231 | data_list = [] 232 | for size, tensor in zip(size_list, tensor_list): 233 | buffer = tensor.cpu().numpy().tobytes()[:size] 234 | data_list.append(pickle.loads(buffer)) 235 | return data_list 236 | else: 237 | dist.gather(tensor, [], dst=dst, group=group) 238 | return [] 239 | 240 | 241 | def shared_random_seed(): 242 | """ 243 | Returns: 244 | int: a random number that is the same across all workers. 245 | If workers need a shared RNG, they can use this shared seed to 246 | create one. 247 | 248 | All workers must call this function, otherwise it will deadlock. 249 | """ 250 | ints = np.random.randint(2 ** 31) 251 | all_ints = all_gather(ints) 252 | return all_ints[0] 253 | 254 | 255 | def reduce_dict(input_dict, average=True): 256 | """ 257 | Reduce the values in the dictionary from all processes so that process with rank 258 | 0 has the reduced results. 259 | 260 | Args: 261 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 262 | average (bool): whether to do average or sum 263 | 264 | Returns: 265 | a dict with the same keys as input_dict, after reduction. 266 | """ 267 | world_size = get_world_size() 268 | if world_size < 2: 269 | return input_dict 270 | with torch.no_grad(): 271 | names = [] 272 | values = [] 273 | # sort the keys so that they are consistent across processes 274 | for k in sorted(input_dict.keys()): 275 | names.append(k) 276 | values.append(input_dict[k]) 277 | values = torch.stack(values, dim=0) 278 | dist.reduce(values, dst=0) 279 | if dist.get_rank() == 0 and average: 280 | # only main process gets accumulated, so only divide by 281 | # world_size in this case 282 | values /= world_size 283 | reduced_dict = {k: v for k, v in zip(names, values)} 284 | return reduced_dict -------------------------------------------------------------------------------- /fam/llm/training/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import fam.llm.training.dist as dist_utils 3 | 4 | class Evaluator(object): 5 | def __init__(self): 6 | self.reset() 7 | 8 | def update(self, results): 9 | for k in self.fields: 10 | self.results[k] += results[k] 11 | 12 | def synchronize_between_processes(self): 13 | all_results = dist_utils.all_gather(self.results) 14 | merged_results = {} 15 | for r in all_results: 16 | for k in r.keys(): 17 | merged_results.setdefault(k, []) 18 | merged_results[k] += r[k] 19 | for k in merged_results.keys(): 20 | merged_results[k] = np.array(merged_results[k]) 21 | self.results = merged_results 22 | 23 | def summarize(self): 24 | #! With multi-gpu, the dataloader duplicates examples if the number of examples is not divisible by the batch size. 25 | if dist_utils.is_main_process(): 26 | return {k: np.mean(self.results[k]) for k in self.fields} 27 | 28 | else: 29 | assert False, "This if function should not be called." 30 | 31 | def reset(self): 32 | self.results = {} 33 | self.fields = [ 34 | "loss", 35 | ] 36 | 37 | for f in self.fields: 38 | self.results.setdefault(f, []) 39 | -------------------------------------------------------------------------------- /fam/llm/training/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | from IPython import embed 4 | from transformers import ( 5 | get_polynomial_decay_schedule_with_warmup, 6 | get_cosine_schedule_with_warmup, 7 | ) 8 | 9 | 10 | def set_schedule(pl_module): 11 | optimizer_type = pl_module.configs["optimizer_type"] 12 | lr = pl_module.configs["learning_rate"] 13 | beta_one = pl_module.configs["beta_one"] 14 | beta_two = pl_module.configs["beta_two"] 15 | eps = pl_module.configs["epsilon"] 16 | wd = pl_module.configs["weight_decay"] 17 | #fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters 18 | #use_fused = fused_available and pl_module.device.type == "cuda" 19 | use_fused=False 20 | extra_args = dict(fused=True) if use_fused else dict() 21 | 22 | if optimizer_type == "Adam": 23 | optimizer = torch.optim.Adam(pl_module.first_stage_model_cls.model.parameters(), lr=lr, betas=(beta_one, beta_two), eps=eps, **extra_args) 24 | elif optimizer_type == "AdamW": 25 | model = pl_module.first_stage_model_cls.model 26 | # start with all of the candidate parameters 27 | param_dict = {pn: p for pn, p in model.named_parameters()} 28 | # filter out those that do not require grad 29 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 30 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 31 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 32 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 33 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 34 | optim_groups = [ 35 | {"params": decay_params, "weight_decay": wd}, 36 | {"params": nodecay_params, "weight_decay": 0.0}, 37 | ] 38 | num_decay_params = sum(p.numel() for p in decay_params) 39 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 40 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 41 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 42 | 43 | optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(beta_one, beta_two), eps=eps, **extra_args) 44 | 45 | if pl_module.trainer.max_steps == -1: 46 | # TODO: this is note tested in multi-node set-up. 47 | max_steps = ( 48 | len(pl_module.trainer.datamodule.train_dataloader()) 49 | // len(pl_module.trainer.device_ids) 50 | * pl_module.trainer.max_epochs 51 | // pl_module.trainer.accumulate_grad_batches 52 | ) 53 | else: 54 | max_steps = pl_module.trainer.max_steps 55 | 56 | end_lr = pl_module.configs["end_lr"] 57 | decay_power = pl_module.configs["decay_power"] 58 | warmup_steps = pl_module.configs["warmup_steps"] 59 | if warmup_steps <= 1: 60 | warmup_steps = int(max_steps * warmup_steps) 61 | 62 | if decay_power == "cosine": 63 | scheduler = get_cosine_schedule_with_warmup( 64 | optimizer, 65 | num_warmup_steps=warmup_steps, 66 | num_training_steps=max_steps, 67 | ) 68 | else: 69 | scheduler = get_polynomial_decay_schedule_with_warmup( 70 | optimizer, 71 | num_warmup_steps=warmup_steps, 72 | num_training_steps=max_steps, 73 | lr_end=end_lr, 74 | power=decay_power, 75 | ) 76 | 77 | sched = {"scheduler": scheduler, "interval": "step"} 78 | 79 | return ( 80 | [optimizer], 81 | [sched], 82 | ) -------------------------------------------------------------------------------- /fam/llm/training/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | def get_first_stage_path(model_dir: str): 5 | """Absolute path to checkpoint for the first stage model.""" 6 | return os.path.join(os.path.expanduser(model_dir), "first_stage.pt") 7 | 8 | 9 | def get_second_stage_path(model_dir: str): 10 | """Absolute path to checkpoint for the second stage model.""" 11 | return os.path.join(os.path.expanduser(model_dir), "second_stage.pt") 12 | 13 | 14 | @dataclass 15 | class TrainingConfig: 16 | ckpt_path: str # path to checkpoint 17 | output_dir: str 18 | num_samples: int = 10 # number of samples to draw 19 | seed: int = 1337 # random seed 20 | device: str = "cuda" 21 | dtype: str = "bfloat16" 22 | compile: bool = False 23 | init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 24 | train_from_scratch: bool = False # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 25 | spkemb_dropout: float = 0.0 26 | 27 | def __str__(self): 28 | field_strs = [] 29 | for field in dataclasses.fields(self): 30 | value = getattr(self, field.name) 31 | field_strs.append(f" {field.name}: {value}") 32 | 33 | return "TrainingConfig:\n" + "\n".join(field_strs) 34 | -------------------------------------------------------------------------------- /fam/llm/training/wandb_utils.py: -------------------------------------------------------------------------------- 1 | # python fam/llm/training/wandb_utils.py 2 | import os 3 | import wandb 4 | from datetime import datetime 5 | from argparse import Namespace 6 | from IPython import embed 7 | 8 | class WandbLogger(object): 9 | def __init__(self, args: Namespace): 10 | super().__init__() 11 | self._step = 0 12 | self._debug = args.debug 13 | now = datetime.now() 14 | exp_name = args.exp_name + "-" + str(now).replace(" ", "-") 15 | wandb_input = { 16 | "entity": args.wandb_entity, 17 | "name": exp_name, 18 | "config": args 19 | } 20 | wandb_input["project"] = args.wandb_project 21 | if args.wandb_run_id is not None: 22 | wandb_input["id"] = args.wandb_run_id 23 | wandb_input["resume"] = "must" 24 | wandb.init(**wandb_input) 25 | 26 | def log(self, results, split: str, step: int = None, commit=False): 27 | formated_results = {} 28 | if step is not None: 29 | self._step = step 30 | 31 | for k, v in results.items(): 32 | formated_results["{}/{}".format(split, k)] = v 33 | wandb.log(formated_results, step=self._step, commit=commit) 34 | 35 | def get_step(self): 36 | return self._step 37 | 38 | 39 | if __name__ == "__main__": 40 | from fam.llm.training import parse_args 41 | training_args = parse_args() 42 | logger = WandbLogger(training_args) -------------------------------------------------------------------------------- /fam/llm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | import tempfile 5 | 6 | import librosa 7 | import torch 8 | 9 | 10 | def normalize_text(text: str) -> str: 11 | unicode_conversion = { 12 | 8175: "'", 13 | 8189: "'", 14 | 8190: "'", 15 | 8208: "-", 16 | 8209: "-", 17 | 8210: "-", 18 | 8211: "-", 19 | 8212: "-", 20 | 8213: "-", 21 | 8214: "||", 22 | 8216: "'", 23 | 8217: "'", 24 | 8218: ",", 25 | 8219: "`", 26 | 8220: '"', 27 | 8221: '"', 28 | 8222: ",,", 29 | 8223: '"', 30 | 8228: ".", 31 | 8229: "..", 32 | 8230: "...", 33 | 8242: "'", 34 | 8243: '"', 35 | 8245: "'", 36 | 8246: '"', 37 | 180: "'", 38 | 2122: "TM", # Trademark 39 | } 40 | 41 | text = text.translate(unicode_conversion) 42 | 43 | non_bpe_chars = set([c for c in list(text) if ord(c) >= 256]) 44 | #if len(non_bpe_chars) > 0: 45 | # non_bpe_points = [(c, ord(c)) for c in non_bpe_chars] 46 | # raise ValueError(f"Non-BPE single token characters found: {non_bpe_points}") 47 | 48 | text = text.replace("\t", " ") 49 | text = text.replace("\n", " ") 50 | text = text.replace("*", " ") 51 | text = text.strip() 52 | text = re.sub("\s\s+", " ", text) # remove multiple spaces 53 | return text 54 | 55 | def check_audio_file(path_or_uri, threshold_s=30): 56 | if "http" in path_or_uri: 57 | temp_fd, filepath = tempfile.mkstemp() 58 | os.close(temp_fd) # Close the file descriptor, curl will create a new connection 59 | curl_command = ["curl", "-L", path_or_uri, "-o", filepath] 60 | subprocess.run(curl_command, check=True) 61 | 62 | else: 63 | filepath = path_or_uri 64 | 65 | audio, sr = librosa.load(filepath) 66 | duration_s = librosa.get_duration(y=audio, sr=sr) 67 | #if duration_s < threshold_s: 68 | # raise Exception( 69 | # f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed." 70 | # ) 71 | 72 | # Clean up the temporary file if it was created 73 | if "http" in path_or_uri: 74 | os.remove(filepath) 75 | 76 | 77 | def get_default_dtype() -> str: 78 | """Compute default 'dtype' based on GPU architecture""" 79 | if torch.cuda.is_available(): 80 | for i in range(torch.cuda.device_count()): 81 | device_properties = torch.cuda.get_device_properties(i) 82 | dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures 83 | else: 84 | dtype = "float16" 85 | 86 | print(f"using dtype={dtype}") 87 | return dtype 88 | 89 | 90 | def get_device() -> str: 91 | return "cuda" if torch.cuda.is_available() else "cpu" 92 | -------------------------------------------------------------------------------- /fam/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/py.typed -------------------------------------------------------------------------------- /fam/quantiser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/__init__.py -------------------------------------------------------------------------------- /fam/quantiser/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/audio/__init__.py -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/audio/speaker_encoder/__init__.py -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | mel_window_length = 25 5 | mel_window_step = 10 6 | mel_n_channels = 40 7 | sampling_rate = 16000 8 | 9 | 10 | def wav_to_mel_spectrogram(wav): 11 | """ 12 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 13 | Note: this not a log-mel spectrogram. 14 | """ 15 | frames = librosa.feature.melspectrogram( 16 | y=wav, 17 | sr=sampling_rate, 18 | n_fft=int(sampling_rate * mel_window_length / 1000), 19 | hop_length=int(sampling_rate * mel_window_step / 1000), 20 | n_mels=mel_n_channels, 21 | ) 22 | return frames.astype(np.float32).T 23 | -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt -------------------------------------------------------------------------------- /fam/quantiser/audio/speaker_encoder/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import perf_counter as timer 3 | from typing import List, Optional, Union 4 | 5 | import librosa 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | from fam.quantiser.audio.speaker_encoder import audio 11 | 12 | DEFAULT_SPKENC_CKPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/ckpt.pt") 13 | 14 | mel_window_step = 10 15 | mel_n_channels = 40 16 | sampling_rate = 16000 17 | partials_n_frames = 160 18 | model_hidden_size = 256 19 | model_embedding_size = 256 20 | model_num_layers = 3 21 | 22 | 23 | class SpeakerEncoder(nn.Module): 24 | def __init__( 25 | self, 26 | weights_fpath: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | verbose: bool = True, 29 | eval: bool = False, 30 | ): 31 | super().__init__() 32 | 33 | # Define the network 34 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) 35 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 36 | self.relu = nn.ReLU() 37 | 38 | # Get the target device 39 | if device is None: 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | elif isinstance(device, str): 42 | device = torch.device(device) 43 | self.device = device 44 | 45 | start = timer() 46 | if eval and weights_fpath is None: 47 | weights_fpath = DEFAULT_SPKENC_CKPT_PATH 48 | 49 | if weights_fpath is not None: 50 | checkpoint = torch.load(weights_fpath, map_location="cpu") 51 | 52 | self.load_state_dict(checkpoint["model_state"], strict=False) 53 | self.to(device) 54 | 55 | if eval: 56 | self.eval() 57 | 58 | if verbose: 59 | print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start)) 60 | 61 | def forward(self, mels: torch.FloatTensor): 62 | _, (hidden, _) = self.lstm(mels) 63 | embeds_raw = self.relu(self.linear(hidden[-1])) 64 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 65 | 66 | @staticmethod 67 | def compute_partial_slices(n_samples: int, rate, min_coverage): 68 | # Compute how many frames separate two partial utterances 69 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 70 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 71 | frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) 72 | 73 | # Compute the slices 74 | wav_slices, mel_slices = [], [] 75 | steps = max(1, n_frames - partials_n_frames + frame_step + 1) 76 | for i in range(0, steps, frame_step): 77 | mel_range = np.array([i, i + partials_n_frames]) 78 | wav_range = mel_range * samples_per_frame 79 | mel_slices.append(slice(*mel_range)) 80 | wav_slices.append(slice(*wav_range)) 81 | 82 | # Evaluate whether extra padding is warranted or not 83 | last_wav_range = wav_slices[-1] 84 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 85 | if coverage < min_coverage and len(mel_slices) > 1: 86 | mel_slices = mel_slices[:-1] 87 | wav_slices = wav_slices[:-1] 88 | 89 | return wav_slices, mel_slices 90 | 91 | def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True): 92 | wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) 93 | max_wave_length = wav_slices[-1].stop 94 | if max_wave_length >= len(wav): 95 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 96 | 97 | mel = audio.wav_to_mel_spectrogram(wav) 98 | mels = np.array([mel[s] for s in mel_slices]) 99 | with torch.no_grad(): 100 | mels = torch.from_numpy(mels).to(self.device) # type: ignore 101 | partial_embeds = self(mels) 102 | 103 | if numpy: 104 | partial_embeds = partial_embeds.cpu().numpy() 105 | raw_embed = np.mean(partial_embeds, axis=0) 106 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 107 | else: 108 | raw_embed = partial_embeds.mean(dim=0) 109 | embed = raw_embed / torch.linalg.norm(raw_embed, 2) 110 | 111 | if return_partials: 112 | return embed, partial_embeds, wav_slices 113 | return embed 114 | 115 | def embed_speaker(self, wavs: List[np.ndarray], **kwargs): 116 | raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0) 117 | return raw_embed / np.linalg.norm(raw_embed, 2) 118 | 119 | def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor: 120 | wav_tgt, _ = librosa.load(fpath, sr=16000) 121 | wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) 122 | embedding = self.embed_utterance(wav_tgt, numpy=numpy) 123 | return embedding 124 | -------------------------------------------------------------------------------- /fam/quantiser/text/tokenise.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | 3 | 4 | class TrainedBPETokeniser: 5 | def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None: 6 | self.tokenizer = tiktoken.Encoding( 7 | name=name, 8 | pat_str=pat_str, 9 | mergeable_ranks=mergeable_ranks, 10 | special_tokens=special_tokens, 11 | ) 12 | self.offset = offset 13 | 14 | def encode(self, text: str) -> list[int]: 15 | # note: we add a end of text token! 16 | tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token] 17 | if self.offset is not None: 18 | tokens = [x + self.offset for x in tokens] 19 | 20 | return tokens 21 | 22 | def decode(self, tokens: list[int]): 23 | if self.offset is not None: 24 | tokens = [x - self.offset for x in tokens] 25 | return self.tokenizer.decode(tokens) 26 | 27 | @property 28 | def eot_token(self): 29 | if self.offset is not None: 30 | return self.tokenizer.eot_token + self.offset 31 | else: 32 | return self.tokenizer.eot_token 33 | -------------------------------------------------------------------------------- /fam/ui/app.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import os 4 | 5 | import gradio as gr 6 | import requests 7 | import soundfile as sf 8 | 9 | API_SERVER_URL = "http://127.0.0.1:58003/tts" 10 | RADIO_CHOICES = ["Preset voices", "Upload target voice", "Record your voice"] 11 | MAX_CHARS = 220 12 | PRESET_VOICES = { 13 | # female 14 | "Ava": "https://cdn.themetavoice.xyz/speakers/ava.flac", 15 | "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3", 16 | # male 17 | "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3", 18 | "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav", 19 | } 20 | 21 | 22 | def denormalise_top_p(top_p): 23 | # returns top_p in the range [0.9, 1.0] 24 | return round(0.9 + top_p / 100, 2) 25 | 26 | 27 | def denormalise_guidance(guidance): 28 | # returns guidance in the range [1.0, 3.0] 29 | return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1) 30 | 31 | 32 | def _handle_edge_cases(to_say, upload_target): 33 | if not to_say: 34 | raise gr.Error("Please provide text to synthesise") 35 | 36 | def _check_file_size(path): 37 | if not path: 38 | return 39 | filesize = os.path.getsize(path) 40 | filesize_mb = filesize / 1024 / 1024 41 | if filesize_mb >= 50: 42 | raise gr.Error( 43 | f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB" 44 | ) 45 | 46 | _check_file_size(upload_target) 47 | 48 | 49 | def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target): 50 | d_top_p = denormalise_top_p(top_p) 51 | d_guidance = denormalise_guidance(guidance) 52 | 53 | _handle_edge_cases(to_say, upload_target) 54 | 55 | to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS] 56 | 57 | custom_target_path = None 58 | if toggle == RADIO_CHOICES[1]: 59 | custom_target_path = upload_target 60 | elif toggle == RADIO_CHOICES[2]: 61 | custom_target_path = record_target 62 | 63 | config = { 64 | "text": to_say, 65 | "guidance": d_guidance, 66 | "top_p": d_top_p, 67 | "speaker_ref_path": PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else None, 68 | } 69 | headers = {"Content-Type": "audio/wav", "X-Payload": json.dumps(config)} 70 | if not custom_target_path: 71 | response = requests.post(API_SERVER_URL, headers=headers, data=None) 72 | else: 73 | with open(custom_target_path, "rb") as f: 74 | data = f.read() 75 | response = requests.post(API_SERVER_URL, headers=headers, data=data) 76 | 77 | wav, sr = None, None 78 | if response.status_code == 200: 79 | audio_buffer = io.BytesIO(response.content) 80 | audio_buffer.seek(0) 81 | wav, sr = sf.read(audio_buffer, dtype="float32") 82 | else: 83 | print(f"Something went wrong. response status code: {response.status_code}") 84 | 85 | return sr, wav 86 | 87 | 88 | def change_voice_selection_layout(choice): 89 | index = RADIO_CHOICES.index(choice) 90 | return [ 91 | gr.update(visible=True) 92 | if i == index else gr.update(visible=False) 93 | for i in range(len(RADIO_CHOICES)) 94 | ] 95 | 96 | 97 | title = "# TTS by Kotoba-Speech" 98 | 99 | description = """ 100 | Kotoba-Speech v0.1は、1.2Bのトランスフォーマーに基づく音声生成モデルです。 101 | 以下の機能をサポートしています: 102 | \n 103 | * 日本語における滑らかなテキスト読み上げ生成 104 | * スピーチプロンプトを通じたOne-shot音声クローニング 105 | 106 | Kotoba Technologiesは、公開されたモデルを商用可能なApache 2.0ライセンスで公開します。 107 | 推論およびモデルコードは、Meta-Voiceをベースに作られており、学習コードは弊社のGitHubで近日中に公開する予定です。 108 | Kotoba Technologiesは、音声基盤モデルの開発に取り組んでおり、今後もモデルの公開を行なっていきます。是非、[Discord Community](https://discord.gg/qPVFqhGN7Z)に参加してご意見ください! 109 | 110 | Kotoba-Speech v0.1 is a 1.2B Transformer-based speech generative model. It supports the following properties: 111 | \n 112 | * Fluent text-to-speech generation in Japanese 113 | * One-shot voice cloning through speech prompt 114 | 115 | We are releasing our model under the Apache 2.0 license. Our inference and model code is adapted from Meta-Voice, and we will our training code on our GitHub repository shortly. 116 | Kotoba Technologies is committing on developing speech foundation models, and we’ll continue releasing our models. Please join [our discord](https://discord.gg/qPVFqhGN7Z) to contribute to out community. 117 | """ 118 | 119 | with gr.Blocks(title="TTS by Kotoba-Speech") as demo: 120 | gr.Markdown(title) 121 | 122 | with gr.Row(): 123 | gr.Markdown(description) 124 | 125 | with gr.Row(): 126 | with gr.Column(): 127 | to_say = gr.TextArea( 128 | label="What should I say!?", 129 | lines=4, 130 | value="コトバテクノロジーズのミッションは、音声基盤モデルを作ることです。", 131 | ) 132 | 133 | with gr.Row(), gr.Column(): 134 | # voice settings 135 | top_p = gr.Slider( 136 | value=5.0, 137 | minimum=0.0, 138 | maximum=10.0, 139 | step=1.0, 140 | label="Speech Stability - improves text following for a challenging speaker", 141 | ) 142 | guidance = gr.Slider( 143 | value=5.0, 144 | minimum=1.0, 145 | maximum=5.0, 146 | step=1.0, 147 | label="Speaker similarity - How closely to match speaker identity and speech style.", 148 | ) 149 | 150 | # voice select 151 | toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0]) 152 | 153 | with gr.Row(visible=True) as row_1: 154 | preset_dropdown = gr.Dropdown( 155 | PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0] 156 | ) 157 | with gr.Accordion("Preview: Preset voices", open=False): 158 | for label, path in PRESET_VOICES.items(): 159 | gr.Audio(value=path, label=label) 160 | 161 | with gr.Row(visible=False) as row_2: 162 | upload_target = gr.Audio( 163 | sources=["upload"], 164 | type="filepath", 165 | label="Upload a clean sample to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.", 166 | min_length=10, 167 | max_length=90, 168 | ) 169 | 170 | with gr.Row(visible=False) as row_3: 171 | record_target = gr.Audio( 172 | sources=["microphone"], 173 | type="filepath", 174 | label="Record your voice with a microphone to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.", 175 | min_length=10, 176 | max_length=90, 177 | ) 178 | 179 | toggle.change( 180 | change_voice_selection_layout, 181 | inputs=toggle, 182 | outputs=[row_1, row_2, row_3], 183 | ) 184 | 185 | with gr.Column(): 186 | speech = gr.Audio( 187 | type="numpy", 188 | label="Kotoba-Speech says...", 189 | ) 190 | 191 | submit = gr.Button("Generate Speech") 192 | submit.click( 193 | fn=tts, 194 | inputs=[to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target], 195 | outputs=speech, 196 | ) 197 | 198 | 199 | demo.queue(default_concurrency_limit=2) 200 | demo.launch() 201 | # demo.launch(server_name="0.0.0.0", server_port=3000, share=True) # dev 202 | -------------------------------------------------------------------------------- /fam/ui/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kotoba-tech/kotoba-speech-release/866b66049a6a94b51f918e504ab3d1ec0c128c69/fam/ui/assets/favicon.ico -------------------------------------------------------------------------------- /preprocess/audio_tokenize.py: -------------------------------------------------------------------------------- 1 | import os, torch, json 2 | from tqdm import tqdm 3 | import argparse 4 | import julius 5 | from audiocraft.data.audio import audio_read 6 | from audiocraft.models import MultiBandDiffusion # type: ignore 7 | from utils import get_sorted_keys, get_ids, create_directory_recursive 8 | 9 | mbd_sample_rate = 24_000 10 | mbd_bandwidth = 6 11 | num_codebooks = 8 12 | mbd = MultiBandDiffusion.get_mbd_24khz(bw=mbd_bandwidth) 13 | 14 | def tokenize(audio_path, output_file): 15 | wav, sr = audio_read(audio_path) 16 | if sr != mbd_sample_rate: 17 | wav = julius.resample_frac(wav, sr, mbd_sample_rate) 18 | if wav.ndim == 2: 19 | wav = wav.unsqueeze(1) 20 | wav = wav.to("cuda") 21 | with torch.no_grad(): 22 | tokens = mbd.codec_model.encode(wav) 23 | tokens = tokens[0][0].cpu() 24 | torch.save(tokens, output_file) 25 | 26 | def main(file_name, base_dir, num_shards, shard_id): 27 | files = get_sorted_keys(file_name) 28 | start, end = get_ids(len(files), num_shards, shard_id) 29 | print(start, end) 30 | for file_name in tqdm(files[start:end]): 31 | in_file = os.path.join(base_dir, "wav", file_name + ".wav") 32 | out_file = os.path.join(base_dir, "tok", file_name + ".tok") 33 | if os.path.exists(out_file): 34 | continue 35 | create_directory_recursive(out_file) 36 | tokenize(in_file, out_file) 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="Audito Tokenize") 40 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file") 41 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir") 42 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards") 43 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID") 44 | args = parser.parse_args() 45 | main(args.in_file, args.base_dir, args.num_shards, args.shard_id) 46 | -------------------------------------------------------------------------------- /preprocess/download_reazon.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, json 2 | import pandas as pd 3 | import argparse 4 | from utils import get_dirs, get_sorted_keys, get_ids, create_directory_recursive 5 | 6 | def download_tsv(split, base_dir): 7 | command = f"wget https://reazonspeech.s3.abci.ai/{split}.tsv -P {base_dir}" 8 | subprocess.run(command, shell=True) 9 | file_path = os.path.join(base_dir, f"{split}.tsv") 10 | return file_path 11 | 12 | def download_file(file_name, base_dir): 13 | base_url = "https://reazonspeech.s3.abci.ai/data/" 14 | path = os.path.join(base_url, file_name + ".tar") 15 | command = f"wget {path} -P {base_dir}" 16 | subprocess.run(command, shell=True) 17 | path = os.path.join(base_dir, file_name + ".tar") 18 | command = f"tar xvf {path} -C {base_dir}" 19 | subprocess.run(command, shell=True) 20 | command = f"rm {path}" 21 | subprocess.run(command, shell=True) 22 | 23 | def flac2wav(file_name, base_dir): 24 | flac_name = os.path.join(base_dir, file_name + ".flac") 25 | wav_name = os.path.join(base_dir, 'wav', file_name + '.wav') 26 | create_directory_recursive(wav_name) 27 | subprocess.run(f"ffmpeg -y -i {flac_name} {wav_name}", shell=True) 28 | 29 | def download_files(in_file, base_dir, num_shards, shard_id): 30 | files = get_dirs(in_file) 31 | start, end = get_ids(len(files), num_shards, shard_id) 32 | for file_name in files[start:end]: 33 | path = os.path.join(base_dir, file_name) 34 | if not os.path.exists(path): 35 | download_file(file_name, base_dir) 36 | 37 | def flac2wav_files(in_file, base_dir, num_shards, shard_id): 38 | files = get_sorted_keys(in_file) 39 | start, end = get_ids(len(files), num_shards, shard_id) 40 | for file_name in files[start:end]: 41 | path = os.path.join(base_dir, file_name) 42 | path = os.path.join(base_dir, 'wav', file_name + '.wav') 43 | if not os.path.exists(path): 44 | flac2wav(file_name, base_dir) 45 | 46 | def tsv2jsonl(in_file, out_file): 47 | data = pd.read_csv(in_file, header=None, sep='\t') 48 | with open(out_file, "w") as f: 49 | for index, row in data.iterrows(): 50 | json.dump({"key": row[0].replace(".flac", ""), "text": row[1]}, f, ensure_ascii=False) 51 | f.write('\n') 52 | 53 | def main(split, base_dir, num_shards, shard_id): 54 | create_directory_recursive(base_dir) 55 | in_file = download_tsv(split, base_dir) 56 | out_file = in_file.replace(".tsv", ".jsonl") 57 | tsv2jsonl(in_file, out_file) 58 | download_files(out_file, base_dir, num_shards, shard_id) 59 | flac2wav_files(out_file, base_dir, num_shards, shard_id) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser(description="Download Reazon Dataset") 63 | parser.add_argument("--split", default="small", help="Reazon split") 64 | parser.add_argument("--base_dir", default="data/", help="path to data directory") 65 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards") 66 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID") 67 | args = parser.parse_args() 68 | main(args.split, args.base_dir, args.num_shards, args.shard_id) 69 | -------------------------------------------------------------------------------- /preprocess/spk_embed.py: -------------------------------------------------------------------------------- 1 | import os, torch, json 2 | from tqdm import tqdm 3 | import argparse 4 | from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder 5 | from utils import get_sorted_keys, get_ids, create_directory_recursive 6 | 7 | smodel = SpeakerEncoder(device="cuda", eval=True, verbose=False) 8 | 9 | def speaker_embed(audio_path, output_file): 10 | with torch.no_grad(): 11 | embedding = smodel.embed_utterance_from_file(audio_path, numpy=False) 12 | embedding = embedding.cpu().detach() 13 | torch.save(embedding, output_file) 14 | return embedding 15 | 16 | 17 | def main(file_name, base_dir, num_shards, shard_id): 18 | files = get_sorted_keys(file_name) 19 | start, end = get_ids(len(files), num_shards, shard_id) 20 | print(start, end) 21 | for file_name in tqdm(files[start:end]): 22 | in_file = os.path.join(base_dir, "wav", file_name + ".wav") 23 | out_file = os.path.join(base_dir, "emb", file_name + ".emb") 24 | if os.path.exists(out_file): 25 | continue 26 | create_directory_recursive(out_file) 27 | speaker_embed(in_file, out_file) 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description="Create Speaker Embeddings") 31 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file") 32 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir") 33 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards") 34 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID") 35 | args = parser.parse_args() 36 | main(args.in_file, args.base_dir, args.num_shards, args.shard_id) 37 | -------------------------------------------------------------------------------- /preprocess/split.py: -------------------------------------------------------------------------------- 1 | import os, subprocess 2 | import json 3 | import argparse 4 | 5 | 6 | def split_json_file(json_file_path, output_dir, val_size=500): 7 | with open(json_file_path, 'r') as f: 8 | # Read and modify lines 9 | lines = f.readlines() 10 | 11 | N = len(lines) 12 | train_data = lines[:N-val_size] 13 | val_data = lines[N-val_size*2:N-val_size] 14 | test_data = lines[N-val_size:] 15 | 16 | # Write to train.jsonl 17 | with open(os.path.join(output_dir, 'train.jsonl'), 'w', encoding='utf-8') as f: 18 | f.writelines(train_data) 19 | 20 | # Write to val.jsonl 21 | with open(os.path.join(output_dir, 'val.jsonl'), 'w', encoding='utf-8') as f: 22 | f.writelines(val_data) 23 | 24 | # Write to test.jsonl 25 | with open(os.path.join(output_dir, 'test.jsonl'), 'w', encoding='utf-8') as f: 26 | f.writelines(test_data) 27 | 28 | 29 | 30 | # Example usage: 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description="Split list into chunks") 33 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file") 34 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir") 35 | args = parser.parse_args() 36 | split_json_file(args.in_file, args.base_dir) 37 | -------------------------------------------------------------------------------- /preprocess/text_tokenize.py: -------------------------------------------------------------------------------- 1 | import os, torch, argparse, torch 2 | from tqdm import tqdm 3 | from fam.quantiser.text.tokenise import TrainedBPETokeniser 4 | from utils import get_keys_texts, get_ids, create_directory_recursive 5 | from huggingface_hub import snapshot_download 6 | from fam.llm.training import get_first_stage_path 7 | 8 | model_dir = snapshot_download(repo_id="metavoiceio/metavoice-1B-v0.1") 9 | checkpoint = torch.load(get_first_stage_path(model_dir)) 10 | tokenizer = TrainedBPETokeniser(**checkpoint["meta"]["tokenizer"]) 11 | 12 | def text_tokenize(text, output_file): 13 | tokens = tokenizer.encode(text) 14 | tokens = torch.tensor(tokens, dtype=torch.int64) 15 | torch.save(tokens, output_file) 16 | 17 | def main(file_name, num_shards, shard_id, base_dir): 18 | keys_text = get_keys_texts(file_name) 19 | start, end = get_ids(len(keys_text), num_shards, shard_id) 20 | print(start, end) 21 | for keys_texts in tqdm(keys_text[start:end]): 22 | file_name = keys_texts["key"] 23 | text = keys_texts["text"] 24 | out_file = os.path.join(base_dir, "txt", file_name + ".pt") 25 | create_directory_recursive(out_file) 26 | if os.path.exists(out_file): 27 | continue 28 | text_tokenize(text, out_file) 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser(description="Tokenize Text") 32 | args = parser.parse_args() 33 | parser.add_argument("--in_file", default="data/small.jsonl", help="Name of the file") 34 | parser.add_argument("--base_dir", type=str, default="data/", help="base_dir") 35 | parser.add_argument("--num_shards", type=int, default=1, help="Number of shards") 36 | parser.add_argument("--shard_id", type=int, default=0, help="Shard ID") 37 | args = parser.parse_args() 38 | main(args.in_file, args.num_shards, args.shard_id, args.base_dir) 39 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | 3 | 4 | def create_directory(dir_path): 5 | """ 6 | Recursively create directories if they do not exist. 7 | 8 | Args: 9 | directory_path (str): The directory path to create. 10 | """ 11 | if not os.path.exists(dir_path): 12 | os.makedirs(dir_path) 13 | 14 | def list_parent_directories(path): 15 | """ 16 | List all parent directories of a given path. 17 | 18 | Args: 19 | path (str): The directory path to find parent directories for. 20 | 21 | Returns: 22 | list: A list of parent directories. 23 | """ 24 | parent_directories = [] 25 | while True: 26 | path = os.path.dirname(path) 27 | if path == '/' or not path: 28 | break 29 | parent_directories.append(path) 30 | return parent_directories 31 | 32 | def create_directory_recursive(path): 33 | dirs = list_parent_directories(path) 34 | dirs.reverse() 35 | for dir in dirs: 36 | create_directory(dir) 37 | 38 | def get_sorted_keys(in_file, key="key"): 39 | files = [] 40 | with open(in_file) as fin: 41 | for line in fin: 42 | line = json.loads(line) 43 | files.append(line[key]) 44 | return files 45 | 46 | def get_keys_texts(in_file, key="key"): 47 | data = [] 48 | with open(in_file) as fin: 49 | for line in fin: 50 | line = json.loads(line) 51 | data.append(line) 52 | return data 53 | 54 | def get_ids(num_keys, num_shards, shard_id): 55 | shard_size = num_keys//num_shards 56 | remainder = num_keys - shard_size*num_shards 57 | start_id = shard_size*shard_id + min([shard_id, remainder]) 58 | end_id = shard_size*(shard_id+1) + min([shard_id+1, remainder]) 59 | return start_id, end_id 60 | 61 | def get_dirs(in_file): 62 | files = set() 63 | with open(in_file) as fin: 64 | for line in fin: 65 | line = json.loads(line) 66 | files.add(line["key"].split("/")[0]) 67 | files = sorted(files) 68 | return files 69 | 70 | def split_json_file(json_file_path, output_dir): 71 | with open(json_file_path, 'r') as f: 72 | # Read and modify lines 73 | lines = f.readlines() 74 | 75 | N = len(lines) 76 | train_data = lines[:N-1000] 77 | val_data = lines[N-1000:N-500] 78 | test_data = lines[N-500:] 79 | 80 | # Write to train.jsonl 81 | with open(os.path.join(output_dir, 'train.jsonl'), 'w', encoding='utf-8') as f: 82 | f.writelines(train_data) 83 | 84 | # Write to val.jsonl 85 | with open(os.path.join(output_dir, 'val.jsonl'), 'w') as f: 86 | f.writelines(val_data) 87 | 88 | # Write to test.jsonl 89 | with open(os.path.join(output_dir, 'test.jsonl'), 'w') as f: 90 | f.writelines(test_data) 91 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "kotoba-speech" 3 | version = "0.1.0" 4 | description = "Foundational model for text to speech" 5 | requires-python = ">=3.10,<3.12" 6 | 7 | [tool.black] 8 | line-length = 120 9 | exclude = ''' 10 | /( 11 | \.git 12 | | \.mypy_cache 13 | | \.tox 14 | | _build 15 | | build 16 | | dist 17 | )/ 18 | ''' 19 | 20 | [tool.isort] 21 | profile = "black" 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging==23.1 2 | wheel==0.42.0 3 | transformers 4 | librosa 5 | tqdm 6 | tiktoken==0.5.1 7 | audiocraft 8 | numpy 9 | ninja 10 | fastapi 11 | uvicorn 12 | tyro 13 | deepfilternet 14 | pydub 15 | soundfile 16 | gradio 17 | huggingface_hub 18 | wandb 19 | pytorch_lightning 20 | IPython 21 | -------------------------------------------------------------------------------- /scripts/abci/submit_ddp_1gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-l rt_AG.small=1 4 | #$-l h_rt=1:00:00 5 | #$-j y 6 | #$ -o outputs/a-node/ 7 | #$-cwd 8 | 9 | #submit: qsub -g gcd50698 scripts/abci/submit_ddp_1node.sh 10 | 11 | source /etc/profile.d/modules.sh 12 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7 13 | source myenv/bin/activate 14 | 15 | python fam/llm/train.py --num_gpus 1 --batch_size 1 --per_gpu_batchsize 1 --check_val_every_n_epoch 1 --exp_name kotoba_voice_1.3B_debug_abci_reazon_small --max_epoch 1 --debug_data_dir /groups/gcd50698/reazon_data/reazon_small --debug --use_ddp_strategy 16 | -------------------------------------------------------------------------------- /scripts/abci/submit_ddp_1node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-l rt_AF=1 4 | #$-l h_rt=0:10:00 5 | #$-j y 6 | #$ -o outputs/a-node/ 7 | #$-cwd 8 | 9 | #submit: qsub -g gcd50698 scripts/abci/submit_ddp_1node.sh 10 | 11 | source /etc/profile.d/modules.sh 12 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7 13 | source myenv/bin/activate 14 | 15 | python fam/llm/train.py --num_gpus 8 --batch_size 16 --per_gpu_batchsize 1 --check_val_every_n_epoch 1 --exp_name kotoba_voice_1.3B_debug_abci_reazon_small --max_epoch 1 --debug_data_dir /groups/gcd50698/reazon_data/reazon_small --debug --use_ddp_strategy 16 | -------------------------------------------------------------------------------- /scripts/abci/submit_fsdp_1node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-l rt_AF=1 4 | #$-l h_rt=0:10:00 5 | #$-j y 6 | #$ -o outputs/a-node/ 7 | #$-cwd 8 | 9 | #submit: qsub -g gcd50698 scripts/abci/submit_fsdp_1node.sh 10 | 11 | source /etc/profile.d/modules.sh 12 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7 13 | source myenv/bin/activate 14 | 15 | python fam/llm/train.py --num_gpus 8 --batch_size 16 --per_gpu_batchsize 1 --check_val_every_n_epoch 1 --exp_name kotoba_voice_1.3B_debug_abci_reazon_small --max_epoch 1 --debug_data_dir /groups/gcd50698/reazon_data/reazon_small --debug --fsdp_strategy SHARD_GRAD_OP 16 | -------------------------------------------------------------------------------- /scripts/abci/submit_fsdp_2node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$-l rt_AF=2 4 | #$-l h_rt=0:10:00 5 | #$-j y 6 | #$ -o outputs/a-node/ 7 | #$-cwd 8 | 9 | #submit: qsub -g gcd50698 scripts/abci/submit_fsdp_2node.sh 10 | 11 | # module 12 | source /etc/profile.d/modules.sh 13 | module load python/3.11/3.11.2 cuda/12.0/12.0.0 cudnn/8.9/8.9.7 hpcx/2.12 14 | source myenv/bin/activate 15 | 16 | # distributed settings 17 | export MASTER_ADDR=$(/usr/sbin/ip a show dev bond0 | grep 'inet ' | awk '{ print $2 }' | cut -d "/" -f 1) 18 | export MASTER_PORT=$((10000 + ($JOB_ID % 50000))) 19 | export NUM_GPU_PER_NODE=8 20 | NUM_NODES=$NHOSTS 21 | NUM_GPUS=$((${NUM_NODES} * ${NUM_GPU_PER_NODE})) 22 | 23 | mkdir -p ./hostfile 24 | HOSTFILE_NAME=./hostfile/hostfile_${JOB_ID} 25 | while read -r line; do 26 | echo "${line} slots=${NUM_GPU_PER_NODE}" 27 | done <"$SGE_JOB_HOSTLIST" >"$HOSTFILE_NAME" 28 | 29 | mpirun -np $NUM_GPUS \ 30 | -hostfile $HOSTFILE_NAME \ 31 | -x MASTER_ADDR=$MASTER_ADDR \ 32 | -x MASTER_PORT=$MASTER_PORT \ 33 | -bind-to none -map-by slot \ 34 | -x PATH \ 35 | python fam/llm/train.py \ 36 | --num_nodes 2 \ 37 | --num_gpus 8 \ 38 | --batch_size 16 \ 39 | --per_gpu_batchsize 1 \ 40 | --check_val_every_n_epoch 1 \ 41 | --exp_name kotoba_voice_1.3B_debug_abci_reazon_small \ 42 | --max_epoch 1 \ 43 | --debug_data_dir /groups/gcd50698/reazon_data/reazon_small \ 44 | --debug \ 45 | --fsdp_strategy SHARD_GRAD_OP 46 | 47 | # python -m torch.distributed.run \ 48 | # --nnodes 2 \ 49 | # --master_addr $MASTER_ADDR \ 50 | # --master_port $MASTER_PORT \ 51 | # --nproc_per_node 8 \ 52 | # fam/llm/train.py \ 53 | # --num_nodes 2 \ 54 | # --num_gpus 8 \ 55 | # --batch_size 16 \ 56 | # --per_gpu_batchsize 1 \ 57 | # --check_val_every_n_epoch 1 \ 58 | # --exp_name kotoba_voice_1.3B_debug_abci_reazon_small \ 59 | # --max_epoch 1 \ 60 | # --debug_data_dir /groups/gcd50698/reazon_data/reazon_small \ 61 | # --debug \ 62 | # --fsdp_strategy SHARD_GRAD_OP 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup # type: ignore 2 | 3 | setup( 4 | name="fam", 5 | packages=find_packages(".", exclude=["tests"]), 6 | ) 7 | --------------------------------------------------------------------------------